kodit 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

kodit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.5'
21
- __version_tuple__ = version_tuple = (0, 1, 5)
20
+ __version__ = version = '0.1.7'
21
+ __version_tuple__ = version_tuple = (0, 1, 7)
kodit/alembic/env.py CHANGED
@@ -3,7 +3,6 @@
3
3
 
4
4
  import asyncio
5
5
 
6
- import structlog
7
6
  from alembic import context
8
7
  from sqlalchemy import pool
9
8
  from sqlalchemy.engine import Connection
@@ -75,7 +74,11 @@ async def run_async_migrations() -> None:
75
74
 
76
75
  def run_migrations_online() -> None:
77
76
  """Run migrations in 'online' mode."""
78
- asyncio.run(run_async_migrations())
77
+ connectable = config.attributes.get("connection", None)
78
+ if connectable is None:
79
+ asyncio.run(run_async_migrations())
80
+ else:
81
+ do_run_migrations(connectable)
79
82
 
80
83
 
81
84
  if context.is_offline_mode():
kodit/app.py CHANGED
@@ -4,7 +4,7 @@ from asgi_correlation_id import CorrelationIdMiddleware
4
4
  from fastapi import FastAPI
5
5
 
6
6
  from kodit.mcp import mcp
7
- from kodit.middleware import logging_middleware
7
+ from kodit.middleware import ASGICancelledErrorMiddleware, logging_middleware
8
8
 
9
9
  # See https://gofastmcp.com/deployment/asgi#fastapi-integration
10
10
  mcp_app = mcp.sse_app()
@@ -23,3 +23,7 @@ async def root() -> dict[str, str]:
23
23
 
24
24
  # Add mcp routes last, otherwise previous routes aren't added
25
25
  app.mount("", mcp_app)
26
+
27
+ # Wrap the entire app with ASGI middleware after all routes are added to suppress
28
+ # CancelledError at the ASGI level
29
+ app = ASGICancelledErrorMiddleware(app)
kodit/bm25/bm25.py CHANGED
@@ -1,20 +1,20 @@
1
1
  """BM25 service."""
2
2
 
3
+ from pathlib import Path
4
+
3
5
  import bm25s
4
6
  import Stemmer
5
7
  import structlog
6
8
  from bm25s.tokenization import Tokenized
7
9
 
8
- from kodit.config import Config
9
-
10
10
 
11
11
  class BM25Service:
12
12
  """Service for BM25."""
13
13
 
14
- def __init__(self, config: Config) -> None:
14
+ def __init__(self, data_dir: Path) -> None:
15
15
  """Initialize the BM25 service."""
16
16
  self.log = structlog.get_logger(__name__)
17
- self.index_path = config.get_data_dir() / "bm25s_index"
17
+ self.index_path = data_dir / "bm25s_index"
18
18
  try:
19
19
  self.log.debug("Loading BM25 index")
20
20
  self.retriever = bm25s.BM25.load(self.index_path, mmap=True)
kodit/cli.py CHANGED
@@ -8,7 +8,7 @@ from typing import Any
8
8
  import click
9
9
  import structlog
10
10
  import uvicorn
11
- from pytable_formatter import Table
11
+ from pytable_formatter import Cell, Table
12
12
  from sqlalchemy.ext.asyncio import AsyncSession
13
13
 
14
14
  from kodit.config import (
@@ -17,8 +17,8 @@ from kodit.config import (
17
17
  DEFAULT_DISABLE_TELEMETRY,
18
18
  DEFAULT_LOG_FORMAT,
19
19
  DEFAULT_LOG_LEVEL,
20
- get_config,
21
- reset_config,
20
+ AppContext,
21
+ with_app_context,
22
22
  with_session,
23
23
  )
24
24
  from kodit.indexing.repository import IndexRepository
@@ -40,23 +40,33 @@ from kodit.sources.service import SourceService
40
40
  )
41
41
  @click.option("--db-url", help=f"Database URL [default: {DEFAULT_DB_URL}]")
42
42
  @click.option("--data-dir", help=f"Data directory [default: {DEFAULT_BASE_DIR}]")
43
- @click.option("--env-file", help="Path to a .env file [default: .env]")
43
+ @click.option(
44
+ "--env-file",
45
+ help="Path to a .env file [default: .env]",
46
+ type=click.Path(
47
+ exists=True,
48
+ dir_okay=False,
49
+ resolve_path=True,
50
+ path_type=Path,
51
+ ),
52
+ )
53
+ @click.pass_context
44
54
  def cli( # noqa: PLR0913
55
+ ctx: click.Context,
45
56
  log_level: str | None,
46
57
  log_format: str | None,
47
58
  disable_telemetry: bool | None,
48
59
  db_url: str | None,
49
60
  data_dir: str | None,
50
- env_file: str | None,
61
+ env_file: Path | None,
51
62
  ) -> None:
52
63
  """kodit CLI - Code indexing for better AI code generation.""" # noqa: D403
64
+ config = AppContext()
53
65
  # First check if env-file is set and reload config if it is
54
66
  if env_file:
