kodit 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

kodit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.3'
21
- __version_tuple__ = version_tuple = (0, 1, 3)
20
+ __version__ = version = '0.1.5'
21
+ __version_tuple__ = version_tuple = (0, 1, 5)
kodit/alembic/env.py CHANGED
@@ -66,8 +66,6 @@ async def run_async_migrations() -> None:
66
66
  prefix="sqlalchemy.",
67
67
  poolclass=pool.NullPool,
68
68
  )
69
- log = structlog.get_logger(__name__)
70
- log.debug("Running migrations on %s", connectable.url)
71
69
 
72
70
  async with connectable.connect() as connection:
73
71
  await connection.run_sync(do_run_migrations)
kodit/app.py CHANGED
@@ -5,14 +5,10 @@ from fastapi import FastAPI
5
5
 
6
6
  from kodit.mcp import mcp
7
7
  from kodit.middleware import logging_middleware
8
- from kodit.sse import create_sse_server
9
8
 
10
- app = FastAPI(title="kodit API")
11
-
12
- # Get the SSE routes from the Starlette app hosting the MCP server
13
- sse_app = create_sse_server(mcp)
14
- for route in sse_app.routes:
15
- app.router.routes.append(route)
9
+ # See https://gofastmcp.com/deployment/asgi#fastapi-integration
10
+ mcp_app = mcp.sse_app()
11
+ app = FastAPI(title="kodit API", lifespan=mcp_app.router.lifespan_context)
16
12
 
17
13
  # Add middleware
18
14
  app.middleware("http")(logging_middleware)
@@ -22,4 +18,8 @@ app.add_middleware(CorrelationIdMiddleware)
22
18
  @app.get("/")
23
19
  async def root() -> dict[str, str]:
24
20
  """Return a welcome message for the kodit API."""
25
- return {"message": "Welcome to kodit API"}
21
+ return {"message": "Hello, World!"}
22
+
23
+
24
+ # Add mcp routes last, otherwise previous routes aren't added
25
+ app.mount("", mcp_app)
kodit/bm25/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """BM25 module."""
kodit/bm25/bm25.py ADDED
@@ -0,0 +1,71 @@
1
+ """BM25 service."""
2
+
3
+ import bm25s
4
+ import Stemmer
5
+ import structlog
6
+ from bm25s.tokenization import Tokenized
7
+
8
+ from kodit.config import Config
9
+
10
+
11
+ class BM25Service:
12
+ """Service for BM25."""
13
+
14
+ def __init__(self, config: Config) -> None:
15
+ """Initialize the BM25 service."""
16
+ self.log = structlog.get_logger(__name__)
17
+ self.index_path = config.get_data_dir() / "bm25s_index"
18
+ try:
19
+ self.log.debug("Loading BM25 index")
20
+ self.retriever = bm25s.BM25.load(self.index_path, mmap=True)
21
+ except FileNotFoundError:
22
+ self.log.debug("BM25 index not found, creating new index")
23
+ self.retriever = bm25s.BM25()
24
+
25
+ self.stemmer = Stemmer.Stemmer("english")
26
+
27
+ def _tokenize(self, corpus: list[str]) -> list[list[str]] | Tokenized:
28
+ return bm25s.tokenize(
29
+ corpus,
30
+ stopwords="en",
31
+ stemmer=self.stemmer,
32
+ return_ids=False,
33
+ show_progress=True,
34
+ )
35
+
36
+ def index(self, corpus: list[str]) -> None:
37
+ """Index a new corpus."""
38
+ self.log.debug("Indexing corpus")
39
+ vocab = self._tokenize(corpus)
40
+ self.retriever = bm25s.BM25()
41
+ self.retriever.index(vocab)
42
+ self.retriever.save(self.index_path)
43
+
44
+ def retrieve(
45
+ self, doc_ids: list[int], query: str, top_k: int = 2
46
+ ) -> list[tuple[int, float]]:
47
+ """Retrieve from the index."""
48
+ if top_k == 0:
49
+ self.log.warning("Top k is 0, returning empty list")
50
+ return []
51
+ if len(doc_ids) == 0:
52
+ self.log.warning("No documents to retrieve from, returning empty list")
53
+ return []
54
+
55
+ top_k = min(top_k, len(doc_ids))
56
+ self.log.debug(
57
+ "Retrieving from index", query=query, top_k=top_k, num_docs=len(doc_ids)
58
+ )
59
+
60
+ query_tokens = self._tokenize([query])
61
+
62
+ self.log.debug("Query tokens", query_tokens=query_tokens)
63
+
64
+ results, scores = self.retriever.retrieve(
65
+ query_tokens=query_tokens, corpus=doc_ids, k=top_k
66
+ )
67
+ self.log.debug("Raw results", results=results, scores=scores)
68
+ return [
69
+ (int(result), float(score))
70
+ for result, score in zip(results[0], scores[0], strict=False)
71
+ ]
kodit/cli.py CHANGED
@@ -1,41 +1,74 @@
1
1
  """Command line interface for kodit."""
