kodit 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (50) hide show
  1. kodit/_version.py +16 -3
  2. kodit/app.py +10 -3
  3. kodit/application/factories/code_indexing_factory.py +54 -7
  4. kodit/application/factories/reporting_factory.py +27 -0
  5. kodit/application/services/auto_indexing_service.py +16 -4
  6. kodit/application/services/code_indexing_application_service.py +115 -133
  7. kodit/application/services/indexing_worker_service.py +18 -20
  8. kodit/application/services/queue_service.py +15 -12
  9. kodit/application/services/reporting.py +86 -0
  10. kodit/application/services/sync_scheduler.py +21 -20
  11. kodit/cli.py +14 -18
  12. kodit/config.py +35 -17
  13. kodit/database.py +2 -1
  14. kodit/domain/protocols.py +9 -1
  15. kodit/domain/services/bm25_service.py +1 -6
  16. kodit/domain/services/index_service.py +22 -58
  17. kodit/domain/value_objects.py +57 -9
  18. kodit/infrastructure/api/v1/__init__.py +2 -2
  19. kodit/infrastructure/api/v1/dependencies.py +23 -10
  20. kodit/infrastructure/api/v1/routers/__init__.py +2 -1
  21. kodit/infrastructure/api/v1/routers/queue.py +76 -0
  22. kodit/infrastructure/api/v1/schemas/queue.py +35 -0
  23. kodit/infrastructure/cloning/git/working_copy.py +36 -7
  24. kodit/infrastructure/embedding/embedding_factory.py +18 -19
  25. kodit/infrastructure/embedding/embedding_providers/litellm_embedding_provider.py +156 -0
  26. kodit/infrastructure/enrichment/enrichment_factory.py +7 -16
  27. kodit/infrastructure/enrichment/{openai_enrichment_provider.py → litellm_enrichment_provider.py} +70 -60
  28. kodit/infrastructure/git/git_utils.py +9 -2
  29. kodit/infrastructure/mappers/index_mapper.py +1 -0
  30. kodit/infrastructure/reporting/__init__.py +1 -0
  31. kodit/infrastructure/reporting/log_progress.py +65 -0
  32. kodit/infrastructure/reporting/tdqm_progress.py +73 -0
  33. kodit/infrastructure/sqlalchemy/embedding_repository.py +47 -68
  34. kodit/infrastructure/sqlalchemy/entities.py +28 -2
  35. kodit/infrastructure/sqlalchemy/index_repository.py +274 -236
  36. kodit/infrastructure/sqlalchemy/task_repository.py +55 -39
  37. kodit/infrastructure/sqlalchemy/unit_of_work.py +59 -0
  38. kodit/log.py +6 -0
  39. kodit/mcp.py +10 -2
  40. {kodit-0.4.0.dist-info → kodit-0.4.2.dist-info}/METADATA +3 -2
  41. {kodit-0.4.0.dist-info → kodit-0.4.2.dist-info}/RECORD +44 -41
  42. kodit/domain/interfaces.py +0 -27
  43. kodit/infrastructure/embedding/embedding_providers/openai_embedding_provider.py +0 -183
  44. kodit/infrastructure/ui/__init__.py +0 -1
  45. kodit/infrastructure/ui/progress.py +0 -170
  46. kodit/infrastructure/ui/spinner.py +0 -74
  47. kodit/reporting.py +0 -78
  48. {kodit-0.4.0.dist-info → kodit-0.4.2.dist-info}/WHEEL +0 -0
  49. {kodit-0.4.0.dist-info → kodit-0.4.2.dist-info}/entry_points.txt +0 -0
  50. {kodit-0.4.0.dist-info → kodit-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,7 @@
1
1
  """SQLAlchemy implementation of IndexRepository using Index aggregate root."""
2
2
 
3
- from collections.abc import Sequence
3
+ from collections.abc import Callable, Sequence
4
+ from datetime import UTC, datetime
4
5
  from typing import cast
5
6
 
6
7
  from pydantic import AnyUrl
@@ -15,6 +16,15 @@ from kodit.domain.value_objects import (
15
16
  )
16
17
  from kodit.infrastructure.mappers.index_mapper import IndexMapper
17
18
  from kodit.infrastructure.sqlalchemy import entities as db_entities
19
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
20
+
21
+
22
+ def create_index_repository(
23
+ session_factory: Callable[[], AsyncSession],
24
+ ) -> IndexRepository:
25
+ """Create an index repository."""
26
+ uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
27
+ return SqlAlchemyIndexRepository(uow)
18
28
 
19
29
 
20
30
  class SqlAlchemyIndexRepository(IndexRepository):
@@ -27,120 +37,134 @@ class SqlAlchemyIndexRepository(IndexRepository):
27
37
  - Snippet entities with their contents
28
38
  """
29
39
 
30
- def __init__(self, session: AsyncSession) -> None:
40
+ def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
31
41
  """Initialize the repository."""
32
- self._session = session
33
- self._mapper = IndexMapper(session)
42
+ self.uow = uow
43
+
44
+ @property
45
+ def _mapper(self) -> IndexMapper:
46
+ if self.uow.session is None:
47
+ raise RuntimeError("UnitOfWork must be used within async context")
48
+ return IndexMapper(self.uow.session)
49
+
50
+ @property
51
+ def _session(self) -> AsyncSession:
52
+ if self.uow.session is None:
53
+ raise RuntimeError("UnitOfWork must be used within async context")
54
+ return self.uow.session
34
55
 
35
56
  async def create(
36
57
  self, uri: AnyUrl, working_copy: domain_entities.WorkingCopy
37
58
  ) -> domain_entities.Index:
38
59
  """Create an index with all the files and authors in the working copy."""
39
- # 1. Verify that a source with this URI does not exist
40
- existing_source = await self._get_source_by_uri(uri)
41
- if existing_source:
42
- # Check if index already exists for this source
43
- existing_index = await self._get_index_by_source_id(existing_source.id)
44
- if existing_index:
45
- return await self._mapper.to_domain_index(existing_index)
46
-
47
- # 2. Create the source
48
- db_source = db_entities.Source(
49
- uri=str(uri),
50
- cloned_path=str(working_copy.cloned_path),
51
- source_type=db_entities.SourceType(working_copy.source_type.value),
52
- )
53
- self._session.add(db_source)
54
- await self._session.flush() # Get source ID
55
-
56
- # 3. Create a set of unique authors
57
- unique_authors = {}
58
- for domain_file in working_copy.files:
59
- for author in domain_file.authors:
60
- key = (author.name, author.email)
61
- if key not in unique_authors:
62
- unique_authors[key] = author
63
-
64
- # 4. Create authors if they don't exist and store their IDs
65
- author_id_map = {}
66
- for domain_author in unique_authors.values():
67
- db_author = await self._find_or_create_author(domain_author)
68
- author_id_map[(domain_author.name, domain_author.email)] = db_author.id
69
-
70
- # 5. Create files
71
- for domain_file in working_copy.files:
72
- db_file = db_entities.File(
73
- created_at=domain_file.created_at or db_source.created_at,
74
- updated_at=domain_file.updated_at or db_source.updated_at,
75
- source_id=db_source.id,
76
- mime_type=domain_file.mime_type,
77
- uri=str(domain_file.uri),
78
- cloned_path=str(domain_file.uri), # Use URI as cloned path
79
- sha256=domain_file.sha256,
80
- size_bytes=0, # Deprecated
81
- extension="", # Deprecated
82
- file_processing_status=domain_file.file_processing_status.value,
60
+ async with self.uow:
61
+ # 1. Verify that a source with this URI does not exist
62
+ existing_source = await self._get_source_by_uri(uri)
63
+ if existing_source:
64
+ # Check if index already exists for this source
65
+ existing_index = await self._get_index_by_source_id(existing_source.id)
66
+ if existing_index:
67
+ return await self._mapper.to_domain_index(existing_index)
68
+
69
+ # 2. Create the source
70
+ db_source = db_entities.Source(
71
+ uri=str(uri),
72
+ cloned_path=str(working_copy.cloned_path),
73
+ source_type=db_entities.SourceType(working_copy.source_type.value),
83
74
  )
84
- self._session.add(db_file)
85
- await self._session.flush() # Get file ID
86
-
87
- # 6. Create author_file_mappings
88
- for author in domain_file.authors:
89
- author_id = author_id_map[(author.name, author.email)]
90
- mapping = db_entities.AuthorFileMapping(
91
- author_id=author_id, file_id=db_file.id
75
+ self._session.add(db_source)
76
+ await self._session.flush() # Get source ID
77
+
78
+ # 3. Create a set of unique authors
79
+ unique_authors = {}
80
+ for domain_file in working_copy.files:
81
+ for author in domain_file.authors:
82
+ key = (author.name, author.email)
83
+ if key not in unique_authors:
84
+ unique_authors[key] = author
85
+
86
+ # 4. Create authors if they don't exist and store their IDs
87
+ author_id_map = {}
88
+ for domain_author in unique_authors.values():
89
+ db_author = await self._find_or_create_author(domain_author)
90
+ author_id_map[(domain_author.name, domain_author.email)] = db_author.id
91
+
92
+ # 5. Create files
93
+ for domain_file in working_copy.files:
94
+ db_file = db_entities.File(
95
+ created_at=domain_file.created_at or db_source.created_at,
96
+ updated_at=domain_file.updated_at or db_source.updated_at,
97
+ source_id=db_source.id,
98
+ mime_type=domain_file.mime_type,
99
+ uri=str(domain_file.uri),
100
+ cloned_path=str(domain_file.uri), # Use URI as cloned path
101
+ sha256=domain_file.sha256,
102
+ size_bytes=0, # Deprecated
103
+ extension="", # Deprecated
104
+ file_processing_status=domain_file.file_processing_status.value,
92
105
  )
93
- await self._upsert_author_file_mapping(mapping)
106
+ self._session.add(db_file)
107
+ await self._session.flush() # Get file ID
108
+
109
+ # 6. Create author_file_mappings
110
+ for author in domain_file.authors:
111
+ author_id = author_id_map[(author.name, author.email)]
112
+ mapping = db_entities.AuthorFileMapping(
113
+ author_id=author_id, file_id=db_file.id
114
+ )
115
+ await self._upsert_author_file_mapping(mapping)
94
116
 
95
- # 7. Create the index
96
- db_index = db_entities.Index(source_id=db_source.id)
97
- self._session.add(db_index)
98
- await self._session.flush() # Get index ID
117
+ # 7. Create the index
118
+ db_index = db_entities.Index(source_id=db_source.id)
119
+ self._session.add(db_index)
120
+ await self._session.flush() # Get index ID
99
121
 
100
- # 8. Return the new index
101
- return await self._mapper.to_domain_index(db_index)
122
+ # 8. Return the new index
123
+ return await self._mapper.to_domain_index(db_index)
102
124
 
103
125
  async def get(self, index_id: int) -> domain_entities.Index | None:
104
126
  """Get an index by ID."""
105
- db_index = await self._session.get(db_entities.Index, index_id)
106
- if not db_index:
107
- return None
127
+ async with self.uow:
128
+ db_index = await self._session.get(db_entities.Index, index_id)
129
+ if not db_index:
130
+ return None
108
131
 
109
- return await self._mapper.to_domain_index(db_index)
132
+ return await self._mapper.to_domain_index(db_index)
110
133
 
111
134
  async def get_by_uri(self, uri: AnyUrl) -> domain_entities.Index | None:
112
135
  """Get an index by source URI."""
113
- db_source = await self._get_source_by_uri(uri)
114
- if not db_source:
115
- return None
136
+ async with self.uow:
137
+ db_source = await self._get_source_by_uri(uri)
138
+ if not db_source:
139
+ return None
116
140
 
117
- db_index = await self._get_index_by_source_id(db_source.id)
118
- if not db_index:
119
- return None
141
+ db_index = await self._get_index_by_source_id(db_source.id)
142
+ if not db_index:
143
+ return None
120
144
 
121
- return await self._mapper.to_domain_index(db_index)
145
+ return await self._mapper.to_domain_index(db_index)
122
146
 
123
147
  async def all(self) -> list[domain_entities.Index]:
124
148
  """List all indexes."""
125
- stmt = select(db_entities.Index)
126
- result = await self._session.scalars(stmt)
127
- db_indexes = result.all()
149
+ async with self.uow:
150
+ stmt = select(db_entities.Index)
151
+ result = await self._session.scalars(stmt)
152
+ db_indexes = result.all()
128
153
 
129
- domain_indexes = []
130
- for db_index in db_indexes:
131
- domain_index = await self._mapper.to_domain_index(db_index)
132
- domain_indexes.append(domain_index)
154
+ domain_indexes = []
155
+ for db_index in db_indexes:
156
+ domain_index = await self._mapper.to_domain_index(db_index)
157
+ domain_indexes.append(domain_index)
133
158
 
134
- return domain_indexes
159
+ return domain_indexes
135
160
 
136
161
  async def update_index_timestamp(self, index_id: int) -> None:
137
162
  """Update the timestamp of an index."""
138
- from datetime import UTC, datetime
139
-
140
- db_index = await self._session.get(db_entities.Index, index_id)
141
- if db_index:
163
+ async with self.uow:
164
+ db_index = await self._session.get(db_entities.Index, index_id)
165
+ if not db_index:
166
+ raise ValueError(f"Index {index_id} not found")
142
167
  db_index.updated_at = datetime.now(UTC)
143
- # SQLAlchemy will automatically track this change
144
168
 
145
169
  async def add_snippets(
146
170
  self, index_id: int, snippets: list[domain_entities.Snippet]
@@ -152,17 +176,18 @@ class SqlAlchemyIndexRepository(IndexRepository):
152
176
  if not snippets:
153
177
  return
154
178
 
155
- # Validate the index exists
156
- db_index = await self._session.get(db_entities.Index, index_id)
157
- if not db_index:
158
- raise ValueError(f"Index {index_id} not found")
179
+ async with self.uow:
180
+ # Validate the index exists
181
+ db_index = await self._session.get(db_entities.Index, index_id)
182
+ if not db_index:
183
+ raise ValueError(f"Index {index_id} not found")
159
184
 
160
- # Convert domain snippets to database entities
161
- for domain_snippet in snippets:
162
- db_snippet = await self._mapper.from_domain_snippet(
163
- domain_snippet, index_id
164
- )
165
- self._session.add(db_snippet)
185
+ # Convert domain snippets to database entities
186
+ for domain_snippet in snippets:
187
+ db_snippet = await self._mapper.from_domain_snippet(
188
+ domain_snippet, index_id
189
+ )
190
+ self._session.add(db_snippet)
166
191
 
167
192
  async def update_snippets(
168
193
  self, index_id: int, snippets: list[domain_entities.Snippet]
@@ -175,27 +200,30 @@ class SqlAlchemyIndexRepository(IndexRepository):
175
200
  if not snippets:
176
201
  return
177
202
 
178
- # Validate the index exists
179
- db_index = await self._session.get(db_entities.Index, index_id)
180
- if not db_index:
181
- raise ValueError(f"Index {index_id} not found")
203
+ async with self.uow:
204
+ # Validate the index exists
205
+ db_index = await self._session.get(db_entities.Index, index_id)
206
+ if not db_index:
207
+ raise ValueError(f"Index {index_id} not found")
182
208
 
183
- # Update each snippet
184
- for domain_snippet in snippets:
185
- if not domain_snippet.id:
186
- raise ValueError("Snippet must have an ID for update")
209
+ # Update each snippet
210
+ for domain_snippet in snippets:
211
+ if not domain_snippet.id:
212
+ raise ValueError("Snippet must have an ID for update")
187
213
 
188
- # Get the existing snippet
189
- db_snippet = await self._session.get(db_entities.Snippet, domain_snippet.id)
190
- if not db_snippet:
191
- raise ValueError(f"Snippet {domain_snippet.id} not found")
214
+ # Get the existing snippet
215
+ db_snippet = await self._session.get(
216
+ db_entities.Snippet, domain_snippet.id
217
+ )
218
+ if not db_snippet:
219
+ raise ValueError(f"Snippet {domain_snippet.id} not found")
192
220
 
193
- db_snippet.content = domain_snippet.original_text()
194
- db_snippet.summary = domain_snippet.summary_text()
221
+ db_snippet.content = domain_snippet.original_text()
222
+ db_snippet.summary = domain_snippet.summary_text()
195
223
 
196
- # Update timestamps if provided
197
- if domain_snippet.updated_at:
198
- db_snippet.updated_at = domain_snippet.updated_at
224
+ # Update timestamps if provided
225
+ if domain_snippet.updated_at:
226
+ db_snippet.updated_at = domain_snippet.updated_at
199
227
 
200
228
  async def search( # noqa: C901
201
229
  self, request: MultiSearchRequest
@@ -258,71 +286,77 @@ class SqlAlchemyIndexRepository(IndexRepository):
258
286
  query = query.limit(request.top_k)
259
287
 
260
288
  # Execute query
261
- result = await self._session.scalars(query)
262
- db_snippets = result.all()
263
-
264
- # Convert to SnippetWithContext
265
- snippet_contexts = []
266
- for db_snippet in db_snippets:
267
- # Get the file for this snippet
268
- db_file = await self._session.get(db_entities.File, db_snippet.file_id)
269
- if not db_file:
270
- continue
271
-
272
- # Get the source for this file
273
- db_source = await self._session.get(db_entities.Source, db_file.source_id)
274
- if not db_source:
275
- continue
276
-
277
- domain_file = await self._mapper.to_domain_file(db_file)
278
- snippet_context = SnippetWithContext(
279
- source=await self._mapper.to_domain_source(db_source),
280
- file=domain_file,
281
- authors=domain_file.authors,
282
- snippet=await self._mapper.to_domain_snippet(
283
- db_snippet=db_snippet, domain_files=[domain_file]
284
- ),
285
- )
286
- snippet_contexts.append(snippet_context)
289
+ async with self.uow:
290
+ result = await self._session.scalars(query)
291
+ db_snippets = result.all()
292
+
293
+ # Convert to SnippetWithContext
294
+ snippet_contexts = []
295
+ for db_snippet in db_snippets:
296
+ # Get the file for this snippet
297
+ db_file = await self._session.get(db_entities.File, db_snippet.file_id)
298
+ if not db_file:
299
+ continue
300
+
301
+ # Get the source for this file
302
+ db_source = await self._session.get(
303
+ db_entities.Source, db_file.source_id
304
+ )
305
+ if not db_source:
306
+ continue
307
+
308
+ domain_file = await self._mapper.to_domain_file(db_file)
309
+ snippet_context = SnippetWithContext(
310
+ source=await self._mapper.to_domain_source(db_source),
311
+ file=domain_file,
312
+ authors=domain_file.authors,
313
+ snippet=await self._mapper.to_domain_snippet(
314
+ db_snippet=db_snippet, domain_files=[domain_file]
315
+ ),
316
+ )
317
+ snippet_contexts.append(snippet_context)
287
318
 
288
- return snippet_contexts
319
+ return snippet_contexts
289
320
 
290
321
  async def get_snippets_by_ids(self, ids: list[int]) -> list[SnippetWithContext]:
291
322
  """Get snippets by their IDs."""
292
323
  if not ids:
293
324
  return []
294
325
 
295
- # Query snippets by IDs
296
- query = select(db_entities.Snippet).where(db_entities.Snippet.id.in_(ids))
326
+ async with self.uow:
327
+ # Query snippets by IDs
328
+ query = select(db_entities.Snippet).where(db_entities.Snippet.id.in_(ids))
297
329
 
298
- result = await self._session.scalars(query)
299
- db_snippets = result.all()
330
+ result = await self._session.scalars(query)
331
+ db_snippets = result.all()
300
332
 
301
- # Convert to SnippetWithContext using similar logic as search
302
- snippet_contexts = []
303
- for db_snippet in db_snippets:
304
- # Get the file for this snippet
305
- db_file = await self._session.get(db_entities.File, db_snippet.file_id)
306
- if not db_file:
307
- continue
333
+ # Convert to SnippetWithContext using similar logic as search
334
+ snippet_contexts = []
335
+ for db_snippet in db_snippets:
336
+ # Get the file for this snippet
337
+ db_file = await self._session.get(db_entities.File, db_snippet.file_id)
338
+ if not db_file:
339
+ continue
308
340
 
309
- # Get the source for this file
310
- db_source = await self._session.get(db_entities.Source, db_file.source_id)
311
- if not db_source:
312
- continue
313
-
314
- domain_file = await self._mapper.to_domain_file(db_file)
315
- snippet_context = SnippetWithContext(
316
- source=await self._mapper.to_domain_source(db_source),
317
- file=domain_file,
318
- authors=domain_file.authors,
319
- snippet=await self._mapper.to_domain_snippet(
320
- db_snippet=db_snippet, domain_files=[domain_file]
321
- ),
322
- )
323
- snippet_contexts.append(snippet_context)
341
+ # Get the source for this file
342
+ db_source = await self._session.get(
343
+ db_entities.Source, db_file.source_id
344
+ )
345
+ if not db_source:
346
+ continue
347
+
348
+ domain_file = await self._mapper.to_domain_file(db_file)
349
+ snippet_context = SnippetWithContext(
350
+ source=await self._mapper.to_domain_source(db_source),
351
+ file=domain_file,
352
+ authors=domain_file.authors,
353
+ snippet=await self._mapper.to_domain_snippet(
354
+ db_snippet=db_snippet, domain_files=[domain_file]
355
+ ),
356
+ )
357
+ snippet_contexts.append(snippet_context)
324
358
 
325
- return snippet_contexts
359
+ return snippet_contexts
326
360
 
327
361
  async def _get_source_by_uri(self, uri: AnyUrl) -> db_entities.Source | None:
328
362
  """Get source by URI."""
@@ -379,25 +413,26 @@ class SqlAlchemyIndexRepository(IndexRepository):
379
413
 
380
414
  async def delete_snippets(self, index_id: int) -> None:
381
415
  """Delete all snippets from an index."""
382
- # First get all snippets for this index
383
- stmt = select(db_entities.Snippet).where(
384
- db_entities.Snippet.index_id == index_id
385
- )
386
- result = await self._session.scalars(stmt)
387
- snippets = result.all()
388
-
389
- # Delete all embeddings for these snippets
390
- for snippet in snippets:
391
- embedding_stmt = delete(db_entities.Embedding).where(
392
- db_entities.Embedding.snippet_id == snippet.id
416
+ async with self.uow:
417
+ # First get all snippets for this index
418
+ stmt = select(db_entities.Snippet).where(
419
+ db_entities.Snippet.index_id == index_id
393
420
  )
394
- await self._session.execute(embedding_stmt)
421
+ result = await self._session.scalars(stmt)
422
+ snippets = result.all()
395
423
 
396
- # Now delete the snippets
397
- snippet_stmt = delete(db_entities.Snippet).where(
398
- db_entities.Snippet.index_id == index_id
399
- )
400
- await self._session.execute(snippet_stmt)
424
+ # Delete all embeddings for these snippets
425
+ for snippet in snippets:
426
+ embedding_stmt = delete(db_entities.Embedding).where(
427
+ db_entities.Embedding.snippet_id == snippet.id
428
+ )
429
+ await self._session.execute(embedding_stmt)
430
+
431
+ # Now delete the snippets
432
+ snippet_stmt = delete(db_entities.Snippet).where(
433
+ db_entities.Snippet.index_id == index_id
434
+ )
435
+ await self._session.execute(snippet_stmt)
401
436
 
402
437
  async def delete_snippets_by_file_ids(self, file_ids: list[int]) -> None:
403
438
  """Delete snippets by file IDs.
@@ -408,50 +443,52 @@ class SqlAlchemyIndexRepository(IndexRepository):
408
443
  if not file_ids:
409
444
  return
410
445
 
411
- # First get all snippets for these files
412
- stmt = select(db_entities.Snippet).where(
413
- db_entities.Snippet.file_id.in_(file_ids)
414
- )
415
- result = await self._session.scalars(stmt)
416
- snippets = result.all()
417
-
418
- # Delete all embeddings for these snippets
419
- for snippet in snippets:
420
- embedding_stmt = delete(db_entities.Embedding).where(
421
- db_entities.Embedding.snippet_id == snippet.id
446
+ async with self.uow:
447
+ # First get all snippets for these files
448
+ stmt = select(db_entities.Snippet).where(
449
+ db_entities.Snippet.file_id.in_(file_ids)
422
450
  )
423
- await self._session.execute(embedding_stmt)
451
+ result = await self._session.scalars(stmt)
452
+ snippets = result.all()
424
453
 
425
- # Now delete the snippets
426
- snippet_stmt = delete(db_entities.Snippet).where(
427
- db_entities.Snippet.file_id.in_(file_ids)
428
- )
429
- await self._session.execute(snippet_stmt)
454
+ # Delete all embeddings for these snippets
455
+ for snippet in snippets:
456
+ embedding_stmt = delete(db_entities.Embedding).where(
457
+ db_entities.Embedding.snippet_id == snippet.id
458
+ )
459
+ await self._session.execute(embedding_stmt)
460
+
461
+ # Now delete the snippets
462
+ snippet_stmt = delete(db_entities.Snippet).where(
463
+ db_entities.Snippet.file_id.in_(file_ids)
464
+ )
465
+ await self._session.execute(snippet_stmt)
430
466
 
431
467
  async def update(self, index: domain_entities.Index) -> None:
432
468
  """Update an index by ensuring all domain objects are saved to database."""
433
469
  if not index.id:
434
470
  raise ValueError("Index must have an ID to be updated")
435
471
 
436
- # 1. Verify the index exists in the database
437
- db_index = await self._session.get(db_entities.Index, index.id)
438
- if not db_index:
439
- raise ValueError(f"Index {index.id} not found")
472
+ async with self.uow:
473
+ # 1. Verify the index exists in the database
474
+ db_index = await self._session.get(db_entities.Index, index.id)
475
+ if not db_index:
476
+ raise ValueError(f"Index {index.id} not found")
440
477
 
441
- # 2. Update index timestamps
442
- if index.updated_at:
443
- db_index.updated_at = index.updated_at
478
+ # 2. Update index timestamps
479
+ if index.updated_at:
480
+ db_index.updated_at = index.updated_at
444
481
 
445
- # 3. Update source if it exists
446
- await self._update_source(index, db_index)
482
+ # 3. Update source if it exists
483
+ await self._update_source(index, db_index)
447
484
 
448
- # 4. Handle files and authors from working copy
449
- if index.source and index.source.working_copy:
450
- await self._update_files_and_authors(index, db_index)
485
+ # 4. Handle files and authors from working copy
486
+ if index.source and index.source.working_copy:
487
+ await self._update_files_and_authors(index, db_index)
451
488
 
452
- # 5. Handle snippets
453
- if index.snippets:
454
- await self._update_snippets(index)
489
+ # 5. Handle snippets
490
+ if index.snippets:
491
+ await self._update_snippets(index)
455
492
 
456
493
  async def _update_source(
457
494
  self, index: domain_entities.Index, db_index: db_entities.Index
@@ -583,26 +620,27 @@ class SqlAlchemyIndexRepository(IndexRepository):
583
620
  # Delete all snippets and embeddings
584
621
  await self.delete_snippets(index.id)
585
622
 
586
- # Delete all author file mappings
587
- stmt = delete(db_entities.AuthorFileMapping).where(
588
- db_entities.AuthorFileMapping.file_id.in_(
589
- [file.id for file in index.source.working_copy.files]
623
+ async with self.uow:
624
+ # Delete all author file mappings
625
+ stmt = delete(db_entities.AuthorFileMapping).where(
626
+ db_entities.AuthorFileMapping.file_id.in_(
627
+ [file.id for file in index.source.working_copy.files]
628
+ )
590
629
  )
591
- )
592
- await self._session.execute(stmt)
630
+ await self._session.execute(stmt)
593
631
 
594
- # Delete all files
595
- stmt = delete(db_entities.File).where(
596
- db_entities.File.source_id == index.source.id
597
- )
598
- await self._session.execute(stmt)
632
+ # Delete all files
633
+ stmt = delete(db_entities.File).where(
634
+ db_entities.File.source_id == index.source.id
635
+ )
636
+ await self._session.execute(stmt)
599
637
 
600
- # Delete the index
601
- stmt = delete(db_entities.Index).where(db_entities.Index.id == index.id)
602
- await self._session.execute(stmt)
638
+ # Delete the index
639
+ stmt = delete(db_entities.Index).where(db_entities.Index.id == index.id)
640
+ await self._session.execute(stmt)
603
641
 
604
- # Delete the source
605
- stmt = delete(db_entities.Source).where(
606
- db_entities.Source.id == index.source.id
607
- )
608
- await self._session.execute(stmt)
642
+ # Delete the source
643
+ stmt = delete(db_entities.Source).where(
644
+ db_entities.Source.id == index.source.id
645
+ )
646
+ await self._session.execute(stmt)