kodit 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/cli.py +105 -19
- kodit/embedding/embedding_factory.py +2 -2
- kodit/embedding/embedding_provider/embedding_provider.py +9 -2
- kodit/embedding/embedding_provider/openai_embedding_provider.py +19 -7
- kodit/embedding/vectorchord_vector_search_service.py +24 -15
- kodit/enrichment/__init__.py +1 -0
- kodit/enrichment/enrichment_factory.py +23 -0
- kodit/enrichment/enrichment_provider/__init__.py +1 -0
- kodit/enrichment/enrichment_provider/enrichment_provider.py +16 -0
- kodit/enrichment/enrichment_provider/local_enrichment_provider.py +63 -0
- kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +77 -0
- kodit/enrichment/enrichment_service.py +33 -0
- kodit/indexing/fusion.py +67 -0
- kodit/indexing/indexing_repository.py +20 -0
- kodit/indexing/indexing_service.py +120 -4
- kodit/mcp.py +25 -16
- kodit/snippets/languages/go.scm +26 -0
- {kodit-0.1.15.dist-info → kodit-0.1.16.dist-info}/METADATA +3 -1
- {kodit-0.1.15.dist-info → kodit-0.1.16.dist-info}/RECORD +23 -17
- kodit/search/__init__.py +0 -1
- kodit/search/search_repository.py +0 -57
- kodit/search/search_service.py +0 -135
- {kodit-0.1.15.dist-info → kodit-0.1.16.dist-info}/WHEEL +0 -0
- {kodit-0.1.15.dist-info → kodit-0.1.16.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.15.dist-info → kodit-0.1.16.dist-info}/licenses/LICENSE +0 -0
kodit/_version.py
CHANGED
kodit/cli.py
CHANGED
|
@@ -17,11 +17,10 @@ from kodit.config import (
|
|
|
17
17
|
with_session,
|
|
18
18
|
)
|
|
19
19
|
from kodit.embedding.embedding_factory import embedding_factory
|
|
20
|
+
from kodit.enrichment.enrichment_factory import enrichment_factory
|
|
20
21
|
from kodit.indexing.indexing_repository import IndexRepository
|
|
21
|
-
from kodit.indexing.indexing_service import IndexService
|
|
22
|
+
from kodit.indexing.indexing_service import IndexService, SearchRequest
|
|
22
23
|
from kodit.log import configure_logging, configure_telemetry, log_event
|
|
23
|
-
from kodit.search.search_repository import SearchRepository
|
|
24
|
-
from kodit.search.search_service import SearchRequest, SearchService
|
|
25
24
|
from kodit.source.source_repository import SourceRepository
|
|
26
25
|
from kodit.source.source_service import SourceService
|
|
27
26
|
|
|
@@ -72,9 +71,13 @@ async def index(
|
|
|
72
71
|
repository=repository,
|
|
73
72
|
source_service=source_service,
|
|
74
73
|
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
75
|
-
|
|
76
|
-
app_context=app_context, session=session
|
|
74
|
+
code_search_service=embedding_factory(
|
|
75
|
+
task_name="code", app_context=app_context, session=session
|
|
77
76
|
),
|
|
77
|
+
text_search_service=embedding_factory(
|
|
78
|
+
task_name="text", app_context=app_context, session=session
|
|
79
|
+
),
|
|
80
|
+
enrichment_service=enrichment_factory(app_context),
|
|
78
81
|
)
|
|
79
82
|
|
|
80
83
|
if not sources:
|
|
@@ -131,11 +134,20 @@ async def code(
|
|
|
131
134
|
|
|
132
135
|
This works best if your query is code.
|
|
133
136
|
"""
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
+
source_repository = SourceRepository(session)
|
|
138
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
139
|
+
repository = IndexRepository(session)
|
|
140
|
+
service = IndexService(
|
|
141
|
+
repository=repository,
|
|
142
|
+
source_service=source_service,
|
|
137
143
|
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
138
|
-
|
|
144
|
+
code_search_service=embedding_factory(
|
|
145
|
+
task_name="code", app_context=app_context, session=session
|
|
146
|
+
),
|
|
147
|
+
text_search_service=embedding_factory(
|
|
148
|
+
task_name="text", app_context=app_context, session=session
|
|
149
|
+
),
|
|
150
|
+
enrichment_service=enrichment_factory(app_context),
|
|
139
151
|
)
|
|
140
152
|
|
|
141
153
|
snippets = await service.search(SearchRequest(code_query=query, top_k=top_k))
|
|
@@ -147,6 +159,7 @@ async def code(
|
|
|
147
159
|
for snippet in snippets:
|
|
148
160
|
click.echo("-" * 80)
|
|
149
161
|
click.echo(f"{snippet.uri}")
|
|
162
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
150
163
|
click.echo(snippet.content)
|
|
151
164
|
click.echo("-" * 80)
|
|
152
165
|
click.echo()
|
|
@@ -164,11 +177,20 @@ async def keyword(
|
|
|
164
177
|
top_k: int,
|
|
165
178
|
) -> None:
|
|
166
179
|
"""Search for snippets using keyword search."""
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
180
|
+
source_repository = SourceRepository(session)
|
|
181
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
182
|
+
repository = IndexRepository(session)
|
|
183
|
+
service = IndexService(
|
|
184
|
+
repository=repository,
|
|
185
|
+
source_service=source_service,
|
|
170
186
|
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
171
|
-
|
|
187
|
+
code_search_service=embedding_factory(
|
|
188
|
+
task_name="code", app_context=app_context, session=session
|
|
189
|
+
),
|
|
190
|
+
text_search_service=embedding_factory(
|
|
191
|
+
task_name="text", app_context=app_context, session=session
|
|
192
|
+
),
|
|
193
|
+
enrichment_service=enrichment_factory(app_context),
|
|
172
194
|
)
|
|
173
195
|
|
|
174
196
|
snippets = await service.search(SearchRequest(keywords=keywords, top_k=top_k))
|
|
@@ -180,6 +202,53 @@ async def keyword(
|
|
|
180
202
|
for snippet in snippets:
|
|
181
203
|
click.echo("-" * 80)
|
|
182
204
|
click.echo(f"{snippet.uri}")
|
|
205
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
206
|
+
click.echo(snippet.content)
|
|
207
|
+
click.echo("-" * 80)
|
|
208
|
+
click.echo()
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@search.command()
|
|
212
|
+
@click.argument("query")
|
|
213
|
+
@click.option("--top-k", default=10, help="Number of snippets to retrieve")
|
|
214
|
+
@with_app_context
|
|
215
|
+
@with_session
|
|
216
|
+
async def text(
|
|
217
|
+
session: AsyncSession,
|
|
218
|
+
app_context: AppContext,
|
|
219
|
+
query: str,
|
|
220
|
+
top_k: int,
|
|
221
|
+
) -> None:
|
|
222
|
+
"""Search for snippets using semantic text search.
|
|
223
|
+
|
|
224
|
+
This works best if your query is text.
|
|
225
|
+
"""
|
|
226
|
+
source_repository = SourceRepository(session)
|
|
227
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
228
|
+
repository = IndexRepository(session)
|
|
229
|
+
service = IndexService(
|
|
230
|
+
repository=repository,
|
|
231
|
+
source_service=source_service,
|
|
232
|
+
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
233
|
+
code_search_service=embedding_factory(
|
|
234
|
+
task_name="code", app_context=app_context, session=session
|
|
235
|
+
),
|
|
236
|
+
text_search_service=embedding_factory(
|
|
237
|
+
task_name="text", app_context=app_context, session=session
|
|
238
|
+
),
|
|
239
|
+
enrichment_service=enrichment_factory(app_context),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
snippets = await service.search(SearchRequest(text_query=query, top_k=top_k))
|
|
243
|
+
|
|
244
|
+
if len(snippets) == 0:
|
|
245
|
+
click.echo("No snippets found")
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
for snippet in snippets:
|
|
249
|
+
click.echo("-" * 80)
|
|
250
|
+
click.echo(f"{snippet.uri}")
|
|
251
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
183
252
|
click.echo(snippet.content)
|
|
184
253
|
click.echo("-" * 80)
|
|
185
254
|
click.echo()
|
|
@@ -189,28 +258,44 @@ async def keyword(
|
|
|
189
258
|
@click.option("--top-k", default=10, help="Number of snippets to retrieve")
|
|
190
259
|
@click.option("--keywords", required=True, help="Comma separated list of keywords")
|
|
191
260
|
@click.option("--code", required=True, help="Semantic code search query")
|
|
261
|
+
@click.option("--text", required=True, help="Semantic text search query")
|
|
192
262
|
@with_app_context
|
|
193
263
|
@with_session
|
|
194
|
-
async def hybrid(
|
|
264
|
+
async def hybrid( # noqa: PLR0913
|
|
195
265
|
session: AsyncSession,
|
|
196
266
|
app_context: AppContext,
|
|
197
267
|
top_k: int,
|
|
198
268
|
keywords: str,
|
|
199
269
|
code: str,
|
|
270
|
+
text: str,
|
|
200
271
|
) -> None:
|
|
201
272
|
"""Search for snippets using hybrid search."""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
273
|
+
source_repository = SourceRepository(session)
|
|
274
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
275
|
+
repository = IndexRepository(session)
|
|
276
|
+
service = IndexService(
|
|
277
|
+
repository=repository,
|
|
278
|
+
source_service=source_service,
|
|
205
279
|
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
206
|
-
|
|
280
|
+
code_search_service=embedding_factory(
|
|
281
|
+
task_name="code", app_context=app_context, session=session
|
|
282
|
+
),
|
|
283
|
+
text_search_service=embedding_factory(
|
|
284
|
+
task_name="text", app_context=app_context, session=session
|
|
285
|
+
),
|
|
286
|
+
enrichment_service=enrichment_factory(app_context),
|
|
207
287
|
)
|
|
208
288
|
|
|
209
289
|
# Parse keywords into a list of strings
|
|
210
290
|
keywords_list = [k.strip().lower() for k in keywords.split(",")]
|
|
211
291
|
|
|
212
292
|
snippets = await service.search(
|
|
213
|
-
SearchRequest(
|
|
293
|
+
SearchRequest(
|
|
294
|
+
text_query=text,
|
|
295
|
+
keywords=keywords_list,
|
|
296
|
+
code_query=code,
|
|
297
|
+
top_k=top_k,
|
|
298
|
+
)
|
|
214
299
|
)
|
|
215
300
|
|
|
216
301
|
if len(snippets) == 0:
|
|
@@ -220,6 +305,7 @@ async def hybrid(
|
|
|
220
305
|
for snippet in snippets:
|
|
221
306
|
click.echo("-" * 80)
|
|
222
307
|
click.echo(f"{snippet.uri}")
|
|
308
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
223
309
|
click.echo(snippet.content)
|
|
224
310
|
click.echo("-" * 80)
|
|
225
311
|
click.echo()
|
|
@@ -21,7 +21,7 @@ from kodit.embedding.vectorchord_vector_search_service import (
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def embedding_factory(
|
|
24
|
-
app_context: AppContext, session: AsyncSession
|
|
24
|
+
task_name: str, app_context: AppContext, session: AsyncSession
|
|
25
25
|
) -> VectorSearchService:
|
|
26
26
|
"""Create an embedding service."""
|
|
27
27
|
embedding_repository = EmbeddingRepository(session=session)
|
|
@@ -33,7 +33,7 @@ def embedding_factory(
|
|
|
33
33
|
embedding_provider = LocalEmbeddingProvider(CODE)
|
|
34
34
|
|
|
35
35
|
if app_context.default_search.provider == "vectorchord":
|
|
36
|
-
return VectorChordVectorSearchService(session, embedding_provider)
|
|
36
|
+
return VectorChordVectorSearchService(task_name, session, embedding_provider)
|
|
37
37
|
if app_context.default_search.provider == "sqlite":
|
|
38
38
|
return LocalVectorSearchService(
|
|
39
39
|
embedding_repository=embedding_repository,
|
|
@@ -38,8 +38,15 @@ def split_sub_batches(encoding: tiktoken.Encoding, data: list[str]) -> list[list
|
|
|
38
38
|
item_tokens = len(encoding.encode(next_item))
|
|
39
39
|
|
|
40
40
|
if item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
# Loop around trying to truncate the snippet until it fits in the max
|
|
42
|
+
# embedding size
|
|
43
|
+
while item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
|
|
44
|
+
next_item = next_item[:-1]
|
|
45
|
+
item_tokens = len(encoding.encode(next_item))
|
|
46
|
+
|
|
47
|
+
data_to_process[0] = next_item
|
|
48
|
+
|
|
49
|
+
log.warning("Truncated snippet", snippet=next_item)
|
|
43
50
|
|
|
44
51
|
if current_tokens + item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
|
|
45
52
|
break
|
|
@@ -38,26 +38,38 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
38
38
|
# Process batches in parallel with a semaphore to limit concurrent requests
|
|
39
39
|
sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
|
|
40
40
|
|
|
41
|
-
|
|
41
|
+
# Create a list of tuples with a temporary id for each batch
|
|
42
|
+
# We need to do this so that we can return the results in the same order as the
|
|
43
|
+
# input data
|
|
44
|
+
input_data = [(i, batch) for i, batch in enumerate(batched_data)]
|
|
45
|
+
|
|
46
|
+
async def process_batch(
|
|
47
|
+
data: tuple[int, list[str]],
|
|
48
|
+
) -> tuple[int, list[Vector]]:
|
|
49
|
+
batch_id, batch = data
|
|
42
50
|
async with sem:
|
|
43
51
|
try:
|
|
44
52
|
response = await self.openai_client.embeddings.create(
|
|
45
53
|
model=self.model_name,
|
|
46
54
|
input=batch,
|
|
47
55
|
)
|
|
48
|
-
return [
|
|
56
|
+
return batch_id, [
|
|
49
57
|
[float(x) for x in embedding.embedding]
|
|
50
58
|
for embedding in response.data
|
|
51
59
|
]
|
|
52
60
|
except Exception as e:
|
|
53
61
|
self.log.exception("Error embedding batch", error=str(e))
|
|
54
|
-
return []
|
|
62
|
+
return batch_id, []
|
|
55
63
|
|
|
56
64
|
# Create tasks for all batches
|
|
57
|
-
tasks = [process_batch(batch) for batch in
|
|
65
|
+
tasks = [process_batch(batch) for batch in input_data]
|
|
58
66
|
|
|
59
67
|
# Process all batches and yield results as they complete
|
|
60
|
-
results: list[Vector] = []
|
|
68
|
+
results: list[tuple[int, list[Vector]]] = []
|
|
61
69
|
for task in asyncio.as_completed(tasks):
|
|
62
|
-
|
|
63
|
-
|
|
70
|
+
result = await task
|
|
71
|
+
results.append(result)
|
|
72
|
+
|
|
73
|
+
# Output in the same order as the input data
|
|
74
|
+
ordered_results = [result for _, result in sorted(results, key=lambda x: x[0])]
|
|
75
|
+
return [item for sublist in ordered_results for item in sublist]
|
|
@@ -12,23 +12,20 @@ from kodit.embedding.vector_search_service import (
|
|
|
12
12
|
VectorSearchService,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
-
TABLE_NAME = "vectorchord_embeddings"
|
|
16
|
-
INDEX_NAME = f"{TABLE_NAME}_idx"
|
|
17
|
-
|
|
18
15
|
# SQL Queries
|
|
19
16
|
CREATE_VCHORD_EXTENSION = """
|
|
20
17
|
CREATE EXTENSION IF NOT EXISTS vchord CASCADE;
|
|
21
18
|
"""
|
|
22
19
|
|
|
23
|
-
CHECK_VCHORD_EMBEDDING_DIMENSION =
|
|
20
|
+
CHECK_VCHORD_EMBEDDING_DIMENSION = """
|
|
24
21
|
SELECT a.atttypmod as dimension
|
|
25
22
|
FROM pg_attribute a
|
|
26
23
|
JOIN pg_class c ON a.attrelid = c.oid
|
|
27
24
|
WHERE c.relname = '{TABLE_NAME}'
|
|
28
25
|
AND a.attname = 'embedding';
|
|
29
|
-
"""
|
|
26
|
+
"""
|
|
30
27
|
|
|
31
|
-
CREATE_VCHORD_INDEX =
|
|
28
|
+
CREATE_VCHORD_INDEX = """
|
|
32
29
|
CREATE INDEX IF NOT EXISTS {INDEX_NAME}
|
|
33
30
|
ON {TABLE_NAME}
|
|
34
31
|
USING vchordrq (embedding vector_l2_ops) WITH (options = $$
|
|
@@ -38,21 +35,21 @@ lists = []
|
|
|
38
35
|
$$);
|
|
39
36
|
"""
|
|
40
37
|
|
|
41
|
-
INSERT_QUERY =
|
|
38
|
+
INSERT_QUERY = """
|
|
42
39
|
INSERT INTO {TABLE_NAME} (snippet_id, embedding)
|
|
43
40
|
VALUES (:snippet_id, :embedding)
|
|
44
41
|
ON CONFLICT (snippet_id) DO UPDATE
|
|
45
42
|
SET embedding = EXCLUDED.embedding
|
|
46
|
-
"""
|
|
43
|
+
"""
|
|
47
44
|
|
|
48
45
|
# Note that <=> in vectorchord is cosine distance
|
|
49
46
|
# So scores go from 0 (similar) to 2 (opposite)
|
|
50
|
-
SEARCH_QUERY =
|
|
47
|
+
SEARCH_QUERY = """
|
|
51
48
|
SELECT snippet_id, embedding <=> :query as score
|
|
52
49
|
FROM {TABLE_NAME}
|
|
53
50
|
ORDER BY score ASC
|
|
54
51
|
LIMIT :top_k;
|
|
55
|
-
"""
|
|
52
|
+
"""
|
|
56
53
|
|
|
57
54
|
|
|
58
55
|
class VectorChordVectorSearchService(VectorSearchService):
|
|
@@ -60,6 +57,7 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
60
57
|
|
|
61
58
|
def __init__(
|
|
62
59
|
self,
|
|
60
|
+
task_name: str,
|
|
63
61
|
session: AsyncSession,
|
|
64
62
|
embedding_provider: EmbeddingProvider,
|
|
65
63
|
) -> None:
|
|
@@ -67,6 +65,8 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
67
65
|
self.embedding_provider = embedding_provider
|
|
68
66
|
self._session = session
|
|
69
67
|
self._initialized = False
|
|
68
|
+
self.table_name = f"vectorchord_{task_name}_embeddings"
|
|
69
|
+
self.index_name = f"{self.table_name}_idx"
|
|
70
70
|
|
|
71
71
|
async def _initialize(self) -> None:
|
|
72
72
|
"""Initialize the VectorChord environment."""
|
|
@@ -88,15 +88,23 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
88
88
|
vector_dim = (await self.embedding_provider.embed(["dimension"]))[0]
|
|
89
89
|
await self._session.execute(
|
|
90
90
|
text(
|
|
91
|
-
f"""CREATE TABLE IF NOT EXISTS {
|
|
91
|
+
f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
92
92
|
id SERIAL PRIMARY KEY,
|
|
93
93
|
snippet_id INT NOT NULL UNIQUE,
|
|
94
94
|
embedding VECTOR({len(vector_dim)}) NOT NULL
|
|
95
95
|
);"""
|
|
96
96
|
)
|
|
97
97
|
)
|
|
98
|
-
await self._session.execute(
|
|
99
|
-
|
|
98
|
+
await self._session.execute(
|
|
99
|
+
text(
|
|
100
|
+
CREATE_VCHORD_INDEX.format(
|
|
101
|
+
TABLE_NAME=self.table_name, INDEX_NAME=self.index_name
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
result = await self._session.execute(
|
|
106
|
+
text(CHECK_VCHORD_EMBEDDING_DIMENSION.format(TABLE_NAME=self.table_name))
|
|
107
|
+
)
|
|
100
108
|
vector_dim_from_db = result.scalar_one()
|
|
101
109
|
if vector_dim_from_db != len(vector_dim):
|
|
102
110
|
msg = (
|
|
@@ -123,7 +131,7 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
123
131
|
embeddings = await self.embedding_provider.embed([doc.text for doc in data])
|
|
124
132
|
# Execute inserts
|
|
125
133
|
await self._execute(
|
|
126
|
-
text(INSERT_QUERY),
|
|
134
|
+
text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
|
|
127
135
|
[
|
|
128
136
|
{"snippet_id": doc.snippet_id, "embedding": str(embedding)}
|
|
129
137
|
for doc, embedding in zip(data, embeddings, strict=True)
|
|
@@ -135,7 +143,8 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
135
143
|
"""Query the embedding model."""
|
|
136
144
|
embedding = await self.embedding_provider.embed([query])
|
|
137
145
|
result = await self._execute(
|
|
138
|
-
text(SEARCH_QUERY
|
|
146
|
+
text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
|
|
147
|
+
{"query": str(embedding[0]), "top_k": top_k},
|
|
139
148
|
)
|
|
140
149
|
rows = result.mappings().all()
|
|
141
150
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Enrichment."""
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Embedding service."""
|
|
2
|
+
|
|
3
|
+
from kodit.config import AppContext
|
|
4
|
+
from kodit.enrichment.enrichment_provider.local_enrichment_provider import (
|
|
5
|
+
LocalEnrichmentProvider,
|
|
6
|
+
)
|
|
7
|
+
from kodit.enrichment.enrichment_provider.openai_enrichment_provider import (
|
|
8
|
+
OpenAIEnrichmentProvider,
|
|
9
|
+
)
|
|
10
|
+
from kodit.enrichment.enrichment_service import (
|
|
11
|
+
EnrichmentService,
|
|
12
|
+
LLMEnrichmentService,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def enrichment_factory(app_context: AppContext) -> EnrichmentService:
|
|
17
|
+
"""Create an embedding service."""
|
|
18
|
+
openai_client = app_context.get_default_openai_client()
|
|
19
|
+
if openai_client is not None:
|
|
20
|
+
enrichment_provider = OpenAIEnrichmentProvider(openai_client=openai_client)
|
|
21
|
+
return LLMEnrichmentService(enrichment_provider)
|
|
22
|
+
|
|
23
|
+
return LLMEnrichmentService(LocalEnrichmentProvider())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Enrichment provider."""
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Enrichment provider."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
ENRICHMENT_SYSTEM_PROMPT = """
|
|
6
|
+
You are a professional software developer. You will be given a snippet of code.
|
|
7
|
+
Please provide a concise explanation of the code.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EnrichmentProvider(ABC):
|
|
12
|
+
"""Enrichment provider."""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
16
|
+
"""Enrich a list of strings."""
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Local embedding service."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
|
7
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
8
|
+
|
|
9
|
+
from kodit.enrichment.enrichment_provider.enrichment_provider import (
|
|
10
|
+
ENRICHMENT_SYSTEM_PROMPT,
|
|
11
|
+
EnrichmentProvider,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LocalEnrichmentProvider(EnrichmentProvider):
|
|
16
|
+
"""Local embedder."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, model_name: str = "Qwen/Qwen3-0.6B") -> None:
|
|
19
|
+
"""Initialize the local enrichment provider."""
|
|
20
|
+
self.log = structlog.get_logger(__name__)
|
|
21
|
+
self.model_name = model_name
|
|
22
|
+
self.model = None
|
|
23
|
+
self.tokenizer = None
|
|
24
|
+
|
|
25
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
26
|
+
"""Enrich a list of strings."""
|
|
27
|
+
if self.tokenizer is None:
|
|
28
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
29
|
+
if self.model is None:
|
|
30
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
31
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
32
|
+
self.model_name,
|
|
33
|
+
torch_dtype="auto",
|
|
34
|
+
trust_remote_code=True,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
results = []
|
|
38
|
+
for snippet in data:
|
|
39
|
+
# prepare the model input
|
|
40
|
+
messages = [
|
|
41
|
+
{"role": "system", "content": ENRICHMENT_SYSTEM_PROMPT},
|
|
42
|
+
{"role": "user", "content": snippet},
|
|
43
|
+
]
|
|
44
|
+
text = self.tokenizer.apply_chat_template(
|
|
45
|
+
messages,
|
|
46
|
+
tokenize=False,
|
|
47
|
+
add_generation_prompt=True,
|
|
48
|
+
enable_thinking=False,
|
|
49
|
+
)
|
|
50
|
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(
|
|
51
|
+
self.model.device
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# conduct text completion
|
|
55
|
+
generated_ids = self.model.generate(**model_inputs, max_new_tokens=32768)
|
|
56
|
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
|
|
57
|
+
content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip(
|
|
58
|
+
"\n"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
results.append(content)
|
|
62
|
+
|
|
63
|
+
return results
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""OpenAI embedding service."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
import tiktoken
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from kodit.enrichment.enrichment_provider.enrichment_provider import (
|
|
11
|
+
ENRICHMENT_SYSTEM_PROMPT,
|
|
12
|
+
EnrichmentProvider,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
OPENAI_NUM_PARALLEL_TASKS = 10
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OpenAIEnrichmentProvider(EnrichmentProvider):
|
|
19
|
+
"""OpenAI enrichment provider."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
openai_client: AsyncOpenAI,
|
|
24
|
+
model_name: str = "gpt-4o-mini",
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initialize the OpenAI enrichment provider."""
|
|
27
|
+
self.log = structlog.get_logger(__name__)
|
|
28
|
+
self.openai_client = openai_client
|
|
29
|
+
self.model_name = model_name
|
|
30
|
+
self.encoding = tiktoken.encoding_for_model(model_name)
|
|
31
|
+
|
|
32
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
33
|
+
"""Enrich a list of documents."""
|
|
34
|
+
# Process batches in parallel with a semaphore to limit concurrent requests
|
|
35
|
+
sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
|
|
36
|
+
|
|
37
|
+
# Create a list of tuples with a temporary id for each snippet
|
|
38
|
+
# We need to do this so that we can return the results in the same order as the
|
|
39
|
+
# input data
|
|
40
|
+
input_data = [(i, snippet) for i, snippet in enumerate(data)]
|
|
41
|
+
|
|
42
|
+
async def process_data(data: tuple[int, str]) -> tuple[int, str]:
|
|
43
|
+
snippet_id, snippet = data
|
|
44
|
+
if not snippet:
|
|
45
|
+
return snippet_id, ""
|
|
46
|
+
async with sem:
|
|
47
|
+
try:
|
|
48
|
+
response = await self.openai_client.chat.completions.create(
|
|
49
|
+
model=self.model_name,
|
|
50
|
+
messages=[
|
|
51
|
+
{
|
|
52
|
+
"role": "system",
|
|
53
|
+
"content": ENRICHMENT_SYSTEM_PROMPT,
|
|
54
|
+
},
|
|
55
|
+
{"role": "user", "content": snippet},
|
|
56
|
+
],
|
|
57
|
+
)
|
|
58
|
+
return snippet_id, response.choices[0].message.content or ""
|
|
59
|
+
except Exception as e:
|
|
60
|
+
self.log.exception("Error enriching data", error=str(e))
|
|
61
|
+
return snippet_id, ""
|
|
62
|
+
|
|
63
|
+
# Create tasks for all data
|
|
64
|
+
tasks = [process_data(snippet) for snippet in input_data]
|
|
65
|
+
|
|
66
|
+
# Process all data and yield results as they complete
|
|
67
|
+
results: list[tuple[int, str]] = []
|
|
68
|
+
for task in tqdm(
|
|
69
|
+
asyncio.as_completed(tasks),
|
|
70
|
+
total=len(tasks),
|
|
71
|
+
leave=False,
|
|
72
|
+
):
|
|
73
|
+
result = await task
|
|
74
|
+
results.append(result)
|
|
75
|
+
|
|
76
|
+
# Output in the same order as the input data
|
|
77
|
+
return [result for _, result in sorted(results, key=lambda x: x[0])]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Enrichment service."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from kodit.enrichment.enrichment_provider.enrichment_provider import EnrichmentProvider
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EnrichmentService(ABC):
|
|
9
|
+
"""Enrichment service."""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
13
|
+
"""Enrich a list of strings."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NullEnrichmentService(EnrichmentService):
|
|
17
|
+
"""Null enrichment service."""
|
|
18
|
+
|
|
19
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
20
|
+
"""Enrich a list of strings."""
|
|
21
|
+
return [""] * len(data)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LLMEnrichmentService(EnrichmentService):
|
|
25
|
+
"""Enrichment service using an LLM."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, enrichment_provider: EnrichmentProvider) -> None:
|
|
28
|
+
"""Initialize the enrichment service."""
|
|
29
|
+
self.enrichment_provider = enrichment_provider
|
|
30
|
+
|
|
31
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
32
|
+
"""Enrich a list of strings."""
|
|
33
|
+
return await self.enrichment_provider.enrich(data)
|
kodit/indexing/fusion.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Fusion functions for combining search results."""
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class FusionResult:
|
|
9
|
+
"""Result of a fusion operation."""
|
|
10
|
+
|
|
11
|
+
id: int
|
|
12
|
+
score: float
|
|
13
|
+
original_scores: list[float]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class FusionRequest:
|
|
18
|
+
"""Result of a RRF operation."""
|
|
19
|
+
|
|
20
|
+
id: int
|
|
21
|
+
score: float
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def reciprocal_rank_fusion(
|
|
25
|
+
rankings: list[list[FusionRequest]], k: float = 60
|
|
26
|
+
) -> list[FusionResult]:
|
|
27
|
+
"""RRF prioritises results that are present in all results.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
rankings: List of rankers, each containing a list of document ids. Top of the
|
|
31
|
+
list is considered to be the best result.
|
|
32
|
+
k: Parameter for RRF.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Dictionary of ids and their scores.
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
scores = {}
|
|
39
|
+
for ranker in rankings:
|
|
40
|
+
for rank in ranker:
|
|
41
|
+
scores[rank.id] = float(0)
|
|
42
|
+
|
|
43
|
+
for ranker in rankings:
|
|
44
|
+
for i, rank in enumerate(ranker):
|
|
45
|
+
scores[rank.id] += 1.0 / (k + i)
|
|
46
|
+
|
|
47
|
+
# Create a list of tuples of ids and their scores
|
|
48
|
+
results = [(rank, scores[rank]) for rank in scores]
|
|
49
|
+
|
|
50
|
+
# Sort results by score
|
|
51
|
+
results.sort(key=lambda x: x[1], reverse=True)
|
|
52
|
+
|
|
53
|
+
# Create a map of original scores to ids
|
|
54
|
+
original_scores_to_ids = defaultdict(list)
|
|
55
|
+
for ranker in rankings:
|
|
56
|
+
for rank in ranker:
|
|
57
|
+
original_scores_to_ids[rank.id].append(rank.score)
|
|
58
|
+
|
|
59
|
+
# Rebuild a list of final results with their original scores
|
|
60
|
+
return [
|
|
61
|
+
FusionResult(
|
|
62
|
+
id=result[0],
|
|
63
|
+
score=result[1],
|
|
64
|
+
original_scores=original_scores_to_ids[result[0]],
|
|
65
|
+
)
|
|
66
|
+
for result in results
|
|
67
|
+
]
|
|
@@ -196,3 +196,23 @@ class IndexRepository:
|
|
|
196
196
|
"""
|
|
197
197
|
self.session.add(embedding)
|
|
198
198
|
await self.session.commit()
|
|
199
|
+
|
|
200
|
+
async def list_snippets_by_ids(self, ids: list[int]) -> list[tuple[File, Snippet]]:
|
|
201
|
+
"""List snippets by IDs.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
A list of snippets in the same order as the input IDs.
|
|
205
|
+
|
|
206
|
+
"""
|
|
207
|
+
query = (
|
|
208
|
+
select(Snippet, File)
|
|
209
|
+
.where(Snippet.id.in_(ids))
|
|
210
|
+
.join(File, Snippet.file_id == File.id)
|
|
211
|
+
)
|
|
212
|
+
rows = await self.session.execute(query)
|
|
213
|
+
|
|
214
|
+
# Create a dictionary for O(1) lookup of results by ID
|
|
215
|
+
id_to_result = {snippet.id: (file, snippet) for snippet, file in rows.all()}
|
|
216
|
+
|
|
217
|
+
# Return results in the same order as input IDs
|
|
218
|
+
return [id_to_result[i] for i in ids]
|
|
@@ -13,11 +13,17 @@ import pydantic
|
|
|
13
13
|
import structlog
|
|
14
14
|
from tqdm.asyncio import tqdm
|
|
15
15
|
|
|
16
|
-
from kodit.bm25.keyword_search_service import
|
|
16
|
+
from kodit.bm25.keyword_search_service import (
|
|
17
|
+
BM25Document,
|
|
18
|
+
BM25Result,
|
|
19
|
+
KeywordSearchProvider,
|
|
20
|
+
)
|
|
17
21
|
from kodit.embedding.vector_search_service import (
|
|
18
22
|
VectorSearchRequest,
|
|
19
23
|
VectorSearchService,
|
|
20
24
|
)
|
|
25
|
+
from kodit.enrichment.enrichment_service import EnrichmentService
|
|
26
|
+
from kodit.indexing.fusion import FusionRequest, reciprocal_rank_fusion
|
|
21
27
|
from kodit.indexing.indexing_models import Snippet
|
|
22
28
|
from kodit.indexing.indexing_repository import IndexRepository
|
|
23
29
|
from kodit.snippets.snippets import SnippetService
|
|
@@ -42,6 +48,28 @@ class IndexView(pydantic.BaseModel):
|
|
|
42
48
|
num_snippets: int | None = None
|
|
43
49
|
|
|
44
50
|
|
|
51
|
+
class SearchRequest(pydantic.BaseModel):
|
|
52
|
+
"""Request for a search."""
|
|
53
|
+
|
|
54
|
+
text_query: str | None = None
|
|
55
|
+
code_query: str | None = None
|
|
56
|
+
keywords: list[str] | None = None
|
|
57
|
+
top_k: int = 10
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SearchResult(pydantic.BaseModel):
|
|
61
|
+
"""Data transfer object for search results.
|
|
62
|
+
|
|
63
|
+
This model represents a single search result, containing both the file path
|
|
64
|
+
and the matching snippet content.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
id: int
|
|
68
|
+
uri: str
|
|
69
|
+
content: str
|
|
70
|
+
original_scores: list[float]
|
|
71
|
+
|
|
72
|
+
|
|
45
73
|
class IndexService:
|
|
46
74
|
"""Service for managing code indexes.
|
|
47
75
|
|
|
@@ -50,12 +78,14 @@ class IndexService:
|
|
|
50
78
|
IndexRepository), and provides a clean API for index management.
|
|
51
79
|
"""
|
|
52
80
|
|
|
53
|
-
def __init__(
|
|
81
|
+
def __init__( # noqa: PLR0913
|
|
54
82
|
self,
|
|
55
83
|
repository: IndexRepository,
|
|
56
84
|
source_service: SourceService,
|
|
57
85
|
keyword_search_provider: KeywordSearchProvider,
|
|
58
|
-
|
|
86
|
+
code_search_service: VectorSearchService,
|
|
87
|
+
text_search_service: VectorSearchService,
|
|
88
|
+
enrichment_service: EnrichmentService,
|
|
59
89
|
) -> None:
|
|
60
90
|
"""Initialize the index service.
|
|
61
91
|
|
|
@@ -69,7 +99,9 @@ class IndexService:
|
|
|
69
99
|
self.snippet_service = SnippetService()
|
|
70
100
|
self.log = structlog.get_logger(__name__)
|
|
71
101
|
self.keyword_search_provider = keyword_search_provider
|
|
72
|
-
self.code_search_service =
|
|
102
|
+
self.code_search_service = code_search_service
|
|
103
|
+
self.text_search_service = text_search_service
|
|
104
|
+
self.enrichment_service = enrichment_service
|
|
73
105
|
|
|
74
106
|
async def create(self, source_id: int) -> IndexView:
|
|
75
107
|
"""Create a new index for a source.
|
|
@@ -152,9 +184,93 @@ class IndexService:
|
|
|
152
184
|
]
|
|
153
185
|
)
|
|
154
186
|
|
|
187
|
+
self.log.info("Enriching snippets")
|
|
188
|
+
enriched_contents = await self.enrichment_service.enrich(
|
|
189
|
+
[snippet.content for snippet in snippets]
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self.log.info("Creating semantic text index")
|
|
193
|
+
with Spinner():
|
|
194
|
+
await self.text_search_service.index(
|
|
195
|
+
[
|
|
196
|
+
VectorSearchRequest(snippet.id, enriched_content)
|
|
197
|
+
for snippet, enriched_content in zip(
|
|
198
|
+
snippets, enriched_contents, strict=True
|
|
199
|
+
)
|
|
200
|
+
]
|
|
201
|
+
)
|
|
202
|
+
# Add the enriched text back to the snippets and write to the database
|
|
203
|
+
for snippet, enriched_content in zip(
|
|
204
|
+
snippets, enriched_contents, strict=True
|
|
205
|
+
):
|
|
206
|
+
snippet.content = (
|
|
207
|
+
enriched_content + "\n\n```\n" + snippet.content + "\n```"
|
|
208
|
+
)
|
|
209
|
+
await self.repository.add_snippet_or_update_content(snippet)
|
|
210
|
+
|
|
155
211
|
# Update index timestamp
|
|
156
212
|
await self.repository.update_index_timestamp(index)
|
|
157
213
|
|
|
214
|
+
async def search(self, request: SearchRequest) -> list[SearchResult]:
|
|
215
|
+
"""Search for relevant data."""
|
|
216
|
+
fusion_list: list[list[FusionRequest]] = []
|
|
217
|
+
if request.keywords:
|
|
218
|
+
# Gather results for each keyword
|
|
219
|
+
result_ids: list[BM25Result] = []
|
|
220
|
+
for keyword in request.keywords:
|
|
221
|
+
results = await self.keyword_search_provider.retrieve(
|
|
222
|
+
keyword, request.top_k
|
|
223
|
+
)
|
|
224
|
+
result_ids.extend(results)
|
|
225
|
+
|
|
226
|
+
fusion_list.append(
|
|
227
|
+
[FusionRequest(id=x.snippet_id, score=x.score) for x in result_ids]
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Compute embedding for semantic query
|
|
231
|
+
if request.code_query:
|
|
232
|
+
query_embedding = await self.code_search_service.retrieve(
|
|
233
|
+
request.code_query, top_k=request.top_k
|
|
234
|
+
)
|
|
235
|
+
fusion_list.append(
|
|
236
|
+
[FusionRequest(id=x.snippet_id, score=x.score) for x in query_embedding]
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if request.text_query:
|
|
240
|
+
query_embedding = await self.text_search_service.retrieve(
|
|
241
|
+
request.text_query, top_k=request.top_k
|
|
242
|
+
)
|
|
243
|
+
fusion_list.append(
|
|
244
|
+
[FusionRequest(id=x.snippet_id, score=x.score) for x in query_embedding]
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if len(fusion_list) == 0:
|
|
248
|
+
return []
|
|
249
|
+
|
|
250
|
+
# Combine all results together with RFF if required
|
|
251
|
+
final_results = reciprocal_rank_fusion(
|
|
252
|
+
rankings=fusion_list,
|
|
253
|
+
k=60,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Only keep top_k results
|
|
257
|
+
final_results = final_results[: request.top_k]
|
|
258
|
+
|
|
259
|
+
# Get snippets from database (up to top_k)
|
|
260
|
+
search_results = await self.repository.list_snippets_by_ids(
|
|
261
|
+
[x.id for x in final_results]
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
return [
|
|
265
|
+
SearchResult(
|
|
266
|
+
id=snippet.id,
|
|
267
|
+
uri=file.uri,
|
|
268
|
+
content=snippet.content,
|
|
269
|
+
original_scores=fr.original_scores,
|
|
270
|
+
)
|
|
271
|
+
for (file, snippet), fr in zip(search_results, final_results, strict=True)
|
|
272
|
+
]
|
|
273
|
+
|
|
158
274
|
async def _create_snippets(
|
|
159
275
|
self,
|
|
160
276
|
index_id: int,
|
kodit/mcp.py
CHANGED
|
@@ -16,8 +16,11 @@ from kodit.bm25.keyword_search_factory import keyword_search_factory
|
|
|
16
16
|
from kodit.config import AppContext
|
|
17
17
|
from kodit.database import Database
|
|
18
18
|
from kodit.embedding.embedding_factory import embedding_factory
|
|
19
|
-
from kodit.
|
|
20
|
-
from kodit.
|
|
19
|
+
from kodit.enrichment.enrichment_factory import enrichment_factory
|
|
20
|
+
from kodit.indexing.indexing_repository import IndexRepository
|
|
21
|
+
from kodit.indexing.indexing_service import IndexService, SearchRequest, SearchResult
|
|
22
|
+
from kodit.source.source_repository import SourceRepository
|
|
23
|
+
from kodit.source.source_service import SourceService
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
@dataclass
|
|
@@ -123,32 +126,38 @@ async def search(
|
|
|
123
126
|
|
|
124
127
|
mcp_context: MCPContext = ctx.request_context.lifespan_context
|
|
125
128
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
+
source_repository = SourceRepository(mcp_context.session)
|
|
130
|
+
source_service = SourceService(
|
|
131
|
+
mcp_context.app_context.get_clone_dir(), source_repository
|
|
129
132
|
)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
log.debug("Creating search service")
|
|
137
|
-
search_service = SearchService(
|
|
138
|
-
repository=search_repository,
|
|
133
|
+
repository = IndexRepository(mcp_context.session)
|
|
134
|
+
service = IndexService(
|
|
135
|
+
repository=repository,
|
|
136
|
+
source_service=source_service,
|
|
139
137
|
keyword_search_provider=keyword_search_factory(
|
|
138
|
+
mcp_context.app_context, mcp_context.session
|
|
139
|
+
),
|
|
140
|
+
code_search_service=embedding_factory(
|
|
141
|
+
task_name="code",
|
|
140
142
|
app_context=mcp_context.app_context,
|
|
141
143
|
session=mcp_context.session,
|
|
142
144
|
),
|
|
143
|
-
|
|
145
|
+
text_search_service=embedding_factory(
|
|
146
|
+
task_name="text",
|
|
147
|
+
app_context=mcp_context.app_context,
|
|
148
|
+
session=mcp_context.session,
|
|
149
|
+
),
|
|
150
|
+
enrichment_service=enrichment_factory(mcp_context.app_context),
|
|
144
151
|
)
|
|
145
152
|
|
|
146
153
|
search_request = SearchRequest(
|
|
147
154
|
keywords=keywords,
|
|
148
155
|
code_query="\n".join(related_file_contents),
|
|
156
|
+
text_query=user_intent,
|
|
149
157
|
)
|
|
158
|
+
|
|
150
159
|
log.debug("Searching for snippets")
|
|
151
|
-
snippets = await
|
|
160
|
+
snippets = await service.search(request=search_request)
|
|
152
161
|
|
|
153
162
|
log.debug("Fusing output")
|
|
154
163
|
output = output_fusion(snippets=snippets)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
(function_declaration
|
|
2
|
+
name: (identifier) @function.name
|
|
3
|
+
body: (block) @function.body
|
|
4
|
+
) @function.def
|
|
5
|
+
|
|
6
|
+
(method_declaration
|
|
7
|
+
name: (field_identifier) @method.name
|
|
8
|
+
body: (block) @method.body
|
|
9
|
+
) @method.def
|
|
10
|
+
|
|
11
|
+
(import_declaration
|
|
12
|
+
(import_spec
|
|
13
|
+
path: (interpreted_string_literal) @import.name
|
|
14
|
+
)
|
|
15
|
+
) @import.statement
|
|
16
|
+
|
|
17
|
+
(identifier) @ident
|
|
18
|
+
|
|
19
|
+
(parameter_declaration
|
|
20
|
+
name: (identifier) @param.name
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
(package_clause "package" (package_identifier) @name.definition.module)
|
|
24
|
+
|
|
25
|
+
;; Exclude comments from being captured
|
|
26
|
+
(comment) @comment
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kodit
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.16
|
|
4
4
|
Summary: Code indexing for better AI code generation
|
|
5
5
|
Project-URL: Homepage, https://docs.helixml.tech/kodit/
|
|
6
6
|
Project-URL: Documentation, https://docs.helixml.tech/kodit/
|
|
@@ -15,6 +15,7 @@ Keywords: ai,indexing,mcp,rag
|
|
|
15
15
|
Classifier: Development Status :: 2 - Pre-Alpha
|
|
16
16
|
Classifier: Intended Audience :: Developers
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
18
19
|
Classifier: Topic :: Software Development :: Code Generators
|
|
19
20
|
Requires-Python: >=3.12
|
|
20
21
|
Requires-Dist: aiofiles>=24.1.0
|
|
@@ -42,6 +43,7 @@ Requires-Dist: sqlalchemy[asyncio]>=2.0.40
|
|
|
42
43
|
Requires-Dist: structlog>=25.3.0
|
|
43
44
|
Requires-Dist: tdqm>=0.0.1
|
|
44
45
|
Requires-Dist: tiktoken>=0.9.0
|
|
46
|
+
Requires-Dist: transformers>=4.51.3
|
|
45
47
|
Requires-Dist: tree-sitter-language-pack>=0.7.3
|
|
46
48
|
Requires-Dist: tree-sitter>=0.24.0
|
|
47
49
|
Requires-Dist: uritools>=5.0.0
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
kodit/.gitignore,sha256=ztkjgRwL9Uud1OEi36hGQeDGk3OLK1NfDEO8YqGYy8o,11
|
|
2
2
|
kodit/__init__.py,sha256=aEKHYninUq1yh6jaNfvJBYg-6fenpN132nJt1UU6Jxs,59
|
|
3
|
-
kodit/_version.py,sha256=
|
|
3
|
+
kodit/_version.py,sha256=VYJNWHISWEW-KD_clKUYcTY_Z30r993Sjws4URJIL0g,513
|
|
4
4
|
kodit/app.py,sha256=Mr5BFHOHx5zppwjC4XPWVvHjwgl1yrKbUjTWXKubJQM,891
|
|
5
|
-
kodit/cli.py,sha256=
|
|
5
|
+
kodit/cli.py,sha256=i7eEt0FdIQGEfXKFte-8fBcZZGE8BPXBp40aGwJDQGI,11323
|
|
6
6
|
kodit/config.py,sha256=2W2u5J8j-Mbt-C4xzOuK-PeuDCx0S_rnCXPhBwvfLT4,4353
|
|
7
7
|
kodit/database.py,sha256=WB1KpVxUYPgiJGU0gJa2hqytYB8wJEJ5z3WayhWzNMU,2403
|
|
8
8
|
kodit/log.py,sha256=HU1OmuxO4FcVw61k4WW7Y4WM7BrDaeplw1PcBHhuIZY,5434
|
|
9
|
-
kodit/mcp.py,sha256=
|
|
9
|
+
kodit/mcp.py,sha256=QruyPskWB0_x59pkfj5BBeXuR13GMny5TAZEa2j4U9s,5752
|
|
10
10
|
kodit/middleware.py,sha256=I6FOkqG9-8RH5kR1-0ZoQWfE4qLCB8lZYv8H_OCH29o,2714
|
|
11
11
|
kodit/bm25/__init__.py,sha256=j8zyriNWhbwE5Lbybzg1hQAhANlU9mKHWw4beeUR6og,19
|
|
12
12
|
kodit/bm25/keyword_search_factory.py,sha256=rp-wx3DJsc2KlELK1V337EyeYvmwnMQwUqOo1WVPSmg,631
|
|
@@ -14,21 +14,29 @@ kodit/bm25/keyword_search_service.py,sha256=aBbWQKgQmi2re3EIHdXFS00n7Wj3b2D0pZsL
|
|
|
14
14
|
kodit/bm25/local_bm25.py,sha256=AAbFhbQDqyL3d7jsPL7W4HsLxdoYctaDsREUXOLy6jM,3260
|
|
15
15
|
kodit/bm25/vectorchord_bm25.py,sha256=_nGrkUReYLLV-L8RIuIVLwjuhSYZl9T532n5OVf0kWs,6393
|
|
16
16
|
kodit/embedding/__init__.py,sha256=h9NXzDA1r-K23nvBajBV-RJzHJN0p3UJ7UQsmdnOoRw,24
|
|
17
|
-
kodit/embedding/embedding_factory.py,sha256=
|
|
17
|
+
kodit/embedding/embedding_factory.py,sha256=UGnFRyyQXazSUOwyW4Hg7Vq2-kfAoDj9lD4CTLu8x04,1630
|
|
18
18
|
kodit/embedding/embedding_models.py,sha256=rN90vSs86dYiqoawcp8E9jtwY31JoJXYfaDlsJK7uqc,656
|
|
19
19
|
kodit/embedding/embedding_repository.py,sha256=-ux3scpBzel8c0pMH9fNOEsSXFIzl-IfgaWrkTb1szo,6907
|
|
20
20
|
kodit/embedding/local_vector_search_service.py,sha256=hkF0qlfzjyGt400qIX9Mr6B7b7i8WvYIYWN2Z2C_pcs,1907
|
|
21
21
|
kodit/embedding/vector_search_service.py,sha256=pQJ129QjGrAWOXzqkywmgtDRpy8_gtzYgkivyqF9Vrs,1009
|
|
22
|
-
kodit/embedding/vectorchord_vector_search_service.py,sha256=
|
|
22
|
+
kodit/embedding/vectorchord_vector_search_service.py,sha256=KSs0IMFHHIllwq2d3A0LGqGGZDqO1Ht6K-BCfBBWW0Y,5051
|
|
23
23
|
kodit/embedding/embedding_provider/__init__.py,sha256=h9NXzDA1r-K23nvBajBV-RJzHJN0p3UJ7UQsmdnOoRw,24
|
|
24
|
-
kodit/embedding/embedding_provider/embedding_provider.py,sha256=
|
|
24
|
+
kodit/embedding/embedding_provider/embedding_provider.py,sha256=Tf3bwUsUMzAgoyLFM5qBtOLqPp1qr03TzrwGczkDvy0,1835
|
|
25
25
|
kodit/embedding/embedding_provider/hash_embedding_provider.py,sha256=nAhlhh8j8PqqCCbhVl26Y8ntFBm2vJBCtB4X04g5Wwg,2638
|
|
26
26
|
kodit/embedding/embedding_provider/local_embedding_provider.py,sha256=4ER-UPq506Y0TWU6qcs0nUqw6bSKQkSrdog-DhNQWM8,1906
|
|
27
|
-
kodit/embedding/embedding_provider/openai_embedding_provider.py,sha256=
|
|
27
|
+
kodit/embedding/embedding_provider/openai_embedding_provider.py,sha256=V_jdUXiaGdslplwxMlfgFc4_hAVS2eaJXMTs2C7RiLI,2666
|
|
28
|
+
kodit/enrichment/__init__.py,sha256=vBEolHpKaHUhfINX0dSGyAPlvgpLNAer9YzFtdvCB24,18
|
|
29
|
+
kodit/enrichment/enrichment_factory.py,sha256=vKjkUTdhj74IW2S4GENDWdWMJx6BwUSZjJGDC0i7DSk,787
|
|
30
|
+
kodit/enrichment/enrichment_service.py,sha256=87Sd3gGbEMJYb_wVrHG8L1yGIZmQNR7foUS4_y94azI,977
|
|
31
|
+
kodit/enrichment/enrichment_provider/__init__.py,sha256=klf8iuLVWX4iRz-DZQauFFNAoJC5CByczh48TBZPW-o,27
|
|
32
|
+
kodit/enrichment/enrichment_provider/enrichment_provider.py,sha256=E0H5rq3OENM0yYbA8K_3nSnj5lUHCpoIOqpWLo-2MVU,413
|
|
33
|
+
kodit/enrichment/enrichment_provider/local_enrichment_provider.py,sha256=bR6HR1gH7wtZdMLOwaKdASjvllRo1FlNW9GyZC11zAM,2164
|
|
34
|
+
kodit/enrichment/enrichment_provider/openai_enrichment_provider.py,sha256=gYuFTAeIVdQNlCUvNSPgRoiRwCvRD0C8419h8ubyABA,2725
|
|
28
35
|
kodit/indexing/__init__.py,sha256=cPyi2Iej3G1JFWlWr7X80_UrsMaTu5W5rBwgif1B3xo,75
|
|
36
|
+
kodit/indexing/fusion.py,sha256=TZb4fPAedXdEUXzwzOofW98QIOymdbclBOP1KOijuEk,1674
|
|
29
37
|
kodit/indexing/indexing_models.py,sha256=6NX9HVcj6Pu9ePwHC7n-PWSyAgukpJq0nCNmUIigtbo,1282
|
|
30
|
-
kodit/indexing/indexing_repository.py,sha256=
|
|
31
|
-
kodit/indexing/indexing_service.py,sha256=
|
|
38
|
+
kodit/indexing/indexing_repository.py,sha256=GYHoACUWYKQdVTwP7tfik_TMUD1WUK76nywH88eCSwg,7006
|
|
39
|
+
kodit/indexing/indexing_service.py,sha256=tKcZpi0pzsmF6OpqnqF0Q5HfSXxi5iLTysrVSou4JiQ,10579
|
|
32
40
|
kodit/migrations/README,sha256=ISVtAOvqvKk_5ThM5ioJE-lMkvf9IbknFUFVU_vPma4,58
|
|
33
41
|
kodit/migrations/__init__.py,sha256=lP5MuwlyWRMO6UcDWnQcQ3G-GYHcFb6rl9gYPHJ1sjo,40
|
|
34
42
|
kodit/migrations/env.py,sha256=w1M7OZh-ZeR2dPHS0ByXAUxQjfZQ8xIzMseWuzLDTWw,2469
|
|
@@ -36,14 +44,12 @@ kodit/migrations/script.py.mako,sha256=zWziKtiwYKEWuwPV_HBNHwa9LCT45_bi01-uSNFaO
|
|
|
36
44
|
kodit/migrations/versions/7c3bbc2ab32b_add_embeddings_table.py,sha256=-61qol9PfQKILCDQRA5jEaats9aGZs9Wdtp-j-38SF4,1644
|
|
37
45
|
kodit/migrations/versions/85155663351e_initial.py,sha256=Cg7zlF871o9ShV5rQMQ1v7hRV7fI59veDY9cjtTrs-8,3306
|
|
38
46
|
kodit/migrations/versions/__init__.py,sha256=9-lHzptItTzq_fomdIRBegQNm4Znx6pVjwD4MiqRIdo,36
|
|
39
|
-
kodit/search/__init__.py,sha256=4QbdjbrlhNKMovmuKHxJnUeZT7KNjTTFU0GdnuwUHdQ,36
|
|
40
|
-
kodit/search/search_repository.py,sha256=6q0k7JMTM_7hPK2TSA30CykGbc5N16kCL7HTjlbai0w,1563
|
|
41
|
-
kodit/search/search_service.py,sha256=-XlbP_50e1dKFJ5jBvex5FjLnffW43LcwQV_SeYNFB0,3944
|
|
42
47
|
kodit/snippets/__init__.py,sha256=-2coNoCRjTixU9KcP6alpmt7zqf37tCRWH3D7FPJ8dg,48
|
|
43
48
|
kodit/snippets/method_snippets.py,sha256=EVHhSNWahAC5nSXv9fWVFJY2yq25goHdCSCuENC07F8,4145
|
|
44
49
|
kodit/snippets/snippets.py,sha256=mwN0bM1Msu8ZeEsUHyQ7tx3Hj3vZsm8G7Wu4eWSkLY8,1539
|
|
45
50
|
kodit/snippets/languages/__init__.py,sha256=Bj5KKZSls2MQ8ZY1S_nHg447MgGZW-2WZM-oq6vjwwA,1187
|
|
46
51
|
kodit/snippets/languages/csharp.scm,sha256=gbBN4RiV1FBuTJF6orSnDFi8H9JwTw-d4piLJYsWUsc,222
|
|
52
|
+
kodit/snippets/languages/go.scm,sha256=SEX9mTOrhP2KiQW7oflDKkd21u5dK56QbJ4LvTDxY8A,533
|
|
47
53
|
kodit/snippets/languages/python.scm,sha256=ee85R9PBzwye3IMTE7-iVoKWd_ViU3EJISTyrFGrVeo,429
|
|
48
54
|
kodit/source/__init__.py,sha256=1NTZyPdjThVQpZO1Mp1ColVsS7sqYanOVLqnoqV9Ipo,83
|
|
49
55
|
kodit/source/source_models.py,sha256=xb42CaNDO1CUB8SIW-xXMrB6Ji8cFw-yeJ550xBEg9Q,2398
|
|
@@ -51,8 +57,8 @@ kodit/source/source_repository.py,sha256=0EksMpoLzdkfe8S4eeCm4Sf7TuxsOzOzaF4BBsM
|
|
|
51
57
|
kodit/source/source_service.py,sha256=u_GaH07ewakThQJRfT8O_yZ54A52qLtJuM1bF3xUT2A,9633
|
|
52
58
|
kodit/util/__init__.py,sha256=bPu6CtqDWCRGU7VgW2_aiQrCBi8G89FS6k1PjvDajJ0,37
|
|
53
59
|
kodit/util/spinner.py,sha256=R9bzrHtBiIH6IfLbmsIVHL53s8vg-tqW4lwGGALu4dw,1932
|
|
54
|
-
kodit-0.1.
|
|
55
|
-
kodit-0.1.
|
|
56
|
-
kodit-0.1.
|
|
57
|
-
kodit-0.1.
|
|
58
|
-
kodit-0.1.
|
|
60
|
+
kodit-0.1.16.dist-info/METADATA,sha256=1lR4ZSTiRBzUv9Gj8FPspv4GU2vWGQU6HSiffWgU2Do,2467
|
|
61
|
+
kodit-0.1.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
62
|
+
kodit-0.1.16.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
|
|
63
|
+
kodit-0.1.16.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
64
|
+
kodit-0.1.16.dist-info/RECORD,,
|
kodit/search/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
"""Search for relevant snippets."""
|
|
@@ -1,57 +0,0 @@
|
|
|
1
|
-
"""Repository for searching for relevant snippets."""
|
|
2
|
-
|
|
3
|
-
from typing import TypeVar
|
|
4
|
-
|
|
5
|
-
from sqlalchemy import (
|
|
6
|
-
select,
|
|
7
|
-
)
|
|
8
|
-
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
-
|
|
10
|
-
from kodit.indexing.indexing_models import Snippet
|
|
11
|
-
from kodit.source.source_models import File
|
|
12
|
-
|
|
13
|
-
T = TypeVar("T")
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class SearchRepository:
|
|
17
|
-
"""Repository for searching for relevant snippets."""
|
|
18
|
-
|
|
19
|
-
def __init__(self, session: AsyncSession) -> None:
|
|
20
|
-
"""Initialize the search repository.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
session: The SQLAlchemy async session to use for database operations.
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
self.session = session
|
|
27
|
-
|
|
28
|
-
async def list_snippet_ids(self) -> list[int]:
|
|
29
|
-
"""List all snippet IDs.
|
|
30
|
-
|
|
31
|
-
Returns:
|
|
32
|
-
A list of all snippets.
|
|
33
|
-
|
|
34
|
-
"""
|
|
35
|
-
query = select(Snippet.id)
|
|
36
|
-
rows = await self.session.execute(query)
|
|
37
|
-
return list(rows.scalars().all())
|
|
38
|
-
|
|
39
|
-
async def list_snippets_by_ids(self, ids: list[int]) -> list[tuple[File, Snippet]]:
|
|
40
|
-
"""List snippets by IDs.
|
|
41
|
-
|
|
42
|
-
Returns:
|
|
43
|
-
A list of snippets in the same order as the input IDs.
|
|
44
|
-
|
|
45
|
-
"""
|
|
46
|
-
query = (
|
|
47
|
-
select(Snippet, File)
|
|
48
|
-
.where(Snippet.id.in_(ids))
|
|
49
|
-
.join(File, Snippet.file_id == File.id)
|
|
50
|
-
)
|
|
51
|
-
rows = await self.session.execute(query)
|
|
52
|
-
|
|
53
|
-
# Create a dictionary for O(1) lookup of results by ID
|
|
54
|
-
id_to_result = {snippet.id: (file, snippet) for snippet, file in rows.all()}
|
|
55
|
-
|
|
56
|
-
# Return results in the same order as input IDs
|
|
57
|
-
return [id_to_result[i] for i in ids]
|
kodit/search/search_service.py
DELETED
|
@@ -1,135 +0,0 @@
|
|
|
1
|
-
"""Search service."""
|
|
2
|
-
|
|
3
|
-
import pydantic
|
|
4
|
-
import structlog
|
|
5
|
-
|
|
6
|
-
from kodit.bm25.keyword_search_service import BM25Result, KeywordSearchProvider
|
|
7
|
-
from kodit.embedding.vector_search_service import VectorSearchService
|
|
8
|
-
from kodit.search.search_repository import SearchRepository
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class SearchRequest(pydantic.BaseModel):
|
|
12
|
-
"""Request for a search."""
|
|
13
|
-
|
|
14
|
-
code_query: str | None = None
|
|
15
|
-
keywords: list[str] | None = None
|
|
16
|
-
top_k: int = 10
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class SearchResult(pydantic.BaseModel):
|
|
20
|
-
"""Data transfer object for search results.
|
|
21
|
-
|
|
22
|
-
This model represents a single search result, containing both the file path
|
|
23
|
-
and the matching snippet content.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
id: int
|
|
27
|
-
uri: str
|
|
28
|
-
content: str
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class Snippet(pydantic.BaseModel):
|
|
32
|
-
"""Snippet model."""
|
|
33
|
-
|
|
34
|
-
content: str
|
|
35
|
-
file_path: str
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class SearchService:
|
|
39
|
-
"""Service for searching for relevant data."""
|
|
40
|
-
|
|
41
|
-
def __init__(
|
|
42
|
-
self,
|
|
43
|
-
repository: SearchRepository,
|
|
44
|
-
keyword_search_provider: KeywordSearchProvider,
|
|
45
|
-
embedding_service: VectorSearchService,
|
|
46
|
-
) -> None:
|
|
47
|
-
"""Initialize the search service."""
|
|
48
|
-
self.repository = repository
|
|
49
|
-
self.log = structlog.get_logger(__name__)
|
|
50
|
-
self.keyword_search_provider = keyword_search_provider
|
|
51
|
-
self.code_embedding_service = embedding_service
|
|
52
|
-
|
|
53
|
-
async def search(self, request: SearchRequest) -> list[SearchResult]:
|
|
54
|
-
"""Search for relevant data."""
|
|
55
|
-
fusion_list = []
|
|
56
|
-
if request.keywords:
|
|
57
|
-
# Gather results for each keyword
|
|
58
|
-
result_ids: list[BM25Result] = []
|
|
59
|
-
for keyword in request.keywords:
|
|
60
|
-
results = await self.keyword_search_provider.retrieve(
|
|
61
|
-
keyword, request.top_k
|
|
62
|
-
)
|
|
63
|
-
result_ids.extend(results)
|
|
64
|
-
|
|
65
|
-
# Sort results by score
|
|
66
|
-
result_ids.sort(key=lambda x: x[1], reverse=True)
|
|
67
|
-
|
|
68
|
-
self.log.debug("Search results (BM25)", results=result_ids)
|
|
69
|
-
|
|
70
|
-
bm25_results = [x[0] for x in result_ids]
|
|
71
|
-
fusion_list.append(bm25_results)
|
|
72
|
-
|
|
73
|
-
# Compute embedding for semantic query
|
|
74
|
-
semantic_results = []
|
|
75
|
-
if request.code_query:
|
|
76
|
-
query_embedding = await self.code_embedding_service.retrieve(
|
|
77
|
-
request.code_query, top_k=request.top_k
|
|
78
|
-
)
|
|
79
|
-
semantic_results = [x.snippet_id for x in query_embedding]
|
|
80
|
-
fusion_list.append(semantic_results)
|
|
81
|
-
|
|
82
|
-
if len(fusion_list) == 0:
|
|
83
|
-
return []
|
|
84
|
-
|
|
85
|
-
# Combine all results together with RFF if required
|
|
86
|
-
final_results = reciprocal_rank_fusion(fusion_list, k=60)
|
|
87
|
-
|
|
88
|
-
# Extract ids from final results
|
|
89
|
-
final_ids = [x[0] for x in final_results]
|
|
90
|
-
|
|
91
|
-
# Get snippets from database (up to top_k)
|
|
92
|
-
search_results = await self.repository.list_snippets_by_ids(
|
|
93
|
-
final_ids[: request.top_k]
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
return [
|
|
97
|
-
SearchResult(
|
|
98
|
-
id=snippet.id,
|
|
99
|
-
uri=file.uri,
|
|
100
|
-
content=snippet.content,
|
|
101
|
-
)
|
|
102
|
-
for file, snippet in search_results
|
|
103
|
-
]
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def reciprocal_rank_fusion(
|
|
107
|
-
rankings: list[list[int]], k: float = 60
|
|
108
|
-
) -> list[tuple[int, float]]:
|
|
109
|
-
"""RRF prioritises results that are present in all results.
|
|
110
|
-
|
|
111
|
-
Args:
|
|
112
|
-
rankings: List of rankers, each containing a list of document ids. Top of the
|
|
113
|
-
list is considered to be the best result.
|
|
114
|
-
k: Parameter for RRF.
|
|
115
|
-
|
|
116
|
-
Returns:
|
|
117
|
-
Dictionary of ids and their scores.
|
|
118
|
-
|
|
119
|
-
"""
|
|
120
|
-
scores = {}
|
|
121
|
-
for ranker in rankings:
|
|
122
|
-
for rank in ranker:
|
|
123
|
-
scores[rank] = float(0)
|
|
124
|
-
|
|
125
|
-
for ranker in rankings:
|
|
126
|
-
for i, rank in enumerate(ranker):
|
|
127
|
-
scores[rank] += 1.0 / (k + i)
|
|
128
|
-
|
|
129
|
-
# Create a list of tuples of ids and their scores
|
|
130
|
-
results = [(rank, scores[rank]) for rank in scores]
|
|
131
|
-
|
|
132
|
-
# Sort results by score
|
|
133
|
-
results.sort(key=lambda x: x[1], reverse=True)
|
|
134
|
-
|
|
135
|
-
return results
|
|
File without changes
|
|
File without changes
|
|
File without changes
|