2
2
 
3
3
  import os
4
+ import signal
5
+ from pathlib import Path
6
+ from typing import Any
4
7
 
5
8
  import click
6
9
  import structlog
7
10
  import uvicorn
8
- from dotenv import dotenv_values
9
11
  from pytable_formatter import Table
10
12
  from sqlalchemy.ext.asyncio import AsyncSession
11
13
 
12
- from kodit.database import configure_database, with_session
14
+ from kodit.config import (
15
+ DEFAULT_BASE_DIR,
16
+ DEFAULT_DB_URL,
17
+ DEFAULT_DISABLE_TELEMETRY,
18
+ DEFAULT_LOG_FORMAT,
19
+ DEFAULT_LOG_LEVEL,
20
+ get_config,
21
+ reset_config,
22
+ with_session,
23
+ )
13
24
  from kodit.indexing.repository import IndexRepository
14
25
  from kodit.indexing.service import IndexService
15
- from kodit.logging import LogFormat, configure_logging, disable_posthog, log_event
26
+ from kodit.logging import configure_logging, configure_telemetry, log_event
16
27
  from kodit.retreival.repository import RetrievalRepository
17
28
  from kodit.retreival.service import RetrievalRequest, RetrievalService
18
29
  from kodit.sources.repository import SourceRepository
19
30
  from kodit.sources.service import SourceService
20
31
 
21
- env_vars = dict(dotenv_values())
22
- os.environ.update(env_vars)
23
32
 
