kodit 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/bm25/keyword_search_factory.py +17 -0
- kodit/bm25/keyword_search_service.py +34 -0
- kodit/bm25/{bm25.py → local_bm25.py} +40 -14
- kodit/bm25/vectorchord_bm25.py +193 -0
- kodit/cli.py +114 -25
- kodit/config.py +9 -2
- kodit/database.py +4 -2
- kodit/embedding/embedding_factory.py +44 -0
- kodit/embedding/embedding_provider/__init__.py +1 -0
- kodit/embedding/embedding_provider/embedding_provider.py +60 -0
- kodit/embedding/embedding_provider/hash_embedding_provider.py +77 -0
- kodit/embedding/embedding_provider/local_embedding_provider.py +58 -0
- kodit/embedding/embedding_provider/openai_embedding_provider.py +75 -0
- kodit/{search/search_repository.py → embedding/embedding_repository.py} +61 -33
- kodit/embedding/local_vector_search_service.py +50 -0
- kodit/embedding/vector_search_service.py +38 -0
- kodit/embedding/vectorchord_vector_search_service.py +154 -0
- kodit/enrichment/__init__.py +1 -0
- kodit/enrichment/enrichment_factory.py +23 -0
- kodit/enrichment/enrichment_provider/__init__.py +1 -0
- kodit/enrichment/enrichment_provider/enrichment_provider.py +16 -0
- kodit/enrichment/enrichment_provider/local_enrichment_provider.py +63 -0
- kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +77 -0
- kodit/enrichment/enrichment_service.py +33 -0
- kodit/indexing/fusion.py +67 -0
- kodit/indexing/indexing_repository.py +44 -4
- kodit/indexing/indexing_service.py +142 -31
- kodit/mcp.py +31 -18
- kodit/snippets/languages/go.scm +26 -0
- kodit/source/source_service.py +9 -3
- kodit/util/__init__.py +1 -0
- kodit/util/spinner.py +59 -0
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/METADATA +4 -1
- kodit-0.1.16.dist-info/RECORD +64 -0
- kodit/embedding/embedding.py +0 -203
- kodit/search/__init__.py +0 -1
- kodit/search/search_service.py +0 -147
- kodit-0.1.14.dist-info/RECORD +0 -44
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/WHEEL +0 -0
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/licenses/LICENSE +0 -0
kodit/_version.py
CHANGED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Factory for creating keyword search providers."""
|
|
2
|
+
|
|
3
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
4
|
+
|
|
5
|
+
from kodit.bm25.keyword_search_service import KeywordSearchProvider
|
|
6
|
+
from kodit.bm25.local_bm25 import BM25Service
|
|
7
|
+
from kodit.bm25.vectorchord_bm25 import VectorChordBM25
|
|
8
|
+
from kodit.config import AppContext
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def keyword_search_factory(
|
|
12
|
+
app_context: AppContext, session: AsyncSession
|
|
13
|
+
) -> KeywordSearchProvider:
|
|
14
|
+
"""Create a keyword search provider."""
|
|
15
|
+
if app_context.default_search.provider == "vectorchord":
|
|
16
|
+
return VectorChordBM25(session=session)
|
|
17
|
+
return BM25Service(data_dir=app_context.get_data_dir())
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Keyword search service."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BM25Document(NamedTuple):
|
|
8
|
+
"""BM25 document."""
|
|
9
|
+
|
|
10
|
+
snippet_id: int
|
|
11
|
+
text: str
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BM25Result(NamedTuple):
|
|
15
|
+
"""BM25 result."""
|
|
16
|
+
|
|
17
|
+
snippet_id: int
|
|
18
|
+
score: float
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class KeywordSearchProvider(ABC):
|
|
22
|
+
"""Interface for keyword search providers."""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
async def index(self, corpus: list[BM25Document]) -> None:
|
|
26
|
+
"""Index a new corpus."""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def retrieve(self, query: str, top_k: int = 2) -> list[BM25Result]:
|
|
30
|
+
"""Retrieve from the index."""
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
async def delete(self, snippet_ids: list[int]) -> None:
|
|
34
|
+
"""Delete documents from the index."""
|
|
@@ -1,23 +1,36 @@
|
|
|
1
|
-
"""BM25 service."""
|
|
1
|
+
"""Locally hosted BM25 service primarily for use with SQLite."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
|
|
6
|
+
import aiofiles
|
|
5
7
|
import bm25s
|
|
6
8
|
import Stemmer
|
|
7
9
|
import structlog
|
|
8
10
|
from bm25s.tokenization import Tokenized
|
|
9
11
|
|
|
12
|
+
from kodit.bm25.keyword_search_service import (
|
|
13
|
+
BM25Document,
|
|
14
|
+
BM25Result,
|
|
15
|
+
KeywordSearchProvider,
|
|
16
|
+
)
|
|
10
17
|
|
|
11
|
-
|
|
12
|
-
|
|
18
|
+
SNIPPET_IDS_FILE = "snippet_ids.jsonl"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BM25Service(KeywordSearchProvider):
|
|
22
|
+
"""LocalBM25 service."""
|
|
13
23
|
|
|
14
24
|
def __init__(self, data_dir: Path) -> None:
|
|
15
25
|
"""Initialize the BM25 service."""
|
|
16
26
|
self.log = structlog.get_logger(__name__)
|
|
17
27
|
self.index_path = data_dir / "bm25s_index"
|
|
28
|
+
self.snippet_ids: list[int] = []
|
|
18
29
|
try:
|
|
19
30
|
self.log.debug("Loading BM25 index")
|
|
20
31
|
self.retriever = bm25s.BM25.load(self.index_path, mmap=True)
|
|
32
|
+
with Path(self.index_path / SNIPPET_IDS_FILE).open() as f:
|
|
33
|
+
self.snippet_ids = json.load(f)
|
|
21
34
|
except FileNotFoundError:
|
|
22
35
|
self.log.debug("BM25 index not found, creating new index")
|
|
23
36
|
self.retriever = bm25s.BM25()
|
|
@@ -33,28 +46,34 @@ class BM25Service:
|
|
|
33
46
|
show_progress=True,
|
|
34
47
|
)
|
|
35
48
|
|
|
36
|
-
def index(self, corpus: list[
|
|
49
|
+
async def index(self, corpus: list[BM25Document]) -> None:
|
|
37
50
|
"""Index a new corpus."""
|
|
38
51
|
self.log.debug("Indexing corpus")
|
|
39
|
-
vocab = self._tokenize(corpus)
|
|
52
|
+
vocab = self._tokenize([doc.text for doc in corpus])
|
|
40
53
|
self.retriever = bm25s.BM25()
|
|
41
54
|
self.retriever.index(vocab, show_progress=False)
|
|
42
55
|
self.retriever.save(self.index_path)
|
|
56
|
+
self.snippet_ids = self.snippet_ids + [doc.snippet_id for doc in corpus]
|
|
57
|
+
async with aiofiles.open(self.index_path / SNIPPET_IDS_FILE, "w") as f:
|
|
58
|
+
await f.write(json.dumps(self.snippet_ids))
|
|
43
59
|
|
|
44
|
-
def retrieve(
|
|
45
|
-
self, doc_ids: list[int], query: str, top_k: int = 2
|
|
46
|
-
) -> list[tuple[int, float]]:
|
|
60
|
+
async def retrieve(self, query: str, top_k: int = 2) -> list[BM25Result]:
|
|
47
61
|
"""Retrieve from the index."""
|
|
48
62
|
if top_k == 0:
|
|
49
63
|
self.log.warning("Top k is 0, returning empty list")
|
|
50
64
|
return []
|
|
51
|
-
|
|
52
|
-
|
|
65
|
+
|
|
66
|
+
# Get the number of documents in the index
|
|
67
|
+
num_docs = self.retriever.scores["num_docs"]
|
|
68
|
+
if num_docs == 0:
|
|
53
69
|
return []
|
|
54
70
|
|
|
55
|
-
|
|
71
|
+
# Adjust top_k to not exceed corpus size
|
|
72
|
+
top_k = min(top_k, num_docs)
|
|
56
73
|
self.log.debug(
|
|
57
|
-
"Retrieving from index",
|
|
74
|
+
"Retrieving from index",
|
|
75
|
+
query=query,
|
|
76
|
+
top_k=top_k,
|
|
58
77
|
)
|
|
59
78
|
|
|
60
79
|
query_tokens = self._tokenize([query])
|
|
@@ -62,10 +81,17 @@ class BM25Service:
|
|
|
62
81
|
self.log.debug("Query tokens", query_tokens=query_tokens)
|
|
63
82
|
|
|
64
83
|
results, scores = self.retriever.retrieve(
|
|
65
|
-
query_tokens=query_tokens,
|
|
84
|
+
query_tokens=query_tokens,
|
|
85
|
+
corpus=self.snippet_ids,
|
|
86
|
+
k=top_k,
|
|
66
87
|
)
|
|
67
88
|
self.log.debug("Raw results", results=results, scores=scores)
|
|
68
89
|
return [
|
|
69
|
-
(int(result), float(score))
|
|
90
|
+
BM25Result(snippet_id=int(result), score=float(score))
|
|
70
91
|
for result, score in zip(results[0], scores[0], strict=False)
|
|
92
|
+
if score > 0.0
|
|
71
93
|
]
|
|
94
|
+
|
|
95
|
+
async def delete(self, snippet_ids: list[int]) -> None: # noqa: ARG002
|
|
96
|
+
"""Delete documents from the index."""
|
|
97
|
+
self.log.warning("Deletion not supported for local BM25 index")
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""VectorChord repository for document operations."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlalchemy import Result, TextClause, bindparam, text
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from kodit.bm25.keyword_search_service import (
|
|
9
|
+
BM25Document,
|
|
10
|
+
BM25Result,
|
|
11
|
+
KeywordSearchProvider,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
TABLE_NAME = "vectorchord_bm25_documents"
|
|
15
|
+
INDEX_NAME = f"{TABLE_NAME}_idx"
|
|
16
|
+
TOKENIZER_NAME = "bert"
|
|
17
|
+
|
|
18
|
+
# SQL statements
|
|
19
|
+
CREATE_VCHORD_EXTENSION = "CREATE EXTENSION IF NOT EXISTS vchord CASCADE;"
|
|
20
|
+
CREATE_PG_TOKENIZER = "CREATE EXTENSION IF NOT EXISTS pg_tokenizer CASCADE;"
|
|
21
|
+
CREATE_VCHORD_BM25 = "CREATE EXTENSION IF NOT EXISTS vchord_bm25 CASCADE;"
|
|
22
|
+
SET_SEARCH_PATH = """
|
|
23
|
+
SET search_path TO
|
|
24
|
+
"$user", public, bm25_catalog, pg_catalog, information_schema, tokenizer_catalog;
|
|
25
|
+
"""
|
|
26
|
+
CREATE_BM25_TABLE = f"""
|
|
27
|
+
CREATE TABLE IF NOT EXISTS {TABLE_NAME} (
|
|
28
|
+
id SERIAL PRIMARY KEY,
|
|
29
|
+
snippet_id BIGINT NOT NULL,
|
|
30
|
+
passage TEXT NOT NULL,
|
|
31
|
+
embedding bm25vector,
|
|
32
|
+
UNIQUE(snippet_id)
|
|
33
|
+
)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
CREATE_BM25_INDEX = f"""
|
|
37
|
+
CREATE INDEX IF NOT EXISTS {INDEX_NAME}
|
|
38
|
+
ON {TABLE_NAME}
|
|
39
|
+
USING bm25 (embedding bm25_ops)
|
|
40
|
+
"""
|
|
41
|
+
TOKENIZER_NAME_CHECK_QUERY = (
|
|
42
|
+
f"SELECT 1 FROM tokenizer_catalog.tokenizer WHERE name = '{TOKENIZER_NAME}'" # noqa: S608
|
|
43
|
+
)
|
|
44
|
+
LOAD_TOKENIZER = """
|
|
45
|
+
SELECT create_tokenizer('bert', $$
|
|
46
|
+
model = "llmlingua2"
|
|
47
|
+
pre_tokenizer = "unicode_segmentation" # Unicode Standard Annex #29
|
|
48
|
+
[[character_filters]]
|
|
49
|
+
to_lowercase = {} # convert all characters to lowercase
|
|
50
|
+
[[character_filters]]
|
|
51
|
+
unicode_normalization = "nfkd" # Unicode Normalization Form KD
|
|
52
|
+
[[token_filters]]
|
|
53
|
+
skip_non_alphanumeric = {} # remove non-alphanumeric tokens
|
|
54
|
+
[[token_filters]]
|
|
55
|
+
stopwords = "nltk_english" # remove stopwords using the nltk dictionary
|
|
56
|
+
[[token_filters]]
|
|
57
|
+
stemmer = "english_porter2" # stem tokens using the English Porter2 stemmer
|
|
58
|
+
$$)
|
|
59
|
+
"""
|
|
60
|
+
INSERT_QUERY = f"""
|
|
61
|
+
INSERT INTO {TABLE_NAME} (snippet_id, passage)
|
|
62
|
+
VALUES (:snippet_id, :passage)
|
|
63
|
+
ON CONFLICT (snippet_id) DO UPDATE
|
|
64
|
+
SET passage = EXCLUDED.passage
|
|
65
|
+
""" # noqa: S608
|
|
66
|
+
UPDATE_QUERY = f"""
|
|
67
|
+
UPDATE {TABLE_NAME}
|
|
68
|
+
SET embedding = tokenize(passage, '{TOKENIZER_NAME}')
|
|
69
|
+
""" # noqa: S608
|
|
70
|
+
SEARCH_QUERY = f"""
|
|
71
|
+
SELECT
|
|
72
|
+
snippet_id,
|
|
73
|
+
embedding <&>
|
|
74
|
+
to_bm25query('{INDEX_NAME}', tokenize(:query_text, '{TOKENIZER_NAME}'))
|
|
75
|
+
AS bm25_score
|
|
76
|
+
FROM {TABLE_NAME}
|
|
77
|
+
ORDER BY bm25_score
|
|
78
|
+
LIMIT :limit
|
|
79
|
+
""" # noqa: S608
|
|
80
|
+
DELETE_QUERY = f"""
|
|
81
|
+
DELETE FROM {TABLE_NAME}
|
|
82
|
+
WHERE snippet_id IN :snippet_ids
|
|
83
|
+
""" # noqa: S608
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class VectorChordBM25(KeywordSearchProvider):
|
|
87
|
+
"""BM25 using VectorChord."""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
session: AsyncSession,
|
|
92
|
+
) -> None:
|
|
93
|
+
"""Initialize the VectorChord BM25."""
|
|
94
|
+
self.__session = session
|
|
95
|
+
self._initialized = False
|
|
96
|
+
|
|
97
|
+
async def _initialize(self) -> None:
|
|
98
|
+
"""Initialize the VectorChord environment."""
|
|
99
|
+
try:
|
|
100
|
+
await self._create_extensions()
|
|
101
|
+
await self._create_tokenizer_if_not_exists()
|
|
102
|
+
await self._create_tables()
|
|
103
|
+
self._initialized = True
|
|
104
|
+
except Exception as e:
|
|
105
|
+
msg = f"Failed to initialize VectorChord repository: {e}"
|
|
106
|
+
raise RuntimeError(msg) from e
|
|
107
|
+
|
|
108
|
+
async def _create_extensions(self) -> None:
|
|
109
|
+
"""Create the necessary extensions."""
|
|
110
|
+
await self.__session.execute(text(CREATE_VCHORD_EXTENSION))
|
|
111
|
+
await self.__session.execute(text(CREATE_PG_TOKENIZER))
|
|
112
|
+
await self.__session.execute(text(CREATE_VCHORD_BM25))
|
|
113
|
+
await self.__session.execute(text(SET_SEARCH_PATH))
|
|
114
|
+
await self._commit()
|
|
115
|
+
|
|
116
|
+
async def _create_tokenizer_if_not_exists(self) -> None:
|
|
117
|
+
"""Create the tokenizer if it doesn't exist."""
|
|
118
|
+
# Check if tokenizer exists in the catalog
|
|
119
|
+
result = await self.__session.execute(text(TOKENIZER_NAME_CHECK_QUERY))
|
|
120
|
+
if result.scalar_one_or_none() is None:
|
|
121
|
+
# Tokenizer doesn't exist, create it
|
|
122
|
+
await self.__session.execute(text(LOAD_TOKENIZER))
|
|
123
|
+
await self._commit()
|
|
124
|
+
|
|
125
|
+
async def _create_tables(self) -> None:
|
|
126
|
+
"""Create the necessary tables in the correct order."""
|
|
127
|
+
await self.__session.execute(text(CREATE_BM25_TABLE))
|
|
128
|
+
await self.__session.execute(text(CREATE_BM25_INDEX))
|
|
129
|
+
await self._commit()
|
|
130
|
+
|
|
131
|
+
async def _execute(
|
|
132
|
+
self, query: TextClause, param_list: list[Any] | dict[str, Any] | None = None
|
|
133
|
+
) -> Result:
|
|
134
|
+
"""Execute a query."""
|
|
135
|
+
if not self._initialized:
|
|
136
|
+
await self._initialize()
|
|
137
|
+
return await self.__session.execute(query, param_list)
|
|
138
|
+
|
|
139
|
+
async def _commit(self) -> None:
|
|
140
|
+
"""Commit the session."""
|
|
141
|
+
await self.__session.commit()
|
|
142
|
+
|
|
143
|
+
async def index(self, corpus: list[BM25Document]) -> None:
|
|
144
|
+
"""Index a new corpus."""
|
|
145
|
+
# Filter out any documents that don't have a snippet_id or text
|
|
146
|
+
corpus = [
|
|
147
|
+
doc
|
|
148
|
+
for doc in corpus
|
|
149
|
+
if doc.snippet_id is not None and doc.text is not None and doc.text != ""
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
if not corpus:
|
|
153
|
+
return
|
|
154
|
+
|
|
155
|
+
# Execute inserts
|
|
156
|
+
await self._execute(
|
|
157
|
+
text(INSERT_QUERY),
|
|
158
|
+
[{"snippet_id": doc.snippet_id, "passage": doc.text} for doc in corpus],
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Tokenize the new documents with schema qualification
|
|
162
|
+
await self._execute(text(UPDATE_QUERY))
|
|
163
|
+
await self._commit()
|
|
164
|
+
|
|
165
|
+
async def delete(self, snippet_ids: list[int]) -> None:
|
|
166
|
+
"""Delete documents from the index."""
|
|
167
|
+
await self._execute(
|
|
168
|
+
text(DELETE_QUERY).bindparams(bindparam("snippet_ids", expanding=True)),
|
|
169
|
+
{"snippet_ids": snippet_ids},
|
|
170
|
+
)
|
|
171
|
+
await self._commit()
|
|
172
|
+
|
|
173
|
+
async def retrieve(
|
|
174
|
+
self,
|
|
175
|
+
query: str,
|
|
176
|
+
top_k: int = 10,
|
|
177
|
+
) -> list[BM25Result]:
|
|
178
|
+
"""Search documents using BM25 similarity."""
|
|
179
|
+
if not query or query == "":
|
|
180
|
+
return []
|
|
181
|
+
|
|
182
|
+
sql = text(SEARCH_QUERY).bindparams(query_text=query, limit=top_k)
|
|
183
|
+
try:
|
|
184
|
+
result = await self._execute(sql)
|
|
185
|
+
rows = result.mappings().all()
|
|
186
|
+
|
|
187
|
+
return [
|
|
188
|
+
BM25Result(snippet_id=row["snippet_id"], score=row["bm25_score"])
|
|
189
|
+
for row in rows
|
|
190
|
+
]
|
|
191
|
+
except Exception as e:
|
|
192
|
+
msg = f"Error during BM25 search: {e}"
|
|
193
|
+
raise RuntimeError(msg) from e
|
kodit/cli.py
CHANGED
|
@@ -10,17 +10,17 @@ import uvicorn
|
|
|
10
10
|
from pytable_formatter import Cell, Table
|
|
11
11
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
12
12
|
|
|
13
|
+
from kodit.bm25.keyword_search_factory import keyword_search_factory
|
|
13
14
|
from kodit.config import (
|
|
14
15
|
AppContext,
|
|
15
16
|
with_app_context,
|
|
16
17
|
with_session,
|
|
17
18
|
)
|
|
18
|
-
from kodit.embedding.
|
|
19
|
+
from kodit.embedding.embedding_factory import embedding_factory
|
|
20
|
+
from kodit.enrichment.enrichment_factory import enrichment_factory
|
|
19
21
|
from kodit.indexing.indexing_repository import IndexRepository
|
|
20
|
-
from kodit.indexing.indexing_service import IndexService
|
|
22
|
+
from kodit.indexing.indexing_service import IndexService, SearchRequest
|
|
21
23
|
from kodit.log import configure_logging, configure_telemetry, log_event
|
|
22
|
-
from kodit.search.search_repository import SearchRepository
|
|
23
|
-
from kodit.search.search_service import SearchRequest, SearchService
|
|
24
24
|
from kodit.source.source_repository import SourceRepository
|
|
25
25
|
from kodit.source.source_service import SourceService
|
|
26
26
|
|
|
@@ -68,10 +68,16 @@ async def index(
|
|
|
68
68
|
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
69
69
|
repository = IndexRepository(session)
|
|
70
70
|
service = IndexService(
|
|
71
|
-
repository,
|
|
72
|
-
source_service,
|
|
73
|
-
app_context
|
|
74
|
-
|
|
71
|
+
repository=repository,
|
|
72
|
+
source_service=source_service,
|
|
73
|
+
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
74
|
+
code_search_service=embedding_factory(
|
|
75
|
+
task_name="code", app_context=app_context, session=session
|
|
76
|
+
),
|
|
77
|
+
text_search_service=embedding_factory(
|
|
78
|
+
task_name="text", app_context=app_context, session=session
|
|
79
|
+
),
|
|
80
|
+
enrichment_service=enrichment_factory(app_context),
|
|
75
81
|
)
|
|
76
82
|
|
|
77
83
|
if not sources:
|
|
@@ -128,11 +134,20 @@ async def code(
|
|
|
128
134
|
|
|
129
135
|
This works best if your query is code.
|
|
130
136
|
"""
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
137
|
+
source_repository = SourceRepository(session)
|
|
138
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
139
|
+
repository = IndexRepository(session)
|
|
140
|
+
service = IndexService(
|
|
141
|
+
repository=repository,
|
|
142
|
+
source_service=source_service,
|
|
143
|
+
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
144
|
+
code_search_service=embedding_factory(
|
|
145
|
+
task_name="code", app_context=app_context, session=session
|
|
146
|
+
),
|
|
147
|
+
text_search_service=embedding_factory(
|
|
148
|
+
task_name="text", app_context=app_context, session=session
|
|
149
|
+
),
|
|
150
|
+
enrichment_service=enrichment_factory(app_context),
|
|
136
151
|
)
|
|
137
152
|
|
|
138
153
|
snippets = await service.search(SearchRequest(code_query=query, top_k=top_k))
|
|
@@ -144,6 +159,7 @@ async def code(
|
|
|
144
159
|
for snippet in snippets:
|
|
145
160
|
click.echo("-" * 80)
|
|
146
161
|
click.echo(f"{snippet.uri}")
|
|
162
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
147
163
|
click.echo(snippet.content)
|
|
148
164
|
click.echo("-" * 80)
|
|
149
165
|
click.echo()
|
|
@@ -161,11 +177,20 @@ async def keyword(
|
|
|
161
177
|
top_k: int,
|
|
162
178
|
) -> None:
|
|
163
179
|
"""Search for snippets using keyword search."""
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
180
|
+
source_repository = SourceRepository(session)
|
|
181
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
182
|
+
repository = IndexRepository(session)
|
|
183
|
+
service = IndexService(
|
|
184
|
+
repository=repository,
|
|
185
|
+
source_service=source_service,
|
|
186
|
+
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
187
|
+
code_search_service=embedding_factory(
|
|
188
|
+
task_name="code", app_context=app_context, session=session
|
|
189
|
+
),
|
|
190
|
+
text_search_service=embedding_factory(
|
|
191
|
+
task_name="text", app_context=app_context, session=session
|
|
192
|
+
),
|
|
193
|
+
enrichment_service=enrichment_factory(app_context),
|
|
169
194
|
)
|
|
170
195
|
|
|
171
196
|
snippets = await service.search(SearchRequest(keywords=keywords, top_k=top_k))
|
|
@@ -177,6 +202,53 @@ async def keyword(
|
|
|
177
202
|
for snippet in snippets:
|
|
178
203
|
click.echo("-" * 80)
|
|
179
204
|
click.echo(f"{snippet.uri}")
|
|
205
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
206
|
+
click.echo(snippet.content)
|
|
207
|
+
click.echo("-" * 80)
|
|
208
|
+
click.echo()
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@search.command()
|
|
212
|
+
@click.argument("query")
|
|
213
|
+
@click.option("--top-k", default=10, help="Number of snippets to retrieve")
|
|
214
|
+
@with_app_context
|
|
215
|
+
@with_session
|
|
216
|
+
async def text(
|
|
217
|
+
session: AsyncSession,
|
|
218
|
+
app_context: AppContext,
|
|
219
|
+
query: str,
|
|
220
|
+
top_k: int,
|
|
221
|
+
) -> None:
|
|
222
|
+
"""Search for snippets using semantic text search.
|
|
223
|
+
|
|
224
|
+
This works best if your query is text.
|
|
225
|
+
"""
|
|
226
|
+
source_repository = SourceRepository(session)
|
|
227
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
228
|
+
repository = IndexRepository(session)
|
|
229
|
+
service = IndexService(
|
|
230
|
+
repository=repository,
|
|
231
|
+
source_service=source_service,
|
|
232
|
+
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
233
|
+
code_search_service=embedding_factory(
|
|
234
|
+
task_name="code", app_context=app_context, session=session
|
|
235
|
+
),
|
|
236
|
+
text_search_service=embedding_factory(
|
|
237
|
+
task_name="text", app_context=app_context, session=session
|
|
238
|
+
),
|
|
239
|
+
enrichment_service=enrichment_factory(app_context),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
snippets = await service.search(SearchRequest(text_query=query, top_k=top_k))
|
|
243
|
+
|
|
244
|
+
if len(snippets) == 0:
|
|
245
|
+
click.echo("No snippets found")
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
for snippet in snippets:
|
|
249
|
+
click.echo("-" * 80)
|
|
250
|
+
click.echo(f"{snippet.uri}")
|
|
251
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
180
252
|
click.echo(snippet.content)
|
|
181
253
|
click.echo("-" * 80)
|
|
182
254
|
click.echo()
|
|
@@ -186,28 +258,44 @@ async def keyword(
|
|
|
186
258
|
@click.option("--top-k", default=10, help="Number of snippets to retrieve")
|
|
187
259
|
@click.option("--keywords", required=True, help="Comma separated list of keywords")
|
|
188
260
|
@click.option("--code", required=True, help="Semantic code search query")
|
|
261
|
+
@click.option("--text", required=True, help="Semantic text search query")
|
|
189
262
|
@with_app_context
|
|
190
263
|
@with_session
|
|
191
|
-
async def hybrid(
|
|
264
|
+
async def hybrid( # noqa: PLR0913
|
|
192
265
|
session: AsyncSession,
|
|
193
266
|
app_context: AppContext,
|
|
194
267
|
top_k: int,
|
|
195
268
|
keywords: str,
|
|
196
269
|
code: str,
|
|
270
|
+
text: str,
|
|
197
271
|
) -> None:
|
|
198
272
|
"""Search for snippets using hybrid search."""
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
273
|
+
source_repository = SourceRepository(session)
|
|
274
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
275
|
+
repository = IndexRepository(session)
|
|
276
|
+
service = IndexService(
|
|
277
|
+
repository=repository,
|
|
278
|
+
source_service=source_service,
|
|
279
|
+
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
280
|
+
code_search_service=embedding_factory(
|
|
281
|
+
task_name="code", app_context=app_context, session=session
|
|
282
|
+
),
|
|
283
|
+
text_search_service=embedding_factory(
|
|
284
|
+
task_name="text", app_context=app_context, session=session
|
|
285
|
+
),
|
|
286
|
+
enrichment_service=enrichment_factory(app_context),
|
|
204
287
|
)
|
|
205
288
|
|
|
206
289
|
# Parse keywords into a list of strings
|
|
207
290
|
keywords_list = [k.strip().lower() for k in keywords.split(",")]
|
|
208
291
|
|
|
209
292
|
snippets = await service.search(
|
|
210
|
-
SearchRequest(
|
|
293
|
+
SearchRequest(
|
|
294
|
+
text_query=text,
|
|
295
|
+
keywords=keywords_list,
|
|
296
|
+
code_query=code,
|
|
297
|
+
top_k=top_k,
|
|
298
|
+
)
|
|
211
299
|
)
|
|
212
300
|
|
|
213
301
|
if len(snippets) == 0:
|
|
@@ -217,6 +305,7 @@ async def hybrid(
|
|
|
217
305
|
for snippet in snippets:
|
|
218
306
|
click.echo("-" * 80)
|
|
219
307
|
click.echo(f"{snippet.uri}")
|
|
308
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
220
309
|
click.echo(snippet.content)
|
|
221
310
|
click.echo("-" * 80)
|
|
222
311
|
click.echo()
|
kodit/config.py
CHANGED
|
@@ -12,14 +12,12 @@ from pydantic import BaseModel, Field
|
|
|
12
12
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
13
13
|
|
|
14
14
|
from kodit.database import Database
|
|
15
|
-
from kodit.embedding.embedding import TINY
|
|
16
15
|
|
|
17
16
|
DEFAULT_BASE_DIR = Path.home() / ".kodit"
|
|
18
17
|
DEFAULT_DB_URL = f"sqlite+aiosqlite:///{DEFAULT_BASE_DIR}/kodit.db"
|
|
19
18
|
DEFAULT_LOG_LEVEL = "INFO"
|
|
20
19
|
DEFAULT_LOG_FORMAT = "pretty"
|
|
21
20
|
DEFAULT_DISABLE_TELEMETRY = False
|
|
22
|
-
DEFAULT_EMBEDDING_MODEL_NAME = TINY
|
|
23
21
|
T = TypeVar("T")
|
|
24
22
|
|
|
25
23
|
|
|
@@ -31,6 +29,12 @@ class Endpoint(BaseModel):
|
|
|
31
29
|
base_url: str | None = None
|
|
32
30
|
|
|
33
31
|
|
|
32
|
+
class Search(BaseModel):
|
|
33
|
+
"""Search provides configuration for a search engine."""
|
|
34
|
+
|
|
35
|
+
provider: Literal["sqlite", "vectorchord"] = Field(default="sqlite")
|
|
36
|
+
|
|
37
|
+
|
|
34
38
|
class AppContext(BaseSettings):
|
|
35
39
|
"""Global context for the kodit project. Provides a shared state for the app."""
|
|
36
40
|
|
|
@@ -57,6 +61,9 @@ class AppContext(BaseSettings):
|
|
|
57
61
|
"(can be overridden by task-specific configuration)."
|
|
58
62
|
),
|
|
59
63
|
)
|
|
64
|
+
default_search: Search = Field(
|
|
65
|
+
default=Search(),
|
|
66
|
+
)
|
|
60
67
|
_db: Database | None = None
|
|
61
68
|
|
|
62
69
|
def model_post_init(self, _: Any) -> None:
|
kodit/database.py
CHANGED
|
@@ -27,10 +27,12 @@ class CommonMixin:
|
|
|
27
27
|
|
|
28
28
|
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
|
29
29
|
created_at: Mapped[datetime] = mapped_column(
|
|
30
|
-
DateTime, default=lambda: datetime.now(UTC)
|
|
30
|
+
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
|
31
31
|
)
|
|
32
32
|
updated_at: Mapped[datetime] = mapped_column(
|
|
33
|
-
DateTime
|
|
33
|
+
DateTime(timezone=True),
|
|
34
|
+
default=lambda: datetime.now(UTC),
|
|
35
|
+
onupdate=lambda: datetime.now(UTC),
|
|
34
36
|
)
|
|
35
37
|
|
|
36
38
|
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Embedding service."""
|
|
2
|
+
|
|
3
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
4
|
+
|
|
5
|
+
from kodit.config import AppContext
|
|
6
|
+
from kodit.embedding.embedding_provider.local_embedding_provider import (
|
|
7
|
+
CODE,
|
|
8
|
+
LocalEmbeddingProvider,
|
|
9
|
+
)
|
|
10
|
+
from kodit.embedding.embedding_provider.openai_embedding_provider import (
|
|
11
|
+
OpenAIEmbeddingProvider,
|
|
12
|
+
)
|
|
13
|
+
from kodit.embedding.embedding_repository import EmbeddingRepository
|
|
14
|
+
from kodit.embedding.local_vector_search_service import LocalVectorSearchService
|
|
15
|
+
from kodit.embedding.vector_search_service import (
|
|
16
|
+
VectorSearchService,
|
|
17
|
+
)
|
|
18
|
+
from kodit.embedding.vectorchord_vector_search_service import (
|
|
19
|
+
VectorChordVectorSearchService,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def embedding_factory(
|
|
24
|
+
task_name: str, app_context: AppContext, session: AsyncSession
|
|
25
|
+
) -> VectorSearchService:
|
|
26
|
+
"""Create an embedding service."""
|
|
27
|
+
embedding_repository = EmbeddingRepository(session=session)
|
|
28
|
+
embedding_provider = None
|
|
29
|
+
openai_client = app_context.get_default_openai_client()
|
|
30
|
+
if openai_client is not None:
|
|
31
|
+
embedding_provider = OpenAIEmbeddingProvider(openai_client=openai_client)
|
|
32
|
+
else:
|
|
33
|
+
embedding_provider = LocalEmbeddingProvider(CODE)
|
|
34
|
+
|
|
35
|
+
if app_context.default_search.provider == "vectorchord":
|
|
36
|
+
return VectorChordVectorSearchService(task_name, session, embedding_provider)
|
|
37
|
+
if app_context.default_search.provider == "sqlite":
|
|
38
|
+
return LocalVectorSearchService(
|
|
39
|
+
embedding_repository=embedding_repository,
|
|
40
|
+
embedding_provider=embedding_provider,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
msg = f"Invalid semantic search provider: {app_context.default_search.provider}"
|
|
44
|
+
raise ValueError(msg)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Embedding module."""
|