kodit 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/bm25/keyword_search_factory.py +17 -0
- kodit/bm25/keyword_search_service.py +34 -0
- kodit/bm25/{bm25.py → local_bm25.py} +40 -14
- kodit/bm25/vectorchord_bm25.py +193 -0
- kodit/cli.py +114 -25
- kodit/config.py +9 -2
- kodit/database.py +4 -2
- kodit/embedding/embedding_factory.py +44 -0
- kodit/embedding/embedding_provider/__init__.py +1 -0
- kodit/embedding/embedding_provider/embedding_provider.py +60 -0
- kodit/embedding/embedding_provider/hash_embedding_provider.py +77 -0
- kodit/embedding/embedding_provider/local_embedding_provider.py +58 -0
- kodit/embedding/embedding_provider/openai_embedding_provider.py +75 -0
- kodit/{search/search_repository.py → embedding/embedding_repository.py} +61 -33
- kodit/embedding/local_vector_search_service.py +50 -0
- kodit/embedding/vector_search_service.py +38 -0
- kodit/embedding/vectorchord_vector_search_service.py +154 -0
- kodit/enrichment/__init__.py +1 -0
- kodit/enrichment/enrichment_factory.py +23 -0
- kodit/enrichment/enrichment_provider/__init__.py +1 -0
- kodit/enrichment/enrichment_provider/enrichment_provider.py +16 -0
- kodit/enrichment/enrichment_provider/local_enrichment_provider.py +63 -0
- kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +77 -0
- kodit/enrichment/enrichment_service.py +33 -0
- kodit/indexing/fusion.py +67 -0
- kodit/indexing/indexing_repository.py +44 -4
- kodit/indexing/indexing_service.py +142 -31
- kodit/mcp.py +31 -18
- kodit/snippets/languages/go.scm +26 -0
- kodit/source/source_service.py +9 -3
- kodit/util/__init__.py +1 -0
- kodit/util/spinner.py +59 -0
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/METADATA +4 -1
- kodit-0.1.16.dist-info/RECORD +64 -0
- kodit/embedding/embedding.py +0 -203
- kodit/search/__init__.py +0 -1
- kodit/search/search_service.py +0 -147
- kodit-0.1.14.dist-info/RECORD +0 -44
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/WHEEL +0 -0
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Enrichment provider."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
ENRICHMENT_SYSTEM_PROMPT = """
|
|
6
|
+
You are a professional software developer. You will be given a snippet of code.
|
|
7
|
+
Please provide a concise explanation of the code.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EnrichmentProvider(ABC):
|
|
12
|
+
"""Enrichment provider."""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
16
|
+
"""Enrich a list of strings."""
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Local embedding service."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
|
7
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
8
|
+
|
|
9
|
+
from kodit.enrichment.enrichment_provider.enrichment_provider import (
|
|
10
|
+
ENRICHMENT_SYSTEM_PROMPT,
|
|
11
|
+
EnrichmentProvider,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LocalEnrichmentProvider(EnrichmentProvider):
|
|
16
|
+
"""Local embedder."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, model_name: str = "Qwen/Qwen3-0.6B") -> None:
|
|
19
|
+
"""Initialize the local enrichment provider."""
|
|
20
|
+
self.log = structlog.get_logger(__name__)
|
|
21
|
+
self.model_name = model_name
|
|
22
|
+
self.model = None
|
|
23
|
+
self.tokenizer = None
|
|
24
|
+
|
|
25
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
26
|
+
"""Enrich a list of strings."""
|
|
27
|
+
if self.tokenizer is None:
|
|
28
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
29
|
+
if self.model is None:
|
|
30
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
31
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
32
|
+
self.model_name,
|
|
33
|
+
torch_dtype="auto",
|
|
34
|
+
trust_remote_code=True,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
results = []
|
|
38
|
+
for snippet in data:
|
|
39
|
+
# prepare the model input
|
|
40
|
+
messages = [
|
|
41
|
+
{"role": "system", "content": ENRICHMENT_SYSTEM_PROMPT},
|
|
42
|
+
{"role": "user", "content": snippet},
|
|
43
|
+
]
|
|
44
|
+
text = self.tokenizer.apply_chat_template(
|
|
45
|
+
messages,
|
|
46
|
+
tokenize=False,
|
|
47
|
+
add_generation_prompt=True,
|
|
48
|
+
enable_thinking=False,
|
|
49
|
+
)
|
|
50
|
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(
|
|
51
|
+
self.model.device
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# conduct text completion
|
|
55
|
+
generated_ids = self.model.generate(**model_inputs, max_new_tokens=32768)
|
|
56
|
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
|
|
57
|
+
content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip(
|
|
58
|
+
"\n"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
results.append(content)
|
|
62
|
+
|
|
63
|
+
return results
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""OpenAI embedding service."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
import tiktoken
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from kodit.enrichment.enrichment_provider.enrichment_provider import (
|
|
11
|
+
ENRICHMENT_SYSTEM_PROMPT,
|
|
12
|
+
EnrichmentProvider,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
OPENAI_NUM_PARALLEL_TASKS = 10
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OpenAIEnrichmentProvider(EnrichmentProvider):
|
|
19
|
+
"""OpenAI enrichment provider."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
openai_client: AsyncOpenAI,
|
|
24
|
+
model_name: str = "gpt-4o-mini",
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initialize the OpenAI enrichment provider."""
|
|
27
|
+
self.log = structlog.get_logger(__name__)
|
|
28
|
+
self.openai_client = openai_client
|
|
29
|
+
self.model_name = model_name
|
|
30
|
+
self.encoding = tiktoken.encoding_for_model(model_name)
|
|
31
|
+
|
|
32
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
33
|
+
"""Enrich a list of documents."""
|
|
34
|
+
# Process batches in parallel with a semaphore to limit concurrent requests
|
|
35
|
+
sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
|
|
36
|
+
|
|
37
|
+
# Create a list of tuples with a temporary id for each snippet
|
|
38
|
+
# We need to do this so that we can return the results in the same order as the
|
|
39
|
+
# input data
|
|
40
|
+
input_data = [(i, snippet) for i, snippet in enumerate(data)]
|
|
41
|
+
|
|
42
|
+
async def process_data(data: tuple[int, str]) -> tuple[int, str]:
|
|
43
|
+
snippet_id, snippet = data
|
|
44
|
+
if not snippet:
|
|
45
|
+
return snippet_id, ""
|
|
46
|
+
async with sem:
|
|
47
|
+
try:
|
|
48
|
+
response = await self.openai_client.chat.completions.create(
|
|
49
|
+
model=self.model_name,
|
|
50
|
+
messages=[
|
|
51
|
+
{
|
|
52
|
+
"role": "system",
|
|
53
|
+
"content": ENRICHMENT_SYSTEM_PROMPT,
|
|
54
|
+
},
|
|
55
|
+
{"role": "user", "content": snippet},
|
|
56
|
+
],
|
|
57
|
+
)
|
|
58
|
+
return snippet_id, response.choices[0].message.content or ""
|
|
59
|
+
except Exception as e:
|
|
60
|
+
self.log.exception("Error enriching data", error=str(e))
|
|
61
|
+
return snippet_id, ""
|
|
62
|
+
|
|
63
|
+
# Create tasks for all data
|
|
64
|
+
tasks = [process_data(snippet) for snippet in input_data]
|
|
65
|
+
|
|
66
|
+
# Process all data and yield results as they complete
|
|
67
|
+
results: list[tuple[int, str]] = []
|
|
68
|
+
for task in tqdm(
|
|
69
|
+
asyncio.as_completed(tasks),
|
|
70
|
+
total=len(tasks),
|
|
71
|
+
leave=False,
|
|
72
|
+
):
|
|
73
|
+
result = await task
|
|
74
|
+
results.append(result)
|
|
75
|
+
|
|
76
|
+
# Output in the same order as the input data
|
|
77
|
+
return [result for _, result in sorted(results, key=lambda x: x[0])]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Enrichment service."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from kodit.enrichment.enrichment_provider.enrichment_provider import EnrichmentProvider
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EnrichmentService(ABC):
|
|
9
|
+
"""Enrichment service."""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
13
|
+
"""Enrich a list of strings."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NullEnrichmentService(EnrichmentService):
|
|
17
|
+
"""Null enrichment service."""
|
|
18
|
+
|
|
19
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
20
|
+
"""Enrich a list of strings."""
|
|
21
|
+
return [""] * len(data)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LLMEnrichmentService(EnrichmentService):
|
|
25
|
+
"""Enrichment service using an LLM."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, enrichment_provider: EnrichmentProvider) -> None:
|
|
28
|
+
"""Initialize the enrichment service."""
|
|
29
|
+
self.enrichment_provider = enrichment_provider
|
|
30
|
+
|
|
31
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
32
|
+
"""Enrich a list of strings."""
|
|
33
|
+
return await self.enrichment_provider.enrich(data)
|
kodit/indexing/fusion.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Fusion functions for combining search results."""
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class FusionResult:
|
|
9
|
+
"""Result of a fusion operation."""
|
|
10
|
+
|
|
11
|
+
id: int
|
|
12
|
+
score: float
|
|
13
|
+
original_scores: list[float]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class FusionRequest:
|
|
18
|
+
"""Result of a RRF operation."""
|
|
19
|
+
|
|
20
|
+
id: int
|
|
21
|
+
score: float
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def reciprocal_rank_fusion(
|
|
25
|
+
rankings: list[list[FusionRequest]], k: float = 60
|
|
26
|
+
) -> list[FusionResult]:
|
|
27
|
+
"""RRF prioritises results that are present in all results.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
rankings: List of rankers, each containing a list of document ids. Top of the
|
|
31
|
+
list is considered to be the best result.
|
|
32
|
+
k: Parameter for RRF.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Dictionary of ids and their scores.
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
scores = {}
|
|
39
|
+
for ranker in rankings:
|
|
40
|
+
for rank in ranker:
|
|
41
|
+
scores[rank.id] = float(0)
|
|
42
|
+
|
|
43
|
+
for ranker in rankings:
|
|
44
|
+
for i, rank in enumerate(ranker):
|
|
45
|
+
scores[rank.id] += 1.0 / (k + i)
|
|
46
|
+
|
|
47
|
+
# Create a list of tuples of ids and their scores
|
|
48
|
+
results = [(rank, scores[rank]) for rank in scores]
|
|
49
|
+
|
|
50
|
+
# Sort results by score
|
|
51
|
+
results.sort(key=lambda x: x[1], reverse=True)
|
|
52
|
+
|
|
53
|
+
# Create a map of original scores to ids
|
|
54
|
+
original_scores_to_ids = defaultdict(list)
|
|
55
|
+
for ranker in rankings:
|
|
56
|
+
for rank in ranker:
|
|
57
|
+
original_scores_to_ids[rank.id].append(rank.score)
|
|
58
|
+
|
|
59
|
+
# Rebuild a list of final results with their original scores
|
|
60
|
+
return [
|
|
61
|
+
FusionResult(
|
|
62
|
+
id=result[0],
|
|
63
|
+
score=result[1],
|
|
64
|
+
original_scores=original_scores_to_ids[result[0]],
|
|
65
|
+
)
|
|
66
|
+
for result in results
|
|
67
|
+
]
|
|
@@ -10,6 +10,7 @@ from typing import TypeVar
|
|
|
10
10
|
|
|
11
11
|
from sqlalchemy import delete, func, select
|
|
12
12
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
13
|
+
from sqlalchemy.orm.exc import MultipleResultsFound
|
|
13
14
|
|
|
14
15
|
from kodit.embedding.embedding_models import Embedding
|
|
15
16
|
from kodit.indexing.indexing_models import Index, Snippet
|
|
@@ -124,15 +125,34 @@ class IndexRepository:
|
|
|
124
125
|
index.updated_at = datetime.now(UTC)
|
|
125
126
|
await self.session.commit()
|
|
126
127
|
|
|
127
|
-
async def
|
|
128
|
-
"""Add a new snippet to the database.
|
|
128
|
+
async def add_snippet_or_update_content(self, snippet: Snippet) -> None:
|
|
129
|
+
"""Add a new snippet to the database if it doesn't exist, otherwise update it.
|
|
129
130
|
|
|
130
131
|
Args:
|
|
131
132
|
snippet: The Snippet instance to add.
|
|
132
133
|
|
|
133
134
|
"""
|
|
134
|
-
|
|
135
|
-
|
|
135
|
+
query = select(Snippet).where(
|
|
136
|
+
Snippet.file_id == snippet.file_id,
|
|
137
|
+
Snippet.index_id == snippet.index_id,
|
|
138
|
+
)
|
|
139
|
+
result = await self.session.execute(query)
|
|
140
|
+
try:
|
|
141
|
+
existing_snippet = result.scalar_one_or_none()
|
|
142
|
+
|
|
143
|
+
if existing_snippet:
|
|
144
|
+
existing_snippet.content = snippet.content
|
|
145
|
+
else:
|
|
146
|
+
self.session.add(snippet)
|
|
147
|
+
|
|
148
|
+
await self.session.commit()
|
|
149
|
+
except MultipleResultsFound as e:
|
|
150
|
+
msg = (
|
|
151
|
+
f"Multiple snippets found for file_id {snippet.file_id}, this "
|
|
152
|
+
"shouldn't happen. "
|
|
153
|
+
"Please report this as a bug then delete your index and start again."
|
|
154
|
+
)
|
|
155
|
+
raise ValueError(msg) from e
|
|
136
156
|
|
|
137
157
|
async def delete_all_snippets(self, index_id: int) -> None:
|
|
138
158
|
"""Delete all snippets for an index.
|
|
@@ -176,3 +196,23 @@ class IndexRepository:
|
|
|
176
196
|
"""
|
|
177
197
|
self.session.add(embedding)
|
|
178
198
|
await self.session.commit()
|
|
199
|
+
|
|
200
|
+
async def list_snippets_by_ids(self, ids: list[int]) -> list[tuple[File, Snippet]]:
|
|
201
|
+
"""List snippets by IDs.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
A list of snippets in the same order as the input IDs.
|
|
205
|
+
|
|
206
|
+
"""
|
|
207
|
+
query = (
|
|
208
|
+
select(Snippet, File)
|
|
209
|
+
.where(Snippet.id.in_(ids))
|
|
210
|
+
.join(File, Snippet.file_id == File.id)
|
|
211
|
+
)
|
|
212
|
+
rows = await self.session.execute(query)
|
|
213
|
+
|
|
214
|
+
# Create a dictionary for O(1) lookup of results by ID
|
|
215
|
+
id_to_result = {snippet.id: (file, snippet) for snippet, file in rows.all()}
|
|
216
|
+
|
|
217
|
+
# Return results in the same order as input IDs
|
|
218
|
+
return [id_to_result[i] for i in ids]
|
|
@@ -13,13 +13,22 @@ import pydantic
|
|
|
13
13
|
import structlog
|
|
14
14
|
from tqdm.asyncio import tqdm
|
|
15
15
|
|
|
16
|
-
from kodit.bm25.
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
from kodit.bm25.keyword_search_service import (
|
|
17
|
+
BM25Document,
|
|
18
|
+
BM25Result,
|
|
19
|
+
KeywordSearchProvider,
|
|
20
|
+
)
|
|
21
|
+
from kodit.embedding.vector_search_service import (
|
|
22
|
+
VectorSearchRequest,
|
|
23
|
+
VectorSearchService,
|
|
24
|
+
)
|
|
25
|
+
from kodit.enrichment.enrichment_service import EnrichmentService
|
|
26
|
+
from kodit.indexing.fusion import FusionRequest, reciprocal_rank_fusion
|
|
19
27
|
from kodit.indexing.indexing_models import Snippet
|
|
20
28
|
from kodit.indexing.indexing_repository import IndexRepository
|
|
21
29
|
from kodit.snippets.snippets import SnippetService
|
|
22
30
|
from kodit.source.source_service import SourceService
|
|
31
|
+
from kodit.util.spinner import Spinner
|
|
23
32
|
|
|
24
33
|
# List of MIME types that are blacklisted from being indexed
|
|
25
34
|
MIME_BLACKLIST = ["unknown/unknown"]
|
|
@@ -39,6 +48,28 @@ class IndexView(pydantic.BaseModel):
|
|
|
39
48
|
num_snippets: int | None = None
|
|
40
49
|
|
|
41
50
|
|
|
51
|
+
class SearchRequest(pydantic.BaseModel):
|
|
52
|
+
"""Request for a search."""
|
|
53
|
+
|
|
54
|
+
text_query: str | None = None
|
|
55
|
+
code_query: str | None = None
|
|
56
|
+
keywords: list[str] | None = None
|
|
57
|
+
top_k: int = 10
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SearchResult(pydantic.BaseModel):
|
|
61
|
+
"""Data transfer object for search results.
|
|
62
|
+
|
|
63
|
+
This model represents a single search result, containing both the file path
|
|
64
|
+
and the matching snippet content.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
id: int
|
|
68
|
+
uri: str
|
|
69
|
+
content: str
|
|
70
|
+
original_scores: list[float]
|
|
71
|
+
|
|
72
|
+
|
|
42
73
|
class IndexService:
|
|
43
74
|
"""Service for managing code indexes.
|
|
44
75
|
|
|
@@ -47,12 +78,14 @@ class IndexService:
|
|
|
47
78
|
IndexRepository), and provides a clean API for index management.
|
|
48
79
|
"""
|
|
49
80
|
|
|
50
|
-
def __init__(
|
|
81
|
+
def __init__( # noqa: PLR0913
|
|
51
82
|
self,
|
|
52
83
|
repository: IndexRepository,
|
|
53
84
|
source_service: SourceService,
|
|
54
|
-
|
|
55
|
-
|
|
85
|
+
keyword_search_provider: KeywordSearchProvider,
|
|
86
|
+
code_search_service: VectorSearchService,
|
|
87
|
+
text_search_service: VectorSearchService,
|
|
88
|
+
enrichment_service: EnrichmentService,
|
|
56
89
|
) -> None:
|
|
57
90
|
"""Initialize the index service.
|
|
58
91
|
|
|
@@ -65,8 +98,10 @@ class IndexService:
|
|
|
65
98
|
self.source_service = source_service
|
|
66
99
|
self.snippet_service = SnippetService()
|
|
67
100
|
self.log = structlog.get_logger(__name__)
|
|
68
|
-
self.
|
|
69
|
-
self.
|
|
101
|
+
self.keyword_search_provider = keyword_search_provider
|
|
102
|
+
self.code_search_service = code_search_service
|
|
103
|
+
self.text_search_service = text_search_service
|
|
104
|
+
self.enrichment_service = enrichment_service
|
|
70
105
|
|
|
71
106
|
async def create(self, source_id: int) -> IndexView:
|
|
72
107
|
"""Create a new index for a source.
|
|
@@ -126,41 +161,116 @@ class IndexService:
|
|
|
126
161
|
msg = f"Index not found: {index_id}"
|
|
127
162
|
raise ValueError(msg)
|
|
128
163
|
|
|
129
|
-
# First delete all old snippets, if they exist
|
|
130
|
-
await self.repository.delete_all_snippets(index_id)
|
|
131
|
-
|
|
132
164
|
# Create snippets for supported file types
|
|
133
165
|
await self._create_snippets(index_id)
|
|
134
166
|
|
|
135
167
|
snippets = await self.repository.get_all_snippets(index_id)
|
|
136
168
|
|
|
137
169
|
self.log.info("Creating keyword index")
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
170
|
+
with Spinner():
|
|
171
|
+
await self.keyword_search_provider.index(
|
|
172
|
+
[
|
|
173
|
+
BM25Document(snippet_id=snippet.id, text=snippet.content)
|
|
174
|
+
for snippet in snippets
|
|
175
|
+
]
|
|
176
|
+
)
|
|
144
177
|
|
|
145
178
|
self.log.info("Creating semantic code index")
|
|
146
|
-
|
|
147
|
-
self.
|
|
148
|
-
[
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
):
|
|
153
|
-
await self.repository.add_embedding(
|
|
154
|
-
Embedding(
|
|
155
|
-
snippet_id=e.id,
|
|
156
|
-
embedding=e.embedding,
|
|
157
|
-
type=EmbeddingType.CODE,
|
|
158
|
-
)
|
|
179
|
+
with Spinner():
|
|
180
|
+
await self.code_search_service.index(
|
|
181
|
+
[
|
|
182
|
+
VectorSearchRequest(snippet.id, snippet.content)
|
|
183
|
+
for snippet in snippets
|
|
184
|
+
]
|
|
159
185
|
)
|
|
160
186
|
|
|
187
|
+
self.log.info("Enriching snippets")
|
|
188
|
+
enriched_contents = await self.enrichment_service.enrich(
|
|
189
|
+
[snippet.content for snippet in snippets]
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self.log.info("Creating semantic text index")
|
|
193
|
+
with Spinner():
|
|
194
|
+
await self.text_search_service.index(
|
|
195
|
+
[
|
|
196
|
+
VectorSearchRequest(snippet.id, enriched_content)
|
|
197
|
+
for snippet, enriched_content in zip(
|
|
198
|
+
snippets, enriched_contents, strict=True
|
|
199
|
+
)
|
|
200
|
+
]
|
|
201
|
+
)
|
|
202
|
+
# Add the enriched text back to the snippets and write to the database
|
|
203
|
+
for snippet, enriched_content in zip(
|
|
204
|
+
snippets, enriched_contents, strict=True
|
|
205
|
+
):
|
|
206
|
+
snippet.content = (
|
|
207
|
+
enriched_content + "\n\n```\n" + snippet.content + "\n```"
|
|
208
|
+
)
|
|
209
|
+
await self.repository.add_snippet_or_update_content(snippet)
|
|
210
|
+
|
|
161
211
|
# Update index timestamp
|
|
162
212
|
await self.repository.update_index_timestamp(index)
|
|
163
213
|
|
|
214
|
+
async def search(self, request: SearchRequest) -> list[SearchResult]:
|
|
215
|
+
"""Search for relevant data."""
|
|
216
|
+
fusion_list: list[list[FusionRequest]] = []
|
|
217
|
+
if request.keywords:
|
|
218
|
+
# Gather results for each keyword
|
|
219
|
+
result_ids: list[BM25Result] = []
|
|
220
|
+
for keyword in request.keywords:
|
|
221
|
+
results = await self.keyword_search_provider.retrieve(
|
|
222
|
+
keyword, request.top_k
|
|
223
|
+
)
|
|
224
|
+
result_ids.extend(results)
|
|
225
|
+
|
|
226
|
+
fusion_list.append(
|
|
227
|
+
[FusionRequest(id=x.snippet_id, score=x.score) for x in result_ids]
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Compute embedding for semantic query
|
|
231
|
+
if request.code_query:
|
|
232
|
+
query_embedding = await self.code_search_service.retrieve(
|
|
233
|
+
request.code_query, top_k=request.top_k
|
|
234
|
+
)
|
|
235
|
+
fusion_list.append(
|
|
236
|
+
[FusionRequest(id=x.snippet_id, score=x.score) for x in query_embedding]
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if request.text_query:
|
|
240
|
+
query_embedding = await self.text_search_service.retrieve(
|
|
241
|
+
request.text_query, top_k=request.top_k
|
|
242
|
+
)
|
|
243
|
+
fusion_list.append(
|
|
244
|
+
[FusionRequest(id=x.snippet_id, score=x.score) for x in query_embedding]
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if len(fusion_list) == 0:
|
|
248
|
+
return []
|
|
249
|
+
|
|
250
|
+
# Combine all results together with RFF if required
|
|
251
|
+
final_results = reciprocal_rank_fusion(
|
|
252
|
+
rankings=fusion_list,
|
|
253
|
+
k=60,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Only keep top_k results
|
|
257
|
+
final_results = final_results[: request.top_k]
|
|
258
|
+
|
|
259
|
+
# Get snippets from database (up to top_k)
|
|
260
|
+
search_results = await self.repository.list_snippets_by_ids(
|
|
261
|
+
[x.id for x in final_results]
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
return [
|
|
265
|
+
SearchResult(
|
|
266
|
+
id=snippet.id,
|
|
267
|
+
uri=file.uri,
|
|
268
|
+
content=snippet.content,
|
|
269
|
+
original_scores=fr.original_scores,
|
|
270
|
+
)
|
|
271
|
+
for (file, snippet), fr in zip(search_results, final_results, strict=True)
|
|
272
|
+
]
|
|
273
|
+
|
|
164
274
|
async def _create_snippets(
|
|
165
275
|
self,
|
|
166
276
|
index_id: int,
|
|
@@ -174,6 +284,7 @@ class IndexService:
|
|
|
174
284
|
|
|
175
285
|
"""
|
|
176
286
|
files = await self.repository.files_for_index(index_id)
|
|
287
|
+
self.log.info("Creating snippets for files", index_id=index_id)
|
|
177
288
|
for file in tqdm(files, total=len(files), leave=False):
|
|
178
289
|
# Skip unsupported file types
|
|
179
290
|
if file.mime_type in MIME_BLACKLIST:
|
|
@@ -195,4 +306,4 @@ class IndexService:
|
|
|
195
306
|
file_id=file.id,
|
|
196
307
|
content=snippet.text,
|
|
197
308
|
)
|
|
198
|
-
await self.repository.
|
|
309
|
+
await self.repository.add_snippet_or_update_content(s)
|
kodit/mcp.py
CHANGED
|
@@ -12,11 +12,15 @@ from pydantic import Field
|
|
|
12
12
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
13
13
|
|
|
14
14
|
from kodit._version import version
|
|
15
|
+
from kodit.bm25.keyword_search_factory import keyword_search_factory
|
|
15
16
|
from kodit.config import AppContext
|
|
16
17
|
from kodit.database import Database
|
|
17
|
-
from kodit.embedding.
|
|
18
|
-
from kodit.
|
|
19
|
-
from kodit.
|
|
18
|
+
from kodit.embedding.embedding_factory import embedding_factory
|
|
19
|
+
from kodit.enrichment.enrichment_factory import enrichment_factory
|
|
20
|
+
from kodit.indexing.indexing_repository import IndexRepository
|
|
21
|
+
from kodit.indexing.indexing_service import IndexService, SearchRequest, SearchResult
|
|
22
|
+
from kodit.source.source_repository import SourceRepository
|
|
23
|
+
from kodit.source.source_service import SourceService
|
|
20
24
|
|
|
21
25
|
|
|
22
26
|
@dataclass
|
|
@@ -122,29 +126,38 @@ async def search(
|
|
|
122
126
|
|
|
123
127
|
mcp_context: MCPContext = ctx.request_context.lifespan_context
|
|
124
128
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
129
|
+
source_repository = SourceRepository(mcp_context.session)
|
|
130
|
+
source_service = SourceService(
|
|
131
|
+
mcp_context.app_context.get_clone_dir(), source_repository
|
|
128
132
|
)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
133
|
+
repository = IndexRepository(mcp_context.session)
|
|
134
|
+
service = IndexService(
|
|
135
|
+
repository=repository,
|
|
136
|
+
source_service=source_service,
|
|
137
|
+
keyword_search_provider=keyword_search_factory(
|
|
138
|
+
mcp_context.app_context, mcp_context.session
|
|
139
|
+
),
|
|
140
|
+
code_search_service=embedding_factory(
|
|
141
|
+
task_name="code",
|
|
142
|
+
app_context=mcp_context.app_context,
|
|
143
|
+
session=mcp_context.session,
|
|
144
|
+
),
|
|
145
|
+
text_search_service=embedding_factory(
|
|
146
|
+
task_name="text",
|
|
147
|
+
app_context=mcp_context.app_context,
|
|
148
|
+
session=mcp_context.session,
|
|
149
|
+
),
|
|
150
|
+
enrichment_service=enrichment_factory(mcp_context.app_context),
|
|
140
151
|
)
|
|
141
152
|
|
|
142
153
|
search_request = SearchRequest(
|
|
143
154
|
keywords=keywords,
|
|
144
155
|
code_query="\n".join(related_file_contents),
|
|
156
|
+
text_query=user_intent,
|
|
145
157
|
)
|
|
158
|
+
|
|
146
159
|
log.debug("Searching for snippets")
|
|
147
|
-
snippets = await
|
|
160
|
+
snippets = await service.search(request=search_request)
|
|
148
161
|
|
|
149
162
|
log.debug("Fusing output")
|
|
150
163
|
output = output_fusion(snippets=snippets)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
(function_declaration
|
|
2
|
+
name: (identifier) @function.name
|
|
3
|
+
body: (block) @function.body
|
|
4
|
+
) @function.def
|
|
5
|
+
|
|
6
|
+
(method_declaration
|
|
7
|
+
name: (field_identifier) @method.name
|
|
8
|
+
body: (block) @method.body
|
|
9
|
+
) @method.def
|
|
10
|
+
|
|
11
|
+
(import_declaration
|
|
12
|
+
(import_spec
|
|
13
|
+
path: (interpreted_string_literal) @import.name
|
|
14
|
+
)
|
|
15
|
+
) @import.statement
|
|
16
|
+
|
|
17
|
+
(identifier) @ident
|
|
18
|
+
|
|
19
|
+
(parameter_declaration
|
|
20
|
+
name: (identifier) @param.name
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
(package_clause "package" (package_identifier) @name.definition.module)
|
|
24
|
+
|
|
25
|
+
;; Exclude comments from being captured
|
|
26
|
+
(comment) @comment
|