24
-
25
- @click.group(context_settings={"auto_envvar_prefix": "KODIT", "show_default": True})
26
- @click.option("--log-level", default="INFO", help="Log level")
27
- @click.option("--log-format", default=LogFormat.PRETTY, help="Log format")
28
- @click.option("--disable-telemetry", is_flag=True, help="Disable telemetry")
29
- def cli(
30
- log_level: str,
31
- log_format: LogFormat,
32
- disable_telemetry: bool, # noqa: FBT001
33
+ @click.group(context_settings={"max_content_width": 100})
34
+ @click.option("--log-level", help=f"Log level [default: {DEFAULT_LOG_LEVEL}]")
35
+ @click.option("--log-format", help=f"Log format [default: {DEFAULT_LOG_FORMAT}]")
36
+ @click.option(
37
+ "--disable-telemetry",
38
+ is_flag=True,
39
+ help=f"Disable telemetry [default: {DEFAULT_DISABLE_TELEMETRY}]",
40
+ )
41
+ @click.option("--db-url", help=f"Database URL [default: {DEFAULT_DB_URL}]")
42
+ @click.option("--data-dir", help=f"Data directory [default: {DEFAULT_BASE_DIR}]")
43
+ @click.option("--env-file", help="Path to a .env file [default: .env]")
44
+ def cli( # noqa: PLR0913
45
+ log_level: str | None,
46
+ log_format: str | None,
47
+ disable_telemetry: bool | None,
48
+ db_url: str | None,
49
+ data_dir: str | None,
50
+ env_file: str | None,
33
51
  ) -> None:
34
52
  """kodit CLI - Code indexing for better AI code generation.""" # noqa: D403
35
- configure_logging(log_level, log_format)
53
+ # First check if env-file is set and reload config if it is
54
+ if env_file:
55
+ reset_config()
56
+ get_config(env_file)
57
+
58
+ # Override global config with cli args, if set
59
+ config = get_config()
60
+ if data_dir:
61
+ config.data_dir = Path(data_dir)
62
+ if db_url:
63
+ config.db_url = db_url
64
+ if log_level:
65
+ config.log_level = log_level
66
+ if log_format:
67
+ config.log_format = log_format
36
68
  if disable_telemetry:
37
- disable_posthog()
38
- configure_database()
69
+ config.disable_telemetry = disable_telemetry
70
+ configure_logging(config)
71
+ configure_telemetry(config)
39
72
 
40
73
 
41
74
  @cli.group()
@@ -48,7 +81,7 @@ def sources() -> None:
48
81
  async def list_sources(session: AsyncSession) -> None:
49
82
  """List all code sources."""
50
83
  repository = SourceRepository(session)
51
- service = SourceService(repository)
84
+ service = SourceService(get_config().get_clone_dir(), repository)
52
85
  sources = await service.list_sources()
53
86
 
54
87
  # Define headers and data
@@ -66,7 +99,7 @@ async def list_sources(session: AsyncSession) -> None:
66
99
  async def create_source(session: AsyncSession, uri: str) -> None:
67
100
  """Add a new code source."""
68
101
  repository = SourceRepository(session)
69
- service = SourceService(repository)
102
+ service = SourceService(get_config().get_clone_dir(), repository)
70
103
  source = await service.create(uri)
71
104
  click.echo(f"Source created: {source.id}")
72
105
 
@@ -82,9 +115,9 @@ def indexes() -> None:
82
115
  async def create_index(session: AsyncSession, source_id: int) -> None:
83
116
  """Create an index for a source."""
84
117
  source_repository = SourceRepository(session)
85
- source_service = SourceService(source_repository)
118
+ source_service = SourceService(get_config().get_clone_dir(), source_repository)
86
119
  repository = IndexRepository(session)
87
- service = IndexService(repository, source_service)
120
+ service = IndexService(get_config(), repository, source_service)
88
121
  index = await service.create(source_id)
89
122
  click.echo(f"Index created: {index.id}")
90
123
 
@@ -94,9 +127,9 @@ async def create_index(session: AsyncSession, source_id: int) -> None:
94
127
  async def list_indexes(session: AsyncSession) -> None:
95
128
  """List all indexes."""
96
129
  source_repository = SourceRepository(session)
97
- source_service = SourceService(source_repository)
130
+ source_service = SourceService(get_config().get_clone_dir(), source_repository)
98
131
  repository = IndexRepository(session)
99
- service = IndexService(repository, source_service)
132
+ service = IndexService(get_config(), repository, source_service)
100
133
  indexes = await service.list_indexes()
101
134
 
102
135
  # Define headers and data
@@ -104,7 +137,6 @@ async def list_indexes(session: AsyncSession) -> None:
104
137
  "ID",
105
138
  "Created At",
106
139
  "Updated At",
107
- "Source URI",
108
140
  "Num Snippets",
109
141
  ]
110
142
  data = [
@@ -112,7 +144,6 @@ async def list_indexes(session: AsyncSession) -> None:
112
144
  index.id,
113
145
  index.created_at,
114
146
  index.updated_at,
115
- index.source_uri,
116
147
  index.num_snippets,
117
148
  ]
118
149
  for index in indexes
@@ -129,48 +160,69 @@ async def list_indexes(session: AsyncSession) -> None:
129
160
  async def run_index(session: AsyncSession, index_id: int) -> None:
