kodit 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/bm25/keyword_search_factory.py +17 -0
- kodit/bm25/keyword_search_service.py +34 -0
- kodit/bm25/{bm25.py → local_bm25.py} +40 -14
- kodit/bm25/vectorchord_bm25.py +193 -0
- kodit/cli.py +14 -11
- kodit/config.py +9 -2
- kodit/database.py +4 -2
- kodit/embedding/embedding_factory.py +44 -0
- kodit/embedding/embedding_provider/__init__.py +1 -0
- kodit/embedding/embedding_provider/embedding_provider.py +53 -0
- kodit/embedding/embedding_provider/hash_embedding_provider.py +77 -0
- kodit/embedding/embedding_provider/local_embedding_provider.py +58 -0
- kodit/embedding/embedding_provider/openai_embedding_provider.py +63 -0
- kodit/embedding/embedding_repository.py +206 -0
- kodit/embedding/local_vector_search_service.py +50 -0
- kodit/embedding/vector_search_service.py +38 -0
- kodit/embedding/vectorchord_vector_search_service.py +145 -0
- kodit/indexing/indexing_repository.py +24 -4
- kodit/indexing/indexing_service.py +25 -30
- kodit/mcp.py +7 -3
- kodit/search/search_repository.py +0 -121
- kodit/search/search_service.py +12 -24
- kodit/source/source_service.py +9 -3
- kodit/util/__init__.py +1 -0
- kodit/util/spinner.py +59 -0
- {kodit-0.1.14.dist-info → kodit-0.1.15.dist-info}/METADATA +2 -1
- kodit-0.1.15.dist-info/RECORD +58 -0
- kodit/embedding/embedding.py +0 -203
- kodit-0.1.14.dist-info/RECORD +0 -44
- {kodit-0.1.14.dist-info → kodit-0.1.15.dist-info}/WHEEL +0 -0
- {kodit-0.1.14.dist-info → kodit-0.1.15.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.14.dist-info → kodit-0.1.15.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Local embedding service."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
import tiktoken
|
|
7
|
+
from sentence_transformers import SentenceTransformer
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from kodit.embedding.embedding_provider.embedding_provider import (
|
|
11
|
+
EmbeddingProvider,
|
|
12
|
+
Vector,
|
|
13
|
+
split_sub_batches,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
TINY = "tiny"
|
|
17
|
+
CODE = "code"
|
|
18
|
+
TEST = "test"
|
|
19
|
+
|
|
20
|
+
COMMON_EMBEDDING_MODELS = {
|
|
21
|
+
TINY: "ibm-granite/granite-embedding-30m-english",
|
|
22
|
+
CODE: "flax-sentence-embeddings/st-codesearch-distilroberta-base",
|
|
23
|
+
TEST: "minishlab/potion-base-4M",
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LocalEmbeddingProvider(EmbeddingProvider):
|
|
28
|
+
"""Local embedder."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, model_name: str) -> None:
|
|
31
|
+
"""Initialize the local embedder."""
|
|
32
|
+
self.log = structlog.get_logger(__name__)
|
|
33
|
+
self.model_name = COMMON_EMBEDDING_MODELS.get(model_name, model_name)
|
|
34
|
+
self.embedding_model = None
|
|
35
|
+
self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
|
|
36
|
+
|
|
37
|
+
def _model(self) -> SentenceTransformer:
|
|
38
|
+
"""Get the embedding model."""
|
|
39
|
+
if self.embedding_model is None:
|
|
40
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
41
|
+
self.embedding_model = SentenceTransformer(
|
|
42
|
+
self.model_name,
|
|
43
|
+
trust_remote_code=True,
|
|
44
|
+
device="cpu", # Force CPU so we don't have to install accelerate, etc.
|
|
45
|
+
)
|
|
46
|
+
return self.embedding_model
|
|
47
|
+
|
|
48
|
+
async def embed(self, data: list[str]) -> list[Vector]:
|
|
49
|
+
"""Embed a list of strings."""
|
|
50
|
+
model = self._model()
|
|
51
|
+
|
|
52
|
+
batched_data = split_sub_batches(self.encoding, data)
|
|
53
|
+
|
|
54
|
+
results: list[Vector] = []
|
|
55
|
+
for batch in tqdm(batched_data, total=len(batched_data), leave=False):
|
|
56
|
+
embeddings = model.encode(batch, show_progress_bar=False, batch_size=4)
|
|
57
|
+
results.extend([[float(x) for x in embedding] for embedding in embeddings])
|
|
58
|
+
return results
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""OpenAI embedding service."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
import tiktoken
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from kodit.embedding.embedding_provider.embedding_provider import (
|
|
10
|
+
EmbeddingProvider,
|
|
11
|
+
Vector,
|
|
12
|
+
split_sub_batches,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
OPENAI_NUM_PARALLEL_TASKS = 10
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
19
|
+
"""OpenAI embedder."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
openai_client: AsyncOpenAI,
|
|
24
|
+
model_name: str = "text-embedding-3-small",
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initialize the OpenAI embedder."""
|
|
27
|
+
self.log = structlog.get_logger(__name__)
|
|
28
|
+
self.openai_client = openai_client
|
|
29
|
+
self.model_name = model_name
|
|
30
|
+
self.encoding = tiktoken.encoding_for_model(model_name)
|
|
31
|
+
|
|
32
|
+
async def embed(self, data: list[str]) -> list[Vector]:
|
|
33
|
+
"""Embed a list of documents."""
|
|
34
|
+
# First split the list into a list of list where each sublist has fewer than
|
|
35
|
+
# max tokens.
|
|
36
|
+
batched_data = split_sub_batches(self.encoding, data)
|
|
37
|
+
|
|
38
|
+
# Process batches in parallel with a semaphore to limit concurrent requests
|
|
39
|
+
sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
|
|
40
|
+
|
|
41
|
+
async def process_batch(batch: list[str]) -> list[Vector]:
|
|
42
|
+
async with sem:
|
|
43
|
+
try:
|
|
44
|
+
response = await self.openai_client.embeddings.create(
|
|
45
|
+
model=self.model_name,
|
|
46
|
+
input=batch,
|
|
47
|
+
)
|
|
48
|
+
return [
|
|
49
|
+
[float(x) for x in embedding.embedding]
|
|
50
|
+
for embedding in response.data
|
|
51
|
+
]
|
|
52
|
+
except Exception as e:
|
|
53
|
+
self.log.exception("Error embedding batch", error=str(e))
|
|
54
|
+
return []
|
|
55
|
+
|
|
56
|
+
# Create tasks for all batches
|
|
57
|
+
tasks = [process_batch(batch) for batch in batched_data]
|
|
58
|
+
|
|
59
|
+
# Process all batches and yield results as they complete
|
|
60
|
+
results: list[Vector] = []
|
|
61
|
+
for task in asyncio.as_completed(tasks):
|
|
62
|
+
results.extend(await task)
|
|
63
|
+
return results
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""Repository for managing embeddings."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
6
|
+
|
|
7
|
+
from kodit.embedding.embedding_models import Embedding, EmbeddingType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EmbeddingRepository:
|
|
11
|
+
"""Repository for managing embeddings.
|
|
12
|
+
|
|
13
|
+
This class provides methods for creating and retrieving embeddings from the
|
|
14
|
+
database. It handles the low-level database operations and transaction management.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
session: The SQLAlchemy async session to use for database operations.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, session: AsyncSession) -> None:
|
|
22
|
+
"""Initialize the embedding repository."""
|
|
23
|
+
self.session = session
|
|
24
|
+
|
|
25
|
+
async def create_embedding(self, embedding: Embedding) -> Embedding:
|
|
26
|
+
"""Create a new embedding record in the database.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
embedding: The Embedding instance to create.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The created Embedding instance.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
self.session.add(embedding)
|
|
36
|
+
await self.session.commit()
|
|
37
|
+
return embedding
|
|
38
|
+
|
|
39
|
+
async def get_embedding_by_snippet_id_and_type(
|
|
40
|
+
self, snippet_id: int, embedding_type: EmbeddingType
|
|
41
|
+
) -> Embedding | None:
|
|
42
|
+
"""Get an embedding by its snippet ID and type.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
snippet_id: The ID of the snippet to get the embedding for.
|
|
46
|
+
embedding_type: The type of embedding to get.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The Embedding instance if found, None otherwise.
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
query = select(Embedding).where(
|
|
53
|
+
Embedding.snippet_id == snippet_id,
|
|
54
|
+
Embedding.type == embedding_type,
|
|
55
|
+
)
|
|
56
|
+
result = await self.session.execute(query)
|
|
57
|
+
return result.scalar_one_or_none()
|
|
58
|
+
|
|
59
|
+
async def list_embeddings_by_type(
|
|
60
|
+
self, embedding_type: EmbeddingType
|
|
61
|
+
) -> list[Embedding]:
|
|
62
|
+
"""List all embeddings of a given type.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
embedding_type: The type of embeddings to list.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A list of Embedding instances.
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
query = select(Embedding).where(Embedding.type == embedding_type)
|
|
72
|
+
result = await self.session.execute(query)
|
|
73
|
+
return list(result.scalars())
|
|
74
|
+
|
|
75
|
+
async def delete_embeddings_by_snippet_id(self, snippet_id: int) -> None:
|
|
76
|
+
"""Delete all embeddings for a snippet.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
snippet_id: The ID of the snippet to delete embeddings for.
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
query = select(Embedding).where(Embedding.snippet_id == snippet_id)
|
|
83
|
+
result = await self.session.execute(query)
|
|
84
|
+
embeddings = result.scalars().all()
|
|
85
|
+
for embedding in embeddings:
|
|
86
|
+
await self.session.delete(embedding)
|
|
87
|
+
await self.session.commit()
|
|
88
|
+
|
|
89
|
+
async def list_semantic_results(
|
|
90
|
+
self, embedding_type: EmbeddingType, embedding: list[float], top_k: int = 10
|
|
91
|
+
) -> list[tuple[int, float]]:
|
|
92
|
+
"""List semantic results using cosine similarity.
|
|
93
|
+
|
|
94
|
+
This implementation fetches all embeddings of the given type and computes
|
|
95
|
+
cosine similarity in Python using NumPy for better performance.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
embedding_type: The type of embeddings to search
|
|
99
|
+
embedding: The query embedding vector
|
|
100
|
+
top_k: Number of results to return
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
List of (snippet_id, similarity_score) tuples, sorted by similarity
|
|
104
|
+
|
|
105
|
+
"""
|
|
106
|
+
# Step 1: Fetch embeddings from database
|
|
107
|
+
embeddings = await self._list_embedding_values(embedding_type)
|
|
108
|
+
if not embeddings:
|
|
109
|
+
return []
|
|
110
|
+
|
|
111
|
+
# Step 2: Convert to numpy arrays
|
|
112
|
+
stored_vecs, query_vec = self._prepare_vectors(embeddings, embedding)
|
|
113
|
+
|
|
114
|
+
# Step 3: Compute similarities
|
|
115
|
+
similarities = self._compute_similarities(stored_vecs, query_vec)
|
|
116
|
+
|
|
117
|
+
# Step 4: Get top-k results
|
|
118
|
+
return self._get_top_k_results(similarities, embeddings, top_k)
|
|
119
|
+
|
|
120
|
+
async def _list_embedding_values(
|
|
121
|
+
self, embedding_type: EmbeddingType
|
|
122
|
+
) -> list[tuple[int, list[float]]]:
|
|
123
|
+
"""List all embeddings of a given type from the database.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
embedding_type: The type of embeddings to fetch
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
List of (snippet_id, embedding) tuples
|
|
130
|
+
|
|
131
|
+
"""
|
|
132
|
+
# Only select the fields we need and use a more efficient query
|
|
133
|
+
query = select(Embedding.snippet_id, Embedding.embedding).where(
|
|
134
|
+
Embedding.type == embedding_type
|
|
135
|
+
)
|
|
136
|
+
rows = await self.session.execute(query)
|
|
137
|
+
return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
|
|
138
|
+
|
|
139
|
+
def _prepare_vectors(
|
|
140
|
+
self, embeddings: list[tuple[int, list[float]]], query_embedding: list[float]
|
|
141
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
142
|
+
"""Convert embeddings to numpy arrays.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
embeddings: List of (snippet_id, embedding) tuples
|
|
146
|
+
query_embedding: Query embedding vector
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Tuple of (stored_vectors, query_vector) as numpy arrays
|
|
150
|
+
|
|
151
|
+
"""
|
|
152
|
+
try:
|
|
153
|
+
stored_vecs = np.array(
|
|
154
|
+
[emb[1] for emb in embeddings]
|
|
155
|
+
) # Use index 1 to get embedding
|
|
156
|
+
except ValueError as e:
|
|
157
|
+
if "inhomogeneous" in str(e):
|
|
158
|
+
msg = (
|
|
159
|
+
"The database has returned embeddings of different sizes. If you"
|
|
160
|
+
"have recently updated the embedding model, you will need to"
|
|
161
|
+
"delete your database and re-index your snippets."
|
|
162
|
+
)
|
|
163
|
+
raise ValueError(msg) from e
|
|
164
|
+
raise
|
|
165
|
+
|
|
166
|
+
query_vec = np.array(query_embedding)
|
|
167
|
+
return stored_vecs, query_vec
|
|
168
|
+
|
|
169
|
+
def _compute_similarities(
|
|
170
|
+
self, stored_vecs: np.ndarray, query_vec: np.ndarray
|
|
171
|
+
) -> np.ndarray:
|
|
172
|
+
"""Compute cosine similarities between stored vectors and query vector.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
stored_vecs: Array of stored embedding vectors
|
|
176
|
+
query_vec: Query embedding vector
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Array of similarity scores
|
|
180
|
+
|
|
181
|
+
"""
|
|
182
|
+
stored_norms = np.linalg.norm(stored_vecs, axis=1)
|
|
183
|
+
query_norm = np.linalg.norm(query_vec)
|
|
184
|
+
return np.dot(stored_vecs, query_vec) / (stored_norms * query_norm)
|
|
185
|
+
|
|
186
|
+
def _get_top_k_results(
|
|
187
|
+
self,
|
|
188
|
+
similarities: np.ndarray,
|
|
189
|
+
embeddings: list[tuple[int, list[float]]],
|
|
190
|
+
top_k: int,
|
|
191
|
+
) -> list[tuple[int, float]]:
|
|
192
|
+
"""Get top-k results by similarity score.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
similarities: Array of similarity scores
|
|
196
|
+
embeddings: List of (snippet_id, embedding) tuples
|
|
197
|
+
top_k: Number of results to return
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
List of (snippet_id, similarity_score) tuples
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
top_indices = np.argsort(similarities)[::-1][:top_k]
|
|
204
|
+
return [
|
|
205
|
+
(embeddings[i][0], float(similarities[i])) for i in top_indices
|
|
206
|
+
] # Use index 0 to get snippet_id
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Local vector search."""
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
import tiktoken
|
|
5
|
+
|
|
6
|
+
from kodit.embedding.embedding_models import Embedding, EmbeddingType
|
|
7
|
+
from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
|
|
8
|
+
from kodit.embedding.embedding_repository import EmbeddingRepository
|
|
9
|
+
from kodit.embedding.vector_search_service import (
|
|
10
|
+
VectorSearchRequest,
|
|
11
|
+
VectorSearchResponse,
|
|
12
|
+
VectorSearchService,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LocalVectorSearchService(VectorSearchService):
|
|
17
|
+
"""Local vector search."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
embedding_repository: EmbeddingRepository,
|
|
22
|
+
embedding_provider: EmbeddingProvider,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initialize the local embedder."""
|
|
25
|
+
self.log = structlog.get_logger(__name__)
|
|
26
|
+
self.embedding_repository = embedding_repository
|
|
27
|
+
self.embedding_provider = embedding_provider
|
|
28
|
+
self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
|
|
29
|
+
|
|
30
|
+
async def index(self, data: list[VectorSearchRequest]) -> None:
|
|
31
|
+
"""Embed a list of documents."""
|
|
32
|
+
embeddings = await self.embedding_provider.embed([i.text for i in data])
|
|
33
|
+
for i, x in zip(data, embeddings, strict=False):
|
|
34
|
+
await self.embedding_repository.create_embedding(
|
|
35
|
+
Embedding(
|
|
36
|
+
snippet_id=i.snippet_id,
|
|
37
|
+
embedding=[float(y) for y in x],
|
|
38
|
+
type=EmbeddingType.CODE,
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
|
|
43
|
+
"""Query the embedding model."""
|
|
44
|
+
embedding = (await self.embedding_provider.embed([query]))[0]
|
|
45
|
+
results = await self.embedding_repository.list_semantic_results(
|
|
46
|
+
EmbeddingType.CODE, [float(x) for x in embedding], top_k
|
|
47
|
+
)
|
|
48
|
+
return [
|
|
49
|
+
VectorSearchResponse(snippet_id, score) for snippet_id, score in results
|
|
50
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Embedding service."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class VectorSearchResponse(NamedTuple):
|
|
8
|
+
"""Embedding result."""
|
|
9
|
+
|
|
10
|
+
snippet_id: int
|
|
11
|
+
score: float
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VectorSearchRequest(NamedTuple):
|
|
15
|
+
"""Input for embedding."""
|
|
16
|
+
|
|
17
|
+
snippet_id: int
|
|
18
|
+
text: str
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class VectorSearchService(ABC):
|
|
22
|
+
"""Semantic search service interface."""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
async def index(self, data: list[VectorSearchRequest]) -> None:
|
|
26
|
+
"""Embed a list of documents.
|
|
27
|
+
|
|
28
|
+
The embedding service accepts a massive list of id,strings to embed. Behind the
|
|
29
|
+
scenes it batches up requests and parallelizes them for performance according to
|
|
30
|
+
the specifics of the embedding service.
|
|
31
|
+
|
|
32
|
+
The id reference is required because the parallelization may return results out
|
|
33
|
+
of order.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
|
|
38
|
+
"""Query the embedding model."""
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Vectorchord vector search."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlalchemy import Result, TextClause, text
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
|
|
9
|
+
from kodit.embedding.vector_search_service import (
|
|
10
|
+
VectorSearchRequest,
|
|
11
|
+
VectorSearchResponse,
|
|
12
|
+
VectorSearchService,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
TABLE_NAME = "vectorchord_embeddings"
|
|
16
|
+
INDEX_NAME = f"{TABLE_NAME}_idx"
|
|
17
|
+
|
|
18
|
+
# SQL Queries
|
|
19
|
+
CREATE_VCHORD_EXTENSION = """
|
|
20
|
+
CREATE EXTENSION IF NOT EXISTS vchord CASCADE;
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
CHECK_VCHORD_EMBEDDING_DIMENSION = f"""
|
|
24
|
+
SELECT a.atttypmod as dimension
|
|
25
|
+
FROM pg_attribute a
|
|
26
|
+
JOIN pg_class c ON a.attrelid = c.oid
|
|
27
|
+
WHERE c.relname = '{TABLE_NAME}'
|
|
28
|
+
AND a.attname = 'embedding';
|
|
29
|
+
""" # noqa: S608
|
|
30
|
+
|
|
31
|
+
CREATE_VCHORD_INDEX = f"""
|
|
32
|
+
CREATE INDEX IF NOT EXISTS {INDEX_NAME}
|
|
33
|
+
ON {TABLE_NAME}
|
|
34
|
+
USING vchordrq (embedding vector_l2_ops) WITH (options = $$
|
|
35
|
+
residual_quantization = true
|
|
36
|
+
[build.internal]
|
|
37
|
+
lists = []
|
|
38
|
+
$$);
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
INSERT_QUERY = f"""
|
|
42
|
+
INSERT INTO {TABLE_NAME} (snippet_id, embedding)
|
|
43
|
+
VALUES (:snippet_id, :embedding)
|
|
44
|
+
ON CONFLICT (snippet_id) DO UPDATE
|
|
45
|
+
SET embedding = EXCLUDED.embedding
|
|
46
|
+
""" # noqa: S608
|
|
47
|
+
|
|
48
|
+
# Note that <=> in vectorchord is cosine distance
|
|
49
|
+
# So scores go from 0 (similar) to 2 (opposite)
|
|
50
|
+
SEARCH_QUERY = f"""
|
|
51
|
+
SELECT snippet_id, embedding <=> :query as score
|
|
52
|
+
FROM {TABLE_NAME}
|
|
53
|
+
ORDER BY score ASC
|
|
54
|
+
LIMIT :top_k;
|
|
55
|
+
""" # noqa: S608
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class VectorChordVectorSearchService(VectorSearchService):
|
|
59
|
+
"""VectorChord vector search."""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
session: AsyncSession,
|
|
64
|
+
embedding_provider: EmbeddingProvider,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Initialize the VectorChord BM25."""
|
|
67
|
+
self.embedding_provider = embedding_provider
|
|
68
|
+
self._session = session
|
|
69
|
+
self._initialized = False
|
|
70
|
+
|
|
71
|
+
async def _initialize(self) -> None:
|
|
72
|
+
"""Initialize the VectorChord environment."""
|
|
73
|
+
try:
|
|
74
|
+
await self._create_extensions()
|
|
75
|
+
await self._create_tables()
|
|
76
|
+
self._initialized = True
|
|
77
|
+
except Exception as e:
|
|
78
|
+
msg = f"Failed to initialize VectorChord repository: {e}"
|
|
79
|
+
raise RuntimeError(msg) from e
|
|
80
|
+
|
|
81
|
+
async def _create_extensions(self) -> None:
|
|
82
|
+
"""Create the necessary extensions."""
|
|
83
|
+
await self._session.execute(text(CREATE_VCHORD_EXTENSION))
|
|
84
|
+
await self._commit()
|
|
85
|
+
|
|
86
|
+
async def _create_tables(self) -> None:
|
|
87
|
+
"""Create the necessary tables."""
|
|
88
|
+
vector_dim = (await self.embedding_provider.embed(["dimension"]))[0]
|
|
89
|
+
await self._session.execute(
|
|
90
|
+
text(
|
|
91
|
+
f"""CREATE TABLE IF NOT EXISTS {TABLE_NAME} (
|
|
92
|
+
id SERIAL PRIMARY KEY,
|
|
93
|
+
snippet_id INT NOT NULL UNIQUE,
|
|
94
|
+
embedding VECTOR({len(vector_dim)}) NOT NULL
|
|
95
|
+
);"""
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
await self._session.execute(text(CREATE_VCHORD_INDEX))
|
|
99
|
+
result = await self._session.execute(text(CHECK_VCHORD_EMBEDDING_DIMENSION))
|
|
100
|
+
vector_dim_from_db = result.scalar_one()
|
|
101
|
+
if vector_dim_from_db != len(vector_dim):
|
|
102
|
+
msg = (
|
|
103
|
+
f"Embedding vector dimension does not match database, "
|
|
104
|
+
f"please delete your index: {vector_dim_from_db} != {len(vector_dim)}"
|
|
105
|
+
)
|
|
106
|
+
raise ValueError(msg)
|
|
107
|
+
await self._commit()
|
|
108
|
+
|
|
109
|
+
async def _execute(
|
|
110
|
+
self, query: TextClause, param_list: list[Any] | dict[str, Any] | None = None
|
|
111
|
+
) -> Result:
|
|
112
|
+
"""Execute a query."""
|
|
113
|
+
if not self._initialized:
|
|
114
|
+
await self._initialize()
|
|
115
|
+
return await self._session.execute(query, param_list)
|
|
116
|
+
|
|
117
|
+
async def _commit(self) -> None:
|
|
118
|
+
"""Commit the session."""
|
|
119
|
+
await self._session.commit()
|
|
120
|
+
|
|
121
|
+
async def index(self, data: list[VectorSearchRequest]) -> None:
|
|
122
|
+
"""Embed a list of documents."""
|
|
123
|
+
embeddings = await self.embedding_provider.embed([doc.text for doc in data])
|
|
124
|
+
# Execute inserts
|
|
125
|
+
await self._execute(
|
|
126
|
+
text(INSERT_QUERY),
|
|
127
|
+
[
|
|
128
|
+
{"snippet_id": doc.snippet_id, "embedding": str(embedding)}
|
|
129
|
+
for doc, embedding in zip(data, embeddings, strict=True)
|
|
130
|
+
],
|
|
131
|
+
)
|
|
132
|
+
await self._commit()
|
|
133
|
+
|
|
134
|
+
async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
|
|
135
|
+
"""Query the embedding model."""
|
|
136
|
+
embedding = await self.embedding_provider.embed([query])
|
|
137
|
+
result = await self._execute(
|
|
138
|
+
text(SEARCH_QUERY), {"query": str(embedding[0]), "top_k": top_k}
|
|
139
|
+
)
|
|
140
|
+
rows = result.mappings().all()
|
|
141
|
+
|
|
142
|
+
return [
|
|
143
|
+
VectorSearchResponse(snippet_id=row["snippet_id"], score=row["score"])
|
|
144
|
+
for row in rows
|
|
145
|
+
]
|
|
@@ -10,6 +10,7 @@ from typing import TypeVar
|
|
|
10
10
|
|
|
11
11
|
from sqlalchemy import delete, func, select
|
|
12
12
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
13
|
+
from sqlalchemy.orm.exc import MultipleResultsFound
|
|
13
14
|
|
|
14
15
|
from kodit.embedding.embedding_models import Embedding
|
|
15
16
|
from kodit.indexing.indexing_models import Index, Snippet
|
|
@@ -124,15 +125,34 @@ class IndexRepository:
|
|
|
124
125
|
index.updated_at = datetime.now(UTC)
|
|
125
126
|
await self.session.commit()
|
|
126
127
|
|
|
127
|
-
async def
|
|
128
|
-
"""Add a new snippet to the database.
|
|
128
|
+
async def add_snippet_or_update_content(self, snippet: Snippet) -> None:
|
|
129
|
+
"""Add a new snippet to the database if it doesn't exist, otherwise update it.
|
|
129
130
|
|
|
130
131
|
Args:
|
|
131
132
|
snippet: The Snippet instance to add.
|
|
132
133
|
|
|
133
134
|
"""
|
|
134
|
-
|
|
135
|
-
|
|
135
|
+
query = select(Snippet).where(
|
|
136
|
+
Snippet.file_id == snippet.file_id,
|
|
137
|
+
Snippet.index_id == snippet.index_id,
|
|
138
|
+
)
|
|
139
|
+
result = await self.session.execute(query)
|
|
140
|
+
try:
|
|
141
|
+
existing_snippet = result.scalar_one_or_none()
|
|
142
|
+
|
|
143
|
+
if existing_snippet:
|
|
144
|
+
existing_snippet.content = snippet.content
|
|
145
|
+
else:
|
|
146
|
+
self.session.add(snippet)
|
|
147
|
+
|
|
148
|
+
await self.session.commit()
|
|
149
|
+
except MultipleResultsFound as e:
|
|
150
|
+
msg = (
|
|
151
|
+
f"Multiple snippets found for file_id {snippet.file_id}, this "
|
|
152
|
+
"shouldn't happen. "
|
|
153
|
+
"Please report this as a bug then delete your index and start again."
|
|
154
|
+
)
|
|
155
|
+
raise ValueError(msg) from e
|
|
136
156
|
|
|
137
157
|
async def delete_all_snippets(self, index_id: int) -> None:
|
|
138
158
|
"""Delete all snippets for an index.
|