kodit 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/alembic/env.py +0 -2
- kodit/app.py +8 -8
- kodit/bm25/__init__.py +1 -0
- kodit/bm25/bm25.py +71 -0
- kodit/cli.py +87 -35
- kodit/config.py +86 -2
- kodit/database.py +38 -55
- kodit/indexing/repository.py +11 -0
- kodit/indexing/service.py +26 -17
- kodit/logging.py +20 -18
- kodit/mcp.py +76 -5
- kodit/retreival/repository.py +32 -0
- kodit/retreival/service.py +41 -3
- kodit/snippets/__init__.py +1 -0
- kodit/snippets/languages/__init__.py +53 -0
- kodit/snippets/languages/csharp.scm +12 -0
- kodit/snippets/languages/python.scm +22 -0
- kodit/snippets/method_snippets.py +120 -0
- kodit/snippets/snippets.py +48 -0
- kodit/sources/service.py +3 -5
- {kodit-0.1.3.dist-info → kodit-0.1.5.dist-info}/METADATA +6 -2
- kodit-0.1.5.dist-info/RECORD +40 -0
- kodit/sse.py +0 -61
- kodit-0.1.3.dist-info/RECORD +0 -33
- {kodit-0.1.3.dist-info → kodit-0.1.5.dist-info}/WHEEL +0 -0
- {kodit-0.1.3.dist-info → kodit-0.1.5.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.3.dist-info → kodit-0.1.5.dist-info}/licenses/LICENSE +0 -0
kodit/_version.py
CHANGED
kodit/alembic/env.py
CHANGED
|
@@ -66,8 +66,6 @@ async def run_async_migrations() -> None:
|
|
|
66
66
|
prefix="sqlalchemy.",
|
|
67
67
|
poolclass=pool.NullPool,
|
|
68
68
|
)
|
|
69
|
-
log = structlog.get_logger(__name__)
|
|
70
|
-
log.debug("Running migrations on %s", connectable.url)
|
|
71
69
|
|
|
72
70
|
async with connectable.connect() as connection:
|
|
73
71
|
await connection.run_sync(do_run_migrations)
|
kodit/app.py
CHANGED
|
@@ -5,14 +5,10 @@ from fastapi import FastAPI
|
|
|
5
5
|
|
|
6
6
|
from kodit.mcp import mcp
|
|
7
7
|
from kodit.middleware import logging_middleware
|
|
8
|
-
from kodit.sse import create_sse_server
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
sse_app = create_sse_server(mcp)
|
|
14
|
-
for route in sse_app.routes:
|
|
15
|
-
app.router.routes.append(route)
|
|
9
|
+
# See https://gofastmcp.com/deployment/asgi#fastapi-integration
|
|
10
|
+
mcp_app = mcp.sse_app()
|
|
11
|
+
app = FastAPI(title="kodit API", lifespan=mcp_app.router.lifespan_context)
|
|
16
12
|
|
|
17
13
|
# Add middleware
|
|
18
14
|
app.middleware("http")(logging_middleware)
|
|
@@ -22,4 +18,8 @@ app.add_middleware(CorrelationIdMiddleware)
|
|
|
22
18
|
@app.get("/")
|
|
23
19
|
async def root() -> dict[str, str]:
|
|
24
20
|
"""Return a welcome message for the kodit API."""
|
|
25
|
-
return {"message": "
|
|
21
|
+
return {"message": "Hello, World!"}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Add mcp routes last, otherwise previous routes aren't added
|
|
25
|
+
app.mount("", mcp_app)
|
kodit/bm25/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""BM25 module."""
|
kodit/bm25/bm25.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""BM25 service."""
|
|
2
|
+
|
|
3
|
+
import bm25s
|
|
4
|
+
import Stemmer
|
|
5
|
+
import structlog
|
|
6
|
+
from bm25s.tokenization import Tokenized
|
|
7
|
+
|
|
8
|
+
from kodit.config import Config
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BM25Service:
|
|
12
|
+
"""Service for BM25."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: Config) -> None:
|
|
15
|
+
"""Initialize the BM25 service."""
|
|
16
|
+
self.log = structlog.get_logger(__name__)
|
|
17
|
+
self.index_path = config.get_data_dir() / "bm25s_index"
|
|
18
|
+
try:
|
|
19
|
+
self.log.debug("Loading BM25 index")
|
|
20
|
+
self.retriever = bm25s.BM25.load(self.index_path, mmap=True)
|
|
21
|
+
except FileNotFoundError:
|
|
22
|
+
self.log.debug("BM25 index not found, creating new index")
|
|
23
|
+
self.retriever = bm25s.BM25()
|
|
24
|
+
|
|
25
|
+
self.stemmer = Stemmer.Stemmer("english")
|
|
26
|
+
|
|
27
|
+
def _tokenize(self, corpus: list[str]) -> list[list[str]] | Tokenized:
|
|
28
|
+
return bm25s.tokenize(
|
|
29
|
+
corpus,
|
|
30
|
+
stopwords="en",
|
|
31
|
+
stemmer=self.stemmer,
|
|
32
|
+
return_ids=False,
|
|
33
|
+
show_progress=True,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def index(self, corpus: list[str]) -> None:
|
|
37
|
+
"""Index a new corpus."""
|
|
38
|
+
self.log.debug("Indexing corpus")
|
|
39
|
+
vocab = self._tokenize(corpus)
|
|
40
|
+
self.retriever = bm25s.BM25()
|
|
41
|
+
self.retriever.index(vocab)
|
|
42
|
+
self.retriever.save(self.index_path)
|
|
43
|
+
|
|
44
|
+
def retrieve(
|
|
45
|
+
self, doc_ids: list[int], query: str, top_k: int = 2
|
|
46
|
+
) -> list[tuple[int, float]]:
|
|
47
|
+
"""Retrieve from the index."""
|
|
48
|
+
if top_k == 0:
|
|
49
|
+
self.log.warning("Top k is 0, returning empty list")
|
|
50
|
+
return []
|
|
51
|
+
if len(doc_ids) == 0:
|
|
52
|
+
self.log.warning("No documents to retrieve from, returning empty list")
|
|
53
|
+
return []
|
|
54
|
+
|
|
55
|
+
top_k = min(top_k, len(doc_ids))
|
|
56
|
+
self.log.debug(
|
|
57
|
+
"Retrieving from index", query=query, top_k=top_k, num_docs=len(doc_ids)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
query_tokens = self._tokenize([query])
|
|
61
|
+
|
|
62
|
+
self.log.debug("Query tokens", query_tokens=query_tokens)
|
|
63
|
+
|
|
64
|
+
results, scores = self.retriever.retrieve(
|
|
65
|
+
query_tokens=query_tokens, corpus=doc_ids, k=top_k
|
|
66
|
+
)
|
|
67
|
+
self.log.debug("Raw results", results=results, scores=scores)
|
|
68
|
+
return [
|
|
69
|
+
(int(result), float(score))
|
|
70
|
+
for result, score in zip(results[0], scores[0], strict=False)
|
|
71
|
+
]
|
kodit/cli.py
CHANGED
|
@@ -1,41 +1,74 @@
|
|
|
1
1
|
"""Command line interface for kodit."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
+
import signal
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
4
7
|
|
|
5
8
|
import click
|
|
6
9
|
import structlog
|
|
7
10
|
import uvicorn
|
|
8
|
-
from dotenv import dotenv_values
|
|
9
11
|
from pytable_formatter import Table
|
|
10
12
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
11
13
|
|
|
12
|
-
from kodit.
|
|
14
|
+
from kodit.config import (
|
|
15
|
+
DEFAULT_BASE_DIR,
|
|
16
|
+
DEFAULT_DB_URL,
|
|
17
|
+
DEFAULT_DISABLE_TELEMETRY,
|
|
18
|
+
DEFAULT_LOG_FORMAT,
|
|
19
|
+
DEFAULT_LOG_LEVEL,
|
|
20
|
+
get_config,
|
|
21
|
+
reset_config,
|
|
22
|
+
with_session,
|
|
23
|
+
)
|
|
13
24
|
from kodit.indexing.repository import IndexRepository
|
|
14
25
|
from kodit.indexing.service import IndexService
|
|
15
|
-
from kodit.logging import
|
|
26
|
+
from kodit.logging import configure_logging, configure_telemetry, log_event
|
|
16
27
|
from kodit.retreival.repository import RetrievalRepository
|
|
17
28
|
from kodit.retreival.service import RetrievalRequest, RetrievalService
|
|
18
29
|
from kodit.sources.repository import SourceRepository
|
|
19
30
|
from kodit.sources.service import SourceService
|
|
20
31
|
|
|
21
|
-
env_vars = dict(dotenv_values())
|
|
22
|
-
os.environ.update(env_vars)
|
|
23
32
|
|
|
24
|
-
|
|
25
|
-
@click.
|
|
26
|
-
@click.option("--log-
|
|
27
|
-
@click.option(
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
+
@click.group(context_settings={"max_content_width": 100})
|
|
34
|
+
@click.option("--log-level", help=f"Log level [default: {DEFAULT_LOG_LEVEL}]")
|
|
35
|
+
@click.option("--log-format", help=f"Log format [default: {DEFAULT_LOG_FORMAT}]")
|
|
36
|
+
@click.option(
|
|
37
|
+
"--disable-telemetry",
|
|
38
|
+
is_flag=True,
|
|
39
|
+
help=f"Disable telemetry [default: {DEFAULT_DISABLE_TELEMETRY}]",
|
|
40
|
+
)
|
|
41
|
+
@click.option("--db-url", help=f"Database URL [default: {DEFAULT_DB_URL}]")
|
|
42
|
+
@click.option("--data-dir", help=f"Data directory [default: {DEFAULT_BASE_DIR}]")
|
|
43
|
+
@click.option("--env-file", help="Path to a .env file [default: .env]")
|
|
44
|
+
def cli( # noqa: PLR0913
|
|
45
|
+
log_level: str | None,
|
|
46
|
+
log_format: str | None,
|
|
47
|
+
disable_telemetry: bool | None,
|
|
48
|
+
db_url: str | None,
|
|
49
|
+
data_dir: str | None,
|
|
50
|
+
env_file: str | None,
|
|
33
51
|
) -> None:
|
|
34
52
|
"""kodit CLI - Code indexing for better AI code generation.""" # noqa: D403
|
|
35
|
-
|
|
53
|
+
# First check if env-file is set and reload config if it is
|
|
54
|
+
if env_file:
|
|
55
|
+
reset_config()
|
|
56
|
+
get_config(env_file)
|
|
57
|
+
|
|
58
|
+
# Override global config with cli args, if set
|
|
59
|
+
config = get_config()
|
|
60
|
+
if data_dir:
|
|
61
|
+
config.data_dir = Path(data_dir)
|
|
62
|
+
if db_url:
|
|
63
|
+
config.db_url = db_url
|
|
64
|
+
if log_level:
|
|
65
|
+
config.log_level = log_level
|
|
66
|
+
if log_format:
|
|
67
|
+
config.log_format = log_format
|
|
36
68
|
if disable_telemetry:
|
|
37
|
-
|
|
38
|
-
|
|
69
|
+
config.disable_telemetry = disable_telemetry
|
|
70
|
+
configure_logging(config)
|
|
71
|
+
configure_telemetry(config)
|
|
39
72
|
|
|
40
73
|
|
|
41
74
|
@cli.group()
|
|
@@ -48,7 +81,7 @@ def sources() -> None:
|
|
|
48
81
|
async def list_sources(session: AsyncSession) -> None:
|
|
49
82
|
"""List all code sources."""
|
|
50
83
|
repository = SourceRepository(session)
|
|
51
|
-
service = SourceService(repository)
|
|
84
|
+
service = SourceService(get_config().get_clone_dir(), repository)
|
|
52
85
|
sources = await service.list_sources()
|
|
53
86
|
|
|
54
87
|
# Define headers and data
|
|
@@ -66,7 +99,7 @@ async def list_sources(session: AsyncSession) -> None:
|
|
|
66
99
|
async def create_source(session: AsyncSession, uri: str) -> None:
|
|
67
100
|
"""Add a new code source."""
|
|
68
101
|
repository = SourceRepository(session)
|
|
69
|
-
service = SourceService(repository)
|
|
102
|
+
service = SourceService(get_config().get_clone_dir(), repository)
|
|
70
103
|
source = await service.create(uri)
|
|
71
104
|
click.echo(f"Source created: {source.id}")
|
|
72
105
|
|
|
@@ -82,9 +115,9 @@ def indexes() -> None:
|
|
|
82
115
|
async def create_index(session: AsyncSession, source_id: int) -> None:
|
|
83
116
|
"""Create an index for a source."""
|
|
84
117
|
source_repository = SourceRepository(session)
|
|
85
|
-
source_service = SourceService(source_repository)
|
|
118
|
+
source_service = SourceService(get_config().get_clone_dir(), source_repository)
|
|
86
119
|
repository = IndexRepository(session)
|
|
87
|
-
service = IndexService(repository, source_service)
|
|
120
|
+
service = IndexService(get_config(), repository, source_service)
|
|
88
121
|
index = await service.create(source_id)
|
|
89
122
|
click.echo(f"Index created: {index.id}")
|
|
90
123
|
|
|
@@ -94,9 +127,9 @@ async def create_index(session: AsyncSession, source_id: int) -> None:
|
|
|
94
127
|
async def list_indexes(session: AsyncSession) -> None:
|
|
95
128
|
"""List all indexes."""
|
|
96
129
|
source_repository = SourceRepository(session)
|
|
97
|
-
source_service = SourceService(source_repository)
|
|
130
|
+
source_service = SourceService(get_config().get_clone_dir(), source_repository)
|
|
98
131
|
repository = IndexRepository(session)
|
|
99
|
-
service = IndexService(repository, source_service)
|
|
132
|
+
service = IndexService(get_config(), repository, source_service)
|
|
100
133
|
indexes = await service.list_indexes()
|
|
101
134
|
|
|
102
135
|
# Define headers and data
|
|
@@ -104,7 +137,6 @@ async def list_indexes(session: AsyncSession) -> None:
|
|
|
104
137
|
"ID",
|
|
105
138
|
"Created At",
|
|
106
139
|
"Updated At",
|
|
107
|
-
"Source URI",
|
|
108
140
|
"Num Snippets",
|
|
109
141
|
]
|
|
110
142
|
data = [
|
|
@@ -112,7 +144,6 @@ async def list_indexes(session: AsyncSession) -> None:
|
|
|
112
144
|
index.id,
|
|
113
145
|
index.created_at,
|
|
114
146
|
index.updated_at,
|
|
115
|
-
index.source_uri,
|
|
116
147
|
index.num_snippets,
|
|
117
148
|
]
|
|
118
149
|
for index in indexes
|
|
@@ -129,48 +160,69 @@ async def list_indexes(session: AsyncSession) -> None:
|
|
|
129
160
|
async def run_index(session: AsyncSession, index_id: int) -> None:
|
|
130
161
|
"""Run an index."""
|
|
131
162
|
source_repository = SourceRepository(session)
|
|
132
|
-
source_service = SourceService(source_repository)
|
|
163
|
+
source_service = SourceService(get_config().get_clone_dir(), source_repository)
|
|
133
164
|
repository = IndexRepository(session)
|
|
134
|
-
service = IndexService(repository, source_service)
|
|
165
|
+
service = IndexService(get_config(), repository, source_service)
|
|
135
166
|
await service.run(index_id)
|
|
136
167
|
|
|
137
168
|
|
|
138
169
|
@cli.command()
|
|
139
170
|
@click.argument("query")
|
|
171
|
+
@click.option("--top-k", default=10, help="Number of snippets to retrieve")
|
|
140
172
|
@with_session
|
|
141
|
-
async def retrieve(session: AsyncSession, query: str) -> None:
|
|
173
|
+
async def retrieve(session: AsyncSession, query: str, top_k: int) -> None:
|
|
142
174
|
"""Retrieve snippets from the database."""
|
|
143
175
|
repository = RetrievalRepository(session)
|
|
144
|
-
service = RetrievalService(repository)
|
|
145
|
-
|
|
176
|
+
service = RetrievalService(get_config(), repository)
|
|
177
|
+
# Temporary request while we don't have all search capabilities
|
|
178
|
+
snippets = await service.retrieve(
|
|
179
|
+
RetrievalRequest(keywords=query.split(","), top_k=top_k)
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if len(snippets) == 0:
|
|
183
|
+
click.echo("No snippets found")
|
|
184
|
+
return
|
|
146
185
|
|
|
147
186
|
for snippet in snippets:
|
|
187
|
+
click.echo("-" * 80)
|
|
148
188
|
click.echo(f"{snippet.uri}")
|
|
149
189
|
click.echo(snippet.content)
|
|
190
|
+
click.echo("-" * 80)
|
|
150
191
|
click.echo()
|
|
151
192
|
|
|
152
193
|
|
|
153
194
|
@cli.command()
|
|
154
195
|
@click.option("--host", default="127.0.0.1", help="Host to bind the server to")
|
|
155
196
|
@click.option("--port", default=8080, help="Port to bind the server to")
|
|
156
|
-
@click.option("--reload", is_flag=True, help="Enable auto-reload for development")
|
|
157
197
|
def serve(
|
|
158
198
|
host: str,
|
|
159
199
|
port: int,
|
|
160
|
-
reload: bool, # noqa: FBT001
|
|
161
200
|
) -> None:
|
|
162
201
|
"""Start the kodit server, which hosts the MCP server and the kodit API."""
|
|
163
202
|
log = structlog.get_logger(__name__)
|
|
164
|
-
log.info("Starting kodit server", host=host, port=port
|
|
203
|
+
log.info("Starting kodit server", host=host, port=port)
|
|
165
204
|
log_event("kodit_server_started")
|
|
166
|
-
|
|
205
|
+
os.environ["HELLO"] = "WORLD"
|
|
206
|
+
|
|
207
|
+
# Configure uvicorn with graceful shutdown
|
|
208
|
+
config = uvicorn.Config(
|
|
167
209
|
"kodit.app:app",
|
|
168
210
|
host=host,
|
|
169
211
|
port=port,
|
|
170
|
-
reload=
|
|
212
|
+
reload=False,
|
|
171
213
|
log_config=None, # Setting to None forces uvicorn to use our structlog setup
|
|
172
214
|
access_log=False, # Using own middleware for access logging
|
|
215
|
+
timeout_graceful_shutdown=0, # The mcp server does not shutdown cleanly, force
|
|
173
216
|
)
|
|
217
|
+
server = uvicorn.Server(config)
|
|
218
|
+
|
|
219
|
+
def handle_sigint(signum: int, frame: Any) -> None:
|
|
220
|
+
"""Handle SIGINT (Ctrl+C)."""
|
|
221
|
+
log.info("Received shutdown signal, force killing MCP connections")
|
|
222
|
+
server.handle_exit(signum, frame)
|
|
223
|
+
|
|
224
|
+
signal.signal(signal.SIGINT, handle_sigint)
|
|
225
|
+
server.run()
|
|
174
226
|
|
|
175
227
|
|
|
176
228
|
@cli.command()
|
kodit/config.py
CHANGED
|
@@ -1,5 +1,89 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Global configuration for the kodit project."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from functools import wraps
|
|
3
6
|
from pathlib import Path
|
|
7
|
+
from typing import Any, TypeVar
|
|
4
8
|
|
|
5
|
-
|
|
9
|
+
from pydantic import Field
|
|
10
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
11
|
+
|
|
12
|
+
from kodit.database import Database
|
|
13
|
+
|
|
14
|
+
DEFAULT_BASE_DIR = Path.home() / ".kodit"
|
|
15
|
+
DEFAULT_DB_URL = f"sqlite+aiosqlite:///{DEFAULT_BASE_DIR}/kodit.db"
|
|
16
|
+
DEFAULT_LOG_LEVEL = "INFO"
|
|
17
|
+
DEFAULT_LOG_FORMAT = "pretty"
|
|
18
|
+
DEFAULT_DISABLE_TELEMETRY = False
|
|
19
|
+
T = TypeVar("T")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Config(BaseSettings):
|
|
23
|
+
"""Global configuration for the kodit project."""
|
|
24
|
+
|
|
25
|
+
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
|
26
|
+
|
|
27
|
+
data_dir: Path = Field(default=DEFAULT_BASE_DIR)
|
|
28
|
+
db_url: str = Field(default=DEFAULT_DB_URL)
|
|
29
|
+
log_level: str = Field(default=DEFAULT_LOG_LEVEL)
|
|
30
|
+
log_format: str = Field(default=DEFAULT_LOG_FORMAT)
|
|
31
|
+
disable_telemetry: bool = Field(default=DEFAULT_DISABLE_TELEMETRY)
|
|
32
|
+
_db: Database | None = None
|
|
33
|
+
|
|
34
|
+
def model_post_init(self, _: Any) -> None:
|
|
35
|
+
"""Post-initialization hook."""
|
|
36
|
+
# Call this to ensure the data dir exists for the default db location
|
|
37
|
+
self.get_data_dir()
|
|
38
|
+
|
|
39
|
+
def get_data_dir(self) -> Path:
|
|
40
|
+
"""Get the data directory."""
|
|
41
|
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
return self.data_dir
|
|
43
|
+
|
|
44
|
+
def get_clone_dir(self) -> Path:
|
|
45
|
+
"""Get the clone directory."""
|
|
46
|
+
clone_dir = self.get_data_dir() / "clones"
|
|
47
|
+
clone_dir.mkdir(parents=True, exist_ok=True)
|
|
48
|
+
return clone_dir
|
|
49
|
+
|
|
50
|
+
def get_db(self, *, run_migrations: bool = True) -> Database:
|
|
51
|
+
"""Get the database."""
|
|
52
|
+
if self._db is None:
|
|
53
|
+
self._db = Database(self.db_url, run_migrations=run_migrations)
|
|
54
|
+
return self._db
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Global config instance for mcp Apps
|
|
58
|
+
config = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_config(env_file: str | None = None) -> Config:
|
|
62
|
+
"""Get the global config instance."""
|
|
63
|
+
global config # noqa: PLW0603
|
|
64
|
+
if config is None:
|
|
65
|
+
config = Config(_env_file=env_file)
|
|
66
|
+
return config
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def reset_config() -> None:
|
|
70
|
+
"""Reset the global config instance."""
|
|
71
|
+
global config # noqa: PLW0603
|
|
72
|
+
config = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def with_session(func: Callable[..., T]) -> Callable[..., T]:
|
|
76
|
+
"""Provide an async session to CLI commands."""
|
|
77
|
+
|
|
78
|
+
@wraps(func)
|
|
79
|
+
def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
80
|
+
# Create DB connection before starting event loop
|
|
81
|
+
db = get_config().get_db()
|
|
82
|
+
|
|
83
|
+
async def _run() -> T:
|
|
84
|
+
async with db.get_session() as session:
|
|
85
|
+
return await func(session, *args, **kwargs)
|
|
86
|
+
|
|
87
|
+
return asyncio.run(_run())
|
|
88
|
+
|
|
89
|
+
return wrapper
|
kodit/database.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
"""Database configuration for kodit."""
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
from
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
5
|
from datetime import UTC, datetime
|
|
6
|
-
from functools import wraps
|
|
7
6
|
from pathlib import Path
|
|
8
|
-
from typing import Any, TypeVar
|
|
9
7
|
|
|
8
|
+
import structlog
|
|
10
9
|
from alembic import command
|
|
11
|
-
from alembic.config import Config
|
|
10
|
+
from alembic.config import Config as AlembicConfig
|
|
12
11
|
from sqlalchemy import DateTime
|
|
13
12
|
from sqlalchemy.ext.asyncio import (
|
|
14
13
|
AsyncAttrs,
|
|
@@ -19,23 +18,6 @@ from sqlalchemy.ext.asyncio import (
|
|
|
19
18
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
20
19
|
|
|
21
20
|
from kodit import alembic
|
|
22
|
-
from kodit.config import DATA_DIR
|
|
23
|
-
|
|
24
|
-
# Constants
|
|
25
|
-
DB_URL = f"sqlite+aiosqlite:///{DATA_DIR}/kodit.db"
|
|
26
|
-
|
|
27
|
-
# Create data directory if it doesn't exist
|
|
28
|
-
DATA_DIR.mkdir(exist_ok=True)
|
|
29
|
-
|
|
30
|
-
# Create async engine with file-based SQLite
|
|
31
|
-
engine = create_async_engine(DB_URL, echo=False)
|
|
32
|
-
|
|
33
|
-
# Create async session factory
|
|
34
|
-
async_session_factory = async_sessionmaker(
|
|
35
|
-
engine,
|
|
36
|
-
class_=AsyncSession,
|
|
37
|
-
expire_on_commit=False,
|
|
38
|
-
)
|
|
39
21
|
|
|
40
22
|
|
|
41
23
|
class Base(AsyncAttrs, DeclarativeBase):
|
|
@@ -54,36 +36,37 @@ class CommonMixin:
|
|
|
54
36
|
)
|
|
55
37
|
|
|
56
38
|
|
|
57
|
-
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@
|
|
73
|
-
def
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
39
|
+
class Database:
|
|
40
|
+
"""Database class for kodit."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, db_url: str, *, run_migrations: bool = True) -> None:
|
|
43
|
+
"""Initialize the database."""
|
|
44
|
+
self.log = structlog.get_logger(__name__)
|
|
45
|
+
if run_migrations:
|
|
46
|
+
self._run_migrations(db_url)
|
|
47
|
+
db_engine = create_async_engine(db_url, echo=False)
|
|
48
|
+
self.db_session_factory = async_sessionmaker(
|
|
49
|
+
db_engine,
|
|
50
|
+
class_=AsyncSession,
|
|
51
|
+
expire_on_commit=False,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
@asynccontextmanager
|
|
55
|
+
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
56
|
+
"""Get a database session."""
|
|
57
|
+
async with self.db_session_factory() as session:
|
|
58
|
+
try:
|
|
59
|
+
yield session
|
|
60
|
+
finally:
|
|
61
|
+
await session.close()
|
|
62
|
+
|
|
63
|
+
def _run_migrations(self, db_url: str) -> None:
|
|
64
|
+
"""Run any pending migrations."""
|
|
65
|
+
# Create Alembic configuration and run migrations
|
|
66
|
+
alembic_cfg = AlembicConfig()
|
|
67
|
+
alembic_cfg.set_main_option(
|
|
68
|
+
"script_location", str(Path(alembic.__file__).parent)
|
|
69
|
+
)
|
|
70
|
+
alembic_cfg.set_main_option("sqlalchemy.url", db_url)
|
|
71
|
+
self.log.debug("Running migrations", db_url=db_url)
|
|
72
|
+
command.upgrade(alembic_cfg, "head")
|
kodit/indexing/repository.py
CHANGED
|
@@ -130,3 +130,14 @@ class IndexRepository:
|
|
|
130
130
|
query = select(Snippet).where(Snippet.index_id == index_id)
|
|
131
131
|
result = await self.session.execute(query)
|
|
132
132
|
return list(result.scalars())
|
|
133
|
+
|
|
134
|
+
async def get_all_snippets(self) -> list[Snippet]:
|
|
135
|
+
"""Get all snippets.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
A list of all snippets.
|
|
139
|
+
|
|
140
|
+
"""
|
|
141
|
+
query = select(Snippet).order_by(Snippet.id)
|
|
142
|
+
result = await self.session.execute(query)
|
|
143
|
+
return list(result.scalars())
|
kodit/indexing/service.py
CHANGED
|
@@ -7,24 +7,21 @@ index management.
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from datetime import datetime
|
|
10
|
+
from pathlib import Path
|
|
10
11
|
|
|
11
|
-
import aiofiles
|
|
12
12
|
import pydantic
|
|
13
13
|
import structlog
|
|
14
14
|
from tqdm.asyncio import tqdm
|
|
15
15
|
|
|
16
|
+
from kodit.bm25.bm25 import BM25Service
|
|
17
|
+
from kodit.config import Config
|
|
16
18
|
from kodit.indexing.models import Snippet
|
|
17
19
|
from kodit.indexing.repository import IndexRepository
|
|
20
|
+
from kodit.snippets.snippets import SnippetService
|
|
18
21
|
from kodit.sources.service import SourceService
|
|
19
22
|
|
|
20
|
-
# List of MIME types that are
|
|
21
|
-
|
|
22
|
-
"text/plain",
|
|
23
|
-
"text/markdown",
|
|
24
|
-
"text/x-python",
|
|
25
|
-
"text/x-shellscript",
|
|
26
|
-
"text/x-sql",
|
|
27
|
-
]
|
|
23
|
+
# List of MIME types that are blacklisted from being indexed
|
|
24
|
+
MIME_BLACKLIST = ["unknown/unknown"]
|
|
28
25
|
|
|
29
26
|
|
|
30
27
|
class IndexView(pydantic.BaseModel):
|
|
@@ -37,7 +34,6 @@ class IndexView(pydantic.BaseModel):
|
|
|
37
34
|
id: int
|
|
38
35
|
created_at: datetime
|
|
39
36
|
updated_at: datetime | None = None
|
|
40
|
-
source_uri: str | None = None
|
|
41
37
|
num_snippets: int | None = None
|
|
42
38
|
|
|
43
39
|
|
|
@@ -50,7 +46,7 @@ class IndexService:
|
|
|
50
46
|
"""
|
|
51
47
|
|
|
52
48
|
def __init__(
|
|
53
|
-
self, repository: IndexRepository, source_service: SourceService
|
|
49
|
+
self, config: Config, repository: IndexRepository, source_service: SourceService
|
|
54
50
|
) -> None:
|
|
55
51
|
"""Initialize the index service.
|
|
56
52
|
|
|
@@ -61,7 +57,9 @@ class IndexService:
|
|
|
61
57
|
"""
|
|
62
58
|
self.repository = repository
|
|
63
59
|
self.source_service = source_service
|
|
60
|
+
self.snippet_service = SnippetService()
|
|
64
61
|
self.log = structlog.get_logger(__name__)
|
|
62
|
+
self.bm25 = BM25Service(config)
|
|
65
63
|
|
|
66
64
|
async def create(self, source_id: int) -> IndexView:
|
|
67
65
|
"""Create a new index for a source.
|
|
@@ -120,6 +118,10 @@ class IndexService:
|
|
|
120
118
|
# Create snippets for supported file types
|
|
121
119
|
await self._create_snippets(index_id)
|
|
122
120
|
|
|
121
|
+
# Update BM25 index
|
|
122
|
+
snippets = await self.repository.get_all_snippets()
|
|
123
|
+
self.bm25.index([snippet.content for snippet in snippets])
|
|
124
|
+
|
|
123
125
|
# Update index timestamp
|
|
124
126
|
await self.repository.update_index_timestamp(index)
|
|
125
127
|
|
|
@@ -138,16 +140,23 @@ class IndexService:
|
|
|
138
140
|
files = await self.repository.files_for_index(index_id)
|
|
139
141
|
for file in tqdm(files, total=len(files)):
|
|
140
142
|
# Skip unsupported file types
|
|
141
|
-
if file.mime_type
|
|
143
|
+
if file.mime_type in MIME_BLACKLIST:
|
|
142
144
|
self.log.debug("Skipping mime type", mime_type=file.mime_type)
|
|
143
145
|
continue
|
|
144
146
|
|
|
145
147
|
# Create snippet from file content
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
148
|
+
try:
|
|
149
|
+
snippets = self.snippet_service.snippets_for_file(
|
|
150
|
+
Path(file.cloned_path)
|
|
151
|
+
)
|
|
152
|
+
except ValueError as e:
|
|
153
|
+
self.log.debug("Skipping file", file=file.cloned_path, error=e)
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
for snippet in snippets:
|
|
157
|
+
s = Snippet(
|
|
149
158
|
index_id=index_id,
|
|
150
159
|
file_id=file.id,
|
|
151
|
-
content=
|
|
160
|
+
content=snippet.text,
|
|
152
161
|
)
|
|
153
|
-
await self.repository.add_snippet(
|
|
162
|
+
await self.repository.add_snippet(s)
|