kodit 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (95) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +53 -23
  3. kodit/application/factories/reporting_factory.py +6 -2
  4. kodit/application/factories/server_factory.py +311 -0
  5. kodit/application/services/code_search_application_service.py +144 -0
  6. kodit/application/services/commit_indexing_application_service.py +543 -0
  7. kodit/application/services/indexing_worker_service.py +13 -44
  8. kodit/application/services/queue_service.py +24 -3
  9. kodit/application/services/reporting.py +0 -2
  10. kodit/application/services/sync_scheduler.py +15 -31
  11. kodit/cli.py +2 -753
  12. kodit/cli_utils.py +2 -9
  13. kodit/config.py +1 -94
  14. kodit/database.py +38 -1
  15. kodit/domain/{entities.py → entities/__init__.py} +50 -195
  16. kodit/domain/entities/git.py +190 -0
  17. kodit/domain/factories/__init__.py +1 -0
  18. kodit/domain/factories/git_repo_factory.py +76 -0
  19. kodit/domain/protocols.py +263 -64
  20. kodit/domain/services/bm25_service.py +5 -1
  21. kodit/domain/services/embedding_service.py +3 -0
  22. kodit/domain/services/git_repository_service.py +429 -0
  23. kodit/domain/services/git_service.py +300 -0
  24. kodit/domain/services/task_status_query_service.py +2 -2
  25. kodit/domain/value_objects.py +83 -114
  26. kodit/infrastructure/api/client/__init__.py +0 -2
  27. kodit/infrastructure/api/v1/__init__.py +0 -4
  28. kodit/infrastructure/api/v1/dependencies.py +92 -46
  29. kodit/infrastructure/api/v1/routers/__init__.py +0 -6
  30. kodit/infrastructure/api/v1/routers/commits.py +271 -0
  31. kodit/infrastructure/api/v1/routers/queue.py +2 -2
  32. kodit/infrastructure/api/v1/routers/repositories.py +282 -0
  33. kodit/infrastructure/api/v1/routers/search.py +31 -14
  34. kodit/infrastructure/api/v1/schemas/__init__.py +0 -24
  35. kodit/infrastructure/api/v1/schemas/commit.py +96 -0
  36. kodit/infrastructure/api/v1/schemas/context.py +2 -0
  37. kodit/infrastructure/api/v1/schemas/repository.py +128 -0
  38. kodit/infrastructure/api/v1/schemas/search.py +12 -9
  39. kodit/infrastructure/api/v1/schemas/snippet.py +58 -0
  40. kodit/infrastructure/api/v1/schemas/tag.py +31 -0
  41. kodit/infrastructure/api/v1/schemas/task_status.py +2 -0
  42. kodit/infrastructure/bm25/local_bm25_repository.py +16 -4
  43. kodit/infrastructure/bm25/vectorchord_bm25_repository.py +68 -52
  44. kodit/infrastructure/cloning/git/git_python_adaptor.py +467 -0
  45. kodit/infrastructure/cloning/git/working_copy.py +1 -1
  46. kodit/infrastructure/embedding/embedding_factory.py +3 -2
  47. kodit/infrastructure/embedding/local_vector_search_repository.py +1 -1
  48. kodit/infrastructure/embedding/vectorchord_vector_search_repository.py +111 -84
  49. kodit/infrastructure/enrichment/litellm_enrichment_provider.py +19 -26
  50. kodit/infrastructure/indexing/fusion_service.py +1 -1
  51. kodit/infrastructure/mappers/git_mapper.py +193 -0
  52. kodit/infrastructure/mappers/snippet_mapper.py +106 -0
  53. kodit/infrastructure/mappers/task_mapper.py +5 -44
  54. kodit/infrastructure/reporting/log_progress.py +8 -5
  55. kodit/infrastructure/reporting/telemetry_progress.py +21 -0
  56. kodit/infrastructure/slicing/slicer.py +32 -31
  57. kodit/infrastructure/sqlalchemy/embedding_repository.py +43 -23
  58. kodit/infrastructure/sqlalchemy/entities.py +394 -158
  59. kodit/infrastructure/sqlalchemy/git_branch_repository.py +263 -0
  60. kodit/infrastructure/sqlalchemy/git_commit_repository.py +337 -0
  61. kodit/infrastructure/sqlalchemy/git_repository.py +252 -0
  62. kodit/infrastructure/sqlalchemy/git_tag_repository.py +257 -0
  63. kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +484 -0
  64. kodit/infrastructure/sqlalchemy/task_repository.py +29 -23
  65. kodit/infrastructure/sqlalchemy/task_status_repository.py +24 -12
  66. kodit/infrastructure/sqlalchemy/unit_of_work.py +10 -14
  67. kodit/mcp.py +12 -30
  68. kodit/migrations/env.py +1 -0
  69. kodit/migrations/versions/04b80f802e0c_foreign_key_review.py +100 -0
  70. kodit/migrations/versions/7f15f878c3a1_add_new_git_entities.py +690 -0
  71. kodit/migrations/versions/f9e5ef5e688f_add_git_commits_number.py +43 -0
  72. kodit/py.typed +0 -0
  73. kodit/utils/dump_openapi.py +7 -4
  74. kodit/utils/path_utils.py +29 -0
  75. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/METADATA +3 -3
  76. kodit-0.5.0.dist-info/RECORD +137 -0
  77. kodit/application/factories/code_indexing_factory.py +0 -195
  78. kodit/application/services/auto_indexing_service.py +0 -99
  79. kodit/application/services/code_indexing_application_service.py +0 -410
  80. kodit/domain/services/index_query_service.py +0 -70
  81. kodit/domain/services/index_service.py +0 -269
  82. kodit/infrastructure/api/client/index_client.py +0 -57
  83. kodit/infrastructure/api/v1/routers/indexes.py +0 -164
  84. kodit/infrastructure/api/v1/schemas/index.py +0 -101
  85. kodit/infrastructure/bm25/bm25_factory.py +0 -28
  86. kodit/infrastructure/cloning/__init__.py +0 -1
  87. kodit/infrastructure/cloning/metadata.py +0 -98
  88. kodit/infrastructure/mappers/index_mapper.py +0 -345
  89. kodit/infrastructure/reporting/tdqm_progress.py +0 -38
  90. kodit/infrastructure/slicing/language_detection_service.py +0 -18
  91. kodit/infrastructure/sqlalchemy/index_repository.py +0 -646
  92. kodit-0.4.3.dist-info/RECORD +0 -125
  93. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/WHEEL +0 -0
  94. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/entry_points.txt +0 -0
  95. {kodit-0.4.3.dist-info → kodit-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,484 @@
1
+ """SQLAlchemy implementation of SnippetRepositoryV2."""
2
+
3
+ import zlib
4
+ from collections.abc import Callable
5
+
6
+ from sqlalchemy import delete, insert, select
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+
9
+ from kodit.domain.entities.git import SnippetV2
10
+ from kodit.domain.protocols import SnippetRepositoryV2
11
+ from kodit.domain.value_objects import MultiSearchRequest
12
+ from kodit.infrastructure.mappers.snippet_mapper import SnippetMapper
13
+ from kodit.infrastructure.sqlalchemy import entities as db_entities
14
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
15
+
16
+
17
+ def create_snippet_v2_repository(
18
+ session_factory: Callable[[], AsyncSession],
19
+ ) -> SnippetRepositoryV2:
20
+ """Create a snippet v2 repository."""
21
+ return SqlAlchemySnippetRepositoryV2(session_factory=session_factory)
22
+
23
+
24
+ class SqlAlchemySnippetRepositoryV2(SnippetRepositoryV2):
25
+ """SQLAlchemy implementation of SnippetRepositoryV2."""
26
+
27
+ def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
28
+ """Initialize the repository."""
29
+ self.session_factory = session_factory
30
+
31
+ @property
32
+ def _mapper(self) -> SnippetMapper:
33
+ return SnippetMapper()
34
+
35
+ async def save_snippets(self, commit_sha: str, snippets: list[SnippetV2]) -> None:
36
+ """Batch save snippets for a commit."""
37
+ if not snippets:
38
+ return
39
+
40
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
41
+ # Bulk operations for better performance
42
+ await self._bulk_save_snippets(session, snippets)
43
+ await self._bulk_create_commit_associations(session, commit_sha, snippets)
44
+ await self._bulk_create_file_associations(session, commit_sha, snippets)
45
+ await self._bulk_update_enrichments(session, snippets)
46
+
47
+ async def _bulk_save_snippets(
48
+ self, session: AsyncSession, snippets: list[SnippetV2]
49
+ ) -> None:
50
+ """Bulk save snippets using efficient batch operations."""
51
+ snippet_shas = [snippet.sha for snippet in snippets]
52
+
53
+ # Get existing snippets in bulk
54
+ existing_snippets_stmt = select(db_entities.SnippetV2.sha).where(
55
+ db_entities.SnippetV2.sha.in_(snippet_shas)
56
+ )
57
+ existing_snippet_shas = set(
58
+ (await session.scalars(existing_snippets_stmt)).all()
59
+ )
60
+
61
+ # Prepare new snippets for bulk insert
62
+ new_snippets = [
63
+ {
64
+ "sha": snippet.sha,
65
+ "content": snippet.content,
66
+ "extension": snippet.extension,
67
+ }
68
+ for snippet in snippets
69
+ if snippet.sha not in existing_snippet_shas
70
+ ]
71
+
72
+ # Bulk insert new snippets in chunks to avoid parameter limits
73
+ if new_snippets:
74
+ chunk_size = 1000 # Conservative chunk size for parameter limits
75
+ for i in range(0, len(new_snippets), chunk_size):
76
+ chunk = new_snippets[i : i + chunk_size]
77
+ stmt = insert(db_entities.SnippetV2).values(chunk)
78
+ await session.execute(stmt)
79
+
80
+ async def _bulk_create_commit_associations(
81
+ self, session: AsyncSession, commit_sha: str, snippets: list[SnippetV2]
82
+ ) -> None:
83
+ """Bulk create commit-snippet associations."""
84
+ snippet_shas = [snippet.sha for snippet in snippets]
85
+
86
+ # Get existing associations in bulk
87
+ existing_associations_stmt = select(
88
+ db_entities.CommitSnippetV2.snippet_sha
89
+ ).where(
90
+ db_entities.CommitSnippetV2.commit_sha == commit_sha,
91
+ db_entities.CommitSnippetV2.snippet_sha.in_(snippet_shas)
92
+ )
93
+ existing_association_shas = set(
94
+ (await session.scalars(existing_associations_stmt)).all()
95
+ )
96
+
97
+ # Prepare new associations for bulk insert
98
+ new_associations = [
99
+ {
100
+ "commit_sha": commit_sha,
101
+ "snippet_sha": snippet.sha,
102
+ }
103
+ for snippet in snippets
104
+ if snippet.sha not in existing_association_shas
105
+ ]
106
+
107
+ # Bulk insert new associations in chunks to avoid parameter limits
108
+ if new_associations:
109
+ chunk_size = 1000 # Conservative chunk size for parameter limits
110
+ for i in range(0, len(new_associations), chunk_size):
111
+ chunk = new_associations[i : i + chunk_size]
112
+ stmt = insert(db_entities.CommitSnippetV2).values(chunk)
113
+ await session.execute(stmt)
114
+
115
+ async def _bulk_create_file_associations(
116
+ self, session: AsyncSession, commit_sha: str, snippets: list[SnippetV2]
117
+ ) -> None:
118
+ """Bulk create snippet-file associations."""
119
+ # Collect all file paths from all snippets
120
+ file_paths = set()
121
+ for snippet in snippets:
122
+ for file in snippet.derives_from:
123
+ file_paths.add(file.path)
124
+
125
+ if not file_paths:
126
+ return
127
+
128
+ # Get existing files in bulk
129
+ existing_files_stmt = select(
130
+ db_entities.GitCommitFile.path,
131
+ db_entities.GitCommitFile.blob_sha
132
+ ).where(
133
+ db_entities.GitCommitFile.commit_sha == commit_sha,
134
+ db_entities.GitCommitFile.path.in_(list(file_paths))
135
+ )
136
+ existing_files_result = await session.execute(existing_files_stmt)
137
+ existing_files_map: dict[str, str] = {
138
+ row[0]: row[1] for row in existing_files_result.fetchall()
139
+ }
140
+
141
+ # Get existing snippet-file associations to avoid duplicates
142
+ snippet_shas = [snippet.sha for snippet in snippets]
143
+ existing_snippet_files_stmt = select(
144
+ db_entities.SnippetV2File.snippet_sha,
145
+ db_entities.SnippetV2File.file_path
146
+ ).where(
147
+ db_entities.SnippetV2File.commit_sha == commit_sha,
148
+ db_entities.SnippetV2File.snippet_sha.in_(snippet_shas)
149
+ )
150
+ existing_snippet_files = set(await session.execute(existing_snippet_files_stmt))
151
+
152
+ # Prepare new file associations
153
+ new_file_associations = []
154
+ for snippet in snippets:
155
+ for file in snippet.derives_from:
156
+ association_key = (snippet.sha, file.path)
157
+ if (association_key not in existing_snippet_files
158
+ and file.path in existing_files_map):
159
+ new_file_associations.append({
160
+ "snippet_sha": snippet.sha,
161
+ "blob_sha": existing_files_map[file.path],
162
+ "commit_sha": commit_sha,
163
+ "file_path": file.path,
164
+ })
165
+
166
+ # Bulk insert new file associations in chunks to avoid parameter limits
167
+ if new_file_associations:
168
+ chunk_size = 1000 # Conservative chunk size for parameter limits
169
+ for i in range(0, len(new_file_associations), chunk_size):
170
+ chunk = new_file_associations[i : i + chunk_size]
171
+ stmt = insert(db_entities.SnippetV2File).values(chunk)
172
+ await session.execute(stmt)
173
+
174
+ async def _bulk_update_enrichments(
175
+ self, session: AsyncSession, snippets: list[SnippetV2]
176
+ ) -> None:
177
+ """Bulk update enrichments for snippets."""
178
+ snippet_shas = [snippet.sha for snippet in snippets]
179
+
180
+ # Get all existing enrichments for these snippets
181
+ existing_enrichments_stmt = select(
182
+ db_entities.Enrichment.snippet_sha,
183
+ db_entities.Enrichment.type,
184
+ db_entities.Enrichment.content
185
+ ).where(
186
+ db_entities.Enrichment.snippet_sha.in_(snippet_shas)
187
+ )
188
+ existing_enrichments = await session.execute(existing_enrichments_stmt)
189
+
190
+ # Create lookup for existing enrichment hashes
191
+ existing_enrichment_map = {}
192
+ for snippet_sha, enrichment_type, content in existing_enrichments:
193
+ content_hash = self._hash_string(content)
194
+ key = (snippet_sha, enrichment_type)
195
+ existing_enrichment_map[key] = content_hash
196
+
197
+ # Collect enrichments to delete and add
198
+ enrichments_to_delete = []
199
+ enrichments_to_add = []
200
+
201
+ for snippet in snippets:
202
+ for enrichment in snippet.enrichments:
203
+ key = (snippet.sha, db_entities.EnrichmentType(enrichment.type.value))
204
+ new_hash = self._hash_string(enrichment.content)
205
+
206
+ if key in existing_enrichment_map:
207
+ if existing_enrichment_map[key] != new_hash:
208
+ # Content changed, mark for deletion and re-addition
209
+ enrichments_to_delete.append(key)
210
+ enrichments_to_add.append({
211
+ "snippet_sha": snippet.sha,
212
+ "type": db_entities.EnrichmentType(enrichment.type.value),
213
+ "content": enrichment.content,
214
+ })
215
+ else:
216
+ # New enrichment
217
+ enrichments_to_add.append({
218
+ "snippet_sha": snippet.sha,
219
+ "type": db_entities.EnrichmentType(enrichment.type.value),
220
+ "content": enrichment.content,
221
+ })
222
+
223
+ # Bulk delete changed enrichments
224
+ if enrichments_to_delete:
225
+ for snippet_sha, enrichment_type in enrichments_to_delete:
226
+ stmt = delete(db_entities.Enrichment).where(
227
+ db_entities.Enrichment.snippet_sha == snippet_sha,
228
+ db_entities.Enrichment.type == enrichment_type,
229
+ )
230
+ await session.execute(stmt)
231
+
232
+ # Bulk insert new/updated enrichments in chunks to avoid parameter limits
233
+ if enrichments_to_add:
234
+ chunk_size = 1000 # Conservative chunk size for parameter limits
235
+ for i in range(0, len(enrichments_to_add), chunk_size):
236
+ chunk = enrichments_to_add[i : i + chunk_size]
237
+ insert_stmt = insert(db_entities.Enrichment).values(chunk)
238
+ await session.execute(insert_stmt)
239
+
240
+ async def _get_or_create_raw_snippet(
241
+ self, session: AsyncSession, commit_sha: str, domain_snippet: SnippetV2
242
+ ) -> db_entities.SnippetV2:
243
+ """Get or create a SnippetV2 in the database."""
244
+ db_snippet = await session.get(db_entities.SnippetV2, domain_snippet.sha)
245
+ if not db_snippet:
246
+ db_snippet = self._mapper.from_domain_snippet_v2(domain_snippet)
247
+ session.add(db_snippet)
248
+ await session.flush()
249
+
250
+ # Associate snippet with commit
251
+ commit_association = db_entities.CommitSnippetV2(
252
+ commit_sha=commit_sha,
253
+ snippet_sha=db_snippet.sha,
254
+ )
255
+ session.add(commit_association)
256
+
257
+ # Associate snippet with files
258
+ for file in domain_snippet.derives_from:
259
+ # Find the file in the database (which should have been created during
260
+ # the scan)
261
+ db_file = await session.get(
262
+ db_entities.GitCommitFile, (commit_sha, file.path)
263
+ )
264
+ if not db_file:
265
+ raise ValueError(
266
+ f"File {file.path} not found for commit {commit_sha}"
267
+ )
268
+ db_association = db_entities.SnippetV2File(
269
+ snippet_sha=db_snippet.sha,
270
+ blob_sha=db_file.blob_sha,
271
+ commit_sha=commit_sha,
272
+ file_path=file.path,
273
+ )
274
+ session.add(db_association)
275
+ return db_snippet
276
+
277
+ async def _update_enrichments_if_changed(
278
+ self,
279
+ session: AsyncSession,
280
+ db_snippet: db_entities.SnippetV2,
281
+ domain_snippet: SnippetV2,
282
+ ) -> None:
283
+ """Update enrichments if they have changed."""
284
+ current_enrichments = await session.scalars(
285
+ select(db_entities.Enrichment).where(
286
+ db_entities.Enrichment.snippet_sha == db_snippet.sha
287
+ )
288
+ )
289
+ current_enrichment_shas = {
290
+ self._hash_string(enrichment.content)
291
+ for enrichment in list(current_enrichments)
292
+ }
293
+ for enrichment in domain_snippet.enrichments:
294
+ if self._hash_string(enrichment.content) in current_enrichment_shas:
295
+ continue
296
+
297
+ # If not present, delete the existing enrichment for this type if it exists
298
+ stmt = delete(db_entities.Enrichment).where(
299
+ db_entities.Enrichment.snippet_sha == db_snippet.sha,
300
+ db_entities.Enrichment.type
301
+ == db_entities.EnrichmentType(enrichment.type.value),
302
+ )
303
+ await session.execute(stmt)
304
+
305
+ db_enrichment = db_entities.Enrichment(
306
+ snippet_sha=db_snippet.sha,
307
+ type=db_entities.EnrichmentType(enrichment.type.value),
308
+ content=enrichment.content,
309
+ )
310
+ session.add(db_enrichment)
311
+
312
+ async def get_snippets_for_commit(self, commit_sha: str) -> list[SnippetV2]:
313
+ """Get all snippets for a specific commit."""
314
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
315
+ # Get snippets for the commit through the association table
316
+ snippet_associations = (
317
+ await session.scalars(
318
+ select(db_entities.CommitSnippetV2).where(
319
+ db_entities.CommitSnippetV2.commit_sha == commit_sha
320
+ )
321
+ )
322
+ ).all()
323
+ if not snippet_associations:
324
+ return []
325
+ db_snippets = (
326
+ await session.scalars(
327
+ select(db_entities.SnippetV2).where(
328
+ db_entities.SnippetV2.sha.in_(
329
+ [
330
+ association.snippet_sha
331
+ for association in snippet_associations
332
+ ]
333
+ )
334
+ )
335
+ )
336
+ ).all()
337
+
338
+ return [
339
+ await self._to_domain_snippet_v2(session, db_snippet)
340
+ for db_snippet in db_snippets
341
+ ]
342
+
343
+ async def delete_snippets_for_commit(self, commit_sha: str) -> None:
344
+ """Delete all snippet associations for a commit."""
345
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
346
+ # Note: We only delete the commit-snippet associations,
347
+ # not the snippets themselves as they might be used by other commits
348
+ stmt = delete(db_entities.CommitSnippetV2).where(
349
+ db_entities.CommitSnippetV2.commit_sha == commit_sha
350
+ )
351
+ await session.execute(stmt)
352
+
353
+ def _hash_string(self, string: str) -> int:
354
+ """Hash a string."""
355
+ return zlib.crc32(string.encode())
356
+
357
+ async def search(self, request: MultiSearchRequest) -> list[SnippetV2]:
358
+ """Search snippets with filters."""
359
+ raise NotImplementedError("Not implemented")
360
+
361
+ # Build base query joining all necessary tables
362
+ query = (
363
+ select(
364
+ db_entities.SnippetV2,
365
+ db_entities.GitCommit,
366
+ db_entities.GitFile,
367
+ db_entities.GitRepo,
368
+ )
369
+ .join(
370
+ db_entities.CommitSnippetV2,
371
+ db_entities.SnippetV2.sha == db_entities.CommitSnippetV2.snippet_sha,
372
+ )
373
+ .join(
374
+ db_entities.GitCommit,
375
+ db_entities.CommitSnippetV2.commit_sha
376
+ == db_entities.GitCommit.commit_sha,
377
+ )
378
+ .join(
379
+ db_entities.SnippetV2File,
380
+ db_entities.SnippetV2.sha == db_entities.SnippetV2File.snippet_sha,
381
+ )
382
+ .join(
383
+ db_entities.GitCommitFile,
384
+ db_entities.SnippetV2.sha == db_entities.Enrichment.snippet_sha,
385
+ )
386
+ .join(
387
+ db_entities.GitFile,
388
+ db_entities.SnippetV2File.file_blob_sha == db_entities.GitFile.blob_sha,
389
+ )
390
+ .join(
391
+ db_entities.GitRepo,
392
+ db_entities.GitCommitFile.file_blob_sha == db_entities.GitRepo.id,
393
+ )
394
+ )
395
+
396
+ # Apply filters if provided
397
+ if request.filters:
398
+ if request.filters.source_repo:
399
+ query = query.where(
400
+ db_entities.GitRepo.sanitized_remote_uri.ilike(
401
+ f"%{request.filters.source_repo}%"
402
+ )
403
+ )
404
+
405
+ if request.filters.file_path:
406
+ query = query.where(
407
+ db_entities.GitFile.path.ilike(f"%{request.filters.file_path}%")
408
+ )
409
+
410
+ # TODO(Phil): Double check that git timestamps are correctly populated
411
+ if request.filters.created_after:
412
+ query = query.where(
413
+ db_entities.GitFile.created_at >= request.filters.created_after
414
+ )
415
+
416
+ if request.filters.created_before:
417
+ query = query.where(
418
+ db_entities.GitFile.created_at <= request.filters.created_before
419
+ )
420
+
421
+ # Apply limit
422
+ query = query.limit(request.top_k)
423
+
424
+ # Execute query
425
+ async with SqlAlchemyUnitOfWork(self.session_factory):
426
+ result = await self._session.scalars(query)
427
+ db_snippets = result.all()
428
+
429
+ return [
430
+ self._mapper.to_domain_snippet_v2(
431
+ db_snippet=snippet,
432
+ derives_from=git_file,
433
+ db_enrichments=[],
434
+ )
435
+ for snippet, git_commit, git_file, git_repo in db_snippets
436
+ ]
437
+
438
+ async def get_by_ids(self, ids: list[str]) -> list[SnippetV2]:
439
+ """Get snippets by their IDs."""
440
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
441
+ # Get snippets for the commit through the association table
442
+ db_snippets = (
443
+ await session.scalars(
444
+ select(db_entities.SnippetV2).where(
445
+ db_entities.SnippetV2.sha.in_(ids)
446
+ )
447
+ )
448
+ ).all()
449
+
450
+ return [
451
+ await self._to_domain_snippet_v2(session, db_snippet)
452
+ for db_snippet in db_snippets
453
+ ]
454
+
455
+ async def _to_domain_snippet_v2(
456
+ self, session: AsyncSession, db_snippet: db_entities.SnippetV2
457
+ ) -> SnippetV2:
458
+ """Convert a SQLAlchemy SnippetV2 to a domain SnippetV2."""
459
+ # Files it derives from
460
+ db_files = await session.scalars(
461
+ select(db_entities.GitCommitFile)
462
+ .join(
463
+ db_entities.SnippetV2File,
464
+ (db_entities.GitCommitFile.path == db_entities.SnippetV2File.file_path)
465
+ & (
466
+ db_entities.GitCommitFile.commit_sha
467
+ == db_entities.SnippetV2File.commit_sha
468
+ ),
469
+ )
470
+ .where(db_entities.SnippetV2File.snippet_sha == db_snippet.sha)
471
+ )
472
+
473
+ # Enrichments related to this snippet
474
+ db_enrichments = await session.scalars(
475
+ select(db_entities.Enrichment).where(
476
+ db_entities.Enrichment.snippet_sha == db_snippet.sha
477
+ )
478
+ )
479
+
480
+ return self._mapper.to_domain_snippet_v2(
481
+ db_snippet=db_snippet,
482
+ db_files=list(db_files),
483
+ db_enrichments=list(db_enrichments),
484
+ )
@@ -8,8 +8,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
8
8
 
9
9
  from kodit.domain.entities import Task
10
10
  from kodit.domain.protocols import TaskRepository
11
- from kodit.domain.value_objects import TaskType
12
- from kodit.infrastructure.mappers.task_mapper import TaskMapper, TaskTypeMapper
11
+ from kodit.domain.value_objects import TaskOperation
12
+ from kodit.infrastructure.mappers.task_mapper import TaskMapper
13
13
  from kodit.infrastructure.sqlalchemy import entities as db_entities
14
14
  from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
15
15
 
@@ -18,16 +18,15 @@ def create_task_repository(
18
18
  session_factory: Callable[[], AsyncSession],
19
19
  ) -> TaskRepository:
20
20
  """Create an index repository."""
21
- uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
22
- return SqlAlchemyTaskRepository(uow)
21
+ return SqlAlchemyTaskRepository(session_factory=session_factory)
23
22
 
24
23
 
25
24
  class SqlAlchemyTaskRepository(TaskRepository):
26
25
  """Repository for task persistence using the existing Task entity."""
27
26
 
28
- def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
27
+ def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
29
28
  """Initialize the repository."""
30
- self.uow = uow
29
+ self.session_factory = session_factory
31
30
  self.log = structlog.get_logger(__name__)
32
31
 
33
32
  async def add(
@@ -35,39 +34,48 @@ class SqlAlchemyTaskRepository(TaskRepository):
35
34
  task: Task,
36
35
  ) -> None:
37
36
  """Create a new task in the database."""
38
- async with self.uow:
39
- self.uow.session.add(TaskMapper.from_domain_task(task))
37
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
38
+ session.add(TaskMapper.from_domain_task(task))
40
39
 
41
40
  async def get(self, task_id: str) -> Task | None:
42
41
  """Get a task by ID."""
43
- async with self.uow:
42
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
44
43
  stmt = select(db_entities.Task).where(db_entities.Task.dedup_key == task_id)
45
- result = await self.uow.session.execute(stmt)
44
+ result = await session.execute(stmt)
46
45
  db_task = result.scalar_one_or_none()
47
46
  if not db_task:
48
47
  return None
49
48
  return TaskMapper.to_domain_task(db_task)
50
49
 
51
- async def take(self) -> Task | None:
50
+ async def next(self) -> Task | None:
52
51
  """Take a task for processing and remove it from the database."""
53
- async with self.uow:
52
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
54
53
  stmt = (
55
54
  select(db_entities.Task)
56
55
  .order_by(db_entities.Task.priority.desc(), db_entities.Task.created_at)
57
56
  .limit(1)
58
57
  )
59
- result = await self.uow.session.execute(stmt)
58
+ result = await session.execute(stmt)
60
59
  db_task = result.scalar_one_or_none()
61
60
  if not db_task:
62
61
  return None
63
- await self.uow.session.delete(db_task)
64
62
  return TaskMapper.to_domain_task(db_task)
65
63
 
64
+ async def remove(self, task: Task) -> None:
65
+ """Remove a task from the database."""
66
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
67
+ db_task = await session.scalar(
68
+ select(db_entities.Task).where(db_entities.Task.dedup_key == task.id)
69
+ )
70
+ if not db_task:
71
+ raise ValueError(f"Task not found: {task.id}")
72
+ await session.delete(db_task)
73
+
66
74
  async def update(self, task: Task) -> None:
67
75
  """Update a task in the database."""
68
- async with self.uow:
76
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
69
77
  stmt = select(db_entities.Task).where(db_entities.Task.dedup_key == task.id)
70
- result = await self.uow.session.execute(stmt)
78
+ result = await session.execute(stmt)
71
79
  db_task = result.scalar_one_or_none()
72
80
 
73
81
  if not db_task:
@@ -76,21 +84,19 @@ class SqlAlchemyTaskRepository(TaskRepository):
76
84
  db_task.priority = task.priority
77
85
  db_task.payload = task.payload
78
86
 
79
- async def list(self, task_type: TaskType | None = None) -> list[Task]:
87
+ async def list(self, task_operation: TaskOperation | None = None) -> list[Task]:
80
88
  """List tasks with optional status filter."""
81
- async with self.uow:
89
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
82
90
  stmt = select(db_entities.Task)
83
91
 
84
- if task_type:
85
- stmt = stmt.where(
86
- db_entities.Task.type == TaskTypeMapper.from_domain_type(task_type)
87
- )
92
+ if task_operation:
93
+ stmt = stmt.where(db_entities.Task.type == task_operation.value)
88
94
 
89
95
  stmt = stmt.order_by(
90
96
  db_entities.Task.priority.desc(), db_entities.Task.created_at
91
97
  )
92
98
 
93
- result = await self.uow.session.execute(stmt)
99
+ result = await session.execute(stmt)
94
100
  records = result.scalars().all()
95
101
 
96
102
  # Convert to domain entities
@@ -17,32 +17,44 @@ def create_task_status_repository(
17
17
  session_factory: Callable[[], AsyncSession],
18
18
  ) -> TaskStatusRepository:
19
19
  """Create an index repository."""
20
- uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
21
- return SqlAlchemyTaskStatusRepository(uow)
20
+ return SqlAlchemyTaskStatusRepository(session_factory=session_factory)
22
21
 
23
22
 
24
23
  class SqlAlchemyTaskStatusRepository(TaskStatusRepository):
25
24
  """Repository for persisting TaskStatus entities."""
26
25
 
27
- def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
26
+ def __init__(self, session_factory: Callable[[], AsyncSession]) -> None:
28
27
  """Initialize the repository."""
29
- self.uow = uow
30
- self.log = structlog.get_logger(__name__)
28
+ self.session_factory = session_factory
31
29
  self.mapper = TaskStatusMapper()
30
+ self.log = structlog.get_logger(__name__)
32
31
 
33
32
  async def save(self, status: domain_entities.TaskStatus) -> None:
34
33
  """Save a TaskStatus to database."""
35
- async with self.uow:
34
+ # If this task has a parent, ensure the parent exists in the database first
35
+ if status.parent is not None:
36
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
37
+ parent_stmt = select(db_entities.TaskStatus).where(
38
+ db_entities.TaskStatus.id == status.parent.id,
39
+ )
40
+ parent_result = await session.execute(parent_stmt)
41
+ existing_parent = parent_result.scalar_one_or_none()
42
+
43
+ if not existing_parent:
44
+ # Recursively save the parent first
45
+ await self.save(status.parent)
46
+
47
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
36
48
  # Convert domain entity to database entity
37
49
  db_status = self.mapper.from_domain_task_status(status)
38
50
  stmt = select(db_entities.TaskStatus).where(
39
51
  db_entities.TaskStatus.id == db_status.id,
40
52
  )
41
- result = await self.uow.session.execute(stmt)
53
+ result = await session.execute(stmt)
42
54
  existing = result.scalar_one_or_none()
43
55
 
44
56
  if not existing:
45
- self.uow.session.add(db_status)
57
+ session.add(db_status)
46
58
  else:
47
59
  # Update existing record with new values
48
60
  existing.operation = db_status.operation
@@ -59,12 +71,12 @@ class SqlAlchemyTaskStatusRepository(TaskStatusRepository):
59
71
  self, trackable_type: str, trackable_id: int
60
72
  ) -> list[domain_entities.TaskStatus]:
61
73
  """Load TaskStatus entities with hierarchy from database."""
62
- async with self.uow:
74
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
63
75
  stmt = select(db_entities.TaskStatus).where(
64
76
  db_entities.TaskStatus.trackable_id == trackable_id,
65
77
  db_entities.TaskStatus.trackable_type == trackable_type,
66
78
  )
67
- result = await self.uow.session.execute(stmt)
79
+ result = await session.execute(stmt)
68
80
  db_statuses = list(result.scalars().all())
69
81
 
70
82
  # Use mapper to convert and reconstruct hierarchy
@@ -72,8 +84,8 @@ class SqlAlchemyTaskStatusRepository(TaskStatusRepository):
72
84
 
73
85
  async def delete(self, status: domain_entities.TaskStatus) -> None:
74
86
  """Delete a TaskStatus."""
75
- async with self.uow:
87
+ async with SqlAlchemyUnitOfWork(self.session_factory) as session:
76
88
  stmt = delete(db_entities.TaskStatus).where(
77
89
  db_entities.TaskStatus.id == status.id,
78
90
  )
79
- await self.uow.session.execute(stmt)
91
+ await session.execute(stmt)