130
161
  """Run an index."""
131
162
  source_repository = SourceRepository(session)
132
- source_service = SourceService(source_repository)
163
+ source_service = SourceService(get_config().get_clone_dir(), source_repository)
133
164
  repository = IndexRepository(session)
134
- service = IndexService(repository, source_service)
165
+ service = IndexService(get_config(), repository, source_service)
135
166
  await service.run(index_id)
136
167
 
137
168
 
138
169
  @cli.command()
139
170
  @click.argument("query")
171
+ @click.option("--top-k", default=10, help="Number of snippets to retrieve")
140
172
  @with_session
141
- async def retrieve(session: AsyncSession, query: str) -> None:
173
+ async def retrieve(session: AsyncSession, query: str, top_k: int) -> None:
142
174
  """Retrieve snippets from the database."""
143
175
  repository = RetrievalRepository(session)
144
- service = RetrievalService(repository)
145
- snippets = await service.retrieve(RetrievalRequest(query=query))
176
+ service = RetrievalService(get_config(), repository)
177
+ # Temporary request while we don't have all search capabilities
178
+ snippets = await service.retrieve(
179
+ RetrievalRequest(keywords=query.split(","), top_k=top_k)
180
+ )
181
+
182
+ if len(snippets) == 0:
183
+ click.echo("No snippets found")
184
+ return
146
185
 
147
186
  for snippet in snippets:
187
+ click.echo("-" * 80)
148
188
  click.echo(f"{snippet.uri}")
149
189
  click.echo(snippet.content)
190
+ click.echo("-" * 80)
150
191
  click.echo()
151
192
 
152
193
 
153
194
  @cli.command()
154
195
  @click.option("--host", default="127.0.0.1", help="Host to bind the server to")
155
196
  @click.option("--port", default=8080, help="Port to bind the server to")
156
- @click.option("--reload", is_flag=True, help="Enable auto-reload for development")
157
197
  def serve(
158
198
  host: str,
159
199
  port: int,
160
- reload: bool, # noqa: FBT001
161
200
  ) -> None:
162
201
  """Start the kodit server, which hosts the MCP server and the kodit API."""
163
202
  log = structlog.get_logger(__name__)
164
- log.info("Starting kodit server", host=host, port=port, reload=reload)
203
+ log.info("Starting kodit server", host=host, port=port)
165
204
  log_event("kodit_server_started")
