kodit 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (100) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +59 -24
  3. kodit/application/factories/reporting_factory.py +16 -7
  4. kodit/application/factories/server_factory.py +311 -0
  5. kodit/application/services/code_search_application_service.py +144 -0
  6. kodit/application/services/commit_indexing_application_service.py +543 -0
  7. kodit/application/services/indexing_worker_service.py +13 -46
  8. kodit/application/services/queue_service.py +24 -3
  9. kodit/application/services/reporting.py +70 -54
  10. kodit/application/services/sync_scheduler.py +15 -31
  11. kodit/cli.py +2 -763
  12. kodit/cli_utils.py +2 -9
  13. kodit/config.py +3 -96
  14. kodit/database.py +38 -1
  15. kodit/domain/entities/__init__.py +276 -0
  16. kodit/domain/entities/git.py +190 -0
  17. kodit/domain/factories/__init__.py +1 -0
  18. kodit/domain/factories/git_repo_factory.py +76 -0
  19. kodit/domain/protocols.py +270 -46
  20. kodit/domain/services/bm25_service.py +5 -1
  21. kodit/domain/services/embedding_service.py +3 -0
  22. kodit/domain/services/git_repository_service.py +429 -0
  23. kodit/domain/services/git_service.py +300 -0
  24. kodit/domain/services/task_status_query_service.py +19 -0
  25. kodit/domain/value_objects.py +113 -147
  26. kodit/infrastructure/api/client/__init__.py +0 -2
  27. kodit/infrastructure/api/v1/__init__.py +0 -4
  28. kodit/infrastructure/api/v1/dependencies.py +105 -44
  29. kodit/infrastructure/api/v1/routers/__init__.py +0 -6
  30. kodit/infrastructure/api/v1/routers/commits.py +271 -0
  31. kodit/infrastructure/api/v1/routers/queue.py +2 -2
  32. kodit/infrastructure/api/v1/routers/repositories.py +282 -0
  33. kodit/infrastructure/api/v1/routers/search.py +31 -14
  34. kodit/infrastructure/api/v1/schemas/__init__.py +0 -24
  35. kodit/infrastructure/api/v1/schemas/commit.py +96 -0
  36. kodit/infrastructure/api/v1/schemas/context.py +2 -0
  37. kodit/infrastructure/api/v1/schemas/repository.py +128 -0
  38. kodit/infrastructure/api/v1/schemas/search.py +12 -9
  39. kodit/infrastructure/api/v1/schemas/snippet.py +58 -0
  40. kodit/infrastructure/api/v1/schemas/tag.py +31 -0
  41. kodit/infrastructure/api/v1/schemas/task_status.py +41 -0
  42. kodit/infrastructure/bm25/local_bm25_repository.py +16 -4
  43. kodit/infrastructure/bm25/vectorchord_bm25_repository.py +68 -52
  44. kodit/infrastructure/cloning/git/git_python_adaptor.py +467 -0
  45. kodit/infrastructure/cloning/git/working_copy.py +10 -3
  46. kodit/infrastructure/embedding/embedding_factory.py +3 -2
  47. kodit/infrastructure/embedding/local_vector_search_repository.py +1 -1
  48. kodit/infrastructure/embedding/vectorchord_vector_search_repository.py +111 -84
  49. kodit/infrastructure/enrichment/litellm_enrichment_provider.py +19 -26
  50. kodit/infrastructure/enrichment/local_enrichment_provider.py +41 -30
  51. kodit/infrastructure/indexing/fusion_service.py +1 -1
  52. kodit/infrastructure/mappers/git_mapper.py +193 -0
  53. kodit/infrastructure/mappers/snippet_mapper.py +106 -0
  54. kodit/infrastructure/mappers/task_mapper.py +5 -44
  55. kodit/infrastructure/mappers/task_status_mapper.py +85 -0
  56. kodit/infrastructure/reporting/db_progress.py +23 -0
  57. kodit/infrastructure/reporting/log_progress.py +13 -38
  58. kodit/infrastructure/reporting/telemetry_progress.py +21 -0
  59. kodit/infrastructure/slicing/slicer.py +32 -31
  60. kodit/infrastructure/sqlalchemy/embedding_repository.py +43 -23
  61. kodit/infrastructure/sqlalchemy/entities.py +428 -131
  62. kodit/infrastructure/sqlalchemy/git_branch_repository.py +263 -0
  63. kodit/infrastructure/sqlalchemy/git_commit_repository.py +337 -0
  64. kodit/infrastructure/sqlalchemy/git_repository.py +252 -0
  65. kodit/infrastructure/sqlalchemy/git_tag_repository.py +257 -0
  66. kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +484 -0
  67. kodit/infrastructure/sqlalchemy/task_repository.py +29 -23
  68. kodit/infrastructure/sqlalchemy/task_status_repository.py +91 -0
  69. kodit/infrastructure/sqlalchemy/unit_of_work.py +10 -14
  70. kodit/mcp.py +12 -26
  71. kodit/migrations/env.py +1 -1
  72. kodit/migrations/versions/04b80f802e0c_foreign_key_review.py +100 -0
  73. kodit/migrations/versions/7f15f878c3a1_add_new_git_entities.py +690 -0
  74. kodit/migrations/versions/b9cd1c3fd762_add_task_status.py +77 -0
  75. kodit/migrations/versions/f9e5ef5e688f_add_git_commits_number.py +43 -0
  76. kodit/py.typed +0 -0
  77. kodit/utils/dump_openapi.py +7 -4
  78. kodit/utils/path_utils.py +29 -0
  79. {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/METADATA +3 -3
  80. kodit-0.5.0.dist-info/RECORD +137 -0
  81. kodit/application/factories/code_indexing_factory.py +0 -193
  82. kodit/application/services/auto_indexing_service.py +0 -103
  83. kodit/application/services/code_indexing_application_service.py +0 -393
  84. kodit/domain/entities.py +0 -323
  85. kodit/domain/services/index_query_service.py +0 -70
  86. kodit/domain/services/index_service.py +0 -267
  87. kodit/infrastructure/api/client/index_client.py +0 -57
  88. kodit/infrastructure/api/v1/routers/indexes.py +0 -119
  89. kodit/infrastructure/api/v1/schemas/index.py +0 -101
  90. kodit/infrastructure/bm25/bm25_factory.py +0 -28
  91. kodit/infrastructure/cloning/__init__.py +0 -1
  92. kodit/infrastructure/cloning/metadata.py +0 -98
  93. kodit/infrastructure/mappers/index_mapper.py +0 -345
  94. kodit/infrastructure/reporting/tdqm_progress.py +0 -73
  95. kodit/infrastructure/slicing/language_detection_service.py +0 -18
  96. kodit/infrastructure/sqlalchemy/index_repository.py +0 -646
  97. kodit-0.4.2.dist-info/RECORD +0 -119
  98. {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/WHEEL +0 -0
  99. {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/entry_points.txt +0 -0
  100. {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,484 @@
1
+ """SQLAlchemy implementation of SnippetRepositoryV2."""
2
+
3
+ import zlib
4
+ from collections.abc import Callable
5
+
6
+ from sqlalchemy import delete, insert, select
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+
9
+ from kodit.domain.entities.git import SnippetV2
10
+ from kodit.domain.protocols import SnippetRepositoryV2
11
+ from kodit.domain.value_objects import MultiSearchRequest
12
+ from kodit.infrastructure.mappers.snippet_mapper import SnippetMapper
13
+ from kodit.infrastructure.sqlalchemy import entities as db_entities
14
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
15
+
16
+
17
+ def create_snippet_v2_repository(
18
+ session_factory: Callable[[], AsyncSession],
19
+ ) -> SnippetRepositoryV2:
20
+ """Create a snippet v2 repository."""
21
+ return SqlAlchemySnippetRepositoryV2(session_factory=session_factory)
22
+
23
+
24
+ class SqlAlchemySnippetRepositoryV2(SnippetRepositoryV2):
25
+ """SQLAlchemy implementation of SnippetRepositoryV2."""
26
+
27
+ def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
28
+ """Initialize the repository."""
29
+ self.session_factory = session_factory
30
+
31
+ @property
32
+ def _mapper(self) -> SnippetMapper:
33
+ return SnippetMapper()
34
+
35
+ async def save_snippets(self, commit_sha: str, snippets: list[SnippetV2]) -> None:
36
+ """Batch save snippets for a commit."""
37
+ if not snippets:
38
+ return
39
+
40
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
41
+ # Bulk operations for better performance
42
+ await self._bulk_save_snippets(session, snippets)
43
+ await self._bulk_create_commit_associations(session, commit_sha, snippets)
44
+ await self._bulk_create_file_associations(session, commit_sha, snippets)
45
+ await self._bulk_update_enrichments(session, snippets)
46
+
47
+ async def _bulk_save_snippets(
48
+ self, session: AsyncSession, snippets: list[SnippetV2]
49
+ ) -> None:
50
+ """Bulk save snippets using efficient batch operations."""
51
+ snippet_shas = [snippet.sha for snippet in snippets]
52
+
53
+ # Get existing snippets in bulk
54
+ existing_snippets_stmt = select(db_entities.SnippetV2.sha).where(
55
+ db_entities.SnippetV2.sha.in_(snippet_shas)
56
+ )
57
+ existing_snippet_shas = set(
58
+ (await session.scalars(existing_snippets_stmt)).all()
59
+ )
60
+
61
+ # Prepare new snippets for bulk insert
62
+ new_snippets = [
63
+ {
64
+ "sha": snippet.sha,
65
+ "content": snippet.content,
66
+ "extension": snippet.extension,
67
+ }
68
+ for snippet in snippets
69
+ if snippet.sha not in existing_snippet_shas
70
+ ]
71
+
72
+ # Bulk insert new snippets in chunks to avoid parameter limits
73
+ if new_snippets:
74
+ chunk_size = 1000 # Conservative chunk size for parameter limits
75
+ for i in range(0, len(new_snippets), chunk_size):
76
+ chunk = new_snippets[i : i + chunk_size]
77
+ stmt = insert(db_entities.SnippetV2).values(chunk)
78
+ await session.execute(stmt)
79
+
80
+ async def _bulk_create_commit_associations(
81
+ self, session: AsyncSession, commit_sha: str, snippets: list[SnippetV2]
82
+ ) -> None:
83
+ """Bulk create commit-snippet associations."""
84
+ snippet_shas = [snippet.sha for snippet in snippets]
85
+
86
+ # Get existing associations in bulk
87
+ existing_associations_stmt = select(
88
+ db_entities.CommitSnippetV2.snippet_sha
89
+ ).where(
90
+ db_entities.CommitSnippetV2.commit_sha == commit_sha,
91
+ db_entities.CommitSnippetV2.snippet_sha.in_(snippet_shas)
92
+ )
93
+ existing_association_shas = set(
94
+ (await session.scalars(existing_associations_stmt)).all()
95
+ )
96
+
97
+ # Prepare new associations for bulk insert
98
+ new_associations = [
99
+ {
100
+ "commit_sha": commit_sha,
101
+ "snippet_sha": snippet.sha,
102
+ }
103
+ for snippet in snippets
104
+ if snippet.sha not in existing_association_shas
105
+ ]
106
+
107
+ # Bulk insert new associations in chunks to avoid parameter limits
108
+ if new_associations:
109
+ chunk_size = 1000 # Conservative chunk size for parameter limits
110
+ for i in range(0, len(new_associations), chunk_size):
111
+ chunk = new_associations[i : i + chunk_size]
112
+ stmt = insert(db_entities.CommitSnippetV2).values(chunk)
113
+ await session.execute(stmt)
114
+
115
+ async def _bulk_create_file_associations(
116
+ self, session: AsyncSession, commit_sha: str, snippets: list[SnippetV2]
117
+ ) -> None:
118
+ """Bulk create snippet-file associations."""
119
+ # Collect all file paths from all snippets
120
+ file_paths = set()
121
+ for snippet in snippets:
122
+ for file in snippet.derives_from:
123
+ file_paths.add(file.path)
124
+
125
+ if not file_paths:
126
+ return
127
+
128
+ # Get existing files in bulk
129
+ existing_files_stmt = select(
130
+ db_entities.GitCommitFile.path,
131
+ db_entities.GitCommitFile.blob_sha
132
+ ).where(
133
+ db_entities.GitCommitFile.commit_sha == commit_sha,
134
+ db_entities.GitCommitFile.path.in_(list(file_paths))
135
+ )
136
+ existing_files_result = await session.execute(existing_files_stmt)
137
+ existing_files_map: dict[str, str] = {
138
+ row[0]: row[1] for row in existing_files_result.fetchall()
139
+ }
140
+
141
+ # Get existing snippet-file associations to avoid duplicates
142
+ snippet_shas = [snippet.sha for snippet in snippets]
143
+ existing_snippet_files_stmt = select(
144
+ db_entities.SnippetV2File.snippet_sha,
145
+ db_entities.SnippetV2File.file_path
146
+ ).where(
147
+ db_entities.SnippetV2File.commit_sha == commit_sha,
148
+ db_entities.SnippetV2File.snippet_sha.in_(snippet_shas)
149
+ )
150
+ existing_snippet_files = set(await session.execute(existing_snippet_files_stmt))
151
+
152
+ # Prepare new file associations
153
+ new_file_associations = []
154
+ for snippet in snippets:
155
+ for file in snippet.derives_from:
156
+ association_key = (snippet.sha, file.path)
157
+ if (association_key not in existing_snippet_files
158
+ and file.path in existing_files_map):
159
+ new_file_associations.append({
160
+ "snippet_sha": snippet.sha,
161
+ "blob_sha": existing_files_map[file.path],
162
+ "commit_sha": commit_sha,
163
+ "file_path": file.path,
164
+ })
165
+
166
+ # Bulk insert new file associations in chunks to avoid parameter limits
167
+ if new_file_associations:
168
+ chunk_size = 1000 # Conservative chunk size for parameter limits
169
+ for i in range(0, len(new_file_associations), chunk_size):
170
+ chunk = new_file_associations[i : i + chunk_size]
171
+ stmt = insert(db_entities.SnippetV2File).values(chunk)
172
+ await session.execute(stmt)
173
+
174
+ async def _bulk_update_enrichments(
175
+ self, session: AsyncSession, snippets: list[SnippetV2]
176
+ ) -> None:
177
+ """Bulk update enrichments for snippets."""
178
+ snippet_shas = [snippet.sha for snippet in snippets]
179
+
180
+ # Get all existing enrichments for these snippets
181
+ existing_enrichments_stmt = select(
182
+ db_entities.Enrichment.snippet_sha,
183
+ db_entities.Enrichment.type,
184
+ db_entities.Enrichment.content
185
+ ).where(
186
+ db_entities.Enrichment.snippet_sha.in_(snippet_shas)
187
+ )
188
+ existing_enrichments = await session.execute(existing_enrichments_stmt)
189
+
190
+ # Create lookup for existing enrichment hashes
191
+ existing_enrichment_map = {}
192
+ for snippet_sha, enrichment_type, content in existing_enrichments:
193
+ content_hash = self._hash_string(content)
194
+ key = (snippet_sha, enrichment_type)
195
+ existing_enrichment_map[key] = content_hash
196
+
197
+ # Collect enrichments to delete and add
198
+ enrichments_to_delete = []
199
+ enrichments_to_add = []
200
+
201
+ for snippet in snippets:
202
+ for enrichment in snippet.enrichments:
203
+ key = (snippet.sha, db_entities.EnrichmentType(enrichment.type.value))
204
+ new_hash = self._hash_string(enrichment.content)
205
+
206
+ if key in existing_enrichment_map:
207
+ if existing_enrichment_map[key] != new_hash:
208
+ # Content changed, mark for deletion and re-addition
209
+ enrichments_to_delete.append(key)
210
+ enrichments_to_add.append({
211
+ "snippet_sha": snippet.sha,
212
+ "type": db_entities.EnrichmentType(enrichment.type.value),
213
+ "content": enrichment.content,
214
+ })
215
+ else:
216
+ # New enrichment
217
+ enrichments_to_add.append({
218
+ "snippet_sha": snippet.sha,
219
+ "type": db_entities.EnrichmentType(enrichment.type.value),
220
+ "content": enrichment.content,
221
+ })
222
+
223
+ # Bulk delete changed enrichments
224
+ if enrichments_to_delete:
225
+ for snippet_sha, enrichment_type in enrichments_to_delete:
226
+ stmt = delete(db_entities.Enrichment).where(
227
+ db_entities.Enrichment.snippet_sha == snippet_sha,
228
+ db_entities.Enrichment.type == enrichment_type,
229
+ )
230
+ await session.execute(stmt)
231
+
232
+ # Bulk insert new/updated enrichments in chunks to avoid parameter limits
233
+ if enrichments_to_add:
234
+ chunk_size = 1000 # Conservative chunk size for parameter limits
235
+ for i in range(0, len(enrichments_to_add), chunk_size):
236
+ chunk = enrichments_to_add[i : i + chunk_size]
237
+ insert_stmt = insert(db_entities.Enrichment).values(chunk)
238
+ await session.execute(insert_stmt)
239
+
240
+ async def _get_or_create_raw_snippet(
241
+ self, session: AsyncSession, commit_sha: str, domain_snippet: SnippetV2
242
+ ) -> db_entities.SnippetV2:
243
+ """Get or create a SnippetV2 in the database."""
244
+ db_snippet = await session.get(db_entities.SnippetV2, domain_snippet.sha)
245
+ if not db_snippet:
246
+ db_snippet = self._mapper.from_domain_snippet_v2(domain_snippet)
247
+ session.add(db_snippet)
248
+ await session.flush()
249
+
250
+ # Associate snippet with commit
251
+ commit_association = db_entities.CommitSnippetV2(
252
+ commit_sha=commit_sha,
253
+ snippet_sha=db_snippet.sha,
254
+ )
255
+ session.add(commit_association)
256
+
257
+ # Associate snippet with files
258
+ for file in domain_snippet.derives_from:
259
+ # Find the file in the database (which should have been created during
260
+ # the scan)
261
+ db_file = await session.get(
262
+ db_entities.GitCommitFile, (commit_sha, file.path)
263
+ )
264
+ if not db_file:
265
+ raise ValueError(
266
+ f"File {file.path} not found for commit {commit_sha}"
267
+ )
268
+ db_association = db_entities.SnippetV2File(
269
+ snippet_sha=db_snippet.sha,
270
+ blob_sha=db_file.blob_sha,
271
+ commit_sha=commit_sha,
272
+ file_path=file.path,
273
+ )
274
+ session.add(db_association)
275
+ return db_snippet
276
+
277
+ async def _update_enrichments_if_changed(
278
+ self,
279
+ session: AsyncSession,
280
+ db_snippet: db_entities.SnippetV2,
281
+ domain_snippet: SnippetV2,
282
+ ) -> None:
283
+ """Update enrichments if they have changed."""
284
+ current_enrichments = await session.scalars(
285
+ select(db_entities.Enrichment).where(
286
+ db_entities.Enrichment.snippet_sha == db_snippet.sha
287
+ )
288
+ )
289
+ current_enrichment_shas = {
290
+ self._hash_string(enrichment.content)
291
+ for enrichment in list(current_enrichments)
292
+ }
293
+ for enrichment in domain_snippet.enrichments:
294
+ if self._hash_string(enrichment.content) in current_enrichment_shas:
295
+ continue
296
+
297
+ # If not present, delete the existing enrichment for this type if it exists
298
+ stmt = delete(db_entities.Enrichment).where(
299
+ db_entities.Enrichment.snippet_sha == db_snippet.sha,
300
+ db_entities.Enrichment.type
301
+ == db_entities.EnrichmentType(enrichment.type.value),
302
+ )
303
+ await session.execute(stmt)
304
+
305
+ db_enrichment = db_entities.Enrichment(
306
+ snippet_sha=db_snippet.sha,
307
+ type=db_entities.EnrichmentType(enrichment.type.value),
308
+ content=enrichment.content,
309
+ )
310
+ session.add(db_enrichment)
311
+
312
+ async def get_snippets_for_commit(self, commit_sha: str) -> list[SnippetV2]:
313
+ """Get all snippets for a specific commit."""
314
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
315
+ # Get snippets for the commit through the association table
316
+ snippet_associations = (
317
+ await session.scalars(
318
+ select(db_entities.CommitSnippetV2).where(
319
+ db_entities.CommitSnippetV2.commit_sha == commit_sha
320
+ )
321
+ )
322
+ ).all()
323
+ if not snippet_associations:
324
+ return []
325
+ db_snippets = (
326
+ await session.scalars(
327
+ select(db_entities.SnippetV2).where(
328
+ db_entities.SnippetV2.sha.in_(
329
+ [
330
+ association.snippet_sha
331
+ for association in snippet_associations
332
+ ]
333
+ )
334
+ )
335
+ )
336
+ ).all()
337
+
338
+ return [
339
+ await self._to_domain_snippet_v2(session, db_snippet)
340
+ for db_snippet in db_snippets
341
+ ]
342
+
343
+ async def delete_snippets_for_commit(self, commit_sha: str) -> None:
344
+ """Delete all snippet associations for a commit."""
345
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
346
+ # Note: We only delete the commit-snippet associations,
347
+ # not the snippets themselves as they might be used by other commits
348
+ stmt = delete(db_entities.CommitSnippetV2).where(
349
+ db_entities.CommitSnippetV2.commit_sha == commit_sha
350
+ )
351
+ await session.execute(stmt)
352
+
353
+ def _hash_string(self, string: str) -> int:
354
+ """Hash a string."""
355
+ return zlib.crc32(string.encode())
356
+
357
+ async def search(self, request: MultiSearchRequest) -> list[SnippetV2]:
358
+ """Search snippets with filters."""
359
+ raise NotImplementedError("Not implemented")
360
+
361
+ # Build base query joining all necessary tables
362
+ query = (
363
+ select(
364
+ db_entities.SnippetV2,
365
+ db_entities.GitCommit,
366
+ db_entities.GitFile,
367
+ db_entities.GitRepo,
368
+ )
369
+ .join(
370
+ db_entities.CommitSnippetV2,
371
+ db_entities.SnippetV2.sha == db_entities.CommitSnippetV2.snippet_sha,
372
+ )
373
+ .join(
374
+ db_entities.GitCommit,
375
+ db_entities.CommitSnippetV2.commit_sha
376
+ == db_entities.GitCommit.commit_sha,
377
+ )
378
+ .join(
379
+ db_entities.SnippetV2File,
380
+ db_entities.SnippetV2.sha == db_entities.SnippetV2File.snippet_sha,
381
+ )
382
+ .join(
383
+ db_entities.GitCommitFile,
384
+ db_entities.SnippetV2.sha == db_entities.Enrichment.snippet_sha,
385
+ )
386
+ .join(
387
+ db_entities.GitFile,
388
+ db_entities.SnippetV2File.file_blob_sha == db_entities.GitFile.blob_sha,
389
+ )
390
+ .join(
391
+ db_entities.GitRepo,
392
+ db_entities.GitCommitFile.file_blob_sha == db_entities.GitRepo.id,
393
+ )
394
+ )
395
+
396
+ # Apply filters if provided
397
+ if request.filters:
398
+ if request.filters.source_repo:
399
+ query = query.where(
400
+ db_entities.GitRepo.sanitized_remote_uri.ilike(
401
+ f"%{request.filters.source_repo}%"
402
+ )
403
+ )
404
+
405
+ if request.filters.file_path:
406
+ query = query.where(
407
+ db_entities.GitFile.path.ilike(f"%{request.filters.file_path}%")
408
+ )
409
+
410
+ # TODO(Phil): Double check that git timestamps are correctly populated
411
+ if request.filters.created_after:
412
+ query = query.where(
413
+ db_entities.GitFile.created_at >= request.filters.created_after
414
+ )
415
+
416
+ if request.filters.created_before:
417
+ query = query.where(
418
+ db_entities.GitFile.created_at <= request.filters.created_before
419
+ )
420
+
421
+ # Apply limit
422
+ query = query.limit(request.top_k)
423
+
424
+ # Execute query
425
+ async with SqlAlchemyUnitOfWork(self.session_factory):
426
+ result = await self._session.scalars(query)
427
+ db_snippets = result.all()
428
+
429
+ return [
430
+ self._mapper.to_domain_snippet_v2(
431
+ db_snippet=snippet,
432
+ derives_from=git_file,
433
+ db_enrichments=[],
434
+ )
435
+ for snippet, git_commit, git_file, git_repo in db_snippets
436
+ ]
437
+
438
+ async def get_by_ids(self, ids: list[str]) -> list[SnippetV2]:
439
+ """Get snippets by their IDs."""
440
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
441
+ # Get snippets for the commit through the association table
442
+ db_snippets = (
443
+ await session.scalars(
444
+ select(db_entities.SnippetV2).where(
445
+ db_entities.SnippetV2.sha.in_(ids)
446
+ )
447
+ )
448
+ ).all()
449
+
450
+ return [
451
+ await self._to_domain_snippet_v2(session, db_snippet)
452
+ for db_snippet in db_snippets
453
+ ]
454
+
455
+ async def _to_domain_snippet_v2(
456
+ self, session: AsyncSession, db_snippet: db_entities.SnippetV2
457
+ ) -> SnippetV2:
458
+ """Convert a SQLAlchemy SnippetV2 to a domain SnippetV2."""
459
+ # Files it derives from
460
+ db_files = await session.scalars(
461
+ select(db_entities.GitCommitFile)
462
+ .join(
463
+ db_entities.SnippetV2File,
464
+ (db_entities.GitCommitFile.path == db_entities.SnippetV2File.file_path)
465
+ & (
466
+ db_entities.GitCommitFile.commit_sha
467
+ == db_entities.SnippetV2File.commit_sha
468
+ ),
469
+ )
470
+ .where(db_entities.SnippetV2File.snippet_sha == db_snippet.sha)
471
+ )
472
+
473
+ # Enrichments related to this snippet
474
+ db_enrichments = await session.scalars(
475
+ select(db_entities.Enrichment).where(
476
+ db_entities.Enrichment.snippet_sha == db_snippet.sha
477
+ )
478
+ )
479
+
480
+ return self._mapper.to_domain_snippet_v2(
481
+ db_snippet=db_snippet,
482
+ db_files=list(db_files),
483
+ db_enrichments=list(db_enrichments),
484
+ )
@@ -8,8 +8,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
8
8
 
9
9
  from kodit.domain.entities import Task
10
10
  from kodit.domain.protocols import TaskRepository
11
- from kodit.domain.value_objects import TaskType
12
- from kodit.infrastructure.mappers.task_mapper import TaskMapper, TaskTypeMapper
11
+ from kodit.domain.value_objects import TaskOperation
12
+ from kodit.infrastructure.mappers.task_mapper import TaskMapper
13
13
  from kodit.infrastructure.sqlalchemy import entities as db_entities
14
14
  from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
15
15
 
@@ -18,16 +18,15 @@ def create_task_repository(
18
18
  session_factory: Callable[[], AsyncSession],
19
19
  ) -> TaskRepository:
20
20
  """Create an index repository."""
21
- uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
22
- return SqlAlchemyTaskRepository(uow)
21
+ return SqlAlchemyTaskRepository(session_factory=session_factory)
23
22
 
24
23
 
25
24
  class SqlAlchemyTaskRepository(TaskRepository):
26
25
  """Repository for task persistence using the existing Task entity."""
27
26
 
28
- def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
27
+ def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
29
28
  """Initialize the repository."""
30
- self.uow = uow
29
+ self.session_factory = session_factory
31
30
  self.log = structlog.get_logger(__name__)
32
31
 
33
32
  async def add(
@@ -35,39 +34,48 @@ class SqlAlchemyTaskRepository(TaskRepository):
35
34
  task: Task,
36
35
  ) -> None:
37
36
  """Create a new task in the database."""
38
- async with self.uow:
39
- self.uow.session.add(TaskMapper.from_domain_task(task))
37
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
38
+ session.add(TaskMapper.from_domain_task(task))
40
39
 
41
40
  async def get(self, task_id: str) -> Task | None:
42
41
  """Get a task by ID."""
43
- async with self.uow:
42
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
44
43
  stmt = select(db_entities.Task).where(db_entities.Task.dedup_key == task_id)
45
- result = await self.uow.session.execute(stmt)
44
+ result = await session.execute(stmt)
46
45
  db_task = result.scalar_one_or_none()
47
46
  if not db_task:
48
47
  return None
49
48
  return TaskMapper.to_domain_task(db_task)
50
49
 
51
- async def take(self) -> Task | None:
50
+ async def next(self) -> Task | None:
52
51
  """Take a task for processing and remove it from the database."""
53
- async with self.uow:
52
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
54
53
  stmt = (
55
54
  select(db_entities.Task)
56
55
  .order_by(db_entities.Task.priority.desc(), db_entities.Task.created_at)
57
56
  .limit(1)
58
57
  )
59
- result = await self.uow.session.execute(stmt)
58
+ result = await session.execute(stmt)
60
59
  db_task = result.scalar_one_or_none()
61
60
  if not db_task:
62
61
  return None
63
- await self.uow.session.delete(db_task)
64
62
  return TaskMapper.to_domain_task(db_task)
65
63
 
64
+ async def remove(self, task: Task) -> None:
65
+ """Remove a task from the database."""
66
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
67
+ db_task = await session.scalar(
68
+ select(db_entities.Task).where(db_entities.Task.dedup_key == task.id)
69
+ )
70
+ if not db_task:
71
+ raise ValueError(f"Task not found: {task.id}")
72
+ await session.delete(db_task)
73
+
66
74
  async def update(self, task: Task) -> None:
67
75
  """Update a task in the database."""
68
- async with self.uow:
76
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
69
77
  stmt = select(db_entities.Task).where(db_entities.Task.dedup_key == task.id)
70
- result = await self.uow.session.execute(stmt)
78
+ result = await session.execute(stmt)
71
79
  db_task = result.scalar_one_or_none()
72
80
 
73
81
  if not db_task:
@@ -76,21 +84,19 @@ class SqlAlchemyTaskRepository(TaskRepository):
76
84
  db_task.priority = task.priority
77
85
  db_task.payload = task.payload
78
86
 
79
- async def list(self, task_type: TaskType | None = None) -> list[Task]:
87
+ async def list(self, task_operation: TaskOperation | None = None) -> list[Task]:
80
88
  """List tasks with optional status filter."""
81
- async with self.uow:
89
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
82
90
  stmt = select(db_entities.Task)
83
91
 
84
- if task_type:
85
- stmt = stmt.where(
86
- db_entities.Task.type == TaskTypeMapper.from_domain_type(task_type)
87
- )
92
+ if task_operation:
93
+ stmt = stmt.where(db_entities.Task.type == task_operation.value)
88
94
 
89
95
  stmt = stmt.order_by(
90
96
  db_entities.Task.priority.desc(), db_entities.Task.created_at
91
97
  )
92
98
 
93
- result = await self.uow.session.execute(stmt)
99
+ result = await session.execute(stmt)
94
100
  records = result.scalars().all()
95
101
 
96
102
  # Convert to domain entities
@@ -0,0 +1,91 @@
1
+ """Task repository for the task queue."""
2
+
3
+ from collections.abc import Callable
4
+
5
+ import structlog
6
+ from sqlalchemy import delete, select
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+
9
+ from kodit.domain import entities as domain_entities
10
+ from kodit.domain.protocols import TaskStatusRepository
11
+ from kodit.infrastructure.mappers.task_status_mapper import TaskStatusMapper
12
+ from kodit.infrastructure.sqlalchemy import entities as db_entities
13
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
14
+
15
+
16
+ def create_task_status_repository(
17
+ session_factory: Callable[[], AsyncSession],
18
+ ) -> TaskStatusRepository:
19
+ """Create an index repository."""
20
+ return SqlAlchemyTaskStatusRepository(session_factory=session_factory)
21
+
22
+
23
+ class SqlAlchemyTaskStatusRepository(TaskStatusRepository):
24
+ """Repository for persisting TaskStatus entities."""
25
+
26
+ def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
27
+ """Initialize the repository."""
28
+ self.session_factory = session_factory
29
+ self.mapper = TaskStatusMapper()
30
+ self.log = structlog.get_logger(__name__)
31
+
32
+ async def save(self, status: domain_entities.TaskStatus) -> None:
33
+ """Save a TaskStatus to database."""
34
+ # If this task has a parent, ensure the parent exists in the database first
35
+ if status.parent is not None:
36
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
37
+ parent_stmt = select(db_entities.TaskStatus).where(
38
+ db_entities.TaskStatus.id == status.parent.id,
39
+ )
40
+ parent_result = await session.execute(parent_stmt)
41
+ existing_parent = parent_result.scalar_one_or_none()
42
+
43
+ if not existing_parent:
44
+ # Recursively save the parent first
45
+ await self.save(status.parent)
46
+
47
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
48
+ # Convert domain entity to database entity
49
+ db_status = self.mapper.from_domain_task_status(status)
50
+ stmt = select(db_entities.TaskStatus).where(
51
+ db_entities.TaskStatus.id == db_status.id,
52
+ )
53
+ result = await session.execute(stmt)
54
+ existing = result.scalar_one_or_none()
55
+
56
+ if not existing:
57
+ session.add(db_status)
58
+ else:
59
+ # Update existing record with new values
60
+ existing.operation = db_status.operation
61
+ existing.state = db_status.state
62
+ existing.error = db_status.error
63
+ existing.total = db_status.total
64
+ existing.current = db_status.current
65
+ existing.updated_at = db_status.updated_at
66
+ existing.parent = db_status.parent
67
+ existing.trackable_id = db_status.trackable_id
68
+ existing.trackable_type = db_status.trackable_type
69
+
70
+ async def load_with_hierarchy(
71
+ self, trackable_type: str, trackable_id: int
72
+ ) -> list[domain_entities.TaskStatus]:
73
+ """Load TaskStatus entities with hierarchy from database."""
74
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
75
+ stmt = select(db_entities.TaskStatus).where(
76
+ db_entities.TaskStatus.trackable_id == trackable_id,
77
+ db_entities.TaskStatus.trackable_type == trackable_type,
78
+ )
79
+ result = await session.execute(stmt)
80
+ db_statuses = list(result.scalars().all())
81
+
82
+ # Use mapper to convert and reconstruct hierarchy
83
+ return self.mapper.to_domain_task_status_with_hierarchy(db_statuses)
84
+
85
+ async def delete(self, status: domain_entities.TaskStatus) -> None:
86
+ """Delete a TaskStatus."""
87
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
88
+ stmt = delete(db_entities.TaskStatus).where(
89
+ db_entities.TaskStatus.id == status.id,
90
+ )
91
+ await session.execute(stmt)