kodit 0.1.5__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (78) hide show
  1. {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/test.yaml +3 -0
  2. {kodit-0.1.5 → kodit-0.1.6}/PKG-INFO +1 -1
  3. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/_version.py +2 -2
  4. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/env.py +5 -2
  5. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/app.py +5 -1
  6. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/bm25/bm25.py +4 -4
  7. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/cli.py +56 -24
  8. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/config.py +34 -26
  9. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/database.py +20 -17
  10. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/indexing/service.py +5 -3
  11. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/logging.py +8 -8
  12. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/mcp.py +77 -39
  13. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/middleware.py +16 -0
  14. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/retreival/service.py +4 -3
  15. {kodit-0.1.5 → kodit-0.1.6}/tests/conftest.py +18 -0
  16. kodit-0.1.6/tests/kodit/cli_test.py +75 -0
  17. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/indexing/test_service.py +5 -3
  18. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/mcp_test.py +14 -0
  19. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/retreival/test_service.py +5 -3
  20. {kodit-0.1.5 → kodit-0.1.6}/tests/smoke.sh +6 -2
  21. kodit-0.1.5/tests/kodit/cli_test.py +0 -51
  22. {kodit-0.1.5 → kodit-0.1.6}/.cursor/rules/kodit.mdc +0 -0
  23. {kodit-0.1.5 → kodit-0.1.6}/.github/CODE_OF_CONDUCT.md +0 -0
  24. {kodit-0.1.5 → kodit-0.1.6}/.github/CONTRIBUTING.md +0 -0
  25. {kodit-0.1.5 → kodit-0.1.6}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  26. {kodit-0.1.5 → kodit-0.1.6}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  27. {kodit-0.1.5 → kodit-0.1.6}/.github/PULL_REQUEST_TEMPLATE.md +0 -0
  28. {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/docker.yaml +0 -0
  29. {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/docs.yaml +0 -0
  30. {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/pypi-test.yaml +0 -0
  31. {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/pypi.yaml +0 -0
  32. {kodit-0.1.5 → kodit-0.1.6}/.gitignore +0 -0
  33. {kodit-0.1.5 → kodit-0.1.6}/.python-version +0 -0
  34. {kodit-0.1.5 → kodit-0.1.6}/.vscode/launch.json +0 -0
  35. {kodit-0.1.5 → kodit-0.1.6}/.vscode/settings.json +0 -0
  36. {kodit-0.1.5 → kodit-0.1.6}/Dockerfile +0 -0
  37. {kodit-0.1.5 → kodit-0.1.6}/LICENSE +0 -0
  38. {kodit-0.1.5 → kodit-0.1.6}/README.md +0 -0
  39. {kodit-0.1.5 → kodit-0.1.6}/alembic.ini +0 -0
  40. {kodit-0.1.5 → kodit-0.1.6}/docs/_index.md +0 -0
  41. {kodit-0.1.5 → kodit-0.1.6}/docs/developer/index.md +0 -0
  42. {kodit-0.1.5 → kodit-0.1.6}/pyproject.toml +0 -0
  43. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/.gitignore +0 -0
  44. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/__init__.py +0 -0
  45. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/README +0 -0
  46. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/__init__.py +0 -0
  47. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/script.py.mako +0 -0
  48. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/versions/85155663351e_initial.py +0 -0
  49. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/versions/__init__.py +0 -0
  50. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/bm25/__init__.py +0 -0
  51. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/indexing/__init__.py +0 -0
  52. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/indexing/models.py +0 -0
  53. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/indexing/repository.py +0 -0
  54. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/retreival/__init__.py +0 -0
  55. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/retreival/repository.py +0 -0
  56. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/__init__.py +0 -0
  57. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/languages/__init__.py +0 -0
  58. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/languages/csharp.scm +0 -0
  59. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/languages/python.scm +0 -0
  60. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/method_snippets.py +0 -0
  61. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/snippets.py +0 -0
  62. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/sources/__init__.py +0 -0
  63. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/sources/models.py +0 -0
  64. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/sources/repository.py +0 -0
  65. {kodit-0.1.5 → kodit-0.1.6}/src/kodit/sources/service.py +0 -0
  66. {kodit-0.1.5 → kodit-0.1.6}/tests/__init__.py +0 -0
  67. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/__init__.py +0 -0
  68. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/e2e.py +0 -0
  69. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/indexing/__init__.py +0 -0
  70. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/retreival/__init__.py +0 -0
  71. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/__init__.py +0 -0
  72. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/csharp.cs +0 -0
  73. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/detect_language_test.py +0 -0
  74. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/method_extraction_test.py +0 -0
  75. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/python.py +0 -0
  76. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/sources/__init__.py +0 -0
  77. {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/sources/test_service.py +0 -0
  78. {kodit-0.1.5 → kodit-0.1.6}/uv.lock +0 -0
@@ -100,5 +100,8 @@ jobs:
100
100
  - name: Run simple version command test
101
101
  run: kodit version
102
102
 
103
+ - name: Delete kodit data_dir
104
+ run: rm -rf ${HOME}/.kodit
105
+
103
106
  - name: Run smoke test
104
107
  run: ./tests/smoke.sh
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kodit
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: Code indexing for better AI code generation
5
5
  Project-URL: Homepage, https://docs.helixml.tech/kodit/
6
6
  Project-URL: Documentation, https://docs.helixml.tech/kodit/
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.5'
21
- __version_tuple__ = version_tuple = (0, 1, 5)
20
+ __version__ = version = '0.1.6'
21
+ __version_tuple__ = version_tuple = (0, 1, 6)
@@ -3,7 +3,6 @@
3
3
 
4
4
  import asyncio
5
5
 
6
- import structlog
7
6
  from alembic import context
8
7
  from sqlalchemy import pool
9
8
  from sqlalchemy.engine import Connection
@@ -75,7 +74,11 @@ async def run_async_migrations() -> None:
75
74
 
76
75
  def run_migrations_online() -> None:
77
76
  """Run migrations in 'online' mode."""
78
- asyncio.run(run_async_migrations())
77
+ connectable = config.attributes.get("connection", None)
78
+ if connectable is None:
79
+ asyncio.run(run_async_migrations())
80
+ else:
81
+ do_run_migrations(connectable)
79
82
 
80
83
 
81
84
  if context.is_offline_mode():
@@ -4,7 +4,7 @@ from asgi_correlation_id import CorrelationIdMiddleware
4
4
  from fastapi import FastAPI
5
5
 
6
6
  from kodit.mcp import mcp
7
- from kodit.middleware import logging_middleware
7
+ from kodit.middleware import ASGICancelledErrorMiddleware, logging_middleware
8
8
 
9
9
  # See https://gofastmcp.com/deployment/asgi#fastapi-integration
10
10
  mcp_app = mcp.sse_app()
@@ -23,3 +23,7 @@ async def root() -> dict[str, str]:
23
23
 
24
24
  # Add mcp routes last, otherwise previous routes aren't added
25
25
  app.mount("", mcp_app)
26
+
27
+ # Wrap the entire app with ASGI middleware after all routes are added to suppress
28
+ # CancelledError at the ASGI level
29
+ app = ASGICancelledErrorMiddleware(app)
@@ -1,20 +1,20 @@
1
1
  """BM25 service."""
2
2
 
3
+ from pathlib import Path
4
+
3
5
  import bm25s
4
6
  import Stemmer
5
7
  import structlog
6
8
  from bm25s.tokenization import Tokenized
7
9
 
8
- from kodit.config import Config
9
-
10
10
 
11
11
  class BM25Service:
12
12
  """Service for BM25."""
13
13
 
14
- def __init__(self, config: Config) -> None:
14
+ def __init__(self, data_dir: Path) -> None:
15
15
  """Initialize the BM25 service."""
16
16
  self.log = structlog.get_logger(__name__)
17
- self.index_path = config.get_data_dir() / "bm25s_index"
17
+ self.index_path = data_dir / "bm25s_index"
18
18
  try:
19
19
  self.log.debug("Loading BM25 index")
20
20
  self.retriever = bm25s.BM25.load(self.index_path, mmap=True)
@@ -17,8 +17,8 @@ from kodit.config import (
17
17
  DEFAULT_DISABLE_TELEMETRY,
18
18
  DEFAULT_LOG_FORMAT,
19
19
  DEFAULT_LOG_LEVEL,
20
- get_config,
21
- reset_config,
20
+ AppContext,
21
+ with_app_context,
22
22
  with_session,
23
23
  )
24
24
  from kodit.indexing.repository import IndexRepository
@@ -40,23 +40,33 @@ from kodit.sources.service import SourceService
40
40
  )
41
41
  @click.option("--db-url", help=f"Database URL [default: {DEFAULT_DB_URL}]")
42
42
  @click.option("--data-dir", help=f"Data directory [default: {DEFAULT_BASE_DIR}]")
43
- @click.option("--env-file", help="Path to a .env file [default: .env]")
43
+ @click.option(
44
+ "--env-file",
45
+ help="Path to a .env file [default: .env]",
46
+ type=click.Path(
47
+ exists=True,
48
+ dir_okay=False,
49
+ resolve_path=True,
50
+ path_type=Path,
51
+ ),
52
+ )
53
+ @click.pass_context
44
54
  def cli( # noqa: PLR0913
55
+ ctx: click.Context,
45
56
  log_level: str | None,
46
57
  log_format: str | None,
47
58
  disable_telemetry: bool | None,
48
59
  db_url: str | None,
49
60
  data_dir: str | None,
50
- env_file: str | None,
61
+ env_file: Path | None,
51
62
  ) -> None:
52
63
  """kodit CLI - Code indexing for better AI code generation.""" # noqa: D403
64
+ config = AppContext()
53
65
  # First check if env-file is set and reload config if it is
54
66
  if env_file:
55
- reset_config()
56
- get_config(env_file)
67
+ config = AppContext(_env_file=env_file) # type: ignore[reportCallIssue]
57
68
 
58
- # Override global config with cli args, if set
59
- config = get_config()
69
+ # Now override with CLI arguments, if set
60
70
  if data_dir:
61
71
  config.data_dir = Path(data_dir)
62
72
  if db_url:
@@ -70,6 +80,9 @@ def cli( # noqa: PLR0913
70
80
  configure_logging(config)
71
81
  configure_telemetry(config)
72
82
 
83
+ # Set the app context in the click context for downstream cli
84
+ ctx.obj = config
85
+
73
86
 
74
87
  @cli.group()
75
88
  def sources() -> None:
@@ -77,11 +90,12 @@ def sources() -> None:
77
90
 
78
91
 
79
92
  @sources.command(name="list")
93
+ @with_app_context
80
94
  @with_session
81
- async def list_sources(session: AsyncSession) -> None:
95
+ async def list_sources(session: AsyncSession, app_context: AppContext) -> None:
82
96
  """List all code sources."""
83
97
  repository = SourceRepository(session)
84
- service = SourceService(get_config().get_clone_dir(), repository)
98
+ service = SourceService(app_context.get_clone_dir(), repository)
85
99
  sources = await service.list_sources()
86
100
 
87
101
  # Define headers and data
@@ -95,11 +109,14 @@ async def list_sources(session: AsyncSession) -> None:
95
109
 
96
110
  @sources.command(name="create")
97
111
  @click.argument("uri")
112
+ @with_app_context
98
113
  @with_session
99
- async def create_source(session: AsyncSession, uri: str) -> None:
114
+ async def create_source(
115
+ session: AsyncSession, app_context: AppContext, uri: str
116
+ ) -> None:
100
117
  """Add a new code source."""
101
118
  repository = SourceRepository(session)
102
- service = SourceService(get_config().get_clone_dir(), repository)
119
+ service = SourceService(app_context.get_clone_dir(), repository)
103
120
  source = await service.create(uri)
104
121
  click.echo(f"Source created: {source.id}")
105
122
 
@@ -111,25 +128,29 @@ def indexes() -> None:
111
128
 
112
129
  @indexes.command(name="create")
113
130
  @click.argument("source_id")
131
+ @with_app_context
114
132
  @with_session
115
- async def create_index(session: AsyncSession, source_id: int) -> None:
133
+ async def create_index(
134
+ session: AsyncSession, app_context: AppContext, source_id: int
135
+ ) -> None:
116
136
  """Create an index for a source."""
117
137
  source_repository = SourceRepository(session)
118
- source_service = SourceService(get_config().get_clone_dir(), source_repository)
138
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
119
139
  repository = IndexRepository(session)
120
- service = IndexService(get_config(), repository, source_service)
140
+ service = IndexService(repository, source_service, app_context.get_data_dir())
121
141
  index = await service.create(source_id)
122
142
  click.echo(f"Index created: {index.id}")
123
143
 
124
144
 
125
145
  @indexes.command(name="list")
146
+ @with_app_context
126
147
  @with_session
127
- async def list_indexes(session: AsyncSession) -> None:
148
+ async def list_indexes(session: AsyncSession, app_context: AppContext) -> None:
128
149
  """List all indexes."""
129
150
  source_repository = SourceRepository(session)
130
- source_service = SourceService(get_config().get_clone_dir(), source_repository)
151
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
131
152
  repository = IndexRepository(session)
132
- service = IndexService(get_config(), repository, source_service)
153
+ service = IndexService(repository, source_service, app_context.get_data_dir())
133
154
  indexes = await service.list_indexes()
134
155
 
135
156
  # Define headers and data
@@ -156,24 +177,30 @@ async def list_indexes(session: AsyncSession) -> None:
156
177
 
157
178
  @indexes.command(name="run")
158
179
  @click.argument("index_id")
180
+ @with_app_context
159
181
  @with_session
160
- async def run_index(session: AsyncSession, index_id: int) -> None:
182
+ async def run_index(
183
+ session: AsyncSession, app_context: AppContext, index_id: int
184
+ ) -> None:
161
185
  """Run an index."""
162
186
  source_repository = SourceRepository(session)
163
- source_service = SourceService(get_config().get_clone_dir(), source_repository)
187
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
164
188
  repository = IndexRepository(session)
165
- service = IndexService(get_config(), repository, source_service)
189
+ service = IndexService(repository, source_service, app_context.get_data_dir())
166
190
  await service.run(index_id)
167
191
 
168
192
 
169
193
  @cli.command()
170
194
  @click.argument("query")
171
195
  @click.option("--top-k", default=10, help="Number of snippets to retrieve")
196
+ @with_app_context
172
197
  @with_session
173
- async def retrieve(session: AsyncSession, query: str, top_k: int) -> None:
198
+ async def retrieve(
199
+ session: AsyncSession, app_context: AppContext, query: str, top_k: int
200
+ ) -> None:
174
201
  """Retrieve snippets from the database."""
175
202
  repository = RetrievalRepository(session)
176
- service = RetrievalService(get_config(), repository)
203
+ service = RetrievalService(repository, app_context.get_data_dir())
177
204
  # Temporary request while we don't have all search capabilities
178
205
  snippets = await service.retrieve(
179
206
  RetrievalRequest(keywords=query.split(","), top_k=top_k)
@@ -194,7 +221,9 @@ async def retrieve(session: AsyncSession, query: str, top_k: int) -> None:
194
221
  @cli.command()
195
222
  @click.option("--host", default="127.0.0.1", help="Host to bind the server to")
196
223
  @click.option("--port", default=8080, help="Port to bind the server to")
224
+ @with_app_context
197
225
  def serve(
226
+ app_context: AppContext,
198
227
  host: str,
199
228
  port: int,
200
229
  ) -> None:
@@ -202,7 +231,10 @@ def serve(
202
231
  log = structlog.get_logger(__name__)
203
232
  log.info("Starting kodit server", host=host, port=port)
204
233
  log_event("kodit_server_started")
205
- os.environ["HELLO"] = "WORLD"
234
+
235
+ # Dump AppContext to a dictionary of strings, and set the env vars
236
+ app_context_dict = {k: str(v) for k, v in app_context.model_dump().items()}
237
+ os.environ.update(app_context_dict)
206
238
 
207
239
  # Configure uvicorn with graceful shutdown
208
240
  config = uvicorn.Config(
@@ -1,11 +1,12 @@
1
1
  """Global configuration for the kodit project."""
2
2
 
3
3
  import asyncio
4
- from collections.abc import Callable
4
+ from collections.abc import Callable, Coroutine
5
5
  from functools import wraps
6
6
  from pathlib import Path
7
7
  from typing import Any, TypeVar
8
8
 
9
+ import click
9
10
  from pydantic import Field
10
11
  from pydantic_settings import BaseSettings, SettingsConfigDict
11
12
 
@@ -19,8 +20,8 @@ DEFAULT_DISABLE_TELEMETRY = False
19
20
  T = TypeVar("T")
20
21
 
21
22
 
22
- class Config(BaseSettings):
23
- """Global configuration for the kodit project."""
23
+ class AppContext(BaseSettings):
24
+ """Global context for the kodit project. Provides a shared state for the app."""
24
25
 
25
26
  model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
26
27
 
@@ -47,43 +48,50 @@ class Config(BaseSettings):
47
48
  clone_dir.mkdir(parents=True, exist_ok=True)
48
49
  return clone_dir
49
50
 
50
- def get_db(self, *, run_migrations: bool = True) -> Database:
51
+ async def get_db(self, *, run_migrations: bool = True) -> Database:
51
52
  """Get the database."""
52
53
  if self._db is None:
53
- self._db = Database(self.db_url, run_migrations=run_migrations)
54
+ self._db = Database(self.db_url)
55
+ if run_migrations:
56
+ await self._db.run_migrations(self.db_url)
54
57
  return self._db
55
58
 
56
59
 
57
- # Global config instance for mcp Apps
58
- config = None
60
+ with_app_context = click.make_pass_decorator(AppContext)
59
61
 
62
+ T = TypeVar("T")
60
63
 
61
- def get_config(env_file: str | None = None) -> Config:
62
- """Get the global config instance."""
63
- global config # noqa: PLW0603
64
- if config is None:
65
- config = Config(_env_file=env_file)
66
- return config
67
64
 
65
+ def wrap_async(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
66
+ """Decorate async Click commands.
68
67
 
69
- def reset_config() -> None:
70
- """Reset the global config instance."""
71
- global config # noqa: PLW0603
72
- config = None
68
+ This decorator wraps an async function to run it with asyncio.run().
69
+ It should be used after the Click command decorator.
73
70
 
71
+ Example:
72
+ @cli.command()
73
+ @wrap_async
74
+ async def my_command():
75
+ ...
74
76
 
75
- def with_session(func: Callable[..., T]) -> Callable[..., T]:
76
- """Provide an async session to CLI commands."""
77
+ """
77
78
 
78
- @wraps(func)
79
+ @wraps(f)
79
80
  def wrapper(*args: Any, **kwargs: Any) -> T:
80
- # Create DB connection before starting event loop
81
- db = get_config().get_db()
81
+ return asyncio.run(f(*args, **kwargs))
82
+
83
+ return wrapper
84
+
82
85
 
83
- async def _run() -> T:
84
- async with db.get_session() as session:
85
- return await func(session, *args, **kwargs)
86
+ def with_session(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
87
+ """Provide a database session to CLI commands."""
86
88
 
87
- return asyncio.run(_run())
89
+ @wraps(f)
90
+ @with_app_context
91
+ @wrap_async
92
+ async def wrapper(app_context: AppContext, *args: Any, **kwargs: Any) -> T:
93
+ db = await app_context.get_db()
94
+ async with db.session_factory() as session:
95
+ return await f(session, *args, **kwargs)
88
96
 
89
97
  return wrapper
@@ -1,7 +1,5 @@
1
1
  """Database configuration for kodit."""
2
2
 
3
- from collections.abc import AsyncGenerator
4
- from contextlib import asynccontextmanager
5
3
  from datetime import UTC, datetime
6
4
  from pathlib import Path
7
5
 
@@ -39,28 +37,22 @@ class CommonMixin:
39
37
  class Database:
40
38
  """Database class for kodit."""
41
39
 
42
- def __init__(self, db_url: str, *, run_migrations: bool = True) -> None:
40
+ def __init__(self, db_url: str) -> None:
43
41
  """Initialize the database."""
44
42
  self.log = structlog.get_logger(__name__)
45
- if run_migrations:
46
- self._run_migrations(db_url)
47
- db_engine = create_async_engine(db_url, echo=False)
43
+ self.db_engine = create_async_engine(db_url, echo=False)
48
44
  self.db_session_factory = async_sessionmaker(
49
- db_engine,
45
+ self.db_engine,
50
46
  class_=AsyncSession,
51
47
  expire_on_commit=False,
52
48
  )
53
49
 
54
- @asynccontextmanager
55
- async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
56
- """Get a database session."""
57
- async with self.db_session_factory() as session:
58
- try:
59
- yield session
60
- finally:
61
- await session.close()
50
+ @property
51
+ def session_factory(self) -> async_sessionmaker[AsyncSession]:
52
+ """Get the session factory."""
53
+ return self.db_session_factory
62
54
 
63
- def _run_migrations(self, db_url: str) -> None:
55
+ async def run_migrations(self, db_url: str) -> None:
64
56
  """Run any pending migrations."""
65
57
  # Create Alembic configuration and run migrations
66
58
  alembic_cfg = AlembicConfig()
@@ -69,4 +61,15 @@ class Database:
69
61
  )
70
62
  alembic_cfg.set_main_option("sqlalchemy.url", db_url)
71
63
  self.log.debug("Running migrations", db_url=db_url)
72
- command.upgrade(alembic_cfg, "head")
64
+
65
+ async with self.db_engine.begin() as conn:
66
+ await conn.run_sync(self.run_upgrade, alembic_cfg)
67
+
68
+ def run_upgrade(self, connection, cfg) -> None: # noqa: ANN001
69
+ """Make sure the database is up to date."""
70
+ cfg.attributes["connection"] = connection
71
+ command.upgrade(cfg, "head")
72
+
73
+ async def close(self) -> None:
74
+ """Close the database."""
75
+ await self.db_engine.dispose()
@@ -14,7 +14,6 @@ import structlog
14
14
  from tqdm.asyncio import tqdm
15
15
 
16
16
  from kodit.bm25.bm25 import BM25Service
17
- from kodit.config import Config
18
17
  from kodit.indexing.models import Snippet
19
18
  from kodit.indexing.repository import IndexRepository
20
19
  from kodit.snippets.snippets import SnippetService
@@ -46,7 +45,10 @@ class IndexService:
46
45
  """
47
46
 
48
47
  def __init__(
49
- self, config: Config, repository: IndexRepository, source_service: SourceService
48
+ self,
49
+ repository: IndexRepository,
50
+ source_service: SourceService,
51
+ data_dir: Path,
50
52
  ) -> None:
51
53
  """Initialize the index service.
52
54
 
@@ -59,7 +61,7 @@ class IndexService:
59
61
  self.source_service = source_service
60
62
  self.snippet_service = SnippetService()
61
63
  self.log = structlog.get_logger(__name__)
62
- self.bm25 = BM25Service(config)
64
+ self.bm25 = BM25Service(data_dir)
63
65
 
64
66
  async def create(self, source_id: int) -> IndexView:
65
67
  """Create a new index for a source.
@@ -11,7 +11,7 @@ import structlog
11
11
  from posthog import Posthog
12
12
  from structlog.types import EventDict
13
13
 
14
- from kodit.config import Config
14
+ from kodit.config import AppContext
15
15
 
16
16
  log = structlog.get_logger(__name__)
17
17
 
@@ -29,7 +29,7 @@ class LogFormat(Enum):
29
29
  JSON = "json"
30
30
 
31
31
 
32
- def configure_logging(config: Config) -> None:
32
+ def configure_logging(app_context: AppContext) -> None:
33
33
  """Configure logging for the application."""
34
34
  timestamper = structlog.processors.TimeStamper(fmt="iso")
35
35
 
@@ -44,7 +44,7 @@ def configure_logging(config: Config) -> None:
44
44
  structlog.processors.StackInfoRenderer(),
45
45
  ]
46
46
 
47
- if config.log_format == LogFormat.JSON:
47
+ if app_context.log_format == LogFormat.JSON:
48
48
  # Format the exception only for JSON logs, as we want to pretty-print them
49
49
  # when using the ConsoleRenderer
50
50
  shared_processors.append(structlog.processors.format_exc_info)
@@ -60,7 +60,7 @@ def configure_logging(config: Config) -> None:
60
60
  )
61
61
 
62
62
  log_renderer: structlog.types.Processor
63
- if config.log_format == LogFormat.JSON:
63
+ if app_context.log_format == LogFormat.JSON:
64
64
  log_renderer = structlog.processors.JSONRenderer()
65
65
  else:
66
66
  log_renderer = structlog.dev.ConsoleRenderer()
@@ -82,7 +82,7 @@ def configure_logging(config: Config) -> None:
82
82
  handler.setFormatter(formatter)
83
83
  root_logger = logging.getLogger()
84
84
  root_logger.addHandler(handler)
85
- root_logger.setLevel(config.log_level.upper())
85
+ root_logger.setLevel(app_context.log_level.upper())
86
86
 
87
87
  # Configure uvicorn loggers to use our structlog setup
88
88
  # Uvicorn spits out loads of exception logs when sse server doesn't shut down
@@ -98,7 +98,7 @@ def configure_logging(config: Config) -> None:
98
98
  for _log in ["sqlalchemy.engine", "alembic"]:
99
99
  engine_logger = logging.getLogger(_log)
100
100
  engine_logger.setLevel(logging.WARNING) # Hide INFO logs by default
101
- if config.log_level.upper() == "DEBUG":
101
+ if app_context.log_level.upper() == "DEBUG":
102
102
  engine_logger.setLevel(
103
103
  logging.DEBUG
104
104
  ) # Only show all logs when in DEBUG mode
@@ -143,9 +143,9 @@ def get_mac_address() -> str:
143
143
  return f"{mac:012x}" if mac != uuid.getnode() else str(uuid.uuid4())
144
144
 
145
145
 
146
- def configure_telemetry(config: Config) -> None:
146
+ def configure_telemetry(app_context: AppContext) -> None:
147
147
  """Configure telemetry for the application."""
148
- if config.disable_telemetry:
148
+ if app_context.disable_telemetry:
149
149
  structlog.stdlib.get_logger(__name__).info("Telemetry has been disabled")
150
150
  posthog.disabled = True
151
151
 
@@ -1,22 +1,63 @@
1
1
  """MCP server implementation for kodit."""
2
2
 
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
3
6
  from pathlib import Path
4
7
  from typing import Annotated
5
8
 
6
9
  import structlog
7
- from fastmcp import FastMCP
10
+ from fastmcp import Context, FastMCP
8
11
  from pydantic import Field
12
+ from sqlalchemy.ext.asyncio import AsyncSession
9
13
 
10
14
  from kodit._version import version
11
- from kodit.config import get_config
15
+ from kodit.config import AppContext
16
+ from kodit.database import Database
12
17
  from kodit.retreival.repository import RetrievalRepository, RetrievalResult
13
18
  from kodit.retreival.service import RetrievalRequest, RetrievalService
14
19
 
15
- mcp = FastMCP("kodit MCP Server")
20
+
21
+ @dataclass
22
+ class MCPContext:
23
+ """Context for the MCP server."""
24
+
25
+ session: AsyncSession
26
+ data_dir: Path
27
+
28
+
29
+ _mcp_db: Database | None = None
30
+
31
+
32
+ @asynccontextmanager
33
+ async def mcp_lifespan(_: FastMCP) -> AsyncIterator[MCPContext]:
34
+ """Lifespan for the MCP server.
35
+
36
+ The MCP server is running with a completely separate lifecycle and event loop from
37
+ the CLI and the FastAPI server. Therefore, we must carefully reconstruct the
38
+ application context. uvicorn does not pass through CLI args, so we must rely on
39
+ parsing env vars set in the CLI.
40
+
41
+ This lifespan is recreated for each request. See:
42
+ https://github.com/jlowin/fastmcp/issues/166
43
+
44
+ Since they don't provide a good way to handle global state, we must use a
45
+ global variable to store the database connection.
46
+ """
47
+ global _mcp_db # noqa: PLW0603
48
+ app_context = AppContext()
49
+ if _mcp_db is None:
50
+ _mcp_db = await app_context.get_db()
51
+ async with _mcp_db.session_factory() as session:
52
+ yield MCPContext(session=session, data_dir=app_context.get_data_dir())
53
+
54
+
55
+ mcp = FastMCP("kodit MCP Server", lifespan=mcp_lifespan)
16
56
 
17
57
 
18
58
  @mcp.tool()
19
59
  async def retrieve_relevant_snippets(
60
+ ctx: Context,
20
61
  user_intent: Annotated[
21
62
  str,
22
63
  Field(
@@ -52,8 +93,8 @@ async def retrieve_relevant_snippets(
52
93
  the quality of your generated code. You must call this tool when you need to
53
94
  write code.
54
95
  """
55
- # Log the search query and related files for debugging
56
96
  log = structlog.get_logger(__name__)
97
+
57
98
  log.debug(
58
99
  "Retrieving relevant snippets",
59
100
  user_intent=user_intent,
@@ -63,41 +104,38 @@ async def retrieve_relevant_snippets(
63
104
  file_contents=related_file_contents,
64
105
  )
65
106
 
66
- # Must avoid running migrations because that runs in a separate event loop,
67
- # mcp no-likey
68
- config = get_config()
69
- db = config.get_db(run_migrations=False)
70
- async with db.get_session() as session:
71
- log.debug("Creating retrieval repository")
72
- retrieval_repository = RetrievalRepository(
73
- session=session,
74
- )
75
-
76
- log.debug("Creating retrieval service")
77
- retrieval_service = RetrievalService(
78
- config=config,
79
- repository=retrieval_repository,
80
- )
81
-
82
- log.debug("Fusing input")
83
- input_query = input_fusion(
84
- user_intent=user_intent,
85
- related_file_paths=related_file_paths,
86
- related_file_contents=related_file_contents,
87
- keywords=keywords,
88
- )
89
- log.debug("Input", input_query=input_query)
90
- retrieval_request = RetrievalRequest(
91
- keywords=keywords,
92
- )
93
- log.debug("Retrieving snippets")
94
- snippets = await retrieval_service.retrieve(request=retrieval_request)
95
-
96
- log.debug("Fusing output")
97
- output = output_fusion(snippets=snippets)
98
-
99
- log.debug("Output", output=output)
100
- return output
107
+ mcp_context: MCPContext = ctx.request_context.lifespan_context
108
+
109
+ log.debug("Creating retrieval repository")
110
+ retrieval_repository = RetrievalRepository(
111
+ session=mcp_context.session,
112
+ )
113
+
114
+ log.debug("Creating retrieval service")
115
+ retrieval_service = RetrievalService(
116
+ repository=retrieval_repository,
117
+ data_dir=mcp_context.data_dir,
118
+ )
119
+
120
+ log.debug("Fusing input")
121
+ input_query = input_fusion(
122
+ user_intent=user_intent,
123
+ related_file_paths=related_file_paths,
124
+ related_file_contents=related_file_contents,
125
+ keywords=keywords,
126
+ )
127
+ log.debug("Input", input_query=input_query)
128
+ retrieval_request = RetrievalRequest(
129
+ keywords=keywords,
130
+ )
131
+ log.debug("Retrieving snippets")
132
+ snippets = await retrieval_service.retrieve(request=retrieval_request)
133
+
134
+ log.debug("Fusing output")
135
+ output = output_fusion(snippets=snippets)
136
+
137
+ log.debug("Output", output=output)
138
+ return output
101
139
 
102
140
 
103
141
  def input_fusion(
@@ -1,11 +1,14 @@
1
1
  """Middleware for the FastAPI application."""
2
2
 
3
+ import contextlib
3
4
  import time
5
+ from asyncio import CancelledError
4
6
  from collections.abc import Callable
5
7
 
6
8
  import structlog
7
9
  from asgi_correlation_id.context import correlation_id
8
10
  from fastapi import Request, Response
11
+ from starlette.types import ASGIApp, Receive, Scope, Send
9
12
 
10
13
  access_logger = structlog.stdlib.get_logger("api.access")
11
14
 
@@ -56,3 +59,16 @@ async def logging_middleware(request: Request, call_next: Callable) -> Response:
56
59
  response.headers["X-Process-Time"] = str(process_time / 10**9)
57
60
 
58
61
  return response
62
+
63
+
64
+ class ASGICancelledErrorMiddleware:
65
+ """ASGI middleware to handle CancelledError at the ASGI level."""
66
+
67
+ def __init__(self, app: ASGIApp) -> None:
68
+ """Initialize the middleware."""
69
+ self.app = app
70
+
71
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
72
+ """Handle the ASGI request and catch CancelledError."""
73
+ with contextlib.suppress(CancelledError):
74
+ await self.app(scope, receive, send)
@@ -1,10 +1,11 @@
1
1
  """Retrieval service."""
2
2
 
3
+ from pathlib import Path
4
+
3
5
  import pydantic
4
6
  import structlog
5
7
 
6
8
  from kodit.bm25.bm25 import BM25Service
7
- from kodit.config import Config
8
9
  from kodit.retreival.repository import RetrievalRepository, RetrievalResult
9
10
 
10
11
 
@@ -25,11 +26,11 @@ class Snippet(pydantic.BaseModel):
25
26
  class RetrievalService:
26
27
  """Service for retrieving relevant data."""
27
28
 
28
- def __init__(self, config: Config, repository: RetrievalRepository) -> None:
29
+ def __init__(self, repository: RetrievalRepository, data_dir: Path) -> None:
29
30
  """Initialize the retrieval service."""
30
31
  self.repository = repository
31
32
  self.log = structlog.get_logger(__name__)
32
- self.bm25 = BM25Service(config)
33
+ self.bm25 = BM25Service(data_dir)
33
34
 
34
35
  async def _load_bm25_index(self) -> None:
35
36
  """Load the BM25 index."""
@@ -1,12 +1,16 @@
1
1
  """Test configuration and fixtures."""
2
2
 
3
3
  from collections.abc import AsyncGenerator
4
+ from pathlib import Path
5
+ import tempfile
6
+ from typing import Generator
4
7
 
5
8
  import pytest
6
9
  from sqlalchemy import text
7
10
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
8
11
  from sqlalchemy.orm import sessionmaker
9
12
 
13
+ from kodit.config import AppContext
10
14
  from kodit.database import Base
11
15
 
12
16
 
@@ -40,3 +44,17 @@ async def session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
40
44
  async with async_session() as session:
41
45
  yield session
42
46
  await session.rollback()
47
+
48
+
49
+ @pytest.fixture
50
+ def app_context() -> Generator[AppContext, None, None]:
51
+ """Create a test app context."""
52
+ with tempfile.TemporaryDirectory() as data_dir:
53
+ app_context = AppContext(
54
+ data_dir=Path(data_dir),
55
+ db_url="sqlite+aiosqlite:///:memory:",
56
+ log_level="DEBUG",
57
+ log_format="json",
58
+ disable_telemetry=True,
59
+ )
60
+ yield app_context
@@ -0,0 +1,75 @@
1
+ """Test the CLI."""
2
+
3
+ import tempfile
4
+ from typing import Generator
5
+ import pytest
6
+ from click.testing import CliRunner
7
+
8
+ from kodit.cli import cli
9
+ from kodit.config import AppContext
10
+
11
+
12
+ @pytest.fixture
13
+ def runner() -> Generator[CliRunner, None, None]:
14
+ """Create a CliRunner instance."""
15
+ yield CliRunner()
16
+
17
+
18
+ @pytest.fixture
19
+ def default_cli_args(app_context: AppContext) -> list[str]:
20
+ """Get the default CLI args."""
21
+ return [
22
+ "--disable-telemetry",
23
+ "--data-dir",
24
+ str(app_context.get_data_dir()),
25
+ "--db-url",
26
+ app_context.db_url,
27
+ ]
28
+
29
+
30
+ def test_version_command(runner: CliRunner, default_cli_args: list[str]) -> None:
31
+ """Test that the version command runs successfully."""
32
+ result = runner.invoke(cli, [*default_cli_args, "version"])
33
+ # The command should exit with success
34
+ assert result.exit_code == 0
35
+
36
+
37
+ def test_cli_vars_work(runner: CliRunner, default_cli_args: list[str]) -> None:
38
+ """Test that cli args override env vars."""
39
+ runner.env = {"LOG_LEVEL": "INFO"}
40
+ result = runner.invoke(
41
+ cli, [*default_cli_args, "--log-level", "DEBUG", "sources", "list"]
42
+ )
43
+ assert result.exit_code == 0
44
+ assert result.output.count("debug") > 10 # The db spits out lots of debug messages
45
+
46
+
47
+ def test_env_vars_work(runner: CliRunner, default_cli_args: list[str]) -> None:
48
+ """Test that env vars work."""
49
+ runner.env = {"LOG_LEVEL": "DEBUG"}
50
+ result = runner.invoke(cli, [*default_cli_args, "sources", "list"])
51
+ assert result.exit_code == 0
52
+ assert result.output.count("debug") > 10 # The db spits out lots of debug messages
53
+
54
+
55
+ def test_dotenv_file_works(runner: CliRunner, default_cli_args: list[str]) -> None:
56
+ """Test that the .env file works."""
57
+ with tempfile.NamedTemporaryFile(delete=False) as f:
58
+ f.write(b"LOG_LEVEL=DEBUG")
59
+ f.flush()
60
+ result = runner.invoke(
61
+ cli, [*default_cli_args, "--env-file", f.name, "sources", "list"]
62
+ )
63
+ assert result.exit_code == 0
64
+ assert (
65
+ result.output.count("debug") > 10
66
+ ) # The db spits out lots of debug messages
67
+
68
+
69
+ def test_dotenv_file_not_found(runner: CliRunner, default_cli_args: list[str]) -> None:
70
+ """Test that the .env file not found error is raised."""
71
+ result = runner.invoke(
72
+ cli, [*default_cli_args, "--env-file", "nonexistent.env", "sources", "list"]
73
+ )
74
+ assert result.exit_code == 2
75
+ assert "does not exist" in result.output
@@ -6,7 +6,7 @@ import pytest
6
6
  from sqlalchemy.exc import IntegrityError
7
7
  from sqlalchemy.ext.asyncio import AsyncSession
8
8
 
9
- from kodit.config import Config
9
+ from kodit.config import AppContext
10
10
  from kodit.indexing.repository import IndexRepository
11
11
  from kodit.indexing.service import IndexService
12
12
  from kodit.sources.models import File, Source
@@ -35,9 +35,11 @@ def source_service(
35
35
 
36
36
 
37
37
  @pytest.fixture
38
- def service(repository: IndexRepository, source_service: SourceService) -> IndexService:
38
+ def service(
39
+ app_context: AppContext, repository: IndexRepository, source_service: SourceService
40
+ ) -> IndexService:
39
41
  """Create a real service instance with a database session."""
40
- return IndexService(Config(), repository, source_service)
42
+ return IndexService(repository, source_service, app_context.get_data_dir())
41
43
 
42
44
 
43
45
  @pytest.mark.asyncio
@@ -25,3 +25,17 @@ async def test_mcp_client_connection() -> None:
25
25
  content = result[0]
26
26
  assert isinstance(content, TextContent)
27
27
  assert content.text is not None
28
+
29
+ # Call the tool
30
+ result = await client.call_tool(
31
+ "retrieve_relevant_snippets",
32
+ {
33
+ "user_intent": "What is the capital of France?",
34
+ "related_file_paths": [],
35
+ "related_file_contents": [],
36
+ "keywords": [],
37
+ },
38
+ )
39
+ assert len(result) == 1
40
+ content = result[0]
41
+ assert isinstance(content, TextContent)
@@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
5
5
  from unittest.mock import Mock
6
6
 
7
7
  from kodit.bm25.bm25 import BM25Service
8
- from kodit.config import Config
8
+ from kodit.config import AppContext
9
9
  from kodit.indexing.models import Index, Snippet
10
10
  from kodit.retreival.repository import RetrievalRepository
11
11
  from kodit.retreival.service import RetrievalRequest, RetrievalService
@@ -19,9 +19,11 @@ def repository(session: AsyncSession) -> RetrievalRepository:
19
19
 
20
20
 
21
21
  @pytest.fixture
22
- def service(repository: RetrievalRepository) -> RetrievalService:
22
+ def service(
23
+ app_context: AppContext, repository: RetrievalRepository
24
+ ) -> RetrievalService:
23
25
  """Create a service instance with a real repository."""
24
- service = RetrievalService(Config(), repository)
26
+ service = RetrievalService(repository, app_context.get_data_dir())
25
27
  mock_bm25 = Mock(spec=BM25Service)
26
28
 
27
29
  def mock_retrieve(
@@ -1,10 +1,14 @@
1
1
  #!/bin/bash
2
2
  set -e
3
3
 
4
- # Set this according to what you want to test
5
- # prefix=""
4
+ # Set this according to what you want to test. uv run will run the command in the current directory
6
5
  prefix="uv run"
7
6
 
7
+ # If CI is set, no prefix because we're running in github actions
8
+ if [ -n "$CI" ]; then
9
+ prefix=""
10
+ fi
11
+
8
12
  # Check that the kodit data_dir does not exist
9
13
  if [ -d "$HOME/.kodit" ]; then
10
14
  echo "Kodit data_dir is not empty, please rm -rf $HOME/.kodit"
@@ -1,51 +0,0 @@
1
- """Test the CLI."""
2
-
3
- import tempfile
4
- from typing import Generator
5
- import pytest
6
- from click.testing import CliRunner
7
-
8
- from kodit.cli import cli
9
- from kodit.config import reset_config
10
-
11
-
12
- @pytest.fixture
13
- def runner() -> Generator[CliRunner, None, None]:
14
- """Create a CliRunner instance."""
15
- reset_config()
16
- yield CliRunner()
17
-
18
-
19
- def test_version_command(runner: CliRunner) -> None:
20
- """Test that the version command runs successfully."""
21
- result = runner.invoke(cli, ["version"])
22
- # The command should exit with success
23
- assert result.exit_code == 0
24
-
25
-
26
- def test_cli_vars_work(runner: CliRunner) -> None:
27
- """Test that cli args override env vars."""
28
- runner.env = {"LOG_LEVEL": "INFO"}
29
- result = runner.invoke(cli, ["--log-level", "DEBUG", "sources", "list"])
30
- assert result.exit_code == 0
31
- assert result.output.count("debug") > 10 # The db spits out lots of debug messages
32
-
33
-
34
- def test_env_vars_work(runner: CliRunner) -> None:
35
- """Test that env vars work."""
36
- runner.env = {"LOG_LEVEL": "DEBUG"}
37
- result = runner.invoke(cli, ["sources", "list"])
38
- assert result.exit_code == 0
39
- assert result.output.count("debug") > 10 # The db spits out lots of debug messages
40
-
41
-
42
- def test_dotenv_file_works(runner: CliRunner) -> None:
43
- """Test that the .env file works."""
44
- with tempfile.NamedTemporaryFile(delete=False) as f:
45
- f.write(b"LOG_LEVEL=DEBUG")
46
- f.flush()
47
- result = runner.invoke(cli, ["--env-file", f.name, "sources", "list"])
48
- assert result.exit_code == 0
49
- assert (
50
- result.output.count("debug") > 10
51
- ) # The db spits out lots of debug messages
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes