kodit 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (42) hide show
  1. kodit/_version.py +2 -2
  2. kodit/bm25/keyword_search_factory.py +17 -0
  3. kodit/bm25/keyword_search_service.py +34 -0
  4. kodit/bm25/{bm25.py → local_bm25.py} +40 -14
  5. kodit/bm25/vectorchord_bm25.py +193 -0
  6. kodit/cli.py +114 -25
  7. kodit/config.py +9 -2
  8. kodit/database.py +4 -2
  9. kodit/embedding/embedding_factory.py +44 -0
  10. kodit/embedding/embedding_provider/__init__.py +1 -0
  11. kodit/embedding/embedding_provider/embedding_provider.py +60 -0
  12. kodit/embedding/embedding_provider/hash_embedding_provider.py +77 -0
  13. kodit/embedding/embedding_provider/local_embedding_provider.py +58 -0
  14. kodit/embedding/embedding_provider/openai_embedding_provider.py +75 -0
  15. kodit/{search/search_repository.py → embedding/embedding_repository.py} +61 -33
  16. kodit/embedding/local_vector_search_service.py +50 -0
  17. kodit/embedding/vector_search_service.py +38 -0
  18. kodit/embedding/vectorchord_vector_search_service.py +154 -0
  19. kodit/enrichment/__init__.py +1 -0
  20. kodit/enrichment/enrichment_factory.py +23 -0
  21. kodit/enrichment/enrichment_provider/__init__.py +1 -0
  22. kodit/enrichment/enrichment_provider/enrichment_provider.py +16 -0
  23. kodit/enrichment/enrichment_provider/local_enrichment_provider.py +63 -0
  24. kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +77 -0
  25. kodit/enrichment/enrichment_service.py +33 -0
  26. kodit/indexing/fusion.py +67 -0
  27. kodit/indexing/indexing_repository.py +44 -4
  28. kodit/indexing/indexing_service.py +142 -31
  29. kodit/mcp.py +31 -18
  30. kodit/snippets/languages/go.scm +26 -0
  31. kodit/source/source_service.py +9 -3
  32. kodit/util/__init__.py +1 -0
  33. kodit/util/spinner.py +59 -0
  34. {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/METADATA +4 -1
  35. kodit-0.1.16.dist-info/RECORD +64 -0
  36. kodit/embedding/embedding.py +0 -203
  37. kodit/search/__init__.py +0 -1
  38. kodit/search/search_service.py +0 -147
  39. kodit-0.1.14.dist-info/RECORD +0 -44
  40. {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/WHEEL +0 -0
  41. {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/entry_points.txt +0 -0
  42. {kodit-0.1.14.dist-info → kodit-0.1.16.dist-info}/licenses/LICENSE +0 -0
kodit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.14'
21
- __version_tuple__ = version_tuple = (0, 1, 14)
20
+ __version__ = version = '0.1.16'
21
+ __version_tuple__ = version_tuple = (0, 1, 16)
@@ -0,0 +1,17 @@
1
+ """Factory for creating keyword search providers."""
2
+
3
+ from sqlalchemy.ext.asyncio import AsyncSession
4
+
5
+ from kodit.bm25.keyword_search_service import KeywordSearchProvider
6
+ from kodit.bm25.local_bm25 import BM25Service
7
+ from kodit.bm25.vectorchord_bm25 import VectorChordBM25
8
+ from kodit.config import AppContext
9
+
10
+
11
+ def keyword_search_factory(
12
+ app_context: AppContext, session: AsyncSession
13
+ ) -> KeywordSearchProvider:
14
+ """Create a keyword search provider."""
15
+ if app_context.default_search.provider == "vectorchord":
16
+ return VectorChordBM25(session=session)
17
+ return BM25Service(data_dir=app_context.get_data_dir())
@@ -0,0 +1,34 @@
1
+ """Keyword search service."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import NamedTuple
5
+
6
+
7
+ class BM25Document(NamedTuple):
8
+ """BM25 document."""
9
+
10
+ snippet_id: int
11
+ text: str
12
+
13
+
14
+ class BM25Result(NamedTuple):
15
+ """BM25 result."""
16
+
17
+ snippet_id: int
18
+ score: float
19
+
20
+
21
+ class KeywordSearchProvider(ABC):
22
+ """Interface for keyword search providers."""
23
+
24
+ @abstractmethod
25
+ async def index(self, corpus: list[BM25Document]) -> None:
26
+ """Index a new corpus."""
27
+
28
+ @abstractmethod
29
+ async def retrieve(self, query: str, top_k: int = 2) -> list[BM25Result]:
30
+ """Retrieve from the index."""
31
+
32
+ @abstractmethod
33
+ async def delete(self, snippet_ids: list[int]) -> None:
34
+ """Delete documents from the index."""
@@ -1,23 +1,36 @@
1
- """BM25 service."""
1
+ """Locally hosted BM25 service primarily for use with SQLite."""
2
2
 
3
+ import json
3
4
  from pathlib import Path
4
5
 
6
+ import aiofiles
5
7
  import bm25s
6
8
  import Stemmer
7
9
  import structlog
8
10
  from bm25s.tokenization import Tokenized
9
11
 
12
+ from kodit.bm25.keyword_search_service import (
13
+ BM25Document,
14
+ BM25Result,
15
+ KeywordSearchProvider,
16
+ )
10
17
 
11
- class BM25Service:
12
- """Service for BM25."""
18
+ SNIPPET_IDS_FILE = "snippet_ids.jsonl"
19
+
20
+
21
+ class BM25Service(KeywordSearchProvider):
22
+ """LocalBM25 service."""
13
23
 
14
24
  def __init__(self, data_dir: Path) -> None:
15
25
  """Initialize the BM25 service."""
16
26
  self.log = structlog.get_logger(__name__)
17
27
  self.index_path = data_dir / "bm25s_index"
28
+ self.snippet_ids: list[int] = []
18
29
  try:
19
30
  self.log.debug("Loading BM25 index")
20
31
  self.retriever = bm25s.BM25.load(self.index_path, mmap=True)
32
+ with Path(self.index_path / SNIPPET_IDS_FILE).open() as f:
33
+ self.snippet_ids = json.load(f)
21
34
  except FileNotFoundError:
22
35
  self.log.debug("BM25 index not found, creating new index")
23
36
  self.retriever = bm25s.BM25()
@@ -33,28 +46,34 @@ class BM25Service:
33
46
  show_progress=True,
34
47
  )
35
48
 
36
- def index(self, corpus: list[str]) -> None:
49
+ async def index(self, corpus: list[BM25Document]) -> None:
37
50
  """Index a new corpus."""
38
51
  self.log.debug("Indexing corpus")
39
- vocab = self._tokenize(corpus)
52
+ vocab = self._tokenize([doc.text for doc in corpus])
40
53
  self.retriever = bm25s.BM25()
41
54
  self.retriever.index(vocab, show_progress=False)
42
55
  self.retriever.save(self.index_path)
56
+ self.snippet_ids = self.snippet_ids + [doc.snippet_id for doc in corpus]
57
+ async with aiofiles.open(self.index_path / SNIPPET_IDS_FILE, "w") as f:
58
+ await f.write(json.dumps(self.snippet_ids))
43
59
 
44
- def retrieve(
45
- self, doc_ids: list[int], query: str, top_k: int = 2
46
- ) -> list[tuple[int, float]]:
60
+ async def retrieve(self, query: str, top_k: int = 2) -> list[BM25Result]:
47
61
  """Retrieve from the index."""
48
62
  if top_k == 0:
49
63
  self.log.warning("Top k is 0, returning empty list")
50
64
  return []
51
- if len(doc_ids) == 0:
52
- self.log.warning("No documents to retrieve from, returning empty list")
65
+
66
+ # Get the number of documents in the index
67
+ num_docs = self.retriever.scores["num_docs"]
68
+ if num_docs == 0:
53
69
  return []
54
70
 
55
- top_k = min(top_k, len(self.retriever.scores))
71
+ # Adjust top_k to not exceed corpus size
72
+ top_k = min(top_k, num_docs)
56
73
  self.log.debug(
57
- "Retrieving from index", query=query, top_k=top_k, num_docs=len(doc_ids)
74
+ "Retrieving from index",
75
+ query=query,
76
+ top_k=top_k,
58
77
  )
59
78
 
60
79
  query_tokens = self._tokenize([query])
@@ -62,10 +81,17 @@ class BM25Service:
62
81
  self.log.debug("Query tokens", query_tokens=query_tokens)
63
82
 
64
83
  results, scores = self.retriever.retrieve(
65
- query_tokens=query_tokens, corpus=doc_ids, k=top_k
84
+ query_tokens=query_tokens,
85
+ corpus=self.snippet_ids,
86
+ k=top_k,
66
87
  )
67
88
  self.log.debug("Raw results", results=results, scores=scores)
68
89
  return [
69
- (int(result), float(score))
90
+ BM25Result(snippet_id=int(result), score=float(score))
70
91
  for result, score in zip(results[0], scores[0], strict=False)
92
+ if score > 0.0
71
93
  ]
94
+
95
+ async def delete(self, snippet_ids: list[int]) -> None: # noqa: ARG002
96
+ """Delete documents from the index."""
97
+ self.log.warning("Deletion not supported for local BM25 index")
@@ -0,0 +1,193 @@
1
+ """VectorChord repository for document operations."""
2
+
3
+ from typing import Any
4
+
5
+ from sqlalchemy import Result, TextClause, bindparam, text
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from kodit.bm25.keyword_search_service import (
9
+ BM25Document,
10
+ BM25Result,
11
+ KeywordSearchProvider,
12
+ )
13
+
14
+ TABLE_NAME = "vectorchord_bm25_documents"
15
+ INDEX_NAME = f"{TABLE_NAME}_idx"
16
+ TOKENIZER_NAME = "bert"
17
+
18
+ # SQL statements
19
+ CREATE_VCHORD_EXTENSION = "CREATE EXTENSION IF NOT EXISTS vchord CASCADE;"
20
+ CREATE_PG_TOKENIZER = "CREATE EXTENSION IF NOT EXISTS pg_tokenizer CASCADE;"
21
+ CREATE_VCHORD_BM25 = "CREATE EXTENSION IF NOT EXISTS vchord_bm25 CASCADE;"
22
+ SET_SEARCH_PATH = """
23
+ SET search_path TO
24
+ "$user", public, bm25_catalog, pg_catalog, information_schema, tokenizer_catalog;
25
+ """
26
+ CREATE_BM25_TABLE = f"""
27
+ CREATE TABLE IF NOT EXISTS {TABLE_NAME} (
28
+ id SERIAL PRIMARY KEY,
29
+ snippet_id BIGINT NOT NULL,
30
+ passage TEXT NOT NULL,
31
+ embedding bm25vector,
32
+ UNIQUE(snippet_id)
33
+ )
34
+ """
35
+
36
+ CREATE_BM25_INDEX = f"""
37
+ CREATE INDEX IF NOT EXISTS {INDEX_NAME}
38
+ ON {TABLE_NAME}
39
+ USING bm25 (embedding bm25_ops)
40
+ """
41
+ TOKENIZER_NAME_CHECK_QUERY = (
42
+ f"SELECT 1 FROM tokenizer_catalog.tokenizer WHERE name = '{TOKENIZER_NAME}'" # noqa: S608
43
+ )
44
+ LOAD_TOKENIZER = """
45
+ SELECT create_tokenizer('bert', $$
46
+ model = "llmlingua2"
47
+ pre_tokenizer = "unicode_segmentation" # Unicode Standard Annex #29
48
+ [[character_filters]]
49
+ to_lowercase = {} # convert all characters to lowercase
50
+ [[character_filters]]
51
+ unicode_normalization = "nfkd" # Unicode Normalization Form KD
52
+ [[token_filters]]
53
+ skip_non_alphanumeric = {} # remove non-alphanumeric tokens
54
+ [[token_filters]]
55
+ stopwords = "nltk_english" # remove stopwords using the nltk dictionary
56
+ [[token_filters]]
57
+ stemmer = "english_porter2" # stem tokens using the English Porter2 stemmer
58
+ $$)
59
+ """
60
+ INSERT_QUERY = f"""
61
+ INSERT INTO {TABLE_NAME} (snippet_id, passage)
62
+ VALUES (:snippet_id, :passage)
63
+ ON CONFLICT (snippet_id) DO UPDATE
64
+ SET passage = EXCLUDED.passage
65
+ """ # noqa: S608
66
+ UPDATE_QUERY = f"""
67
+ UPDATE {TABLE_NAME}
68
+ SET embedding = tokenize(passage, '{TOKENIZER_NAME}')
69
+ """ # noqa: S608
70
+ SEARCH_QUERY = f"""
71
+ SELECT
72
+ snippet_id,
73
+ embedding <&>
74
+ to_bm25query('{INDEX_NAME}', tokenize(:query_text, '{TOKENIZER_NAME}'))
75
+ AS bm25_score
76
+ FROM {TABLE_NAME}
77
+ ORDER BY bm25_score
78
+ LIMIT :limit
79
+ """ # noqa: S608
80
+ DELETE_QUERY = f"""
81
+ DELETE FROM {TABLE_NAME}
82
+ WHERE snippet_id IN :snippet_ids
83
+ """ # noqa: S608
84
+
85
+
86
+ class VectorChordBM25(KeywordSearchProvider):
87
+ """BM25 using VectorChord."""
88
+
89
+ def __init__(
90
+ self,
91
+ session: AsyncSession,
92
+ ) -> None:
93
+ """Initialize the VectorChord BM25."""
94
+ self.__session = session
95
+ self._initialized = False
96
+
97
+ async def _initialize(self) -> None:
98
+ """Initialize the VectorChord environment."""
99
+ try:
100
+ await self._create_extensions()
101
+ await self._create_tokenizer_if_not_exists()
102
+ await self._create_tables()
103
+ self._initialized = True
104
+ except Exception as e:
105
+ msg = f"Failed to initialize VectorChord repository: {e}"
106
+ raise RuntimeError(msg) from e
107
+
108
+ async def _create_extensions(self) -> None:
109
+ """Create the necessary extensions."""
110
+ await self.__session.execute(text(CREATE_VCHORD_EXTENSION))
111
+ await self.__session.execute(text(CREATE_PG_TOKENIZER))
112
+ await self.__session.execute(text(CREATE_VCHORD_BM25))
113
+ await self.__session.execute(text(SET_SEARCH_PATH))
114
+ await self._commit()
115
+
116
+ async def _create_tokenizer_if_not_exists(self) -> None:
117
+ """Create the tokenizer if it doesn't exist."""
118
+ # Check if tokenizer exists in the catalog
119
+ result = await self.__session.execute(text(TOKENIZER_NAME_CHECK_QUERY))
120
+ if result.scalar_one_or_none() is None:
121
+ # Tokenizer doesn't exist, create it
122
+ await self.__session.execute(text(LOAD_TOKENIZER))
123
+ await self._commit()
124
+
125
+ async def _create_tables(self) -> None:
126
+ """Create the necessary tables in the correct order."""
127
+ await self.__session.execute(text(CREATE_BM25_TABLE))
128
+ await self.__session.execute(text(CREATE_BM25_INDEX))
129
+ await self._commit()
130
+
131
+ async def _execute(
132
+ self, query: TextClause, param_list: list[Any] | dict[str, Any] | None = None
133
+ ) -> Result:
134
+ """Execute a query."""
135
+ if not self._initialized:
136
+ await self._initialize()
137
+ return await self.__session.execute(query, param_list)
138
+
139
+ async def _commit(self) -> None:
140
+ """Commit the session."""
141
+ await self.__session.commit()
142
+
143
+ async def index(self, corpus: list[BM25Document]) -> None:
144
+ """Index a new corpus."""
145
+ # Filter out any documents that don't have a snippet_id or text
146
+ corpus = [
147
+ doc
148
+ for doc in corpus
149
+ if doc.snippet_id is not None and doc.text is not None and doc.text != ""
150
+ ]
151
+
152
+ if not corpus:
153
+ return
154
+
155
+ # Execute inserts
156
+ await self._execute(
157
+ text(INSERT_QUERY),
158
+ [{"snippet_id": doc.snippet_id, "passage": doc.text} for doc in corpus],
159
+ )
160
+
161
+ # Tokenize the new documents with schema qualification
162
+ await self._execute(text(UPDATE_QUERY))
163
+ await self._commit()
164
+
165
+ async def delete(self, snippet_ids: list[int]) -> None:
166
+ """Delete documents from the index."""
167
+ await self._execute(
168
+ text(DELETE_QUERY).bindparams(bindparam("snippet_ids", expanding=True)),
169
+ {"snippet_ids": snippet_ids},
170
+ )
171
+ await self._commit()
172
+
173
+ async def retrieve(
174
+ self,
175
+ query: str,
176
+ top_k: int = 10,
177
+ ) -> list[BM25Result]:
178
+ """Search documents using BM25 similarity."""
179
+ if not query or query == "":
180
+ return []
181
+
182
+ sql = text(SEARCH_QUERY).bindparams(query_text=query, limit=top_k)
183
+ try:
184
+ result = await self._execute(sql)
185
+ rows = result.mappings().all()
186
+
187
+ return [
188
+ BM25Result(snippet_id=row["snippet_id"], score=row["bm25_score"])
189
+ for row in rows
190
+ ]
191
+ except Exception as e:
192
+ msg = f"Error during BM25 search: {e}"
193
+ raise RuntimeError(msg) from e
kodit/cli.py CHANGED
@@ -10,17 +10,17 @@ import uvicorn
10
10
  from pytable_formatter import Cell, Table
11
11
  from sqlalchemy.ext.asyncio import AsyncSession
12
12
 
13
+ from kodit.bm25.keyword_search_factory import keyword_search_factory
13
14
  from kodit.config import (
14
15
  AppContext,
15
16
  with_app_context,
16
17
  with_session,
17
18
  )
18
- from kodit.embedding.embedding import embedding_factory
19
+ from kodit.embedding.embedding_factory import embedding_factory
20
+ from kodit.enrichment.enrichment_factory import enrichment_factory
19
21
  from kodit.indexing.indexing_repository import IndexRepository
20
- from kodit.indexing.indexing_service import IndexService
22
+ from kodit.indexing.indexing_service import IndexService, SearchRequest
21
23
  from kodit.log import configure_logging, configure_telemetry, log_event
22
- from kodit.search.search_repository import SearchRepository
23
- from kodit.search.search_service import SearchRequest, SearchService
24
24
  from kodit.source.source_repository import SourceRepository
25
25
  from kodit.source.source_service import SourceService
26
26
 
@@ -68,10 +68,16 @@ async def index(
68
68
  source_service = SourceService(app_context.get_clone_dir(), source_repository)
69
69
  repository = IndexRepository(session)
70
70
  service = IndexService(
71
- repository,
72
- source_service,
73
- app_context.get_data_dir(),
74
- embedding_service=embedding_factory(app_context.get_default_openai_client()),
71
+ repository=repository,
72
+ source_service=source_service,
73
+ keyword_search_provider=keyword_search_factory(app_context, session),
74
+ code_search_service=embedding_factory(
75
+ task_name="code", app_context=app_context, session=session
76
+ ),
77
+ text_search_service=embedding_factory(
78
+ task_name="text", app_context=app_context, session=session
79
+ ),
80
+ enrichment_service=enrichment_factory(app_context),
75
81
  )
76
82
 
77
83
  if not sources:
@@ -128,11 +134,20 @@ async def code(
128
134
 
129
135
  This works best if your query is code.
130
136
  """
131
- repository = SearchRepository(session)
132
- service = SearchService(
133
- repository,
134
- app_context.get_data_dir(),
135
- embedding_service=embedding_factory(app_context.get_default_openai_client()),
137
+ source_repository = SourceRepository(session)
138
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
139
+ repository = IndexRepository(session)
140
+ service = IndexService(
141
+ repository=repository,
142
+ source_service=source_service,
143
+ keyword_search_provider=keyword_search_factory(app_context, session),
144
+ code_search_service=embedding_factory(
145
+ task_name="code", app_context=app_context, session=session
146
+ ),
147
+ text_search_service=embedding_factory(
148
+ task_name="text", app_context=app_context, session=session
149
+ ),
150
+ enrichment_service=enrichment_factory(app_context),
136
151
  )
137
152
 
138
153
  snippets = await service.search(SearchRequest(code_query=query, top_k=top_k))
@@ -144,6 +159,7 @@ async def code(
144
159
  for snippet in snippets:
145
160
  click.echo("-" * 80)
146
161
  click.echo(f"{snippet.uri}")
162
+ click.echo(f"Original scores: {snippet.original_scores}")
147
163
  click.echo(snippet.content)
148
164
  click.echo("-" * 80)
149
165
  click.echo()
@@ -161,11 +177,20 @@ async def keyword(
161
177
  top_k: int,
162
178
  ) -> None:
163
179
  """Search for snippets using keyword search."""
164
- repository = SearchRepository(session)
165
- service = SearchService(
166
- repository,
167
- app_context.get_data_dir(),
168
- embedding_service=embedding_factory(app_context.get_default_openai_client()),
180
+ source_repository = SourceRepository(session)
181
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
182
+ repository = IndexRepository(session)
183
+ service = IndexService(
184
+ repository=repository,
185
+ source_service=source_service,
186
+ keyword_search_provider=keyword_search_factory(app_context, session),
187
+ code_search_service=embedding_factory(
188
+ task_name="code", app_context=app_context, session=session
189
+ ),
190
+ text_search_service=embedding_factory(
191
+ task_name="text", app_context=app_context, session=session
192
+ ),
193
+ enrichment_service=enrichment_factory(app_context),
169
194
  )
170
195
 
171
196
  snippets = await service.search(SearchRequest(keywords=keywords, top_k=top_k))
@@ -177,6 +202,53 @@ async def keyword(
177
202
  for snippet in snippets:
178
203
  click.echo("-" * 80)
179
204
  click.echo(f"{snippet.uri}")
205
+ click.echo(f"Original scores: {snippet.original_scores}")
206
+ click.echo(snippet.content)
207
+ click.echo("-" * 80)
208
+ click.echo()
209
+
210
+
211
+ @search.command()
212
+ @click.argument("query")
213
+ @click.option("--top-k", default=10, help="Number of snippets to retrieve")
214
+ @with_app_context
215
+ @with_session
216
+ async def text(
217
+ session: AsyncSession,
218
+ app_context: AppContext,
219
+ query: str,
220
+ top_k: int,
221
+ ) -> None:
222
+ """Search for snippets using semantic text search.
223
+
224
+ This works best if your query is text.
225
+ """
226
+ source_repository = SourceRepository(session)
227
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
228
+ repository = IndexRepository(session)
229
+ service = IndexService(
230
+ repository=repository,
231
+ source_service=source_service,
232
+ keyword_search_provider=keyword_search_factory(app_context, session),
233
+ code_search_service=embedding_factory(
234
+ task_name="code", app_context=app_context, session=session
235
+ ),
236
+ text_search_service=embedding_factory(
237
+ task_name="text", app_context=app_context, session=session
238
+ ),
239
+ enrichment_service=enrichment_factory(app_context),
240
+ )
241
+
242
+ snippets = await service.search(SearchRequest(text_query=query, top_k=top_k))
243
+
244
+ if len(snippets) == 0:
245
+ click.echo("No snippets found")
246
+ return
247
+
248
+ for snippet in snippets:
249
+ click.echo("-" * 80)
250
+ click.echo(f"{snippet.uri}")
251
+ click.echo(f"Original scores: {snippet.original_scores}")
180
252
  click.echo(snippet.content)
181
253
  click.echo("-" * 80)
182
254
  click.echo()
@@ -186,28 +258,44 @@ async def keyword(
186
258
  @click.option("--top-k", default=10, help="Number of snippets to retrieve")
187
259
  @click.option("--keywords", required=True, help="Comma separated list of keywords")
188
260
  @click.option("--code", required=True, help="Semantic code search query")
261
+ @click.option("--text", required=True, help="Semantic text search query")
189
262
  @with_app_context
190
263
  @with_session
191
- async def hybrid(
264
+ async def hybrid( # noqa: PLR0913
192
265
  session: AsyncSession,
193
266
  app_context: AppContext,
194
267
  top_k: int,
195
268
  keywords: str,
196
269
  code: str,
270
+ text: str,
197
271
  ) -> None:
198
272
  """Search for snippets using hybrid search."""
199
- repository = SearchRepository(session)
200
- service = SearchService(
201
- repository,
202
- app_context.get_data_dir(),
203
- embedding_service=embedding_factory(app_context.get_default_openai_client()),
273
+ source_repository = SourceRepository(session)
274
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
275
+ repository = IndexRepository(session)
276
+ service = IndexService(
277
+ repository=repository,
278
+ source_service=source_service,
279
+ keyword_search_provider=keyword_search_factory(app_context, session),
280
+ code_search_service=embedding_factory(
281
+ task_name="code", app_context=app_context, session=session
282
+ ),
283
+ text_search_service=embedding_factory(
284
+ task_name="text", app_context=app_context, session=session
285
+ ),
286
+ enrichment_service=enrichment_factory(app_context),
204
287
  )
205
288
 
206
289
  # Parse keywords into a list of strings
207
290
  keywords_list = [k.strip().lower() for k in keywords.split(",")]
208
291
 
209
292
  snippets = await service.search(
210
- SearchRequest(keywords=keywords_list, code_query=code, top_k=top_k)
293
+ SearchRequest(
294
+ text_query=text,
295
+ keywords=keywords_list,
296
+ code_query=code,
297
+ top_k=top_k,
298
+ )
211
299
  )
212
300
 
213
301
  if len(snippets) == 0:
@@ -217,6 +305,7 @@ async def hybrid(
217
305
  for snippet in snippets:
218
306
  click.echo("-" * 80)
219
307
  click.echo(f"{snippet.uri}")
308
+ click.echo(f"Original scores: {snippet.original_scores}")
220
309
  click.echo(snippet.content)
221
310
  click.echo("-" * 80)
222
311
  click.echo()
kodit/config.py CHANGED
@@ -12,14 +12,12 @@ from pydantic import BaseModel, Field
12
12
  from pydantic_settings import BaseSettings, SettingsConfigDict
13
13
 
14
14
  from kodit.database import Database
15
- from kodit.embedding.embedding import TINY
16
15
 
17
16
  DEFAULT_BASE_DIR = Path.home() / ".kodit"
18
17
  DEFAULT_DB_URL = f"sqlite+aiosqlite:///{DEFAULT_BASE_DIR}/kodit.db"
19
18
  DEFAULT_LOG_LEVEL = "INFO"
20
19
  DEFAULT_LOG_FORMAT = "pretty"
21
20
  DEFAULT_DISABLE_TELEMETRY = False
22
- DEFAULT_EMBEDDING_MODEL_NAME = TINY
23
21
  T = TypeVar("T")
24
22
 
25
23
 
@@ -31,6 +29,12 @@ class Endpoint(BaseModel):
31
29
  base_url: str | None = None
32
30
 
33
31
 
32
+ class Search(BaseModel):
33
+ """Search provides configuration for a search engine."""
34
+
35
+ provider: Literal["sqlite", "vectorchord"] = Field(default="sqlite")
36
+
37
+
34
38
  class AppContext(BaseSettings):
35
39
  """Global context for the kodit project. Provides a shared state for the app."""
36
40
 
@@ -57,6 +61,9 @@ class AppContext(BaseSettings):
57
61
  "(can be overridden by task-specific configuration)."
58
62
  ),
59
63
  )
64
+ default_search: Search = Field(
65
+ default=Search(),
66
+ )
60
67
  _db: Database | None = None
61
68
 
62
69
  def model_post_init(self, _: Any) -> None:
kodit/database.py CHANGED
@@ -27,10 +27,12 @@ class CommonMixin:
27
27
 
28
28
  id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
29
29
  created_at: Mapped[datetime] = mapped_column(
30
- DateTime, default=lambda: datetime.now(UTC)
30
+ DateTime(timezone=True), default=lambda: datetime.now(UTC)
31
31
  )
32
32
  updated_at: Mapped[datetime] = mapped_column(
33
- DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)
33
+ DateTime(timezone=True),
34
+ default=lambda: datetime.now(UTC),
35
+ onupdate=lambda: datetime.now(UTC),
34
36
  )
35
37
 
36
38
 
@@ -0,0 +1,44 @@
1
+ """Embedding service."""
2
+
3
+ from sqlalchemy.ext.asyncio import AsyncSession
4
+
5
+ from kodit.config import AppContext
6
+ from kodit.embedding.embedding_provider.local_embedding_provider import (
7
+ CODE,
8
+ LocalEmbeddingProvider,
9
+ )
10
+ from kodit.embedding.embedding_provider.openai_embedding_provider import (
11
+ OpenAIEmbeddingProvider,
12
+ )
13
+ from kodit.embedding.embedding_repository import EmbeddingRepository
14
+ from kodit.embedding.local_vector_search_service import LocalVectorSearchService
15
+ from kodit.embedding.vector_search_service import (
16
+ VectorSearchService,
17
+ )
18
+ from kodit.embedding.vectorchord_vector_search_service import (
19
+ VectorChordVectorSearchService,
20
+ )
21
+
22
+
23
+ def embedding_factory(
24
+ task_name: str, app_context: AppContext, session: AsyncSession
25
+ ) -> VectorSearchService:
26
+ """Create an embedding service."""
27
+ embedding_repository = EmbeddingRepository(session=session)
28
+ embedding_provider = None
29
+ openai_client = app_context.get_default_openai_client()
30
+ if openai_client is not None:
31
+ embedding_provider = OpenAIEmbeddingProvider(openai_client=openai_client)
32
+ else:
33
+ embedding_provider = LocalEmbeddingProvider(CODE)
34
+
35
+ if app_context.default_search.provider == "vectorchord":
36
+ return VectorChordVectorSearchService(task_name, session, embedding_provider)
37
+ if app_context.default_search.provider == "sqlite":
38
+ return LocalVectorSearchService(
39
+ embedding_repository=embedding_repository,
40
+ embedding_provider=embedding_provider,
41
+ )
42
+
43
+ msg = f"Invalid semantic search provider: {app_context.default_search.provider}"
44
+ raise ValueError(msg)
@@ -0,0 +1 @@
1
+ """Embedding module."""