kodit 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/alembic/env.py +5 -4
- kodit/app.py +13 -9
- kodit/bm25/__init__.py +1 -0
- kodit/bm25/bm25.py +71 -0
- kodit/cli.py +124 -38
- kodit/config.py +94 -2
- kodit/database.py +41 -57
- kodit/indexing/repository.py +11 -0
- kodit/indexing/service.py +28 -16
- kodit/logging.py +20 -18
- kodit/mcp.py +84 -34
- kodit/middleware.py +16 -0
- kodit/retreival/repository.py +32 -0
- kodit/retreival/service.py +42 -3
- kodit/snippets/__init__.py +1 -0
- kodit/snippets/languages/__init__.py +53 -0
- kodit/snippets/languages/csharp.scm +12 -0
- kodit/snippets/languages/python.scm +22 -0
- kodit/snippets/method_snippets.py +120 -0
- kodit/snippets/snippets.py +48 -0
- kodit/sources/service.py +3 -5
- {kodit-0.1.4.dist-info → kodit-0.1.6.dist-info}/METADATA +6 -2
- kodit-0.1.6.dist-info/RECORD +40 -0
- kodit/sse.py +0 -61
- kodit-0.1.4.dist-info/RECORD +0 -33
- {kodit-0.1.4.dist-info → kodit-0.1.6.dist-info}/WHEEL +0 -0
- {kodit-0.1.4.dist-info → kodit-0.1.6.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.4.dist-info → kodit-0.1.6.dist-info}/licenses/LICENSE +0 -0
kodit/_version.py
CHANGED
kodit/alembic/env.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
|
|
4
4
|
import asyncio
|
|
5
5
|
|
|
6
|
-
import structlog
|
|
7
6
|
from alembic import context
|
|
8
7
|
from sqlalchemy import pool
|
|
9
8
|
from sqlalchemy.engine import Connection
|
|
@@ -66,8 +65,6 @@ async def run_async_migrations() -> None:
|
|
|
66
65
|
prefix="sqlalchemy.",
|
|
67
66
|
poolclass=pool.NullPool,
|
|
68
67
|
)
|
|
69
|
-
log = structlog.get_logger(__name__)
|
|
70
|
-
log.debug("Running migrations on %s", connectable.url)
|
|
71
68
|
|
|
72
69
|
async with connectable.connect() as connection:
|
|
73
70
|
await connection.run_sync(do_run_migrations)
|
|
@@ -77,7 +74,11 @@ async def run_async_migrations() -> None:
|
|
|
77
74
|
|
|
78
75
|
def run_migrations_online() -> None:
|
|
79
76
|
"""Run migrations in 'online' mode."""
|
|
80
|
-
|
|
77
|
+
connectable = config.attributes.get("connection", None)
|
|
78
|
+
if connectable is None:
|
|
79
|
+
asyncio.run(run_async_migrations())
|
|
80
|
+
else:
|
|
81
|
+
do_run_migrations(connectable)
|
|
81
82
|
|
|
82
83
|
|
|
83
84
|
if context.is_offline_mode():
|
kodit/app.py
CHANGED
|
@@ -4,15 +4,11 @@ from asgi_correlation_id import CorrelationIdMiddleware
|
|
|
4
4
|
from fastapi import FastAPI
|
|
5
5
|
|
|
6
6
|
from kodit.mcp import mcp
|
|
7
|
-
from kodit.middleware import logging_middleware
|
|
8
|
-
from kodit.sse import create_sse_server
|
|
7
|
+
from kodit.middleware import ASGICancelledErrorMiddleware, logging_middleware
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
sse_app = create_sse_server(mcp)
|
|
14
|
-
for route in sse_app.routes:
|
|
15
|
-
app.router.routes.append(route)
|
|
9
|
+
# See https://gofastmcp.com/deployment/asgi#fastapi-integration
|
|
10
|
+
mcp_app = mcp.sse_app()
|
|
11
|
+
app = FastAPI(title="kodit API", lifespan=mcp_app.router.lifespan_context)
|
|
16
12
|
|
|
17
13
|
# Add middleware
|
|
18
14
|
app.middleware("http")(logging_middleware)
|
|
@@ -22,4 +18,12 @@ app.add_middleware(CorrelationIdMiddleware)
|
|
|
22
18
|
@app.get("/")
|
|
23
19
|
async def root() -> dict[str, str]:
|
|
24
20
|
"""Return a welcome message for the kodit API."""
|
|
25
|
-
return {"message": "
|
|
21
|
+
return {"message": "Hello, World!"}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Add mcp routes last, otherwise previous routes aren't added
|
|
25
|
+
app.mount("", mcp_app)
|
|
26
|
+
|
|
27
|
+
# Wrap the entire app with ASGI middleware after all routes are added to suppress
|
|
28
|
+
# CancelledError at the ASGI level
|
|
29
|
+
app = ASGICancelledErrorMiddleware(app)
|
kodit/bm25/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""BM25 module."""
|
kodit/bm25/bm25.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""BM25 service."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import bm25s
|
|
6
|
+
import Stemmer
|
|
7
|
+
import structlog
|
|
8
|
+
from bm25s.tokenization import Tokenized
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BM25Service:
|
|
12
|
+
"""Service for BM25."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, data_dir: Path) -> None:
|
|
15
|
+
"""Initialize the BM25 service."""
|
|
16
|
+
self.log = structlog.get_logger(__name__)
|
|
17
|
+
self.index_path = data_dir / "bm25s_index"
|
|
18
|
+
try:
|
|
19
|
+
self.log.debug("Loading BM25 index")
|
|
20
|
+
self.retriever = bm25s.BM25.load(self.index_path, mmap=True)
|
|
21
|
+
except FileNotFoundError:
|
|
22
|
+
self.log.debug("BM25 index not found, creating new index")
|
|
23
|
+
self.retriever = bm25s.BM25()
|
|
24
|
+
|
|
25
|
+
self.stemmer = Stemmer.Stemmer("english")
|
|
26
|
+
|
|
27
|
+
def _tokenize(self, corpus: list[str]) -> list[list[str]] | Tokenized:
|
|
28
|
+
return bm25s.tokenize(
|
|
29
|
+
corpus,
|
|
30
|
+
stopwords="en",
|
|
31
|
+
stemmer=self.stemmer,
|
|
32
|
+
return_ids=False,
|
|
33
|
+
show_progress=True,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def index(self, corpus: list[str]) -> None:
|
|
37
|
+
"""Index a new corpus."""
|
|
38
|
+
self.log.debug("Indexing corpus")
|
|
39
|
+
vocab = self._tokenize(corpus)
|
|
40
|
+
self.retriever = bm25s.BM25()
|
|
41
|
+
self.retriever.index(vocab)
|
|
42
|
+
self.retriever.save(self.index_path)
|
|
43
|
+
|
|
44
|
+
def retrieve(
|
|
45
|
+
self, doc_ids: list[int], query: str, top_k: int = 2
|
|
46
|
+
) -> list[tuple[int, float]]:
|
|
47
|
+
"""Retrieve from the index."""
|
|
48
|
+
if top_k == 0:
|
|
49
|
+
self.log.warning("Top k is 0, returning empty list")
|
|
50
|
+
return []
|
|
51
|
+
if len(doc_ids) == 0:
|
|
52
|
+
self.log.warning("No documents to retrieve from, returning empty list")
|
|
53
|
+
return []
|
|
54
|
+
|
|
55
|
+
top_k = min(top_k, len(doc_ids))
|
|
56
|
+
self.log.debug(
|
|
57
|
+
"Retrieving from index", query=query, top_k=top_k, num_docs=len(doc_ids)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
query_tokens = self._tokenize([query])
|
|
61
|
+
|
|
62
|
+
self.log.debug("Query tokens", query_tokens=query_tokens)
|
|
63
|
+
|
|
64
|
+
results, scores = self.retriever.retrieve(
|
|
65
|
+
query_tokens=query_tokens, corpus=doc_ids, k=top_k
|
|
66
|
+
)
|
|
67
|
+
self.log.debug("Raw results", results=results, scores=scores)
|
|
68
|
+
return [
|
|
69
|
+
(int(result), float(score))
|
|
70
|
+
for result, score in zip(results[0], scores[0], strict=False)
|
|
71
|
+
]
|
kodit/cli.py
CHANGED
|
@@ -1,41 +1,87 @@
|
|
|
1
1
|
"""Command line interface for kodit."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
+
import signal
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
4
7
|
|
|
5
8
|
import click
|
|
6
9
|
import structlog
|
|
7
10
|
import uvicorn
|
|
8
|
-
from dotenv import dotenv_values
|
|
9
11
|
from pytable_formatter import Table
|
|
10
12
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
11
13
|
|
|
12
|
-
from kodit.
|
|
14
|
+
from kodit.config import (
|
|
15
|
+
DEFAULT_BASE_DIR,
|
|
16
|
+
DEFAULT_DB_URL,
|
|
17
|
+
DEFAULT_DISABLE_TELEMETRY,
|
|
18
|
+
DEFAULT_LOG_FORMAT,
|
|
19
|
+
DEFAULT_LOG_LEVEL,
|
|
20
|
+
AppContext,
|
|
21
|
+
with_app_context,
|
|
22
|
+
with_session,
|
|
23
|
+
)
|
|
13
24
|
from kodit.indexing.repository import IndexRepository
|
|
14
25
|
from kodit.indexing.service import IndexService
|
|
15
|
-
from kodit.logging import
|
|
26
|
+
from kodit.logging import configure_logging, configure_telemetry, log_event
|
|
16
27
|
from kodit.retreival.repository import RetrievalRepository
|
|
17
28
|
from kodit.retreival.service import RetrievalRequest, RetrievalService
|
|
18
29
|
from kodit.sources.repository import SourceRepository
|
|
19
30
|
from kodit.sources.service import SourceService
|
|
20
31
|
|
|
21
|
-
env_vars = dict(dotenv_values())
|
|
22
|
-
os.environ.update(env_vars)
|
|
23
32
|
|
|
24
|
-
|
|
25
|
-
@click.
|
|
26
|
-
@click.option("--log-
|
|
27
|
-
@click.option(
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
+
@click.group(context_settings={"max_content_width": 100})
|
|
34
|
+
@click.option("--log-level", help=f"Log level [default: {DEFAULT_LOG_LEVEL}]")
|
|
35
|
+
@click.option("--log-format", help=f"Log format [default: {DEFAULT_LOG_FORMAT}]")
|
|
36
|
+
@click.option(
|
|
37
|
+
"--disable-telemetry",
|
|
38
|
+
is_flag=True,
|
|
39
|
+
help=f"Disable telemetry [default: {DEFAULT_DISABLE_TELEMETRY}]",
|
|
40
|
+
)
|
|
41
|
+
@click.option("--db-url", help=f"Database URL [default: {DEFAULT_DB_URL}]")
|
|
42
|
+
@click.option("--data-dir", help=f"Data directory [default: {DEFAULT_BASE_DIR}]")
|
|
43
|
+
@click.option(
|
|
44
|
+
"--env-file",
|
|
45
|
+
help="Path to a .env file [default: .env]",
|
|
46
|
+
type=click.Path(
|
|
47
|
+
exists=True,
|
|
48
|
+
dir_okay=False,
|
|
49
|
+
resolve_path=True,
|
|
50
|
+
path_type=Path,
|
|
51
|
+
),
|
|
52
|
+
)
|
|
53
|
+
@click.pass_context
|
|
54
|
+
def cli( # noqa: PLR0913
|
|
55
|
+
ctx: click.Context,
|
|
56
|
+
log_level: str | None,
|
|
57
|
+
log_format: str | None,
|
|
58
|
+
disable_telemetry: bool | None,
|
|
59
|
+
db_url: str | None,
|
|
60
|
+
data_dir: str | None,
|
|
61
|
+
env_file: Path | None,
|
|
33
62
|
) -> None:
|
|
34
63
|
"""kodit CLI - Code indexing for better AI code generation.""" # noqa: D403
|
|
35
|
-
|
|
64
|
+
config = AppContext()
|
|
65
|
+
# First check if env-file is set and reload config if it is
|
|
66
|
+
if env_file:
|
|
67
|
+
config = AppContext(_env_file=env_file) # type: ignore[reportCallIssue]
|
|
68
|
+
|
|
69
|
+
# Now override with CLI arguments, if set
|
|
70
|
+
if data_dir:
|
|
71
|
+
config.data_dir = Path(data_dir)
|
|
72
|
+
if db_url:
|
|
73
|
+
config.db_url = db_url
|
|
74
|
+
if log_level:
|
|
75
|
+
config.log_level = log_level
|
|
76
|
+
if log_format:
|
|
77
|
+
config.log_format = log_format
|
|
36
78
|
if disable_telemetry:
|
|
37
|
-
|
|
38
|
-
|
|
79
|
+
config.disable_telemetry = disable_telemetry
|
|
80
|
+
configure_logging(config)
|
|
81
|
+
configure_telemetry(config)
|
|
82
|
+
|
|
83
|
+
# Set the app context in the click context for downstream cli
|
|
84
|
+
ctx.obj = config
|
|
39
85
|
|
|
40
86
|
|
|
41
87
|
@cli.group()
|
|
@@ -44,11 +90,12 @@ def sources() -> None:
|
|
|
44
90
|
|
|
45
91
|
|
|
46
92
|
@sources.command(name="list")
|
|
93
|
+
@with_app_context
|
|
47
94
|
@with_session
|
|
48
|
-
async def list_sources(session: AsyncSession) -> None:
|
|
95
|
+
async def list_sources(session: AsyncSession, app_context: AppContext) -> None:
|
|
49
96
|
"""List all code sources."""
|
|
50
97
|
repository = SourceRepository(session)
|
|
51
|
-
service = SourceService(repository)
|
|
98
|
+
service = SourceService(app_context.get_clone_dir(), repository)
|
|
52
99
|
sources = await service.list_sources()
|
|
53
100
|
|
|
54
101
|
# Define headers and data
|
|
@@ -62,11 +109,14 @@ async def list_sources(session: AsyncSession) -> None:
|
|
|
62
109
|
|
|
63
110
|
@sources.command(name="create")
|
|
64
111
|
@click.argument("uri")
|
|
112
|
+
@with_app_context
|
|
65
113
|
@with_session
|
|
66
|
-
async def create_source(
|
|
114
|
+
async def create_source(
|
|
115
|
+
session: AsyncSession, app_context: AppContext, uri: str
|
|
116
|
+
) -> None:
|
|
67
117
|
"""Add a new code source."""
|
|
68
118
|
repository = SourceRepository(session)
|
|
69
|
-
service = SourceService(repository)
|
|
119
|
+
service = SourceService(app_context.get_clone_dir(), repository)
|
|
70
120
|
source = await service.create(uri)
|
|
71
121
|
click.echo(f"Source created: {source.id}")
|
|
72
122
|
|
|
@@ -78,25 +128,29 @@ def indexes() -> None:
|
|
|
78
128
|
|
|
79
129
|
@indexes.command(name="create")
|
|
80
130
|
@click.argument("source_id")
|
|
131
|
+
@with_app_context
|
|
81
132
|
@with_session
|
|
82
|
-
async def create_index(
|
|
133
|
+
async def create_index(
|
|
134
|
+
session: AsyncSession, app_context: AppContext, source_id: int
|
|
135
|
+
) -> None:
|
|
83
136
|
"""Create an index for a source."""
|
|
84
137
|
source_repository = SourceRepository(session)
|
|
85
|
-
source_service = SourceService(source_repository)
|
|
138
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
86
139
|
repository = IndexRepository(session)
|
|
87
|
-
service = IndexService(repository, source_service)
|
|
140
|
+
service = IndexService(repository, source_service, app_context.get_data_dir())
|
|
88
141
|
index = await service.create(source_id)
|
|
89
142
|
click.echo(f"Index created: {index.id}")
|
|
90
143
|
|
|
91
144
|
|
|
92
145
|
@indexes.command(name="list")
|
|
146
|
+
@with_app_context
|
|
93
147
|
@with_session
|
|
94
|
-
async def list_indexes(session: AsyncSession) -> None:
|
|
148
|
+
async def list_indexes(session: AsyncSession, app_context: AppContext) -> None:
|
|
95
149
|
"""List all indexes."""
|
|
96
150
|
source_repository = SourceRepository(session)
|
|
97
|
-
source_service = SourceService(source_repository)
|
|
151
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
98
152
|
repository = IndexRepository(session)
|
|
99
|
-
service = IndexService(repository, source_service)
|
|
153
|
+
service = IndexService(repository, source_service, app_context.get_data_dir())
|
|
100
154
|
indexes = await service.list_indexes()
|
|
101
155
|
|
|
102
156
|
# Define headers and data
|
|
@@ -123,52 +177,84 @@ async def list_indexes(session: AsyncSession) -> None:
|
|
|
123
177
|
|
|
124
178
|
@indexes.command(name="run")
|
|
125
179
|
@click.argument("index_id")
|
|
180
|
+
@with_app_context
|
|
126
181
|
@with_session
|
|
127
|
-
async def run_index(
|
|
182
|
+
async def run_index(
|
|
183
|
+
session: AsyncSession, app_context: AppContext, index_id: int
|
|
184
|
+
) -> None:
|
|
128
185
|
"""Run an index."""
|
|
129
186
|
source_repository = SourceRepository(session)
|
|
130
|
-
source_service = SourceService(source_repository)
|
|
187
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
131
188
|
repository = IndexRepository(session)
|
|
132
|
-
service = IndexService(repository, source_service)
|
|
189
|
+
service = IndexService(repository, source_service, app_context.get_data_dir())
|
|
133
190
|
await service.run(index_id)
|
|
134
191
|
|
|
135
192
|
|
|
136
193
|
@cli.command()
|
|
137
194
|
@click.argument("query")
|
|
195
|
+
@click.option("--top-k", default=10, help="Number of snippets to retrieve")
|
|
196
|
+
@with_app_context
|
|
138
197
|
@with_session
|
|
139
|
-
async def retrieve(
|
|
198
|
+
async def retrieve(
|
|
199
|
+
session: AsyncSession, app_context: AppContext, query: str, top_k: int
|
|
200
|
+
) -> None:
|
|
140
201
|
"""Retrieve snippets from the database."""
|
|
141
202
|
repository = RetrievalRepository(session)
|
|
142
|
-
service = RetrievalService(repository)
|
|
143
|
-
|
|
203
|
+
service = RetrievalService(repository, app_context.get_data_dir())
|
|
204
|
+
# Temporary request while we don't have all search capabilities
|
|
205
|
+
snippets = await service.retrieve(
|
|
206
|
+
RetrievalRequest(keywords=query.split(","), top_k=top_k)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if len(snippets) == 0:
|
|
210
|
+
click.echo("No snippets found")
|
|
211
|
+
return
|
|
144
212
|
|
|
145
213
|
for snippet in snippets:
|
|
214
|
+
click.echo("-" * 80)
|
|
146
215
|
click.echo(f"{snippet.uri}")
|
|
147
216
|
click.echo(snippet.content)
|
|
217
|
+
click.echo("-" * 80)
|
|
148
218
|
click.echo()
|
|
149
219
|
|
|
150
220
|
|
|
151
221
|
@cli.command()
|
|
152
222
|
@click.option("--host", default="127.0.0.1", help="Host to bind the server to")
|
|
153
223
|
@click.option("--port", default=8080, help="Port to bind the server to")
|
|
154
|
-
@
|
|
224
|
+
@with_app_context
|
|
155
225
|
def serve(
|
|
226
|
+
app_context: AppContext,
|
|
156
227
|
host: str,
|
|
157
228
|
port: int,
|
|
158
|
-
reload: bool, # noqa: FBT001
|
|
159
229
|
) -> None:
|
|
160
230
|
"""Start the kodit server, which hosts the MCP server and the kodit API."""
|
|
161
231
|
log = structlog.get_logger(__name__)
|
|
162
|
-
log.info("Starting kodit server", host=host, port=port
|
|
232
|
+
log.info("Starting kodit server", host=host, port=port)
|
|
163
233
|
log_event("kodit_server_started")
|
|
164
|
-
|
|
234
|
+
|
|
235
|
+
# Dump AppContext to a dictionary of strings, and set the env vars
|
|
236
|
+
app_context_dict = {k: str(v) for k, v in app_context.model_dump().items()}
|
|
237
|
+
os.environ.update(app_context_dict)
|
|
238
|
+
|
|
239
|
+
# Configure uvicorn with graceful shutdown
|
|
240
|
+
config = uvicorn.Config(
|
|
165
241
|
"kodit.app:app",
|
|
166
242
|
host=host,
|
|
167
243
|
port=port,
|
|
168
|
-
reload=
|
|
244
|
+
reload=False,
|
|
169
245
|
log_config=None, # Setting to None forces uvicorn to use our structlog setup
|
|
170
246
|
access_log=False, # Using own middleware for access logging
|
|
247
|
+
timeout_graceful_shutdown=0, # The mcp server does not shutdown cleanly, force
|
|
171
248
|
)
|
|
249
|
+
server = uvicorn.Server(config)
|
|
250
|
+
|
|
251
|
+
def handle_sigint(signum: int, frame: Any) -> None:
|
|
252
|
+
"""Handle SIGINT (Ctrl+C)."""
|
|
253
|
+
log.info("Received shutdown signal, force killing MCP connections")
|
|
254
|
+
server.handle_exit(signum, frame)
|
|
255
|
+
|
|
256
|
+
signal.signal(signal.SIGINT, handle_sigint)
|
|
257
|
+
server.run()
|
|
172
258
|
|
|
173
259
|
|
|
174
260
|
@cli.command()
|
kodit/config.py
CHANGED
|
@@ -1,5 +1,97 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Global configuration for the kodit project."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Callable, Coroutine
|
|
5
|
+
from functools import wraps
|
|
3
6
|
from pathlib import Path
|
|
7
|
+
from typing import Any, TypeVar
|
|
4
8
|
|
|
5
|
-
|
|
9
|
+
import click
|
|
10
|
+
from pydantic import Field
|
|
11
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
12
|
+
|
|
13
|
+
from kodit.database import Database
|
|
14
|
+
|
|
15
|
+
DEFAULT_BASE_DIR = Path.home() / ".kodit"
|
|
16
|
+
DEFAULT_DB_URL = f"sqlite+aiosqlite:///{DEFAULT_BASE_DIR}/kodit.db"
|
|
17
|
+
DEFAULT_LOG_LEVEL = "INFO"
|
|
18
|
+
DEFAULT_LOG_FORMAT = "pretty"
|
|
19
|
+
DEFAULT_DISABLE_TELEMETRY = False
|
|
20
|
+
T = TypeVar("T")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AppContext(BaseSettings):
|
|
24
|
+
"""Global context for the kodit project. Provides a shared state for the app."""
|
|
25
|
+
|
|
26
|
+
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
|
27
|
+
|
|
28
|
+
data_dir: Path = Field(default=DEFAULT_BASE_DIR)
|
|
29
|
+
db_url: str = Field(default=DEFAULT_DB_URL)
|
|
30
|
+
log_level: str = Field(default=DEFAULT_LOG_LEVEL)
|
|
31
|
+
log_format: str = Field(default=DEFAULT_LOG_FORMAT)
|
|
32
|
+
disable_telemetry: bool = Field(default=DEFAULT_DISABLE_TELEMETRY)
|
|
33
|
+
_db: Database | None = None
|
|
34
|
+
|
|
35
|
+
def model_post_init(self, _: Any) -> None:
|
|
36
|
+
"""Post-initialization hook."""
|
|
37
|
+
# Call this to ensure the data dir exists for the default db location
|
|
38
|
+
self.get_data_dir()
|
|
39
|
+
|
|
40
|
+
def get_data_dir(self) -> Path:
|
|
41
|
+
"""Get the data directory."""
|
|
42
|
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
return self.data_dir
|
|
44
|
+
|
|
45
|
+
def get_clone_dir(self) -> Path:
|
|
46
|
+
"""Get the clone directory."""
|
|
47
|
+
clone_dir = self.get_data_dir() / "clones"
|
|
48
|
+
clone_dir.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
return clone_dir
|
|
50
|
+
|
|
51
|
+
async def get_db(self, *, run_migrations: bool = True) -> Database:
|
|
52
|
+
"""Get the database."""
|
|
53
|
+
if self._db is None:
|
|
54
|
+
self._db = Database(self.db_url)
|
|
55
|
+
if run_migrations:
|
|
56
|
+
await self._db.run_migrations(self.db_url)
|
|
57
|
+
return self._db
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
with_app_context = click.make_pass_decorator(AppContext)
|
|
61
|
+
|
|
62
|
+
T = TypeVar("T")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def wrap_async(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
|
|
66
|
+
"""Decorate async Click commands.
|
|
67
|
+
|
|
68
|
+
This decorator wraps an async function to run it with asyncio.run().
|
|
69
|
+
It should be used after the Click command decorator.
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
@cli.command()
|
|
73
|
+
@wrap_async
|
|
74
|
+
async def my_command():
|
|
75
|
+
...
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
@wraps(f)
|
|
80
|
+
def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
81
|
+
return asyncio.run(f(*args, **kwargs))
|
|
82
|
+
|
|
83
|
+
return wrapper
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def with_session(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
|
|
87
|
+
"""Provide a database session to CLI commands."""
|
|
88
|
+
|
|
89
|
+
@wraps(f)
|
|
90
|
+
@with_app_context
|
|
91
|
+
@wrap_async
|
|
92
|
+
async def wrapper(app_context: AppContext, *args: Any, **kwargs: Any) -> T:
|
|
93
|
+
db = await app_context.get_db()
|
|
94
|
+
async with db.session_factory() as session:
|
|
95
|
+
return await f(session, *args, **kwargs)
|
|
96
|
+
|
|
97
|
+
return wrapper
|
kodit/database.py
CHANGED
|
@@ -1,15 +1,11 @@
|
|
|
1
1
|
"""Database configuration for kodit."""
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
|
-
from collections.abc import AsyncGenerator, Callable
|
|
5
|
-
from contextlib import asynccontextmanager
|
|
6
3
|
from datetime import UTC, datetime
|
|
7
|
-
from functools import wraps
|
|
8
4
|
from pathlib import Path
|
|
9
|
-
from typing import Any, TypeVar
|
|
10
5
|
|
|
6
|
+
import structlog
|
|
11
7
|
from alembic import command
|
|
12
|
-
from alembic.config import Config
|
|
8
|
+
from alembic.config import Config as AlembicConfig
|
|
13
9
|
from sqlalchemy import DateTime
|
|
14
10
|
from sqlalchemy.ext.asyncio import (
|
|
15
11
|
AsyncAttrs,
|
|
@@ -20,23 +16,6 @@ from sqlalchemy.ext.asyncio import (
|
|
|
20
16
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
21
17
|
|
|
22
18
|
from kodit import alembic
|
|
23
|
-
from kodit.config import DATA_DIR
|
|
24
|
-
|
|
25
|
-
# Constants
|
|
26
|
-
DB_URL = f"sqlite+aiosqlite:///{DATA_DIR}/kodit.db"
|
|
27
|
-
|
|
28
|
-
# Create data directory if it doesn't exist
|
|
29
|
-
DATA_DIR.mkdir(exist_ok=True)
|
|
30
|
-
|
|
31
|
-
# Create async engine with file-based SQLite
|
|
32
|
-
engine = create_async_engine(DB_URL, echo=False)
|
|
33
|
-
|
|
34
|
-
# Create async session factory
|
|
35
|
-
async_session_factory = async_sessionmaker(
|
|
36
|
-
engine,
|
|
37
|
-
class_=AsyncSession,
|
|
38
|
-
expire_on_commit=False,
|
|
39
|
-
)
|
|
40
19
|
|
|
41
20
|
|
|
42
21
|
class Base(AsyncAttrs, DeclarativeBase):
|
|
@@ -55,37 +34,42 @@ class CommonMixin:
|
|
|
55
34
|
)
|
|
56
35
|
|
|
57
36
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
37
|
+
class Database:
|
|
38
|
+
"""Database class for kodit."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, db_url: str) -> None:
|
|
41
|
+
"""Initialize the database."""
|
|
42
|
+
self.log = structlog.get_logger(__name__)
|
|
43
|
+
self.db_engine = create_async_engine(db_url, echo=False)
|
|
44
|
+
self.db_session_factory = async_sessionmaker(
|
|
45
|
+
self.db_engine,
|
|
46
|
+
class_=AsyncSession,
|
|
47
|
+
expire_on_commit=False,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def session_factory(self) -> async_sessionmaker[AsyncSession]:
|
|
52
|
+
"""Get the session factory."""
|
|
53
|
+
return self.db_session_factory
|
|
54
|
+
|
|
55
|
+
async def run_migrations(self, db_url: str) -> None:
|
|
56
|
+
"""Run any pending migrations."""
|
|
57
|
+
# Create Alembic configuration and run migrations
|
|
58
|
+
alembic_cfg = AlembicConfig()
|
|
59
|
+
alembic_cfg.set_main_option(
|
|
60
|
+
"script_location", str(Path(alembic.__file__).parent)
|
|
61
|
+
)
|
|
62
|
+
alembic_cfg.set_main_option("sqlalchemy.url", db_url)
|
|
63
|
+
self.log.debug("Running migrations", db_url=db_url)
|
|
64
|
+
|
|
65
|
+
async with self.db_engine.begin() as conn:
|
|
66
|
+
await conn.run_sync(self.run_upgrade, alembic_cfg)
|
|
67
|
+
|
|
68
|
+
def run_upgrade(self, connection, cfg) -> None: # noqa: ANN001
|
|
69
|
+
"""Make sure the database is up to date."""
|
|
70
|
+
cfg.attributes["connection"] = connection
|
|
71
|
+
command.upgrade(cfg, "head")
|
|
72
|
+
|
|
73
|
+
async def close(self) -> None:
|
|
74
|
+
"""Close the database."""
|
|
75
|
+
await self.db_engine.dispose()
|
kodit/indexing/repository.py
CHANGED
|
@@ -130,3 +130,14 @@ class IndexRepository:
|
|
|
130
130
|
query = select(Snippet).where(Snippet.index_id == index_id)
|
|
131
131
|
result = await self.session.execute(query)
|
|
132
132
|
return list(result.scalars())
|
|
133
|
+
|
|
134
|
+
async def get_all_snippets(self) -> list[Snippet]:
|
|
135
|
+
"""Get all snippets.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
A list of all snippets.
|
|
139
|
+
|
|
140
|
+
"""
|
|
141
|
+
query = select(Snippet).order_by(Snippet.id)
|
|
142
|
+
result = await self.session.execute(query)
|
|
143
|
+
return list(result.scalars())
|