kodit 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kodit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.0'
21
- __version_tuple__ = version_tuple = (0, 1, 0)
20
+ __version__ = version = '0.1.2'
21
+ __version_tuple__ = version_tuple = (0, 1, 2)
kodit/alembic/README ADDED
@@ -0,0 +1 @@
1
+ Generic single-database configuration with an async dbapi.
@@ -0,0 +1 @@
1
+ """Database configuration for kodit."""
kodit/alembic/env.py ADDED
@@ -0,0 +1,86 @@
1
+ # ruff: noqa: F401
2
+ """Alembic environment file for kodit."""
3
+
4
+ import asyncio
5
+
6
+ import structlog
7
+ from alembic import context
8
+ from sqlalchemy import pool
9
+ from sqlalchemy.engine import Connection
10
+ from sqlalchemy.ext.asyncio import async_engine_from_config
11
+
12
+ import kodit.indexing.models
13
+ import kodit.sources.models
14
+ from kodit.database import Base
15
+
16
+ # this is the Alembic Config object, which provides
17
+ # access to the values within the .ini file in use.
18
+ config = context.config
19
+
20
+ # Interpret the config file for Python logging.
21
+ # This line sets up loggers basically.
22
+ # We skip this to preserve the existing logging configuration
23
+ # if config.config_file_name is not None:
24
+
25
+ # add your model's MetaData object here
26
+ # for 'autogenerate' support
27
+ target_metadata = Base.metadata
28
+
29
+
30
+ def run_migrations_offline() -> None:
31
+ """Run migrations in 'offline' mode.
32
+
33
+ This configures the context with just a URL
34
+ and not an Engine, though an Engine is acceptable
35
+ here as well. By skipping the Engine creation
36
+ we don't even need a DBAPI to be available.
37
+
38
+ Calls to context.execute() here emit the given string to the
39
+ script output.
40
+
41
+ """
42
+ url = config.get_main_option("sqlalchemy.url")
43
+ context.configure(
44
+ url=url,
45
+ target_metadata=target_metadata,
46
+ literal_binds=True,
47
+ dialect_opts={"paramstyle": "named"},
48
+ )
49
+
50
+ with context.begin_transaction():
51
+ context.run_migrations()
52
+
53
+
54
+ def do_run_migrations(connection: Connection) -> None:
55
+ """Run migrations in 'online' mode."""
56
+ context.configure(connection=connection, target_metadata=target_metadata)
57
+
58
+ with context.begin_transaction():
59
+ context.run_migrations()
60
+
61
+
62
+ async def run_async_migrations() -> None:
63
+ """Run migrations in 'async' mode."""
64
+ connectable = async_engine_from_config(
65
+ config.get_section(config.config_ini_section, {}),
66
+ prefix="sqlalchemy.",
67
+ poolclass=pool.NullPool,
68
+ )
69
+ log = structlog.get_logger(__name__)
70
+ log.debug("Running migrations on %s", connectable.url)
71
+
72
+ async with connectable.connect() as connection:
73
+ await connection.run_sync(do_run_migrations)
74
+
75
+ await connectable.dispose()
76
+
77
+
78
+ def run_migrations_online() -> None:
79
+ """Run migrations in 'online' mode."""
80
+ asyncio.run(run_async_migrations())
81
+
82
+
83
+ if context.is_offline_mode():
84
+ run_migrations_offline()
85
+ else:
86
+ run_migrations_online()
@@ -0,0 +1,30 @@
1
+ # ruff: noqa
2
+ """${message}
3
+
4
+ Revision ID: ${up_revision}
5
+ Revises: ${down_revision | comma,n}
6
+ Create Date: ${create_date}
7
+
8
+ """
9
+
10
+ from typing import Sequence, Union
11
+
12
+ from alembic import op
13
+ import sqlalchemy as sa
14
+ ${imports if imports else ""}
15
+
16
+ # revision identifiers, used by Alembic.
17
+ revision: str = ${repr(up_revision)}
18
+ down_revision: Union[str, None] = ${repr(down_revision)}
19
+ branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
20
+ depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
21
+
22
+
23
+ def upgrade() -> None:
24
+ """Upgrade schema."""
25
+ ${upgrades if upgrades else "pass"}
26
+
27
+
28
+ def downgrade() -> None:
29
+ """Downgrade schema."""
30
+ ${downgrades if downgrades else "pass"}
@@ -0,0 +1,82 @@
1
+ # ruff: noqa
2
+ """initial
3
+
4
+ Revision ID: 85155663351e
5
+ Revises:
6
+ Create Date: 2025-05-08 13:45:16.687162
7
+
8
+ """
9
+
10
+ from typing import Sequence, Union
11
+
12
+ from alembic import op
13
+ import sqlalchemy as sa
14
+
15
+
16
+ # revision identifiers, used by Alembic.
17
+ revision: str = '85155663351e'
18
+ down_revision: Union[str, None] = None
19
+ branch_labels: Union[str, Sequence[str], None] = None
20
+ depends_on: Union[str, Sequence[str], None] = None
21
+
22
+
23
+ def upgrade() -> None:
24
+ """Upgrade schema."""
25
+ # ### commands auto generated by Alembic - please adjust! ###
26
+ op.create_table('sources',
27
+ sa.Column('uri', sa.String(length=1024), nullable=False),
28
+ sa.Column('cloned_path', sa.String(length=1024), nullable=False),
29
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
30
+ sa.Column('created_at', sa.DateTime(), nullable=False),
31
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
32
+ sa.PrimaryKeyConstraint('id')
33
+ )
34
+ op.create_index(op.f('ix_sources_uri'), 'sources', ['uri'], unique=True)
35
+ op.create_table('files',
36
+ sa.Column('source_id', sa.Integer(), nullable=False),
37
+ sa.Column('mime_type', sa.String(length=255), nullable=False),
38
+ sa.Column('uri', sa.String(length=1024), nullable=False),
39
+ sa.Column('cloned_path', sa.String(length=1024), nullable=False),
40
+ sa.Column('sha256', sa.String(length=64), nullable=False),
41
+ sa.Column('size_bytes', sa.Integer(), nullable=False),
42
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
43
+ sa.Column('created_at', sa.DateTime(), nullable=False),
44
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
45
+ sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ),
46
+ sa.PrimaryKeyConstraint('id')
47
+ )
48
+ op.create_index(op.f('ix_files_sha256'), 'files', ['sha256'], unique=False)
49
+ op.create_table('indexes',
50
+ sa.Column('source_id', sa.Integer(), nullable=False),
51
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
52
+ sa.Column('created_at', sa.DateTime(), nullable=False),
53
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
54
+ sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ),
55
+ sa.PrimaryKeyConstraint('id')
56
+ )
57
+ op.create_index(op.f('ix_indexes_source_id'), 'indexes', ['source_id'], unique=True)
58
+ op.create_table('snippets',
59
+ sa.Column('file_id', sa.Integer(), nullable=False),
60
+ sa.Column('index_id', sa.Integer(), nullable=False),
61
+ sa.Column('content', sa.UnicodeText(), nullable=False),
62
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
63
+ sa.Column('created_at', sa.DateTime(), nullable=False),
64
+ sa.Column('updated_at', sa.DateTime(), nullable=False),
65
+ sa.ForeignKeyConstraint(['file_id'], ['files.id'], ),
66
+ sa.ForeignKeyConstraint(['index_id'], ['indexes.id'], ),
67
+ sa.PrimaryKeyConstraint('id')
68
+ )
69
+ # ### end Alembic commands ###
70
+
71
+
72
+ def downgrade() -> None:
73
+ """Downgrade schema."""
74
+ # ### commands auto generated by Alembic - please adjust! ###
75
+ op.drop_table('snippets')
76
+ op.drop_index(op.f('ix_indexes_source_id'), table_name='indexes')
77
+ op.drop_table('indexes')
78
+ op.drop_index(op.f('ix_files_sha256'), table_name='files')
79
+ op.drop_table('files')
80
+ op.drop_index(op.f('ix_sources_uri'), table_name='sources')
81
+ op.drop_table('sources')
82
+ # ### end Alembic commands ###
@@ -0,0 +1 @@
1
+ """Alembic migrations for kodit."""
kodit/cli.py CHANGED
@@ -6,8 +6,17 @@ import click
6
6
  import structlog
