kodit 0.1.5__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/test.yaml +3 -0
- {kodit-0.1.5 → kodit-0.1.6}/PKG-INFO +1 -1
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/_version.py +2 -2
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/env.py +5 -2
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/app.py +5 -1
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/bm25/bm25.py +4 -4
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/cli.py +56 -24
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/config.py +34 -26
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/database.py +20 -17
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/indexing/service.py +5 -3
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/logging.py +8 -8
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/mcp.py +77 -39
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/middleware.py +16 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/retreival/service.py +4 -3
- {kodit-0.1.5 → kodit-0.1.6}/tests/conftest.py +18 -0
- kodit-0.1.6/tests/kodit/cli_test.py +75 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/indexing/test_service.py +5 -3
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/mcp_test.py +14 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/retreival/test_service.py +5 -3
- {kodit-0.1.5 → kodit-0.1.6}/tests/smoke.sh +6 -2
- kodit-0.1.5/tests/kodit/cli_test.py +0 -51
- {kodit-0.1.5 → kodit-0.1.6}/.cursor/rules/kodit.mdc +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.github/CODE_OF_CONDUCT.md +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.github/CONTRIBUTING.md +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.github/PULL_REQUEST_TEMPLATE.md +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/docker.yaml +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/docs.yaml +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/pypi-test.yaml +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.github/workflows/pypi.yaml +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.gitignore +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.python-version +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.vscode/launch.json +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/.vscode/settings.json +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/Dockerfile +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/LICENSE +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/README.md +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/alembic.ini +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/docs/_index.md +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/docs/developer/index.md +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/pyproject.toml +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/.gitignore +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/README +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/script.py.mako +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/versions/85155663351e_initial.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/alembic/versions/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/bm25/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/indexing/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/indexing/models.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/indexing/repository.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/retreival/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/retreival/repository.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/languages/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/languages/csharp.scm +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/languages/python.scm +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/method_snippets.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/snippets/snippets.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/sources/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/sources/models.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/sources/repository.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/src/kodit/sources/service.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/e2e.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/indexing/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/retreival/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/csharp.cs +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/detect_language_test.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/method_extraction_test.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/snippets/python.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/sources/__init__.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/tests/kodit/sources/test_service.py +0 -0
- {kodit-0.1.5 → kodit-0.1.6}/uv.lock +0 -0
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
|
|
4
4
|
import asyncio
|
|
5
5
|
|
|
6
|
-
import structlog
|
|
7
6
|
from alembic import context
|
|
8
7
|
from sqlalchemy import pool
|
|
9
8
|
from sqlalchemy.engine import Connection
|
|
@@ -75,7 +74,11 @@ async def run_async_migrations() -> None:
|
|
|
75
74
|
|
|
76
75
|
def run_migrations_online() -> None:
|
|
77
76
|
"""Run migrations in 'online' mode."""
|
|
78
|
-
|
|
77
|
+
connectable = config.attributes.get("connection", None)
|
|
78
|
+
if connectable is None:
|
|
79
|
+
asyncio.run(run_async_migrations())
|
|
80
|
+
else:
|
|
81
|
+
do_run_migrations(connectable)
|
|
79
82
|
|
|
80
83
|
|
|
81
84
|
if context.is_offline_mode():
|
|
@@ -4,7 +4,7 @@ from asgi_correlation_id import CorrelationIdMiddleware
|
|
|
4
4
|
from fastapi import FastAPI
|
|
5
5
|
|
|
6
6
|
from kodit.mcp import mcp
|
|
7
|
-
from kodit.middleware import logging_middleware
|
|
7
|
+
from kodit.middleware import ASGICancelledErrorMiddleware, logging_middleware
|
|
8
8
|
|
|
9
9
|
# See https://gofastmcp.com/deployment/asgi#fastapi-integration
|
|
10
10
|
mcp_app = mcp.sse_app()
|
|
@@ -23,3 +23,7 @@ async def root() -> dict[str, str]:
|
|
|
23
23
|
|
|
24
24
|
# Add mcp routes last, otherwise previous routes aren't added
|
|
25
25
|
app.mount("", mcp_app)
|
|
26
|
+
|
|
27
|
+
# Wrap the entire app with ASGI middleware after all routes are added to suppress
|
|
28
|
+
# CancelledError at the ASGI level
|
|
29
|
+
app = ASGICancelledErrorMiddleware(app)
|
|
@@ -1,20 +1,20 @@
|
|
|
1
1
|
"""BM25 service."""
|
|
2
2
|
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
3
5
|
import bm25s
|
|
4
6
|
import Stemmer
|
|
5
7
|
import structlog
|
|
6
8
|
from bm25s.tokenization import Tokenized
|
|
7
9
|
|
|
8
|
-
from kodit.config import Config
|
|
9
|
-
|
|
10
10
|
|
|
11
11
|
class BM25Service:
|
|
12
12
|
"""Service for BM25."""
|
|
13
13
|
|
|
14
|
-
def __init__(self,
|
|
14
|
+
def __init__(self, data_dir: Path) -> None:
|
|
15
15
|
"""Initialize the BM25 service."""
|
|
16
16
|
self.log = structlog.get_logger(__name__)
|
|
17
|
-
self.index_path =
|
|
17
|
+
self.index_path = data_dir / "bm25s_index"
|
|
18
18
|
try:
|
|
19
19
|
self.log.debug("Loading BM25 index")
|
|
20
20
|
self.retriever = bm25s.BM25.load(self.index_path, mmap=True)
|
|
@@ -17,8 +17,8 @@ from kodit.config import (
|
|
|
17
17
|
DEFAULT_DISABLE_TELEMETRY,
|
|
18
18
|
DEFAULT_LOG_FORMAT,
|
|
19
19
|
DEFAULT_LOG_LEVEL,
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
AppContext,
|
|
21
|
+
with_app_context,
|
|
22
22
|
with_session,
|
|
23
23
|
)
|
|
24
24
|
from kodit.indexing.repository import IndexRepository
|
|
@@ -40,23 +40,33 @@ from kodit.sources.service import SourceService
|
|
|
40
40
|
)
|
|
41
41
|
@click.option("--db-url", help=f"Database URL [default: {DEFAULT_DB_URL}]")
|
|
42
42
|
@click.option("--data-dir", help=f"Data directory [default: {DEFAULT_BASE_DIR}]")
|
|
43
|
-
@click.option(
|
|
43
|
+
@click.option(
|
|
44
|
+
"--env-file",
|
|
45
|
+
help="Path to a .env file [default: .env]",
|
|
46
|
+
type=click.Path(
|
|
47
|
+
exists=True,
|
|
48
|
+
dir_okay=False,
|
|
49
|
+
resolve_path=True,
|
|
50
|
+
path_type=Path,
|
|
51
|
+
),
|
|
52
|
+
)
|
|
53
|
+
@click.pass_context
|
|
44
54
|
def cli( # noqa: PLR0913
|
|
55
|
+
ctx: click.Context,
|
|
45
56
|
log_level: str | None,
|
|
46
57
|
log_format: str | None,
|
|
47
58
|
disable_telemetry: bool | None,
|
|
48
59
|
db_url: str | None,
|
|
49
60
|
data_dir: str | None,
|
|
50
|
-
env_file:
|
|
61
|
+
env_file: Path | None,
|
|
51
62
|
) -> None:
|
|
52
63
|
"""kodit CLI - Code indexing for better AI code generation.""" # noqa: D403
|
|
64
|
+
config = AppContext()
|
|
53
65
|
# First check if env-file is set and reload config if it is
|
|
54
66
|
if env_file:
|
|
55
|
-
|
|
56
|
-
get_config(env_file)
|
|
67
|
+
config = AppContext(_env_file=env_file) # type: ignore[reportCallIssue]
|
|
57
68
|
|
|
58
|
-
#
|
|
59
|
-
config = get_config()
|
|
69
|
+
# Now override with CLI arguments, if set
|
|
60
70
|
if data_dir:
|
|
61
71
|
config.data_dir = Path(data_dir)
|
|
62
72
|
if db_url:
|
|
@@ -70,6 +80,9 @@ def cli( # noqa: PLR0913
|
|
|
70
80
|
configure_logging(config)
|
|
71
81
|
configure_telemetry(config)
|
|
72
82
|
|
|
83
|
+
# Set the app context in the click context for downstream cli
|
|
84
|
+
ctx.obj = config
|
|
85
|
+
|
|
73
86
|
|
|
74
87
|
@cli.group()
|
|
75
88
|
def sources() -> None:
|
|
@@ -77,11 +90,12 @@ def sources() -> None:
|
|
|
77
90
|
|
|
78
91
|
|
|
79
92
|
@sources.command(name="list")
|
|
93
|
+
@with_app_context
|
|
80
94
|
@with_session
|
|
81
|
-
async def list_sources(session: AsyncSession) -> None:
|
|
95
|
+
async def list_sources(session: AsyncSession, app_context: AppContext) -> None:
|
|
82
96
|
"""List all code sources."""
|
|
83
97
|
repository = SourceRepository(session)
|
|
84
|
-
service = SourceService(
|
|
98
|
+
service = SourceService(app_context.get_clone_dir(), repository)
|
|
85
99
|
sources = await service.list_sources()
|
|
86
100
|
|
|
87
101
|
# Define headers and data
|
|
@@ -95,11 +109,14 @@ async def list_sources(session: AsyncSession) -> None:
|
|
|
95
109
|
|
|
96
110
|
@sources.command(name="create")
|
|
97
111
|
@click.argument("uri")
|
|
112
|
+
@with_app_context
|
|
98
113
|
@with_session
|
|
99
|
-
async def create_source(
|
|
114
|
+
async def create_source(
|
|
115
|
+
session: AsyncSession, app_context: AppContext, uri: str
|
|
116
|
+
) -> None:
|
|
100
117
|
"""Add a new code source."""
|
|
101
118
|
repository = SourceRepository(session)
|
|
102
|
-
service = SourceService(
|
|
119
|
+
service = SourceService(app_context.get_clone_dir(), repository)
|
|
103
120
|
source = await service.create(uri)
|
|
104
121
|
click.echo(f"Source created: {source.id}")
|
|
105
122
|
|
|
@@ -111,25 +128,29 @@ def indexes() -> None:
|
|
|
111
128
|
|
|
112
129
|
@indexes.command(name="create")
|
|
113
130
|
@click.argument("source_id")
|
|
131
|
+
@with_app_context
|
|
114
132
|
@with_session
|
|
115
|
-
async def create_index(
|
|
133
|
+
async def create_index(
|
|
134
|
+
session: AsyncSession, app_context: AppContext, source_id: int
|
|
135
|
+
) -> None:
|
|
116
136
|
"""Create an index for a source."""
|
|
117
137
|
source_repository = SourceRepository(session)
|
|
118
|
-
source_service = SourceService(
|
|
138
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
119
139
|
repository = IndexRepository(session)
|
|
120
|
-
service = IndexService(
|
|
140
|
+
service = IndexService(repository, source_service, app_context.get_data_dir())
|
|
121
141
|
index = await service.create(source_id)
|
|
122
142
|
click.echo(f"Index created: {index.id}")
|
|
123
143
|
|
|
124
144
|
|
|
125
145
|
@indexes.command(name="list")
|
|
146
|
+
@with_app_context
|
|
126
147
|
@with_session
|
|
127
|
-
async def list_indexes(session: AsyncSession) -> None:
|
|
148
|
+
async def list_indexes(session: AsyncSession, app_context: AppContext) -> None:
|
|
128
149
|
"""List all indexes."""
|
|
129
150
|
source_repository = SourceRepository(session)
|
|
130
|
-
source_service = SourceService(
|
|
151
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
131
152
|
repository = IndexRepository(session)
|
|
132
|
-
service = IndexService(
|
|
153
|
+
service = IndexService(repository, source_service, app_context.get_data_dir())
|
|
133
154
|
indexes = await service.list_indexes()
|
|
134
155
|
|
|
135
156
|
# Define headers and data
|
|
@@ -156,24 +177,30 @@ async def list_indexes(session: AsyncSession) -> None:
|
|
|
156
177
|
|
|
157
178
|
@indexes.command(name="run")
|
|
158
179
|
@click.argument("index_id")
|
|
180
|
+
@with_app_context
|
|
159
181
|
@with_session
|
|
160
|
-
async def run_index(
|
|
182
|
+
async def run_index(
|
|
183
|
+
session: AsyncSession, app_context: AppContext, index_id: int
|
|
184
|
+
) -> None:
|
|
161
185
|
"""Run an index."""
|
|
162
186
|
source_repository = SourceRepository(session)
|
|
163
|
-
source_service = SourceService(
|
|
187
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
164
188
|
repository = IndexRepository(session)
|
|
165
|
-
service = IndexService(
|
|
189
|
+
service = IndexService(repository, source_service, app_context.get_data_dir())
|
|
166
190
|
await service.run(index_id)
|
|
167
191
|
|
|
168
192
|
|
|
169
193
|
@cli.command()
|
|
170
194
|
@click.argument("query")
|
|
171
195
|
@click.option("--top-k", default=10, help="Number of snippets to retrieve")
|
|
196
|
+
@with_app_context
|
|
172
197
|
@with_session
|
|
173
|
-
async def retrieve(
|
|
198
|
+
async def retrieve(
|
|
199
|
+
session: AsyncSession, app_context: AppContext, query: str, top_k: int
|
|
200
|
+
) -> None:
|
|
174
201
|
"""Retrieve snippets from the database."""
|
|
175
202
|
repository = RetrievalRepository(session)
|
|
176
|
-
service = RetrievalService(
|
|
203
|
+
service = RetrievalService(repository, app_context.get_data_dir())
|
|
177
204
|
# Temporary request while we don't have all search capabilities
|
|
178
205
|
snippets = await service.retrieve(
|
|
179
206
|
RetrievalRequest(keywords=query.split(","), top_k=top_k)
|
|
@@ -194,7 +221,9 @@ async def retrieve(session: AsyncSession, query: str, top_k: int) -> None:
|
|
|
194
221
|
@cli.command()
|
|
195
222
|
@click.option("--host", default="127.0.0.1", help="Host to bind the server to")
|
|
196
223
|
@click.option("--port", default=8080, help="Port to bind the server to")
|
|
224
|
+
@with_app_context
|
|
197
225
|
def serve(
|
|
226
|
+
app_context: AppContext,
|
|
198
227
|
host: str,
|
|
199
228
|
port: int,
|
|
200
229
|
) -> None:
|
|
@@ -202,7 +231,10 @@ def serve(
|
|
|
202
231
|
log = structlog.get_logger(__name__)
|
|
203
232
|
log.info("Starting kodit server", host=host, port=port)
|
|
204
233
|
log_event("kodit_server_started")
|
|
205
|
-
|
|
234
|
+
|
|
235
|
+
# Dump AppContext to a dictionary of strings, and set the env vars
|
|
236
|
+
app_context_dict = {k: str(v) for k, v in app_context.model_dump().items()}
|
|
237
|
+
os.environ.update(app_context_dict)
|
|
206
238
|
|
|
207
239
|
# Configure uvicorn with graceful shutdown
|
|
208
240
|
config = uvicorn.Config(
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
"""Global configuration for the kodit project."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
from collections.abc import Callable
|
|
4
|
+
from collections.abc import Callable, Coroutine
|
|
5
5
|
from functools import wraps
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Any, TypeVar
|
|
8
8
|
|
|
9
|
+
import click
|
|
9
10
|
from pydantic import Field
|
|
10
11
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
11
12
|
|
|
@@ -19,8 +20,8 @@ DEFAULT_DISABLE_TELEMETRY = False
|
|
|
19
20
|
T = TypeVar("T")
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
class
|
|
23
|
-
"""Global
|
|
23
|
+
class AppContext(BaseSettings):
|
|
24
|
+
"""Global context for the kodit project. Provides a shared state for the app."""
|
|
24
25
|
|
|
25
26
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
|
26
27
|
|
|
@@ -47,43 +48,50 @@ class Config(BaseSettings):
|
|
|
47
48
|
clone_dir.mkdir(parents=True, exist_ok=True)
|
|
48
49
|
return clone_dir
|
|
49
50
|
|
|
50
|
-
def get_db(self, *, run_migrations: bool = True) -> Database:
|
|
51
|
+
async def get_db(self, *, run_migrations: bool = True) -> Database:
|
|
51
52
|
"""Get the database."""
|
|
52
53
|
if self._db is None:
|
|
53
|
-
self._db = Database(self.db_url
|
|
54
|
+
self._db = Database(self.db_url)
|
|
55
|
+
if run_migrations:
|
|
56
|
+
await self._db.run_migrations(self.db_url)
|
|
54
57
|
return self._db
|
|
55
58
|
|
|
56
59
|
|
|
57
|
-
|
|
58
|
-
config = None
|
|
60
|
+
with_app_context = click.make_pass_decorator(AppContext)
|
|
59
61
|
|
|
62
|
+
T = TypeVar("T")
|
|
60
63
|
|
|
61
|
-
def get_config(env_file: str | None = None) -> Config:
|
|
62
|
-
"""Get the global config instance."""
|
|
63
|
-
global config # noqa: PLW0603
|
|
64
|
-
if config is None:
|
|
65
|
-
config = Config(_env_file=env_file)
|
|
66
|
-
return config
|
|
67
64
|
|
|
65
|
+
def wrap_async(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
|
|
66
|
+
"""Decorate async Click commands.
|
|
68
67
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
global config # noqa: PLW0603
|
|
72
|
-
config = None
|
|
68
|
+
This decorator wraps an async function to run it with asyncio.run().
|
|
69
|
+
It should be used after the Click command decorator.
|
|
73
70
|
|
|
71
|
+
Example:
|
|
72
|
+
@cli.command()
|
|
73
|
+
@wrap_async
|
|
74
|
+
async def my_command():
|
|
75
|
+
...
|
|
74
76
|
|
|
75
|
-
|
|
76
|
-
"""Provide an async session to CLI commands."""
|
|
77
|
+
"""
|
|
77
78
|
|
|
78
|
-
@wraps(
|
|
79
|
+
@wraps(f)
|
|
79
80
|
def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
80
|
-
|
|
81
|
-
|
|
81
|
+
return asyncio.run(f(*args, **kwargs))
|
|
82
|
+
|
|
83
|
+
return wrapper
|
|
84
|
+
|
|
82
85
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
return await func(session, *args, **kwargs)
|
|
86
|
+
def with_session(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
|
|
87
|
+
"""Provide a database session to CLI commands."""
|
|
86
88
|
|
|
87
|
-
|
|
89
|
+
@wraps(f)
|
|
90
|
+
@with_app_context
|
|
91
|
+
@wrap_async
|
|
92
|
+
async def wrapper(app_context: AppContext, *args: Any, **kwargs: Any) -> T:
|
|
93
|
+
db = await app_context.get_db()
|
|
94
|
+
async with db.session_factory() as session:
|
|
95
|
+
return await f(session, *args, **kwargs)
|
|
88
96
|
|
|
89
97
|
return wrapper
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Database configuration for kodit."""
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncGenerator
|
|
4
|
-
from contextlib import asynccontextmanager
|
|
5
3
|
from datetime import UTC, datetime
|
|
6
4
|
from pathlib import Path
|
|
7
5
|
|
|
@@ -39,28 +37,22 @@ class CommonMixin:
|
|
|
39
37
|
class Database:
|
|
40
38
|
"""Database class for kodit."""
|
|
41
39
|
|
|
42
|
-
def __init__(self, db_url: str
|
|
40
|
+
def __init__(self, db_url: str) -> None:
|
|
43
41
|
"""Initialize the database."""
|
|
44
42
|
self.log = structlog.get_logger(__name__)
|
|
45
|
-
|
|
46
|
-
self._run_migrations(db_url)
|
|
47
|
-
db_engine = create_async_engine(db_url, echo=False)
|
|
43
|
+
self.db_engine = create_async_engine(db_url, echo=False)
|
|
48
44
|
self.db_session_factory = async_sessionmaker(
|
|
49
|
-
db_engine,
|
|
45
|
+
self.db_engine,
|
|
50
46
|
class_=AsyncSession,
|
|
51
47
|
expire_on_commit=False,
|
|
52
48
|
)
|
|
53
49
|
|
|
54
|
-
@
|
|
55
|
-
|
|
56
|
-
"""Get
|
|
57
|
-
|
|
58
|
-
try:
|
|
59
|
-
yield session
|
|
60
|
-
finally:
|
|
61
|
-
await session.close()
|
|
50
|
+
@property
|
|
51
|
+
def session_factory(self) -> async_sessionmaker[AsyncSession]:
|
|
52
|
+
"""Get the session factory."""
|
|
53
|
+
return self.db_session_factory
|
|
62
54
|
|
|
63
|
-
def
|
|
55
|
+
async def run_migrations(self, db_url: str) -> None:
|
|
64
56
|
"""Run any pending migrations."""
|
|
65
57
|
# Create Alembic configuration and run migrations
|
|
66
58
|
alembic_cfg = AlembicConfig()
|
|
@@ -69,4 +61,15 @@ class Database:
|
|
|
69
61
|
)
|
|
70
62
|
alembic_cfg.set_main_option("sqlalchemy.url", db_url)
|
|
71
63
|
self.log.debug("Running migrations", db_url=db_url)
|
|
72
|
-
|
|
64
|
+
|
|
65
|
+
async with self.db_engine.begin() as conn:
|
|
66
|
+
await conn.run_sync(self.run_upgrade, alembic_cfg)
|
|
67
|
+
|
|
68
|
+
def run_upgrade(self, connection, cfg) -> None: # noqa: ANN001
|
|
69
|
+
"""Make sure the database is up to date."""
|
|
70
|
+
cfg.attributes["connection"] = connection
|
|
71
|
+
command.upgrade(cfg, "head")
|
|
72
|
+
|
|
73
|
+
async def close(self) -> None:
|
|
74
|
+
"""Close the database."""
|
|
75
|
+
await self.db_engine.dispose()
|
|
@@ -14,7 +14,6 @@ import structlog
|
|
|
14
14
|
from tqdm.asyncio import tqdm
|
|
15
15
|
|
|
16
16
|
from kodit.bm25.bm25 import BM25Service
|
|
17
|
-
from kodit.config import Config
|
|
18
17
|
from kodit.indexing.models import Snippet
|
|
19
18
|
from kodit.indexing.repository import IndexRepository
|
|
20
19
|
from kodit.snippets.snippets import SnippetService
|
|
@@ -46,7 +45,10 @@ class IndexService:
|
|
|
46
45
|
"""
|
|
47
46
|
|
|
48
47
|
def __init__(
|
|
49
|
-
self,
|
|
48
|
+
self,
|
|
49
|
+
repository: IndexRepository,
|
|
50
|
+
source_service: SourceService,
|
|
51
|
+
data_dir: Path,
|
|
50
52
|
) -> None:
|
|
51
53
|
"""Initialize the index service.
|
|
52
54
|
|
|
@@ -59,7 +61,7 @@ class IndexService:
|
|
|
59
61
|
self.source_service = source_service
|
|
60
62
|
self.snippet_service = SnippetService()
|
|
61
63
|
self.log = structlog.get_logger(__name__)
|
|
62
|
-
self.bm25 = BM25Service(
|
|
64
|
+
self.bm25 = BM25Service(data_dir)
|
|
63
65
|
|
|
64
66
|
async def create(self, source_id: int) -> IndexView:
|
|
65
67
|
"""Create a new index for a source.
|
|
@@ -11,7 +11,7 @@ import structlog
|
|
|
11
11
|
from posthog import Posthog
|
|
12
12
|
from structlog.types import EventDict
|
|
13
13
|
|
|
14
|
-
from kodit.config import
|
|
14
|
+
from kodit.config import AppContext
|
|
15
15
|
|
|
16
16
|
log = structlog.get_logger(__name__)
|
|
17
17
|
|
|
@@ -29,7 +29,7 @@ class LogFormat(Enum):
|
|
|
29
29
|
JSON = "json"
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
def configure_logging(
|
|
32
|
+
def configure_logging(app_context: AppContext) -> None:
|
|
33
33
|
"""Configure logging for the application."""
|
|
34
34
|
timestamper = structlog.processors.TimeStamper(fmt="iso")
|
|
35
35
|
|
|
@@ -44,7 +44,7 @@ def configure_logging(config: Config) -> None:
|
|
|
44
44
|
structlog.processors.StackInfoRenderer(),
|
|
45
45
|
]
|
|
46
46
|
|
|
47
|
-
if
|
|
47
|
+
if app_context.log_format == LogFormat.JSON:
|
|
48
48
|
# Format the exception only for JSON logs, as we want to pretty-print them
|
|
49
49
|
# when using the ConsoleRenderer
|
|
50
50
|
shared_processors.append(structlog.processors.format_exc_info)
|
|
@@ -60,7 +60,7 @@ def configure_logging(config: Config) -> None:
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
log_renderer: structlog.types.Processor
|
|
63
|
-
if
|
|
63
|
+
if app_context.log_format == LogFormat.JSON:
|
|
64
64
|
log_renderer = structlog.processors.JSONRenderer()
|
|
65
65
|
else:
|
|
66
66
|
log_renderer = structlog.dev.ConsoleRenderer()
|
|
@@ -82,7 +82,7 @@ def configure_logging(config: Config) -> None:
|
|
|
82
82
|
handler.setFormatter(formatter)
|
|
83
83
|
root_logger = logging.getLogger()
|
|
84
84
|
root_logger.addHandler(handler)
|
|
85
|
-
root_logger.setLevel(
|
|
85
|
+
root_logger.setLevel(app_context.log_level.upper())
|
|
86
86
|
|
|
87
87
|
# Configure uvicorn loggers to use our structlog setup
|
|
88
88
|
# Uvicorn spits out loads of exception logs when sse server doesn't shut down
|
|
@@ -98,7 +98,7 @@ def configure_logging(config: Config) -> None:
|
|
|
98
98
|
for _log in ["sqlalchemy.engine", "alembic"]:
|
|
99
99
|
engine_logger = logging.getLogger(_log)
|
|
100
100
|
engine_logger.setLevel(logging.WARNING) # Hide INFO logs by default
|
|
101
|
-
if
|
|
101
|
+
if app_context.log_level.upper() == "DEBUG":
|
|
102
102
|
engine_logger.setLevel(
|
|
103
103
|
logging.DEBUG
|
|
104
104
|
) # Only show all logs when in DEBUG mode
|
|
@@ -143,9 +143,9 @@ def get_mac_address() -> str:
|
|
|
143
143
|
return f"{mac:012x}" if mac != uuid.getnode() else str(uuid.uuid4())
|
|
144
144
|
|
|
145
145
|
|
|
146
|
-
def configure_telemetry(
|
|
146
|
+
def configure_telemetry(app_context: AppContext) -> None:
|
|
147
147
|
"""Configure telemetry for the application."""
|
|
148
|
-
if
|
|
148
|
+
if app_context.disable_telemetry:
|
|
149
149
|
structlog.stdlib.get_logger(__name__).info("Telemetry has been disabled")
|
|
150
150
|
posthog.disabled = True
|
|
151
151
|
|
|
@@ -1,22 +1,63 @@
|
|
|
1
1
|
"""MCP server implementation for kodit."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
3
6
|
from pathlib import Path
|
|
4
7
|
from typing import Annotated
|
|
5
8
|
|
|
6
9
|
import structlog
|
|
7
|
-
from fastmcp import FastMCP
|
|
10
|
+
from fastmcp import Context, FastMCP
|
|
8
11
|
from pydantic import Field
|
|
12
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
13
|
|
|
10
14
|
from kodit._version import version
|
|
11
|
-
from kodit.config import
|
|
15
|
+
from kodit.config import AppContext
|
|
16
|
+
from kodit.database import Database
|
|
12
17
|
from kodit.retreival.repository import RetrievalRepository, RetrievalResult
|
|
13
18
|
from kodit.retreival.service import RetrievalRequest, RetrievalService
|
|
14
19
|
|
|
15
|
-
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class MCPContext:
|
|
23
|
+
"""Context for the MCP server."""
|
|
24
|
+
|
|
25
|
+
session: AsyncSession
|
|
26
|
+
data_dir: Path
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
_mcp_db: Database | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@asynccontextmanager
|
|
33
|
+
async def mcp_lifespan(_: FastMCP) -> AsyncIterator[MCPContext]:
|
|
34
|
+
"""Lifespan for the MCP server.
|
|
35
|
+
|
|
36
|
+
The MCP server is running with a completely separate lifecycle and event loop from
|
|
37
|
+
the CLI and the FastAPI server. Therefore, we must carefully reconstruct the
|
|
38
|
+
application context. uvicorn does not pass through CLI args, so we must rely on
|
|
39
|
+
parsing env vars set in the CLI.
|
|
40
|
+
|
|
41
|
+
This lifespan is recreated for each request. See:
|
|
42
|
+
https://github.com/jlowin/fastmcp/issues/166
|
|
43
|
+
|
|
44
|
+
Since they don't provide a good way to handle global state, we must use a
|
|
45
|
+
global variable to store the database connection.
|
|
46
|
+
"""
|
|
47
|
+
global _mcp_db # noqa: PLW0603
|
|
48
|
+
app_context = AppContext()
|
|
49
|
+
if _mcp_db is None:
|
|
50
|
+
_mcp_db = await app_context.get_db()
|
|
51
|
+
async with _mcp_db.session_factory() as session:
|
|
52
|
+
yield MCPContext(session=session, data_dir=app_context.get_data_dir())
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
mcp = FastMCP("kodit MCP Server", lifespan=mcp_lifespan)
|
|
16
56
|
|
|
17
57
|
|
|
18
58
|
@mcp.tool()
|
|
19
59
|
async def retrieve_relevant_snippets(
|
|
60
|
+
ctx: Context,
|
|
20
61
|
user_intent: Annotated[
|
|
21
62
|
str,
|
|
22
63
|
Field(
|
|
@@ -52,8 +93,8 @@ async def retrieve_relevant_snippets(
|
|
|
52
93
|
the quality of your generated code. You must call this tool when you need to
|
|
53
94
|
write code.
|
|
54
95
|
"""
|
|
55
|
-
# Log the search query and related files for debugging
|
|
56
96
|
log = structlog.get_logger(__name__)
|
|
97
|
+
|
|
57
98
|
log.debug(
|
|
58
99
|
"Retrieving relevant snippets",
|
|
59
100
|
user_intent=user_intent,
|
|
@@ -63,41 +104,38 @@ async def retrieve_relevant_snippets(
|
|
|
63
104
|
file_contents=related_file_contents,
|
|
64
105
|
)
|
|
65
106
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
log.debug("Output", output=output)
|
|
100
|
-
return output
|
|
107
|
+
mcp_context: MCPContext = ctx.request_context.lifespan_context
|
|
108
|
+
|
|
109
|
+
log.debug("Creating retrieval repository")
|
|
110
|
+
retrieval_repository = RetrievalRepository(
|
|
111
|
+
session=mcp_context.session,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
log.debug("Creating retrieval service")
|
|
115
|
+
retrieval_service = RetrievalService(
|
|
116
|
+
repository=retrieval_repository,
|
|
117
|
+
data_dir=mcp_context.data_dir,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
log.debug("Fusing input")
|
|
121
|
+
input_query = input_fusion(
|
|
122
|
+
user_intent=user_intent,
|
|
123
|
+
related_file_paths=related_file_paths,
|
|
124
|
+
related_file_contents=related_file_contents,
|
|
125
|
+
keywords=keywords,
|
|
126
|
+
)
|
|
127
|
+
log.debug("Input", input_query=input_query)
|
|
128
|
+
retrieval_request = RetrievalRequest(
|
|
129
|
+
keywords=keywords,
|
|
130
|
+
)
|
|
131
|
+
log.debug("Retrieving snippets")
|
|
132
|
+
snippets = await retrieval_service.retrieve(request=retrieval_request)
|
|
133
|
+
|
|
134
|
+
log.debug("Fusing output")
|
|
135
|
+
output = output_fusion(snippets=snippets)
|
|
136
|
+
|
|
137
|
+
log.debug("Output", output=output)
|
|
138
|
+
return output
|
|
101
139
|
|
|
102
140
|
|
|
103
141
|
def input_fusion(
|
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
"""Middleware for the FastAPI application."""
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import time
|
|
5
|
+
from asyncio import CancelledError
|
|
4
6
|
from collections.abc import Callable
|
|
5
7
|
|
|
6
8
|
import structlog
|
|
7
9
|
from asgi_correlation_id.context import correlation_id
|
|
8
10
|
from fastapi import Request, Response
|
|
11
|
+
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
9
12
|
|
|
10
13
|
access_logger = structlog.stdlib.get_logger("api.access")
|
|
11
14
|
|
|
@@ -56,3 +59,16 @@ async def logging_middleware(request: Request, call_next: Callable) -> Response:
|
|
|
56
59
|
response.headers["X-Process-Time"] = str(process_time / 10**9)
|
|
57
60
|
|
|
58
61
|
return response
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ASGICancelledErrorMiddleware:
|
|
65
|
+
"""ASGI middleware to handle CancelledError at the ASGI level."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, app: ASGIApp) -> None:
|
|
68
|
+
"""Initialize the middleware."""
|
|
69
|
+
self.app = app
|
|
70
|
+
|
|
71
|
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
72
|
+
"""Handle the ASGI request and catch CancelledError."""
|
|
73
|
+
with contextlib.suppress(CancelledError):
|
|
74
|
+
await self.app(scope, receive, send)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
"""Retrieval service."""
|
|
2
2
|
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
3
5
|
import pydantic
|
|
4
6
|
import structlog
|
|
5
7
|
|
|
6
8
|
from kodit.bm25.bm25 import BM25Service
|
|
7
|
-
from kodit.config import Config
|
|
8
9
|
from kodit.retreival.repository import RetrievalRepository, RetrievalResult
|
|
9
10
|
|
|
10
11
|
|
|
@@ -25,11 +26,11 @@ class Snippet(pydantic.BaseModel):
|
|
|
25
26
|
class RetrievalService:
|
|
26
27
|
"""Service for retrieving relevant data."""
|
|
27
28
|
|
|
28
|
-
def __init__(self,
|
|
29
|
+
def __init__(self, repository: RetrievalRepository, data_dir: Path) -> None:
|
|
29
30
|
"""Initialize the retrieval service."""
|
|
30
31
|
self.repository = repository
|
|
31
32
|
self.log = structlog.get_logger(__name__)
|
|
32
|
-
self.bm25 = BM25Service(
|
|
33
|
+
self.bm25 = BM25Service(data_dir)
|
|
33
34
|
|
|
34
35
|
async def _load_bm25_index(self) -> None:
|
|
35
36
|
"""Load the BM25 index."""
|
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
"""Test configuration and fixtures."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import AsyncGenerator
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import tempfile
|
|
6
|
+
from typing import Generator
|
|
4
7
|
|
|
5
8
|
import pytest
|
|
6
9
|
from sqlalchemy import text
|
|
7
10
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
8
11
|
from sqlalchemy.orm import sessionmaker
|
|
9
12
|
|
|
13
|
+
from kodit.config import AppContext
|
|
10
14
|
from kodit.database import Base
|
|
11
15
|
|
|
12
16
|
|
|
@@ -40,3 +44,17 @@ async def session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
|
|
40
44
|
async with async_session() as session:
|
|
41
45
|
yield session
|
|
42
46
|
await session.rollback()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.fixture
|
|
50
|
+
def app_context() -> Generator[AppContext, None, None]:
|
|
51
|
+
"""Create a test app context."""
|
|
52
|
+
with tempfile.TemporaryDirectory() as data_dir:
|
|
53
|
+
app_context = AppContext(
|
|
54
|
+
data_dir=Path(data_dir),
|
|
55
|
+
db_url="sqlite+aiosqlite:///:memory:",
|
|
56
|
+
log_level="DEBUG",
|
|
57
|
+
log_format="json",
|
|
58
|
+
disable_telemetry=True,
|
|
59
|
+
)
|
|
60
|
+
yield app_context
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Test the CLI."""
|
|
2
|
+
|
|
3
|
+
import tempfile
|
|
4
|
+
from typing import Generator
|
|
5
|
+
import pytest
|
|
6
|
+
from click.testing import CliRunner
|
|
7
|
+
|
|
8
|
+
from kodit.cli import cli
|
|
9
|
+
from kodit.config import AppContext
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def runner() -> Generator[CliRunner, None, None]:
|
|
14
|
+
"""Create a CliRunner instance."""
|
|
15
|
+
yield CliRunner()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def default_cli_args(app_context: AppContext) -> list[str]:
|
|
20
|
+
"""Get the default CLI args."""
|
|
21
|
+
return [
|
|
22
|
+
"--disable-telemetry",
|
|
23
|
+
"--data-dir",
|
|
24
|
+
str(app_context.get_data_dir()),
|
|
25
|
+
"--db-url",
|
|
26
|
+
app_context.db_url,
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_version_command(runner: CliRunner, default_cli_args: list[str]) -> None:
|
|
31
|
+
"""Test that the version command runs successfully."""
|
|
32
|
+
result = runner.invoke(cli, [*default_cli_args, "version"])
|
|
33
|
+
# The command should exit with success
|
|
34
|
+
assert result.exit_code == 0
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_cli_vars_work(runner: CliRunner, default_cli_args: list[str]) -> None:
|
|
38
|
+
"""Test that cli args override env vars."""
|
|
39
|
+
runner.env = {"LOG_LEVEL": "INFO"}
|
|
40
|
+
result = runner.invoke(
|
|
41
|
+
cli, [*default_cli_args, "--log-level", "DEBUG", "sources", "list"]
|
|
42
|
+
)
|
|
43
|
+
assert result.exit_code == 0
|
|
44
|
+
assert result.output.count("debug") > 10 # The db spits out lots of debug messages
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_env_vars_work(runner: CliRunner, default_cli_args: list[str]) -> None:
|
|
48
|
+
"""Test that env vars work."""
|
|
49
|
+
runner.env = {"LOG_LEVEL": "DEBUG"}
|
|
50
|
+
result = runner.invoke(cli, [*default_cli_args, "sources", "list"])
|
|
51
|
+
assert result.exit_code == 0
|
|
52
|
+
assert result.output.count("debug") > 10 # The db spits out lots of debug messages
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def test_dotenv_file_works(runner: CliRunner, default_cli_args: list[str]) -> None:
|
|
56
|
+
"""Test that the .env file works."""
|
|
57
|
+
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
58
|
+
f.write(b"LOG_LEVEL=DEBUG")
|
|
59
|
+
f.flush()
|
|
60
|
+
result = runner.invoke(
|
|
61
|
+
cli, [*default_cli_args, "--env-file", f.name, "sources", "list"]
|
|
62
|
+
)
|
|
63
|
+
assert result.exit_code == 0
|
|
64
|
+
assert (
|
|
65
|
+
result.output.count("debug") > 10
|
|
66
|
+
) # The db spits out lots of debug messages
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_dotenv_file_not_found(runner: CliRunner, default_cli_args: list[str]) -> None:
|
|
70
|
+
"""Test that the .env file not found error is raised."""
|
|
71
|
+
result = runner.invoke(
|
|
72
|
+
cli, [*default_cli_args, "--env-file", "nonexistent.env", "sources", "list"]
|
|
73
|
+
)
|
|
74
|
+
assert result.exit_code == 2
|
|
75
|
+
assert "does not exist" in result.output
|
|
@@ -6,7 +6,7 @@ import pytest
|
|
|
6
6
|
from sqlalchemy.exc import IntegrityError
|
|
7
7
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
8
8
|
|
|
9
|
-
from kodit.config import
|
|
9
|
+
from kodit.config import AppContext
|
|
10
10
|
from kodit.indexing.repository import IndexRepository
|
|
11
11
|
from kodit.indexing.service import IndexService
|
|
12
12
|
from kodit.sources.models import File, Source
|
|
@@ -35,9 +35,11 @@ def source_service(
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
@pytest.fixture
|
|
38
|
-
def service(
|
|
38
|
+
def service(
|
|
39
|
+
app_context: AppContext, repository: IndexRepository, source_service: SourceService
|
|
40
|
+
) -> IndexService:
|
|
39
41
|
"""Create a real service instance with a database session."""
|
|
40
|
-
return IndexService(
|
|
42
|
+
return IndexService(repository, source_service, app_context.get_data_dir())
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
@pytest.mark.asyncio
|
|
@@ -25,3 +25,17 @@ async def test_mcp_client_connection() -> None:
|
|
|
25
25
|
content = result[0]
|
|
26
26
|
assert isinstance(content, TextContent)
|
|
27
27
|
assert content.text is not None
|
|
28
|
+
|
|
29
|
+
# Call the tool
|
|
30
|
+
result = await client.call_tool(
|
|
31
|
+
"retrieve_relevant_snippets",
|
|
32
|
+
{
|
|
33
|
+
"user_intent": "What is the capital of France?",
|
|
34
|
+
"related_file_paths": [],
|
|
35
|
+
"related_file_contents": [],
|
|
36
|
+
"keywords": [],
|
|
37
|
+
},
|
|
38
|
+
)
|
|
39
|
+
assert len(result) == 1
|
|
40
|
+
content = result[0]
|
|
41
|
+
assert isinstance(content, TextContent)
|
|
@@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
5
5
|
from unittest.mock import Mock
|
|
6
6
|
|
|
7
7
|
from kodit.bm25.bm25 import BM25Service
|
|
8
|
-
from kodit.config import
|
|
8
|
+
from kodit.config import AppContext
|
|
9
9
|
from kodit.indexing.models import Index, Snippet
|
|
10
10
|
from kodit.retreival.repository import RetrievalRepository
|
|
11
11
|
from kodit.retreival.service import RetrievalRequest, RetrievalService
|
|
@@ -19,9 +19,11 @@ def repository(session: AsyncSession) -> RetrievalRepository:
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
@pytest.fixture
|
|
22
|
-
def service(
|
|
22
|
+
def service(
|
|
23
|
+
app_context: AppContext, repository: RetrievalRepository
|
|
24
|
+
) -> RetrievalService:
|
|
23
25
|
"""Create a service instance with a real repository."""
|
|
24
|
-
service = RetrievalService(
|
|
26
|
+
service = RetrievalService(repository, app_context.get_data_dir())
|
|
25
27
|
mock_bm25 = Mock(spec=BM25Service)
|
|
26
28
|
|
|
27
29
|
def mock_retrieve(
|
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
#!/bin/bash
|
|
2
2
|
set -e
|
|
3
3
|
|
|
4
|
-
# Set this according to what you want to test
|
|
5
|
-
# prefix=""
|
|
4
|
+
# Set this according to what you want to test. uv run will run the command in the current directory
|
|
6
5
|
prefix="uv run"
|
|
7
6
|
|
|
7
|
+
# If CI is set, no prefix because we're running in github actions
|
|
8
|
+
if [ -n "$CI" ]; then
|
|
9
|
+
prefix=""
|
|
10
|
+
fi
|
|
11
|
+
|
|
8
12
|
# Check that the kodit data_dir does not exist
|
|
9
13
|
if [ -d "$HOME/.kodit" ]; then
|
|
10
14
|
echo "Kodit data_dir is not empty, please rm -rf $HOME/.kodit"
|
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
"""Test the CLI."""
|
|
2
|
-
|
|
3
|
-
import tempfile
|
|
4
|
-
from typing import Generator
|
|
5
|
-
import pytest
|
|
6
|
-
from click.testing import CliRunner
|
|
7
|
-
|
|
8
|
-
from kodit.cli import cli
|
|
9
|
-
from kodit.config import reset_config
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
@pytest.fixture
|
|
13
|
-
def runner() -> Generator[CliRunner, None, None]:
|
|
14
|
-
"""Create a CliRunner instance."""
|
|
15
|
-
reset_config()
|
|
16
|
-
yield CliRunner()
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def test_version_command(runner: CliRunner) -> None:
|
|
20
|
-
"""Test that the version command runs successfully."""
|
|
21
|
-
result = runner.invoke(cli, ["version"])
|
|
22
|
-
# The command should exit with success
|
|
23
|
-
assert result.exit_code == 0
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def test_cli_vars_work(runner: CliRunner) -> None:
|
|
27
|
-
"""Test that cli args override env vars."""
|
|
28
|
-
runner.env = {"LOG_LEVEL": "INFO"}
|
|
29
|
-
result = runner.invoke(cli, ["--log-level", "DEBUG", "sources", "list"])
|
|
30
|
-
assert result.exit_code == 0
|
|
31
|
-
assert result.output.count("debug") > 10 # The db spits out lots of debug messages
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def test_env_vars_work(runner: CliRunner) -> None:
|
|
35
|
-
"""Test that env vars work."""
|
|
36
|
-
runner.env = {"LOG_LEVEL": "DEBUG"}
|
|
37
|
-
result = runner.invoke(cli, ["sources", "list"])
|
|
38
|
-
assert result.exit_code == 0
|
|
39
|
-
assert result.output.count("debug") > 10 # The db spits out lots of debug messages
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def test_dotenv_file_works(runner: CliRunner) -> None:
|
|
43
|
-
"""Test that the .env file works."""
|
|
44
|
-
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
45
|
-
f.write(b"LOG_LEVEL=DEBUG")
|
|
46
|
-
f.flush()
|
|
47
|
-
result = runner.invoke(cli, ["--env-file", f.name, "sources", "list"])
|
|
48
|
-
assert result.exit_code == 0
|
|
49
|
-
assert (
|
|
50
|
-
result.output.count("debug") > 10
|
|
51
|
-
) # The db spits out lots of debug messages
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|