kodit 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/bm25/keyword_search_factory.py +17 -0
- kodit/bm25/keyword_search_service.py +34 -0
- kodit/bm25/{bm25.py → local_bm25.py} +40 -14
- kodit/bm25/vectorchord_bm25.py +193 -0
- kodit/cli.py +114 -25
- kodit/config.py +9 -2
- kodit/database.py +4 -2
- kodit/embedding/embedding_factory.py +44 -0
- kodit/embedding/embedding_provider/__init__.py +1 -0
- kodit/embedding/embedding_provider/embedding_provider.py +60 -0
- kodit/embedding/embedding_provider/hash_embedding_provider.py +77 -0
- kodit/embedding/embedding_provider/local_embedding_provider.py +58 -0
- kodit/embedding/embedding_provider/openai_embedding_provider.py +75 -0
- kodit/{search/search_repository.py → embedding/embedding_repository.py} +61 -33
- kodit/embedding/local_vector_search_service.py +50 -0
- kodit/embedding/vector_search_service.py +38 -0
- kodit/embedding/vectorchord_vector_search_service.py +154 -0
- kodit/enrichment/__init__.py +1 -0
- kodit/enrichment/enrichment_factory.py +23 -0
- kodit/enrichment/enrichment_provider/__init__.py +1 -0
- kodit/enrichment/enrichment_provider/enrichment_provider.py +16 -0
- kodit/enrichment/enrichment_provider/local_enrichment_provider.py +63 -0
- kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +77 -0
- kodit/enrichment/enrichment_service.py +33 -0
- kodit/indexing/fusion.py +67 -0
- kodit/indexing/indexing_repository.py +44 -4
- kodit/indexing/indexing_service.py +142 -31
- kodit/mcp.py +31 -18
- kodit/snippets/languages/go.scm +26 -0
- kodit/source/source_service.py +9 -3
- kodit/util/__init__.py +1 -0
- kodit/util/spinner.py +59 -0
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/METADATA +4 -1
- kodit-0.1.16.dist-info/RECORD +64 -0
- kodit/embedding/embedding.py +0 -203
- kodit/search/__init__.py +0 -1
- kodit/search/search_service.py +0 -147
- kodit-0.1.14.dist-info/RECORD +0 -44
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/WHEEL +0 -0
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Embedding provider."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
import tiktoken
|
|
7
|
+
|
|
8
|
+
OPENAI_MAX_EMBEDDING_SIZE = 8192
|
|
9
|
+
|
|
10
|
+
Vector = list[float]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmbeddingProvider(ABC):
|
|
14
|
+
"""Embedding provider."""
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
async def embed(self, data: list[str]) -> list[Vector]:
|
|
18
|
+
"""Embed a list of strings.
|
|
19
|
+
|
|
20
|
+
The embedding provider is responsible for embedding a list of strings into a
|
|
21
|
+
list of vectors. The embedding provider is responsible for splitting the list of
|
|
22
|
+
strings into smaller sub-batches and embedding them in parallel.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def split_sub_batches(encoding: tiktoken.Encoding, data: list[str]) -> list[list[str]]:
|
|
27
|
+
"""Split a list of strings into smaller sub-batches."""
|
|
28
|
+
log = structlog.get_logger(__name__)
|
|
29
|
+
result = []
|
|
30
|
+
data_to_process = [s for s in data if s.strip()] # Filter out empty strings
|
|
31
|
+
|
|
32
|
+
while data_to_process:
|
|
33
|
+
next_batch = []
|
|
34
|
+
current_tokens = 0
|
|
35
|
+
|
|
36
|
+
while data_to_process:
|
|
37
|
+
next_item = data_to_process[0]
|
|
38
|
+
item_tokens = len(encoding.encode(next_item))
|
|
39
|
+
|
|
40
|
+
if item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
|
|
41
|
+
# Loop around trying to truncate the snippet until it fits in the max
|
|
42
|
+
# embedding size
|
|
43
|
+
while item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
|
|
44
|
+
next_item = next_item[:-1]
|
|
45
|
+
item_tokens = len(encoding.encode(next_item))
|
|
46
|
+
|
|
47
|
+
data_to_process[0] = next_item
|
|
48
|
+
|
|
49
|
+
log.warning("Truncated snippet", snippet=next_item)
|
|
50
|
+
|
|
51
|
+
if current_tokens + item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
|
|
52
|
+
break
|
|
53
|
+
|
|
54
|
+
next_batch.append(data_to_process.pop(0))
|
|
55
|
+
current_tokens += item_tokens
|
|
56
|
+
|
|
57
|
+
if next_batch:
|
|
58
|
+
result.append(next_batch)
|
|
59
|
+
|
|
60
|
+
return result
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Hash embedding provider, for use in tests only."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import hashlib
|
|
5
|
+
import math
|
|
6
|
+
from collections.abc import Generator, Sequence
|
|
7
|
+
|
|
8
|
+
from kodit.embedding.embedding_provider.embedding_provider import (
|
|
9
|
+
EmbeddingProvider,
|
|
10
|
+
Vector,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HashEmbeddingProvider(EmbeddingProvider):
|
|
15
|
+
"""A minimal test-time embedding provider.
|
|
16
|
+
|
|
17
|
+
• Zero third-party dependencies (uses only std-lib)
|
|
18
|
+
• Distinguishes strings by hashing with SHA-256
|
|
19
|
+
• Maps the digest to a fixed-size float vector, then ℓ₂-normalises
|
|
20
|
+
• Splits work into small asynchronous chunks for speed in event loops
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, dim: int = 16, batch_size: int = 64) -> None:
|
|
24
|
+
"""Initialize the hash embedding provider."""
|
|
25
|
+
if dim <= 0:
|
|
26
|
+
msg = f"dim must be > 0, got {dim}"
|
|
27
|
+
raise ValueError(msg)
|
|
28
|
+
if batch_size <= 0:
|
|
29
|
+
msg = f"batch_size must be > 0, got {batch_size}"
|
|
30
|
+
raise ValueError(msg)
|
|
31
|
+
self.dim = dim
|
|
32
|
+
self.batch_size = batch_size
|
|
33
|
+
|
|
34
|
+
async def embed(self, data: list[str]) -> list[Vector]:
|
|
35
|
+
"""Embed every string in *data*, preserving order.
|
|
36
|
+
|
|
37
|
+
Work is sliced into *batch_size* chunks and scheduled concurrently
|
|
38
|
+
(still CPU-bound, but enough to cooperate with an asyncio loop).
|
|
39
|
+
"""
|
|
40
|
+
if not data:
|
|
41
|
+
return []
|
|
42
|
+
|
|
43
|
+
async def _embed_chunk(chunk: Sequence[str]) -> list[Vector]:
|
|
44
|
+
return [self._string_to_vector(text) for text in chunk]
|
|
45
|
+
|
|
46
|
+
tasks = [
|
|
47
|
+
asyncio.create_task(_embed_chunk(chunk))
|
|
48
|
+
for chunk in self._chunked(data, self.batch_size)
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
vectors: list[Vector] = []
|
|
52
|
+
for task in tasks:
|
|
53
|
+
vectors.extend(await task)
|
|
54
|
+
return vectors
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _chunked(seq: Sequence[str], size: int) -> Generator[Sequence[str], None, None]:
|
|
58
|
+
"""Yield successive *size*-sized slices from *seq*."""
|
|
59
|
+
for i in range(0, len(seq), size):
|
|
60
|
+
yield seq[i : i + size]
|
|
61
|
+
|
|
62
|
+
def _string_to_vector(self, text: str) -> Vector:
|
|
63
|
+
"""Deterministically convert *text* to a normalised float vector."""
|
|
64
|
+
digest = hashlib.sha256(text.encode("utf-8")).digest()
|
|
65
|
+
|
|
66
|
+
# Build the vector from 4-byte windows of the digest.
|
|
67
|
+
vec = [
|
|
68
|
+
int.from_bytes(
|
|
69
|
+
digest[(i * 4) % len(digest) : (i * 4) % len(digest) + 4], "big"
|
|
70
|
+
)
|
|
71
|
+
/ 0xFFFFFFFF
|
|
72
|
+
for i in range(self.dim)
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
# ℓ₂-normalise so magnitudes are comparable.
|
|
76
|
+
norm = math.sqrt(sum(x * x for x in vec)) or 1.0
|
|
77
|
+
return [x / norm for x in vec]
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Local embedding service."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
import tiktoken
|
|
7
|
+
from sentence_transformers import SentenceTransformer
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from kodit.embedding.embedding_provider.embedding_provider import (
|
|
11
|
+
EmbeddingProvider,
|
|
12
|
+
Vector,
|
|
13
|
+
split_sub_batches,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
TINY = "tiny"
|
|
17
|
+
CODE = "code"
|
|
18
|
+
TEST = "test"
|
|
19
|
+
|
|
20
|
+
COMMON_EMBEDDING_MODELS = {
|
|
21
|
+
TINY: "ibm-granite/granite-embedding-30m-english",
|
|
22
|
+
CODE: "flax-sentence-embeddings/st-codesearch-distilroberta-base",
|
|
23
|
+
TEST: "minishlab/potion-base-4M",
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LocalEmbeddingProvider(EmbeddingProvider):
|
|
28
|
+
"""Local embedder."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, model_name: str) -> None:
|
|
31
|
+
"""Initialize the local embedder."""
|
|
32
|
+
self.log = structlog.get_logger(__name__)
|
|
33
|
+
self.model_name = COMMON_EMBEDDING_MODELS.get(model_name, model_name)
|
|
34
|
+
self.embedding_model = None
|
|
35
|
+
self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
|
|
36
|
+
|
|
37
|
+
def _model(self) -> SentenceTransformer:
|
|
38
|
+
"""Get the embedding model."""
|
|
39
|
+
if self.embedding_model is None:
|
|
40
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
41
|
+
self.embedding_model = SentenceTransformer(
|
|
42
|
+
self.model_name,
|
|
43
|
+
trust_remote_code=True,
|
|
44
|
+
device="cpu", # Force CPU so we don't have to install accelerate, etc.
|
|
45
|
+
)
|
|
46
|
+
return self.embedding_model
|
|
47
|
+
|
|
48
|
+
async def embed(self, data: list[str]) -> list[Vector]:
|
|
49
|
+
"""Embed a list of strings."""
|
|
50
|
+
model = self._model()
|
|
51
|
+
|
|
52
|
+
batched_data = split_sub_batches(self.encoding, data)
|
|
53
|
+
|
|
54
|
+
results: list[Vector] = []
|
|
55
|
+
for batch in tqdm(batched_data, total=len(batched_data), leave=False):
|
|
56
|
+
embeddings = model.encode(batch, show_progress_bar=False, batch_size=4)
|
|
57
|
+
results.extend([[float(x) for x in embedding] for embedding in embeddings])
|
|
58
|
+
return results
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""OpenAI embedding service."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
import tiktoken
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from kodit.embedding.embedding_provider.embedding_provider import (
|
|
10
|
+
EmbeddingProvider,
|
|
11
|
+
Vector,
|
|
12
|
+
split_sub_batches,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
OPENAI_NUM_PARALLEL_TASKS = 10
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
19
|
+
"""OpenAI embedder."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
openai_client: AsyncOpenAI,
|
|
24
|
+
model_name: str = "text-embedding-3-small",
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initialize the OpenAI embedder."""
|
|
27
|
+
self.log = structlog.get_logger(__name__)
|
|
28
|
+
self.openai_client = openai_client
|
|
29
|
+
self.model_name = model_name
|
|
30
|
+
self.encoding = tiktoken.encoding_for_model(model_name)
|
|
31
|
+
|
|
32
|
+
async def embed(self, data: list[str]) -> list[Vector]:
|
|
33
|
+
"""Embed a list of documents."""
|
|
34
|
+
# First split the list into a list of list where each sublist has fewer than
|
|
35
|
+
# max tokens.
|
|
36
|
+
batched_data = split_sub_batches(self.encoding, data)
|
|
37
|
+
|
|
38
|
+
# Process batches in parallel with a semaphore to limit concurrent requests
|
|
39
|
+
sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
|
|
40
|
+
|
|
41
|
+
# Create a list of tuples with a temporary id for each batch
|
|
42
|
+
# We need to do this so that we can return the results in the same order as the
|
|
43
|
+
# input data
|
|
44
|
+
input_data = [(i, batch) for i, batch in enumerate(batched_data)]
|
|
45
|
+
|
|
46
|
+
async def process_batch(
|
|
47
|
+
data: tuple[int, list[str]],
|
|
48
|
+
) -> tuple[int, list[Vector]]:
|
|
49
|
+
batch_id, batch = data
|
|
50
|
+
async with sem:
|
|
51
|
+
try:
|
|
52
|
+
response = await self.openai_client.embeddings.create(
|
|
53
|
+
model=self.model_name,
|
|
54
|
+
input=batch,
|
|
55
|
+
)
|
|
56
|
+
return batch_id, [
|
|
57
|
+
[float(x) for x in embedding.embedding]
|
|
58
|
+
for embedding in response.data
|
|
59
|
+
]
|
|
60
|
+
except Exception as e:
|
|
61
|
+
self.log.exception("Error embedding batch", error=str(e))
|
|
62
|
+
return batch_id, []
|
|
63
|
+
|
|
64
|
+
# Create tasks for all batches
|
|
65
|
+
tasks = [process_batch(batch) for batch in input_data]
|
|
66
|
+
|
|
67
|
+
# Process all batches and yield results as they complete
|
|
68
|
+
results: list[tuple[int, list[Vector]]] = []
|
|
69
|
+
for task in asyncio.as_completed(tasks):
|
|
70
|
+
result = await task
|
|
71
|
+
results.append(result)
|
|
72
|
+
|
|
73
|
+
# Output in the same order as the input data
|
|
74
|
+
ordered_results = [result for _, result in sorted(results, key=lambda x: x[0])]
|
|
75
|
+
return [item for sublist in ordered_results for item in sublist]
|
|
@@ -1,62 +1,90 @@
|
|
|
1
|
-
"""Repository for
|
|
2
|
-
|
|
3
|
-
from typing import TypeVar
|
|
1
|
+
"""Repository for managing embeddings."""
|
|
4
2
|
|
|
5
3
|
import numpy as np
|
|
6
|
-
from sqlalchemy import
|
|
7
|
-
select,
|
|
8
|
-
)
|
|
4
|
+
from sqlalchemy import select
|
|
9
5
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
10
6
|
|
|
11
7
|
from kodit.embedding.embedding_models import Embedding, EmbeddingType
|
|
12
|
-
from kodit.indexing.indexing_models import Snippet
|
|
13
|
-
from kodit.source.source_models import File
|
|
14
8
|
|
|
15
|
-
T = TypeVar("T")
|
|
16
9
|
|
|
10
|
+
class EmbeddingRepository:
|
|
11
|
+
"""Repository for managing embeddings.
|
|
12
|
+
|
|
13
|
+
This class provides methods for creating and retrieving embeddings from the
|
|
14
|
+
database. It handles the low-level database operations and transaction management.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
session: The SQLAlchemy async session to use for database operations.
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
"""Repository for searching for relevant snippets."""
|
|
19
|
+
"""
|
|
20
20
|
|
|
21
21
|
def __init__(self, session: AsyncSession) -> None:
|
|
22
|
-
"""Initialize the
|
|
22
|
+
"""Initialize the embedding repository."""
|
|
23
|
+
self.session = session
|
|
24
|
+
|
|
25
|
+
async def create_embedding(self, embedding: Embedding) -> Embedding:
|
|
26
|
+
"""Create a new embedding record in the database.
|
|
23
27
|
|
|
24
28
|
Args:
|
|
25
|
-
|
|
29
|
+
embedding: The Embedding instance to create.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The created Embedding instance.
|
|
26
33
|
|
|
27
34
|
"""
|
|
28
|
-
self.session
|
|
35
|
+
self.session.add(embedding)
|
|
36
|
+
await self.session.commit()
|
|
37
|
+
return embedding
|
|
29
38
|
|
|
30
|
-
async def
|
|
31
|
-
|
|
39
|
+
async def get_embedding_by_snippet_id_and_type(
|
|
40
|
+
self, snippet_id: int, embedding_type: EmbeddingType
|
|
41
|
+
) -> Embedding | None:
|
|
42
|
+
"""Get an embedding by its snippet ID and type.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
snippet_id: The ID of the snippet to get the embedding for.
|
|
46
|
+
embedding_type: The type of embedding to get.
|
|
32
47
|
|
|
33
48
|
Returns:
|
|
34
|
-
|
|
49
|
+
The Embedding instance if found, None otherwise.
|
|
35
50
|
|
|
36
51
|
"""
|
|
37
|
-
query = select(
|
|
38
|
-
|
|
39
|
-
|
|
52
|
+
query = select(Embedding).where(
|
|
53
|
+
Embedding.snippet_id == snippet_id,
|
|
54
|
+
Embedding.type == embedding_type,
|
|
55
|
+
)
|
|
56
|
+
result = await self.session.execute(query)
|
|
57
|
+
return result.scalar_one_or_none()
|
|
40
58
|
|
|
41
|
-
async def
|
|
42
|
-
|
|
59
|
+
async def list_embeddings_by_type(
|
|
60
|
+
self, embedding_type: EmbeddingType
|
|
61
|
+
) -> list[Embedding]:
|
|
62
|
+
"""List all embeddings of a given type.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
embedding_type: The type of embeddings to list.
|
|
43
66
|
|
|
44
67
|
Returns:
|
|
45
|
-
A list of
|
|
68
|
+
A list of Embedding instances.
|
|
46
69
|
|
|
47
70
|
"""
|
|
48
|
-
query = (
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
.join(File, Snippet.file_id == File.id)
|
|
52
|
-
)
|
|
53
|
-
rows = await self.session.execute(query)
|
|
71
|
+
query = select(Embedding).where(Embedding.type == embedding_type)
|
|
72
|
+
result = await self.session.execute(query)
|
|
73
|
+
return list(result.scalars())
|
|
54
74
|
|
|
55
|
-
|
|
56
|
-
|
|
75
|
+
async def delete_embeddings_by_snippet_id(self, snippet_id: int) -> None:
|
|
76
|
+
"""Delete all embeddings for a snippet.
|
|
57
77
|
|
|
58
|
-
|
|
59
|
-
|
|
78
|
+
Args:
|
|
79
|
+
snippet_id: The ID of the snippet to delete embeddings for.
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
query = select(Embedding).where(Embedding.snippet_id == snippet_id)
|
|
83
|
+
result = await self.session.execute(query)
|
|
84
|
+
embeddings = result.scalars().all()
|
|
85
|
+
for embedding in embeddings:
|
|
86
|
+
await self.session.delete(embedding)
|
|
87
|
+
await self.session.commit()
|
|
60
88
|
|
|
61
89
|
async def list_semantic_results(
|
|
62
90
|
self, embedding_type: EmbeddingType, embedding: list[float], top_k: int = 10
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Local vector search."""
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
import tiktoken
|
|
5
|
+
|
|
6
|
+
from kodit.embedding.embedding_models import Embedding, EmbeddingType
|
|
7
|
+
from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
|
|
8
|
+
from kodit.embedding.embedding_repository import EmbeddingRepository
|
|
9
|
+
from kodit.embedding.vector_search_service import (
|
|
10
|
+
VectorSearchRequest,
|
|
11
|
+
VectorSearchResponse,
|
|
12
|
+
VectorSearchService,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LocalVectorSearchService(VectorSearchService):
|
|
17
|
+
"""Local vector search."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
embedding_repository: EmbeddingRepository,
|
|
22
|
+
embedding_provider: EmbeddingProvider,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initialize the local embedder."""
|
|
25
|
+
self.log = structlog.get_logger(__name__)
|
|
26
|
+
self.embedding_repository = embedding_repository
|
|
27
|
+
self.embedding_provider = embedding_provider
|
|
28
|
+
self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
|
|
29
|
+
|
|
30
|
+
async def index(self, data: list[VectorSearchRequest]) -> None:
|
|
31
|
+
"""Embed a list of documents."""
|
|
32
|
+
embeddings = await self.embedding_provider.embed([i.text for i in data])
|
|
33
|
+
for i, x in zip(data, embeddings, strict=False):
|
|
34
|
+
await self.embedding_repository.create_embedding(
|
|
35
|
+
Embedding(
|
|
36
|
+
snippet_id=i.snippet_id,
|
|
37
|
+
embedding=[float(y) for y in x],
|
|
38
|
+
type=EmbeddingType.CODE,
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
|
|
43
|
+
"""Query the embedding model."""
|
|
44
|
+
embedding = (await self.embedding_provider.embed([query]))[0]
|
|
45
|
+
results = await self.embedding_repository.list_semantic_results(
|
|
46
|
+
EmbeddingType.CODE, [float(x) for x in embedding], top_k
|
|
47
|
+
)
|
|
48
|
+
return [
|
|
49
|
+
VectorSearchResponse(snippet_id, score) for snippet_id, score in results
|
|
50
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Embedding service."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class VectorSearchResponse(NamedTuple):
|
|
8
|
+
"""Embedding result."""
|
|
9
|
+
|
|
10
|
+
snippet_id: int
|
|
11
|
+
score: float
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VectorSearchRequest(NamedTuple):
|
|
15
|
+
"""Input for embedding."""
|
|
16
|
+
|
|
17
|
+
snippet_id: int
|
|
18
|
+
text: str
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class VectorSearchService(ABC):
|
|
22
|
+
"""Semantic search service interface."""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
async def index(self, data: list[VectorSearchRequest]) -> None:
|
|
26
|
+
"""Embed a list of documents.
|
|
27
|
+
|
|
28
|
+
The embedding service accepts a massive list of id,strings to embed. Behind the
|
|
29
|
+
scenes it batches up requests and parallelizes them for performance according to
|
|
30
|
+
the specifics of the embedding service.
|
|
31
|
+
|
|
32
|
+
The id reference is required because the parallelization may return results out
|
|
33
|
+
of order.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
|
|
38
|
+
"""Query the embedding model."""
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Vectorchord vector search."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlalchemy import Result, TextClause, text
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
|
|
9
|
+
from kodit.embedding.vector_search_service import (
|
|
10
|
+
VectorSearchRequest,
|
|
11
|
+
VectorSearchResponse,
|
|
12
|
+
VectorSearchService,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
# SQL Queries
|
|
16
|
+
CREATE_VCHORD_EXTENSION = """
|
|
17
|
+
CREATE EXTENSION IF NOT EXISTS vchord CASCADE;
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
CHECK_VCHORD_EMBEDDING_DIMENSION = """
|
|
21
|
+
SELECT a.atttypmod as dimension
|
|
22
|
+
FROM pg_attribute a
|
|
23
|
+
JOIN pg_class c ON a.attrelid = c.oid
|
|
24
|
+
WHERE c.relname = '{TABLE_NAME}'
|
|
25
|
+
AND a.attname = 'embedding';
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
CREATE_VCHORD_INDEX = """
|
|
29
|
+
CREATE INDEX IF NOT EXISTS {INDEX_NAME}
|
|
30
|
+
ON {TABLE_NAME}
|
|
31
|
+
USING vchordrq (embedding vector_l2_ops) WITH (options = $$
|
|
32
|
+
residual_quantization = true
|
|
33
|
+
[build.internal]
|
|
34
|
+
lists = []
|
|
35
|
+
$$);
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
INSERT_QUERY = """
|
|
39
|
+
INSERT INTO {TABLE_NAME} (snippet_id, embedding)
|
|
40
|
+
VALUES (:snippet_id, :embedding)
|
|
41
|
+
ON CONFLICT (snippet_id) DO UPDATE
|
|
42
|
+
SET embedding = EXCLUDED.embedding
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
# Note that <=> in vectorchord is cosine distance
|
|
46
|
+
# So scores go from 0 (similar) to 2 (opposite)
|
|
47
|
+
SEARCH_QUERY = """
|
|
48
|
+
SELECT snippet_id, embedding <=> :query as score
|
|
49
|
+
FROM {TABLE_NAME}
|
|
50
|
+
ORDER BY score ASC
|
|
51
|
+
LIMIT :top_k;
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class VectorChordVectorSearchService(VectorSearchService):
|
|
56
|
+
"""VectorChord vector search."""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
task_name: str,
|
|
61
|
+
session: AsyncSession,
|
|
62
|
+
embedding_provider: EmbeddingProvider,
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Initialize the VectorChord BM25."""
|
|
65
|
+
self.embedding_provider = embedding_provider
|
|
66
|
+
self._session = session
|
|
67
|
+
self._initialized = False
|
|
68
|
+
self.table_name = f"vectorchord_{task_name}_embeddings"
|
|
69
|
+
self.index_name = f"{self.table_name}_idx"
|
|
70
|
+
|
|
71
|
+
async def _initialize(self) -> None:
|
|
72
|
+
"""Initialize the VectorChord environment."""
|
|
73
|
+
try:
|
|
74
|
+
await self._create_extensions()
|
|
75
|
+
await self._create_tables()
|
|
76
|
+
self._initialized = True
|
|
77
|
+
except Exception as e:
|
|
78
|
+
msg = f"Failed to initialize VectorChord repository: {e}"
|
|
79
|
+
raise RuntimeError(msg) from e
|
|
80
|
+
|
|
81
|
+
async def _create_extensions(self) -> None:
|
|
82
|
+
"""Create the necessary extensions."""
|
|
83
|
+
await self._session.execute(text(CREATE_VCHORD_EXTENSION))
|
|
84
|
+
await self._commit()
|
|
85
|
+
|
|
86
|
+
async def _create_tables(self) -> None:
|
|
87
|
+
"""Create the necessary tables."""
|
|
88
|
+
vector_dim = (await self.embedding_provider.embed(["dimension"]))[0]
|
|
89
|
+
await self._session.execute(
|
|
90
|
+
text(
|
|
91
|
+
f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
92
|
+
id SERIAL PRIMARY KEY,
|
|
93
|
+
snippet_id INT NOT NULL UNIQUE,
|
|
94
|
+
embedding VECTOR({len(vector_dim)}) NOT NULL
|
|
95
|
+
);"""
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
await self._session.execute(
|
|
99
|
+
text(
|
|
100
|
+
CREATE_VCHORD_INDEX.format(
|
|
101
|
+
TABLE_NAME=self.table_name, INDEX_NAME=self.index_name
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
result = await self._session.execute(
|
|
106
|
+
text(CHECK_VCHORD_EMBEDDING_DIMENSION.format(TABLE_NAME=self.table_name))
|
|
107
|
+
)
|
|
108
|
+
vector_dim_from_db = result.scalar_one()
|
|
109
|
+
if vector_dim_from_db != len(vector_dim):
|
|
110
|
+
msg = (
|
|
111
|
+
f"Embedding vector dimension does not match database, "
|
|
112
|
+
f"please delete your index: {vector_dim_from_db} != {len(vector_dim)}"
|
|
113
|
+
)
|
|
114
|
+
raise ValueError(msg)
|
|
115
|
+
await self._commit()
|
|
116
|
+
|
|
117
|
+
async def _execute(
|
|
118
|
+
self, query: TextClause, param_list: list[Any] | dict[str, Any] | None = None
|
|
119
|
+
) -> Result:
|
|
120
|
+
"""Execute a query."""
|
|
121
|
+
if not self._initialized:
|
|
122
|
+
await self._initialize()
|
|
123
|
+
return await self._session.execute(query, param_list)
|
|
124
|
+
|
|
125
|
+
async def _commit(self) -> None:
|
|
126
|
+
"""Commit the session."""
|
|
127
|
+
await self._session.commit()
|
|
128
|
+
|
|
129
|
+
async def index(self, data: list[VectorSearchRequest]) -> None:
|
|
130
|
+
"""Embed a list of documents."""
|
|
131
|
+
embeddings = await self.embedding_provider.embed([doc.text for doc in data])
|
|
132
|
+
# Execute inserts
|
|
133
|
+
await self._execute(
|
|
134
|
+
text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
|
|
135
|
+
[
|
|
136
|
+
{"snippet_id": doc.snippet_id, "embedding": str(embedding)}
|
|
137
|
+
for doc, embedding in zip(data, embeddings, strict=True)
|
|
138
|
+
],
|
|
139
|
+
)
|
|
140
|
+
await self._commit()
|
|
141
|
+
|
|
142
|
+
async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
|
|
143
|
+
"""Query the embedding model."""
|
|
144
|
+
embedding = await self.embedding_provider.embed([query])
|
|
145
|
+
result = await self._execute(
|
|
146
|
+
text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
|
|
147
|
+
{"query": str(embedding[0]), "top_k": top_k},
|
|
148
|
+
)
|
|
149
|
+
rows = result.mappings().all()
|
|
150
|
+
|
|
151
|
+
return [
|
|
152
|
+
VectorSearchResponse(snippet_id=row["snippet_id"], score=row["score"])
|
|
153
|
+
for row in rows
|
|
154
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Enrichment."""
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Embedding service."""
|
|
2
|
+
|
|
3
|
+
from kodit.config import AppContext
|
|
4
|
+
from kodit.enrichment.enrichment_provider.local_enrichment_provider import (
|
|
5
|
+
LocalEnrichmentProvider,
|
|
6
|
+
)
|
|
7
|
+
from kodit.enrichment.enrichment_provider.openai_enrichment_provider import (
|
|
8
|
+
OpenAIEnrichmentProvider,
|
|
9
|
+
)
|
|
10
|
+
from kodit.enrichment.enrichment_service import (
|
|
11
|
+
EnrichmentService,
|
|
12
|
+
LLMEnrichmentService,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def enrichment_factory(app_context: AppContext) -> EnrichmentService:
|
|
17
|
+
"""Create an embedding service."""
|
|
18
|
+
openai_client = app_context.get_default_openai_client()
|
|
19
|
+
if openai_client is not None:
|
|
20
|
+
enrichment_provider = OpenAIEnrichmentProvider(openai_client=openai_client)
|
|
21
|
+
return LLMEnrichmentService(enrichment_provider)
|
|
22
|
+
|
|
23
|
+
return LLMEnrichmentService(LocalEnrichmentProvider())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Enrichment provider."""
|