kodit 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/application/__init__.py +1 -0
- kodit/application/commands/__init__.py +1 -0
- kodit/application/commands/snippet_commands.py +22 -0
- kodit/application/services/__init__.py +1 -0
- kodit/application/services/indexing_application_service.py +363 -0
- kodit/application/services/snippet_application_service.py +143 -0
- kodit/cli.py +105 -82
- kodit/database.py +0 -22
- kodit/domain/__init__.py +1 -0
- kodit/{source/source_models.py → domain/entities.py} +88 -19
- kodit/domain/enums.py +9 -0
- kodit/domain/interfaces.py +27 -0
- kodit/domain/repositories.py +95 -0
- kodit/domain/services/__init__.py +1 -0
- kodit/domain/services/bm25_service.py +124 -0
- kodit/domain/services/embedding_service.py +155 -0
- kodit/domain/services/enrichment_service.py +48 -0
- kodit/domain/services/ignore_service.py +45 -0
- kodit/domain/services/indexing_service.py +203 -0
- kodit/domain/services/snippet_extraction_service.py +89 -0
- kodit/domain/services/source_service.py +83 -0
- kodit/domain/value_objects.py +215 -0
- kodit/infrastructure/__init__.py +1 -0
- kodit/infrastructure/bm25/__init__.py +1 -0
- kodit/infrastructure/bm25/bm25_factory.py +28 -0
- kodit/{bm25/local_bm25.py → infrastructure/bm25/local_bm25_repository.py} +33 -22
- kodit/{bm25/vectorchord_bm25.py → infrastructure/bm25/vectorchord_bm25_repository.py} +40 -35
- kodit/infrastructure/cloning/__init__.py +1 -0
- kodit/infrastructure/cloning/folder/__init__.py +1 -0
- kodit/infrastructure/cloning/folder/factory.py +119 -0
- kodit/infrastructure/cloning/folder/working_copy.py +38 -0
- kodit/infrastructure/cloning/git/__init__.py +1 -0
- kodit/infrastructure/cloning/git/factory.py +133 -0
- kodit/infrastructure/cloning/git/working_copy.py +32 -0
- kodit/infrastructure/cloning/metadata.py +127 -0
- kodit/infrastructure/embedding/__init__.py +1 -0
- kodit/infrastructure/embedding/embedding_factory.py +87 -0
- kodit/infrastructure/embedding/embedding_providers/__init__.py +1 -0
- kodit/infrastructure/embedding/embedding_providers/batching.py +93 -0
- kodit/infrastructure/embedding/embedding_providers/hash_embedding_provider.py +79 -0
- kodit/infrastructure/embedding/embedding_providers/local_embedding_provider.py +129 -0
- kodit/infrastructure/embedding/embedding_providers/openai_embedding_provider.py +113 -0
- kodit/infrastructure/embedding/local_vector_search_repository.py +114 -0
- kodit/{embedding/vectorchord_vector_search_service.py → infrastructure/embedding/vectorchord_vector_search_repository.py} +65 -46
- kodit/infrastructure/enrichment/__init__.py +1 -0
- kodit/{enrichment → infrastructure/enrichment}/enrichment_factory.py +28 -12
- kodit/infrastructure/enrichment/legacy_enrichment_models.py +42 -0
- kodit/{enrichment/enrichment_provider → infrastructure/enrichment}/local_enrichment_provider.py +38 -26
- kodit/infrastructure/enrichment/null_enrichment_provider.py +25 -0
- kodit/infrastructure/enrichment/openai_enrichment_provider.py +89 -0
- kodit/infrastructure/git/__init__.py +1 -0
- kodit/{source/git.py → infrastructure/git/git_utils.py} +10 -2
- kodit/infrastructure/ignore/__init__.py +1 -0
- kodit/{source/ignore.py → infrastructure/ignore/ignore_pattern_provider.py} +23 -6
- kodit/infrastructure/indexing/__init__.py +1 -0
- kodit/infrastructure/indexing/fusion_service.py +55 -0
- kodit/infrastructure/indexing/index_repository.py +296 -0
- kodit/infrastructure/indexing/indexing_factory.py +111 -0
- kodit/infrastructure/snippet_extraction/__init__.py +1 -0
- kodit/infrastructure/snippet_extraction/language_detection_service.py +39 -0
- kodit/infrastructure/snippet_extraction/snippet_extraction_factory.py +95 -0
- kodit/infrastructure/snippet_extraction/snippet_query_provider.py +45 -0
- kodit/{snippets/method_snippets.py → infrastructure/snippet_extraction/tree_sitter_snippet_extractor.py} +123 -61
- kodit/infrastructure/sqlalchemy/__init__.py +1 -0
- kodit/{embedding → infrastructure/sqlalchemy}/embedding_repository.py +40 -24
- kodit/infrastructure/sqlalchemy/file_repository.py +73 -0
- kodit/infrastructure/sqlalchemy/repository.py +121 -0
- kodit/infrastructure/sqlalchemy/snippet_repository.py +75 -0
- kodit/infrastructure/ui/__init__.py +1 -0
- kodit/infrastructure/ui/progress.py +127 -0
- kodit/{util → infrastructure/ui}/spinner.py +19 -4
- kodit/mcp.py +50 -28
- kodit/migrations/env.py +1 -4
- kodit/reporting.py +78 -0
- {kodit-0.2.4.dist-info → kodit-0.2.5.dist-info}/METADATA +1 -1
- kodit-0.2.5.dist-info/RECORD +99 -0
- kodit/bm25/__init__.py +0 -1
- kodit/bm25/keyword_search_factory.py +0 -17
- kodit/bm25/keyword_search_service.py +0 -34
- kodit/embedding/__init__.py +0 -1
- kodit/embedding/embedding_factory.py +0 -69
- kodit/embedding/embedding_models.py +0 -28
- kodit/embedding/embedding_provider/__init__.py +0 -1
- kodit/embedding/embedding_provider/embedding_provider.py +0 -92
- kodit/embedding/embedding_provider/hash_embedding_provider.py +0 -86
- kodit/embedding/embedding_provider/local_embedding_provider.py +0 -96
- kodit/embedding/embedding_provider/openai_embedding_provider.py +0 -73
- kodit/embedding/local_vector_search_service.py +0 -87
- kodit/embedding/vector_search_service.py +0 -55
- kodit/enrichment/__init__.py +0 -1
- kodit/enrichment/enrichment_provider/__init__.py +0 -1
- kodit/enrichment/enrichment_provider/enrichment_provider.py +0 -36
- kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +0 -79
- kodit/enrichment/enrichment_service.py +0 -45
- kodit/indexing/__init__.py +0 -1
- kodit/indexing/fusion.py +0 -67
- kodit/indexing/indexing_models.py +0 -43
- kodit/indexing/indexing_repository.py +0 -216
- kodit/indexing/indexing_service.py +0 -344
- kodit/snippets/__init__.py +0 -1
- kodit/snippets/languages/__init__.py +0 -53
- kodit/snippets/snippets.py +0 -50
- kodit/source/__init__.py +0 -1
- kodit/source/source_factories.py +0 -356
- kodit/source/source_repository.py +0 -169
- kodit/source/source_service.py +0 -150
- kodit/util/__init__.py +0 -1
- kodit-0.2.4.dist-info/RECORD +0 -71
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/csharp.scm +0 -0
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/go.scm +0 -0
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/javascript.scm +0 -0
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/python.scm +0 -0
- /kodit/{snippets → infrastructure/snippet_extraction}/languages/typescript.scm +0 -0
- {kodit-0.2.4.dist-info → kodit-0.2.5.dist-info}/WHEEL +0 -0
- {kodit-0.2.4.dist-info → kodit-0.2.5.dist-info}/entry_points.txt +0 -0
- {kodit-0.2.4.dist-info → kodit-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Hash-based embedding provider for testing purposes."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
|
|
8
|
+
from kodit.domain.services.embedding_service import EmbeddingProvider
|
|
9
|
+
from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
|
|
10
|
+
|
|
11
|
+
# Constants for different embedding sizes
|
|
12
|
+
TINY = 64
|
|
13
|
+
CODE = 1536
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HashEmbeddingProvider(EmbeddingProvider):
|
|
17
|
+
"""Hash-based embedding that generates deterministic embeddings for testing."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, embedding_size: int = CODE) -> None:
|
|
20
|
+
"""Initialize the hash embedding provider.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
embedding_size: The size of the embedding vectors to generate
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
self.embedding_size = embedding_size
|
|
27
|
+
self.log = structlog.get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
def embed(
|
|
30
|
+
self, data: list[EmbeddingRequest]
|
|
31
|
+
) -> AsyncGenerator[list[EmbeddingResponse], None]:
|
|
32
|
+
"""Embed a list of strings using a simple hash-based approach."""
|
|
33
|
+
if not data:
|
|
34
|
+
|
|
35
|
+
async def empty_generator() -> AsyncGenerator[
|
|
36
|
+
list[EmbeddingResponse], None
|
|
37
|
+
]:
|
|
38
|
+
if False:
|
|
39
|
+
yield []
|
|
40
|
+
|
|
41
|
+
return empty_generator()
|
|
42
|
+
|
|
43
|
+
# Process in batches
|
|
44
|
+
batch_size = 10
|
|
45
|
+
|
|
46
|
+
async def _embed_batches() -> AsyncGenerator[list[EmbeddingResponse], None]:
|
|
47
|
+
for i in range(0, len(data), batch_size):
|
|
48
|
+
batch = data[i : i + batch_size]
|
|
49
|
+
responses = []
|
|
50
|
+
|
|
51
|
+
for request in batch:
|
|
52
|
+
# Generate a deterministic embedding based on the text
|
|
53
|
+
embedding = self._generate_embedding(request.text)
|
|
54
|
+
responses.append(
|
|
55
|
+
EmbeddingResponse(
|
|
56
|
+
snippet_id=request.snippet_id, embedding=embedding
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
yield responses
|
|
61
|
+
|
|
62
|
+
return _embed_batches()
|
|
63
|
+
|
|
64
|
+
def _generate_embedding(self, text: str) -> list[float]:
|
|
65
|
+
"""Generate a deterministic embedding for the given text."""
|
|
66
|
+
# Use SHA-256 hash of the text as a seed
|
|
67
|
+
hash_obj = hashlib.sha256(text.encode("utf-8"))
|
|
68
|
+
hash_bytes = hash_obj.digest()
|
|
69
|
+
|
|
70
|
+
# Convert hash bytes to a list of floats
|
|
71
|
+
embedding = []
|
|
72
|
+
for i in range(self.embedding_size):
|
|
73
|
+
# Use different bytes for each dimension
|
|
74
|
+
byte_index = i % len(hash_bytes)
|
|
75
|
+
# Convert byte to float between -1 and 1
|
|
76
|
+
value = (hash_bytes[byte_index] - 128) / 128.0
|
|
77
|
+
embedding.append(value)
|
|
78
|
+
|
|
79
|
+
return embedding
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Local embedding provider implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from time import time
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import structlog
|
|
9
|
+
|
|
10
|
+
from kodit.domain.services.embedding_service import EmbeddingProvider
|
|
11
|
+
from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
|
|
12
|
+
|
|
13
|
+
from .batching import split_sub_batches
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from sentence_transformers import SentenceTransformer
|
|
17
|
+
from tiktoken import Encoding
|
|
18
|
+
|
|
19
|
+
# Constants for different embedding models
|
|
20
|
+
TINY = "tiny"
|
|
21
|
+
CODE = "code"
|
|
22
|
+
TEST = "test"
|
|
23
|
+
|
|
24
|
+
COMMON_EMBEDDING_MODELS = {
|
|
25
|
+
TINY: "ibm-granite/granite-embedding-30m-english",
|
|
26
|
+
CODE: "flax-sentence-embeddings/st-codesearch-distilroberta-base",
|
|
27
|
+
TEST: "minishlab/potion-base-4M",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LocalEmbeddingProvider(EmbeddingProvider):
|
|
32
|
+
"""Local embedding provider that uses sentence-transformers."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, model_name: str = CODE) -> None:
|
|
35
|
+
"""Initialize the local embedding provider.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model_name: The model name to use for embeddings. Can be a preset
|
|
39
|
+
('tiny', 'code', 'test') or a full model name.
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
self.log = structlog.get_logger(__name__)
|
|
43
|
+
self.model_name = COMMON_EMBEDDING_MODELS.get(model_name, model_name)
|
|
44
|
+
self.encoding_name = "text-embedding-3-small"
|
|
45
|
+
self.embedding_model: SentenceTransformer | None = None
|
|
46
|
+
self.encoding: Encoding | None = None
|
|
47
|
+
|
|
48
|
+
def _encoding(self) -> "Encoding":
|
|
49
|
+
"""Get the tiktoken encoding."""
|
|
50
|
+
if self.encoding is None:
|
|
51
|
+
from tiktoken import encoding_for_model
|
|
52
|
+
|
|
53
|
+
start_time = time()
|
|
54
|
+
self.encoding = encoding_for_model(self.encoding_name)
|
|
55
|
+
self.log.debug(
|
|
56
|
+
"Encoding loaded",
|
|
57
|
+
model_name=self.encoding_name,
|
|
58
|
+
duration=time() - start_time,
|
|
59
|
+
)
|
|
60
|
+
return self.encoding
|
|
61
|
+
|
|
62
|
+
def _model(self) -> "SentenceTransformer":
|
|
63
|
+
"""Get the embedding model."""
|
|
64
|
+
if self.embedding_model is None:
|
|
65
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
66
|
+
from sentence_transformers import SentenceTransformer
|
|
67
|
+
|
|
68
|
+
start_time = time()
|
|
69
|
+
self.embedding_model = SentenceTransformer(
|
|
70
|
+
self.model_name,
|
|
71
|
+
trust_remote_code=True,
|
|
72
|
+
)
|
|
73
|
+
self.log.debug(
|
|
74
|
+
"Model loaded",
|
|
75
|
+
model_name=self.model_name,
|
|
76
|
+
duration=time() - start_time,
|
|
77
|
+
)
|
|
78
|
+
return self.embedding_model
|
|
79
|
+
|
|
80
|
+
async def embed(
|
|
81
|
+
self, data: list[EmbeddingRequest]
|
|
82
|
+
) -> AsyncGenerator[list[EmbeddingResponse], None]:
|
|
83
|
+
"""Embed a list of strings using the local model."""
|
|
84
|
+
if not data:
|
|
85
|
+
yield []
|
|
86
|
+
|
|
87
|
+
model = self._model()
|
|
88
|
+
encoding = self._encoding()
|
|
89
|
+
|
|
90
|
+
# Split into sub-batches based on token limits
|
|
91
|
+
batched_data = self._split_sub_batches(encoding, data)
|
|
92
|
+
|
|
93
|
+
for batch in batched_data:
|
|
94
|
+
try:
|
|
95
|
+
# Encode the texts using the model
|
|
96
|
+
embeddings = model.encode(
|
|
97
|
+
[item.text for item in batch],
|
|
98
|
+
show_progress_bar=False,
|
|
99
|
+
batch_size=4,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Convert to our response format
|
|
103
|
+
responses = [
|
|
104
|
+
EmbeddingResponse(
|
|
105
|
+
snippet_id=item.snippet_id,
|
|
106
|
+
embedding=[float(x) for x in embedding],
|
|
107
|
+
)
|
|
108
|
+
for item, embedding in zip(batch, embeddings, strict=True)
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
yield responses
|
|
112
|
+
|
|
113
|
+
except Exception as e:
|
|
114
|
+
self.log.exception("Error generating embeddings", error=str(e))
|
|
115
|
+
# Return zero embeddings on error
|
|
116
|
+
responses = [
|
|
117
|
+
EmbeddingResponse(
|
|
118
|
+
snippet_id=item.snippet_id,
|
|
119
|
+
embedding=[0.0] * 1536, # Default embedding size
|
|
120
|
+
)
|
|
121
|
+
for item in batch
|
|
122
|
+
]
|
|
123
|
+
yield responses
|
|
124
|
+
|
|
125
|
+
def _split_sub_batches(
|
|
126
|
+
self, encoding: "Encoding", data: list[EmbeddingRequest]
|
|
127
|
+
) -> list[list[EmbeddingRequest]]:
|
|
128
|
+
"""Proxy to the shared batching utility (kept for backward-compat)."""
|
|
129
|
+
return split_sub_batches(encoding, data)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""OpenAI embedding provider implementation."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import structlog
|
|
8
|
+
import tiktoken
|
|
9
|
+
from tiktoken import Encoding
|
|
10
|
+
|
|
11
|
+
from kodit.domain.services.embedding_service import EmbeddingProvider
|
|
12
|
+
from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
|
|
13
|
+
|
|
14
|
+
from .batching import split_sub_batches
|
|
15
|
+
|
|
16
|
+
# Constants
|
|
17
|
+
MAX_TOKENS = 8192 # Conservative token limit for the embedding model
|
|
18
|
+
BATCH_SIZE = (
|
|
19
|
+
10 # Maximum number of items per API call (keeps existing test expectations)
|
|
20
|
+
)
|
|
21
|
+
OPENAI_NUM_PARALLEL_TASKS = 25 # Semaphore limit for concurrent OpenAI requests
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
25
|
+
"""OpenAI embedding provider that uses OpenAI's embedding API."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self, openai_client: Any, model_name: str = "text-embedding-3-small"
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initialize the OpenAI embedding provider.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
openai_client: The OpenAI client instance
|
|
34
|
+
model_name: The model name to use for embeddings
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
self.openai_client = openai_client
|
|
38
|
+
self.model_name = model_name
|
|
39
|
+
self.log = structlog.get_logger(__name__)
|
|
40
|
+
|
|
41
|
+
# Lazily initialised token encoding
|
|
42
|
+
self._encoding: Encoding | None = None
|
|
43
|
+
|
|
44
|
+
# ---------------------------------------------------------------------
|
|
45
|
+
# Helper utilities
|
|
46
|
+
# ---------------------------------------------------------------------
|
|
47
|
+
|
|
48
|
+
def _get_encoding(self) -> "Encoding":
|
|
49
|
+
"""Return (and cache) the tiktoken encoding for the chosen model."""
|
|
50
|
+
if self._encoding is None:
|
|
51
|
+
self._encoding = tiktoken.encoding_for_model(self.model_name)
|
|
52
|
+
return self._encoding
|
|
53
|
+
|
|
54
|
+
def _split_sub_batches(
|
|
55
|
+
self, encoding: "Encoding", data: list[EmbeddingRequest]
|
|
56
|
+
) -> list[list[EmbeddingRequest]]:
|
|
57
|
+
"""Proxy to the shared batching utility (kept for backward-compat)."""
|
|
58
|
+
return split_sub_batches(
|
|
59
|
+
encoding,
|
|
60
|
+
data,
|
|
61
|
+
max_tokens=MAX_TOKENS,
|
|
62
|
+
batch_size=BATCH_SIZE,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
async def embed(
|
|
66
|
+
self, data: list[EmbeddingRequest]
|
|
67
|
+
) -> AsyncGenerator[list[EmbeddingResponse], None]:
|
|
68
|
+
"""Embed a list of strings using OpenAI's API."""
|
|
69
|
+
if not data:
|
|
70
|
+
yield []
|
|
71
|
+
|
|
72
|
+
encoding = self._get_encoding()
|
|
73
|
+
|
|
74
|
+
# First, split by token limits (and max batch size)
|
|
75
|
+
batched_data = self._split_sub_batches(encoding, data)
|
|
76
|
+
|
|
77
|
+
# -----------------------------------------------------------------
|
|
78
|
+
# Process batches concurrently (but bounded by a semaphore)
|
|
79
|
+
# -----------------------------------------------------------------
|
|
80
|
+
|
|
81
|
+
sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
|
|
82
|
+
|
|
83
|
+
async def _process_batch(
|
|
84
|
+
batch: list[EmbeddingRequest],
|
|
85
|
+
) -> list[EmbeddingResponse]:
|
|
86
|
+
async with sem:
|
|
87
|
+
try:
|
|
88
|
+
response = await self.openai_client.embeddings.create(
|
|
89
|
+
model=self.model_name,
|
|
90
|
+
input=[item.text for item in batch],
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return [
|
|
94
|
+
EmbeddingResponse(
|
|
95
|
+
snippet_id=item.snippet_id,
|
|
96
|
+
embedding=embedding.embedding,
|
|
97
|
+
)
|
|
98
|
+
for item, embedding in zip(batch, response.data, strict=True)
|
|
99
|
+
]
|
|
100
|
+
except Exception as e:
|
|
101
|
+
self.log.exception("Error embedding batch", error=str(e))
|
|
102
|
+
# Fall back to zero embeddings so pipeline can continue
|
|
103
|
+
return [
|
|
104
|
+
EmbeddingResponse(
|
|
105
|
+
snippet_id=item.snippet_id,
|
|
106
|
+
embedding=[0.0] * 1536, # Default OpenAI dim
|
|
107
|
+
)
|
|
108
|
+
for item in batch
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
tasks = [_process_batch(batch) for batch in batched_data]
|
|
112
|
+
for task in asyncio.as_completed(tasks):
|
|
113
|
+
yield await task
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Local vector search repository implementation."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
import tiktoken
|
|
7
|
+
|
|
8
|
+
from kodit.domain.entities import Embedding, EmbeddingType
|
|
9
|
+
from kodit.domain.services.embedding_service import (
|
|
10
|
+
EmbeddingProvider,
|
|
11
|
+
VectorSearchRepository,
|
|
12
|
+
)
|
|
13
|
+
from kodit.domain.value_objects import (
|
|
14
|
+
EmbeddingRequest,
|
|
15
|
+
IndexResult,
|
|
16
|
+
VectorIndexRequest,
|
|
17
|
+
VectorSearchQueryRequest,
|
|
18
|
+
VectorSearchResult,
|
|
19
|
+
)
|
|
20
|
+
from kodit.infrastructure.sqlalchemy.embedding_repository import (
|
|
21
|
+
SqlAlchemyEmbeddingRepository,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LocalVectorSearchRepository(VectorSearchRepository):
|
|
26
|
+
"""Local vector search repository implementation."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
embedding_repository: SqlAlchemyEmbeddingRepository,
|
|
31
|
+
embedding_provider: EmbeddingProvider,
|
|
32
|
+
embedding_type: EmbeddingType = EmbeddingType.CODE,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initialize the local vector search repository.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
embedding_repository: The SQLAlchemy embedding repository
|
|
38
|
+
embedding_provider: The embedding provider for generating embeddings
|
|
39
|
+
embedding_type: The type of embedding to use
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
self.log = structlog.get_logger(__name__)
|
|
43
|
+
self.embedding_repository = embedding_repository
|
|
44
|
+
self.embedding_provider = embedding_provider
|
|
45
|
+
self.encoding = tiktoken.encoding_for_model("text-embedding-3-small")
|
|
46
|
+
self.embedding_type = embedding_type
|
|
47
|
+
|
|
48
|
+
def index_documents(
|
|
49
|
+
self, request: VectorIndexRequest
|
|
50
|
+
) -> AsyncGenerator[list[IndexResult], None]:
|
|
51
|
+
"""Index documents for vector search."""
|
|
52
|
+
if not request.documents:
|
|
53
|
+
|
|
54
|
+
async def empty_generator() -> AsyncGenerator[list[IndexResult], None]:
|
|
55
|
+
if False:
|
|
56
|
+
yield []
|
|
57
|
+
|
|
58
|
+
return empty_generator()
|
|
59
|
+
|
|
60
|
+
# Convert to embedding requests
|
|
61
|
+
requests = [
|
|
62
|
+
EmbeddingRequest(snippet_id=doc.snippet_id, text=doc.text)
|
|
63
|
+
for doc in request.documents
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
async def _index_batches() -> AsyncGenerator[list[IndexResult], None]:
|
|
67
|
+
async for batch in self.embedding_provider.embed(requests):
|
|
68
|
+
results = []
|
|
69
|
+
for result in batch:
|
|
70
|
+
await self.embedding_repository.create_embedding(
|
|
71
|
+
Embedding(
|
|
72
|
+
snippet_id=result.snippet_id,
|
|
73
|
+
embedding=result.embedding,
|
|
74
|
+
type=self.embedding_type,
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
results.append(IndexResult(snippet_id=result.snippet_id))
|
|
78
|
+
yield results
|
|
79
|
+
|
|
80
|
+
return _index_batches()
|
|
81
|
+
|
|
82
|
+
async def search(
|
|
83
|
+
self, request: VectorSearchQueryRequest
|
|
84
|
+
) -> list[VectorSearchResult]:
|
|
85
|
+
"""Search documents using vector similarity."""
|
|
86
|
+
# Build a single-item request and collect its embedding
|
|
87
|
+
req = EmbeddingRequest(snippet_id=0, text=request.query)
|
|
88
|
+
embedding_vec: list[float] | None = None
|
|
89
|
+
async for batch in self.embedding_provider.embed([req]):
|
|
90
|
+
if batch:
|
|
91
|
+
embedding_vec = [float(v) for v in batch[0].embedding]
|
|
92
|
+
break
|
|
93
|
+
|
|
94
|
+
if not embedding_vec:
|
|
95
|
+
return []
|
|
96
|
+
|
|
97
|
+
results = await self.embedding_repository.list_semantic_results(
|
|
98
|
+
self.embedding_type, embedding_vec, request.top_k
|
|
99
|
+
)
|
|
100
|
+
return [
|
|
101
|
+
VectorSearchResult(snippet_id=snippet_id, score=score)
|
|
102
|
+
for snippet_id, score in results
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
async def has_embedding(
|
|
106
|
+
self, snippet_id: int, embedding_type: EmbeddingType
|
|
107
|
+
) -> bool:
|
|
108
|
+
"""Check if a snippet has an embedding."""
|
|
109
|
+
return (
|
|
110
|
+
await self.embedding_repository.get_embedding_by_snippet_id_and_type(
|
|
111
|
+
snippet_id, embedding_type
|
|
112
|
+
)
|
|
113
|
+
is not None
|
|
114
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""VectorChord vector search repository implementation."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import AsyncGenerator
|
|
4
4
|
from typing import Any, Literal
|
|
@@ -7,16 +7,17 @@ import structlog
|
|
|
7
7
|
from sqlalchemy import Result, TextClause, text
|
|
8
8
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
9
|
|
|
10
|
-
from kodit.
|
|
11
|
-
from kodit.
|
|
10
|
+
from kodit.domain.entities import EmbeddingType
|
|
11
|
+
from kodit.domain.services.embedding_service import (
|
|
12
12
|
EmbeddingProvider,
|
|
13
|
-
|
|
13
|
+
VectorSearchRepository,
|
|
14
14
|
)
|
|
15
|
-
from kodit.
|
|
15
|
+
from kodit.domain.value_objects import (
|
|
16
|
+
EmbeddingRequest,
|
|
16
17
|
IndexResult,
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
VectorIndexRequest,
|
|
19
|
+
VectorSearchQueryRequest,
|
|
20
|
+
VectorSearchResult,
|
|
20
21
|
)
|
|
21
22
|
|
|
22
23
|
# SQL Queries
|
|
@@ -65,8 +66,8 @@ SELECT EXISTS(SELECT 1 FROM {TABLE_NAME} WHERE snippet_id = :snippet_id)
|
|
|
65
66
|
TaskName = Literal["code", "text"]
|
|
66
67
|
|
|
67
68
|
|
|
68
|
-
class
|
|
69
|
-
"""VectorChord vector search."""
|
|
69
|
+
class VectorChordVectorSearchRepository(VectorSearchRepository):
|
|
70
|
+
"""VectorChord vector search repository implementation."""
|
|
70
71
|
|
|
71
72
|
def __init__(
|
|
72
73
|
self,
|
|
@@ -74,7 +75,14 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
74
75
|
session: AsyncSession,
|
|
75
76
|
embedding_provider: EmbeddingProvider,
|
|
76
77
|
) -> None:
|
|
77
|
-
"""Initialize the VectorChord
|
|
78
|
+
"""Initialize the VectorChord vector search repository.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
task_name: The task name (code or text)
|
|
82
|
+
session: The SQLAlchemy async session
|
|
83
|
+
embedding_provider: The embedding provider for generating embeddings
|
|
84
|
+
|
|
85
|
+
"""
|
|
78
86
|
self.embedding_provider = embedding_provider
|
|
79
87
|
self._session = session
|
|
80
88
|
self._initialized = False
|
|
@@ -99,7 +107,7 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
99
107
|
|
|
100
108
|
async def _create_tables(self) -> None:
|
|
101
109
|
"""Create the necessary tables."""
|
|
102
|
-
req = EmbeddingRequest(
|
|
110
|
+
req = EmbeddingRequest(snippet_id=0, text="dimension")
|
|
103
111
|
vector_dim: list[float] | None = None
|
|
104
112
|
async for batch in self.embedding_provider.embed([req]):
|
|
105
113
|
if batch:
|
|
@@ -148,37 +156,46 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
148
156
|
"""Commit the session."""
|
|
149
157
|
await self._session.commit()
|
|
150
158
|
|
|
151
|
-
|
|
152
|
-
self,
|
|
159
|
+
def index_documents(
|
|
160
|
+
self, request: VectorIndexRequest
|
|
153
161
|
) -> AsyncGenerator[list[IndexResult], None]:
|
|
154
|
-
"""
|
|
155
|
-
if not
|
|
156
|
-
self.log.warning("Embedding data is empty, skipping embedding")
|
|
157
|
-
return
|
|
158
|
-
|
|
159
|
-
requests = [EmbeddingRequest(id=doc.snippet_id, text=doc.text) for doc in data]
|
|
160
|
-
|
|
161
|
-
async for batch in self.embedding_provider.embed(requests):
|
|
162
|
-
await self._execute(
|
|
163
|
-
text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
|
|
164
|
-
[
|
|
165
|
-
{
|
|
166
|
-
"snippet_id": result.id,
|
|
167
|
-
"embedding": str(result.embedding),
|
|
168
|
-
}
|
|
169
|
-
for result in batch
|
|
170
|
-
],
|
|
171
|
-
)
|
|
172
|
-
await self._commit()
|
|
173
|
-
yield [IndexResult(snippet_id=result.id) for result in batch]
|
|
162
|
+
"""Index documents for vector search."""
|
|
163
|
+
if not request.documents:
|
|
174
164
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
165
|
+
async def empty_generator() -> AsyncGenerator[list[IndexResult], None]:
|
|
166
|
+
if False:
|
|
167
|
+
yield []
|
|
168
|
+
|
|
169
|
+
return empty_generator()
|
|
180
170
|
|
|
181
|
-
|
|
171
|
+
# Convert to embedding requests
|
|
172
|
+
requests = [
|
|
173
|
+
EmbeddingRequest(snippet_id=doc.snippet_id, text=doc.text)
|
|
174
|
+
for doc in request.documents
|
|
175
|
+
]
|
|
176
|
+
|
|
177
|
+
async def _index_batches() -> AsyncGenerator[list[IndexResult], None]:
|
|
178
|
+
async for batch in self.embedding_provider.embed(requests):
|
|
179
|
+
await self._execute(
|
|
180
|
+
text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
|
|
181
|
+
[
|
|
182
|
+
{
|
|
183
|
+
"snippet_id": result.snippet_id,
|
|
184
|
+
"embedding": str(result.embedding),
|
|
185
|
+
}
|
|
186
|
+
for result in batch
|
|
187
|
+
],
|
|
188
|
+
)
|
|
189
|
+
await self._commit()
|
|
190
|
+
yield [IndexResult(snippet_id=result.snippet_id) for result in batch]
|
|
191
|
+
|
|
192
|
+
return _index_batches()
|
|
193
|
+
|
|
194
|
+
async def search(
|
|
195
|
+
self, request: VectorSearchQueryRequest
|
|
196
|
+
) -> list[VectorSearchResult]:
|
|
197
|
+
"""Search documents using vector similarity."""
|
|
198
|
+
req = EmbeddingRequest(snippet_id=0, text=request.query)
|
|
182
199
|
embedding_vec: list[float] | None = None
|
|
183
200
|
async for batch in self.embedding_provider.embed([req]):
|
|
184
201
|
if batch:
|
|
@@ -189,23 +206,25 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
189
206
|
return []
|
|
190
207
|
result = await self._execute(
|
|
191
208
|
text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
|
|
192
|
-
{"query": str(embedding_vec), "top_k": top_k},
|
|
209
|
+
{"query": str(embedding_vec), "top_k": request.top_k},
|
|
193
210
|
)
|
|
194
211
|
rows = result.mappings().all()
|
|
195
212
|
|
|
196
213
|
return [
|
|
197
|
-
|
|
214
|
+
VectorSearchResult(snippet_id=row["snippet_id"], score=row["score"])
|
|
198
215
|
for row in rows
|
|
199
216
|
]
|
|
200
217
|
|
|
201
218
|
async def has_embedding(
|
|
202
|
-
self,
|
|
203
|
-
snippet_id: int,
|
|
204
|
-
embedding_type: EmbeddingType, # noqa: ARG002
|
|
219
|
+
self, snippet_id: int, embedding_type: EmbeddingType
|
|
205
220
|
) -> bool:
|
|
206
221
|
"""Check if a snippet has an embedding."""
|
|
222
|
+
# For VectorChord, we check if the snippet exists in the table
|
|
223
|
+
# Note: embedding_type is ignored since VectorChord uses separate
|
|
224
|
+
# tables per task
|
|
225
|
+
# ruff: noqa: ARG002
|
|
207
226
|
result = await self._execute(
|
|
208
227
|
text(CHECK_VCHORD_EMBEDDING_EXISTS.format(TABLE_NAME=self.table_name)),
|
|
209
228
|
{"snippet_id": snippet_id},
|
|
210
229
|
)
|
|
211
|
-
return result.
|
|
230
|
+
return bool(result.scalar())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Infrastructure enrichment module."""
|
|
@@ -1,28 +1,42 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Enrichment factory for creating enrichment domain services."""
|
|
2
2
|
|
|
3
3
|
from kodit.config import AppContext, Endpoint
|
|
4
|
-
from kodit.
|
|
4
|
+
from kodit.domain.services.enrichment_service import EnrichmentDomainService
|
|
5
|
+
from kodit.infrastructure.enrichment.local_enrichment_provider import (
|
|
5
6
|
LocalEnrichmentProvider,
|
|
6
7
|
)
|
|
7
|
-
from kodit.enrichment.
|
|
8
|
+
from kodit.infrastructure.enrichment.openai_enrichment_provider import (
|
|
8
9
|
OpenAIEnrichmentProvider,
|
|
9
10
|
)
|
|
10
|
-
from kodit.enrichment.enrichment_service import (
|
|
11
|
-
EnrichmentService,
|
|
12
|
-
LLMEnrichmentService,
|
|
13
|
-
)
|
|
14
11
|
from kodit.log import log_event
|
|
15
12
|
|
|
16
13
|
|
|
17
14
|
def _get_endpoint_configuration(app_context: AppContext) -> Endpoint | None:
|
|
18
|
-
"""Get the endpoint configuration for the enrichment service.
|
|
15
|
+
"""Get the endpoint configuration for the enrichment service.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
app_context: The application context.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The endpoint configuration or None.
|
|
22
|
+
|
|
23
|
+
"""
|
|
19
24
|
return app_context.enrichment_endpoint or app_context.default_endpoint or None
|
|
20
25
|
|
|
21
26
|
|
|
22
|
-
def
|
|
23
|
-
|
|
27
|
+
def create_enrichment_domain_service(
|
|
28
|
+
app_context: AppContext,
|
|
29
|
+
) -> EnrichmentDomainService:
|
|
30
|
+
"""Create an enrichment domain service.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
app_context: The application context.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
An enrichment domain service instance.
|
|
37
|
+
|
|
38
|
+
"""
|
|
24
39
|
endpoint = _get_endpoint_configuration(app_context)
|
|
25
|
-
endpoint = app_context.enrichment_endpoint or app_context.default_endpoint or None
|
|
26
40
|
|
|
27
41
|
if endpoint and endpoint.type == "openai":
|
|
28
42
|
log_event("kodit.enrichment", {"provider": "openai"})
|
|
@@ -32,6 +46,8 @@ def enrichment_factory(app_context: AppContext) -> EnrichmentService:
|
|
|
32
46
|
openai_client=AsyncOpenAI(
|
|
33
47
|
api_key=endpoint.api_key or "default",
|
|
34
48
|
base_url=endpoint.base_url or "https://api.openai.com/v1",
|
|
49
|
+
timeout=60,
|
|
50
|
+
max_retries=2,
|
|
35
51
|
),
|
|
36
52
|
model_name=endpoint.model or "gpt-4o-mini",
|
|
37
53
|
)
|
|
@@ -39,4 +55,4 @@ def enrichment_factory(app_context: AppContext) -> EnrichmentService:
|
|
|
39
55
|
log_event("kodit.enrichment", {"provider": "local"})
|
|
40
56
|
enrichment_provider = LocalEnrichmentProvider()
|
|
41
57
|
|
|
42
|
-
return
|
|
58
|
+
return EnrichmentDomainService(enrichment_provider=enrichment_provider)
|