kodit 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

kodit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.4'
21
- __version_tuple__ = version_tuple = (0, 1, 4)
20
+ __version__ = version = '0.1.6'
21
+ __version_tuple__ = version_tuple = (0, 1, 6)
kodit/alembic/env.py CHANGED
@@ -3,7 +3,6 @@
3
3
 
4
4
  import asyncio
5
5
 
6
- import structlog
7
6
  from alembic import context
8
7
  from sqlalchemy import pool
9
8
  from sqlalchemy.engine import Connection
@@ -66,8 +65,6 @@ async def run_async_migrations() -> None:
66
65
  prefix="sqlalchemy.",
67
66
  poolclass=pool.NullPool,
68
67
  )
69
- log = structlog.get_logger(__name__)
70
- log.debug("Running migrations on %s", connectable.url)
71
68
 
72
69
  async with connectable.connect() as connection:
73
70
  await connection.run_sync(do_run_migrations)
@@ -77,7 +74,11 @@ async def run_async_migrations() -> None:
77
74
 
78
75
  def run_migrations_online() -> None:
79
76
  """Run migrations in 'online' mode."""
80
- asyncio.run(run_async_migrations())
77
+ connectable = config.attributes.get("connection", None)
78
+ if connectable is None:
79
+ asyncio.run(run_async_migrations())
80
+ else:
81
+ do_run_migrations(connectable)
81
82
 
82
83
 
83
84
  if context.is_offline_mode():
kodit/app.py CHANGED
@@ -4,15 +4,11 @@ from asgi_correlation_id import CorrelationIdMiddleware
4
4
  from fastapi import FastAPI
5
5
 
6
6
  from kodit.mcp import mcp
7
- from kodit.middleware import logging_middleware
8
- from kodit.sse import create_sse_server
7
+ from kodit.middleware import ASGICancelledErrorMiddleware, logging_middleware
9
8
 
10
- app = FastAPI(title="kodit API")
11
-
12
- # Get the SSE routes from the Starlette app hosting the MCP server
13
- sse_app = create_sse_server(mcp)
14
- for route in sse_app.routes:
15
- app.router.routes.append(route)
9
+ # See https://gofastmcp.com/deployment/asgi#fastapi-integration
10
+ mcp_app = mcp.sse_app()
11
+ app = FastAPI(title="kodit API", lifespan=mcp_app.router.lifespan_context)
16
12
 
17
13
  # Add middleware
18
14
  app.middleware("http")(logging_middleware)
@@ -22,4 +18,12 @@ app.add_middleware(CorrelationIdMiddleware)
22
18
  @app.get("/")
23
19
  async def root() -> dict[str, str]:
24
20
  """Return a welcome message for the kodit API."""