55
- reset_config()
56
- get_config(env_file)
67
+ config = AppContext(_env_file=env_file) # type: ignore[reportCallIssue]
57
68
 
58
- # Override global config with cli args, if set
59
- config = get_config()
69
+ # Now override with CLI arguments, if set
60
70
  if data_dir:
61
71
  config.data_dir = Path(data_dir)
62
72
  if db_url:
@@ -70,110 +80,76 @@ def cli( # noqa: PLR0913
70
80
  configure_logging(config)
71
81
  configure_telemetry(config)
72
82
 
73
-
74
- @cli.group()
75
- def sources() -> None:
76
- """Manage code sources."""
77
-
78
-
79
- @sources.command(name="list")
80
- @with_session
81
- async def list_sources(session: AsyncSession) -> None:
82
- """List all code sources."""
83
- repository = SourceRepository(session)
84
- service = SourceService(get_config().get_clone_dir(), repository)
85
- sources = await service.list_sources()
86
-
87
- # Define headers and data
88
- headers = ["ID", "Created At", "URI"]
89
- data = [[source.id, source.created_at, source.uri] for source in sources]
90
-
91
- # Create and display the table
92
- table = Table(headers=headers, data=data)
93
- click.echo(table)
94
-
95
-
96
- @sources.command(name="create")
97
- @click.argument("uri")
98
- @with_session
99
- async def create_source(session: AsyncSession, uri: str) -> None:
100
- """Add a new code source."""
101
- repository = SourceRepository(session)
102
- service = SourceService(get_config().get_clone_dir(), repository)
103
- source = await service.create(uri)
104
- click.echo(f"Source created: {source.id}")
105
-
106
-
107
- @cli.group()
108
- def indexes() -> None:
109
- """Manage indexes."""
110
-
111
-
112
- @indexes.command(name="create")
113
- @click.argument("source_id")
114
- @with_session
115
- async def create_index(session: AsyncSession, source_id: int) -> None:
116
- """Create an index for a source."""
117
- source_repository = SourceRepository(session)
118
- source_service = SourceService(get_config().get_clone_dir(), source_repository)
119
- repository = IndexRepository(session)
120
- service = IndexService(get_config(), repository, source_service)
121
- index = await service.create(source_id)
122
- click.echo(f"Index created: {index.id}")
83
+ # Set the app context in the click context for downstream cli
84
+ ctx.obj = config
123
85
 
124
86
 
125
- @indexes.command(name="list")
87
+ @cli.command()
88
+ @click.argument("sources", nargs=-1)
89
+ @with_app_context
126
90
  @with_session
127
- async def list_indexes(session: AsyncSession) -> None:
128
- """List all indexes."""
91
+ async def index(
92
+ session: AsyncSession,
93
+ app_context: AppContext,
94
+ sources: list[str],
95
+ ) -> None:
96
+ """List indexes, or index data sources."""
129
97
  source_repository = SourceRepository(session)
130
- source_service = SourceService(get_config().get_clone_dir(), source_repository)
98
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
131
99
  repository = IndexRepository(session)
132
- service = IndexService(get_config(), repository, source_service)
133
- indexes = await service.list_indexes()
134
-
135
- # Define headers and data
136
- headers = [
137
- "ID",
138
- "Created At",
139
- "Updated At",
140
- "Num Snippets",
141
- ]
142
- data = [
143
- [
144
- index.id,
145
- index.created_at,
146
- index.updated_at,
147
- index.num_snippets,
100
+ service = IndexService(repository, source_service, app_context.get_data_dir())
101
+
102
+ if not sources:
103
+ # No source specified, list all indexes
104
+ indexes = await service.list_indexes()
105
+ headers: list[str | Cell] = [
106
+ "ID",
107
+ "Created At",
108
+ "Updated At",
109
+ "Source",
110
+ "Num Snippets",
148
111
  ]
149
- for index in indexes
150
- ]
151
-
152
- # Create and display the table
153
- table = Table(headers=headers, data=data)
154
- click.echo(table)
155
-
156
-
157
- @indexes.command(name="run")
158
- @click.argument("index_id")
159
- @with_session
160
- async def run_index(session: AsyncSession, index_id: int) -> None:
161
- """Run an index."""
162
- source_repository = SourceRepository(session)
163
- source_service = SourceService(get_config().get_clone_dir(), source_repository)
164
- repository = IndexRepository(session)
165
- service = IndexService(get_config(), repository, source_service)
166
- await service.run(index_id)
112
+ data = [
113
+ [
114
+ index.id,
115
+ index.created_at,
116
+ index.updated_at,
117
+ index.source,
118
+ index.num_snippets,
119
+ ]
120
+ for index in indexes
121
+ ]
122
+ click.echo(Table(headers=headers, data=data))
123
+ return
124
+ # Handle source indexing
125
+ for source in sources:
126
+ if source.startswith("https://"):
127
+ msg = "Web or git indexing is not implemented yet"
128
+ raise click.UsageError(msg)
129
+ if source.startswith("git"):
130
+ msg = "Git indexing is not implemented yet"
131
+ raise click.UsageError(msg)
132
+ if Path(source).is_file():
133
+ msg = "File indexing is not implemented yet"
134
+ raise click.UsageError(msg)
135
+
136
+ # Index directory
137
+ s = await source_service.create(source)
138
+ index = await service.create(s.id)
139
+ await service.run(index.id)
167
140
 
168
141
 
169
142
  @cli.command()
170
143
  @click.argument("query")
171
144
  @click.option("--top-k", default=10, help="Number of snippets to retrieve")
145
+ @with_app_context
172
146
  @with_session
173
- async def retrieve(session: AsyncSession, query: str, top_k: int) -> None:
147
+ async def retrieve(
148
+ session: AsyncSession, app_context: AppContext, query: str, top_k: int
149
+ ) -> None:
174
150
  """Retrieve snippets from the database."""
175
151
  repository = RetrievalRepository(session)
176
- service = RetrievalService(get_config(), repository)
152
+ service = RetrievalService(repository, app_context.get_data_dir())
177
153
  # Temporary request while we don't have all search capabilities
178
154
  snippets = await service.retrieve(
179
155
  RetrievalRequest(keywords=query.split(","), top_k=top_k)
@@ -194,7 +170,9 @@ async def retrieve(session: AsyncSession, query: str, top_k: int) -> None:
194
170
  @cli.command()
195
171
  @click.option("--host", default="127.0.0.1", help="Host to bind the server to")
196
172
  @click.option("--port", default=8080, help="Port to bind the server to")
173
+ @with_app_context
197
174
  def serve(
175
+ app_context: AppContext,
198
176
  host: str,
199
177
  port: int,
200
178
  ) -> None:
@@ -202,7 +180,10 @@ def serve(
202
180
  log = structlog.get_logger(__name__)
203
181
  log.info("Starting kodit server", host=host, port=port)
204
182
  log_event("kodit_server_started")
205
- os.environ["HELLO"] = "WORLD"
183
+
184
+ # Dump AppContext to a dictionary of strings, and set the env vars
185
+ app_context_dict = {k: str(v) for k, v in app_context.model_dump().items()}
186
+ os.environ.update(app_context_dict)
206
187
 
207
188
  # Configure uvicorn with graceful shutdown
208
189
  config = uvicorn.Config(
kodit/config.py CHANGED
@@ -1,11 +1,12 @@
1
1
  """Global configuration for the kodit project."""
2
2
 
3
3
  import asyncio
4
- from collections.abc import Callable
4
+ from collections.abc import Callable, Coroutine
5
5
  from functools import wraps
6
6
  from pathlib import Path
7
7
  from typing import Any, TypeVar
8
8
 
9
+ import click
9
10
  from pydantic import Field
10
11
  from pydantic_settings import BaseSettings, SettingsConfigDict
11
12
 
@@ -19,8 +20,8 @@ DEFAULT_DISABLE_TELEMETRY = False
19
20
  T = TypeVar("T")
20
21
 
21
22
 
22
- class Config(BaseSettings):
23
- """Global configuration for the kodit project."""
23
+ class AppContext(BaseSettings):
24
+ """Global context for the kodit project. Provides a shared state for the app."""
24
25
 
25
26
  model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
26
27
 
@@ -47,43 +48,50 @@ class Config(BaseSettings):
47
48
  clone_dir.mkdir(parents=True, exist_ok=True)
48
49
  return clone_dir
49
50
 
50
- def get_db(self, *, run_migrations: bool = True) -> Database:
51
+ async def get_db(self, *, run_migrations: bool = True) -> Database:
51
52
  """Get the database."""
52
53
  if self._db is None:
53
- self._db = Database(self.db_url, run_migrations=run_migrations)
54
+ self._db = Database(self.db_url)
55
+ if run_migrations:
56
+ await self._db.run_migrations(self.db_url)
54
57
  return self._db
55
58
 
56
59
 
57
- # Global config instance for mcp Apps
58
- config = None
60
+ with_app_context = click.make_pass_decorator(AppContext)
59
61
 
62
+ T = TypeVar("T")
60
63
 
61
- def get_config(env_file: str | None = None) -> Config:
62
- """Get the global config instance."""
63
- global config # noqa: PLW0603
64
- if config is None:
65
- config = Config(_env_file=env_file)
66
- return config
67
64
 
65
+ def wrap_async(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
66
+ """Decorate async Click commands.
68
67
 
69
- def reset_config() -> None:
70
- """Reset the global config instance."""
71
- global config # noqa: PLW0603
72
- config = None
68
+ This decorator wraps an async function to run it with asyncio.run().
69
+ It should be used after the Click command decorator.
73
70
 
71
+ Example:
72
+ @cli.command()
73
+ @wrap_async
74
+ async def my_command():
75
+ ...
74
76
 
75
- def with_session(func: Callable[..., T]) -> Callable[..., T]:
76
- """Provide an async session to CLI commands."""
77
+ """
77
78
 
78
- @wraps(func)
79
+ @wraps(f)
79
80
  def wrapper(*args: Any, **kwargs: Any) -> T:
80
- # Create DB connection before starting event loop
81
- db = get_config().get_db()
81
+ return asyncio.run(f(*args, **kwargs))
82
+
83
+ return wrapper
84
+
82
85
 
83
- async def _run() -> T:
84
- async with db.get_session() as session:
85
- return await func(session, *args, **kwargs)
86
+ def with_session(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
87
+ """Provide a database session to CLI commands."""
86
88
 
87
- return asyncio.run(_run())
89
+ @wraps(f)
90
+ @with_app_context
91
+ @wrap_async
92
+ async def wrapper(app_context: AppContext, *args: Any, **kwargs: Any) -> T:
93
+ db = await app_context.get_db()
94
+ async with db.session_factory() as session:
95
+ return await f(session, *args, **kwargs)
88
96
 
89
97
  return wrapper
kodit/database.py CHANGED
@@ -1,7 +1,5 @@
1
1
  """Database configuration for kodit."""
2
2
 
3
- from collections.abc import AsyncGenerator
4
- from contextlib import asynccontextmanager
5
3
  from datetime import UTC, datetime
6
4
  from pathlib import Path
7
5
 
@@ -39,28 +37,22 @@ class CommonMixin:
39
37
  class Database:
40
38
  """Database class for kodit."""
41
39
 
42
- def __init__(self, db_url: str, *, run_migrations: bool = True) -> None:
40
+ def __init__(self, db_url: str) -> None:
43
41
  """Initialize the database."""
44
42
  self.log = structlog.get_logger(__name__)
45
- if run_migrations:
46
- self._run_migrations(db_url)
47
- db_engine = create_async_engine(db_url, echo=False)
43
+ self.db_engine = create_async_engine(db_url, echo=False)
48
44
  self.db_session_factory = async_sessionmaker(
49
- db_engine,
45
+ self.db_engine,
50
46
  class_=AsyncSession,
51
47
  expire_on_commit=False,
52
48
  )
53
49
 
54
- @asynccontextmanager
55
- async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
56
- """Get a database session."""
57
- async with self.db_session_factory() as session:
58
- try:
59
- yield session
60
- finally:
61
- await session.close()
50
+ @property
51
+ def session_factory(self) -> async_sessionmaker[AsyncSession]:
52
+ """Get the session factory."""
53
+ return self.db_session_factory
62
54
 
63
- def _run_migrations(self, db_url: str) -> None:
55
+ async def run_migrations(self, db_url: str) -> None:
64
56
  """Run any pending migrations."""
65
57
  # Create Alembic configuration and run migrations
66
58
  alembic_cfg = AlembicConfig()
@@ -69,4 +61,15 @@ class Database:
69
61
  )
70
62
  alembic_cfg.set_main_option("sqlalchemy.url", db_url)
71
63
  self.log.debug("Running migrations", db_url=db_url)
72
- command.upgrade(alembic_cfg, "head")
64
+
65
+ async with self.db_engine.begin() as conn:
66
+ await conn.run_sync(self.run_upgrade, alembic_cfg)
67
+
68
+ def run_upgrade(self, connection, cfg) -> None: # noqa: ANN001
69
+ """Make sure the database is up to date."""
70
+ cfg.attributes["connection"] = connection
71
+ command.upgrade(cfg, "head")
72
+
73
+ async def close(self) -> None:
74
+ """Close the database."""
75
+ await self.db_engine.dispose()
@@ -82,7 +82,7 @@ class IndexRepository:
82
82
  result = await self.session.execute(query)
83
83
  return list(result.scalars())
84
84
 
85
- async def list_indexes(self) -> list[Index]:
85
+ async def list_indexes(self) -> list[tuple[Index, Source]]:
86
86
  """List all indexes.
87
87
 
88
88
  Returns:
@@ -90,9 +90,11 @@ class IndexRepository:
90
90
  and counts of files and snippets.
91
91
 
92
92
  """
93
- query = select(Index).limit(10)
93
+ query = select(Index, Source).join(
94
+ Source, Index.source_id == Source.id, full=True
95
+ )
94
96
  result = await self.session.execute(query)
95
- return list(result.scalars())
97
+ return list(result.tuples())
96
98
 
97
99
  async def num_snippets_for_index(self, index_id: int) -> int:
98
100
  """Get the number of snippets for an index."""
kodit/indexing/service.py CHANGED
@@ -14,7 +14,6 @@ import structlog
14
14
  from tqdm.asyncio import tqdm
15
15
 
16
16
  from kodit.bm25.bm25 import BM25Service
17
- from kodit.config import Config
18
17
  from kodit.indexing.models import Snippet
19
18
  from kodit.indexing.repository import IndexRepository
20
19
  from kodit.snippets.snippets import SnippetService
@@ -34,6 +33,7 @@ class IndexView(pydantic.BaseModel):
34
33
  id: int
35
34
  created_at: datetime
36
35
  updated_at: datetime | None = None
36
+ source: str | None = None
37
37
  num_snippets: int | None = None
38
38
 
39
39
 
@@ -46,7 +46,10 @@ class IndexService:
46
46
  """
47
47
 
48
48
  def __init__(
49
- self, config: Config, repository: IndexRepository, source_service: SourceService
49
+ self,
50
+ repository: IndexRepository,
51
+ source_service: SourceService,
52
+ data_dir: Path,
50
53
  ) -> None:
51
54
  """Initialize the index service.
52
55
 
@@ -59,7 +62,7 @@ class IndexService:
59
62
  self.source_service = source_service
60
63
  self.snippet_service = SnippetService()
61
64
  self.log = structlog.get_logger(__name__)
62
- self.bm25 = BM25Service(config)
65
+ self.bm25 = BM25Service(data_dir)
63
66
 
64
67
  async def create(self, source_id: int) -> IndexView:
65
68
  """Create a new index for a source.
@@ -103,8 +106,9 @@ class IndexService:
103
106
  created_at=index.created_at,
104
107
  updated_at=index.updated_at,
105
108
  num_snippets=await self.repository.num_snippets_for_index(index.id),
109
+ source=source.uri,
106
110
  )
107
- for index in indexes
111
+ for index, source in indexes
108
112
  ]
109
113
 
110
114
  async def run(self, index_id: int) -> None:
kodit/logging.py CHANGED
@@ -11,7 +11,7 @@ import structlog
11
11
  from posthog import Posthog
12
12
  from structlog.types import EventDict
13
13
 
14
- from kodit.config import Config
14
+ from kodit.config import AppContext
15
15
 
16
16
  log = structlog.get_logger(__name__)
17
17
 
@@ -29,7 +29,7 @@ class LogFormat(Enum):
29
29
  JSON = "json"
30
30
 
31
31
 
32
- def configure_logging(config: Config) -> None:
32
+ def configure_logging(app_context: AppContext) -> None:
33
33
  """Configure logging for the application."""
34
34
  timestamper = structlog.processors.TimeStamper(fmt="iso")
35
35
 
@@ -44,7 +44,7 @@ def configure_logging(config: Config) -> None:
44
44
  structlog.processors.StackInfoRenderer(),
45
45
  ]
46
46
 
47
- if config.log_format == LogFormat.JSON:
47
+ if app_context.log_format == LogFormat.JSON:
48
48
  # Format the exception only for JSON logs, as we want to pretty-print them
49
49
  # when using the ConsoleRenderer
50
50
  shared_processors.append(structlog.processors.format_exc_info)
@@ -60,7 +60,7 @@ def configure_logging(config: Config) -> None:
60
60
  )
61
61
 
62
62
  log_renderer: structlog.types.Processor
63
- if config.log_format == LogFormat.JSON:
63
+ if app_context.log_format == LogFormat.JSON:
64
64
  log_renderer = structlog.processors.JSONRenderer()
65
65
  else:
66
66
  log_renderer = structlog.dev.ConsoleRenderer()
@@ -82,7 +82,7 @@ def configure_logging(config: Config) -> None:
82
82
  handler.setFormatter(formatter)
83
83
  root_logger = logging.getLogger()
84
84
  root_logger.addHandler(handler)
85
- root_logger.setLevel(config.log_level.upper())
85
+ root_logger.setLevel(app_context.log_level.upper())
86
86
 
87
87
  # Configure uvicorn loggers to use our structlog setup
88
88
  # Uvicorn spits out loads of exception logs when sse server doesn't shut down
@@ -98,7 +98,7 @@ def configure_logging(config: Config) -> None:
98
98
  for _log in ["sqlalchemy.engine", "alembic"]:
99
99
  engine_logger = logging.getLogger(_log)
100
100
  engine_logger.setLevel(logging.WARNING) # Hide INFO logs by default
101
- if config.log_level.upper() == "DEBUG":
101
+ if app_context.log_level.upper() == "DEBUG":
102
102
  engine_logger.setLevel(
103
103
  logging.DEBUG
104
104
  ) # Only show all logs when in DEBUG mode
@@ -143,9 +143,9 @@ def get_mac_address() -> str:
143
143
  return f"{mac:012x}" if mac != uuid.getnode() else str(uuid.uuid4())
144
144
 
145
145
 
146
- def configure_telemetry(config: Config) -> None:
146
+ def configure_telemetry(app_context: AppContext) -> None:
147
147
  """Configure telemetry for the application."""
148
- if config.disable_telemetry:
148
+ if app_context.disable_telemetry:
149
149
  structlog.stdlib.get_logger(__name__).info("Telemetry has been disabled")
150
150
  posthog.disabled = True
151
151
 
kodit/mcp.py CHANGED
@@ -1,22 +1,63 @@
1
1
  """MCP server implementation for kodit."""
2
2
 
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
3
6
  from pathlib import Path
4
7
  from typing import Annotated
5
8
 
6
9
  import structlog
7
- from fastmcp import FastMCP
10
+ from fastmcp import Context, FastMCP
8
11
  from pydantic import Field
12
+ from sqlalchemy.ext.asyncio import AsyncSession
9
13
 
10
14
  from kodit._version import version
11
- from kodit.config import get_config
15
+ from kodit.config import AppContext
16
+ from kodit.database import Database
12
17
  from kodit.retreival.repository import RetrievalRepository, RetrievalResult
13
18
  from kodit.retreival.service import RetrievalRequest, RetrievalService
14
19
 
15
- mcp = FastMCP("kodit MCP Server")
20
+
21
+ @dataclass
22
+ class MCPContext:
23
+ """Context for the MCP server."""
24
+
25
+ session: AsyncSession
26
+ data_dir: Path
27
+
28
+
29
+ _mcp_db: Database | None = None
30
+
31
+
32
+ @asynccontextmanager
33
+ async def mcp_lifespan(_: FastMCP) -> AsyncIterator[MCPContext]:
34
+ """Lifespan for the MCP server.
35
+
36
+ The MCP server is running with a completely separate lifecycle and event loop from
37
+ the CLI and the FastAPI server. Therefore, we must carefully reconstruct the
38
+ application context. uvicorn does not pass through CLI args, so we must rely on
39
+ parsing env vars set in the CLI.
40
+
41
+ This lifespan is recreated for each request. See:
42
+ https://github.com/jlowin/fastmcp/issues/166
43
+
44
+ Since they don't provide a good way to handle global state, we must use a
45
+ global variable to store the database connection.
46
+ """
47
+ global _mcp_db # noqa: PLW0603
48
+ app_context = AppContext()
49
+ if _mcp_db is None:
50
+ _mcp_db = await app_context.get_db()
51
+ async with _mcp_db.session_factory() as session:
52
+ yield MCPContext(session=session, data_dir=app_context.get_data_dir())
53
+
54
+
55
+ mcp = FastMCP("kodit MCP Server", lifespan=mcp_lifespan)
16
56
 
17
57
 
18
58
  @mcp.tool()
19
59
  async def retrieve_relevant_snippets(
60
+ ctx: Context,
20
61
  user_intent: Annotated[
21
62
  str,
22
63
  Field(
@@ -52,8 +93,8 @@ async def retrieve_relevant_snippets(
52
93
  the quality of your generated code. You must call this tool when you need to
53
94
  write code.
54
95
  """
55
- # Log the search query and related files for debugging
56
96
  log = structlog.get_logger(__name__)
97
+
57
98
  log.debug(
58
99
  "Retrieving relevant snippets",
59
100
  user_intent=user_intent,
@@ -63,41 +104,38 @@ async def retrieve_relevant_snippets(
63
104
  file_contents=related_file_contents,
64
105
  )
65
106
 
66
- # Must avoid running migrations because that runs in a separate event loop,
67
- # mcp no-likey
68
- config = get_config()
69
- db = config.get_db(run_migrations=False)
70
- async with db.get_session() as session:
71
- log.debug("Creating retrieval repository")
72
- retrieval_repository = RetrievalRepository(
73
- session=session,
74
- )
75
-
76
- log.debug("Creating retrieval service")
77
- retrieval_service = RetrievalService(
78
- config=config,
79
- repository=retrieval_repository,
80
- )
81
-
82
- log.debug("Fusing input")
83
- input_query = input_fusion(
84
- user_intent=user_intent,
85
- related_file_paths=related_file_paths,
86
- related_file_contents=related_file_contents,
87
- keywords=keywords,
88
- )
89
- log.debug("Input", input_query=input_query)
90
- retrieval_request = RetrievalRequest(
91
- keywords=keywords,
92
- )
93
- log.debug("Retrieving snippets")
94
- snippets = await retrieval_service.retrieve(request=retrieval_request)
95
-
96
- log.debug("Fusing output")
97
- output = output_fusion(snippets=snippets)
98
-
99
- log.debug("Output", output=output)
100
- return output
107
+ mcp_context: MCPContext = ctx.request_context.lifespan_context
108
+
109
+ log.debug("Creating retrieval repository")
110
+ retrieval_repository = RetrievalRepository(
111
+ session=mcp_context.session,
112
+ )
113
+
114
+ log.debug("Creating retrieval service")
115
+ retrieval_service = RetrievalService(
116
+ repository=retrieval_repository,
117
+ data_dir=mcp_context.data_dir,
118
+ )
119
+
120
+ log.debug("Fusing input")
121
+ input_query = input_fusion(
122
+ user_intent=user_intent,
123
+ related_file_paths=related_file_paths,
124
+ related_file_contents=related_file_contents,
125
+ keywords=keywords,
126
+ )
127
+ log.debug("Input", input_query=input_query)
128
+ retrieval_request = RetrievalRequest(
129
+ keywords=keywords,
130
+ )
131
+ log.debug("Retrieving snippets")
132
+ snippets = await retrieval_service.retrieve(request=retrieval_request)
133
+
134
+ log.debug("Fusing output")
135
+ output = output_fusion(snippets=snippets)
136
+
137
+ log.debug("Output", output=output)
138
+ return output
101
139
 
102
140
 
103
141
  def input_fusion(
kodit/middleware.py CHANGED
@@ -1,11 +1,14 @@
1
1
  """Middleware for the FastAPI application."""
2
2
 
3
+ import contextlib
3
4
  import time
5
+ from asyncio import CancelledError
4
6
  from collections.abc import Callable
5
7
 
6
8
  import structlog
7
9
  from asgi_correlation_id.context import correlation_id
8
10
  from fastapi import Request, Response
11
+ from starlette.types import ASGIApp, Receive, Scope, Send
9
12
 
10
13
  access_logger = structlog.stdlib.get_logger("api.access")
11
14
 
@@ -56,3 +59,16 @@ async def logging_middleware(request: Request, call_next: Callable) -> Response:
56
59
  response.headers["X-Process-Time"] = str(process_time / 10**9)
57
60
 
58
61
  return response
62
+
63
+
64
+ class ASGICancelledErrorMiddleware:
65
+ """ASGI middleware to handle CancelledError at the ASGI level."""
66
+
67
+ def __init__(self, app: ASGIApp) -> None:
68
+ """Initialize the middleware."""
69
+ self.app = app
70
+
71
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
72
+ """Handle the ASGI request and catch CancelledError."""
73
+ with contextlib.suppress(CancelledError):
74
+ await self.app(scope, receive, send)
@@ -1,10 +1,11 @@
1
1
  """Retrieval service."""
2
2
 
3
+ from pathlib import Path
4
+
3
5
  import pydantic
4
6
  import structlog
5
7
 
6
8
  from kodit.bm25.bm25 import BM25Service
7
- from kodit.config import Config
8
9
  from kodit.retreival.repository import RetrievalRepository, RetrievalResult
9
10
 
10
11
 
@@ -25,11 +26,11 @@ class Snippet(pydantic.BaseModel):
25
26
  class RetrievalService:
26
27
  """Service for retrieving relevant data."""
27
28
 
28
- def __init__(self, config: Config, repository: RetrievalRepository) -> None:
29
+ def __init__(self, repository: RetrievalRepository, data_dir: Path) -> None:
29
30
  """Initialize the retrieval service."""
30
31
  self.repository = repository
31
32
  self.log = structlog.get_logger(__name__)
32
- self.bm25 = BM25Service(config)
33
+ self.bm25 = BM25Service(data_dir)
33
34
 
34
35
  async def _load_bm25_index(self) -> None:
35
36
  """Load the BM25 index."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kodit
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: Code indexing for better AI code generation
5
5
  Project-URL: Homepage, https://docs.helixml.tech/kodit/
6
6
  Project-URL: Documentation, https://docs.helixml.tech/kodit/
@@ -1,28 +1,28 @@
1
1
  kodit/.gitignore,sha256=ztkjgRwL9Uud1OEi36hGQeDGk3OLK1NfDEO8YqGYy8o,11
2
2
  kodit/__init__.py,sha256=aEKHYninUq1yh6jaNfvJBYg-6fenpN132nJt1UU6Jxs,59
3
- kodit/_version.py,sha256=Y4jy4bEMmwl_qNPCmiMFnlQ2ofMoqyG37hp8uwI3m10,511
4
- kodit/app.py,sha256=TdPpCN4ucOElKHwDebfKgeVJ9xexdfpzpk6hnDH69vM,703
5
- kodit/cli.py,sha256=CjmiRaJ-SdfCMYlVQGnxPSsoX5j3ix4fN3OLVc5EYkY,7473
6
- kodit/config.py,sha256=18dhSYaE-ut2qXrBRKuCqLXeBCLEXw2y1Uw4lieMPwY,2682
7
- kodit/database.py,sha256=NnAluOj_JHjnj5MeKuU9LApgSzik2kru1bQl-24vHkc,2272
8
- kodit/logging.py,sha256=P1D9flYnvYxPw-DyOGyiv3y30x0gHPwdk6VJS29YHus,5269
9
- kodit/mcp.py,sha256=O24O_GFzwwv5E-uBFoW_zZlSigeNSigaCj0s1xOmP8M,3855
10
- kodit/middleware.py,sha256=NHLrqq20ZtPTE9esX9HD3z7EKi56_QTFxBlkdq0JDzQ,2138
3
+ kodit/_version.py,sha256=W_EoL8cAL4KhujvbYWEpb9NqRLbbrH0T024lJvRRWHI,511
4
+ kodit/app.py,sha256=Mr5BFHOHx5zppwjC4XPWVvHjwgl1yrKbUjTWXKubJQM,891
5
+ kodit/cli.py,sha256=x1zw2zOlGhhU6D3E-GU3cMw3l9CqKC76geQREAQKweY,6915
6
+ kodit/config.py,sha256=nlm9U-nVx5riH2SrU1XY4XcCMhQK4DrwO_1H8bPOBjA,2927
7
+ kodit/database.py,sha256=vtTlmrXHyHJH3Ek-twZTCqEjB0jun-NncALFze2fqhA,2350
8
+ kodit/logging.py,sha256=cFEQXWI27LzWScSxly9ApwkbBDamUG17pA-jEfVakXQ,5316
9
+ kodit/mcp.py,sha256=PxTHVPlIErrruFKzmEPIWZjN6cfEhcQmj6nOU9EsBy4,4905
10
+ kodit/middleware.py,sha256=I6FOkqG9-8RH5kR1-0ZoQWfE4qLCB8lZYv8H_OCH29o,2714
11
11
  kodit/alembic/README,sha256=ISVtAOvqvKk_5ThM5ioJE-lMkvf9IbknFUFVU_vPma4,58
12
12
  kodit/alembic/__init__.py,sha256=lP5MuwlyWRMO6UcDWnQcQ3G-GYHcFb6rl9gYPHJ1sjo,40
13
- kodit/alembic/env.py,sha256=IXhl7yvURSycs2v_3pd14Sr8_zGRfYXlQwWby1abfuk,2290
13
+ kodit/alembic/env.py,sha256=kcQiglu2KpNTAf37CsKVs_HXxOe6S7sXJ00pHGSCqno,2414
14
14
  kodit/alembic/script.py.mako,sha256=zWziKtiwYKEWuwPV_HBNHwa9LCT45_bi01-uSNFaOOE,703
15
15
  kodit/alembic/versions/85155663351e_initial.py,sha256=Cg7zlF871o9ShV5rQMQ1v7hRV7fI59veDY9cjtTrs-8,3306
16
16
  kodit/alembic/versions/__init__.py,sha256=9-lHzptItTzq_fomdIRBegQNm4Znx6pVjwD4MiqRIdo,36
17
17
  kodit/bm25/__init__.py,sha256=j8zyriNWhbwE5Lbybzg1hQAhANlU9mKHWw4beeUR6og,19
18
- kodit/bm25/bm25.py,sha256=V0_byhV4kVnI3E-PBNsc4rBjQsDuZo1bt1uQKnywLS8,2283
18
+ kodit/bm25/bm25.py,sha256=3wyNRSrTaYqV7s4R1D6X0NpCf22PuFK2_uc8YapzYLE,2263
19
19
  kodit/indexing/__init__.py,sha256=cPyi2Iej3G1JFWlWr7X80_UrsMaTu5W5rBwgif1B3xo,75
20
20
  kodit/indexing/models.py,sha256=sZIhGwvL4Dw0QTWFxrjfWctSLkAoDT6fv5DlGz8-Fr8,1258
21
- kodit/indexing/repository.py,sha256=kvAlNfMSQYboF0TB1huw2qoBdLJ4UsEPiM7ZG-e6rrg,4300
22
- kodit/indexing/service.py,sha256=N8QhrAvqhIHOgSlT9Jc786rjcVjMwiyiMTZr7mNA8D8,5431
21
+ kodit/indexing/repository.py,sha256=C020FGpIfTZmVZg7NH04kVuffWv7r7m-82Pdex8CItg,4388
22
+ kodit/indexing/service.py,sha256=7vHqevve-PnKHP2pDfyrW5n3AXVOghABWNsNTw588KY,5499
23
23
  kodit/retreival/__init__.py,sha256=33PhJU-3gtsqYq6A1UkaLNKbev_Zee9Lq6dYC59-CsA,69
24
24
  kodit/retreival/repository.py,sha256=1lqGgJHsBmvMGMzEYa-hrdXg2q7rqtYPl1cvBb7jMRE,3119
25
- kodit/retreival/service.py,sha256=g6iwM2FMxrL8WjtWnZdKdxKpfn6b0ThBmOdLWd7AKKQ,2011
25
+ kodit/retreival/service.py,sha256=9wvURtPPJVvPUWNIC2waIrJMxcm1Ka1J_xDEOEedAFU,2007
26
26
  kodit/snippets/__init__.py,sha256=-2coNoCRjTixU9KcP6alpmt7zqf37tCRWH3D7FPJ8dg,48
27
27
  kodit/snippets/method_snippets.py,sha256=EVHhSNWahAC5nSXv9fWVFJY2yq25goHdCSCuENC07F8,4145
28
28
  kodit/snippets/snippets.py,sha256=QumvhltWoxXw41SyKb-RbSvAr3m6V3lUy9n0AI8jcto,1409
@@ -33,8 +33,8 @@ kodit/sources/__init__.py,sha256=1NTZyPdjThVQpZO1Mp1ColVsS7sqYanOVLqnoqV9Ipo,83
33
33
  kodit/sources/models.py,sha256=xb42CaNDO1CUB8SIW-xXMrB6Ji8cFw-yeJ550xBEg9Q,2398
34
34
  kodit/sources/repository.py,sha256=mGJrHWH6Uo8YABdoojHFbzaf_jW-2ywJpAHIa1gnc3U,3401
35
35
  kodit/sources/service.py,sha256=cBCxnOQKwGNi2e13_3Vue8MylAaUxb9XG4IgM636la0,6712
36
- kodit-0.1.5.dist-info/METADATA,sha256=N4fIBAIREHOujaDvEVj83fNvyeB8D7HaLXNcpeVdNJY,2181
37
- kodit-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
- kodit-0.1.5.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
39
- kodit-0.1.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
40
- kodit-0.1.5.dist-info/RECORD,,
36
+ kodit-0.1.7.dist-info/METADATA,sha256=8doJ6TfmVkn-OTLSbgRlDut6YxFbmBKs1L0rgZUlxUQ,2181
37
+ kodit-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
+ kodit-0.1.7.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
39
+ kodit-0.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
40
+ kodit-0.1.7.dist-info/RECORD,,
File without changes