166
- uvicorn.run(
205
+ os.environ["HELLO"] = "WORLD"
206
+
207
+ # Configure uvicorn with graceful shutdown
208
+ config = uvicorn.Config(
167
209
  "kodit.app:app",
168
210
  host=host,
169
211
  port=port,
170
- reload=reload,
212
+ reload=False,
171
213
  log_config=None, # Setting to None forces uvicorn to use our structlog setup
172
214
  access_log=False, # Using own middleware for access logging
215
+ timeout_graceful_shutdown=0, # The mcp server does not shutdown cleanly, force
173
216
  )
217
+ server = uvicorn.Server(config)
218
+
219
+ def handle_sigint(signum: int, frame: Any) -> None:
220
+ """Handle SIGINT (Ctrl+C)."""
221
+ log.info("Received shutdown signal, force killing MCP connections")
222
+ server.handle_exit(signum, frame)
223
+
224
+ signal.signal(signal.SIGINT, handle_sigint)
225
+ server.run()
174
226
 
175
227
 
176
228
  @cli.command()
kodit/config.py CHANGED
@@ -1,5 +1,89 @@
1
- """Configuration for the kodit project."""
1
+ """Global configuration for the kodit project."""
2
2
 
3
+ import asyncio
4
+ from collections.abc import Callable
5
+ from functools import wraps
3
6
  from pathlib import Path
7
+ from typing import Any, TypeVar
4
8
 
5
- DATA_DIR = Path.home() / ".kodit"
9
+ from pydantic import Field
10
+ from pydantic_settings import BaseSettings, SettingsConfigDict
11
+
12
+ from kodit.database import Database
13
+
14
+ DEFAULT_BASE_DIR = Path.home() / ".kodit"
15
+ DEFAULT_DB_URL = f"sqlite+aiosqlite:///{DEFAULT_BASE_DIR}/kodit.db"
16
+ DEFAULT_LOG_LEVEL = "INFO"
17
+ DEFAULT_LOG_FORMAT = "pretty"
18
+ DEFAULT_DISABLE_TELEMETRY = False
19
+ T = TypeVar("T")
20
+
21
+
22
+ class Config(BaseSettings):
23
+ """Global configuration for the kodit project."""
24
+
25
+ model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
26
+
27
+ data_dir: Path = Field(default=DEFAULT_BASE_DIR)
28
+ db_url: str = Field(default=DEFAULT_DB_URL)
29
+ log_level: str = Field(default=DEFAULT_LOG_LEVEL)
30
+ log_format: str = Field(default=DEFAULT_LOG_FORMAT)
31
+ disable_telemetry: bool = Field(default=DEFAULT_DISABLE_TELEMETRY)
32
+ _db: Database | None = None
33
+
34
+ def model_post_init(self, _: Any) -> None:
35
+ """Post-initialization hook."""
36
+ # Call this to ensure the data dir exists for the default db location
37
+ self.get_data_dir()
38
+
39
+ def get_data_dir(self) -> Path:
40
+ """Get the data directory."""
41
+ self.data_dir.mkdir(parents=True, exist_ok=True)
42
+ return self.data_dir
43
+
44
+ def get_clone_dir(self) -> Path:
45
+ """Get the clone directory."""
46
+ clone_dir = self.get_data_dir() / "clones"
47
+ clone_dir.mkdir(parents=True, exist_ok=True)
48
+ return clone_dir
49
+
50
+ def get_db(self, *, run_migrations: bool = True) -> Database:
51
+ """Get the database."""
52
+ if self._db is None:
53
+ self._db = Database(self.db_url, run_migrations=run_migrations)
54
+ return self._db
55
+
56
+
57
+ # Global config instance for mcp Apps
58
+ config = None
59
+
60
+
61
+ def get_config(env_file: str | None = None) -> Config:
62
+ """Get the global config instance."""
63
+ global config # noqa: PLW0603
64
+ if config is None:
65
+ config = Config(_env_file=env_file)
66
+ return config
67
+
68
+
69
+ def reset_config() -> None:
70
+ """Reset the global config instance."""
71
+ global config # noqa: PLW0603
72
+ config = None
73
+
74
+
75
+ def with_session(func: Callable[..., T]) -> Callable[..., T]:
76
+ """Provide an async session to CLI commands."""
77
+
78
+ @wraps(func)
79
+ def wrapper(*args: Any, **kwargs: Any) -> T:
80
+ # Create DB connection before starting event loop
81
+ db = get_config().get_db()
82
+
83
+ async def _run() -> T:
84
+ async with db.get_session() as session:
85
+ return await func(session, *args, **kwargs)
86
+
87
+ return asyncio.run(_run())
88
+
89
+ return wrapper
kodit/database.py CHANGED
@@ -1,14 +1,13 @@
1
1
  """Database configuration for kodit."""
2
2
 
3
- import asyncio
4
- from collections.abc import AsyncGenerator, Callable
3
+ from collections.abc import AsyncGenerator
4
+ from contextlib import asynccontextmanager
5
5
  from datetime import UTC, datetime
6
- from functools import wraps
7
6
  from pathlib import Path
8
- from typing import Any, TypeVar
9
7
 
8
+ import structlog
10
9
  from alembic import command
11
- from alembic.config import Config
10
+ from alembic.config import Config as AlembicConfig
12
11
  from sqlalchemy import DateTime
13
12
  from sqlalchemy.ext.asyncio import (
14
13
  AsyncAttrs,
@@ -19,23 +18,6 @@ from sqlalchemy.ext.asyncio import (
19
18
  from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
20
19
 
21
20
  from kodit import alembic
22
- from kodit.config import DATA_DIR
23
-
24
- # Constants
25
- DB_URL = f"sqlite+aiosqlite:///{DATA_DIR}/kodit.db"
26
-
27
- # Create data directory if it doesn't exist
28
- DATA_DIR.mkdir(exist_ok=True)
29
-
30
- # Create async engine with file-based SQLite
31
- engine = create_async_engine(DB_URL, echo=False)
32
-
33
- # Create async session factory
34
- async_session_factory = async_sessionmaker(
35
- engine,
36
- class_=AsyncSession,
37
- expire_on_commit=False,
38
- )
39
21
 
40
22
 
41
23
  class Base(AsyncAttrs, DeclarativeBase):
@@ -54,36 +36,37 @@ class CommonMixin:
54
36
  )
55
37
 
56
38
 
57
- async def get_session() -> AsyncGenerator[AsyncSession, None]:
58
- """Get a database session."""
59
- async with async_session_factory() as session:
60
- try:
61
- yield session
62
- finally:
63
- await session.close()
64
-
65
-
66
- T = TypeVar("T")
67
-
68
-
69
- def with_session(func: Callable[..., T]) -> Callable[..., T]:
70
- """Provide an async session to CLI commands."""
71
-
72
- @wraps(func)
73
- def wrapper(*args: Any, **kwargs: Any) -> T:
74
- async def _run() -> T:
75
- async with async_session_factory() as session:
76
- return await func(session, *args, **kwargs)
77
-
78
- return asyncio.run(_run())
79
-
80
- return wrapper
81
-
82
-
83
- def configure_database() -> None:
84
- """Configure the database by initializing it and running any pending migrations."""
85
- # Create Alembic configuration and run migrations
86
- alembic_cfg = Config()
87
- alembic_cfg.set_main_option("script_location", str(Path(alembic.__file__).parent))
88
- alembic_cfg.set_main_option("sqlalchemy.url", DB_URL)
89
- command.upgrade(alembic_cfg, "head")
39
+ class Database:
40
+ """Database class for kodit."""
41
+
42
+ def __init__(self, db_url: str, *, run_migrations: bool = True) -> None:
43
+ """Initialize the database."""
44
+ self.log = structlog.get_logger(__name__)
45
+ if run_migrations:
46
+ self._run_migrations(db_url)
47
+ db_engine = create_async_engine(db_url, echo=False)
48
+ self.db_session_factory = async_sessionmaker(
49
+ db_engine,
50
+ class_=AsyncSession,
51
+ expire_on_commit=False,
52
+ )
53
+
54
+ @asynccontextmanager
55
+ async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
56
+ """Get a database session."""
57
+ async with self.db_session_factory() as session:
58
+ try:
59
+ yield session
60
+ finally:
61
+ await session.close()
62
+
63
+ def _run_migrations(self, db_url: str) -> None:
64
+ """Run any pending migrations."""
65
+ # Create Alembic configuration and run migrations
66
+ alembic_cfg = AlembicConfig()
67
+ alembic_cfg.set_main_option(
68
+ "script_location", str(Path(alembic.__file__).parent)
69
+ )
70
+ alembic_cfg.set_main_option("sqlalchemy.url", db_url)
71
+ self.log.debug("Running migrations", db_url=db_url)
72
+ command.upgrade(alembic_cfg, "head")
@@ -130,3 +130,14 @@ class IndexRepository:
130
130
  query = select(Snippet).where(Snippet.index_id == index_id)
131
131
  result = await self.session.execute(query)
132
132
  return list(result.scalars())
133
+
134
+ async def get_all_snippets(self) -> list[Snippet]:
135
+ """Get all snippets.
136
+
137
+ Returns:
138
+ A list of all snippets.
139
+
140
+ """
141
+ query = select(Snippet).order_by(Snippet.id)
142
+ result = await self.session.execute(query)
143
+ return list(result.scalars())
kodit/indexing/service.py CHANGED
@@ -7,24 +7,21 @@ index management.
7
7
  """
8
8
 
9
9
  from datetime import datetime
10
+ from pathlib import Path
10
11
 
11
- import aiofiles
12
12
  import pydantic
13
13
  import structlog
14
14
  from tqdm.asyncio import tqdm
15
15
 
16
+ from kodit.bm25.bm25 import BM25Service
17
+ from kodit.config import Config
16
18
  from kodit.indexing.models import Snippet
17
19
  from kodit.indexing.repository import IndexRepository
20
+ from kodit.snippets.snippets import SnippetService
18
21
  from kodit.sources.service import SourceService
19
22
 
20
- # List of MIME types that are supported for indexing and snippet creation
21
- MIME_WHITELIST = [
22
- "text/plain",
23
- "text/markdown",
24
- "text/x-python",
25
- "text/x-shellscript",
26
- "text/x-sql",
27
- ]
23
+ # List of MIME types that are blacklisted from being indexed
24
+ MIME_BLACKLIST = ["unknown/unknown"]
28
25
 
29
26
 
30
27
  class IndexView(pydantic.BaseModel):
@@ -37,7 +34,6 @@ class IndexView(pydantic.BaseModel):
37
34
  id: int
38
35
  created_at: datetime
39
36
  updated_at: datetime | None = None
40
- source_uri: str | None = None
41
37
  num_snippets: int | None = None
42
38
 
43
39
 
@@ -50,7 +46,7 @@ class IndexService:
50
46
  """
51
47
 
52
48
  def __init__(
53
- self, repository: IndexRepository, source_service: SourceService
49
+ self, config: Config, repository: IndexRepository, source_service: SourceService
54
50
  ) -> None:
55
51
  """Initialize the index service.
56
52
 
@@ -61,7 +57,9 @@ class IndexService:
61
57
  """
62
58
  self.repository = repository
63
59
  self.source_service = source_service
60
+ self.snippet_service = SnippetService()
64
61
  self.log = structlog.get_logger(__name__)
62
+ self.bm25 = BM25Service(config)
65
63
 
66
64
  async def create(self, source_id: int) -> IndexView:
67
65
  """Create a new index for a source.
@@ -120,6 +118,10 @@ class IndexService:
120
118
  # Create snippets for supported file types
121
119
  await self._create_snippets(index_id)
122
120
 
121
+ # Update BM25 index
122
+ snippets = await self.repository.get_all_snippets()
123
+ self.bm25.index([snippet.content for snippet in snippets])
124
+
123
125
  # Update index timestamp
124
126
  await self.repository.update_index_timestamp(index)
125
127
 
@@ -138,16 +140,23 @@ class IndexService:
138
140
  files = await self.repository.files_for_index(index_id)
139
141
  for file in tqdm(files, total=len(files)):
140
142
  # Skip unsupported file types
141
- if file.mime_type not in MIME_WHITELIST:
143
+ if file.mime_type in MIME_BLACKLIST:
142
144
  self.log.debug("Skipping mime type", mime_type=file.mime_type)
143
145
  continue
144
146
 
145
147
  # Create snippet from file content
146
- async with aiofiles.open(file.cloned_path, "rb") as f:
147
- content = await f.read()
148
- snippet = Snippet(
148
+ try:
149
+ snippets = self.snippet_service.snippets_for_file(
150
+ Path(file.cloned_path)
151
+ )
152
+ except ValueError as e:
153
+ self.log.debug("Skipping file", file=file.cloned_path, error=e)
154
+ continue
155
+
156
+ for snippet in snippets:
157
+ s = Snippet(
149
158
  index_id=index_id,
150
159
  file_id=file.id,
151
- content=content.decode("utf-8"),
160
+ content=snippet.text,
152
161
  )
153
- await self.repository.add_snippet(snippet)
162
+ await self.repository.add_snippet(s)