kodit 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (33) hide show
  1. kodit/_version.py +2 -2
  2. kodit/bm25/keyword_search_factory.py +17 -0
  3. kodit/bm25/keyword_search_service.py +34 -0
  4. kodit/bm25/{bm25.py → local_bm25.py} +40 -14
  5. kodit/bm25/vectorchord_bm25.py +193 -0
  6. kodit/cli.py +14 -11
  7. kodit/config.py +9 -2
  8. kodit/database.py +4 -2
  9. kodit/embedding/embedding_factory.py +44 -0
  10. kodit/embedding/embedding_provider/__init__.py +1 -0
  11. kodit/embedding/embedding_provider/embedding_provider.py +53 -0
  12. kodit/embedding/embedding_provider/hash_embedding_provider.py +77 -0
  13. kodit/embedding/embedding_provider/local_embedding_provider.py +58 -0
  14. kodit/embedding/embedding_provider/openai_embedding_provider.py +63 -0
  15. kodit/embedding/embedding_repository.py +206 -0
  16. kodit/embedding/local_vector_search_service.py +50 -0
  17. kodit/embedding/vector_search_service.py +38 -0
  18. kodit/embedding/vectorchord_vector_search_service.py +145 -0
  19. kodit/indexing/indexing_repository.py +24 -4
  20. kodit/indexing/indexing_service.py +25 -30
  21. kodit/mcp.py +28 -7
  22. kodit/search/search_repository.py +0 -121
  23. kodit/search/search_service.py +12 -24
  24. kodit/source/source_service.py +9 -3
  25. kodit/util/__init__.py +1 -0
  26. kodit/util/spinner.py +59 -0
  27. {kodit-0.1.13.dist-info → kodit-0.1.15.dist-info}/METADATA +2 -1
  28. kodit-0.1.15.dist-info/RECORD +58 -0
  29. kodit/embedding/embedding.py +0 -203
  30. kodit-0.1.13.dist-info/RECORD +0 -44
  31. {kodit-0.1.13.dist-info → kodit-0.1.15.dist-info}/WHEEL +0 -0
  32. {kodit-0.1.13.dist-info → kodit-0.1.15.dist-info}/entry_points.txt +0 -0
  33. {kodit-0.1.13.dist-info → kodit-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,58 @@
1
+ """Local embedding service."""
2
+
3
+ import os
4
+
5
+ import structlog
6
+ import tiktoken
7
+ from sentence_transformers import SentenceTransformer
8
+ from tqdm import tqdm
9
+
10
+ from kodit.embedding.embedding_provider.embedding_provider import (
11
+ EmbeddingProvider,
12
+ Vector,
13
+ split_sub_batches,
14
+ )
15
+
16
+ TINY = "tiny"
17
+ CODE = "code"
18
+ TEST = "test"
19
+
20
+ COMMON_EMBEDDING_MODELS = {
21
+ TINY: "ibm-granite/granite-embedding-30m-english",
22
+ CODE: "flax-sentence-embeddings/st-codesearch-distilroberta-base",
23
+ TEST: "minishlab/potion-base-4M",
24
+ }
25
+
26
+
27
+ class LocalEmbeddingProvider(EmbeddingProvider):
28
+ """Local embedder."""
29
+
30
+ def __init__(self, model_name: str) -> None:
31
+ """Initialize the local embedder."""
32
+ self.log = structlog.get_logger(__name__)
33
+ self.model_name = COMMON_EMBEDDING_MODELS.get(model_name, model_name)
34
+ self.embedding_model = None
35
+ self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
36
+
37
+ def _model(self) -> SentenceTransformer:
38
+ """Get the embedding model."""
39
+ if self.embedding_model is None:
40
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
41
+ self.embedding_model = SentenceTransformer(
42
+ self.model_name,
43
+ trust_remote_code=True,
44
+ device="cpu", # Force CPU so we don't have to install accelerate, etc.
45
+ )
46
+ return self.embedding_model
47
+
48
+ async def embed(self, data: list[str]) -> list[Vector]:
49
+ """Embed a list of strings."""
50
+ model = self._model()
51
+
52
+ batched_data = split_sub_batches(self.encoding, data)
53
+
54
+ results: list[Vector] = []
55
+ for batch in tqdm(batched_data, total=len(batched_data), leave=False):
56
+ embeddings = model.encode(batch, show_progress_bar=False, batch_size=4)
57
+ results.extend([[float(x) for x in embedding] for embedding in embeddings])
58
+ return results
@@ -0,0 +1,63 @@
1
+ """OpenAI embedding service."""
2
+
3
+ import asyncio
4
+
5
+ import structlog
6
+ import tiktoken
7
+ from openai import AsyncOpenAI
8
+
9
+ from kodit.embedding.embedding_provider.embedding_provider import (
10
+ EmbeddingProvider,
11
+ Vector,
12
+ split_sub_batches,
13
+ )
14
+
15
+ OPENAI_NUM_PARALLEL_TASKS = 10
16
+
17
+
18
+ class OpenAIEmbeddingProvider(EmbeddingProvider):
19
+ """OpenAI embedder."""
20
+
21
+ def __init__(
22
+ self,
23
+ openai_client: AsyncOpenAI,
24
+ model_name: str = "text-embedding-3-small",
25
+ ) -> None:
26
+ """Initialize the OpenAI embedder."""
27
+ self.log = structlog.get_logger(__name__)
28
+ self.openai_client = openai_client
29
+ self.model_name = model_name
30
+ self.encoding = tiktoken.encoding_for_model(model_name)
31
+
32
+ async def embed(self, data: list[str]) -> list[Vector]:
33
+ """Embed a list of documents."""
34
+ # First split the list into a list of list where each sublist has fewer than
35
+ # max tokens.
36
+ batched_data = split_sub_batches(self.encoding, data)
37
+
38
+ # Process batches in parallel with a semaphore to limit concurrent requests
39
+ sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
40
+
41
+ async def process_batch(batch: list[str]) -> list[Vector]:
42
+ async with sem:
43
+ try:
44
+ response = await self.openai_client.embeddings.create(
45
+ model=self.model_name,
46
+ input=batch,
47
+ )
48
+ return [
49
+ [float(x) for x in embedding.embedding]
50
+ for embedding in response.data
51
+ ]
52
+ except Exception as e:
53
+ self.log.exception("Error embedding batch", error=str(e))
54
+ return []
55
+
56
+ # Create tasks for all batches
57
+ tasks = [process_batch(batch) for batch in batched_data]
58
+
59
+ # Process all batches and yield results as they complete
60
+ results: list[Vector] = []
61
+ for task in asyncio.as_completed(tasks):
62
+ results.extend(await task)
63
+ return results
@@ -0,0 +1,206 @@
1
+ """Repository for managing embeddings."""
2
+
3
+ import numpy as np
4
+ from sqlalchemy import select
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
+ from kodit.embedding.embedding_models import Embedding, EmbeddingType
8
+
9
+
10
+ class EmbeddingRepository:
11
+ """Repository for managing embeddings.
12
+
13
+ This class provides methods for creating and retrieving embeddings from the
14
+ database. It handles the low-level database operations and transaction management.
15
+
16
+ Args:
17
+ session: The SQLAlchemy async session to use for database operations.
18
+
19
+ """
20
+
21
+ def __init__(self, session: AsyncSession) -> None:
22
+ """Initialize the embedding repository."""
23
+ self.session = session
24
+
25
+ async def create_embedding(self, embedding: Embedding) -> Embedding:
26
+ """Create a new embedding record in the database.
27
+
28
+ Args:
29
+ embedding: The Embedding instance to create.
30
+
31
+ Returns:
32
+ The created Embedding instance.
33
+
34
+ """
35
+ self.session.add(embedding)
36
+ await self.session.commit()
37
+ return embedding
38
+
39
+ async def get_embedding_by_snippet_id_and_type(
40
+ self, snippet_id: int, embedding_type: EmbeddingType
41
+ ) -> Embedding | None:
42
+ """Get an embedding by its snippet ID and type.
43
+
44
+ Args:
45
+ snippet_id: The ID of the snippet to get the embedding for.
46
+ embedding_type: The type of embedding to get.
47
+
48
+ Returns:
49
+ The Embedding instance if found, None otherwise.
50
+
51
+ """
52
+ query = select(Embedding).where(
53
+ Embedding.snippet_id == snippet_id,
54
+ Embedding.type == embedding_type,
55
+ )
56
+ result = await self.session.execute(query)
57
+ return result.scalar_one_or_none()
58
+
59
+ async def list_embeddings_by_type(
60
+ self, embedding_type: EmbeddingType
61
+ ) -> list[Embedding]:
62
+ """List all embeddings of a given type.
63
+
64
+ Args:
65
+ embedding_type: The type of embeddings to list.
66
+
67
+ Returns:
68
+ A list of Embedding instances.
69
+
70
+ """
71
+ query = select(Embedding).where(Embedding.type == embedding_type)
72
+ result = await self.session.execute(query)
73
+ return list(result.scalars())
74
+
75
+ async def delete_embeddings_by_snippet_id(self, snippet_id: int) -> None:
76
+ """Delete all embeddings for a snippet.
77
+
78
+ Args:
79
+ snippet_id: The ID of the snippet to delete embeddings for.
80
+
81
+ """
82
+ query = select(Embedding).where(Embedding.snippet_id == snippet_id)
83
+ result = await self.session.execute(query)
84
+ embeddings = result.scalars().all()
85
+ for embedding in embeddings:
86
+ await self.session.delete(embedding)
87
+ await self.session.commit()
88
+
89
+ async def list_semantic_results(
90
+ self, embedding_type: EmbeddingType, embedding: list[float], top_k: int = 10
91
+ ) -> list[tuple[int, float]]:
92
+ """List semantic results using cosine similarity.
93
+
94
+ This implementation fetches all embeddings of the given type and computes
95
+ cosine similarity in Python using NumPy for better performance.
96
+
97
+ Args:
98
+ embedding_type: The type of embeddings to search
99
+ embedding: The query embedding vector
100
+ top_k: Number of results to return
101
+
102
+ Returns:
103
+ List of (snippet_id, similarity_score) tuples, sorted by similarity
104
+
105
+ """
106
+ # Step 1: Fetch embeddings from database
107
+ embeddings = await self._list_embedding_values(embedding_type)
108
+ if not embeddings:
109
+ return []
110
+
111
+ # Step 2: Convert to numpy arrays
112
+ stored_vecs, query_vec = self._prepare_vectors(embeddings, embedding)
113
+
114
+ # Step 3: Compute similarities
115
+ similarities = self._compute_similarities(stored_vecs, query_vec)
116
+
117
+ # Step 4: Get top-k results
118
+ return self._get_top_k_results(similarities, embeddings, top_k)
119
+
120
+ async def _list_embedding_values(
121
+ self, embedding_type: EmbeddingType
122
+ ) -> list[tuple[int, list[float]]]:
123
+ """List all embeddings of a given type from the database.
124
+
125
+ Args:
126
+ embedding_type: The type of embeddings to fetch
127
+
128
+ Returns:
129
+ List of (snippet_id, embedding) tuples
130
+
131
+ """
132
+ # Only select the fields we need and use a more efficient query
133
+ query = select(Embedding.snippet_id, Embedding.embedding).where(
134
+ Embedding.type == embedding_type
135
+ )
136
+ rows = await self.session.execute(query)
137
+ return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
138
+
139
+ def _prepare_vectors(
140
+ self, embeddings: list[tuple[int, list[float]]], query_embedding: list[float]
141
+ ) -> tuple[np.ndarray, np.ndarray]:
142
+ """Convert embeddings to numpy arrays.
143
+
144
+ Args:
145
+ embeddings: List of (snippet_id, embedding) tuples
146
+ query_embedding: Query embedding vector
147
+
148
+ Returns:
149
+ Tuple of (stored_vectors, query_vector) as numpy arrays
150
+
151
+ """
152
+ try:
153
+ stored_vecs = np.array(
154
+ [emb[1] for emb in embeddings]
155
+ ) # Use index 1 to get embedding
156
+ except ValueError as e:
157
+ if "inhomogeneous" in str(e):
158
+ msg = (
159
+ "The database has returned embeddings of different sizes. If you"
160
+ "have recently updated the embedding model, you will need to"
161
+ "delete your database and re-index your snippets."
162
+ )
163
+ raise ValueError(msg) from e
164
+ raise
165
+
166
+ query_vec = np.array(query_embedding)
167
+ return stored_vecs, query_vec
168
+
169
+ def _compute_similarities(
170
+ self, stored_vecs: np.ndarray, query_vec: np.ndarray
171
+ ) -> np.ndarray:
172
+ """Compute cosine similarities between stored vectors and query vector.
173
+
174
+ Args:
175
+ stored_vecs: Array of stored embedding vectors
176
+ query_vec: Query embedding vector
177
+
178
+ Returns:
179
+ Array of similarity scores
180
+
181
+ """
182
+ stored_norms = np.linalg.norm(stored_vecs, axis=1)
183
+ query_norm = np.linalg.norm(query_vec)
184
+ return np.dot(stored_vecs, query_vec) / (stored_norms * query_norm)
185
+
186
+ def _get_top_k_results(
187
+ self,
188
+ similarities: np.ndarray,
189
+ embeddings: list[tuple[int, list[float]]],
190
+ top_k: int,
191
+ ) -> list[tuple[int, float]]:
192
+ """Get top-k results by similarity score.
193
+
194
+ Args:
195
+ similarities: Array of similarity scores
196
+ embeddings: List of (snippet_id, embedding) tuples
197
+ top_k: Number of results to return
198
+
199
+ Returns:
200
+ List of (snippet_id, similarity_score) tuples
201
+
202
+ """
203
+ top_indices = np.argsort(similarities)[::-1][:top_k]
204
+ return [
205
+ (embeddings[i][0], float(similarities[i])) for i in top_indices
206
+ ] # Use index 0 to get snippet_id
@@ -0,0 +1,50 @@
1
+ """Local vector search."""
2
+
3
+ import structlog
4
+ import tiktoken
5
+
6
+ from kodit.embedding.embedding_models import Embedding, EmbeddingType
7
+ from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
8
+ from kodit.embedding.embedding_repository import EmbeddingRepository
9
+ from kodit.embedding.vector_search_service import (
10
+ VectorSearchRequest,
11
+ VectorSearchResponse,
12
+ VectorSearchService,
13
+ )
14
+
15
+
16
+ class LocalVectorSearchService(VectorSearchService):
17
+ """Local vector search."""
18
+
19
+ def __init__(
20
+ self,
21
+ embedding_repository: EmbeddingRepository,
22
+ embedding_provider: EmbeddingProvider,
23
+ ) -> None:
24
+ """Initialize the local embedder."""
25
+ self.log = structlog.get_logger(__name__)
26
+ self.embedding_repository = embedding_repository
27
+ self.embedding_provider = embedding_provider
28
+ self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
29
+
30
+ async def index(self, data: list[VectorSearchRequest]) -> None:
31
+ """Embed a list of documents."""
32
+ embeddings = await self.embedding_provider.embed([i.text for i in data])
33
+ for i, x in zip(data, embeddings, strict=False):
34
+ await self.embedding_repository.create_embedding(
35
+ Embedding(
36
+ snippet_id=i.snippet_id,
37
+ embedding=[float(y) for y in x],
38
+ type=EmbeddingType.CODE,
39
+ )
40
+ )
41
+
42
+ async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
43
+ """Query the embedding model."""
44
+ embedding = (await self.embedding_provider.embed([query]))[0]
45
+ results = await self.embedding_repository.list_semantic_results(
46
+ EmbeddingType.CODE, [float(x) for x in embedding], top_k
47
+ )
48
+ return [
49
+ VectorSearchResponse(snippet_id, score) for snippet_id, score in results
50
+ ]
@@ -0,0 +1,38 @@
1
+ """Embedding service."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import NamedTuple
5
+
6
+
7
+ class VectorSearchResponse(NamedTuple):
8
+ """Embedding result."""
9
+
10
+ snippet_id: int
11
+ score: float
12
+
13
+
14
+ class VectorSearchRequest(NamedTuple):
15
+ """Input for embedding."""
16
+
17
+ snippet_id: int
18
+ text: str
19
+
20
+
21
+ class VectorSearchService(ABC):
22
+ """Semantic search service interface."""
23
+
24
+ @abstractmethod
25
+ async def index(self, data: list[VectorSearchRequest]) -> None:
26
+ """Embed a list of documents.
27
+
28
+ The embedding service accepts a massive list of id,strings to embed. Behind the
29
+ scenes it batches up requests and parallelizes them for performance according to
30
+ the specifics of the embedding service.
31
+
32
+ The id reference is required because the parallelization may return results out
33
+ of order.
34
+ """
35
+
36
+ @abstractmethod
37
+ async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
38
+ """Query the embedding model."""
@@ -0,0 +1,145 @@
1
+ """Vectorchord vector search."""
2
+
3
+ from typing import Any
4
+
5
+ from sqlalchemy import Result, TextClause, text
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
9
+ from kodit.embedding.vector_search_service import (
10
+ VectorSearchRequest,
11
+ VectorSearchResponse,
12
+ VectorSearchService,
13
+ )
14
+
15
+ TABLE_NAME = "vectorchord_embeddings"
16
+ INDEX_NAME = f"{TABLE_NAME}_idx"
17
+
18
+ # SQL Queries
19
+ CREATE_VCHORD_EXTENSION = """
20
+ CREATE EXTENSION IF NOT EXISTS vchord CASCADE;
21
+ """
22
+
23
+ CHECK_VCHORD_EMBEDDING_DIMENSION = f"""
24
+ SELECT a.atttypmod as dimension
25
+ FROM pg_attribute a
26
+ JOIN pg_class c ON a.attrelid = c.oid
27
+ WHERE c.relname = '{TABLE_NAME}'
28
+ AND a.attname = 'embedding';
29
+ """ # noqa: S608
30
+
31
+ CREATE_VCHORD_INDEX = f"""
32
+ CREATE INDEX IF NOT EXISTS {INDEX_NAME}
33
+ ON {TABLE_NAME}
34
+ USING vchordrq (embedding vector_l2_ops) WITH (options = $$
35
+ residual_quantization = true
36
+ [build.internal]
37
+ lists = []
38
+ $$);
39
+ """
40
+
41
+ INSERT_QUERY = f"""
42
+ INSERT INTO {TABLE_NAME} (snippet_id, embedding)
43
+ VALUES (:snippet_id, :embedding)
44
+ ON CONFLICT (snippet_id) DO UPDATE
45
+ SET embedding = EXCLUDED.embedding
46
+ """ # noqa: S608
47
+
48
+ # Note that <=> in vectorchord is cosine distance
49
+ # So scores go from 0 (similar) to 2 (opposite)
50
+ SEARCH_QUERY = f"""
51
+ SELECT snippet_id, embedding <=> :query as score
52
+ FROM {TABLE_NAME}
53
+ ORDER BY score ASC
54
+ LIMIT :top_k;
55
+ """ # noqa: S608
56
+
57
+
58
+ class VectorChordVectorSearchService(VectorSearchService):
59
+ """VectorChord vector search."""
60
+
61
+ def __init__(
62
+ self,
63
+ session: AsyncSession,
64
+ embedding_provider: EmbeddingProvider,
65
+ ) -> None:
66
+ """Initialize the VectorChord BM25."""
67
+ self.embedding_provider = embedding_provider
68
+ self._session = session
69
+ self._initialized = False
70
+
71
+ async def _initialize(self) -> None:
72
+ """Initialize the VectorChord environment."""
73
+ try:
74
+ await self._create_extensions()
75
+ await self._create_tables()
76
+ self._initialized = True
77
+ except Exception as e:
78
+ msg = f"Failed to initialize VectorChord repository: {e}"
79
+ raise RuntimeError(msg) from e
80
+
81
+ async def _create_extensions(self) -> None:
82
+ """Create the necessary extensions."""
83
+ await self._session.execute(text(CREATE_VCHORD_EXTENSION))
84
+ await self._commit()
85
+
86
+ async def _create_tables(self) -> None:
87
+ """Create the necessary tables."""
88
+ vector_dim = (await self.embedding_provider.embed(["dimension"]))[0]
89
+ await self._session.execute(
90
+ text(
91
+ f"""CREATE TABLE IF NOT EXISTS {TABLE_NAME} (
92
+ id SERIAL PRIMARY KEY,
93
+ snippet_id INT NOT NULL UNIQUE,
94
+ embedding VECTOR({len(vector_dim)}) NOT NULL
95
+ );"""
96
+ )
97
+ )
98
+ await self._session.execute(text(CREATE_VCHORD_INDEX))
99
+ result = await self._session.execute(text(CHECK_VCHORD_EMBEDDING_DIMENSION))
100
+ vector_dim_from_db = result.scalar_one()
101
+ if vector_dim_from_db != len(vector_dim):
102
+ msg = (
103
+ f"Embedding vector dimension does not match database, "
104
+ f"please delete your index: {vector_dim_from_db} != {len(vector_dim)}"
105
+ )
106
+ raise ValueError(msg)
107
+ await self._commit()
108
+
109
+ async def _execute(
110
+ self, query: TextClause, param_list: list[Any] | dict[str, Any] | None = None
111
+ ) -> Result:
112
+ """Execute a query."""
113
+ if not self._initialized:
114
+ await self._initialize()
115
+ return await self._session.execute(query, param_list)
116
+
117
+ async def _commit(self) -> None:
118
+ """Commit the session."""
119
+ await self._session.commit()
120
+
121
+ async def index(self, data: list[VectorSearchRequest]) -> None:
122
+ """Embed a list of documents."""
123
+ embeddings = await self.embedding_provider.embed([doc.text for doc in data])
124
+ # Execute inserts
125
+ await self._execute(
126
+ text(INSERT_QUERY),
127
+ [
128
+ {"snippet_id": doc.snippet_id, "embedding": str(embedding)}
129
+ for doc, embedding in zip(data, embeddings, strict=True)
130
+ ],
131
+ )
132
+ await self._commit()
133
+
134
+ async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
135
+ """Query the embedding model."""
136
+ embedding = await self.embedding_provider.embed([query])
137
+ result = await self._execute(
138
+ text(SEARCH_QUERY), {"query": str(embedding[0]), "top_k": top_k}
139
+ )
140
+ rows = result.mappings().all()
141
+
142
+ return [
143
+ VectorSearchResponse(snippet_id=row["snippet_id"], score=row["score"])
144
+ for row in rows
145
+ ]
@@ -10,6 +10,7 @@ from typing import TypeVar
10
10
 
11
11
  from sqlalchemy import delete, func, select
12
12
  from sqlalchemy.ext.asyncio import AsyncSession
13
+ from sqlalchemy.orm.exc import MultipleResultsFound
13
14
 
14
15
  from kodit.embedding.embedding_models import Embedding
15
16
  from kodit.indexing.indexing_models import Index, Snippet
@@ -124,15 +125,34 @@ class IndexRepository:
124
125
  index.updated_at = datetime.now(UTC)
125
126
  await self.session.commit()
126
127
 
127
- async def add_snippet(self, snippet: Snippet) -> None:
128
- """Add a new snippet to the database.
128
+ async def add_snippet_or_update_content(self, snippet: Snippet) -> None:
129
+ """Add a new snippet to the database if it doesn't exist, otherwise update it.
129
130
 
130
131
  Args:
131
132
  snippet: The Snippet instance to add.
132
133
 
133
134
  """
134
- self.session.add(snippet)
135
- await self.session.commit()
135
+ query = select(Snippet).where(
136
+ Snippet.file_id == snippet.file_id,
137
+ Snippet.index_id == snippet.index_id,
138
+ )
139
+ result = await self.session.execute(query)
140
+ try:
141
+ existing_snippet = result.scalar_one_or_none()
142
+
143
+ if existing_snippet:
144
+ existing_snippet.content = snippet.content
145
+ else:
146
+ self.session.add(snippet)
147
+
148
+ await self.session.commit()
149
+ except MultipleResultsFound as e:
150
+ msg = (
151
+ f"Multiple snippets found for file_id {snippet.file_id}, this "
152
+ "shouldn't happen. "
153
+ "Please report this as a bug then delete your index and start again."
154
+ )
155
+ raise ValueError(msg) from e
136
156
 
137
157
  async def delete_all_snippets(self, index_id: int) -> None:
138
158
  """Delete all snippets for an index.