kodit 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

kodit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.9'
21
- __version_tuple__ = version_tuple = (0, 1, 9)
20
+ __version__ = version = '0.1.11'
21
+ __version_tuple__ = version_tuple = (0, 1, 11)
kodit/bm25/bm25.py CHANGED
@@ -38,7 +38,7 @@ class BM25Service:
38
38
  self.log.debug("Indexing corpus")
39
39
  vocab = self._tokenize(corpus)
40
40
  self.retriever = bm25s.BM25()
41
- self.retriever.index(vocab)
41
+ self.retriever.index(vocab, show_progress=False)
42
42
  self.retriever.save(self.index_path)
43
43
 
44
44
  def retrieve(
kodit/cli.py CHANGED
@@ -15,6 +15,7 @@ from kodit.config import (
15
15
  DEFAULT_BASE_DIR,
16
16
  DEFAULT_DB_URL,
17
17
  DEFAULT_DISABLE_TELEMETRY,
18
+ DEFAULT_EMBEDDING_MODEL_NAME,
18
19
  DEFAULT_LOG_FORMAT,
19
20
  DEFAULT_LOG_LEVEL,
20
21
  AppContext,
@@ -23,7 +24,7 @@ from kodit.config import (
23
24
  )
24
25
  from kodit.indexing.repository import IndexRepository
25
26
  from kodit.indexing.service import IndexService
26
- from kodit.logging import configure_logging, configure_telemetry, log_event
27
+ from kodit.log import configure_logging, configure_telemetry, log_event
27
28
  from kodit.retreival.repository import RetrievalRepository
28
29
  from kodit.retreival.service import RetrievalRequest, RetrievalService
29
30
  from kodit.sources.repository import SourceRepository
@@ -97,7 +98,12 @@ async def index(
97
98
  source_repository = SourceRepository(session)
98
99
  source_service = SourceService(app_context.get_clone_dir(), source_repository)
99
100
  repository = IndexRepository(session)
100
- service = IndexService(repository, source_service, app_context.get_data_dir())
101
+ service = IndexService(
102
+ repository,
103
+ source_service,
104
+ app_context.get_data_dir(),
105
+ embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
106
+ )
101
107
 
102
108
  if not sources:
103
109
  # No source specified, list all indexes
@@ -133,20 +139,106 @@ async def index(
133
139
  await service.run(index.id)
134
140
 
135
141
 
136
- @cli.command()
142
+ @cli.group()
143
+ def search() -> None:
144
+ """Search for snippets in the database."""
145
+
146
+
147
+ @search.command()
137
148
  @click.argument("query")
138
149
  @click.option("--top-k", default=10, help="Number of snippets to retrieve")
139
150
  @with_app_context
140
151
  @with_session
141
- async def retrieve(
142
- session: AsyncSession, app_context: AppContext, query: str, top_k: int
152
+ async def code(
153
+ session: AsyncSession,
154
+ app_context: AppContext,
155
+ query: str,
156
+ top_k: int,
157
+ ) -> None:
158
+ """Search for snippets using semantic code search.
159
+
160
+ This works best if your query is code.
161
+ """
162
+ repository = RetrievalRepository(session)
163
+ service = RetrievalService(
164
+ repository,
165
+ app_context.get_data_dir(),
166
+ embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
167
+ )
168
+
169
+ snippets = await service.retrieve(RetrievalRequest(code_query=query, top_k=top_k))
170
+
171
+ if len(snippets) == 0:
172
+ click.echo("No snippets found")
173
+ return
174
+
175
+ for snippet in snippets:
176
+ click.echo("-" * 80)
177
+ click.echo(f"{snippet.uri}")
178
+ click.echo(snippet.content)
179
+ click.echo("-" * 80)
180
+ click.echo()
181
+
182
+
183
+ @search.command()
184
+ @click.argument("keywords", nargs=-1)
185
+ @click.option("--top-k", default=10, help="Number of snippets to retrieve")
186
+ @with_app_context
187
+ @with_session
188
+ async def keyword(
189
+ session: AsyncSession,
190
+ app_context: AppContext,
191
+ keywords: list[str],
192
+ top_k: int,
143
193
  ) -> None:
144
- """Retrieve snippets from the database."""
194
+ """Search for snippets using keyword search."""
145
195
  repository = RetrievalRepository(session)
146
- service = RetrievalService(repository, app_context.get_data_dir())
147
- # Temporary request while we don't have all search capabilities
196
+ service = RetrievalService(
197
+ repository,
198
+ app_context.get_data_dir(),
199
+ embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
200
+ )
201
+
202
+ snippets = await service.retrieve(RetrievalRequest(keywords=keywords, top_k=top_k))
203
+
204
+ if len(snippets) == 0:
205
+ click.echo("No snippets found")
206
+ return
207
+
208
+ for snippet in snippets:
209
+ click.echo("-" * 80)
210
+ click.echo(f"{snippet.uri}")
211
+ click.echo(snippet.content)
212
+ click.echo("-" * 80)
213
+ click.echo()
214
+
215
+
216
+ @search.command()
217
+ @click.option("--top-k", default=10, help="Number of snippets to retrieve")
218
+ @click.option("--keywords", required=True, help="Comma separated list of keywords")
219
+ @click.option("--code", required=True, help="Semantic code search query")
220
+ @with_app_context
221
+ @with_session
222
+ async def hybrid(
223
+ session: AsyncSession,
224
+ app_context: AppContext,
225
+ top_k: int,
226
+ keywords: str,
227
+ code: str,
228
+ ) -> None:
229
+ """Search for snippets using hybrid search."""
230
+ repository = RetrievalRepository(session)
231
+ service = RetrievalService(
232
+ repository,
233
+ app_context.get_data_dir(),
234
+ embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
235
+ )
236
+
237
+ # Parse keywords into a list of strings
238
+ keywords_list = [k.strip().lower() for k in keywords.split(",")]
239
+
148
240
  snippets = await service.retrieve(
149
- RetrievalRequest(keywords=query.split(","), top_k=top_k)
241
+ RetrievalRequest(keywords=keywords_list, code_query=code, top_k=top_k)
150
242
  )
151
243
 
152
244
  if len(snippets) == 0:
kodit/config.py CHANGED
@@ -11,12 +11,14 @@ from pydantic import Field
11
11
  from pydantic_settings import BaseSettings, SettingsConfigDict
12
12
 
13
13
  from kodit.database import Database
14
+ from kodit.embedding.embedding import TINY
14
15
 
15
16
  DEFAULT_BASE_DIR = Path.home() / ".kodit"
16
17
  DEFAULT_DB_URL = f"sqlite+aiosqlite:///{DEFAULT_BASE_DIR}/kodit.db"
17
18
  DEFAULT_LOG_LEVEL = "INFO"
18
19
  DEFAULT_LOG_FORMAT = "pretty"
19
20
  DEFAULT_DISABLE_TELEMETRY = False
21
+ DEFAULT_EMBEDDING_MODEL_NAME = TINY
20
22
  T = TypeVar("T")
21
23
 
22
24
 
kodit/database.py CHANGED
@@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import (
15
15
  )
16
16
  from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
17
17
 
18
- from kodit import alembic
18
+ from kodit import migrations
19
19
 
20
20
 
21
21
  class Base(AsyncAttrs, DeclarativeBase):
@@ -57,7 +57,7 @@ class Database:
57
57
  # Create Alembic configuration and run migrations
58
58
  alembic_cfg = AlembicConfig()
59
59
  alembic_cfg.set_main_option(
60
- "script_location", str(Path(alembic.__file__).parent)
60
+ "script_location", str(Path(migrations.__file__).parent)
61
61
  )
62
62
  alembic_cfg.set_main_option("sqlalchemy.url", db_url)
63
63
  self.log.debug("Running migrations", db_url=db_url)
@@ -0,0 +1 @@
1
+ """Embedding module."""
@@ -0,0 +1,52 @@
1
+ """Embedding service."""
2
+
3
+ import os
4
+ from collections.abc import Generator
5
+
6
+ import structlog
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ TINY = "tiny"
10
+ CODE = "code"
11
+ TEST = "test"
12
+
13
+ COMMON_EMBEDDING_MODELS = {
14
+ TINY: "ibm-granite/granite-embedding-30m-english",
15
+ CODE: "flax-sentence-embeddings/st-codesearch-distilroberta-base",
16
+ TEST: "minishlab/potion-base-4M",
17
+ }
18
+
19
+
20
+ class EmbeddingService:
21
+ """Service for embeddings."""
22
+
23
+ def __init__(self, model_name: str) -> None:
24
+ """Initialize the embedding service."""
25
+ self.log = structlog.get_logger(__name__)
26
+ self.model_name = COMMON_EMBEDDING_MODELS.get(model_name, model_name)
27
+ self.embedding_model = None
28
+
29
+ def _model(self) -> SentenceTransformer:
30
+ """Get the embedding model."""
31
+ if self.embedding_model is None:
32
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
33
+ self.embedding_model = SentenceTransformer(
34
+ self.model_name,
35
+ trust_remote_code=True,
36
+ device="cpu", # Force CPU so we don't have to install accelerate, etc.
37
+ )
38
+ return self.embedding_model
39
+
40
+ def embed(self, snippets: list[str]) -> Generator[list[float], None, None]:
41
+ """Embed a list of documents."""
42
+ model = self._model()
43
+ embeddings = model.encode(snippets, show_progress_bar=False, batch_size=4)
44
+ for embedding in embeddings:
45
+ yield [float(x) for x in embedding]
46
+
47
+ def query(self, query: list[str]) -> Generator[list[float], None, None]:
48
+ """Query the embedding model."""
49
+ model = self._model()
50
+ embeddings = model.encode(query, show_progress_bar=False, batch_size=4)
51
+ for embedding in embeddings:
52
+ yield [float(x) for x in embedding]
@@ -0,0 +1,28 @@
1
+ """Embedding models."""
2
+
3
+ from enum import Enum
4
+
5
+ from sqlalchemy import JSON, ForeignKey
6
+ from sqlalchemy import Enum as SQLAlchemyEnum
7
+ from sqlalchemy.orm import Mapped, mapped_column
8
+
9
+ from kodit.database import Base, CommonMixin
10
+
11
+
12
+ class EmbeddingType(Enum):
13
+ """Embedding type."""
14
+
15
+ CODE = 1
16
+ TEXT = 2
17
+
18
+
19
+ class Embedding(Base, CommonMixin):
20
+ """Embedding model."""
21
+
22
+ __tablename__ = "embeddings"
23
+
24
+ snippet_id: Mapped[int] = mapped_column(ForeignKey("snippets.id"), index=True)
25
+ type: Mapped[EmbeddingType] = mapped_column(
26
+ SQLAlchemyEnum(EmbeddingType), index=True
27
+ )
28
+ embedding: Mapped[list[float]] = mapped_column(JSON)
@@ -11,6 +11,7 @@ from typing import TypeVar
11
11
  from sqlalchemy import delete, func, select
12
12
  from sqlalchemy.ext.asyncio import AsyncSession
13
13
 
14
+ from kodit.embedding.models import Embedding
14
15
  from kodit.indexing.models import Index, Snippet
15
16
  from kodit.sources.models import File, Source
16
17
 
@@ -165,3 +166,13 @@ class IndexRepository:
165
166
  query = select(Snippet).order_by(Snippet.id)
166
167
  result = await self.session.execute(query)
167
168
  return list(result.scalars())
169
+
170
+ async def add_embedding(self, embedding: Embedding) -> None:
171
+ """Add a new embedding to the database.
172
+
173
+ Args:
174
+ embedding: The Embedding instance to add.
175
+
176
+ """
177
+ self.session.add(embedding)
178
+ await self.session.commit()
kodit/indexing/service.py CHANGED
@@ -14,6 +14,8 @@ import structlog
14
14
  from tqdm.asyncio import tqdm
15
15
 
16
16
  from kodit.bm25.bm25 import BM25Service
17
+ from kodit.embedding.embedding import EmbeddingService
18
+ from kodit.embedding.models import Embedding, EmbeddingType
17
19
  from kodit.indexing.models import Snippet
18
20
  from kodit.indexing.repository import IndexRepository
19
21
  from kodit.snippets.snippets import SnippetService
@@ -50,6 +52,7 @@ class IndexService:
50
52
  repository: IndexRepository,
51
53
  source_service: SourceService,
52
54
  data_dir: Path,
55
+ embedding_model_name: str,
53
56
  ) -> None:
54
57
  """Initialize the index service.
55
58
 
@@ -63,6 +66,7 @@ class IndexService:
63
66
  self.snippet_service = SnippetService()
64
67
  self.log = structlog.get_logger(__name__)
65
68
  self.bm25 = BM25Service(data_dir)
69
+ self.code_embedding_service = EmbeddingService(model_name=embedding_model_name)
66
70
 
67
71
  async def create(self, source_id: int) -> IndexView:
68
72
  """Create a new index for a source.
@@ -128,9 +132,26 @@ class IndexService:
128
132
  # Create snippets for supported file types
129
133
  await self._create_snippets(index_id)
130
134
 
131
- # Update BM25 index
132
135
  snippets = await self.repository.get_all_snippets()
133
- self.bm25.index([snippet.content for snippet in snippets])
136
+
137
+ self.log.info("Creating keyword index")
138
+ self.bm25.index(
139
+ [
140
+ snippet.content
141
+ for snippet in tqdm(snippets, total=len(snippets), leave=False)
142
+ ]
143
+ )
144
+
145
+ self.log.info("Creating semantic code index")
146
+ for snippet in tqdm(snippets, total=len(snippets), leave=False):
147
+ embedding = next(self.code_embedding_service.embed([snippet.content]))
148
+ await self.repository.add_embedding(
149
+ Embedding(
150
+ snippet_id=snippet.id,
151
+ embedding=embedding,
152
+ type=EmbeddingType.CODE,
153
+ )
154
+ )
134
155
 
135
156
  # Update index timestamp
136
157
  await self.repository.update_index_timestamp(index)
@@ -148,7 +169,7 @@ class IndexService:
148
169
 
149
170
  """
150
171
  files = await self.repository.files_for_index(index_id)
151
- for file in tqdm(files, total=len(files)):
172
+ for file in tqdm(files, total=len(files), leave=False):
152
173
  # Skip unsupported file types
153
174
  if file.mime_type in MIME_BLACKLIST:
154
175
  self.log.debug("Skipping mime type", mime_type=file.mime_type)
@@ -87,7 +87,13 @@ def configure_logging(app_context: AppContext) -> None:
87
87
  # Configure uvicorn loggers to use our structlog setup
88
88
  # Uvicorn spits out loads of exception logs when sse server doesn't shut down
89
89
  # gracefully, so we hide them unless in DEBUG mode
90
- for _log in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
90
+ for _log in [
91
+ "uvicorn",
92
+ "uvicorn.error",
93
+ "uvicorn.access",
94
+ "bm25s",
95
+ "sentence_transformers.SentenceTransformer",
96
+ ]:
91
97
  if root_logger.getEffectiveLevel() == logging.DEBUG:
92
98
  logging.getLogger(_log).handlers.clear()
93
99
  logging.getLogger(_log).propagate = True
kodit/mcp.py CHANGED
@@ -12,7 +12,7 @@ from pydantic import Field
12
12
  from sqlalchemy.ext.asyncio import AsyncSession
13
13
 
14
14
  from kodit._version import version
15
- from kodit.config import AppContext
15
+ from kodit.config import DEFAULT_EMBEDDING_MODEL_NAME, AppContext
16
16
  from kodit.database import Database
17
17
  from kodit.retreival.repository import RetrievalRepository, RetrievalResult
18
18
  from kodit.retreival.service import RetrievalRequest, RetrievalService
@@ -115,18 +115,12 @@ async def retrieve_relevant_snippets(
115
115
  retrieval_service = RetrievalService(
116
116
  repository=retrieval_repository,
117
117
  data_dir=mcp_context.data_dir,
118
+ embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
118
119
  )
119
120
 
120
- log.debug("Fusing input")
121
- input_query = input_fusion(
122
- user_intent=user_intent,
123
- related_file_paths=related_file_paths,
124
- related_file_contents=related_file_contents,
125
- keywords=keywords,
126
- )
127
- log.debug("Input", input_query=input_query)
128
121
  retrieval_request = RetrievalRequest(
129
122
  keywords=keywords,
123
+ code_query="\n".join(related_file_contents),
130
124
  )
131
125
  log.debug("Retrieving snippets")
132
126
  snippets = await retrieval_service.retrieve(request=retrieval_request)
@@ -8,6 +8,7 @@ from sqlalchemy import pool
8
8
  from sqlalchemy.engine import Connection
9
9
  from sqlalchemy.ext.asyncio import async_engine_from_config
10
10
 
11
+ import kodit.embedding.models
11
12
  import kodit.indexing.models
12
13
  import kodit.sources.models
13
14
  from kodit.database import Base
@@ -0,0 +1,47 @@
1
+ # ruff: noqa
2
+ """add embeddings table
3
+
4
+ Revision ID: 7c3bbc2ab32b
5
+ Revises: 85155663351e
6
+ Create Date: 2025-05-23 17:23:09.924980
7
+
8
+ """
9
+
10
+ from typing import Sequence, Union
11
+
12
+ from alembic import op
13
+ import sqlalchemy as sa
14
+
15
+
16
+ # revision identifiers, used by Alembic.
17
+ revision: str = '7c3bbc2ab32b'
18
+ down_revision: Union[str, None] = '85155663351e'
19
+ branch_labels: Union[str, Sequence[str], None] = None
20
+ depends_on: Union[str, Sequence[str], None] = None
21
+
22
+
23
+ def upgrade() -> None:
24
+ """Upgrade schema."""
25
+ # ### commands auto generated by Alembic - please adjust! ###
26
+ op.create_table('embeddings',
27
+ sa.Column('snippet_id', sa.Integer(), nullable=False),
28
+ sa.Column('type', sa.Enum('CODE', 'TEXT', name='embeddingtype'), nullable=False),
29
+ sa.Column('embedding', sa.JSON(), nullable=False),
30
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
31
+ sa.Column('created_at', sa.DateTime(), nullable=False),
32
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
33
+ sa.ForeignKeyConstraint(['snippet_id'], ['snippets.id'], ),
34
+ sa.PrimaryKeyConstraint('id')
35
+ )
36
+ op.create_index(op.f('ix_embeddings_snippet_id'), 'embeddings', ['snippet_id'], unique=False)
37
+ op.create_index(op.f('ix_embeddings_type'), 'embeddings', ['type'], unique=False)
38
+ # ### end Alembic commands ###
39
+
40
+
41
+ def downgrade() -> None:
42
+ """Downgrade schema."""
43
+ # ### commands auto generated by Alembic - please adjust! ###
44
+ op.drop_index(op.f('ix_embeddings_type'), table_name='embeddings')
45
+ op.drop_index(op.f('ix_embeddings_snippet_id'), table_name='embeddings')
46
+ op.drop_table('embeddings')
47
+ # ### end Alembic commands ###
@@ -7,10 +7,14 @@ and their associated file information.
7
7
 
8
8
  from typing import TypeVar
9
9
 
10
+ import numpy as np
10
11
  import pydantic
11
- from sqlalchemy import select
12
+ from sqlalchemy import (
13
+ select,
14
+ )
12
15
  from sqlalchemy.ext.asyncio import AsyncSession
13
16
 
17
+ from kodit.embedding.models import Embedding, EmbeddingType
14
18
  from kodit.indexing.models import Snippet
15
19
  from kodit.sources.models import File
16
20
 
@@ -24,8 +28,10 @@ class RetrievalResult(pydantic.BaseModel):
24
28
  and the matching snippet content.
25
29
  """
26
30
 
31
+ id: int
27
32
  uri: str
28
33
  content: str
34
+ score: float
29
35
 
30
36
 
31
37
  class RetrievalRepository:
@@ -69,8 +75,10 @@ class RetrievalRepository:
69
75
 
70
76
  return [
71
77
  RetrievalResult(
78
+ id=snippet.id,
72
79
  uri=file.uri,
73
80
  content=snippet.content,
81
+ score=1.0,
74
82
  )
75
83
  for snippet, file in results
76
84
  ]
@@ -90,7 +98,7 @@ class RetrievalRepository:
90
98
  """List snippets by IDs.
91
99
 
92
100
  Returns:
93
- A list of snippets.
101
+ A list of snippets in the same order as the input IDs.
94
102
 
95
103
  """
96
104
  query = (
@@ -99,10 +107,125 @@ class RetrievalRepository:
99
107
  .join(File, Snippet.file_id == File.id)
100
108
  )
101
109
  rows = await self.session.execute(query)
102
- return [
103
- RetrievalResult(
110
+
111
+ # Create a dictionary for O(1) lookup of results by ID
112
+ id_to_result = {
113
+ snippet.id: RetrievalResult(
114
+ id=snippet.id,
104
115
  uri=file.uri,
105
116
  content=snippet.content,
117
+ score=1.0,
106
118
  )
107
119
  for snippet, file in rows.all()
108
- ]
120
+ }
121
+
122
+ # Return results in the same order as input IDs
123
+ return [id_to_result[i] for i in ids]
124
+
125
+ async def fetch_embeddings(
126
+ self, embedding_type: EmbeddingType
127
+ ) -> list[tuple[int, list[float]]]:
128
+ """Fetch all embeddings of a given type from the database.
129
+
130
+ Args:
131
+ embedding_type: The type of embeddings to fetch
132
+
133
+ Returns:
134
+ List of (snippet_id, embedding) tuples
135
+
136
+ """
137
+ # Only select the fields we need and use a more efficient query
138
+ query = select(Embedding.snippet_id, Embedding.embedding).where(
139
+ Embedding.type == embedding_type
140
+ )
141
+ rows = await self.session.execute(query)
142
+ return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
143
+
144
+ def prepare_vectors(
145
+ self, embeddings: list[tuple[int, list[float]]], query_embedding: list[float]
146
+ ) -> tuple[np.ndarray, np.ndarray]:
147
+ """Convert embeddings to numpy arrays.
148
+
149
+ Args:
150
+ embeddings: List of (snippet_id, embedding) tuples
151
+ query_embedding: Query embedding vector
152
+
153
+ Returns:
154
+ Tuple of (stored_vectors, query_vector) as numpy arrays
155
+
156
+ """
157
+ stored_vecs = np.array(
158
+ [emb[1] for emb in embeddings]
159
+ ) # Use index 1 to get embedding
160
+ query_vec = np.array(query_embedding)
161
+ return stored_vecs, query_vec
162
+
163
+ def compute_similarities(
164
+ self, stored_vecs: np.ndarray, query_vec: np.ndarray
165
+ ) -> np.ndarray:
166
+ """Compute cosine similarities between stored vectors and query vector.
167
+
168
+ Args:
169
+ stored_vecs: Array of stored embedding vectors
170
+ query_vec: Query embedding vector
171
+
172
+ Returns:
173
+ Array of similarity scores
174
+
175
+ """
176
+ stored_norms = np.linalg.norm(stored_vecs, axis=1)
177
+ query_norm = np.linalg.norm(query_vec)
178
+ return np.dot(stored_vecs, query_vec) / (stored_norms * query_norm)
179
+
180
+ def get_top_k_results(
181
+ self,
182
+ similarities: np.ndarray,
183
+ embeddings: list[tuple[int, list[float]]],
184
+ top_k: int,
185
+ ) -> list[tuple[int, float]]:
186
+ """Get top-k results by similarity score.
187
+
188
+ Args:
189
+ similarities: Array of similarity scores
190
+ embeddings: List of (snippet_id, embedding) tuples
191
+ top_k: Number of results to return
192
+
193
+ Returns:
194
+ List of (snippet_id, similarity_score) tuples
195
+
196
+ """
197
+ top_indices = np.argsort(similarities)[::-1][:top_k]
198
+ return [
199
+ (embeddings[i][0], float(similarities[i])) for i in top_indices
200
+ ] # Use index 0 to get snippet_id
201
+
202
+ async def list_semantic_results(
203
+ self, embedding_type: EmbeddingType, embedding: list[float], top_k: int = 10
204
+ ) -> list[tuple[int, float]]:
205
+ """List semantic results using cosine similarity.
206
+
207
+ This implementation fetches all embeddings of the given type and computes
208
+ cosine similarity in Python using NumPy for better performance.
209
+
210
+ Args:
211
+ embedding_type: The type of embeddings to search
212
+ embedding: The query embedding vector
213
+ top_k: Number of results to return
214
+
215
+ Returns:
216
+ List of (snippet_id, similarity_score) tuples, sorted by similarity
217
+
218
+ """
219
+ # Step 1: Fetch embeddings from database
220
+ embeddings = await self.fetch_embeddings(embedding_type)
221
+ if not embeddings:
222
+ return []
223
+
224
+ # Step 2: Convert to numpy arrays
225
+ stored_vecs, query_vec = self.prepare_vectors(embeddings, embedding)
226
+
227
+ # Step 3: Compute similarities
228
+ similarities = self.compute_similarities(stored_vecs, query_vec)
229
+
230
+ # Step 4: Get top-k results
231
+ return self.get_top_k_results(similarities, embeddings, top_k)
@@ -6,13 +6,16 @@ import pydantic
6
6
  import structlog
7
7
 
8
8
  from kodit.bm25.bm25 import BM25Service
9
+ from kodit.embedding.embedding import EmbeddingService
10
+ from kodit.embedding.models import EmbeddingType
9
11
  from kodit.retreival.repository import RetrievalRepository, RetrievalResult
10
12
 
11
13
 
12
14
  class RetrievalRequest(pydantic.BaseModel):
13
15
  """Request for a retrieval."""
14
16
 
15
- keywords: list[str]
17
+ code_query: str | None = None
18
+ keywords: list[str] | None = None
16
19
  top_k: int = 10
17
20
 
18
21
 
@@ -26,44 +29,96 @@ class Snippet(pydantic.BaseModel):
26
29
  class RetrievalService:
27
30
  """Service for retrieving relevant data."""
28
31
 
29
- def __init__(self, repository: RetrievalRepository, data_dir: Path) -> None:
32
+ def __init__(
33
+ self,
34
+ repository: RetrievalRepository,
35
+ data_dir: Path,
36
+ embedding_model_name: str,
37
+ ) -> None:
30
38
  """Initialize the retrieval service."""
31
39
  self.repository = repository
32
40
  self.log = structlog.get_logger(__name__)
33
41
  self.bm25 = BM25Service(data_dir)
34
-
35
- async def _load_bm25_index(self) -> None:
36
- """Load the BM25 index."""
42
+ self.code_embedding_service = EmbeddingService(model_name=embedding_model_name)
37
43
 
38
44
  async def retrieve(self, request: RetrievalRequest) -> list[RetrievalResult]:
39
45
  """Retrieve relevant data."""
40
- snippet_ids = await self.repository.list_snippet_ids()
46
+ fusion_list = []
47
+ if request.keywords:
48
+ snippet_ids = await self.repository.list_snippet_ids()
49
+
50
+ # Gather results for each keyword
51
+ result_ids: list[tuple[int, float]] = []
52
+ for keyword in request.keywords:
53
+ results = self.bm25.retrieve(snippet_ids, keyword, request.top_k)
54
+ result_ids.extend(results)
55
+
56
+ # Sort results by score
57
+ result_ids.sort(key=lambda x: x[1], reverse=True)
58
+
59
+ self.log.debug("Retrieval results (BM25)", results=result_ids)
60
+
61
+ bm25_results = [x[0] for x in result_ids]
62
+ fusion_list.append(bm25_results)
63
+
64
+ # Compute embedding for semantic query
65
+ semantic_results = []
66
+ if request.code_query:
67
+ query_embedding = next(
68
+ self.code_embedding_service.query([request.code_query])
69
+ )
70
+
71
+ query_results = await self.repository.list_semantic_results(
72
+ EmbeddingType.CODE, query_embedding, top_k=request.top_k
73
+ )
74
+
75
+ # Sort results by score
76
+ query_results.sort(key=lambda x: x[1], reverse=True)
41
77
 
42
- # Gather results for each keyword
43
- result_ids: list[tuple[int, float]] = []
44
- for keyword in request.keywords:
45
- results = self.bm25.retrieve(snippet_ids, keyword, request.top_k)
46
- result_ids.extend(results)
78
+ # Extract the snippet ids from the query results
79
+ semantic_results = [x[0] for x in query_results]
80
+ fusion_list.append(semantic_results)
47
81
 
48
- if len(result_ids) == 0:
82
+ if len(fusion_list) == 0:
49
83
  return []
50
84
 
51
- # Sort results by score
52
- result_ids.sort(key=lambda x: x[1], reverse=True)
85
+ # Combine all results together with RFF if required
86
+ final_results = reciprocal_rank_fusion(fusion_list, k=60)
87
+
88
+ # Extract ids from final results
89
+ final_ids = [x[0] for x in final_results]
90
+
91
+ # Get snippets from database (up to top_k)
92
+ return await self.repository.list_snippets_by_ids(final_ids[: request.top_k])
93
+
94
+
95
+ def reciprocal_rank_fusion(
96
+ rankings: list[list[int]], k: float = 60
97
+ ) -> list[tuple[int, float]]:
98
+ """RRF prioritises results that are present in all results.
99
+
100
+ Args:
101
+ rankings: List of rankers, each containing a list of document ids. Top of the
102
+ list is considered to be the best result.
103
+ k: Parameter for RRF.
104
+
105
+ Returns:
106
+ Dictionary of ids and their scores.
107
+
108
+ """
109
+ scores = {}
110
+ for ranker in rankings:
111
+ for rank in ranker:
112
+ scores[rank] = float(0)
53
113
 
54
- self.log.debug(
55
- "Retrieval results",
56
- total_results=len(result_ids),
57
- max_score=result_ids[0][1],
58
- min_score=result_ids[-1][1],
59
- median_score=result_ids[len(result_ids) // 2][1],
60
- )
114
+ for ranker in rankings:
115
+ for i, rank in enumerate(ranker):
116
+ scores[rank] += 1.0 / (k + i)
61
117
 
62
- # Don't return zero score results
63
- result_ids = [x for x in result_ids if x[1] > 0]
118
+ # Create a list of tuples of ids and their scores
119
+ results = [(rank, scores[rank]) for rank in scores]
64
120
 
65
- # Build final list of doc ids up to top_k
66
- final_doc_ids = [x[0] for x in result_ids[: request.top_k]]
121
+ # Sort results by score
122
+ results.sort(key=lambda x: x[1], reverse=True)
67
123
 
68
- # Get snippets from database
69
- return await self.repository.list_snippets_by_ids(final_doc_ids)
124
+ return results
kodit/sources/service.py CHANGED
@@ -165,7 +165,7 @@ class SourceService:
165
165
  file_count = sum(1 for _ in clone_path.rglob("*") if _.is_file())
166
166
 
167
167
  # Process each file in the source directory
168
- for path in tqdm(clone_path.rglob("*"), total=file_count):
168
+ for path in tqdm(clone_path.rglob("*"), total=file_count, leave=False):
169
169
  await self._process_file(source.id, path.absolute())
170
170
 
171
171
  return SourceView(
@@ -212,7 +212,7 @@ class SourceService:
212
212
  file_count = sum(1 for _ in clone_path.rglob("*") if _.is_file())
213
213
 
214
214
  # Process each file in the source directory
215
- for path in tqdm(clone_path.rglob("*"), total=file_count):
215
+ for path in tqdm(clone_path.rglob("*"), total=file_count, leave=False):
216
216
  await self._process_file(source.id, path.absolute())
217
217
 
218
218
  return SourceView(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kodit
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: Code indexing for better AI code generation
5
5
  Project-URL: Homepage, https://docs.helixml.tech/kodit/
6
6
  Project-URL: Documentation, https://docs.helixml.tech/kodit/
@@ -29,11 +29,13 @@ Requires-Dist: dotenv>=0.9.9
29
29
  Requires-Dist: fastapi[standard]>=0.115.12
30
30
  Requires-Dist: fastmcp>=2.3.3
31
31
  Requires-Dist: gitpython>=3.1.44
32
+ Requires-Dist: hf-xet>=1.1.2
32
33
  Requires-Dist: httpx-retries>=0.3.2
33
34
  Requires-Dist: httpx>=0.28.1
34
35
  Requires-Dist: posthog>=4.0.1
35
36
  Requires-Dist: pydantic-settings>=2.9.1
36
37
  Requires-Dist: pytable-formatter>=0.1.1
38
+ Requires-Dist: sentence-transformers>=4.1.0
37
39
  Requires-Dist: sqlalchemy[asyncio]>=2.0.40
38
40
  Requires-Dist: structlog>=25.3.0
39
41
  Requires-Dist: tdqm>=0.0.1
@@ -0,0 +1,44 @@
1
+ kodit/.gitignore,sha256=ztkjgRwL9Uud1OEi36hGQeDGk3OLK1NfDEO8YqGYy8o,11
2
+ kodit/__init__.py,sha256=aEKHYninUq1yh6jaNfvJBYg-6fenpN132nJt1UU6Jxs,59
3
+ kodit/_version.py,sha256=xfwL5IZGNNwnNDAQtGFjpvlNxqYn3U9IM9B98Du9pJw,513
4
+ kodit/app.py,sha256=Mr5BFHOHx5zppwjC4XPWVvHjwgl1yrKbUjTWXKubJQM,891
5
+ kodit/cli.py,sha256=qEQy_Sd64cEV5KzYsKlGLyMxFQ4fFi-as4QO8CRrKYo,8978
6
+ kodit/config.py,sha256=hQshTMW_8jpk94zP-1JaxowgmW_LrT534ipHFaRUGMw,3006
7
+ kodit/database.py,sha256=kekSdyEATdb47jxzQemkSOXMNOwnUwmVVTpn9hYaDK8,2356
8
+ kodit/log.py,sha256=PhyzQktEyyHaNr78W0wmL-RSRuq311DQ-d0l-EKTGmQ,5417
9
+ kodit/mcp.py,sha256=qp16vRb0TY46-xQy179iWgYebr6Ju_Z91ZSzZnWPHuk,4771
10
+ kodit/middleware.py,sha256=I6FOkqG9-8RH5kR1-0ZoQWfE4qLCB8lZYv8H_OCH29o,2714
11
+ kodit/bm25/__init__.py,sha256=j8zyriNWhbwE5Lbybzg1hQAhANlU9mKHWw4beeUR6og,19
12
+ kodit/bm25/bm25.py,sha256=NtlcLrgqJja11qDGKz_U6tuYWaS9sfbyS-TcA__rBKs,2284
13
+ kodit/embedding/__init__.py,sha256=h9NXzDA1r-K23nvBajBV-RJzHJN0p3UJ7UQsmdnOoRw,24
14
+ kodit/embedding/embedding.py,sha256=X2Fa-eXhQwp__QFj9yxIhvlCAiYVQSaZ2y18ZtG5_1Y,1810
15
+ kodit/embedding/models.py,sha256=rN90vSs86dYiqoawcp8E9jtwY31JoJXYfaDlsJK7uqc,656
16
+ kodit/indexing/__init__.py,sha256=cPyi2Iej3G1JFWlWr7X80_UrsMaTu5W5rBwgif1B3xo,75
17
+ kodit/indexing/models.py,sha256=sZIhGwvL4Dw0QTWFxrjfWctSLkAoDT6fv5DlGz8-Fr8,1258
18
+ kodit/indexing/repository.py,sha256=eIaIbqNs9Z3XTVymZ5Zl5uPWveqiEXNo0JTa-y-Tl24,5430
19
+ kodit/indexing/service.py,sha256=hhQ_6vI7J7LnNgOLbsO4B07TOJvEePqqFviiqr3TL_M,6579
20
+ kodit/migrations/README,sha256=ISVtAOvqvKk_5ThM5ioJE-lMkvf9IbknFUFVU_vPma4,58
21
+ kodit/migrations/__init__.py,sha256=lP5MuwlyWRMO6UcDWnQcQ3G-GYHcFb6rl9gYPHJ1sjo,40
22
+ kodit/migrations/env.py,sha256=bzB6vod_tO-X2F_G671FwYSAn0pyhNw8M1kG4MgidO8,2444
23
+ kodit/migrations/script.py.mako,sha256=zWziKtiwYKEWuwPV_HBNHwa9LCT45_bi01-uSNFaOOE,703
24
+ kodit/migrations/versions/7c3bbc2ab32b_add_embeddings_table.py,sha256=-61qol9PfQKILCDQRA5jEaats9aGZs9Wdtp-j-38SF4,1644
25
+ kodit/migrations/versions/85155663351e_initial.py,sha256=Cg7zlF871o9ShV5rQMQ1v7hRV7fI59veDY9cjtTrs-8,3306
26
+ kodit/migrations/versions/__init__.py,sha256=9-lHzptItTzq_fomdIRBegQNm4Znx6pVjwD4MiqRIdo,36
27
+ kodit/retreival/__init__.py,sha256=33PhJU-3gtsqYq6A1UkaLNKbev_Zee9Lq6dYC59-CsA,69
28
+ kodit/retreival/repository.py,sha256=XHkkeUsnXSrrcthJOL9FXgivn5kkaPnC9Qci6ebwjZc,7294
29
+ kodit/retreival/service.py,sha256=gGp74jnqhyCDF5vKOrN2dJKDnhlfR4HZaxADSrjTb4s,3778
30
+ kodit/snippets/__init__.py,sha256=-2coNoCRjTixU9KcP6alpmt7zqf37tCRWH3D7FPJ8dg,48
31
+ kodit/snippets/method_snippets.py,sha256=EVHhSNWahAC5nSXv9fWVFJY2yq25goHdCSCuENC07F8,4145
32
+ kodit/snippets/snippets.py,sha256=QumvhltWoxXw41SyKb-RbSvAr3m6V3lUy9n0AI8jcto,1409
33
+ kodit/snippets/languages/__init__.py,sha256=Bj5KKZSls2MQ8ZY1S_nHg447MgGZW-2WZM-oq6vjwwA,1187
34
+ kodit/snippets/languages/csharp.scm,sha256=gbBN4RiV1FBuTJF6orSnDFi8H9JwTw-d4piLJYsWUsc,222
35
+ kodit/snippets/languages/python.scm,sha256=ee85R9PBzwye3IMTE7-iVoKWd_ViU3EJISTyrFGrVeo,429
36
+ kodit/sources/__init__.py,sha256=1NTZyPdjThVQpZO1Mp1ColVsS7sqYanOVLqnoqV9Ipo,83
37
+ kodit/sources/models.py,sha256=xb42CaNDO1CUB8SIW-xXMrB6Ji8cFw-yeJ550xBEg9Q,2398
38
+ kodit/sources/repository.py,sha256=mGJrHWH6Uo8YABdoojHFbzaf_jW-2ywJpAHIa1gnc3U,3401
39
+ kodit/sources/service.py,sha256=aV_qiqkU2kMBNPvye5_v4NnZiK-lJ64rQdmFtBtsQaY,9243
40
+ kodit-0.1.11.dist-info/METADATA,sha256=yUO645VYUiVrJMRtwNB71O-6qvC94nS7_ILQ8eQEvoY,2288
41
+ kodit-0.1.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
42
+ kodit-0.1.11.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
43
+ kodit-0.1.11.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
44
+ kodit-0.1.11.dist-info/RECORD,,
@@ -1,40 +0,0 @@
1
- kodit/.gitignore,sha256=ztkjgRwL9Uud1OEi36hGQeDGk3OLK1NfDEO8YqGYy8o,11
2
- kodit/__init__.py,sha256=aEKHYninUq1yh6jaNfvJBYg-6fenpN132nJt1UU6Jxs,59
3
- kodit/_version.py,sha256=bhntibG3PKk5Ai3XlSNEV8gj-ffItuKloY6vzWn6swo,511
4
- kodit/app.py,sha256=Mr5BFHOHx5zppwjC4XPWVvHjwgl1yrKbUjTWXKubJQM,891
5
- kodit/cli.py,sha256=bsfURvGKZzpHkChnTlatI0nXHV3KV_6vJnUJ2fQEAfM,6637
6
- kodit/config.py,sha256=nlm9U-nVx5riH2SrU1XY4XcCMhQK4DrwO_1H8bPOBjA,2927
7
- kodit/database.py,sha256=vtTlmrXHyHJH3Ek-twZTCqEjB0jun-NncALFze2fqhA,2350
8
- kodit/logging.py,sha256=cFEQXWI27LzWScSxly9ApwkbBDamUG17pA-jEfVakXQ,5316
9
- kodit/mcp.py,sha256=PxTHVPlIErrruFKzmEPIWZjN6cfEhcQmj6nOU9EsBy4,4905
10
- kodit/middleware.py,sha256=I6FOkqG9-8RH5kR1-0ZoQWfE4qLCB8lZYv8H_OCH29o,2714
11
- kodit/alembic/README,sha256=ISVtAOvqvKk_5ThM5ioJE-lMkvf9IbknFUFVU_vPma4,58
12
- kodit/alembic/__init__.py,sha256=lP5MuwlyWRMO6UcDWnQcQ3G-GYHcFb6rl9gYPHJ1sjo,40
13
- kodit/alembic/env.py,sha256=kcQiglu2KpNTAf37CsKVs_HXxOe6S7sXJ00pHGSCqno,2414
14
- kodit/alembic/script.py.mako,sha256=zWziKtiwYKEWuwPV_HBNHwa9LCT45_bi01-uSNFaOOE,703
15
- kodit/alembic/versions/85155663351e_initial.py,sha256=Cg7zlF871o9ShV5rQMQ1v7hRV7fI59veDY9cjtTrs-8,3306
16
- kodit/alembic/versions/__init__.py,sha256=9-lHzptItTzq_fomdIRBegQNm4Znx6pVjwD4MiqRIdo,36
17
- kodit/bm25/__init__.py,sha256=j8zyriNWhbwE5Lbybzg1hQAhANlU9mKHWw4beeUR6og,19
18
- kodit/bm25/bm25.py,sha256=3wyNRSrTaYqV7s4R1D6X0NpCf22PuFK2_uc8YapzYLE,2263
19
- kodit/indexing/__init__.py,sha256=cPyi2Iej3G1JFWlWr7X80_UrsMaTu5W5rBwgif1B3xo,75
20
- kodit/indexing/models.py,sha256=sZIhGwvL4Dw0QTWFxrjfWctSLkAoDT6fv5DlGz8-Fr8,1258
21
- kodit/indexing/repository.py,sha256=ZicLPXPKQxW6NnY_anmZ4nI1-FGkrJsqjg0NK-vvnTY,5117
22
- kodit/indexing/service.py,sha256=rLWYI70VytlJAyZtQC5Xpqtj9f3EzbivzgeM_1L9BUU,5751
23
- kodit/retreival/__init__.py,sha256=33PhJU-3gtsqYq6A1UkaLNKbev_Zee9Lq6dYC59-CsA,69
24
- kodit/retreival/repository.py,sha256=1lqGgJHsBmvMGMzEYa-hrdXg2q7rqtYPl1cvBb7jMRE,3119
25
- kodit/retreival/service.py,sha256=9wvURtPPJVvPUWNIC2waIrJMxcm1Ka1J_xDEOEedAFU,2007
26
- kodit/snippets/__init__.py,sha256=-2coNoCRjTixU9KcP6alpmt7zqf37tCRWH3D7FPJ8dg,48
27
- kodit/snippets/method_snippets.py,sha256=EVHhSNWahAC5nSXv9fWVFJY2yq25goHdCSCuENC07F8,4145
28
- kodit/snippets/snippets.py,sha256=QumvhltWoxXw41SyKb-RbSvAr3m6V3lUy9n0AI8jcto,1409
29
- kodit/snippets/languages/__init__.py,sha256=Bj5KKZSls2MQ8ZY1S_nHg447MgGZW-2WZM-oq6vjwwA,1187
30
- kodit/snippets/languages/csharp.scm,sha256=gbBN4RiV1FBuTJF6orSnDFi8H9JwTw-d4piLJYsWUsc,222
31
- kodit/snippets/languages/python.scm,sha256=ee85R9PBzwye3IMTE7-iVoKWd_ViU3EJISTyrFGrVeo,429
32
- kodit/sources/__init__.py,sha256=1NTZyPdjThVQpZO1Mp1ColVsS7sqYanOVLqnoqV9Ipo,83
33
- kodit/sources/models.py,sha256=xb42CaNDO1CUB8SIW-xXMrB6Ji8cFw-yeJ550xBEg9Q,2398
34
- kodit/sources/repository.py,sha256=mGJrHWH6Uo8YABdoojHFbzaf_jW-2ywJpAHIa1gnc3U,3401
35
- kodit/sources/service.py,sha256=hqAjGFVhvtePhMrK1Aprj__Mq2PLjVq8CsWMBoA3_Qw,9217
36
- kodit-0.1.9.dist-info/METADATA,sha256=MAqVxrLPrTV3Ihcix_3YHQNq9qyuD1OEavYHV76qli8,2214
37
- kodit-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
- kodit-0.1.9.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
39
- kodit-0.1.9.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
40
- kodit-0.1.9.dist-info/RECORD,,
File without changes
File without changes
File without changes
File without changes
File without changes