kodit 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (29) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +6 -0
  3. kodit/cli.py +8 -2
  4. kodit/embedding/embedding_factory.py +11 -0
  5. kodit/embedding/embedding_provider/embedding_provider.py +42 -14
  6. kodit/embedding/embedding_provider/hash_embedding_provider.py +16 -7
  7. kodit/embedding/embedding_provider/local_embedding_provider.py +43 -11
  8. kodit/embedding/embedding_provider/openai_embedding_provider.py +18 -22
  9. kodit/embedding/local_vector_search_service.py +46 -13
  10. kodit/embedding/vector_search_service.py +18 -1
  11. kodit/embedding/vectorchord_vector_search_service.py +63 -16
  12. kodit/enrichment/enrichment_factory.py +3 -0
  13. kodit/enrichment/enrichment_provider/enrichment_provider.py +21 -1
  14. kodit/enrichment/enrichment_provider/local_enrichment_provider.py +39 -28
  15. kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +25 -27
  16. kodit/enrichment/enrichment_service.py +19 -7
  17. kodit/indexing/indexing_service.py +50 -23
  18. kodit/log.py +126 -24
  19. kodit/migrations/versions/9e53ea8bb3b0_add_authors.py +103 -0
  20. kodit/source/source_factories.py +356 -0
  21. kodit/source/source_models.py +17 -5
  22. kodit/source/source_repository.py +49 -20
  23. kodit/source/source_service.py +41 -218
  24. {kodit-0.2.2.dist-info → kodit-0.2.4.dist-info}/METADATA +2 -2
  25. {kodit-0.2.2.dist-info → kodit-0.2.4.dist-info}/RECORD +28 -27
  26. kodit/migrations/versions/42e836b21102_add_authors.py +0 -64
  27. {kodit-0.2.2.dist-info → kodit-0.2.4.dist-info}/WHEEL +0 -0
  28. {kodit-0.2.2.dist-info → kodit-0.2.4.dist-info}/entry_points.txt +0 -0
  29. {kodit-0.2.2.dist-info → kodit-0.2.4.dist-info}/licenses/LICENSE +0 -0
kodit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.2'
21
- __version_tuple__ = version_tuple = (0, 2, 2)
20
+ __version__ = version = '0.2.4'
21
+ __version_tuple__ = version_tuple = (0, 2, 4)
kodit/app.py CHANGED
@@ -21,6 +21,12 @@ async def root() -> dict[str, str]:
21
21
  return {"message": "Hello, World!"}
22
22
 
23
23
 
24
+ @app.get("/healthz")
25
+ async def healthz() -> dict[str, str]:
26
+ """Return a health check for the kodit API."""
27
+ return {"status": "ok"}
28
+
29
+
24
30
  # Add mcp routes last, otherwise previous routes aren't added
25
31
  app.mount("", mcp_app)
26
32
 
kodit/cli.py CHANGED
@@ -81,6 +81,7 @@ async def index(
81
81
  )
82
82
 
83
83
  if not sources:
84
+ log_event("kodit.cli.index.list")
84
85
  # No source specified, list all indexes
85
86
  indexes = await service.list_indexes()