25
- return {"message": "Welcome to kodit API"}
21
+ return {"message": "Hello, World!"}
22
+
23
+
24
+ # Add mcp routes last, otherwise previous routes aren't added
25
+ app.mount("", mcp_app)
26
+
27
+ # Wrap the entire app with ASGI middleware after all routes are added to suppress
28
+ # CancelledError at the ASGI level
29
+ app = ASGICancelledErrorMiddleware(app)
kodit/bm25/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """BM25 module."""
kodit/bm25/bm25.py ADDED
@@ -0,0 +1,71 @@
1
+ """BM25 service."""
2
+
3
+ from pathlib import Path
4
+
5
+ import bm25s
6
+ import Stemmer
7
+ import structlog
8
+ from bm25s.tokenization import Tokenized
9
+
10
+
11
+ class BM25Service:
12
+ """Service for BM25."""
13
+
14
+ def __init__(self, data_dir: Path) -> None:
15
+ """Initialize the BM25 service."""
16
+ self.log = structlog.get_logger(__name__)
17
+ self.index_path = data_dir / "bm25s_index"
18
+ try:
19
+ self.log.debug("Loading BM25 index")
20
+ self.retriever = bm25s.BM25.load(self.index_path, mmap=True)
21
+ except FileNotFoundError:
22
+ self.log.debug("BM25 index not found, creating new index")
23
+ self.retriever = bm25s.BM25()
24
+
25
+ self.stemmer = Stemmer.Stemmer("english")
26
+
27
+ def _tokenize(self, corpus: list[str]) -> list[list[str]] | Tokenized:
28
+ return bm25s.tokenize(
29
+ corpus,
30
+ stopwords="en",
31
+ stemmer=self.stemmer,
32
+ return_ids=False,
33
+ show_progress=True,
34
+ )
35
+
36
+ def index(self, corpus: list[str]) -> None:
37
+ """Index a new corpus."""
38
+ self.log.debug("Indexing corpus")
39
+ vocab = self._tokenize(corpus)
40
+ self.retriever = bm25s.BM25()
41
+ self.retriever.index(vocab)
42
+ self.retriever.save(self.index_path)
43
+
44
+ def retrieve(
45
+ self, doc_ids: list[int], query: str, top_k: int = 2
46
+ ) -> list[tuple[int, float]]:
47
+ """Retrieve from the index."""
48
+ if top_k == 0:
49
+ self.log.warning("Top k is 0, returning empty list")
50
+ return []
51
+ if len(doc_ids) == 0:
52
+ self.log.warning("No documents to retrieve from, returning empty list")
53
+ return []
54
+
55
+ top_k = min(top_k, len(doc_ids))
56
+ self.log.debug(
57
+ "Retrieving from index", query=query, top_k=top_k, num_docs=len(doc_ids)
58
+ )
59
+
60
+ query_tokens = self._tokenize([query])
61
+
62
+ self.log.debug("Query tokens", query_tokens=query_tokens)
63
+
64
+ results, scores = self.retriever.retrieve(
65
+ query_tokens=query_tokens, corpus=doc_ids, k=top_k
66
+ )
67
+ self.log.debug("Raw results", results=results, scores=scores)
68
+ return [
69
+ (int(result), float(score))
70
+ for result, score in zip(results[0], scores[0], strict=False)
71
+ ]
kodit/cli.py CHANGED
@@ -1,41 +1,87 @@
1
1
  """Command line interface for kodit."""
2
2
 
3
3
  import os
4
+ import signal
5
+ from pathlib import Path
6
+ from typing import Any
4
7
 
5
8
  import click
6
9
  import structlog
7
10
  import uvicorn
8
- from dotenv import dotenv_values
9
11
  from pytable_formatter import Table
10
12
  from sqlalchemy.ext.asyncio import AsyncSession
11
13
 
12
- from kodit.database import configure_database, with_session
14
+ from kodit.config import (
15
+ DEFAULT_BASE_DIR,
16
+ DEFAULT_DB_URL,
17
+ DEFAULT_DISABLE_TELEMETRY,
18
+ DEFAULT_LOG_FORMAT,
19
+ DEFAULT_LOG_LEVEL,
20
+ AppContext,
21
+ with_app_context,
22
+ with_session,
23
+ )
13
24
  from kodit.indexing.repository import IndexRepository
14
25
  from kodit.indexing.service import IndexService
15
- from kodit.logging import LogFormat, configure_logging, disable_posthog, log_event
26
+ from kodit.logging import configure_logging, configure_telemetry, log_event
16
27
  from kodit.retreival.repository import RetrievalRepository
17
28
  from kodit.retreival.service import RetrievalRequest, RetrievalService
18
29
  from kodit.sources.repository import SourceRepository
19
30
  from kodit.sources.service import SourceService
20
31
 
21
- env_vars = dict(dotenv_values())
22
- os.environ.update(env_vars)
23
32
 
