kodit 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (42) hide show
  1. kodit/_version.py +2 -2
  2. kodit/bm25/keyword_search_factory.py +17 -0
  3. kodit/bm25/keyword_search_service.py +34 -0
  4. kodit/bm25/{bm25.py → local_bm25.py} +40 -14
  5. kodit/bm25/vectorchord_bm25.py +193 -0
  6. kodit/cli.py +114 -25
  7. kodit/config.py +9 -2
  8. kodit/database.py +4 -2
  9. kodit/embedding/embedding_factory.py +44 -0
  10. kodit/embedding/embedding_provider/__init__.py +1 -0
  11. kodit/embedding/embedding_provider/embedding_provider.py +60 -0
  12. kodit/embedding/embedding_provider/hash_embedding_provider.py +77 -0
  13. kodit/embedding/embedding_provider/local_embedding_provider.py +58 -0
  14. kodit/embedding/embedding_provider/openai_embedding_provider.py +75 -0
  15. kodit/{search/search_repository.py → embedding/embedding_repository.py} +61 -33
  16. kodit/embedding/local_vector_search_service.py +50 -0
  17. kodit/embedding/vector_search_service.py +38 -0
  18. kodit/embedding/vectorchord_vector_search_service.py +154 -0
  19. kodit/enrichment/__init__.py +1 -0
  20. kodit/enrichment/enrichment_factory.py +23 -0
  21. kodit/enrichment/enrichment_provider/__init__.py +1 -0
  22. kodit/enrichment/enrichment_provider/enrichment_provider.py +16 -0
  23. kodit/enrichment/enrichment_provider/local_enrichment_provider.py +63 -0
  24. kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +77 -0
  25. kodit/enrichment/enrichment_service.py +33 -0
  26. kodit/indexing/fusion.py +67 -0
  27. kodit/indexing/indexing_repository.py +44 -4
  28. kodit/indexing/indexing_service.py +142 -31
  29. kodit/mcp.py +31 -18
  30. kodit/snippets/languages/go.scm +26 -0
  31. kodit/source/source_service.py +9 -3
  32. kodit/util/__init__.py +1 -0
  33. kodit/util/spinner.py +59 -0
  34. {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/METADATA +4 -1
  35. kodit-0.1.16.dist-info/RECORD +64 -0
  36. kodit/embedding/embedding.py +0 -203
  37. kodit/search/__init__.py +0 -1
  38. kodit/search/search_service.py +0 -147
  39. kodit-0.1.14.dist-info/RECORD +0 -44
  40. {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/WHEEL +0 -0
  41. {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/entry_points.txt +0 -0
  42. {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,60 @@
1
+ """Embedding provider."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ import structlog
6
+ import tiktoken
7
+
8
+ OPENAI_MAX_EMBEDDING_SIZE = 8192
9
+
10
+ Vector = list[float]
11
+
12
+
13
+ class EmbeddingProvider(ABC):
14
+ """Embedding provider."""
15
+
16
+ @abstractmethod
17
+ async def embed(self, data: list[str]) -> list[Vector]:
18
+ """Embed a list of strings.
19
+
20
+ The embedding provider is responsible for embedding a list of strings into a
21
+ list of vectors. The embedding provider is responsible for splitting the list of
22
+ strings into smaller sub-batches and embedding them in parallel.
23
+ """
24
+
25
+
26
+ def split_sub_batches(encoding: tiktoken.Encoding, data: list[str]) -> list[list[str]]:
27
+ """Split a list of strings into smaller sub-batches."""
28
+ log = structlog.get_logger(__name__)
29
+ result = []
30
+ data_to_process = [s for s in data if s.strip()] # Filter out empty strings
31
+
32
+ while data_to_process:
33
+ next_batch = []
34
+ current_tokens = 0
35
+
36
+ while data_to_process:
37
+ next_item = data_to_process[0]
38
+ item_tokens = len(encoding.encode(next_item))
39
+
40
+ if item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
41
+ # Loop around trying to truncate the snippet until it fits in the max
42
+ # embedding size
43
+ while item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
44
+ next_item = next_item[:-1]
45
+ item_tokens = len(encoding.encode(next_item))
46
+
47
+ data_to_process[0] = next_item
48
+
49
+ log.warning("Truncated snippet", snippet=next_item)
50
+
51
+ if current_tokens + item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
52
+ break
53
+
54
+ next_batch.append(data_to_process.pop(0))
55
+ current_tokens += item_tokens
56
+
57
+ if next_batch:
58
+ result.append(next_batch)
59
+
60
+ return result
@@ -0,0 +1,77 @@
1
+ """Hash embedding provider, for use in tests only."""
2
+
3
+ import asyncio
4
+ import hashlib
5
+ import math
6
+ from collections.abc import Generator, Sequence
7
+
8
+ from kodit.embedding.embedding_provider.embedding_provider import (
9
+ EmbeddingProvider,
10
+ Vector,
11
+ )
12
+
13
+
14
+ class HashEmbeddingProvider(EmbeddingProvider):
15
+ """A minimal test-time embedding provider.
16
+
17
+ • Zero third-party dependencies (uses only std-lib)
18
+ • Distinguishes strings by hashing with SHA-256
19
+ • Maps the digest to a fixed-size float vector, then ℓ₂-normalises
20
+ • Splits work into small asynchronous chunks for speed in event loops
21
+ """
22
+
23
+ def __init__(self, dim: int = 16, batch_size: int = 64) -> None:
24
+ """Initialize the hash embedding provider."""
25
+ if dim <= 0:
26
+ msg = f"dim must be > 0, got {dim}"
27
+ raise ValueError(msg)
28
+ if batch_size <= 0:
29
+ msg = f"batch_size must be > 0, got {batch_size}"
30
+ raise ValueError(msg)
31
+ self.dim = dim
32
+ self.batch_size = batch_size
33
+
34
+ async def embed(self, data: list[str]) -> list[Vector]:
35
+ """Embed every string in *data*, preserving order.
36
+
37
+ Work is sliced into *batch_size* chunks and scheduled concurrently
38
+ (still CPU-bound, but enough to cooperate with an asyncio loop).
39
+ """
40
+ if not data:
41
+ return []
42
+
43
+ async def _embed_chunk(chunk: Sequence[str]) -> list[Vector]:
44
+ return [self._string_to_vector(text) for text in chunk]
45
+
46
+ tasks = [
47
+ asyncio.create_task(_embed_chunk(chunk))
48
+ for chunk in self._chunked(data, self.batch_size)
49
+ ]
50
+
51
+ vectors: list[Vector] = []
52
+ for task in tasks:
53
+ vectors.extend(await task)
54
+ return vectors
55
+
56
+ @staticmethod
57
+ def _chunked(seq: Sequence[str], size: int) -> Generator[Sequence[str], None, None]:
58
+ """Yield successive *size*-sized slices from *seq*."""
59
+ for i in range(0, len(seq), size):
60
+ yield seq[i : i + size]
61
+
62
+ def _string_to_vector(self, text: str) -> Vector:
63
+ """Deterministically convert *text* to a normalised float vector."""
64
+ digest = hashlib.sha256(text.encode("utf-8")).digest()
65
+
66
+ # Build the vector from 4-byte windows of the digest.
67
+ vec = [
68
+ int.from_bytes(
69
+ digest[(i * 4) % len(digest) : (i * 4) % len(digest) + 4], "big"
70
+ )
71
+ / 0xFFFFFFFF
72
+ for i in range(self.dim)
73
+ ]
74
+
75
+ # ℓ₂-normalise so magnitudes are comparable.
76
+ norm = math.sqrt(sum(x * x for x in vec)) or 1.0
77
+ return [x / norm for x in vec]
@@ -0,0 +1,58 @@
1
+ """Local embedding service."""
2
+
3
+ import os
4
+
5
+ import structlog
6
+ import tiktoken
7
+ from sentence_transformers import SentenceTransformer
8
+ from tqdm import tqdm
9
+
10
+ from kodit.embedding.embedding_provider.embedding_provider import (
11
+ EmbeddingProvider,
12
+ Vector,
13
+ split_sub_batches,
14
+ )
15
+
16
+ TINY = "tiny"
17
+ CODE = "code"
18
+ TEST = "test"
19
+
20
+ COMMON_EMBEDDING_MODELS = {
21
+ TINY: "ibm-granite/granite-embedding-30m-english",
22
+ CODE: "flax-sentence-embeddings/st-codesearch-distilroberta-base",
23
+ TEST: "minishlab/potion-base-4M",
24
+ }
25
+
26
+
27
+ class LocalEmbeddingProvider(EmbeddingProvider):
28
+ """Local embedder."""
29
+
30
+ def __init__(self, model_name: str) -> None:
31
+ """Initialize the local embedder."""
32
+ self.log = structlog.get_logger(__name__)
33
+ self.model_name = COMMON_EMBEDDING_MODELS.get(model_name, model_name)
34
+ self.embedding_model = None
35
+ self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
36
+
37
+ def _model(self) -> SentenceTransformer:
38
+ """Get the embedding model."""
39
+ if self.embedding_model is None:
40
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
41
+ self.embedding_model = SentenceTransformer(
42
+ self.model_name,
43
+ trust_remote_code=True,
44
+ device="cpu", # Force CPU so we don't have to install accelerate, etc.
45
+ )
46
+ return self.embedding_model
47
+
48
+ async def embed(self, data: list[str]) -> list[Vector]:
49
+ """Embed a list of strings."""
50
+ model = self._model()
51
+
52
+ batched_data = split_sub_batches(self.encoding, data)
53
+
54
+ results: list[Vector] = []
55
+ for batch in tqdm(batched_data, total=len(batched_data), leave=False):
56
+ embeddings = model.encode(batch, show_progress_bar=False, batch_size=4)
57
+ results.extend([[float(x) for x in embedding] for embedding in embeddings])
58
+ return results
@@ -0,0 +1,75 @@
1
+ """OpenAI embedding service."""
2
+
3
+ import asyncio
4
+
5
+ import structlog
6
+ import tiktoken
7
+ from openai import AsyncOpenAI
8
+
9
+ from kodit.embedding.embedding_provider.embedding_provider import (
10
+ EmbeddingProvider,
11
+ Vector,
12
+ split_sub_batches,
13
+ )
14
+
15
+ OPENAI_NUM_PARALLEL_TASKS = 10
16
+
17
+
18
+ class OpenAIEmbeddingProvider(EmbeddingProvider):
19
+ """OpenAI embedder."""
20
+
21
+ def __init__(
22
+ self,
23
+ openai_client: AsyncOpenAI,
24
+ model_name: str = "text-embedding-3-small",
25
+ ) -> None:
26
+ """Initialize the OpenAI embedder."""
27
+ self.log = structlog.get_logger(__name__)
28
+ self.openai_client = openai_client
29
+ self.model_name = model_name
30
+ self.encoding = tiktoken.encoding_for_model(model_name)
31
+
32
+ async def embed(self, data: list[str]) -> list[Vector]:
33
+ """Embed a list of documents."""
34
+ # First split the list into a list of list where each sublist has fewer than
35
+ # max tokens.
36
+ batched_data = split_sub_batches(self.encoding, data)
37
+
38
+ # Process batches in parallel with a semaphore to limit concurrent requests
39
+ sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
40
+
41
+ # Create a list of tuples with a temporary id for each batch
42
+ # We need to do this so that we can return the results in the same order as the
43
+ # input data
44
+ input_data = [(i, batch) for i, batch in enumerate(batched_data)]
45
+
46
+ async def process_batch(
47
+ data: tuple[int, list[str]],
48
+ ) -> tuple[int, list[Vector]]:
49
+ batch_id, batch = data
50
+ async with sem:
51
+ try:
52
+ response = await self.openai_client.embeddings.create(
53
+ model=self.model_name,
54
+ input=batch,
55
+ )
56
+ return batch_id, [
57
+ [float(x) for x in embedding.embedding]
58
+ for embedding in response.data
59
+ ]
60
+ except Exception as e:
61
+ self.log.exception("Error embedding batch", error=str(e))
62
+ return batch_id, []
63
+
64
+ # Create tasks for all batches
65
+ tasks = [process_batch(batch) for batch in input_data]
66
+
67
+ # Process all batches and yield results as they complete
68
+ results: list[tuple[int, list[Vector]]] = []
69
+ for task in asyncio.as_completed(tasks):
70
+ result = await task
71
+ results.append(result)
72
+
73
+ # Output in the same order as the input data
74
+ ordered_results = [result for _, result in sorted(results, key=lambda x: x[0])]
75
+ return [item for sublist in ordered_results for item in sublist]
@@ -1,62 +1,90 @@
1
- """Repository for searching for relevant snippets."""
2
-
3
- from typing import TypeVar
1
+ """Repository for managing embeddings."""
4
2
 
5
3
  import numpy as np
6
- from sqlalchemy import (
7
- select,
8
- )
4
+ from sqlalchemy import select
9
5
  from sqlalchemy.ext.asyncio import AsyncSession
10
6
 
11
7
  from kodit.embedding.embedding_models import Embedding, EmbeddingType
12
- from kodit.indexing.indexing_models import Snippet
13
- from kodit.source.source_models import File
14
8
 
15
- T = TypeVar("T")
16
9
 
10
+ class EmbeddingRepository:
11
+ """Repository for managing embeddings.
12
+
13
+ This class provides methods for creating and retrieving embeddings from the
14
+ database. It handles the low-level database operations and transaction management.
15
+
16
+ Args:
17
+ session: The SQLAlchemy async session to use for database operations.
17
18
 
18
- class SearchRepository:
19
- """Repository for searching for relevant snippets."""
19
+ """
20
20
 
21
21
  def __init__(self, session: AsyncSession) -> None:
22
- """Initialize the search repository.
22
+ """Initialize the embedding repository."""
23
+ self.session = session
24
+
25
+ async def create_embedding(self, embedding: Embedding) -> Embedding:
26
+ """Create a new embedding record in the database.
23
27
 
24
28
  Args:
25
- session: The SQLAlchemy async session to use for database operations.
29
+ embedding: The Embedding instance to create.
30
+
31
+ Returns:
32
+ The created Embedding instance.
26
33
 
27
34
  """
28
- self.session = session
35
+ self.session.add(embedding)
36
+ await self.session.commit()
37
+ return embedding
29
38
 
30
- async def list_snippet_ids(self) -> list[int]:
31
- """List all snippet IDs.
39
+ async def get_embedding_by_snippet_id_and_type(
40
+ self, snippet_id: int, embedding_type: EmbeddingType
41
+ ) -> Embedding | None:
42
+ """Get an embedding by its snippet ID and type.
43
+
44
+ Args:
45
+ snippet_id: The ID of the snippet to get the embedding for.
46
+ embedding_type: The type of embedding to get.
32
47
 
33
48
  Returns:
34
- A list of all snippets.
49
+ The Embedding instance if found, None otherwise.
35
50
 
36
51
  """
37
- query = select(Snippet.id)
38
- rows = await self.session.execute(query)
39
- return list(rows.scalars().all())
52
+ query = select(Embedding).where(
53
+ Embedding.snippet_id == snippet_id,
54
+ Embedding.type == embedding_type,
55
+ )
56
+ result = await self.session.execute(query)
57
+ return result.scalar_one_or_none()
40
58
 
41
- async def list_snippets_by_ids(self, ids: list[int]) -> list[tuple[File, Snippet]]:
42
- """List snippets by IDs.
59
+ async def list_embeddings_by_type(
60
+ self, embedding_type: EmbeddingType
61
+ ) -> list[Embedding]:
62
+ """List all embeddings of a given type.
63
+
64
+ Args:
65
+ embedding_type: The type of embeddings to list.
43
66
 
44
67
  Returns:
45
- A list of snippets in the same order as the input IDs.
68
+ A list of Embedding instances.
46
69
 
47
70
  """
48
- query = (
49
- select(Snippet, File)
50
- .where(Snippet.id.in_(ids))
51
- .join(File, Snippet.file_id == File.id)
52
- )
53
- rows = await self.session.execute(query)
71
+ query = select(Embedding).where(Embedding.type == embedding_type)
72
+ result = await self.session.execute(query)
73
+ return list(result.scalars())
54
74
 
55
- # Create a dictionary for O(1) lookup of results by ID
56
- id_to_result = {snippet.id: (file, snippet) for snippet, file in rows.all()}
75
+ async def delete_embeddings_by_snippet_id(self, snippet_id: int) -> None:
76
+ """Delete all embeddings for a snippet.
57
77
 
58
- # Return results in the same order as input IDs
59
- return [id_to_result[i] for i in ids]
78
+ Args:
79
+ snippet_id: The ID of the snippet to delete embeddings for.
80
+
81
+ """
82
+ query = select(Embedding).where(Embedding.snippet_id == snippet_id)
83
+ result = await self.session.execute(query)
84
+ embeddings = result.scalars().all()
85
+ for embedding in embeddings:
86
+ await self.session.delete(embedding)
87
+ await self.session.commit()
60
88
 
61
89
  async def list_semantic_results(
62
90
  self, embedding_type: EmbeddingType, embedding: list[float], top_k: int = 10
@@ -0,0 +1,50 @@
1
+ """Local vector search."""
2
+
3
+ import structlog
4
+ import tiktoken
5
+
6
+ from kodit.embedding.embedding_models import Embedding, EmbeddingType
7
+ from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
8
+ from kodit.embedding.embedding_repository import EmbeddingRepository
9
+ from kodit.embedding.vector_search_service import (
10
+ VectorSearchRequest,
11
+ VectorSearchResponse,
12
+ VectorSearchService,
13
+ )
14
+
15
+
16
+ class LocalVectorSearchService(VectorSearchService):
17
+ """Local vector search."""
18
+
19
+ def __init__(
20
+ self,
21
+ embedding_repository: EmbeddingRepository,
22
+ embedding_provider: EmbeddingProvider,
23
+ ) -> None:
24
+ """Initialize the local embedder."""
25
+ self.log = structlog.get_logger(__name__)
26
+ self.embedding_repository = embedding_repository
27
+ self.embedding_provider = embedding_provider
28
+ self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
29
+
30
+ async def index(self, data: list[VectorSearchRequest]) -> None:
31
+ """Embed a list of documents."""
32
+ embeddings = await self.embedding_provider.embed([i.text for i in data])
33
+ for i, x in zip(data, embeddings, strict=False):
34
+ await self.embedding_repository.create_embedding(
35
+ Embedding(
36
+ snippet_id=i.snippet_id,
37
+ embedding=[float(y) for y in x],
38
+ type=EmbeddingType.CODE,
39
+ )
40
+ )
41
+
42
+ async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
43
+ """Query the embedding model."""
44
+ embedding = (await self.embedding_provider.embed([query]))[0]
45
+ results = await self.embedding_repository.list_semantic_results(
46
+ EmbeddingType.CODE, [float(x) for x in embedding], top_k
47
+ )
48
+ return [
49
+ VectorSearchResponse(snippet_id, score) for snippet_id, score in results
50
+ ]
@@ -0,0 +1,38 @@
1
+ """Embedding service."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import NamedTuple
5
+
6
+
7
+ class VectorSearchResponse(NamedTuple):
8
+ """Embedding result."""
9
+
10
+ snippet_id: int
11
+ score: float
12
+
13
+
14
+ class VectorSearchRequest(NamedTuple):
15
+ """Input for embedding."""
16
+
17
+ snippet_id: int
18
+ text: str
19
+
20
+
21
+ class VectorSearchService(ABC):
22
+ """Semantic search service interface."""
23
+
24
+ @abstractmethod
25
+ async def index(self, data: list[VectorSearchRequest]) -> None:
26
+ """Embed a list of documents.
27
+
28
+ The embedding service accepts a massive list of id,strings to embed. Behind the
29
+ scenes it batches up requests and parallelizes them for performance according to
30
+ the specifics of the embedding service.
31
+
32
+ The id reference is required because the parallelization may return results out
33
+ of order.
34
+ """
35
+
36
+ @abstractmethod
37
+ async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
38
+ """Query the embedding model."""
@@ -0,0 +1,154 @@
1
+ """Vectorchord vector search."""
2
+
3
+ from typing import Any
4
+
5
+ from sqlalchemy import Result, TextClause, text
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
9
+ from kodit.embedding.vector_search_service import (
10
+ VectorSearchRequest,
11
+ VectorSearchResponse,
12
+ VectorSearchService,
13
+ )
14
+
15
+ # SQL Queries
16
+ CREATE_VCHORD_EXTENSION = """
17
+ CREATE EXTENSION IF NOT EXISTS vchord CASCADE;
18
+ """
19
+
20
+ CHECK_VCHORD_EMBEDDING_DIMENSION = """
21
+ SELECT a.atttypmod as dimension
22
+ FROM pg_attribute a
23
+ JOIN pg_class c ON a.attrelid = c.oid
24
+ WHERE c.relname = '{TABLE_NAME}'
25
+ AND a.attname = 'embedding';
26
+ """
27
+
28
+ CREATE_VCHORD_INDEX = """
29
+ CREATE INDEX IF NOT EXISTS {INDEX_NAME}
30
+ ON {TABLE_NAME}
31
+ USING vchordrq (embedding vector_l2_ops) WITH (options = $$
32
+ residual_quantization = true
33
+ [build.internal]
34
+ lists = []
35
+ $$);
36
+ """
37
+
38
+ INSERT_QUERY = """
39
+ INSERT INTO {TABLE_NAME} (snippet_id, embedding)
40
+ VALUES (:snippet_id, :embedding)
41
+ ON CONFLICT (snippet_id) DO UPDATE
42
+ SET embedding = EXCLUDED.embedding
43
+ """
44
+
45
+ # Note that <=> in vectorchord is cosine distance
46
+ # So scores go from 0 (similar) to 2 (opposite)
47
+ SEARCH_QUERY = """
48
+ SELECT snippet_id, embedding <=> :query as score
49
+ FROM {TABLE_NAME}
50
+ ORDER BY score ASC
51
+ LIMIT :top_k;
52
+ """
53
+
54
+
55
+ class VectorChordVectorSearchService(VectorSearchService):
56
+ """VectorChord vector search."""
57
+
58
+ def __init__(
59
+ self,
60
+ task_name: str,
61
+ session: AsyncSession,
62
+ embedding_provider: EmbeddingProvider,
63
+ ) -> None:
64
+ """Initialize the VectorChord BM25."""
65
+ self.embedding_provider = embedding_provider
66
+ self._session = session
67
+ self._initialized = False
68
+ self.table_name = f"vectorchord_{task_name}_embeddings"
69
+ self.index_name = f"{self.table_name}_idx"
70
+
71
+ async def _initialize(self) -> None:
72
+ """Initialize the VectorChord environment."""
73
+ try:
74
+ await self._create_extensions()
75
+ await self._create_tables()
76
+ self._initialized = True
77
+ except Exception as e:
78
+ msg = f"Failed to initialize VectorChord repository: {e}"
79
+ raise RuntimeError(msg) from e
80
+
81
+ async def _create_extensions(self) -> None:
82
+ """Create the necessary extensions."""
83
+ await self._session.execute(text(CREATE_VCHORD_EXTENSION))
84
+ await self._commit()
85
+
86
+ async def _create_tables(self) -> None:
87
+ """Create the necessary tables."""
88
+ vector_dim = (await self.embedding_provider.embed(["dimension"]))[0]
89
+ await self._session.execute(
90
+ text(
91
+ f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
92
+ id SERIAL PRIMARY KEY,
93
+ snippet_id INT NOT NULL UNIQUE,
94
+ embedding VECTOR({len(vector_dim)}) NOT NULL
95
+ );"""
96
+ )
97
+ )
98
+ await self._session.execute(
99
+ text(
100
+ CREATE_VCHORD_INDEX.format(
101
+ TABLE_NAME=self.table_name, INDEX_NAME=self.index_name
102
+ )
103
+ )
104
+ )
105
+ result = await self._session.execute(
106
+ text(CHECK_VCHORD_EMBEDDING_DIMENSION.format(TABLE_NAME=self.table_name))
107
+ )
108
+ vector_dim_from_db = result.scalar_one()
109
+ if vector_dim_from_db != len(vector_dim):
110
+ msg = (
111
+ f"Embedding vector dimension does not match database, "
112
+ f"please delete your index: {vector_dim_from_db} != {len(vector_dim)}"
113
+ )
114
+ raise ValueError(msg)
115
+ await self._commit()
116
+
117
+ async def _execute(
118
+ self, query: TextClause, param_list: list[Any] | dict[str, Any] | None = None
119
+ ) -> Result:
120
+ """Execute a query."""
121
+ if not self._initialized:
122
+ await self._initialize()
123
+ return await self._session.execute(query, param_list)
124
+
125
+ async def _commit(self) -> None:
126
+ """Commit the session."""
127
+ await self._session.commit()
128
+
129
+ async def index(self, data: list[VectorSearchRequest]) -> None:
130
+ """Embed a list of documents."""
131
+ embeddings = await self.embedding_provider.embed([doc.text for doc in data])
132
+ # Execute inserts
133
+ await self._execute(
134
+ text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
135
+ [
136
+ {"snippet_id": doc.snippet_id, "embedding": str(embedding)}
137
+ for doc, embedding in zip(data, embeddings, strict=True)
138
+ ],
139
+ )
140
+ await self._commit()
141
+
142
+ async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
143
+ """Query the embedding model."""
144
+ embedding = await self.embedding_provider.embed([query])
145
+ result = await self._execute(
146
+ text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
147
+ {"query": str(embedding[0]), "top_k": top_k},
148
+ )
149
+ rows = result.mappings().all()
150
+
151
+ return [
152
+ VectorSearchResponse(snippet_id=row["snippet_id"], score=row["score"])
153
+ for row in rows
154
+ ]
@@ -0,0 +1 @@
1
+ """Enrichment."""
@@ -0,0 +1,23 @@
1
+ """Embedding service."""
2
+
3
+ from kodit.config import AppContext
4
+ from kodit.enrichment.enrichment_provider.local_enrichment_provider import (
5
+ LocalEnrichmentProvider,
6
+ )
7
+ from kodit.enrichment.enrichment_provider.openai_enrichment_provider import (
8
+ OpenAIEnrichmentProvider,
9
+ )
10
+ from kodit.enrichment.enrichment_service import (
11
+ EnrichmentService,
12
+ LLMEnrichmentService,
13
+ )
14
+
15
+
16
+ def enrichment_factory(app_context: AppContext) -> EnrichmentService:
17
+ """Create an embedding service."""
18
+ openai_client = app_context.get_default_openai_client()
19
+ if openai_client is not None:
20
+ enrichment_provider = OpenAIEnrichmentProvider(openai_client=openai_client)
21
+ return LLMEnrichmentService(enrichment_provider)
22
+
23
+ return LLMEnrichmentService(LocalEnrichmentProvider())
@@ -0,0 +1 @@
1
+ """Enrichment provider."""