7
7
  import uvicorn
8
8
  from dotenv import dotenv_values
9
+ from pytable_formatter import Table
10
+ from sqlalchemy.ext.asyncio import AsyncSession
9
11
 
12
+ from kodit.database import configure_database, with_session
13
+ from kodit.indexing.repository import IndexRepository
14
+ from kodit.indexing.service import IndexService
10
15
  from kodit.logging import LogFormat, configure_logging, disable_posthog, log_event
16
+ from kodit.retreival.repository import RetrievalRepository
17
+ from kodit.retreival.service import RetrievalRequest, RetrievalService
18
+ from kodit.sources.repository import SourceRepository
19
+ from kodit.sources.service import SourceService
11
20
 
12
21
  env_vars = dict(dotenv_values())
13
22
  os.environ.update(env_vars)
@@ -26,6 +35,119 @@ def cli(
26
35
  configure_logging(log_level, log_format)
27
36
  if disable_telemetry:
28
37
  disable_posthog()
38
+ configure_database()
39
+
40
+
41
+ @cli.group()
42
+ def sources() -> None:
43
+ """Manage code sources."""
44
+
45
+
46
+ @sources.command(name="list")
47
+ @with_session
48
+ async def list_sources(session: AsyncSession) -> None:
49
+ """List all code sources."""
50
+ repository = SourceRepository(session)
51
+ service = SourceService(repository)
52
+ sources = await service.list_sources()
53
+
54
+ # Define headers and data
55
+ headers = ["ID", "Created At", "URI"]
56
+ data = [[source.id, source.created_at, source.uri] for source in sources]
57
+
58
+ # Create and display the table
59
+ table = Table(headers=headers, data=data)
60
+ click.echo(table)
61
+
62
+
63
+ @sources.command(name="create")
64
+ @click.argument("uri")
65
+ @with_session
66
+ async def create_source(session: AsyncSession, uri: str) -> None:
67
+ """Add a new code source."""
68
+ repository = SourceRepository(session)
69
+ service = SourceService(repository)
70
+ source = await service.create(uri)
71
+ click.echo(f"Source created: {source.id}")
72
+
73
+
74
+ @cli.group()
75
+ def indexes() -> None:
76
+ """Manage indexes."""
77
+
78
+
79
+ @indexes.command(name="create")
80
+ @click.argument("source_id")
81
+ @with_session
82
+ async def create_index(session: AsyncSession, source_id: int) -> None:
83
+ """Create an index for a source."""
84
+ source_repository = SourceRepository(session)
85
+ source_service = SourceService(source_repository)
86
+ repository = IndexRepository(session)
87
+ service = IndexService(repository, source_service)
88
+ index = await service.create(source_id)
89
+ click.echo(f"Index created: {index.id}")
90
+
91
+
92
+ @indexes.command(name="list")
93
+ @with_session
94
+ async def list_indexes(session: AsyncSession) -> None:
95
+ """List all indexes."""
96
+ source_repository = SourceRepository(session)
97
+ source_service = SourceService(source_repository)
98
+ repository = IndexRepository(session)
99
+ service = IndexService(repository, source_service)
100
+ indexes = await service.list_indexes()
101
+
102
+ # Define headers and data
103
+ headers = [
104
+ "ID",
105
+ "Created At",
106
+ "Updated At",
107
+ "Source URI",
108
+ "Num Snippets",
109
+ ]
110
+ data = [
111
+ [
112
+ index.id,
113
+ index.created_at,
114
+ index.updated_at,
115
+ index.source_uri,
116
+ index.num_snippets,
117
+ ]
118
+ for index in indexes
119
+ ]
120
+
121
+ # Create and display the table
122
+ table = Table(headers=headers, data=data)
123
+ click.echo(table)
124
+
125
+
126
+ @indexes.command(name="run")
127
+ @click.argument("index_id")
128
+ @with_session
129
+ async def run_index(session: AsyncSession, index_id: int) -> None:
130
+ """Run an index."""
131
+ source_repository = SourceRepository(session)
132
+ source_service = SourceService(source_repository)
133
+ repository = IndexRepository(session)
134
+ service = IndexService(repository, source_service)
135
+ await service.run(index_id)
136
+
137
+
138
+ @cli.command()
139
+ @click.argument("query")
140
+ @with_session
141
+ async def retrieve(session: AsyncSession, query: str) -> None:
142
+ """Retrieve snippets from the database."""
143
+ repository = RetrievalRepository(session)
144
+ service = RetrievalService(repository)
145
+ snippets = await service.retrieve(RetrievalRequest(query=query))
146
+
147
+ for snippet in snippets:
148
+ click.echo(f"{snippet.uri}")
149
+ click.echo(snippet.content)
150
+ click.echo()
29
151
 
30
152
 
31
153
  @cli.command()
kodit/config.py ADDED
@@ -0,0 +1,5 @@
1
+ """Configuration for the kodit project."""
2
+
3
+ from pathlib import Path
4
+
5
+ DATA_DIR = Path.home() / ".kodit"
kodit/database.py ADDED
@@ -0,0 +1,87 @@
1
+ """Database configuration for kodit."""
2
+
3
+ import asyncio
4
+ from collections.abc import AsyncGenerator, Callable
5
+ from datetime import UTC, datetime
6
+ from functools import wraps
7
+ from typing import Any, TypeVar
8
+
9
+ from alembic import command
10
+ from alembic.config import Config
11
+ from sqlalchemy import DateTime
12
+ from sqlalchemy.ext.asyncio import (
13
+ AsyncAttrs,
14
+ AsyncSession,
15
+ async_sessionmaker,
16
+ create_async_engine,
17
+ )
18
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
19
+
20
+ from kodit.config import DATA_DIR
21
+
22
+ # Constants
23
+ DB_URL = f"sqlite+aiosqlite:///{DATA_DIR}/kodit.db"
24
+
25
+ # Create data directory if it doesn't exist
26
+ DATA_DIR.mkdir(exist_ok=True)
27
+
28
+ # Create async engine with file-based SQLite
29
+ engine = create_async_engine(DB_URL, echo=False)
30
+
31
+ # Create async session factory
32
+ async_session_factory = async_sessionmaker(
33
+ engine,
34
+ class_=AsyncSession,
35
+ expire_on_commit=False,
36
+ )
37
+
38
+
39
+ class Base(AsyncAttrs, DeclarativeBase):
40
+ """Base class for all models."""
41
+
42
+
43
+ class CommonMixin:
44
+ """Common mixin for all models."""
45
+
46
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
47
+ created_at: Mapped[datetime] = mapped_column(
48
+ DateTime, default=lambda: datetime.now(UTC)
49
+ )
50
+ updated_at: Mapped[datetime] = mapped_column(
51
+ DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)
52
+ )
53
+
54
+
55
+ async def get_session() -> AsyncGenerator[AsyncSession, None]:
56
+ """Get a database session."""
57
+ async with async_session_factory() as session:
58
+ try:
59
+ yield session
60
+ finally:
61
+ await session.close()
62
+
63
+
64
+ T = TypeVar("T")
65
+
66
+
67
+ def with_session(func: Callable[..., T]) -> Callable[..., T]:
68
+ """Provide an async session to CLI commands."""
69
+
70
+ @wraps(func)
71
+ def wrapper(*args: Any, **kwargs: Any) -> T:
72
+ async def _run() -> T:
73
+ async with async_session_factory() as session:
74
+ return await func(session, *args, **kwargs)
75
+
76
+ return asyncio.run(_run())
77
+
78
+ return wrapper
79
+
80
+
81
+ def configure_database() -> None:
82
+ """Configure the database by initializing it and running any pending migrations."""
83
+ # Create Alembic configuration and run migrations
84
+ alembic_cfg = Config()
85
+ alembic_cfg.set_main_option("script_location", "src/kodit/alembic")
86
+ alembic_cfg.set_main_option("sqlalchemy.url", DB_URL)
87
+ command.upgrade(alembic_cfg, "head")
@@ -0,0 +1 @@
1
+ """Indexing package for managing code indexes and search functionality."""
@@ -0,0 +1,43 @@
1
+ """Index models for managing code indexes.
2
+
3
+ This module defines the SQLAlchemy models used for storing and managing code indexes,
4
+ including files and snippets. It provides the data structures for tracking indexed
5
+ files and their content.
6
+ """
7
+
8
+ from sqlalchemy import ForeignKey, UnicodeText
9
+ from sqlalchemy.orm import Mapped, mapped_column
10
+
11
+ from kodit.database import Base, CommonMixin
12
+
13
+
14
+ class Index(Base, CommonMixin):
15
+ """Index model."""
16
+
17
+ __tablename__ = "indexes"
18
+
19
+ source_id: Mapped[int] = mapped_column(
20
+ ForeignKey("sources.id"), unique=True, index=True
21
+ )
22
+
23
+ def __init__(self, source_id: int) -> None:
24
+ """Initialize the index."""
25
+ super().__init__()
26
+ self.source_id = source_id
27
+
28
+
29
+ class Snippet(Base, CommonMixin):
30
+ """Snippet model."""
31
+
32
+ __tablename__ = "snippets"
33
+
34
+ file_id: Mapped[int] = mapped_column(ForeignKey("files.id"))
35
+ index_id: Mapped[int] = mapped_column(ForeignKey("indexes.id"))
36
+ content: Mapped[str] = mapped_column(UnicodeText, default="")
37
+
38
+ def __init__(self, file_id: int, index_id: int, content: str) -> None:
39
+ """Initialize the snippet."""
40
+ super().__init__()
41
+ self.file_id = file_id
42
+ self.index_id = index_id
43
+ self.content = content
@@ -0,0 +1,132 @@
1
+ """Repository for managing code indexes and their associated files and snippets.
2
+
3
+ This module provides the IndexRepository class which handles all database operations
4
+ related to code indexes, including creating indexes, managing files and snippets,
5
+ and retrieving index information with their associated metadata.
6
+ """
7
+
8
+ from datetime import UTC, datetime
9
+ from typing import TypeVar
10
+
11
+ from sqlalchemy import func, select
12
+ from sqlalchemy.ext.asyncio import AsyncSession
13
+
14
+ from kodit.indexing.models import Index, Snippet
15
+ from kodit.sources.models import File, Source
16
+
17
+ T = TypeVar("T")
18
+
19
+
20
+ class IndexRepository:
21
+ """Repository for managing code indexes and their associated data.
22
+
23
+ This class provides methods for creating and managing code indexes, including
24
+ their associated files and snippets. It handles all database operations related
25
+ to indexing code sources.
26
+ """
27
+
28
+ def __init__(self, session: AsyncSession) -> None:
29
+ """Initialize the index repository.
30
+
31
+ Args:
32
+ session: The SQLAlchemy async session to use for database operations.
33
+
34
+ """
35
+ self.session = session
36
+
37
+ async def create(self, source_id: int) -> Index:
38
+ """Create a new index for a source.
39
+
40
+ Args:
41
+ source_id: The ID of the source to create an index for.
42
+
43
+ Returns:
44
+ The newly created Index instance.
45
+
46
+ """
47
+ index = Index(source_id=source_id)
48
+ self.session.add(index)
49
+ await self.session.commit()
50
+ return index
51
+
52
+ async def get_by_id(self, index_id: int) -> Index | None:
53
+ """Get an index by its ID.
54
+
55
+ Args:
56
+ index_id: The ID of the index to retrieve.
57
+
58
+ Returns:
59
+ The Index instance if found, None otherwise.
60
+
61
+ """
62
+ query = select(Index).where(Index.id == index_id)
63
+ result = await self.session.execute(query)
64
+ return result.scalar_one_or_none()
65
+
66
+ async def files_for_index(self, index_id: int) -> list[File]:
67
+ """Get all files for an index.
68
+
69
+ Args:
70
+ index_id: The ID of the index to get files for.
71
+
72
+ Returns:
73
+ A list of File instances.
74
+
75
+ """
76
+ query = (
77
+ select(File)
78
+ .join(Source, File.source_id == Source.id)
79
+ .join(Index, Index.source_id == Source.id)
80
+ .where(Index.id == index_id)
81
+ )
82
+ result = await self.session.execute(query)
83
+ return list(result.scalars())
84
+
85
+ async def list_indexes(self) -> list[Index]:
86
+ """List all indexes.
87
+
88
+ Returns:
89
+ A list of tuples containing index information, source details,
90
+ and counts of files and snippets.
91
+
92
+ """
93
+ query = select(Index).limit(10)
94
+ result = await self.session.execute(query)
95
+ return list(result.scalars())
96
+
97
+ async def num_snippets_for_index(self, index_id: int) -> int:
98
+ """Get the number of snippets for an index."""
99
+ query = select(func.count()).where(Snippet.index_id == index_id)
100
+ result = await self.session.execute(query)
101
+ return result.scalar_one()
102
+
103
+ async def update_index_timestamp(self, index: Index) -> None:
104
+ """Update the updated_at timestamp of an index.
105
+
106
+ Args:
107
+ index: The Index instance to update.
108
+
109
+ """
110
+ index.updated_at = datetime.now(UTC)
111
+ await self.session.commit()
112
+
113
+ async def add_snippet(self, snippet: Snippet) -> None:
114
+ """Add a new snippet to the database.
115
+
116
+ Args:
117
+ snippet: The Snippet instance to add.
118
+
119
+ """
120
+ self.session.add(snippet)
121
+ await self.session.commit()
122
+
123
+ async def get_snippets_for_index(self, index_id: int) -> list[Snippet]:
124
+ """Get all snippets for an index.
125
+
126
+ Args:
127
+ index_id: The ID of the index to get snippets for.
128
+
129
+ """
130
+ query = select(Snippet).where(Snippet.index_id == index_id)
131
+ result = await self.session.execute(query)
132
+ return list(result.scalars())