kodit 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/app.py +9 -2
- kodit/application/factories/code_indexing_factory.py +62 -13
- kodit/application/factories/reporting_factory.py +32 -0
- kodit/application/services/auto_indexing_service.py +41 -33
- kodit/application/services/code_indexing_application_service.py +137 -138
- kodit/application/services/indexing_worker_service.py +26 -30
- kodit/application/services/queue_service.py +12 -14
- kodit/application/services/reporting.py +104 -0
- kodit/application/services/sync_scheduler.py +21 -20
- kodit/cli.py +71 -85
- kodit/config.py +26 -3
- kodit/database.py +2 -1
- kodit/domain/entities.py +99 -1
- kodit/domain/protocols.py +34 -1
- kodit/domain/services/bm25_service.py +1 -6
- kodit/domain/services/index_service.py +23 -57
- kodit/domain/services/task_status_query_service.py +19 -0
- kodit/domain/value_objects.py +53 -8
- kodit/infrastructure/api/v1/dependencies.py +40 -12
- kodit/infrastructure/api/v1/routers/indexes.py +45 -0
- kodit/infrastructure/api/v1/schemas/task_status.py +39 -0
- kodit/infrastructure/cloning/git/working_copy.py +43 -7
- kodit/infrastructure/embedding/embedding_factory.py +8 -3
- kodit/infrastructure/embedding/embedding_providers/litellm_embedding_provider.py +48 -55
- kodit/infrastructure/enrichment/local_enrichment_provider.py +41 -30
- kodit/infrastructure/git/git_utils.py +3 -2
- kodit/infrastructure/mappers/index_mapper.py +1 -0
- kodit/infrastructure/mappers/task_status_mapper.py +85 -0
- kodit/infrastructure/reporting/__init__.py +1 -0
- kodit/infrastructure/reporting/db_progress.py +23 -0
- kodit/infrastructure/reporting/log_progress.py +37 -0
- kodit/infrastructure/reporting/tdqm_progress.py +38 -0
- kodit/infrastructure/sqlalchemy/embedding_repository.py +47 -68
- kodit/infrastructure/sqlalchemy/entities.py +89 -2
- kodit/infrastructure/sqlalchemy/index_repository.py +274 -236
- kodit/infrastructure/sqlalchemy/task_repository.py +55 -39
- kodit/infrastructure/sqlalchemy/task_status_repository.py +79 -0
- kodit/infrastructure/sqlalchemy/unit_of_work.py +59 -0
- kodit/mcp.py +15 -3
- kodit/migrations/env.py +0 -1
- kodit/migrations/versions/b9cd1c3fd762_add_task_status.py +77 -0
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/METADATA +1 -1
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/RECORD +47 -40
- kodit/domain/interfaces.py +0 -27
- kodit/infrastructure/ui/__init__.py +0 -1
- kodit/infrastructure/ui/progress.py +0 -170
- kodit/infrastructure/ui/spinner.py +0 -74
- kodit/reporting.py +0 -78
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/WHEEL +0 -0
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/entry_points.txt +0 -0
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""SQLAlchemy implementation of IndexRepository using Index aggregate root."""
|
|
2
2
|
|
|
3
|
-
from collections.abc import Sequence
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from datetime import UTC, datetime
|
|
4
5
|
from typing import cast
|
|
5
6
|
|
|
6
7
|
from pydantic import AnyUrl
|
|
@@ -15,6 +16,15 @@ from kodit.domain.value_objects import (
|
|
|
15
16
|
)
|
|
16
17
|
from kodit.infrastructure.mappers.index_mapper import IndexMapper
|
|
17
18
|
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
19
|
+
from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def create_index_repository(
|
|
23
|
+
session_factory: Callable[[], AsyncSession],
|
|
24
|
+
) -> IndexRepository:
|
|
25
|
+
"""Create an index repository."""
|
|
26
|
+
uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
|
|
27
|
+
return SqlAlchemyIndexRepository(uow)
|
|
18
28
|
|
|
19
29
|
|
|
20
30
|
class SqlAlchemyIndexRepository(IndexRepository):
|
|
@@ -27,120 +37,134 @@ class SqlAlchemyIndexRepository(IndexRepository):
|
|
|
27
37
|
- Snippet entities with their contents
|
|
28
38
|
"""
|
|
29
39
|
|
|
30
|
-
def __init__(self,
|
|
40
|
+
def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
|
|
31
41
|
"""Initialize the repository."""
|
|
32
|
-
self.
|
|
33
|
-
|
|
42
|
+
self.uow = uow
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def _mapper(self) -> IndexMapper:
|
|
46
|
+
if self.uow.session is None:
|
|
47
|
+
raise RuntimeError("UnitOfWork must be used within async context")
|
|
48
|
+
return IndexMapper(self.uow.session)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def _session(self) -> AsyncSession:
|
|
52
|
+
if self.uow.session is None:
|
|
53
|
+
raise RuntimeError("UnitOfWork must be used within async context")
|
|
54
|
+
return self.uow.session
|
|
34
55
|
|
|
35
56
|
async def create(
|
|
36
57
|
self, uri: AnyUrl, working_copy: domain_entities.WorkingCopy
|
|
37
58
|
) -> domain_entities.Index:
|
|
38
59
|
"""Create an index with all the files and authors in the working copy."""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
self._session.add(db_source)
|
|
54
|
-
await self._session.flush() # Get source ID
|
|
55
|
-
|
|
56
|
-
# 3. Create a set of unique authors
|
|
57
|
-
unique_authors = {}
|
|
58
|
-
for domain_file in working_copy.files:
|
|
59
|
-
for author in domain_file.authors:
|
|
60
|
-
key = (author.name, author.email)
|
|
61
|
-
if key not in unique_authors:
|
|
62
|
-
unique_authors[key] = author
|
|
63
|
-
|
|
64
|
-
# 4. Create authors if they don't exist and store their IDs
|
|
65
|
-
author_id_map = {}
|
|
66
|
-
for domain_author in unique_authors.values():
|
|
67
|
-
db_author = await self._find_or_create_author(domain_author)
|
|
68
|
-
author_id_map[(domain_author.name, domain_author.email)] = db_author.id
|
|
69
|
-
|
|
70
|
-
# 5. Create files
|
|
71
|
-
for domain_file in working_copy.files:
|
|
72
|
-
db_file = db_entities.File(
|
|
73
|
-
created_at=domain_file.created_at or db_source.created_at,
|
|
74
|
-
updated_at=domain_file.updated_at or db_source.updated_at,
|
|
75
|
-
source_id=db_source.id,
|
|
76
|
-
mime_type=domain_file.mime_type,
|
|
77
|
-
uri=str(domain_file.uri),
|
|
78
|
-
cloned_path=str(domain_file.uri), # Use URI as cloned path
|
|
79
|
-
sha256=domain_file.sha256,
|
|
80
|
-
size_bytes=0, # Deprecated
|
|
81
|
-
extension="", # Deprecated
|
|
82
|
-
file_processing_status=domain_file.file_processing_status.value,
|
|
60
|
+
async with self.uow:
|
|
61
|
+
# 1. Verify that a source with this URI does not exist
|
|
62
|
+
existing_source = await self._get_source_by_uri(uri)
|
|
63
|
+
if existing_source:
|
|
64
|
+
# Check if index already exists for this source
|
|
65
|
+
existing_index = await self._get_index_by_source_id(existing_source.id)
|
|
66
|
+
if existing_index:
|
|
67
|
+
return await self._mapper.to_domain_index(existing_index)
|
|
68
|
+
|
|
69
|
+
# 2. Create the source
|
|
70
|
+
db_source = db_entities.Source(
|
|
71
|
+
uri=str(uri),
|
|
72
|
+
cloned_path=str(working_copy.cloned_path),
|
|
73
|
+
source_type=db_entities.SourceType(working_copy.source_type.value),
|
|
83
74
|
)
|
|
84
|
-
self._session.add(
|
|
85
|
-
await self._session.flush() # Get
|
|
86
|
-
|
|
87
|
-
#
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
75
|
+
self._session.add(db_source)
|
|
76
|
+
await self._session.flush() # Get source ID
|
|
77
|
+
|
|
78
|
+
# 3. Create a set of unique authors
|
|
79
|
+
unique_authors = {}
|
|
80
|
+
for domain_file in working_copy.files:
|
|
81
|
+
for author in domain_file.authors:
|
|
82
|
+
key = (author.name, author.email)
|
|
83
|
+
if key not in unique_authors:
|
|
84
|
+
unique_authors[key] = author
|
|
85
|
+
|
|
86
|
+
# 4. Create authors if they don't exist and store their IDs
|
|
87
|
+
author_id_map = {}
|
|
88
|
+
for domain_author in unique_authors.values():
|
|
89
|
+
db_author = await self._find_or_create_author(domain_author)
|
|
90
|
+
author_id_map[(domain_author.name, domain_author.email)] = db_author.id
|
|
91
|
+
|
|
92
|
+
# 5. Create files
|
|
93
|
+
for domain_file in working_copy.files:
|
|
94
|
+
db_file = db_entities.File(
|
|
95
|
+
created_at=domain_file.created_at or db_source.created_at,
|
|
96
|
+
updated_at=domain_file.updated_at or db_source.updated_at,
|
|
97
|
+
source_id=db_source.id,
|
|
98
|
+
mime_type=domain_file.mime_type,
|
|
99
|
+
uri=str(domain_file.uri),
|
|
100
|
+
cloned_path=str(domain_file.uri), # Use URI as cloned path
|
|
101
|
+
sha256=domain_file.sha256,
|
|
102
|
+
size_bytes=0, # Deprecated
|
|
103
|
+
extension="", # Deprecated
|
|
104
|
+
file_processing_status=domain_file.file_processing_status.value,
|
|
92
105
|
)
|
|
93
|
-
|
|
106
|
+
self._session.add(db_file)
|
|
107
|
+
await self._session.flush() # Get file ID
|
|
108
|
+
|
|
109
|
+
# 6. Create author_file_mappings
|
|
110
|
+
for author in domain_file.authors:
|
|
111
|
+
author_id = author_id_map[(author.name, author.email)]
|
|
112
|
+
mapping = db_entities.AuthorFileMapping(
|
|
113
|
+
author_id=author_id, file_id=db_file.id
|
|
114
|
+
)
|
|
115
|
+
await self._upsert_author_file_mapping(mapping)
|
|
94
116
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
117
|
+
# 7. Create the index
|
|
118
|
+
db_index = db_entities.Index(source_id=db_source.id)
|
|
119
|
+
self._session.add(db_index)
|
|
120
|
+
await self._session.flush() # Get index ID
|
|
99
121
|
|
|
100
|
-
|
|
101
|
-
|
|
122
|
+
# 8. Return the new index
|
|
123
|
+
return await self._mapper.to_domain_index(db_index)
|
|
102
124
|
|
|
103
125
|
async def get(self, index_id: int) -> domain_entities.Index | None:
|
|
104
126
|
"""Get an index by ID."""
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
127
|
+
async with self.uow:
|
|
128
|
+
db_index = await self._session.get(db_entities.Index, index_id)
|
|
129
|
+
if not db_index:
|
|
130
|
+
return None
|
|
108
131
|
|
|
109
|
-
|
|
132
|
+
return await self._mapper.to_domain_index(db_index)
|
|
110
133
|
|
|
111
134
|
async def get_by_uri(self, uri: AnyUrl) -> domain_entities.Index | None:
|
|
112
135
|
"""Get an index by source URI."""
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
136
|
+
async with self.uow:
|
|
137
|
+
db_source = await self._get_source_by_uri(uri)
|
|
138
|
+
if not db_source:
|
|
139
|
+
return None
|
|
116
140
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
141
|
+
db_index = await self._get_index_by_source_id(db_source.id)
|
|
142
|
+
if not db_index:
|
|
143
|
+
return None
|
|
120
144
|
|
|
121
|
-
|
|
145
|
+
return await self._mapper.to_domain_index(db_index)
|
|
122
146
|
|
|
123
147
|
async def all(self) -> list[domain_entities.Index]:
|
|
124
148
|
"""List all indexes."""
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
149
|
+
async with self.uow:
|
|
150
|
+
stmt = select(db_entities.Index)
|
|
151
|
+
result = await self._session.scalars(stmt)
|
|
152
|
+
db_indexes = result.all()
|
|
128
153
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
154
|
+
domain_indexes = []
|
|
155
|
+
for db_index in db_indexes:
|
|
156
|
+
domain_index = await self._mapper.to_domain_index(db_index)
|
|
157
|
+
domain_indexes.append(domain_index)
|
|
133
158
|
|
|
134
|
-
|
|
159
|
+
return domain_indexes
|
|
135
160
|
|
|
136
161
|
async def update_index_timestamp(self, index_id: int) -> None:
|
|
137
162
|
"""Update the timestamp of an index."""
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
163
|
+
async with self.uow:
|
|
164
|
+
db_index = await self._session.get(db_entities.Index, index_id)
|
|
165
|
+
if not db_index:
|
|
166
|
+
raise ValueError(f"Index {index_id} not found")
|
|
142
167
|
db_index.updated_at = datetime.now(UTC)
|
|
143
|
-
# SQLAlchemy will automatically track this change
|
|
144
168
|
|
|
145
169
|
async def add_snippets(
|
|
146
170
|
self, index_id: int, snippets: list[domain_entities.Snippet]
|
|
@@ -152,17 +176,18 @@ class SqlAlchemyIndexRepository(IndexRepository):
|
|
|
152
176
|
if not snippets:
|
|
153
177
|
return
|
|
154
178
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
179
|
+
async with self.uow:
|
|
180
|
+
# Validate the index exists
|
|
181
|
+
db_index = await self._session.get(db_entities.Index, index_id)
|
|
182
|
+
if not db_index:
|
|
183
|
+
raise ValueError(f"Index {index_id} not found")
|
|
159
184
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
185
|
+
# Convert domain snippets to database entities
|
|
186
|
+
for domain_snippet in snippets:
|
|
187
|
+
db_snippet = await self._mapper.from_domain_snippet(
|
|
188
|
+
domain_snippet, index_id
|
|
189
|
+
)
|
|
190
|
+
self._session.add(db_snippet)
|
|
166
191
|
|
|
167
192
|
async def update_snippets(
|
|
168
193
|
self, index_id: int, snippets: list[domain_entities.Snippet]
|
|
@@ -175,27 +200,30 @@ class SqlAlchemyIndexRepository(IndexRepository):
|
|
|
175
200
|
if not snippets:
|
|
176
201
|
return
|
|
177
202
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
203
|
+
async with self.uow:
|
|
204
|
+
# Validate the index exists
|
|
205
|
+
db_index = await self._session.get(db_entities.Index, index_id)
|
|
206
|
+
if not db_index:
|
|
207
|
+
raise ValueError(f"Index {index_id} not found")
|
|
182
208
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
209
|
+
# Update each snippet
|
|
210
|
+
for domain_snippet in snippets:
|
|
211
|
+
if not domain_snippet.id:
|
|
212
|
+
raise ValueError("Snippet must have an ID for update")
|
|
187
213
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
214
|
+
# Get the existing snippet
|
|
215
|
+
db_snippet = await self._session.get(
|
|
216
|
+
db_entities.Snippet, domain_snippet.id
|
|
217
|
+
)
|
|
218
|
+
if not db_snippet:
|
|
219
|
+
raise ValueError(f"Snippet {domain_snippet.id} not found")
|
|
192
220
|
|
|
193
|
-
|
|
194
|
-
|
|
221
|
+
db_snippet.content = domain_snippet.original_text()
|
|
222
|
+
db_snippet.summary = domain_snippet.summary_text()
|
|
195
223
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
224
|
+
# Update timestamps if provided
|
|
225
|
+
if domain_snippet.updated_at:
|
|
226
|
+
db_snippet.updated_at = domain_snippet.updated_at
|
|
199
227
|
|
|
200
228
|
async def search( # noqa: C901
|
|
201
229
|
self, request: MultiSearchRequest
|
|
@@ -258,71 +286,77 @@ class SqlAlchemyIndexRepository(IndexRepository):
|
|
|
258
286
|
query = query.limit(request.top_k)
|
|
259
287
|
|
|
260
288
|
# Execute query
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
289
|
+
async with self.uow:
|
|
290
|
+
result = await self._session.scalars(query)
|
|
291
|
+
db_snippets = result.all()
|
|
292
|
+
|
|
293
|
+
# Convert to SnippetWithContext
|
|
294
|
+
snippet_contexts = []
|
|
295
|
+
for db_snippet in db_snippets:
|
|
296
|
+
# Get the file for this snippet
|
|
297
|
+
db_file = await self._session.get(db_entities.File, db_snippet.file_id)
|
|
298
|
+
if not db_file:
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
# Get the source for this file
|
|
302
|
+
db_source = await self._session.get(
|
|
303
|
+
db_entities.Source, db_file.source_id
|
|
304
|
+
)
|
|
305
|
+
if not db_source:
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
domain_file = await self._mapper.to_domain_file(db_file)
|
|
309
|
+
snippet_context = SnippetWithContext(
|
|
310
|
+
source=await self._mapper.to_domain_source(db_source),
|
|
311
|
+
file=domain_file,
|
|
312
|
+
authors=domain_file.authors,
|
|
313
|
+
snippet=await self._mapper.to_domain_snippet(
|
|
314
|
+
db_snippet=db_snippet, domain_files=[domain_file]
|
|
315
|
+
),
|
|
316
|
+
)
|
|
317
|
+
snippet_contexts.append(snippet_context)
|
|
287
318
|
|
|
288
|
-
|
|
319
|
+
return snippet_contexts
|
|
289
320
|
|
|
290
321
|
async def get_snippets_by_ids(self, ids: list[int]) -> list[SnippetWithContext]:
|
|
291
322
|
"""Get snippets by their IDs."""
|
|
292
323
|
if not ids:
|
|
293
324
|
return []
|
|
294
325
|
|
|
295
|
-
|
|
296
|
-
|
|
326
|
+
async with self.uow:
|
|
327
|
+
# Query snippets by IDs
|
|
328
|
+
query = select(db_entities.Snippet).where(db_entities.Snippet.id.in_(ids))
|
|
297
329
|
|
|
298
|
-
|
|
299
|
-
|
|
330
|
+
result = await self._session.scalars(query)
|
|
331
|
+
db_snippets = result.all()
|
|
300
332
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
333
|
+
# Convert to SnippetWithContext using similar logic as search
|
|
334
|
+
snippet_contexts = []
|
|
335
|
+
for db_snippet in db_snippets:
|
|
336
|
+
# Get the file for this snippet
|
|
337
|
+
db_file = await self._session.get(db_entities.File, db_snippet.file_id)
|
|
338
|
+
if not db_file:
|
|
339
|
+
continue
|
|
308
340
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
341
|
+
# Get the source for this file
|
|
342
|
+
db_source = await self._session.get(
|
|
343
|
+
db_entities.Source, db_file.source_id
|
|
344
|
+
)
|
|
345
|
+
if not db_source:
|
|
346
|
+
continue
|
|
347
|
+
|
|
348
|
+
domain_file = await self._mapper.to_domain_file(db_file)
|
|
349
|
+
snippet_context = SnippetWithContext(
|
|
350
|
+
source=await self._mapper.to_domain_source(db_source),
|
|
351
|
+
file=domain_file,
|
|
352
|
+
authors=domain_file.authors,
|
|
353
|
+
snippet=await self._mapper.to_domain_snippet(
|
|
354
|
+
db_snippet=db_snippet, domain_files=[domain_file]
|
|
355
|
+
),
|
|
356
|
+
)
|
|
357
|
+
snippet_contexts.append(snippet_context)
|
|
324
358
|
|
|
325
|
-
|
|
359
|
+
return snippet_contexts
|
|
326
360
|
|
|
327
361
|
async def _get_source_by_uri(self, uri: AnyUrl) -> db_entities.Source | None:
|
|
328
362
|
"""Get source by URI."""
|
|
@@ -379,25 +413,26 @@ class SqlAlchemyIndexRepository(IndexRepository):
|
|
|
379
413
|
|
|
380
414
|
async def delete_snippets(self, index_id: int) -> None:
|
|
381
415
|
"""Delete all snippets from an index."""
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
db_entities.Snippet.
|
|
385
|
-
|
|
386
|
-
result = await self._session.scalars(stmt)
|
|
387
|
-
snippets = result.all()
|
|
388
|
-
|
|
389
|
-
# Delete all embeddings for these snippets
|
|
390
|
-
for snippet in snippets:
|
|
391
|
-
embedding_stmt = delete(db_entities.Embedding).where(
|
|
392
|
-
db_entities.Embedding.snippet_id == snippet.id
|
|
416
|
+
async with self.uow:
|
|
417
|
+
# First get all snippets for this index
|
|
418
|
+
stmt = select(db_entities.Snippet).where(
|
|
419
|
+
db_entities.Snippet.index_id == index_id
|
|
393
420
|
)
|
|
394
|
-
await self._session.
|
|
421
|
+
result = await self._session.scalars(stmt)
|
|
422
|
+
snippets = result.all()
|
|
395
423
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
424
|
+
# Delete all embeddings for these snippets
|
|
425
|
+
for snippet in snippets:
|
|
426
|
+
embedding_stmt = delete(db_entities.Embedding).where(
|
|
427
|
+
db_entities.Embedding.snippet_id == snippet.id
|
|
428
|
+
)
|
|
429
|
+
await self._session.execute(embedding_stmt)
|
|
430
|
+
|
|
431
|
+
# Now delete the snippets
|
|
432
|
+
snippet_stmt = delete(db_entities.Snippet).where(
|
|
433
|
+
db_entities.Snippet.index_id == index_id
|
|
434
|
+
)
|
|
435
|
+
await self._session.execute(snippet_stmt)
|
|
401
436
|
|
|
402
437
|
async def delete_snippets_by_file_ids(self, file_ids: list[int]) -> None:
|
|
403
438
|
"""Delete snippets by file IDs.
|
|
@@ -408,50 +443,52 @@ class SqlAlchemyIndexRepository(IndexRepository):
|
|
|
408
443
|
if not file_ids:
|
|
409
444
|
return
|
|
410
445
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
db_entities.Snippet.
|
|
414
|
-
|
|
415
|
-
result = await self._session.scalars(stmt)
|
|
416
|
-
snippets = result.all()
|
|
417
|
-
|
|
418
|
-
# Delete all embeddings for these snippets
|
|
419
|
-
for snippet in snippets:
|
|
420
|
-
embedding_stmt = delete(db_entities.Embedding).where(
|
|
421
|
-
db_entities.Embedding.snippet_id == snippet.id
|
|
446
|
+
async with self.uow:
|
|
447
|
+
# First get all snippets for these files
|
|
448
|
+
stmt = select(db_entities.Snippet).where(
|
|
449
|
+
db_entities.Snippet.file_id.in_(file_ids)
|
|
422
450
|
)
|
|
423
|
-
await self._session.
|
|
451
|
+
result = await self._session.scalars(stmt)
|
|
452
|
+
snippets = result.all()
|
|
424
453
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
454
|
+
# Delete all embeddings for these snippets
|
|
455
|
+
for snippet in snippets:
|
|
456
|
+
embedding_stmt = delete(db_entities.Embedding).where(
|
|
457
|
+
db_entities.Embedding.snippet_id == snippet.id
|
|
458
|
+
)
|
|
459
|
+
await self._session.execute(embedding_stmt)
|
|
460
|
+
|
|
461
|
+
# Now delete the snippets
|
|
462
|
+
snippet_stmt = delete(db_entities.Snippet).where(
|
|
463
|
+
db_entities.Snippet.file_id.in_(file_ids)
|
|
464
|
+
)
|
|
465
|
+
await self._session.execute(snippet_stmt)
|
|
430
466
|
|
|
431
467
|
async def update(self, index: domain_entities.Index) -> None:
|
|
432
468
|
"""Update an index by ensuring all domain objects are saved to database."""
|
|
433
469
|
if not index.id:
|
|
434
470
|
raise ValueError("Index must have an ID to be updated")
|
|
435
471
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
472
|
+
async with self.uow:
|
|
473
|
+
# 1. Verify the index exists in the database
|
|
474
|
+
db_index = await self._session.get(db_entities.Index, index.id)
|
|
475
|
+
if not db_index:
|
|
476
|
+
raise ValueError(f"Index {index.id} not found")
|
|
440
477
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
478
|
+
# 2. Update index timestamps
|
|
479
|
+
if index.updated_at:
|
|
480
|
+
db_index.updated_at = index.updated_at
|
|
444
481
|
|
|
445
|
-
|
|
446
|
-
|
|
482
|
+
# 3. Update source if it exists
|
|
483
|
+
await self._update_source(index, db_index)
|
|
447
484
|
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
485
|
+
# 4. Handle files and authors from working copy
|
|
486
|
+
if index.source and index.source.working_copy:
|
|
487
|
+
await self._update_files_and_authors(index, db_index)
|
|
451
488
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
489
|
+
# 5. Handle snippets
|
|
490
|
+
if index.snippets:
|
|
491
|
+
await self._update_snippets(index)
|
|
455
492
|
|
|
456
493
|
async def _update_source(
|
|
457
494
|
self, index: domain_entities.Index, db_index: db_entities.Index
|
|
@@ -583,26 +620,27 @@ class SqlAlchemyIndexRepository(IndexRepository):
|
|
|
583
620
|
# Delete all snippets and embeddings
|
|
584
621
|
await self.delete_snippets(index.id)
|
|
585
622
|
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
db_entities.AuthorFileMapping.
|
|
589
|
-
|
|
623
|
+
async with self.uow:
|
|
624
|
+
# Delete all author file mappings
|
|
625
|
+
stmt = delete(db_entities.AuthorFileMapping).where(
|
|
626
|
+
db_entities.AuthorFileMapping.file_id.in_(
|
|
627
|
+
[file.id for file in index.source.working_copy.files]
|
|
628
|
+
)
|
|
590
629
|
)
|
|
591
|
-
|
|
592
|
-
await self._session.execute(stmt)
|
|
630
|
+
await self._session.execute(stmt)
|
|
593
631
|
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
632
|
+
# Delete all files
|
|
633
|
+
stmt = delete(db_entities.File).where(
|
|
634
|
+
db_entities.File.source_id == index.source.id
|
|
635
|
+
)
|
|
636
|
+
await self._session.execute(stmt)
|
|
599
637
|
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
638
|
+
# Delete the index
|
|
639
|
+
stmt = delete(db_entities.Index).where(db_entities.Index.id == index.id)
|
|
640
|
+
await self._session.execute(stmt)
|
|
603
641
|
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
642
|
+
# Delete the source
|
|
643
|
+
stmt = delete(db_entities.Source).where(
|
|
644
|
+
db_entities.Source.id == index.source.id
|
|
645
|
+
)
|
|
646
|
+
await self._session.execute(stmt)
|