kodit 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

kodit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.15'
21
- __version_tuple__ = version_tuple = (0, 1, 15)
20
+ __version__ = version = '0.1.16'
21
+ __version_tuple__ = version_tuple = (0, 1, 16)
kodit/cli.py CHANGED
@@ -17,11 +17,10 @@ from kodit.config import (
17
17
  with_session,
18
18
  )
19
19
  from kodit.embedding.embedding_factory import embedding_factory
20
+ from kodit.enrichment.enrichment_factory import enrichment_factory
20
21
  from kodit.indexing.indexing_repository import IndexRepository
21
- from kodit.indexing.indexing_service import IndexService
22
+ from kodit.indexing.indexing_service import IndexService, SearchRequest
22
23
  from kodit.log import configure_logging, configure_telemetry, log_event
23
- from kodit.search.search_repository import SearchRepository
24
- from kodit.search.search_service import SearchRequest, SearchService
25
24
  from kodit.source.source_repository import SourceRepository
26
25
  from kodit.source.source_service import SourceService
27
26
 
@@ -72,9 +71,13 @@ async def index(
72
71
  repository=repository,
73
72
  source_service=source_service,
74
73
  keyword_search_provider=keyword_search_factory(app_context, session),
75
- vector_search_service=embedding_factory(
76
- app_context=app_context, session=session
74
+ code_search_service=embedding_factory(
75
+ task_name="code", app_context=app_context, session=session
77
76
  ),
77
+ text_search_service=embedding_factory(
78
+ task_name="text", app_context=app_context, session=session
79
+ ),
80
+ enrichment_service=enrichment_factory(app_context),
78
81
  )
79
82
 
80
83
  if not sources:
@@ -131,11 +134,20 @@ async def code(
131
134
 
132
135
  This works best if your query is code.
133
136
  """
134
- repository = SearchRepository(session)
135
- service = SearchService(
136
- repository,
137
+ source_repository = SourceRepository(session)
138
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
139
+ repository = IndexRepository(session)
140
+ service = IndexService(
141
+ repository=repository,
142
+ source_service=source_service,
137
143
  keyword_search_provider=keyword_search_factory(app_context, session),
138
- embedding_service=embedding_factory(app_context=app_context, session=session),
144
+ code_search_service=embedding_factory(
145
+ task_name="code", app_context=app_context, session=session
146
+ ),
147
+ text_search_service=embedding_factory(
148
+ task_name="text", app_context=app_context, session=session
149
+ ),
150
+ enrichment_service=enrichment_factory(app_context),
139
151
  )
140
152
 
141
153
  snippets = await service.search(SearchRequest(code_query=query, top_k=top_k))
@@ -147,6 +159,7 @@ async def code(
147
159
  for snippet in snippets:
148
160
  click.echo("-" * 80)
149
161
  click.echo(f"{snippet.uri}")
162
+ click.echo(f"Original scores: {snippet.original_scores}")
150
163
  click.echo(snippet.content)
151
164
  click.echo("-" * 80)
152
165
  click.echo()
@@ -164,11 +177,20 @@ async def keyword(
164
177
  top_k: int,
165
178
  ) -> None:
166
179
  """Search for snippets using keyword search."""
167
- repository = SearchRepository(session)
168
- service = SearchService(
169
- repository,
180
+ source_repository = SourceRepository(session)
181
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
182
+ repository = IndexRepository(session)
183
+ service = IndexService(
184
+ repository=repository,
185
+ source_service=source_service,
170
186
  keyword_search_provider=keyword_search_factory(app_context, session),
171
- embedding_service=embedding_factory(app_context=app_context, session=session),
187
+ code_search_service=embedding_factory(
188
+ task_name="code", app_context=app_context, session=session
189
+ ),
190
+ text_search_service=embedding_factory(
191
+ task_name="text", app_context=app_context, session=session
192
+ ),
193
+ enrichment_service=enrichment_factory(app_context),
172
194
  )
173
195
 
174
196
  snippets = await service.search(SearchRequest(keywords=keywords, top_k=top_k))
@@ -180,6 +202,53 @@ async def keyword(
180
202
  for snippet in snippets:
181
203
  click.echo("-" * 80)
182
204
  click.echo(f"{snippet.uri}")
205
+ click.echo(f"Original scores: {snippet.original_scores}")
206
+ click.echo(snippet.content)
207
+ click.echo("-" * 80)
208
+ click.echo()
209
+
210
+
211
+ @search.command()
212
+ @click.argument("query")
213
+ @click.option("--top-k", default=10, help="Number of snippets to retrieve")
214
+ @with_app_context
215
+ @with_session
216
+ async def text(
217
+ session: AsyncSession,
218
+ app_context: AppContext,
219
+ query: str,
220
+ top_k: int,
221
+ ) -> None:
222
+ """Search for snippets using semantic text search.
223
+
224
+ This works best if your query is text.
225
+ """
226
+ source_repository = SourceRepository(session)
227
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
228
+ repository = IndexRepository(session)
229
+ service = IndexService(
230
+ repository=repository,
231
+ source_service=source_service,
232
+ keyword_search_provider=keyword_search_factory(app_context, session),
233
+ code_search_service=embedding_factory(
234
+ task_name="code", app_context=app_context, session=session
235
+ ),
236
+ text_search_service=embedding_factory(
237
+ task_name="text", app_context=app_context, session=session
238
+ ),
239
+ enrichment_service=enrichment_factory(app_context),
240
+ )
241
+
242
+ snippets = await service.search(SearchRequest(text_query=query, top_k=top_k))
243
+
244
+ if len(snippets) == 0:
245
+ click.echo("No snippets found")
246
+ return
247
+
248
+ for snippet in snippets:
249
+ click.echo("-" * 80)
250
+ click.echo(f"{snippet.uri}")
251
+ click.echo(f"Original scores: {snippet.original_scores}")
183
252
  click.echo(snippet.content)
184
253
  click.echo("-" * 80)
185
254
  click.echo()
@@ -189,28 +258,44 @@ async def keyword(
189
258
  @click.option("--top-k", default=10, help="Number of snippets to retrieve")
190
259
  @click.option("--keywords", required=True, help="Comma separated list of keywords")
191
260
  @click.option("--code", required=True, help="Semantic code search query")
261
+ @click.option("--text", required=True, help="Semantic text search query")
192
262
  @with_app_context
193
263
  @with_session
194
- async def hybrid(
264
+ async def hybrid( # noqa: PLR0913
195
265
  session: AsyncSession,
196
266
  app_context: AppContext,
197
267
  top_k: int,
198
268
  keywords: str,
199
269
  code: str,
270
+ text: str,
200
271
  ) -> None:
201
272
  """Search for snippets using hybrid search."""
202
- repository = SearchRepository(session)
203
- service = SearchService(
204
- repository,
273
+ source_repository = SourceRepository(session)
274
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
275
+ repository = IndexRepository(session)
276
+ service = IndexService(
277
+ repository=repository,
278
+ source_service=source_service,
205
279
  keyword_search_provider=keyword_search_factory(app_context, session),
206
- embedding_service=embedding_factory(app_context=app_context, session=session),
280
+ code_search_service=embedding_factory(
281
+ task_name="code", app_context=app_context, session=session
282
+ ),
283
+ text_search_service=embedding_factory(
284
+ task_name="text", app_context=app_context, session=session
285
+ ),
286
+ enrichment_service=enrichment_factory(app_context),
207
287
  )
208
288
 
209
289
  # Parse keywords into a list of strings
210
290
  keywords_list = [k.strip().lower() for k in keywords.split(",")]
211
291
 
212
292
  snippets = await service.search(
213
- SearchRequest(keywords=keywords_list, code_query=code, top_k=top_k)
293
+ SearchRequest(
294
+ text_query=text,
295
+ keywords=keywords_list,
296
+ code_query=code,
297
+ top_k=top_k,
298
+ )
214
299
  )
215
300
 
216
301
  if len(snippets) == 0:
@@ -220,6 +305,7 @@ async def hybrid(
220
305
  for snippet in snippets:
221
306
  click.echo("-" * 80)
222
307
  click.echo(f"{snippet.uri}")
308
+ click.echo(f"Original scores: {snippet.original_scores}")
223
309
  click.echo(snippet.content)
224
310
  click.echo("-" * 80)
225
311
  click.echo()
@@ -21,7 +21,7 @@ from kodit.embedding.vectorchord_vector_search_service import (
21
21
 
22
22
 
23
23
  def embedding_factory(
24
- app_context: AppContext, session: AsyncSession
24
+ task_name: str, app_context: AppContext, session: AsyncSession
25
25
  ) -> VectorSearchService:
26
26
  """Create an embedding service."""
27
27
  embedding_repository = EmbeddingRepository(session=session)
@@ -33,7 +33,7 @@ def embedding_factory(
33
33
  embedding_provider = LocalEmbeddingProvider(CODE)
34
34
 
35
35
  if app_context.default_search.provider == "vectorchord":
36
- return VectorChordVectorSearchService(session, embedding_provider)
36
+ return VectorChordVectorSearchService(task_name, session, embedding_provider)
37
37
  if app_context.default_search.provider == "sqlite":
38
38
  return LocalVectorSearchService(
39
39
  embedding_repository=embedding_repository,
@@ -38,8 +38,15 @@ def split_sub_batches(encoding: tiktoken.Encoding, data: list[str]) -> list[list
38
38
  item_tokens = len(encoding.encode(next_item))
39
39
 
40
40
  if item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
41
- log.warning("Skipping too long snippet", snippet=data_to_process.pop(0))
42
- continue
41
+ # Loop around trying to truncate the snippet until it fits in the max
42
+ # embedding size
43
+ while item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
44
+ next_item = next_item[:-1]
45
+ item_tokens = len(encoding.encode(next_item))
46
+
47
+ data_to_process[0] = next_item
48
+
49
+ log.warning("Truncated snippet", snippet=next_item)
43
50
 
44
51
  if current_tokens + item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
45
52
  break
@@ -38,26 +38,38 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
38
38
  # Process batches in parallel with a semaphore to limit concurrent requests
39
39
  sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
40
40
 
41
- async def process_batch(batch: list[str]) -> list[Vector]:
41
+ # Create a list of tuples with a temporary id for each batch
42
+ # We need to do this so that we can return the results in the same order as the
43
+ # input data
44
+ input_data = [(i, batch) for i, batch in enumerate(batched_data)]
45
+
46
+ async def process_batch(
47
+ data: tuple[int, list[str]],
48
+ ) -> tuple[int, list[Vector]]:
49
+ batch_id, batch = data
42
50
  async with sem:
43
51
  try:
44
52
  response = await self.openai_client.embeddings.create(
45
53
  model=self.model_name,
46
54
  input=batch,
47
55
  )
48
- return [
56
+ return batch_id, [
49
57
  [float(x) for x in embedding.embedding]
50
58
  for embedding in response.data
51
59
  ]
52
60
  except Exception as e:
53
61
  self.log.exception("Error embedding batch", error=str(e))
54
- return []
62
+ return batch_id, []
55
63
 
56
64
  # Create tasks for all batches
57
- tasks = [process_batch(batch) for batch in batched_data]
65
+ tasks = [process_batch(batch) for batch in input_data]
58
66
 
59
67
  # Process all batches and yield results as they complete
60
- results: list[Vector] = []
68
+ results: list[tuple[int, list[Vector]]] = []
61
69
  for task in asyncio.as_completed(tasks):
62
- results.extend(await task)
63
- return results
70
+ result = await task
71
+ results.append(result)
72
+
73
+ # Output in the same order as the input data
74
+ ordered_results = [result for _, result in sorted(results, key=lambda x: x[0])]
75
+ return [item for sublist in ordered_results for item in sublist]
@@ -12,23 +12,20 @@ from kodit.embedding.vector_search_service import (
12
12
  VectorSearchService,
13
13
  )
14
14
 
15
- TABLE_NAME = "vectorchord_embeddings"
16
- INDEX_NAME = f"{TABLE_NAME}_idx"
17
-
18
15
  # SQL Queries
19
16
  CREATE_VCHORD_EXTENSION = """
20
17
  CREATE EXTENSION IF NOT EXISTS vchord CASCADE;
21
18
  """
22
19
 
23
- CHECK_VCHORD_EMBEDDING_DIMENSION = f"""
20
+ CHECK_VCHORD_EMBEDDING_DIMENSION = """
24
21
  SELECT a.atttypmod as dimension
25
22
  FROM pg_attribute a
26
23
  JOIN pg_class c ON a.attrelid = c.oid
27
24
  WHERE c.relname = '{TABLE_NAME}'
28
25
  AND a.attname = 'embedding';
29
- """ # noqa: S608
26
+ """
30
27
 
31
- CREATE_VCHORD_INDEX = f"""
28
+ CREATE_VCHORD_INDEX = """
32
29
  CREATE INDEX IF NOT EXISTS {INDEX_NAME}
33
30
  ON {TABLE_NAME}
34
31
  USING vchordrq (embedding vector_l2_ops) WITH (options = $$
@@ -38,21 +35,21 @@ lists = []
38
35
  $$);
39
36
  """
40
37
 
41
- INSERT_QUERY = f"""
38
+ INSERT_QUERY = """
42
39
  INSERT INTO {TABLE_NAME} (snippet_id, embedding)
43
40
  VALUES (:snippet_id, :embedding)
44
41
  ON CONFLICT (snippet_id) DO UPDATE
45
42
  SET embedding = EXCLUDED.embedding
46
- """ # noqa: S608
43
+ """
47
44
 
48
45
  # Note that <=> in vectorchord is cosine distance
49
46
  # So scores go from 0 (similar) to 2 (opposite)
50
- SEARCH_QUERY = f"""
47
+ SEARCH_QUERY = """
51
48
  SELECT snippet_id, embedding <=> :query as score
52
49
  FROM {TABLE_NAME}
53
50
  ORDER BY score ASC
54
51
  LIMIT :top_k;
55
- """ # noqa: S608
52
+ """
56
53
 
57
54
 
58
55
  class VectorChordVectorSearchService(VectorSearchService):
@@ -60,6 +57,7 @@ class VectorChordVectorSearchService(VectorSearchService):
60
57
 
61
58
  def __init__(
62
59
  self,
60
+ task_name: str,
63
61
  session: AsyncSession,
64
62
  embedding_provider: EmbeddingProvider,
65
63
  ) -> None:
@@ -67,6 +65,8 @@ class VectorChordVectorSearchService(VectorSearchService):
67
65
  self.embedding_provider = embedding_provider
68
66
  self._session = session
69
67
  self._initialized = False
68
+ self.table_name = f"vectorchord_{task_name}_embeddings"
69
+ self.index_name = f"{self.table_name}_idx"
70
70
 
71
71
  async def _initialize(self) -> None:
72
72
  """Initialize the VectorChord environment."""
@@ -88,15 +88,23 @@ class VectorChordVectorSearchService(VectorSearchService):
88
88
  vector_dim = (await self.embedding_provider.embed(["dimension"]))[0]
89
89
  await self._session.execute(
90
90
  text(
91
- f"""CREATE TABLE IF NOT EXISTS {TABLE_NAME} (
91
+ f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
92
92
  id SERIAL PRIMARY KEY,
93
93
  snippet_id INT NOT NULL UNIQUE,
94
94
  embedding VECTOR({len(vector_dim)}) NOT NULL
95
95
  );"""
96
96
  )
97
97
  )
98
- await self._session.execute(text(CREATE_VCHORD_INDEX))
99
- result = await self._session.execute(text(CHECK_VCHORD_EMBEDDING_DIMENSION))
98
+ await self._session.execute(
99
+ text(
100
+ CREATE_VCHORD_INDEX.format(
101
+ TABLE_NAME=self.table_name, INDEX_NAME=self.index_name
102
+ )
103
+ )
104
+ )
105
+ result = await self._session.execute(
106
+ text(CHECK_VCHORD_EMBEDDING_DIMENSION.format(TABLE_NAME=self.table_name))
107
+ )
100
108
  vector_dim_from_db = result.scalar_one()
101
109
  if vector_dim_from_db != len(vector_dim):
102
110
  msg = (
@@ -123,7 +131,7 @@ class VectorChordVectorSearchService(VectorSearchService):
123
131
  embeddings = await self.embedding_provider.embed([doc.text for doc in data])
124
132
  # Execute inserts
125
133
  await self._execute(
126
- text(INSERT_QUERY),
134
+ text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
127
135
  [
128
136
  {"snippet_id": doc.snippet_id, "embedding": str(embedding)}
129
137
  for doc, embedding in zip(data, embeddings, strict=True)
@@ -135,7 +143,8 @@ class VectorChordVectorSearchService(VectorSearchService):
135
143
  """Query the embedding model."""
136
144
  embedding = await self.embedding_provider.embed([query])
137
145
  result = await self._execute(
138
- text(SEARCH_QUERY), {"query": str(embedding[0]), "top_k": top_k}
146
+ text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
147
+ {"query": str(embedding[0]), "top_k": top_k},
139
148
  )
140
149
  rows = result.mappings().all()
141
150
 
@@ -0,0 +1 @@
1
+ """Enrichment."""
@@ -0,0 +1,23 @@
1
+ """Embedding service."""
2
+
3
+ from kodit.config import AppContext
4
+ from kodit.enrichment.enrichment_provider.local_enrichment_provider import (
5
+ LocalEnrichmentProvider,
6
+ )
7
+ from kodit.enrichment.enrichment_provider.openai_enrichment_provider import (
8
+ OpenAIEnrichmentProvider,
9
+ )
10
+ from kodit.enrichment.enrichment_service import (
11
+ EnrichmentService,
12
+ LLMEnrichmentService,
13
+ )
14
+
15
+
16
+ def enrichment_factory(app_context: AppContext) -> EnrichmentService:
17
+ """Create an embedding service."""
18
+ openai_client = app_context.get_default_openai_client()
19
+ if openai_client is not None:
20
+ enrichment_provider = OpenAIEnrichmentProvider(openai_client=openai_client)
21
+ return LLMEnrichmentService(enrichment_provider)
22
+
23
+ return LLMEnrichmentService(LocalEnrichmentProvider())
@@ -0,0 +1 @@
1
+ """Enrichment provider."""
@@ -0,0 +1,16 @@
1
+ """Enrichment provider."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ ENRICHMENT_SYSTEM_PROMPT = """
6
+ You are a professional software developer. You will be given a snippet of code.
7
+ Please provide a concise explanation of the code.
8
+ """
9
+
10
+
11
+ class EnrichmentProvider(ABC):
12
+ """Enrichment provider."""
13
+
14
+ @abstractmethod
15
+ async def enrich(self, data: list[str]) -> list[str]:
16
+ """Enrich a list of strings."""
@@ -0,0 +1,63 @@
1
+ """Local embedding service."""
2
+
3
+ import os
4
+
5
+ import structlog
6
+ from transformers.models.auto.modeling_auto import AutoModelForCausalLM
7
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
8
+
9
+ from kodit.enrichment.enrichment_provider.enrichment_provider import (
10
+ ENRICHMENT_SYSTEM_PROMPT,
11
+ EnrichmentProvider,
12
+ )
13
+
14
+
15
+ class LocalEnrichmentProvider(EnrichmentProvider):
16
+ """Local embedder."""
17
+
18
+ def __init__(self, model_name: str = "Qwen/Qwen3-0.6B") -> None:
19
+ """Initialize the local enrichment provider."""
20
+ self.log = structlog.get_logger(__name__)
21
+ self.model_name = model_name
22
+ self.model = None
23
+ self.tokenizer = None
24
+
25
+ async def enrich(self, data: list[str]) -> list[str]:
26
+ """Enrich a list of strings."""
27
+ if self.tokenizer is None:
28
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
29
+ if self.model is None:
30
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
31
+ self.model = AutoModelForCausalLM.from_pretrained(
32
+ self.model_name,
33
+ torch_dtype="auto",
34
+ trust_remote_code=True,
35
+ )
36
+
37
+ results = []
38
+ for snippet in data:
39
+ # prepare the model input
40
+ messages = [
41
+ {"role": "system", "content": ENRICHMENT_SYSTEM_PROMPT},
42
+ {"role": "user", "content": snippet},
43
+ ]
44
+ text = self.tokenizer.apply_chat_template(
45
+ messages,
46
+ tokenize=False,
47
+ add_generation_prompt=True,
48
+ enable_thinking=False,
49
+ )
50
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(
51
+ self.model.device
52
+ )
53
+
54
+ # conduct text completion
55
+ generated_ids = self.model.generate(**model_inputs, max_new_tokens=32768)
56
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
57
+ content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip(
58
+ "\n"
59
+ )
60
+
61
+ results.append(content)
62
+
63
+ return results
@@ -0,0 +1,77 @@
1
+ """OpenAI embedding service."""
2
+
3
+ import asyncio
4
+
5
+ import structlog
6
+ import tiktoken
7
+ from openai import AsyncOpenAI
8
+ from tqdm import tqdm
9
+
10
+ from kodit.enrichment.enrichment_provider.enrichment_provider import (
11
+ ENRICHMENT_SYSTEM_PROMPT,
12
+ EnrichmentProvider,
13
+ )
14
+
15
+ OPENAI_NUM_PARALLEL_TASKS = 10
16
+
17
+
18
+ class OpenAIEnrichmentProvider(EnrichmentProvider):
19
+ """OpenAI enrichment provider."""
20
+
21
+ def __init__(
22
+ self,
23
+ openai_client: AsyncOpenAI,
24
+ model_name: str = "gpt-4o-mini",
25
+ ) -> None:
26
+ """Initialize the OpenAI enrichment provider."""
27
+ self.log = structlog.get_logger(__name__)
28
+ self.openai_client = openai_client
29
+ self.model_name = model_name
30
+ self.encoding = tiktoken.encoding_for_model(model_name)
31
+
32
+ async def enrich(self, data: list[str]) -> list[str]:
33
+ """Enrich a list of documents."""
34
+ # Process batches in parallel with a semaphore to limit concurrent requests
35
+ sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
36
+
37
+ # Create a list of tuples with a temporary id for each snippet
38
+ # We need to do this so that we can return the results in the same order as the
39
+ # input data
40
+ input_data = [(i, snippet) for i, snippet in enumerate(data)]
41
+
42
+ async def process_data(data: tuple[int, str]) -> tuple[int, str]:
43
+ snippet_id, snippet = data
44
+ if not snippet:
45
+ return snippet_id, ""
46
+ async with sem:
47
+ try:
48
+ response = await self.openai_client.chat.completions.create(
49
+ model=self.model_name,
50
+ messages=[
51
+ {
52
+ "role": "system",
53
+ "content": ENRICHMENT_SYSTEM_PROMPT,
54
+ },
55
+ {"role": "user", "content": snippet},
56
+ ],
57
+ )
58
+ return snippet_id, response.choices[0].message.content or ""
59
+ except Exception as e:
60
+ self.log.exception("Error enriching data", error=str(e))
61
+ return snippet_id, ""
62
+
63
+ # Create tasks for all data
64
+ tasks = [process_data(snippet) for snippet in input_data]
65
+
66
+ # Process all data and yield results as they complete
67
+ results: list[tuple[int, str]] = []
68
+ for task in tqdm(
69
+ asyncio.as_completed(tasks),
70
+ total=len(tasks),
71
+ leave=False,
72
+ ):
73
+ result = await task
74
+ results.append(result)
75
+
76
+ # Output in the same order as the input data
77
+ return [result for _, result in sorted(results, key=lambda x: x[0])]
@@ -0,0 +1,33 @@
1
+ """Enrichment service."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from kodit.enrichment.enrichment_provider.enrichment_provider import EnrichmentProvider
6
+
7
+
8
+ class EnrichmentService(ABC):
9
+ """Enrichment service."""
10
+
11
+ @abstractmethod
12
+ async def enrich(self, data: list[str]) -> list[str]:
13
+ """Enrich a list of strings."""
14
+
15
+
16
+ class NullEnrichmentService(EnrichmentService):
17
+ """Null enrichment service."""
18
+
19
+ async def enrich(self, data: list[str]) -> list[str]:
20
+ """Enrich a list of strings."""
21
+ return [""] * len(data)
22
+
23
+
24
+ class LLMEnrichmentService(EnrichmentService):
25
+ """Enrichment service using an LLM."""
26
+
27
+ def __init__(self, enrichment_provider: EnrichmentProvider) -> None:
28
+ """Initialize the enrichment service."""
29
+ self.enrichment_provider = enrichment_provider
30
+
31
+ async def enrich(self, data: list[str]) -> list[str]:
32
+ """Enrich a list of strings."""
33
+ return await self.enrichment_provider.enrich(data)
@@ -0,0 +1,67 @@
1
+ """Fusion functions for combining search results."""
2
+
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class FusionResult:
9
+ """Result of a fusion operation."""
10
+
11
+ id: int
12
+ score: float
13
+ original_scores: list[float]
14
+
15
+
16
+ @dataclass
17
+ class FusionRequest:
18
+ """Result of a RRF operation."""
19
+
20
+ id: int
21
+ score: float
22
+
23
+
24
+ def reciprocal_rank_fusion(
25
+ rankings: list[list[FusionRequest]], k: float = 60
26
+ ) -> list[FusionResult]:
27
+ """RRF prioritises results that are present in all results.
28
+
29
+ Args:
30
+ rankings: List of rankers, each containing a list of document ids. Top of the
31
+ list is considered to be the best result.
32
+ k: Parameter for RRF.
33
+
34
+ Returns:
35
+ Dictionary of ids and their scores.
36
+
37
+ """
38
+ scores = {}
39
+ for ranker in rankings:
40
+ for rank in ranker:
41
+ scores[rank.id] = float(0)
42
+
43
+ for ranker in rankings:
44
+ for i, rank in enumerate(ranker):
45
+ scores[rank.id] += 1.0 / (k + i)
46
+
47
+ # Create a list of tuples of ids and their scores
48
+ results = [(rank, scores[rank]) for rank in scores]
49
+
50
+ # Sort results by score
51
+ results.sort(key=lambda x: x[1], reverse=True)
52
+
53
+ # Create a map of original scores to ids
54
+ original_scores_to_ids = defaultdict(list)
55
+ for ranker in rankings:
56
+ for rank in ranker:
57
+ original_scores_to_ids[rank.id].append(rank.score)
58
+
59
+ # Rebuild a list of final results with their original scores
60
+ return [
61
+ FusionResult(
62
+ id=result[0],
63
+ score=result[1],
64
+ original_scores=original_scores_to_ids[result[0]],
65
+ )
66
+ for result in results
67
+ ]
@@ -196,3 +196,23 @@ class IndexRepository:
196
196
  """
197
197
  self.session.add(embedding)
198
198
  await self.session.commit()
199
+
200
+ async def list_snippets_by_ids(self, ids: list[int]) -> list[tuple[File, Snippet]]:
201
+ """List snippets by IDs.
202
+
203
+ Returns:
204
+ A list of snippets in the same order as the input IDs.
205
+
206
+ """
207
+ query = (
208
+ select(Snippet, File)
209
+ .where(Snippet.id.in_(ids))
210
+ .join(File, Snippet.file_id == File.id)
211
+ )
212
+ rows = await self.session.execute(query)
213
+
214
+ # Create a dictionary for O(1) lookup of results by ID
215
+ id_to_result = {snippet.id: (file, snippet) for snippet, file in rows.all()}
216
+
217
+ # Return results in the same order as input IDs
218
+ return [id_to_result[i] for i in ids]
@@ -13,11 +13,17 @@ import pydantic
13
13
  import structlog
14
14
  from tqdm.asyncio import tqdm
15
15
 
16
- from kodit.bm25.keyword_search_service import BM25Document, KeywordSearchProvider
16
+ from kodit.bm25.keyword_search_service import (
17
+ BM25Document,
18
+ BM25Result,
19
+ KeywordSearchProvider,
20
+ )
17
21
  from kodit.embedding.vector_search_service import (
18
22
  VectorSearchRequest,
19
23
  VectorSearchService,
20
24
  )
25
+ from kodit.enrichment.enrichment_service import EnrichmentService
26
+ from kodit.indexing.fusion import FusionRequest, reciprocal_rank_fusion
21
27
  from kodit.indexing.indexing_models import Snippet
22
28
  from kodit.indexing.indexing_repository import IndexRepository
23
29
  from kodit.snippets.snippets import SnippetService
@@ -42,6 +48,28 @@ class IndexView(pydantic.BaseModel):
42
48
  num_snippets: int | None = None
43
49
 
44
50
 
51
+ class SearchRequest(pydantic.BaseModel):
52
+ """Request for a search."""
53
+
54
+ text_query: str | None = None
55
+ code_query: str | None = None
56
+ keywords: list[str] | None = None
57
+ top_k: int = 10
58
+
59
+
60
+ class SearchResult(pydantic.BaseModel):
61
+ """Data transfer object for search results.
62
+
63
+ This model represents a single search result, containing both the file path
64
+ and the matching snippet content.
65
+ """
66
+
67
+ id: int
68
+ uri: str
69
+ content: str
70
+ original_scores: list[float]
71
+
72
+
45
73
  class IndexService:
46
74
  """Service for managing code indexes.
47
75
 
@@ -50,12 +78,14 @@ class IndexService:
50
78
  IndexRepository), and provides a clean API for index management.
51
79
  """
52
80
 
53
- def __init__(
81
+ def __init__( # noqa: PLR0913
54
82
  self,
55
83
  repository: IndexRepository,
56
84
  source_service: SourceService,
57
85
  keyword_search_provider: KeywordSearchProvider,
58
- vector_search_service: VectorSearchService,
86
+ code_search_service: VectorSearchService,
87
+ text_search_service: VectorSearchService,
88
+ enrichment_service: EnrichmentService,
59
89
  ) -> None:
60
90
  """Initialize the index service.
61
91
 
@@ -69,7 +99,9 @@ class IndexService:
69
99
  self.snippet_service = SnippetService()
70
100
  self.log = structlog.get_logger(__name__)
71
101
  self.keyword_search_provider = keyword_search_provider
72
- self.code_search_service = vector_search_service
102
+ self.code_search_service = code_search_service
103
+ self.text_search_service = text_search_service
104
+ self.enrichment_service = enrichment_service
73
105
 
74
106
  async def create(self, source_id: int) -> IndexView:
75
107
  """Create a new index for a source.
@@ -152,9 +184,93 @@ class IndexService:
152
184
  ]
153
185
  )
154
186
 
187
+ self.log.info("Enriching snippets")
188
+ enriched_contents = await self.enrichment_service.enrich(
189
+ [snippet.content for snippet in snippets]
190
+ )
191
+
192
+ self.log.info("Creating semantic text index")
193
+ with Spinner():
194
+ await self.text_search_service.index(
195
+ [
196
+ VectorSearchRequest(snippet.id, enriched_content)
197
+ for snippet, enriched_content in zip(
198
+ snippets, enriched_contents, strict=True
199
+ )
200
+ ]
201
+ )
202
+ # Add the enriched text back to the snippets and write to the database
203
+ for snippet, enriched_content in zip(
204
+ snippets, enriched_contents, strict=True
205
+ ):
206
+ snippet.content = (
207
+ enriched_content + "\n\n```\n" + snippet.content + "\n```"
208
+ )
209
+ await self.repository.add_snippet_or_update_content(snippet)
210
+
155
211
  # Update index timestamp
156
212
  await self.repository.update_index_timestamp(index)
157
213
 
214
+ async def search(self, request: SearchRequest) -> list[SearchResult]:
215
+ """Search for relevant data."""
216
+ fusion_list: list[list[FusionRequest]] = []
217
+ if request.keywords:
218
+ # Gather results for each keyword
219
+ result_ids: list[BM25Result] = []
220
+ for keyword in request.keywords:
221
+ results = await self.keyword_search_provider.retrieve(
222
+ keyword, request.top_k
223
+ )
224
+ result_ids.extend(results)
225
+
226
+ fusion_list.append(
227
+ [FusionRequest(id=x.snippet_id, score=x.score) for x in result_ids]
228
+ )
229
+
230
+ # Compute embedding for semantic query
231
+ if request.code_query:
232
+ query_embedding = await self.code_search_service.retrieve(
233
+ request.code_query, top_k=request.top_k
234
+ )
235
+ fusion_list.append(
236
+ [FusionRequest(id=x.snippet_id, score=x.score) for x in query_embedding]
237
+ )
238
+
239
+ if request.text_query:
240
+ query_embedding = await self.text_search_service.retrieve(
241
+ request.text_query, top_k=request.top_k
242
+ )
243
+ fusion_list.append(
244
+ [FusionRequest(id=x.snippet_id, score=x.score) for x in query_embedding]
245
+ )
246
+
247
+ if len(fusion_list) == 0:
248
+ return []
249
+
250
+ # Combine all results together with RFF if required
251
+ final_results = reciprocal_rank_fusion(
252
+ rankings=fusion_list,
253
+ k=60,
254
+ )
255
+
256
+ # Only keep top_k results
257
+ final_results = final_results[: request.top_k]
258
+
259
+ # Get snippets from database (up to top_k)
260
+ search_results = await self.repository.list_snippets_by_ids(
261
+ [x.id for x in final_results]
262
+ )
263
+
264
+ return [
265
+ SearchResult(
266
+ id=snippet.id,
267
+ uri=file.uri,
268
+ content=snippet.content,
269
+ original_scores=fr.original_scores,
270
+ )
271
+ for (file, snippet), fr in zip(search_results, final_results, strict=True)
272
+ ]
273
+
158
274
  async def _create_snippets(
159
275
  self,
160
276
  index_id: int,
kodit/mcp.py CHANGED
@@ -16,8 +16,11 @@ from kodit.bm25.keyword_search_factory import keyword_search_factory
16
16
  from kodit.config import AppContext
17
17
  from kodit.database import Database
18
18
  from kodit.embedding.embedding_factory import embedding_factory
19
- from kodit.search.search_repository import SearchRepository
20
- from kodit.search.search_service import SearchRequest, SearchResult, SearchService
19
+ from kodit.enrichment.enrichment_factory import enrichment_factory
20
+ from kodit.indexing.indexing_repository import IndexRepository
21
+ from kodit.indexing.indexing_service import IndexService, SearchRequest, SearchResult
22
+ from kodit.source.source_repository import SourceRepository
23
+ from kodit.source.source_service import SourceService
21
24
 
22
25
 
23
26
  @dataclass
@@ -123,32 +126,38 @@ async def search(
123
126
 
124
127
  mcp_context: MCPContext = ctx.request_context.lifespan_context
125
128
 
126
- log.debug("Creating search repository")
127
- search_repository = SearchRepository(
128
- session=mcp_context.session,
129
+ source_repository = SourceRepository(mcp_context.session)
130
+ source_service = SourceService(
131
+ mcp_context.app_context.get_clone_dir(), source_repository
129
132
  )
130
-
131
- log.debug("Creating embedding service")
132
- embedding_service = embedding_factory(
133
- app_context=mcp_context.app_context, session=mcp_context.session
134
- )
135
-
136
- log.debug("Creating search service")
137
- search_service = SearchService(
138
- repository=search_repository,
133
+ repository = IndexRepository(mcp_context.session)
134
+ service = IndexService(
135
+ repository=repository,
136
+ source_service=source_service,
139
137
  keyword_search_provider=keyword_search_factory(
138
+ mcp_context.app_context, mcp_context.session
139
+ ),
140
+ code_search_service=embedding_factory(
141
+ task_name="code",
140
142
  app_context=mcp_context.app_context,
141
143
  session=mcp_context.session,
142
144
  ),
143
- embedding_service=embedding_service,
145
+ text_search_service=embedding_factory(
146
+ task_name="text",
147
+ app_context=mcp_context.app_context,
148
+ session=mcp_context.session,
149
+ ),
150
+ enrichment_service=enrichment_factory(mcp_context.app_context),
144
151
  )
145
152
 
146
153
  search_request = SearchRequest(
147
154
  keywords=keywords,
148
155
  code_query="\n".join(related_file_contents),
156
+ text_query=user_intent,
149
157
  )
158
+
150
159
  log.debug("Searching for snippets")
151
- snippets = await search_service.search(request=search_request)
160
+ snippets = await service.search(request=search_request)
152
161
 
153
162
  log.debug("Fusing output")
154
163
  output = output_fusion(snippets=snippets)
@@ -0,0 +1,26 @@
1
+ (function_declaration
2
+ name: (identifier) @function.name
3
+ body: (block) @function.body
4
+ ) @function.def
5
+
6
+ (method_declaration
7
+ name: (field_identifier) @method.name
8
+ body: (block) @method.body
9
+ ) @method.def
10
+
11
+ (import_declaration
12
+ (import_spec
13
+ path: (interpreted_string_literal) @import.name
14
+ )
15
+ ) @import.statement
16
+
17
+ (identifier) @ident
18
+
19
+ (parameter_declaration
20
+ name: (identifier) @param.name
21
+ )
22
+
23
+ (package_clause "package" (package_identifier) @name.definition.module)
24
+
25
+ ;; Exclude comments from being captured
26
+ (comment) @comment
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kodit
3
- Version: 0.1.15
3
+ Version: 0.1.16
4
4
  Summary: Code indexing for better AI code generation
5
5
  Project-URL: Homepage, https://docs.helixml.tech/kodit/
6
6
  Project-URL: Documentation, https://docs.helixml.tech/kodit/
@@ -15,6 +15,7 @@ Keywords: ai,indexing,mcp,rag
15
15
  Classifier: Development Status :: 2 - Pre-Alpha
16
16
  Classifier: Intended Audience :: Developers
17
17
  Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
18
19
  Classifier: Topic :: Software Development :: Code Generators
19
20
  Requires-Python: >=3.12
20
21
  Requires-Dist: aiofiles>=24.1.0
@@ -42,6 +43,7 @@ Requires-Dist: sqlalchemy[asyncio]>=2.0.40
42
43
  Requires-Dist: structlog>=25.3.0
43
44
  Requires-Dist: tdqm>=0.0.1
44
45
  Requires-Dist: tiktoken>=0.9.0
46
+ Requires-Dist: transformers>=4.51.3
45
47
  Requires-Dist: tree-sitter-language-pack>=0.7.3
46
48
  Requires-Dist: tree-sitter>=0.24.0
47
49
  Requires-Dist: uritools>=5.0.0
@@ -1,12 +1,12 @@
1
1
  kodit/.gitignore,sha256=ztkjgRwL9Uud1OEi36hGQeDGk3OLK1NfDEO8YqGYy8o,11
2
2
  kodit/__init__.py,sha256=aEKHYninUq1yh6jaNfvJBYg-6fenpN132nJt1UU6Jxs,59
3
- kodit/_version.py,sha256=OX-WIjJlMaFvqRmCfLtOYEOYoiov9NdOA089N36rG-g,513
3
+ kodit/_version.py,sha256=VYJNWHISWEW-KD_clKUYcTY_Z30r993Sjws4URJIL0g,513
4
4
  kodit/app.py,sha256=Mr5BFHOHx5zppwjC4XPWVvHjwgl1yrKbUjTWXKubJQM,891
5
- kodit/cli.py,sha256=wAaMZQs-h6hyashWB3DBR2GIf496vfHmepcXhpa7-eM,8085
5
+ kodit/cli.py,sha256=i7eEt0FdIQGEfXKFte-8fBcZZGE8BPXBp40aGwJDQGI,11323
6
6
  kodit/config.py,sha256=2W2u5J8j-Mbt-C4xzOuK-PeuDCx0S_rnCXPhBwvfLT4,4353
7
7
  kodit/database.py,sha256=WB1KpVxUYPgiJGU0gJa2hqytYB8wJEJ5z3WayhWzNMU,2403
8
8
  kodit/log.py,sha256=HU1OmuxO4FcVw61k4WW7Y4WM7BrDaeplw1PcBHhuIZY,5434
9
- kodit/mcp.py,sha256=HA3R7YG0Al1A6MjSCSIi0hEGXG3WP7tix-N5AROasCM,5278
9
+ kodit/mcp.py,sha256=QruyPskWB0_x59pkfj5BBeXuR13GMny5TAZEa2j4U9s,5752
10
10
  kodit/middleware.py,sha256=I6FOkqG9-8RH5kR1-0ZoQWfE4qLCB8lZYv8H_OCH29o,2714
11
11
  kodit/bm25/__init__.py,sha256=j8zyriNWhbwE5Lbybzg1hQAhANlU9mKHWw4beeUR6og,19
12
12
  kodit/bm25/keyword_search_factory.py,sha256=rp-wx3DJsc2KlELK1V337EyeYvmwnMQwUqOo1WVPSmg,631
@@ -14,21 +14,29 @@ kodit/bm25/keyword_search_service.py,sha256=aBbWQKgQmi2re3EIHdXFS00n7Wj3b2D0pZsL
14
14
  kodit/bm25/local_bm25.py,sha256=AAbFhbQDqyL3d7jsPL7W4HsLxdoYctaDsREUXOLy6jM,3260
15
15
  kodit/bm25/vectorchord_bm25.py,sha256=_nGrkUReYLLV-L8RIuIVLwjuhSYZl9T532n5OVf0kWs,6393
16
16
  kodit/embedding/__init__.py,sha256=h9NXzDA1r-K23nvBajBV-RJzHJN0p3UJ7UQsmdnOoRw,24
17
- kodit/embedding/embedding_factory.py,sha256=qzoxBS3scR-ABd-u9215uGES7c6clYy2DiKcSDQivnA,1603
17
+ kodit/embedding/embedding_factory.py,sha256=UGnFRyyQXazSUOwyW4Hg7Vq2-kfAoDj9lD4CTLu8x04,1630
18
18
  kodit/embedding/embedding_models.py,sha256=rN90vSs86dYiqoawcp8E9jtwY31JoJXYfaDlsJK7uqc,656
19
19
  kodit/embedding/embedding_repository.py,sha256=-ux3scpBzel8c0pMH9fNOEsSXFIzl-IfgaWrkTb1szo,6907
20
20
  kodit/embedding/local_vector_search_service.py,sha256=hkF0qlfzjyGt400qIX9Mr6B7b7i8WvYIYWN2Z2C_pcs,1907
21
21
  kodit/embedding/vector_search_service.py,sha256=pQJ129QjGrAWOXzqkywmgtDRpy8_gtzYgkivyqF9Vrs,1009
22
- kodit/embedding/vectorchord_vector_search_service.py,sha256=OsVeM3gpoT8Ihzh-kEIzBm3xh_a4D-sErPvsQSKCME8,4732
22
+ kodit/embedding/vectorchord_vector_search_service.py,sha256=KSs0IMFHHIllwq2d3A0LGqGGZDqO1Ht6K-BCfBBWW0Y,5051
23
23
  kodit/embedding/embedding_provider/__init__.py,sha256=h9NXzDA1r-K23nvBajBV-RJzHJN0p3UJ7UQsmdnOoRw,24
24
- kodit/embedding/embedding_provider/embedding_provider.py,sha256=NKs4nriup47R8xRciP07NE1-eZE9RPHklS7VH910UZ4,1537
24
+ kodit/embedding/embedding_provider/embedding_provider.py,sha256=Tf3bwUsUMzAgoyLFM5qBtOLqPp1qr03TzrwGczkDvy0,1835
25
25
  kodit/embedding/embedding_provider/hash_embedding_provider.py,sha256=nAhlhh8j8PqqCCbhVl26Y8ntFBm2vJBCtB4X04g5Wwg,2638
26
26
  kodit/embedding/embedding_provider/local_embedding_provider.py,sha256=4ER-UPq506Y0TWU6qcs0nUqw6bSKQkSrdog-DhNQWM8,1906
27
- kodit/embedding/embedding_provider/openai_embedding_provider.py,sha256=bmUpegDgaF5Qj9uWcj1az4ADA2cKHUjraaMjGGPr83U,2076
27
+ kodit/embedding/embedding_provider/openai_embedding_provider.py,sha256=V_jdUXiaGdslplwxMlfgFc4_hAVS2eaJXMTs2C7RiLI,2666
28
+ kodit/enrichment/__init__.py,sha256=vBEolHpKaHUhfINX0dSGyAPlvgpLNAer9YzFtdvCB24,18
29
+ kodit/enrichment/enrichment_factory.py,sha256=vKjkUTdhj74IW2S4GENDWdWMJx6BwUSZjJGDC0i7DSk,787
30
+ kodit/enrichment/enrichment_service.py,sha256=87Sd3gGbEMJYb_wVrHG8L1yGIZmQNR7foUS4_y94azI,977
31
+ kodit/enrichment/enrichment_provider/__init__.py,sha256=klf8iuLVWX4iRz-DZQauFFNAoJC5CByczh48TBZPW-o,27
32
+ kodit/enrichment/enrichment_provider/enrichment_provider.py,sha256=E0H5rq3OENM0yYbA8K_3nSnj5lUHCpoIOqpWLo-2MVU,413
33
+ kodit/enrichment/enrichment_provider/local_enrichment_provider.py,sha256=bR6HR1gH7wtZdMLOwaKdASjvllRo1FlNW9GyZC11zAM,2164
34
+ kodit/enrichment/enrichment_provider/openai_enrichment_provider.py,sha256=gYuFTAeIVdQNlCUvNSPgRoiRwCvRD0C8419h8ubyABA,2725
28
35
  kodit/indexing/__init__.py,sha256=cPyi2Iej3G1JFWlWr7X80_UrsMaTu5W5rBwgif1B3xo,75
36
+ kodit/indexing/fusion.py,sha256=TZb4fPAedXdEUXzwzOofW98QIOymdbclBOP1KOijuEk,1674
29
37
  kodit/indexing/indexing_models.py,sha256=6NX9HVcj6Pu9ePwHC7n-PWSyAgukpJq0nCNmUIigtbo,1282
30
- kodit/indexing/indexing_repository.py,sha256=4RJ3zY8p6QxHrYW7dDjru_w94Eu19v2gQ4mdlTgcXvY,6331
31
- kodit/indexing/indexing_service.py,sha256=T_dxOzNW_0OCpR4Fha1hHuNkmtLcDMZwL6t5xeu5VXQ,6613
38
+ kodit/indexing/indexing_repository.py,sha256=GYHoACUWYKQdVTwP7tfik_TMUD1WUK76nywH88eCSwg,7006
39
+ kodit/indexing/indexing_service.py,sha256=tKcZpi0pzsmF6OpqnqF0Q5HfSXxi5iLTysrVSou4JiQ,10579
32
40
  kodit/migrations/README,sha256=ISVtAOvqvKk_5ThM5ioJE-lMkvf9IbknFUFVU_vPma4,58
33
41
  kodit/migrations/__init__.py,sha256=lP5MuwlyWRMO6UcDWnQcQ3G-GYHcFb6rl9gYPHJ1sjo,40
34
42
  kodit/migrations/env.py,sha256=w1M7OZh-ZeR2dPHS0ByXAUxQjfZQ8xIzMseWuzLDTWw,2469
@@ -36,14 +44,12 @@ kodit/migrations/script.py.mako,sha256=zWziKtiwYKEWuwPV_HBNHwa9LCT45_bi01-uSNFaO
36
44
  kodit/migrations/versions/7c3bbc2ab32b_add_embeddings_table.py,sha256=-61qol9PfQKILCDQRA5jEaats9aGZs9Wdtp-j-38SF4,1644
37
45
  kodit/migrations/versions/85155663351e_initial.py,sha256=Cg7zlF871o9ShV5rQMQ1v7hRV7fI59veDY9cjtTrs-8,3306
38
46
  kodit/migrations/versions/__init__.py,sha256=9-lHzptItTzq_fomdIRBegQNm4Znx6pVjwD4MiqRIdo,36
39
- kodit/search/__init__.py,sha256=4QbdjbrlhNKMovmuKHxJnUeZT7KNjTTFU0GdnuwUHdQ,36
40
- kodit/search/search_repository.py,sha256=6q0k7JMTM_7hPK2TSA30CykGbc5N16kCL7HTjlbai0w,1563
41
- kodit/search/search_service.py,sha256=-XlbP_50e1dKFJ5jBvex5FjLnffW43LcwQV_SeYNFB0,3944
42
47
  kodit/snippets/__init__.py,sha256=-2coNoCRjTixU9KcP6alpmt7zqf37tCRWH3D7FPJ8dg,48
43
48
  kodit/snippets/method_snippets.py,sha256=EVHhSNWahAC5nSXv9fWVFJY2yq25goHdCSCuENC07F8,4145
44
49
  kodit/snippets/snippets.py,sha256=mwN0bM1Msu8ZeEsUHyQ7tx3Hj3vZsm8G7Wu4eWSkLY8,1539
45
50
  kodit/snippets/languages/__init__.py,sha256=Bj5KKZSls2MQ8ZY1S_nHg447MgGZW-2WZM-oq6vjwwA,1187
46
51
  kodit/snippets/languages/csharp.scm,sha256=gbBN4RiV1FBuTJF6orSnDFi8H9JwTw-d4piLJYsWUsc,222
52
+ kodit/snippets/languages/go.scm,sha256=SEX9mTOrhP2KiQW7oflDKkd21u5dK56QbJ4LvTDxY8A,533
47
53
  kodit/snippets/languages/python.scm,sha256=ee85R9PBzwye3IMTE7-iVoKWd_ViU3EJISTyrFGrVeo,429
48
54
  kodit/source/__init__.py,sha256=1NTZyPdjThVQpZO1Mp1ColVsS7sqYanOVLqnoqV9Ipo,83
49
55
  kodit/source/source_models.py,sha256=xb42CaNDO1CUB8SIW-xXMrB6Ji8cFw-yeJ550xBEg9Q,2398
@@ -51,8 +57,8 @@ kodit/source/source_repository.py,sha256=0EksMpoLzdkfe8S4eeCm4Sf7TuxsOzOzaF4BBsM
51
57
  kodit/source/source_service.py,sha256=u_GaH07ewakThQJRfT8O_yZ54A52qLtJuM1bF3xUT2A,9633
52
58
  kodit/util/__init__.py,sha256=bPu6CtqDWCRGU7VgW2_aiQrCBi8G89FS6k1PjvDajJ0,37
53
59
  kodit/util/spinner.py,sha256=R9bzrHtBiIH6IfLbmsIVHL53s8vg-tqW4lwGGALu4dw,1932
54
- kodit-0.1.15.dist-info/METADATA,sha256=8E-bw8L-Df5Hdt16R5IWkyw7uUAr13CwYfcEyFExaPw,2380
55
- kodit-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
56
- kodit-0.1.15.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
57
- kodit-0.1.15.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
58
- kodit-0.1.15.dist-info/RECORD,,
60
+ kodit-0.1.16.dist-info/METADATA,sha256=1lR4ZSTiRBzUv9Gj8FPspv4GU2vWGQU6HSiffWgU2Do,2467
61
+ kodit-0.1.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
62
+ kodit-0.1.16.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
63
+ kodit-0.1.16.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
64
+ kodit-0.1.16.dist-info/RECORD,,
kodit/search/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Search for relevant snippets."""
@@ -1,57 +0,0 @@
1
- """Repository for searching for relevant snippets."""
2
-
3
- from typing import TypeVar
4
-
5
- from sqlalchemy import (
6
- select,
7
- )
8
- from sqlalchemy.ext.asyncio import AsyncSession
9
-
10
- from kodit.indexing.indexing_models import Snippet
11
- from kodit.source.source_models import File
12
-
13
- T = TypeVar("T")
14
-
15
-
16
- class SearchRepository:
17
- """Repository for searching for relevant snippets."""
18
-
19
- def __init__(self, session: AsyncSession) -> None:
20
- """Initialize the search repository.
21
-
22
- Args:
23
- session: The SQLAlchemy async session to use for database operations.
24
-
25
- """
26
- self.session = session
27
-
28
- async def list_snippet_ids(self) -> list[int]:
29
- """List all snippet IDs.
30
-
31
- Returns:
32
- A list of all snippets.
33
-
34
- """
35
- query = select(Snippet.id)
36
- rows = await self.session.execute(query)
37
- return list(rows.scalars().all())
38
-
39
- async def list_snippets_by_ids(self, ids: list[int]) -> list[tuple[File, Snippet]]:
40
- """List snippets by IDs.
41
-
42
- Returns:
43
- A list of snippets in the same order as the input IDs.
44
-
45
- """
46
- query = (
47
- select(Snippet, File)
48
- .where(Snippet.id.in_(ids))
49
- .join(File, Snippet.file_id == File.id)
50
- )
51
- rows = await self.session.execute(query)
52
-
53
- # Create a dictionary for O(1) lookup of results by ID
54
- id_to_result = {snippet.id: (file, snippet) for snippet, file in rows.all()}
55
-
56
- # Return results in the same order as input IDs
57
- return [id_to_result[i] for i in ids]
@@ -1,135 +0,0 @@
1
- """Search service."""
2
-
3
- import pydantic
4
- import structlog
5
-
6
- from kodit.bm25.keyword_search_service import BM25Result, KeywordSearchProvider
7
- from kodit.embedding.vector_search_service import VectorSearchService
8
- from kodit.search.search_repository import SearchRepository
9
-
10
-
11
- class SearchRequest(pydantic.BaseModel):
12
- """Request for a search."""
13
-
14
- code_query: str | None = None
15
- keywords: list[str] | None = None
16
- top_k: int = 10
17
-
18
-
19
- class SearchResult(pydantic.BaseModel):
20
- """Data transfer object for search results.
21
-
22
- This model represents a single search result, containing both the file path
23
- and the matching snippet content.
24
- """
25
-
26
- id: int
27
- uri: str
28
- content: str
29
-
30
-
31
- class Snippet(pydantic.BaseModel):
32
- """Snippet model."""
33
-
34
- content: str
35
- file_path: str
36
-
37
-
38
- class SearchService:
39
- """Service for searching for relevant data."""
40
-
41
- def __init__(
42
- self,
43
- repository: SearchRepository,
44
- keyword_search_provider: KeywordSearchProvider,
45
- embedding_service: VectorSearchService,
46
- ) -> None:
47
- """Initialize the search service."""
48
- self.repository = repository
49
- self.log = structlog.get_logger(__name__)
50
- self.keyword_search_provider = keyword_search_provider
51
- self.code_embedding_service = embedding_service
52
-
53
- async def search(self, request: SearchRequest) -> list[SearchResult]:
54
- """Search for relevant data."""
55
- fusion_list = []
56
- if request.keywords:
57
- # Gather results for each keyword
58
- result_ids: list[BM25Result] = []
59
- for keyword in request.keywords:
60
- results = await self.keyword_search_provider.retrieve(
61
- keyword, request.top_k
62
- )
63
- result_ids.extend(results)
64
-
65
- # Sort results by score
66
- result_ids.sort(key=lambda x: x[1], reverse=True)
67
-
68
- self.log.debug("Search results (BM25)", results=result_ids)
69
-
70
- bm25_results = [x[0] for x in result_ids]
71
- fusion_list.append(bm25_results)
72
-
73
- # Compute embedding for semantic query
74
- semantic_results = []
75
- if request.code_query:
76
- query_embedding = await self.code_embedding_service.retrieve(
77
- request.code_query, top_k=request.top_k
78
- )
79
- semantic_results = [x.snippet_id for x in query_embedding]
80
- fusion_list.append(semantic_results)
81
-
82
- if len(fusion_list) == 0:
83
- return []
84
-
85
- # Combine all results together with RFF if required
86
- final_results = reciprocal_rank_fusion(fusion_list, k=60)
87
-
88
- # Extract ids from final results
89
- final_ids = [x[0] for x in final_results]
90
-
91
- # Get snippets from database (up to top_k)
92
- search_results = await self.repository.list_snippets_by_ids(
93
- final_ids[: request.top_k]
94
- )
95
-
96
- return [
97
- SearchResult(
98
- id=snippet.id,
99
- uri=file.uri,
100
- content=snippet.content,
101
- )
102
- for file, snippet in search_results
103
- ]
104
-
105
-
106
- def reciprocal_rank_fusion(
107
- rankings: list[list[int]], k: float = 60
108
- ) -> list[tuple[int, float]]:
109
- """RRF prioritises results that are present in all results.
110
-
111
- Args:
112
- rankings: List of rankers, each containing a list of document ids. Top of the
113
- list is considered to be the best result.
114
- k: Parameter for RRF.
115
-
116
- Returns:
117
- Dictionary of ids and their scores.
118
-
119
- """
120
- scores = {}
121
- for ranker in rankings:
122
- for rank in ranker:
123
- scores[rank] = float(0)
124
-
125
- for ranker in rankings:
126
- for i, rank in enumerate(ranker):
127
- scores[rank] += 1.0 / (k + i)
128
-
129
- # Create a list of tuples of ids and their scores
130
- results = [(rank, scores[rank]) for rank in scores]
131
-
132
- # Sort results by score
133
- results.sort(key=lambda x: x[1], reverse=True)
134
-
135
- return results
File without changes