kodit 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (118) hide show
  1. kodit/_version.py +2 -2
  2. kodit/application/__init__.py +1 -0
  3. kodit/application/commands/__init__.py +1 -0
  4. kodit/application/commands/snippet_commands.py +22 -0
  5. kodit/application/services/__init__.py +1 -0
  6. kodit/application/services/indexing_application_service.py +363 -0
  7. kodit/application/services/snippet_application_service.py +143 -0
  8. kodit/cli.py +105 -82
  9. kodit/database.py +0 -22
  10. kodit/domain/__init__.py +1 -0
  11. kodit/{source/source_models.py → domain/entities.py} +88 -19
  12. kodit/domain/enums.py +9 -0
  13. kodit/domain/interfaces.py +27 -0
  14. kodit/domain/repositories.py +95 -0
  15. kodit/domain/services/__init__.py +1 -0
  16. kodit/domain/services/bm25_service.py +124 -0
  17. kodit/domain/services/embedding_service.py +155 -0
  18. kodit/domain/services/enrichment_service.py +48 -0
  19. kodit/domain/services/ignore_service.py +45 -0
  20. kodit/domain/services/indexing_service.py +203 -0
  21. kodit/domain/services/snippet_extraction_service.py +89 -0
  22. kodit/domain/services/source_service.py +83 -0
  23. kodit/domain/value_objects.py +215 -0
  24. kodit/infrastructure/__init__.py +1 -0
  25. kodit/infrastructure/bm25/__init__.py +1 -0
  26. kodit/infrastructure/bm25/bm25_factory.py +28 -0
  27. kodit/{bm25/local_bm25.py → infrastructure/bm25/local_bm25_repository.py} +33 -22
  28. kodit/{bm25/vectorchord_bm25.py → infrastructure/bm25/vectorchord_bm25_repository.py} +40 -35
  29. kodit/infrastructure/cloning/__init__.py +1 -0
  30. kodit/infrastructure/cloning/folder/__init__.py +1 -0
  31. kodit/infrastructure/cloning/folder/factory.py +119 -0
  32. kodit/infrastructure/cloning/folder/working_copy.py +38 -0
  33. kodit/infrastructure/cloning/git/__init__.py +1 -0
  34. kodit/infrastructure/cloning/git/factory.py +133 -0
  35. kodit/infrastructure/cloning/git/working_copy.py +32 -0
  36. kodit/infrastructure/cloning/metadata.py +127 -0
  37. kodit/infrastructure/embedding/__init__.py +1 -0
  38. kodit/infrastructure/embedding/embedding_factory.py +87 -0
  39. kodit/infrastructure/embedding/embedding_providers/__init__.py +1 -0
  40. kodit/infrastructure/embedding/embedding_providers/batching.py +93 -0
  41. kodit/infrastructure/embedding/embedding_providers/hash_embedding_provider.py +79 -0
  42. kodit/infrastructure/embedding/embedding_providers/local_embedding_provider.py +129 -0
  43. kodit/infrastructure/embedding/embedding_providers/openai_embedding_provider.py +113 -0
  44. kodit/infrastructure/embedding/local_vector_search_repository.py +114 -0
  45. kodit/{embedding/vectorchord_vector_search_service.py → infrastructure/embedding/vectorchord_vector_search_repository.py} +98 -32
  46. kodit/infrastructure/enrichment/__init__.py +1 -0
  47. kodit/{enrichment → infrastructure/enrichment}/enrichment_factory.py +28 -12
  48. kodit/infrastructure/enrichment/legacy_enrichment_models.py +42 -0
  49. kodit/infrastructure/enrichment/local_enrichment_provider.py +115 -0
  50. kodit/infrastructure/enrichment/null_enrichment_provider.py +25 -0
  51. kodit/infrastructure/enrichment/openai_enrichment_provider.py +89 -0
  52. kodit/infrastructure/git/__init__.py +1 -0
  53. kodit/{source/git.py → infrastructure/git/git_utils.py} +10 -2
  54. kodit/infrastructure/ignore/__init__.py +1 -0
  55. kodit/{source/ignore.py → infrastructure/ignore/ignore_pattern_provider.py} +23 -6
  56. kodit/infrastructure/indexing/__init__.py +1 -0
  57. kodit/infrastructure/indexing/fusion_service.py +55 -0
  58. kodit/infrastructure/indexing/index_repository.py +296 -0
  59. kodit/infrastructure/indexing/indexing_factory.py +111 -0
  60. kodit/infrastructure/snippet_extraction/__init__.py +1 -0
  61. kodit/infrastructure/snippet_extraction/language_detection_service.py +39 -0
  62. kodit/infrastructure/snippet_extraction/snippet_extraction_factory.py +95 -0
  63. kodit/infrastructure/snippet_extraction/snippet_query_provider.py +45 -0
  64. kodit/{snippets/method_snippets.py → infrastructure/snippet_extraction/tree_sitter_snippet_extractor.py} +123 -61
  65. kodit/infrastructure/sqlalchemy/__init__.py +1 -0
  66. kodit/{embedding → infrastructure/sqlalchemy}/embedding_repository.py +40 -24
  67. kodit/infrastructure/sqlalchemy/file_repository.py +73 -0
  68. kodit/infrastructure/sqlalchemy/repository.py +121 -0
  69. kodit/infrastructure/sqlalchemy/snippet_repository.py +75 -0
  70. kodit/infrastructure/ui/__init__.py +1 -0
  71. kodit/infrastructure/ui/progress.py +127 -0
  72. kodit/{util → infrastructure/ui}/spinner.py +19 -4
  73. kodit/mcp.py +50 -28
  74. kodit/migrations/env.py +1 -4
  75. kodit/reporting.py +78 -0
  76. {kodit-0.2.3.dist-info → kodit-0.2.5.dist-info}/METADATA +1 -1
  77. kodit-0.2.5.dist-info/RECORD +99 -0
  78. kodit/bm25/__init__.py +0 -1
  79. kodit/bm25/keyword_search_factory.py +0 -17
  80. kodit/bm25/keyword_search_service.py +0 -34
  81. kodit/embedding/__init__.py +0 -1
  82. kodit/embedding/embedding_factory.py +0 -63
  83. kodit/embedding/embedding_models.py +0 -28
  84. kodit/embedding/embedding_provider/__init__.py +0 -1
  85. kodit/embedding/embedding_provider/embedding_provider.py +0 -64
  86. kodit/embedding/embedding_provider/hash_embedding_provider.py +0 -77
  87. kodit/embedding/embedding_provider/local_embedding_provider.py +0 -64
  88. kodit/embedding/embedding_provider/openai_embedding_provider.py +0 -77
  89. kodit/embedding/local_vector_search_service.py +0 -54
  90. kodit/embedding/vector_search_service.py +0 -38
  91. kodit/enrichment/__init__.py +0 -1
  92. kodit/enrichment/enrichment_provider/__init__.py +0 -1
  93. kodit/enrichment/enrichment_provider/enrichment_provider.py +0 -16
  94. kodit/enrichment/enrichment_provider/local_enrichment_provider.py +0 -92
  95. kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +0 -81
  96. kodit/enrichment/enrichment_service.py +0 -33
  97. kodit/indexing/__init__.py +0 -1
  98. kodit/indexing/fusion.py +0 -67
  99. kodit/indexing/indexing_models.py +0 -43
  100. kodit/indexing/indexing_repository.py +0 -216
  101. kodit/indexing/indexing_service.py +0 -338
  102. kodit/snippets/__init__.py +0 -1
  103. kodit/snippets/languages/__init__.py +0 -53
  104. kodit/snippets/snippets.py +0 -50
  105. kodit/source/__init__.py +0 -1
  106. kodit/source/source_factories.py +0 -356
  107. kodit/source/source_repository.py +0 -169
  108. kodit/source/source_service.py +0 -150
  109. kodit/util/__init__.py +0 -1
  110. kodit-0.2.3.dist-info/RECORD +0 -71
  111. /kodit/{snippets → infrastructure/snippet_extraction}/languages/csharp.scm +0 -0
  112. /kodit/{snippets → infrastructure/snippet_extraction}/languages/go.scm +0 -0
  113. /kodit/{snippets → infrastructure/snippet_extraction}/languages/javascript.scm +0 -0
  114. /kodit/{snippets → infrastructure/snippet_extraction}/languages/python.scm +0 -0
  115. /kodit/{snippets → infrastructure/snippet_extraction}/languages/typescript.scm +0 -0
  116. {kodit-0.2.3.dist-info → kodit-0.2.5.dist-info}/WHEEL +0 -0
  117. {kodit-0.2.3.dist-info → kodit-0.2.5.dist-info}/entry_points.txt +0 -0
  118. {kodit-0.2.3.dist-info → kodit-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,79 @@
1
+ """Hash-based embedding provider for testing purposes."""
2
+
3
+ import hashlib
4
+ from collections.abc import AsyncGenerator
5
+
6
+ import structlog
7
+
8
+ from kodit.domain.services.embedding_service import EmbeddingProvider
9
+ from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
10
+
11
+ # Constants for different embedding sizes
12
+ TINY = 64
13
+ CODE = 1536
14
+
15
+
16
+ class HashEmbeddingProvider(EmbeddingProvider):
17
+ """Hash-based embedding that generates deterministic embeddings for testing."""
18
+
19
+ def __init__(self, embedding_size: int = CODE) -> None:
20
+ """Initialize the hash embedding provider.
21
+
22
+ Args:
23
+ embedding_size: The size of the embedding vectors to generate
24
+
25
+ """
26
+ self.embedding_size = embedding_size
27
+ self.log = structlog.get_logger(__name__)
28
+
29
+ def embed(
30
+ self, data: list[EmbeddingRequest]
31
+ ) -> AsyncGenerator[list[EmbeddingResponse], None]:
32
+ """Embed a list of strings using a simple hash-based approach."""
33
+ if not data:
34
+
35
+ async def empty_generator() -> AsyncGenerator[
36
+ list[EmbeddingResponse], None
37
+ ]:
38
+ if False:
39
+ yield []
40
+
41
+ return empty_generator()
42
+
43
+ # Process in batches
44
+ batch_size = 10
45
+
46
+ async def _embed_batches() -> AsyncGenerator[list[EmbeddingResponse], None]:
47
+ for i in range(0, len(data), batch_size):
48
+ batch = data[i : i + batch_size]
49
+ responses = []
50
+
51
+ for request in batch:
52
+ # Generate a deterministic embedding based on the text
53
+ embedding = self._generate_embedding(request.text)
54
+ responses.append(
55
+ EmbeddingResponse(
56
+ snippet_id=request.snippet_id, embedding=embedding
57
+ )
58
+ )
59
+
60
+ yield responses
61
+
62
+ return _embed_batches()
63
+
64
+ def _generate_embedding(self, text: str) -> list[float]:
65
+ """Generate a deterministic embedding for the given text."""
66
+ # Use SHA-256 hash of the text as a seed
67
+ hash_obj = hashlib.sha256(text.encode("utf-8"))
68
+ hash_bytes = hash_obj.digest()
69
+
70
+ # Convert hash bytes to a list of floats
71
+ embedding = []
72
+ for i in range(self.embedding_size):
73
+ # Use different bytes for each dimension
74
+ byte_index = i % len(hash_bytes)
75
+ # Convert byte to float between -1 and 1
76
+ value = (hash_bytes[byte_index] - 128) / 128.0
77
+ embedding.append(value)
78
+
79
+ return embedding
@@ -0,0 +1,129 @@
1
+ """Local embedding provider implementation."""
2
+
3
+ import os
4
+ from collections.abc import AsyncGenerator
5
+ from time import time
6
+ from typing import TYPE_CHECKING
7
+
8
+ import structlog
9
+
10
+ from kodit.domain.services.embedding_service import EmbeddingProvider
11
+ from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
12
+
13
+ from .batching import split_sub_batches
14
+
15
+ if TYPE_CHECKING:
16
+ from sentence_transformers import SentenceTransformer
17
+ from tiktoken import Encoding
18
+
19
+ # Constants for different embedding models
20
+ TINY = "tiny"
21
+ CODE = "code"
22
+ TEST = "test"
23
+
24
+ COMMON_EMBEDDING_MODELS = {
25
+ TINY: "ibm-granite/granite-embedding-30m-english",
26
+ CODE: "flax-sentence-embeddings/st-codesearch-distilroberta-base",
27
+ TEST: "minishlab/potion-base-4M",
28
+ }
29
+
30
+
31
+ class LocalEmbeddingProvider(EmbeddingProvider):
32
+ """Local embedding provider that uses sentence-transformers."""
33
+
34
+ def __init__(self, model_name: str = CODE) -> None:
35
+ """Initialize the local embedding provider.
36
+
37
+ Args:
38
+ model_name: The model name to use for embeddings. Can be a preset
39
+ ('tiny', 'code', 'test') or a full model name.
40
+
41
+ """
42
+ self.log = structlog.get_logger(__name__)
43
+ self.model_name = COMMON_EMBEDDING_MODELS.get(model_name, model_name)
44
+ self.encoding_name = "text-embedding-3-small"
45
+ self.embedding_model: SentenceTransformer | None = None
46
+ self.encoding: Encoding | None = None
47
+
48
+ def _encoding(self) -> "Encoding":
49
+ """Get the tiktoken encoding."""
50
+ if self.encoding is None:
51
+ from tiktoken import encoding_for_model
52
+
53
+ start_time = time()
54
+ self.encoding = encoding_for_model(self.encoding_name)
55
+ self.log.debug(
56
+ "Encoding loaded",
57
+ model_name=self.encoding_name,
58
+ duration=time() - start_time,
59
+ )
60
+ return self.encoding
61
+
62
+ def _model(self) -> "SentenceTransformer":
63
+ """Get the embedding model."""
64
+ if self.embedding_model is None:
65
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
66
+ from sentence_transformers import SentenceTransformer
67
+
68
+ start_time = time()
69
+ self.embedding_model = SentenceTransformer(
70
+ self.model_name,
71
+ trust_remote_code=True,
72
+ )
73
+ self.log.debug(
74
+ "Model loaded",
75
+ model_name=self.model_name,
76
+ duration=time() - start_time,
77
+ )
78
+ return self.embedding_model
79
+
80
+ async def embed(
81
+ self, data: list[EmbeddingRequest]
82
+ ) -> AsyncGenerator[list[EmbeddingResponse], None]:
83
+ """Embed a list of strings using the local model."""
84
+ if not data:
85
+ yield []
86
+
87
+ model = self._model()
88
+ encoding = self._encoding()
89
+
90
+ # Split into sub-batches based on token limits
91
+ batched_data = self._split_sub_batches(encoding, data)
92
+
93
+ for batch in batched_data:
94
+ try:
95
+ # Encode the texts using the model
96
+ embeddings = model.encode(
97
+ [item.text for item in batch],
98
+ show_progress_bar=False,
99
+ batch_size=4,
100
+ )
101
+
102
+ # Convert to our response format
103
+ responses = [
104
+ EmbeddingResponse(
105
+ snippet_id=item.snippet_id,
106
+ embedding=[float(x) for x in embedding],
107
+ )
108
+ for item, embedding in zip(batch, embeddings, strict=True)
109
+ ]
110
+
111
+ yield responses
112
+
113
+ except Exception as e:
114
+ self.log.exception("Error generating embeddings", error=str(e))
115
+ # Return zero embeddings on error
116
+ responses = [
117
+ EmbeddingResponse(
118
+ snippet_id=item.snippet_id,
119
+ embedding=[0.0] * 1536, # Default embedding size
120
+ )
121
+ for item in batch
122
+ ]
123
+ yield responses
124
+
125
+ def _split_sub_batches(
126
+ self, encoding: "Encoding", data: list[EmbeddingRequest]
127
+ ) -> list[list[EmbeddingRequest]]:
128
+ """Proxy to the shared batching utility (kept for backward-compat)."""
129
+ return split_sub_batches(encoding, data)
@@ -0,0 +1,113 @@
1
+ """OpenAI embedding provider implementation."""
2
+
3
+ import asyncio
4
+ from collections.abc import AsyncGenerator
5
+ from typing import Any
6
+
7
+ import structlog
8
+ import tiktoken
9
+ from tiktoken import Encoding
10
+
11
+ from kodit.domain.services.embedding_service import EmbeddingProvider
12
+ from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
13
+
14
+ from .batching import split_sub_batches
15
+
16
+ # Constants
17
+ MAX_TOKENS = 8192 # Conservative token limit for the embedding model
18
+ BATCH_SIZE = (
19
+ 10 # Maximum number of items per API call (keeps existing test expectations)
20
+ )
21
+ OPENAI_NUM_PARALLEL_TASKS = 25 # Semaphore limit for concurrent OpenAI requests
22
+
23
+
24
+ class OpenAIEmbeddingProvider(EmbeddingProvider):
25
+ """OpenAI embedding provider that uses OpenAI's embedding API."""
26
+
27
+ def __init__(
28
+ self, openai_client: Any, model_name: str = "text-embedding-3-small"
29
+ ) -> None:
30
+ """Initialize the OpenAI embedding provider.
31
+
32
+ Args:
33
+ openai_client: The OpenAI client instance
34
+ model_name: The model name to use for embeddings
35
+
36
+ """
37
+ self.openai_client = openai_client
38
+ self.model_name = model_name
39
+ self.log = structlog.get_logger(__name__)
40
+
41
+ # Lazily initialised token encoding
42
+ self._encoding: Encoding | None = None
43
+
44
+ # ---------------------------------------------------------------------
45
+ # Helper utilities
46
+ # ---------------------------------------------------------------------
47
+
48
+ def _get_encoding(self) -> "Encoding":
49
+ """Return (and cache) the tiktoken encoding for the chosen model."""
50
+ if self._encoding is None:
51
+ self._encoding = tiktoken.encoding_for_model(self.model_name)
52
+ return self._encoding
53
+
54
+ def _split_sub_batches(
55
+ self, encoding: "Encoding", data: list[EmbeddingRequest]
56
+ ) -> list[list[EmbeddingRequest]]:
57
+ """Proxy to the shared batching utility (kept for backward-compat)."""
58
+ return split_sub_batches(
59
+ encoding,
60
+ data,
61
+ max_tokens=MAX_TOKENS,
62
+ batch_size=BATCH_SIZE,
63
+ )
64
+
65
+ async def embed(
66
+ self, data: list[EmbeddingRequest]
67
+ ) -> AsyncGenerator[list[EmbeddingResponse], None]:
68
+ """Embed a list of strings using OpenAI's API."""
69
+ if not data:
70
+ yield []
71
+
72
+ encoding = self._get_encoding()
73
+
74
+ # First, split by token limits (and max batch size)
75
+ batched_data = self._split_sub_batches(encoding, data)
76
+
77
+ # -----------------------------------------------------------------
78
+ # Process batches concurrently (but bounded by a semaphore)
79
+ # -----------------------------------------------------------------
80
+
81
+ sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
82
+
83
+ async def _process_batch(
84
+ batch: list[EmbeddingRequest],
85
+ ) -> list[EmbeddingResponse]:
86
+ async with sem:
87
+ try:
88
+ response = await self.openai_client.embeddings.create(
89
+ model=self.model_name,
90
+ input=[item.text for item in batch],
91
+ )
92
+
93
+ return [
94
+ EmbeddingResponse(
95
+ snippet_id=item.snippet_id,
96
+ embedding=embedding.embedding,
97
+ )
98
+ for item, embedding in zip(batch, response.data, strict=True)
99
+ ]
100
+ except Exception as e:
101
+ self.log.exception("Error embedding batch", error=str(e))
102
+ # Fall back to zero embeddings so pipeline can continue
103
+ return [
104
+ EmbeddingResponse(
105
+ snippet_id=item.snippet_id,
106
+ embedding=[0.0] * 1536, # Default OpenAI dim
107
+ )
108
+ for item in batch
109
+ ]
110
+
111
+ tasks = [_process_batch(batch) for batch in batched_data]
112
+ for task in asyncio.as_completed(tasks):
113
+ yield await task
@@ -0,0 +1,114 @@
1
+ """Local vector search repository implementation."""
2
+
3
+ from collections.abc import AsyncGenerator
4
+
5
+ import structlog
6
+ import tiktoken
7
+
8
+ from kodit.domain.entities import Embedding, EmbeddingType
9
+ from kodit.domain.services.embedding_service import (
10
+ EmbeddingProvider,
11
+ VectorSearchRepository,
12
+ )
13
+ from kodit.domain.value_objects import (
14
+ EmbeddingRequest,
15
+ IndexResult,
16
+ VectorIndexRequest,
17
+ VectorSearchQueryRequest,
18
+ VectorSearchResult,
19
+ )
20
+ from kodit.infrastructure.sqlalchemy.embedding_repository import (
21
+ SqlAlchemyEmbeddingRepository,
22
+ )
23
+
24
+
25
+ class LocalVectorSearchRepository(VectorSearchRepository):
26
+ """Local vector search repository implementation."""
27
+
28
+ def __init__(
29
+ self,
30
+ embedding_repository: SqlAlchemyEmbeddingRepository,
31
+ embedding_provider: EmbeddingProvider,
32
+ embedding_type: EmbeddingType = EmbeddingType.CODE,
33
+ ) -> None:
34
+ """Initialize the local vector search repository.
35
+
36
+ Args:
37
+ embedding_repository: The SQLAlchemy embedding repository
38
+ embedding_provider: The embedding provider for generating embeddings
39
+ embedding_type: The type of embedding to use
40
+
41
+ """
42
+ self.log = structlog.get_logger(__name__)
43
+ self.embedding_repository = embedding_repository
44
+ self.embedding_provider = embedding_provider
45
+ self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
46
+ self.embedding_type = embedding_type
47
+
48
+ def index_documents(
49
+ self, request: VectorIndexRequest
50
+ ) -> AsyncGenerator[list[IndexResult], None]:
51
+ """Index documents for vector search."""
52
+ if not request.documents:
53
+
54
+ async def empty_generator() -> AsyncGenerator[list[IndexResult], None]:
55
+ if False:
56
+ yield []
57
+
58
+ return empty_generator()
59
+
60
+ # Convert to embedding requests
61
+ requests = [
62
+ EmbeddingRequest(snippet_id=doc.snippet_id, text=doc.text)
63
+ for doc in request.documents
64
+ ]
65
+
66
+ async def _index_batches() -> AsyncGenerator[list[IndexResult], None]:
67
+ async for batch in self.embedding_provider.embed(requests):
68
+ results = []
69
+ for result in batch:
70
+ await self.embedding_repository.create_embedding(
71
+ Embedding(
72
+ snippet_id=result.snippet_id,
73
+ embedding=result.embedding,
74
+ type=self.embedding_type,
75
+ )
76
+ )
77
+ results.append(IndexResult(snippet_id=result.snippet_id))
78
+ yield results
79
+
80
+ return _index_batches()
81
+
82
+ async def search(
83
+ self, request: VectorSearchQueryRequest
84
+ ) -> list[VectorSearchResult]:
85
+ """Search documents using vector similarity."""
86
+ # Build a single-item request and collect its embedding
87
+ req = EmbeddingRequest(snippet_id=0, text=request.query)
88
+ embedding_vec: list[float] | None = None
89
+ async for batch in self.embedding_provider.embed([req]):
90
+ if batch:
91
+ embedding_vec = [float(v) for v in batch[0].embedding]
92
+ break
93
+
94
+ if not embedding_vec:
95
+ return []
96
+
97
+ results = await self.embedding_repository.list_semantic_results(
98
+ self.embedding_type, embedding_vec, request.top_k
99
+ )
100
+ return [
101
+ VectorSearchResult(snippet_id=snippet_id, score=score)
102
+ for snippet_id, score in results
103
+ ]
104
+
105
+ async def has_embedding(
106
+ self, snippet_id: int, embedding_type: EmbeddingType
107
+ ) -> bool:
108
+ """Check if a snippet has an embedding."""
109
+ return (
110
+ await self.embedding_repository.get_embedding_by_snippet_id_and_type(
111
+ snippet_id, embedding_type
112
+ )
113
+ is not None
114
+ )
@@ -1,16 +1,23 @@
1
- """Vectorchord vector search."""
1
+ """VectorChord vector search repository implementation."""
2
2
 
3
+ from collections.abc import AsyncGenerator
3
4
  from typing import Any, Literal
4
5
 
5
6
  import structlog
6
7
  from sqlalchemy import Result, TextClause, text
7
8
  from sqlalchemy.ext.asyncio import AsyncSession
8
9
 
9
- from kodit.embedding.embedding_provider.embedding_provider import EmbeddingProvider
10
- from kodit.embedding.vector_search_service import (
11
- VectorSearchRequest,
12
- VectorSearchResponse,
13
- VectorSearchService,
10
+ from kodit.domain.entities import EmbeddingType
11
+ from kodit.domain.services.embedding_service import (
12
+ EmbeddingProvider,
13
+ VectorSearchRepository,
14
+ )
15
+ from kodit.domain.value_objects import (
16
+ EmbeddingRequest,
17
+ IndexResult,
18
+ VectorIndexRequest,
19
+ VectorSearchQueryRequest,
20
+ VectorSearchResult,
14
21
  )
15
22
 
16
23
  # SQL Queries
@@ -52,11 +59,15 @@ ORDER BY score ASC
52
59
  LIMIT :top_k;
53
60
  """
54
61
 
62
+ CHECK_VCHORD_EMBEDDING_EXISTS = """
63
+ SELECT EXISTS(SELECT 1 FROM {TABLE_NAME} WHERE snippet_id = :snippet_id)
64
+ """
65
+
55
66
  TaskName = Literal["code", "text"]
56
67
 
57
68
 
58
- class VectorChordVectorSearchService(VectorSearchService):
59
- """VectorChord vector search."""
69
+ class VectorChordVectorSearchRepository(VectorSearchRepository):
70
+ """VectorChord vector search repository implementation."""
60
71
 
61
72
  def __init__(
62
73
  self,
@@ -64,7 +75,14 @@ class VectorChordVectorSearchService(VectorSearchService):
64
75
  session: AsyncSession,
65
76
  embedding_provider: EmbeddingProvider,
66
77
  ) -> None:
67
- """Initialize the VectorChord BM25."""
78
+ """Initialize the VectorChord vector search repository.
79
+
80
+ Args:
81
+ task_name: The task name (code or text)
82
+ session: The SQLAlchemy async session
83
+ embedding_provider: The embedding provider for generating embeddings
84
+
85
+ """
68
86
  self.embedding_provider = embedding_provider
69
87
  self._session = session
70
88
  self._initialized = False
@@ -89,7 +107,15 @@ class VectorChordVectorSearchService(VectorSearchService):
89
107
 
90
108
  async def _create_tables(self) -> None:
91
109
  """Create the necessary tables."""
92
- vector_dim = (await self.embedding_provider.embed(["dimension"]))[0]
110
+ req = EmbeddingRequest(snippet_id=0, text="dimension")
111
+ vector_dim: list[float] | None = None
112
+ async for batch in self.embedding_provider.embed([req]):
113
+ if batch:
114
+ vector_dim = batch[0].embedding
115
+ break
116
+ if vector_dim is None:
117
+ msg = "Failed to obtain embedding dimension from provider"
118
+ raise RuntimeError(msg)
93
119
  await self._session.execute(
94
120
  text(
95
121
  f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
@@ -130,35 +156,75 @@ class VectorChordVectorSearchService(VectorSearchService):
130
156
  """Commit the session."""
131
157
  await self._session.commit()
132
158
 
133
- async def index(self, data: list[VectorSearchRequest]) -> None:
134
- """Embed a list of documents."""
135
- if not data or len(data) == 0:
136
- self.log.warning("Embedding data is empty, skipping embedding")
137
- return
138
-
139
- embeddings = await self.embedding_provider.embed([doc.text for doc in data])
140
- # Execute inserts
141
- await self._execute(
142
- text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
143
- [
144
- {"snippet_id": doc.snippet_id, "embedding": str(embedding)}
145
- for doc, embedding in zip(data, embeddings, strict=True)
146
- ],
147
- )
148
- await self._commit()
159
+ def index_documents(
160
+ self, request: VectorIndexRequest
161
+ ) -> AsyncGenerator[list[IndexResult], None]:
162
+ """Index documents for vector search."""
163
+ if not request.documents:
164
+
165
+ async def empty_generator() -> AsyncGenerator[list[IndexResult], None]:
166
+ if False:
167
+ yield []
149
168
 
150
- async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
151
- """Query the embedding model."""
152
- embedding = await self.embedding_provider.embed([query])
153
- if len(embedding) == 0 or len(embedding[0]) == 0:
169
+ return empty_generator()
170
+
171
+ # Convert to embedding requests
172
+ requests = [
173
+ EmbeddingRequest(snippet_id=doc.snippet_id, text=doc.text)
174
+ for doc in request.documents
175
+ ]
176
+
177
+ async def _index_batches() -> AsyncGenerator[list[IndexResult], None]:
178
+ async for batch in self.embedding_provider.embed(requests):
179
+ await self._execute(
180
+ text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
181
+ [
182
+ {
183
+ "snippet_id": result.snippet_id,
184
+ "embedding": str(result.embedding),
185
+ }
186
+ for result in batch
187
+ ],
188
+ )
189
+ await self._commit()
190
+ yield [IndexResult(snippet_id=result.snippet_id) for result in batch]
191
+
192
+ return _index_batches()
193
+
194
+ async def search(
195
+ self, request: VectorSearchQueryRequest
196
+ ) -> list[VectorSearchResult]:
197
+ """Search documents using vector similarity."""
198
+ req = EmbeddingRequest(snippet_id=0, text=request.query)
199
+ embedding_vec: list[float] | None = None
200
+ async for batch in self.embedding_provider.embed([req]):
201
+ if batch:
202
+ embedding_vec = batch[0].embedding
203
+ break
204
+
205
+ if not embedding_vec:
154
206
  return []
155
207
  result = await self._execute(
156
208
  text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
157
- {"query": str(embedding[0]), "top_k": top_k},
209
+ {"query": str(embedding_vec), "top_k": request.top_k},
158
210
  )
159
211
  rows = result.mappings().all()
160
212
 
161
213
  return [
162
- VectorSearchResponse(snippet_id=row["snippet_id"], score=row["score"])
214
+ VectorSearchResult(snippet_id=row["snippet_id"], score=row["score"])
163
215
  for row in rows
164
216
  ]
217
+
218
+ async def has_embedding(
219
+ self, snippet_id: int, embedding_type: EmbeddingType
220
+ ) -> bool:
221
+ """Check if a snippet has an embedding."""
222
+ # For VectorChord, we check if the snippet exists in the table
223
+ # Note: embedding_type is ignored since VectorChord uses separate
224
+ # tables per task
225
+ # ruff: noqa: ARG002
226
+ result = await self._execute(
227
+ text(CHECK_VCHORD_EMBEDDING_EXISTS.format(TABLE_NAME=self.table_name)),
228
+ {"snippet_id": snippet_id},
229
+ )
230
+ return bool(result.scalar())
@@ -0,0 +1 @@
1
+ """Infrastructure enrichment module."""
@@ -1,28 +1,42 @@
1
- """Embedding service."""
1
+ """Enrichment factory for creating enrichment domain services."""
2
2
 
3
3
  from kodit.config import AppContext, Endpoint
4
- from kodit.enrichment.enrichment_provider.local_enrichment_provider import (
4
+ from kodit.domain.services.enrichment_service import EnrichmentDomainService
5
+ from kodit.infrastructure.enrichment.local_enrichment_provider import (
5
6
  LocalEnrichmentProvider,
6
7
  )
7
- from kodit.enrichment.enrichment_provider.openai_enrichment_provider import (
8
+ from kodit.infrastructure.enrichment.openai_enrichment_provider import (
8
9
  OpenAIEnrichmentProvider,
9
10
  )
10
- from kodit.enrichment.enrichment_service import (
11
- EnrichmentService,
12
- LLMEnrichmentService,
13
- )
14
11
  from kodit.log import log_event
15
12
 
16
13
 
17
14
  def _get_endpoint_configuration(app_context: AppContext) -> Endpoint | None:
18
- """Get the endpoint configuration for the enrichment service."""
15
+ """Get the endpoint configuration for the enrichment service.
16
+
17
+ Args:
18
+ app_context: The application context.
19
+
20
+ Returns:
21
+ The endpoint configuration or None.
22
+
23
+ """
19
24
  return app_context.enrichment_endpoint or app_context.default_endpoint or None
20
25
 
21
26
 
22
- def enrichment_factory(app_context: AppContext) -> EnrichmentService:
23
- """Create an enrichment service."""
27
+ def create_enrichment_domain_service(
28
+ app_context: AppContext,
29
+ ) -> EnrichmentDomainService:
30
+ """Create an enrichment domain service.
31
+
32
+ Args:
33
+ app_context: The application context.
34
+
35
+ Returns:
36
+ An enrichment domain service instance.
37
+
38
+ """
24
39
  endpoint = _get_endpoint_configuration(app_context)
25
- endpoint = app_context.enrichment_endpoint or app_context.default_endpoint or None
26
40
 
27
41
  if endpoint and endpoint.type == "openai":
28
42
  log_event("kodit.enrichment", {"provider": "openai"})
@@ -32,6 +46,8 @@ def enrichment_factory(app_context: AppContext) -> EnrichmentService:
32
46
  openai_client=AsyncOpenAI(
33
47
  api_key=endpoint.api_key or "default",
34
48
  base_url=endpoint.base_url or "https://api.openai.com/v1",
49
+ timeout=60,
50
+ max_retries=2,
35
51
  ),
36
52
  model_name=endpoint.model or "gpt-4o-mini",
37
53
  )
@@ -39,4 +55,4 @@ def enrichment_factory(app_context: AppContext) -> EnrichmentService:
39
55
  log_event("kodit.enrichment", {"provider": "local"})
40
56
  enrichment_provider = LocalEnrichmentProvider()
41
57
 
42
- return LLMEnrichmentService(enrichment_provider=enrichment_provider)
58
+ return EnrichmentDomainService(enrichment_provider=enrichment_provider)