86
87
  headers: list[str | Cell] = [
@@ -108,7 +109,8 @@ async def index(
108
109
  msg = "File indexing is not implemented yet"
109
110
  raise click.UsageError(msg)
110
111
 
111
- # Index directory
112
+ # Index source
113
+ log_event("kodit.cli.index.create")
112
114
  s = await source_service.create(source)
113
115
  index = await service.create(s.id)
114
116
  await service.run(index.id)
@@ -134,6 +136,7 @@ async def code(
134
136
 
135
137
  This works best if your query is code.
136
138
  """
139
+ log_event("kodit.cli.search.code")
137
140
  source_repository = SourceRepository(session)
138
141
  source_service = SourceService(app_context.get_clone_dir(), source_repository)
139
142
  repository = IndexRepository(session)
@@ -177,6 +180,7 @@ async def keyword(
177
180
  top_k: int,
178
181
  ) -> None:
179
182
  """Search for snippets using keyword search."""
183
+ log_event("kodit.cli.search.keyword")
180
184
  source_repository = SourceRepository(session)
181
185
  source_service = SourceService(app_context.get_clone_dir(), source_repository)
182
186
  repository = IndexRepository(session)
@@ -223,6 +227,7 @@ async def text(
223
227
 
224
228
  This works best if your query is text.
225
229
  """
230
+ log_event("kodit.cli.search.text")
226
231
  source_repository = SourceRepository(session)
227
232
  source_service = SourceService(app_context.get_clone_dir(), source_repository)
228
233
  repository = IndexRepository(session)
@@ -270,6 +275,7 @@ async def hybrid( # noqa: PLR0913
270
275
  text: str,
271
276
  ) -> None:
272
277
  """Search for snippets using hybrid search."""
278
+ log_event("kodit.cli.search.hybrid")
273
279
  source_repository = SourceRepository(session)
274
280
  source_service = SourceService(app_context.get_clone_dir(), source_repository)
275
281
  repository = IndexRepository(session)
@@ -321,7 +327,7 @@ def serve(
321
327
  """Start the kodit server, which hosts the MCP server and the kodit API."""
322
328
  log = structlog.get_logger(__name__)
323
329
  log.info("Starting kodit server", host=host, port=port)
324
- log_event("kodit_server_started")
330
+ log_event("kodit.cli.serve")
325
331
 
326
332
  # Configure uvicorn with graceful shutdown
327
333
  config = uvicorn.Config(
@@ -3,6 +3,7 @@
3
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
4
 
5
5
  from kodit.config import AppContext, Endpoint
6
+ from kodit.embedding.embedding_models import EmbeddingType
6
7
  from kodit.embedding.embedding_provider.local_embedding_provider import (
7
8
  CODE,
8
9
  LocalEmbeddingProvider,
@@ -19,6 +20,7 @@ from kodit.embedding.vectorchord_vector_search_service import (
19
20
  TaskName,
20
21
  VectorChordVectorSearchService,
21
22
  )
23
+ from kodit.log import log_event
22
24
 
23
25
 
24
26
  def _get_endpoint_configuration(app_context: AppContext) -> Endpoint | None:
@@ -34,6 +36,7 @@ def embedding_factory(
34
36
  endpoint = _get_endpoint_configuration(app_context)
35
37
 
36
38
  if endpoint and endpoint.type == "openai":
39
+ log_event("kodit.embedding", {"provider": "openai"})
37
40
  from openai import AsyncOpenAI
38
41
 
39
42
  embedding_provider = OpenAIEmbeddingProvider(
@@ -44,14 +47,22 @@ def embedding_factory(
44
47
  model_name=endpoint.model or "text-embedding-3-small",
45
48
  )
46
49
  else:
50
+ log_event("kodit.embedding", {"provider": "local"})
47
51
  embedding_provider = LocalEmbeddingProvider(CODE)
48
52
 
49
53
  if app_context.default_search.provider == "vectorchord":
54
+ log_event("kodit.database", {"provider": "vectorchord"})
50
55
  return VectorChordVectorSearchService(task_name, session, embedding_provider)
51
56
  if app_context.default_search.provider == "sqlite":
57
+ log_event("kodit.database", {"provider": "sqlite"})
58
+ if task_name == "code":
59
+ embedding_type = EmbeddingType.CODE
60
+ elif task_name == "text":
61
+ embedding_type = EmbeddingType.TEXT
52
62
  return LocalVectorSearchService(
53
63
  embedding_repository=embedding_repository,
54
64
  embedding_provider=embedding_provider,
65
+ embedding_type=embedding_type,
55
66
  )
56
67
 
57
68
  msg = f"Invalid semantic search provider: {app_context.default_search.provider}"
@@ -1,6 +1,8 @@
1
1
  """Embedding provider."""
2
2
 
3
3
  from abc import ABC, abstractmethod
4
+ from collections.abc import AsyncGenerator
5
+ from dataclasses import dataclass
4
6
 
5
7
  import structlog
6
8
  import tiktoken
@@ -10,11 +12,29 @@ OPENAI_MAX_EMBEDDING_SIZE = 8192
10
12
  Vector = list[float]
11
13
 
12
14
 
15
+ @dataclass
16
+ class EmbeddingRequest:
17
+ """Embedding request."""
18
+
19
+ id: int
20
+ text: str
21
+
22
+
23
+ @dataclass
24
+ class EmbeddingResponse:
25
+ """Embedding response."""
26
+
27
+ id: int
28
+ embedding: Vector
29
+
30
+
13
31
  class EmbeddingProvider(ABC):
14
32
  """Embedding provider."""
15
33
 
16
34
  @abstractmethod
17
- async def embed(self, data: list[str]) -> list[Vector]:
35
+ def embed(
36
+ self, data: list[EmbeddingRequest]
37
+ ) -> AsyncGenerator[list[EmbeddingResponse], None]:
18
38
  """Embed a list of strings.
19
39
 
20
40
  The embedding provider is responsible for embedding a list of strings into a
@@ -25,13 +45,13 @@ class EmbeddingProvider(ABC):
25
45
 
26
46
  def split_sub_batches(
27
47
  encoding: tiktoken.Encoding,
28
- data: list[str],
48
+ data: list[EmbeddingRequest],
29
49
  max_context_window: int = OPENAI_MAX_EMBEDDING_SIZE,
30
- ) -> list[list[str]]:
50
+ ) -> list[list[EmbeddingRequest]]:
31
51
  """Split a list of strings into smaller sub-batches."""
32
52
  log = structlog.get_logger(__name__)
33
53
  result = []
34
- data_to_process = [s for s in data if s.strip()] # Filter out empty strings
54
+ data_to_process = [s for s in data if s.text.strip()] # Filter out empty strings
35
55
 
36
56
  while data_to_process:
37
57
  next_batch = []
@@ -39,18 +59,26 @@ def split_sub_batches(
39
59
 
40
60
  while data_to_process:
41
61
  next_item = data_to_process[0]
42
- item_tokens = len(encoding.encode(next_item))
62
+ item_tokens = len(encoding.encode(next_item.text, disallowed_special=()))
43
63
 
44
64
  if item_tokens > max_context_window:
45
- # Loop around trying to truncate the snippet until it fits in the max
46
- # embedding size
47
- while item_tokens > max_context_window:
48
- next_item = next_item[:-1]
49
- item_tokens = len(encoding.encode(next_item))
50
-
51
- data_to_process[0] = next_item
52
-
53
- log.warning("Truncated snippet", snippet=next_item)
65
+ # Optimise truncation by operating on tokens directly instead of
66
+ # removing one character at a time and repeatedly re-encoding.
67
+ tokens = encoding.encode(next_item.text, disallowed_special=())
68
+ if len(tokens) > max_context_window:
69
+ # Keep only the first *max_context_window* tokens.
70
+ tokens = tokens[:max_context_window]
71
+ # Convert back to text. This requires only one decode call and
72
+ # guarantees that the resulting string fits the token budget.
73
+ next_item.text = encoding.decode(tokens)
74
+ item_tokens = max_context_window # We know the exact size now
75
+
76
+ data_to_process[0] = next_item
77
+
78
+ log.warning(
79
+ "Truncated snippet because it was too long to embed",
80
+ snippet=next_item.text[:100] + "...",
81
+ )
54
82
 
55
83
  if current_tokens + item_tokens > max_context_window:
56
84
  break
@@ -3,10 +3,12 @@
3
3
  import asyncio
4
4
  import hashlib
5
5
  import math
6
- from collections.abc import Generator, Sequence
6
+ from collections.abc import AsyncGenerator, Generator, Sequence
7
7
 
8
8
  from kodit.embedding.embedding_provider.embedding_provider import (
9
9
  EmbeddingProvider,
10
+ EmbeddingRequest,
11
+ EmbeddingResponse,
10
12
  Vector,
11
13
  )
12
14
 
@@ -31,27 +33,34 @@ class HashEmbeddingProvider(EmbeddingProvider):
31
33
  self.dim = dim
32
34
  self.batch_size = batch_size
33
35
 
34
- async def embed(self, data: list[str]) -> list[Vector]:
36
+ async def embed(
37
+ self, data: list[EmbeddingRequest]
38
+ ) -> AsyncGenerator[list[EmbeddingResponse], None]:
35
39
  """Embed every string in *data*, preserving order.
36
40
 
37
41
  Work is sliced into *batch_size* chunks and scheduled concurrently
38
42
  (still CPU-bound, but enough to cooperate with an asyncio loop).
39
43
  """
40
44
  if not data:
41
- return []
45
+ yield []
42
46
 
43
47
  async def _embed_chunk(chunk: Sequence[str]) -> list[Vector]:
44
48
  return [self._string_to_vector(text) for text in chunk]
45
49
 
46
50
  tasks = [
47
51
  asyncio.create_task(_embed_chunk(chunk))
48
- for chunk in self._chunked(data, self.batch_size)
52
+ for chunk in self._chunked([i.text for i in data], self.batch_size)
49
53
  ]
50
54
 
51
- vectors: list[Vector] = []
52
55
  for task in tasks:
53
- vectors.extend(await task)
54
- return vectors
56
+ result = await task
57
+ yield [
58
+ EmbeddingResponse(
59
+ id=item.id,
60
+ embedding=embedding,
61
+ )
62
+ for item, embedding in zip(data, result, strict=True)
63
+ ]
55
64
 
56
65
  @staticmethod
57
66
  def _chunked(seq: Sequence[str], size: int) -> Generator[Sequence[str], None, None]:
@@ -3,20 +3,24 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import os
6
+ from time import time
6
7
  from typing import TYPE_CHECKING
7
8
 
8
9
  import structlog
9
- import tiktoken
10
- from tqdm import tqdm
11
10
 
12
11
  from kodit.embedding.embedding_provider.embedding_provider import (
13
12
  EmbeddingProvider,
14
- Vector,
13
+ EmbeddingRequest,
14
+ EmbeddingResponse,
15
15
  split_sub_batches,
16
16
  )
17
17
 
18
18
  if TYPE_CHECKING:
19
+ from collections.abc import AsyncGenerator
20
+
19
21
  from sentence_transformers import SentenceTransformer
22
+ from tiktoken import Encoding
23
+
20
24
 
21
25
  TINY = "tiny"
22
26
  CODE = "code"
@@ -36,8 +40,22 @@ class LocalEmbeddingProvider(EmbeddingProvider):
36
40
  """Initialize the local embedder."""
37
41
  self.log = structlog.get_logger(__name__)
38
42
  self.model_name = COMMON_EMBEDDING_MODELS.get(model_name, model_name)
43
+ self.encoding_name = "text-embedding-3-small"
39
44
  self.embedding_model = None
40
- self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
45
+ self.encoding = None
46
+
47
+ def _encoding(self) -> Encoding:
48
+ if self.encoding is None:
49
+ from tiktoken import encoding_for_model
50
+
51
+ start_time = time()
52
+ self.encoding = encoding_for_model(self.encoding_name)
53
+ self.log.debug(
54
+ "Encoding loaded",
55
+ model_name=self.encoding_name,
56
+ duration=time() - start_time,
57
+ )
58
+ return self.encoding
41
59
 
42
60
  def _model(self) -> SentenceTransformer:
43
61
  """Get the embedding model."""
@@ -45,20 +63,34 @@ class LocalEmbeddingProvider(EmbeddingProvider):
45
63
  os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
46
64
  from sentence_transformers import SentenceTransformer
47
65
 
66
+ start_time = time()
48
67
  self.embedding_model = SentenceTransformer(
49
68
  self.model_name,
50
69
  trust_remote_code=True,
51
70
  )
71
+ self.log.debug(
72
+ "Model loaded",
73
+ model_name=self.model_name,
74
+ duration=time() - start_time,
75
+ )
52
76
  return self.embedding_model
53
77
 
54
- async def embed(self, data: list[str]) -> list[Vector]:
78
+ async def embed(
79
+ self, data: list[EmbeddingRequest]
80
+ ) -> AsyncGenerator[list[EmbeddingResponse], None]:
55
81
  """Embed a list of strings."""
56
82
  model = self._model()
57
83
 
58
- batched_data = split_sub_batches(self.encoding, data)
84
+ batched_data = split_sub_batches(self._encoding(), data)
59
85
 
60
- results: list[Vector] = []
61
- for batch in tqdm(batched_data, total=len(batched_data), leave=False):
62
- embeddings = model.encode(batch, show_progress_bar=False, batch_size=4)
63
- results.extend([[float(x) for x in embedding] for embedding in embeddings])
64
- return results
86
+ for batch in batched_data:
87
+ embeddings = model.encode(
88
+ [i.text for i in batch], show_progress_bar=False, batch_size=4
89
+ )
90
+ yield [
91
+ EmbeddingResponse(
92
+ id=item.id,
93
+ embedding=[float(x) for x in embedding],
94
+ )
95
+ for item, embedding in zip(batch, embeddings, strict=True)
96
+ ]
@@ -1,6 +1,7 @@
1
1
  """OpenAI embedding service."""
2
2
 
3
3
  import asyncio
4
+ from collections.abc import AsyncGenerator
4
5
 
5
6
  import structlog
6
7
  import tiktoken
@@ -8,7 +9,8 @@ from openai import AsyncOpenAI
8
9
 
9
10
  from kodit.embedding.embedding_provider.embedding_provider import (
10
11
  EmbeddingProvider,
11
- Vector,
12
+ EmbeddingRequest,
13
+ EmbeddingResponse,
12
14
  split_sub_batches,
13
15
  )
14
16
 
@@ -31,7 +33,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
31
33
  "text-embedding-3-small"
32
34
  ) # Sensible default
33
35
 
34
- async def embed(self, data: list[str]) -> list[Vector]:
36
+ async def embed(
37
+ self, data: list[EmbeddingRequest]
38
+ ) -> AsyncGenerator[list[EmbeddingResponse], None]:
35
39
  """Embed a list of documents."""
36
40
  # First split the list into a list of list where each sublist has fewer than
37
41
  # max tokens.
@@ -40,38 +44,30 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
40
44
  # Process batches in parallel with a semaphore to limit concurrent requests
41
45
  sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
42
46
 
43
- # Create a list of tuples with a temporary id for each batch
44
- # We need to do this so that we can return the results in the same order as the
45
- # input data
46
- input_data = [(i, batch) for i, batch in enumerate(batched_data)]
47
-
48
47
  async def process_batch(
49
- data: tuple[int, list[str]],
50
- ) -> tuple[int, list[Vector]]:
51
- batch_id, batch = data
48
+ data: list[EmbeddingRequest],
49
+ ) -> list[EmbeddingResponse]:
52
50
  async with sem:
53
51
  try:
54
52
  response = await self.openai_client.embeddings.create(
55
53
  model=self.model_name,
56
- input=batch,
54
+ input=[i.text for i in data],
57
55
  )
58
- return batch_id, [
59
- [float(x) for x in embedding.embedding]
60
- for embedding in response.data
56
+ return [
57
+ EmbeddingResponse(
58
+ id=item.id,
59
+ embedding=embedding.embedding,
60
+ )
61
+ for item, embedding in zip(data, response.data, strict=True)
61
62
  ]
62
63
  except Exception as e:
63
64
  self.log.exception("Error embedding batch", error=str(e))
64
- return batch_id, []
65
+ return []
65
66
 
66
67
  # Create tasks for all batches
67
- tasks = [process_batch(batch) for batch in input_data]
68
+ tasks = [process_batch(batch) for batch in batched_data]
68
69
 
69
70
  # Process all batches and yield results as they complete
70
- results: list[tuple[int, list[Vector]]] = []
71
71
  for task in asyncio.as_completed(tasks):
72
72
  result = await task
73
- results.append(result)
74
-
75
- # Output in the same order as the input data
76
- ordered_results = [result for _, result in sorted(results, key=lambda x: x[0])]
77
- return [item for sublist in ordered_results for item in sublist]
73
+ yield result
@@ -1,12 +1,18 @@
1
1
  """Local vector search."""
2
2
 
3
+ from collections.abc import AsyncGenerator
4
+
3
5
  import structlog
4
6
  import tiktoken
5
7
 
6
8
  from kodit.embedding.embedding_models import Embedding, EmbeddingType
7
- from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
9
+ from kodit.embedding.embedding_provider.embedding_provider import (
10
+ EmbeddingProvider,
11
+ EmbeddingRequest,
12
+ )
8
13
  from kodit.embedding.embedding_repository import EmbeddingRepository
9
14
  from kodit.embedding.vector_search_service import (
15
+ IndexResult,
10
16
  VectorSearchRequest,
11
17
  VectorSearchResponse,
12
18
  VectorSearchService,
@@ -20,35 +26,62 @@ class LocalVectorSearchService(VectorSearchService):
20
26
  self,
21
27
  embedding_repository: EmbeddingRepository,
22
28
  embedding_provider: EmbeddingProvider,
29
+ embedding_type: EmbeddingType = EmbeddingType.CODE,
23
30
  ) -> None:
24
31
  """Initialize the local embedder."""
25
32
  self.log = structlog.get_logger(__name__)
26
33
  self.embedding_repository = embedding_repository
27
34
  self.embedding_provider = embedding_provider
28
35
  self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
36
+ self.embedding_type = embedding_type
29
37
 
30
- async def index(self, data: list[VectorSearchRequest]) -> None:
38
+ async def index(
39
+ self, data: list[VectorSearchRequest]
40
+ ) -> AsyncGenerator[list[IndexResult], None]:
31
41
  """Embed a list of documents."""
32
42
  if not data or len(data) == 0:
33
- self.log.warning("Embedding data is empty, skipping embedding")
34
43
  return
35
44
 
36
- embeddings = await self.embedding_provider.embed([i.text for i in data])
37
- for i, x in zip(data, embeddings, strict=False):
38
- await self.embedding_repository.create_embedding(
39
- Embedding(
40
- snippet_id=i.snippet_id,
41
- embedding=[float(y) for y in x],
42
- type=EmbeddingType.CODE,
45
+ requests = [EmbeddingRequest(id=doc.snippet_id, text=doc.text) for doc in data]
46
+
47
+ async for batch in self.embedding_provider.embed(requests):
48
+ for result in batch:
49
+ await self.embedding_repository.create_embedding(
50
+ Embedding(
51
+ snippet_id=result.id,
52
+ embedding=result.embedding,
53
+ type=self.embedding_type,
54
+ )
43
55
  )
44
- )
56
+ yield [IndexResult(snippet_id=result.id)]
45
57
 
46
58
  async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
47
59
  """Query the embedding model."""
48
- embedding = (await self.embedding_provider.embed([query]))[0]
60
+ # Build a single-item request and collect its embedding.
61
+ req = EmbeddingRequest(id=0, text=query)
62
+ embedding_vec: list[float] | None = None
63
+ async for batch in self.embedding_provider.embed([req]):
64
+ if batch:
65
+ embedding_vec = [float(v) for v in batch[0].embedding]
66
+ break
67
+
68
+ if not embedding_vec:
69
+ return []
70
+
49
71
  results = await self.embedding_repository.list_semantic_results(
50
- EmbeddingType.CODE, [float(x) for x in embedding], top_k
72
+ self.embedding_type, embedding_vec, top_k
51
73
  )
52
74
  return [
53
75
  VectorSearchResponse(snippet_id, score) for snippet_id, score in results
54
76
  ]
77
+
78
+ async def has_embedding(
79
+ self, snippet_id: int, embedding_type: EmbeddingType
80
+ ) -> bool:
81
+ """Check if a snippet has an embedding."""
82
+ return (
83
+ await self.embedding_repository.get_embedding_by_snippet_id_and_type(
84
+ snippet_id, embedding_type
85
+ )
86
+ is not None
87
+ )
@@ -1,8 +1,11 @@
1
1
  """Embedding service."""
2
2
 
3
3
  from abc import ABC, abstractmethod
4
+ from collections.abc import AsyncGenerator
4
5
  from typing import NamedTuple
5
6
 
7
+ from kodit.embedding.embedding_models import EmbeddingType
8
+
6
9
 
7
10
  class VectorSearchResponse(NamedTuple):
8
11
  """Embedding result."""
@@ -18,11 +21,19 @@ class VectorSearchRequest(NamedTuple):
18
21
  text: str
19
22
 
20
23
 
24
+ class IndexResult(NamedTuple):
25
+ """Result of indexing."""
26
+
27
+ snippet_id: int
28
+
29
+
21
30
  class VectorSearchService(ABC):
22
31
  """Semantic search service interface."""
23
32
 
24
33
  @abstractmethod
25
- async def index(self, data: list[VectorSearchRequest]) -> None:
34
+ def index(
35
+ self, data: list[VectorSearchRequest]
36
+ ) -> AsyncGenerator[list[IndexResult], None]:
26
37
  """Embed a list of documents.
27
38
 
28
39
  The embedding service accepts a massive list of id,strings to embed. Behind the
@@ -36,3 +47,9 @@ class VectorSearchService(ABC):
36
47
  @abstractmethod
37
48
  async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
38
49
  """Query the embedding model."""
50
+
51
+ @abstractmethod
52
+ async def has_embedding(
53
+ self, snippet_id: int, embedding_type: EmbeddingType
54
+ ) -> bool:
55
+ """Check if a snippet has an embedding."""