24
-
25
- @click.group(context_settings={"auto_envvar_prefix": "KODIT", "show_default": True})
26
- @click.option("--log-level", default="INFO", help="Log level")
27
- @click.option("--log-format", default=LogFormat.PRETTY, help="Log format")
28
- @click.option("--disable-telemetry", is_flag=True, help="Disable telemetry")
29
- def cli(
30
- log_level: str,
31
- log_format: LogFormat,
32
- disable_telemetry: bool, # noqa: FBT001
33
+ @click.group(context_settings={"max_content_width": 100})
34
+ @click.option("--log-level", help=f"Log level [default: {DEFAULT_LOG_LEVEL}]")
35
+ @click.option("--log-format", help=f"Log format [default: {DEFAULT_LOG_FORMAT}]")
36
+ @click.option(
37
+ "--disable-telemetry",
38
+ is_flag=True,
39
+ help=f"Disable telemetry [default: {DEFAULT_DISABLE_TELEMETRY}]",
40
+ )
41
+ @click.option("--db-url", help=f"Database URL [default: {DEFAULT_DB_URL}]")
42
+ @click.option("--data-dir", help=f"Data directory [default: {DEFAULT_BASE_DIR}]")
43
+ @click.option(
44
+ "--env-file",
45
+ help="Path to a .env file [default: .env]",
46
+ type=click.Path(
47
+ exists=True,
48
+ dir_okay=False,
49
+ resolve_path=True,
50
+ path_type=Path,
51
+ ),
52
+ )
53
+ @click.pass_context
54
+ def cli( # noqa: PLR0913
55
+ ctx: click.Context,
56
+ log_level: str | None,
57
+ log_format: str | None,
58
+ disable_telemetry: bool | None,
59
+ db_url: str | None,
60
+ data_dir: str | None,
61
+ env_file: Path | None,
33
62
  ) -> None:
34
63
  """kodit CLI - Code indexing for better AI code generation.""" # noqa: D403
35
- configure_logging(log_level, log_format)
64
+ config = AppContext()
65
+ # First check if env-file is set and reload config if it is
66
+ if env_file:
67
+ config = AppContext(_env_file=env_file) # type: ignore[reportCallIssue]
68
+
69
+ # Now override with CLI arguments, if set
70
+ if data_dir:
71
+ config.data_dir = Path(data_dir)
72
+ if db_url:
73
+ config.db_url = db_url
74
+ if log_level:
75
+ config.log_level = log_level
76
+ if log_format:
77
+ config.log_format = log_format
36
78
  if disable_telemetry:
37
- disable_posthog()
38
- configure_database()
79
+ config.disable_telemetry = disable_telemetry
80
+ configure_logging(config)
81
+ configure_telemetry(config)
82
+
83
+ # Set the app context in the click context for downstream cli
84
+ ctx.obj = config
39
85
 
40
86
 
41
87
  @cli.group()
@@ -44,11 +90,12 @@ def sources() -> None:
44
90
 
45
91
 
46
92
  @sources.command(name="list")
93
+ @with_app_context
47
94
  @with_session
48
- async def list_sources(session: AsyncSession) -> None:
95
+ async def list_sources(session: AsyncSession, app_context: AppContext) -> None:
49
96
  """List all code sources."""
50
97
  repository = SourceRepository(session)
51
- service = SourceService(repository)
98
+ service = SourceService(app_context.get_clone_dir(), repository)
52
99
  sources = await service.list_sources()
53
100
 
54
101
  # Define headers and data
@@ -62,11 +109,14 @@ async def list_sources(session: AsyncSession) -> None:
62
109
 
63
110
  @sources.command(name="create")
64
111
  @click.argument("uri")
112
+ @with_app_context
65
113
  @with_session
66
- async def create_source(session: AsyncSession, uri: str) -> None:
114
+ async def create_source(
115
+ session: AsyncSession, app_context: AppContext, uri: str
116
+ ) -> None:
67
117
  """Add a new code source."""
68
118
  repository = SourceRepository(session)
69
- service = SourceService(repository)
119
+ service = SourceService(app_context.get_clone_dir(), repository)
70
120
  source = await service.create(uri)
71
121
  click.echo(f"Source created: {source.id}")
72
122
 
@@ -78,25 +128,29 @@ def indexes() -> None:
78
128
 
79
129
  @indexes.command(name="create")
80
130
  @click.argument("source_id")
131
+ @with_app_context
81
132
  @with_session
82
- async def create_index(session: AsyncSession, source_id: int) -> None:
133
+ async def create_index(
134
+ session: AsyncSession, app_context: AppContext, source_id: int
135
+ ) -> None:
83
136
  """Create an index for a source."""
84
137
  source_repository = SourceRepository(session)
85
- source_service = SourceService(source_repository)
138
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
86
139
  repository = IndexRepository(session)
87
- service = IndexService(repository, source_service)
140
+ service = IndexService(repository, source_service, app_context.get_data_dir())
88
141
  index = await service.create(source_id)
89
142
  click.echo(f"Index created: {index.id}")
90
143
 
91
144
 
92
145
  @indexes.command(name="list")
146
+ @with_app_context
93
147
  @with_session
94
- async def list_indexes(session: AsyncSession) -> None:
148
+ async def list_indexes(session: AsyncSession, app_context: AppContext) -> None:
95
149
  """List all indexes."""
96
150
  source_repository = SourceRepository(session)
97
- source_service = SourceService(source_repository)
151
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
98
152
  repository = IndexRepository(session)
99
- service = IndexService(repository, source_service)
153
+ service = IndexService(repository, source_service, app_context.get_data_dir())
100
154
  indexes = await service.list_indexes()
101
155
 
102
156
  # Define headers and data
@@ -123,52 +177,84 @@ async def list_indexes(session: AsyncSession) -> None:
123
177
 
124
178
  @indexes.command(name="run")
125
179
  @click.argument("index_id")
180
+ @with_app_context
126
181
  @with_session
127
- async def run_index(session: AsyncSession, index_id: int) -> None:
182
+ async def run_index(
183
+ session: AsyncSession, app_context: AppContext, index_id: int
184
+ ) -> None:
128
185
  """Run an index."""
129
186
  source_repository = SourceRepository(session)
130
- source_service = SourceService(source_repository)
187
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
131
188
  repository = IndexRepository(session)
132
- service = IndexService(repository, source_service)
189
+ service = IndexService(repository, source_service, app_context.get_data_dir())
133
190
  await service.run(index_id)
134
191
 
135
192
 
136
193
  @cli.command()
137
194
  @click.argument("query")
195
+ @click.option("--top-k", default=10, help="Number of snippets to retrieve")
196
+ @with_app_context
138
197
  @with_session
139
- async def retrieve(session: AsyncSession, query: str) -> None:
198
+ async def retrieve(
199
+ session: AsyncSession, app_context: AppContext, query: str, top_k: int
200
+ ) -> None:
140
201
  """Retrieve snippets from the database."""
141
202
  repository = RetrievalRepository(session)
142
- service = RetrievalService(repository)
143
- snippets = await service.retrieve(RetrievalRequest(query=query))
203
+ service = RetrievalService(repository, app_context.get_data_dir())
204
+ # Temporary request while we don't have all search capabilities
205
+ snippets = await service.retrieve(
206
+ RetrievalRequest(keywords=query.split(","), top_k=top_k)
207
+ )
208
+
209
+ if len(snippets) == 0:
210
+ click.echo("No snippets found")
211
+ return
144
212
 
145
213
  for snippet in snippets:
214
+ click.echo("-" * 80)
146
215
  click.echo(f"{snippet.uri}")
147
216
  click.echo(snippet.content)
217
+ click.echo("-" * 80)
148
218
  click.echo()
149
219
 
150
220
 
151
221
  @cli.command()
152
222
  @click.option("--host", default="127.0.0.1", help="Host to bind the server to")
153
223
  @click.option("--port", default=8080, help="Port to bind the server to")
154
- @click.option("--reload", is_flag=True, help="Enable auto-reload for development")
224
+ @with_app_context
155
225
  def serve(
226
+ app_context: AppContext,
156
227
  host: str,
157
228
  port: int,
158
- reload: bool, # noqa: FBT001
159
229
  ) -> None:
160
230
  """Start the kodit server, which hosts the MCP server and the kodit API."""
161
231
  log = structlog.get_logger(__name__)
162
- log.info("Starting kodit server", host=host, port=port, reload=reload)
232
+ log.info("Starting kodit server", host=host, port=port)
163
233
  log_event("kodit_server_started")
164
- uvicorn.run(
234
+
235
+ # Dump AppContext to a dictionary of strings, and set the env vars
236
+ app_context_dict = {k: str(v) for k, v in app_context.model_dump().items()}
237
+ os.environ.update(app_context_dict)
238
+
239
+ # Configure uvicorn with graceful shutdown
240
+ config = uvicorn.Config(
165
241
  "kodit.app:app",
166
242
  host=host,
167
243
  port=port,
168
- reload=reload,
244
+ reload=False,
169
245
  log_config=None, # Setting to None forces uvicorn to use our structlog setup
170
246
  access_log=False, # Using own middleware for access logging
247
+ timeout_graceful_shutdown=0, # The mcp server does not shutdown cleanly, force
171
248
  )
249
+ server = uvicorn.Server(config)
250
+
251
+ def handle_sigint(signum: int, frame: Any) -> None:
252
+ """Handle SIGINT (Ctrl+C)."""
253
+ log.info("Received shutdown signal, force killing MCP connections")
254
+ server.handle_exit(signum, frame)
255
+
256
+ signal.signal(signal.SIGINT, handle_sigint)
257
+ server.run()
172
258
 
173
259
 
174
260
  @cli.command()
kodit/config.py CHANGED
@@ -1,5 +1,97 @@
1
- """Configuration for the kodit project."""
1
+ """Global configuration for the kodit project."""
2
2
 
3
+ import asyncio
4
+ from collections.abc import Callable, Coroutine
5
+ from functools import wraps
3
6
  from pathlib import Path
7
+ from typing import Any, TypeVar
4
8
 
5
- DATA_DIR = Path.home() / ".kodit"
9
+ import click
10
+ from pydantic import Field
11
+ from pydantic_settings import BaseSettings, SettingsConfigDict
12
+
13
+ from kodit.database import Database
14
+
15
+ DEFAULT_BASE_DIR = Path.home() / ".kodit"
16
+ DEFAULT_DB_URL = f"sqlite+aiosqlite:///{DEFAULT_BASE_DIR}/kodit.db"
17
+ DEFAULT_LOG_LEVEL = "INFO"
18
+ DEFAULT_LOG_FORMAT = "pretty"
19
+ DEFAULT_DISABLE_TELEMETRY = False
20
+ T = TypeVar("T")
21
+
22
+
23
+ class AppContext(BaseSettings):
24
+ """Global context for the kodit project. Provides a shared state for the app."""
25
+
26
+ model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
27
+
28
+ data_dir: Path = Field(default=DEFAULT_BASE_DIR)
29
+ db_url: str = Field(default=DEFAULT_DB_URL)
30
+ log_level: str = Field(default=DEFAULT_LOG_LEVEL)
31
+ log_format: str = Field(default=DEFAULT_LOG_FORMAT)
32
+ disable_telemetry: bool = Field(default=DEFAULT_DISABLE_TELEMETRY)
33
+ _db: Database | None = None
34
+
35
+ def model_post_init(self, _: Any) -> None:
36
+ """Post-initialization hook."""
37
+ # Call this to ensure the data dir exists for the default db location
38
+ self.get_data_dir()
39
+
40
+ def get_data_dir(self) -> Path:
41
+ """Get the data directory."""
42
+ self.data_dir.mkdir(parents=True, exist_ok=True)
43
+ return self.data_dir
44
+
45
+ def get_clone_dir(self) -> Path:
46
+ """Get the clone directory."""
47
+ clone_dir = self.get_data_dir() / "clones"
48
+ clone_dir.mkdir(parents=True, exist_ok=True)
49
+ return clone_dir
50
+
51
+ async def get_db(self, *, run_migrations: bool = True) -> Database:
52
+ """Get the database."""
53
+ if self._db is None:
54
+ self._db = Database(self.db_url)
55
+ if run_migrations:
56
+ await self._db.run_migrations(self.db_url)
57
+ return self._db
58
+
59
+
60
+ with_app_context = click.make_pass_decorator(AppContext)
61
+
62
+ T = TypeVar("T")
63
+
64
+
65
+ def wrap_async(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
66
+ """Decorate async Click commands.
67
+
68
+ This decorator wraps an async function to run it with asyncio.run().
69
+ It should be used after the Click command decorator.
70
+
71
+ Example:
72
+ @cli.command()
73
+ @wrap_async
74
+ async def my_command():
75
+ ...
76
+
77
+ """
78
+
79
+ @wraps(f)
80
+ def wrapper(*args: Any, **kwargs: Any) -> T:
81
+ return asyncio.run(f(*args, **kwargs))
82
+
83
+ return wrapper
84
+
85
+
86
+ def with_session(f: Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
87
+ """Provide a database session to CLI commands."""
88
+
89
+ @wraps(f)
90
+ @with_app_context
91
+ @wrap_async
92
+ async def wrapper(app_context: AppContext, *args: Any, **kwargs: Any) -> T:
93
+ db = await app_context.get_db()
94
+ async with db.session_factory() as session:
95
+ return await f(session, *args, **kwargs)
96
+
97
+ return wrapper
kodit/database.py CHANGED
@@ -1,15 +1,11 @@
1
1
  """Database configuration for kodit."""
2
2
 
3
- import asyncio
4
- from collections.abc import AsyncGenerator, Callable
5
- from contextlib import asynccontextmanager
6
3
  from datetime import UTC, datetime
7
- from functools import wraps
8
4
  from pathlib import Path
9
- from typing import Any, TypeVar
10
5
 
6
+ import structlog
11
7
  from alembic import command
12
- from alembic.config import Config
8
+ from alembic.config import Config as AlembicConfig
13
9
  from sqlalchemy import DateTime
14
10
  from sqlalchemy.ext.asyncio import (
15
11
  AsyncAttrs,
@@ -20,23 +16,6 @@ from sqlalchemy.ext.asyncio import (
20
16
  from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
21
17
 
22
18
  from kodit import alembic
23
- from kodit.config import DATA_DIR
24
-
25
- # Constants
26
- DB_URL = f"sqlite+aiosqlite:///{DATA_DIR}/kodit.db"
27
-
28
- # Create data directory if it doesn't exist
29
- DATA_DIR.mkdir(exist_ok=True)
30
-
31
- # Create async engine with file-based SQLite
32
- engine = create_async_engine(DB_URL, echo=False)
33
-
34
- # Create async session factory
35
- async_session_factory = async_sessionmaker(
36
- engine,
37
- class_=AsyncSession,
38
- expire_on_commit=False,
39
- )
40
19
 
41
20
 
42
21
  class Base(AsyncAttrs, DeclarativeBase):
@@ -55,37 +34,42 @@ class CommonMixin:
55
34
  )
56
35
 
57
36
 
58
- @asynccontextmanager
59
- async def get_session() -> AsyncGenerator[AsyncSession, None]:
60
- """Get a database session."""
61
- async with async_session_factory() as session:
62
- try:
63
- yield session
64
- finally:
65
- await session.close()
66
-
67
-
68
- T = TypeVar("T")
69
-
70
-
71
- def with_session(func: Callable[..., T]) -> Callable[..., T]:
72
- """Provide an async session to CLI commands."""
73
-
74
- @wraps(func)
75
- def wrapper(*args: Any, **kwargs: Any) -> T:
76
- async def _run() -> T:
77
- async with async_session_factory() as session:
78
- return await func(session, *args, **kwargs)
79
-
80
- return asyncio.run(_run())
81
-
82
- return wrapper
83
-
84
-
85
- def configure_database() -> None:
86
- """Configure the database by initializing it and running any pending migrations."""
87
- # Create Alembic configuration and run migrations
88
- alembic_cfg = Config()
89
- alembic_cfg.set_main_option("script_location", str(Path(alembic.__file__).parent))
90
- alembic_cfg.set_main_option("sqlalchemy.url", DB_URL)
91
- command.upgrade(alembic_cfg, "head")
37
+ class Database:
38
+ """Database class for kodit."""
39
+
40
+ def __init__(self, db_url: str) -> None:
41
+ """Initialize the database."""
42
+ self.log = structlog.get_logger(__name__)
43
+ self.db_engine = create_async_engine(db_url, echo=False)
44
+ self.db_session_factory = async_sessionmaker(
45
+ self.db_engine,
46
+ class_=AsyncSession,
47
+ expire_on_commit=False,
48
+ )
49
+
50
+ @property
51
+ def session_factory(self) -> async_sessionmaker[AsyncSession]:
52
+ """Get the session factory."""
53
+ return self.db_session_factory
54
+
55
+ async def run_migrations(self, db_url: str) -> None:
56
+ """Run any pending migrations."""
57
+ # Create Alembic configuration and run migrations
58
+ alembic_cfg = AlembicConfig()
59
+ alembic_cfg.set_main_option(
60
+ "script_location", str(Path(alembic.__file__).parent)
61
+ )
62
+ alembic_cfg.set_main_option("sqlalchemy.url", db_url)
63
+ self.log.debug("Running migrations", db_url=db_url)
64
+
65
+ async with self.db_engine.begin() as conn:
66
+ await conn.run_sync(self.run_upgrade, alembic_cfg)
67
+
68
+ def run_upgrade(self, connection, cfg) -> None: # noqa: ANN001
69
+ """Make sure the database is up to date."""
70
+ cfg.attributes["connection"] = connection
71
+ command.upgrade(cfg, "head")
72
+
73
+ async def close(self) -> None:
74
+ """Close the database."""
75
+ await self.db_engine.dispose()
@@ -130,3 +130,14 @@ class IndexRepository:
130
130
  query = select(Snippet).where(Snippet.index_id == index_id)
131
131
  result = await self.session.execute(query)
132
132
  return list(result.scalars())
133
+
134
+ async def get_all_snippets(self) -> list[Snippet]:
135
+ """Get all snippets.
136
+
137
+ Returns:
138
+ A list of all snippets.
139
+
140
+ """
141
+ query = select(Snippet).order_by(Snippet.id)
142
+ result = await self.session.execute(query)
143
+ return list(result.scalars())