kodit 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (42) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +4 -2
  3. kodit/application/factories/code_indexing_factory.py +54 -7
  4. kodit/application/factories/reporting_factory.py +27 -0
  5. kodit/application/services/auto_indexing_service.py +16 -4
  6. kodit/application/services/code_indexing_application_service.py +115 -133
  7. kodit/application/services/indexing_worker_service.py +18 -20
  8. kodit/application/services/queue_service.py +12 -14
  9. kodit/application/services/reporting.py +86 -0
  10. kodit/application/services/sync_scheduler.py +21 -20
  11. kodit/cli.py +14 -18
  12. kodit/config.py +24 -1
  13. kodit/database.py +2 -1
  14. kodit/domain/protocols.py +9 -1
  15. kodit/domain/services/bm25_service.py +1 -6
  16. kodit/domain/services/index_service.py +22 -58
  17. kodit/domain/value_objects.py +57 -9
  18. kodit/infrastructure/api/v1/dependencies.py +23 -10
  19. kodit/infrastructure/cloning/git/working_copy.py +36 -7
  20. kodit/infrastructure/embedding/embedding_factory.py +8 -3
  21. kodit/infrastructure/embedding/embedding_providers/litellm_embedding_provider.py +48 -55
  22. kodit/infrastructure/git/git_utils.py +3 -2
  23. kodit/infrastructure/mappers/index_mapper.py +1 -0
  24. kodit/infrastructure/reporting/__init__.py +1 -0
  25. kodit/infrastructure/reporting/log_progress.py +65 -0
  26. kodit/infrastructure/reporting/tdqm_progress.py +73 -0
  27. kodit/infrastructure/sqlalchemy/embedding_repository.py +47 -68
  28. kodit/infrastructure/sqlalchemy/entities.py +28 -2
  29. kodit/infrastructure/sqlalchemy/index_repository.py +274 -236
  30. kodit/infrastructure/sqlalchemy/task_repository.py +55 -39
  31. kodit/infrastructure/sqlalchemy/unit_of_work.py +59 -0
  32. kodit/mcp.py +10 -2
  33. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/METADATA +1 -1
  34. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/RECORD +37 -36
  35. kodit/domain/interfaces.py +0 -27
  36. kodit/infrastructure/ui/__init__.py +0 -1
  37. kodit/infrastructure/ui/progress.py +0 -170
  38. kodit/infrastructure/ui/spinner.py +0 -74
  39. kodit/reporting.py +0 -78
  40. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/WHEEL +0 -0
  41. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/entry_points.txt +0 -0
  42. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@
3
3
  import tempfile
4
4
 
5
5
  import git
6
+ import git.cmd
6
7
  import structlog
7
8
 
8
9
 
@@ -19,10 +20,10 @@ def is_valid_clone_target(target: str) -> bool:
19
20
  """
20
21
  with tempfile.TemporaryDirectory() as temp_dir:
21
22
  try:
22
- git.Repo.clone_from(target, temp_dir)
23
+ git.cmd.Git(temp_dir).ls_remote(target)
23
24
  except git.GitCommandError as e:
24
25
  structlog.get_logger(__name__).warning(
25
- "Failed to clone git repository",
26
+ "Failed to list git repository",
26
27
  target=target,
27
28
  error=e,
28
29
  )
@@ -15,6 +15,7 @@ from kodit.domain.value_objects import (
15
15
  from kodit.infrastructure.sqlalchemy import entities as db_entities
16
16
 
17
17
 
18
+ # TODO(Phil): Make this a pure mapper without any DB access # noqa: TD003, FIX002
18
19
  class IndexMapper:
19
20
  """Mapper for converting between domain Index aggregate and database entities."""
20
21
 
@@ -0,0 +1 @@
1
+ """Reporting infrastructure."""
@@ -0,0 +1,65 @@
1
+ """Log progress using structlog."""
2
+
3
+ import time
4
+ from datetime import UTC, datetime
5
+
6
+ import structlog
7
+
8
+ from kodit.config import ReportingConfig
9
+ from kodit.domain.protocols import ReportingModule
10
+ from kodit.domain.value_objects import Progress, ProgressState, ReportingState
11
+
12
+
13
+ class LoggingReportingModule(ReportingModule):
14
+ """Logging reporting module."""
15
+
16
+ def __init__(self, config: ReportingConfig) -> None:
17
+ """Initialize the logging reporting module."""
18
+ self.config = config
19
+ self._log = structlog.get_logger(__name__)
20
+ self._last_log_time: datetime = datetime.now(UTC)
21
+
22
+ def on_change(self, step: Progress) -> None:
23
+ """On step changed."""
24
+ current_time = datetime.now(UTC)
25
+ time_since_last_log = current_time - self._last_log_time
26
+
27
+ if (
28
+ step.state != ReportingState.IN_PROGRESS
29
+ or time_since_last_log >= self.config.log_time_interval
30
+ ):
31
+ self._log.info(
32
+ step.name,
33
+ state=step.state,
34
+ message=step.message,
35
+ completion_percent=step.completion_percent,
36
+ )
37
+ self._last_log_time = current_time
38
+
39
+
40
+ class LogProgress(Progress):
41
+ """Log progress using structlog with time-based throttling."""
42
+
43
+ def __init__(self, config: ReportingConfig | None = None) -> None:
44
+ """Initialize the log progress."""
45
+ self.log = structlog.get_logger()
46
+ self.config = config or ReportingConfig()
47
+ self.last_log_time: float = 0
48
+
49
+ def on_update(self, state: ProgressState) -> None:
50
+ """Log the progress with time-based throttling."""
51
+ current_time = time.time()
52
+ time_since_last_log = current_time - self.last_log_time
53
+
54
+ if time_since_last_log >= self.config.log_time_interval.total_seconds():
55
+ self.log.info(
56
+ "Progress...",
57
+ operation=state.operation,
58
+ percentage=state.percentage,
59
+ message=state.message,
60
+ )
61
+ self.last_log_time = current_time
62
+
63
+ def on_complete(self) -> None:
64
+ """Log the completion."""
65
+ self.log.info("Completed")
@@ -0,0 +1,73 @@
1
+ """TQDM progress."""
2
+
3
+ from tqdm import tqdm
4
+
5
+ from kodit.config import ReportingConfig
6
+ from kodit.domain.protocols import ReportingModule
7
+ from kodit.domain.value_objects import Progress, ProgressState, ReportingState
8
+
9
+
10
+ class TQDMReportingModule(ReportingModule):
11
+ """TQDM reporting module."""
12
+
13
+ def __init__(self, config: ReportingConfig) -> None:
14
+ """Initialize the TQDM reporting module."""
15
+ self.config = config
16
+ self.pbar = tqdm()
17
+
18
+ def on_change(self, step: Progress) -> None:
19
+ """On step changed."""
20
+ if step.state == ReportingState.COMPLETED:
21
+ self.pbar.close()
22
+ return
23
+
24
+ self.pbar.set_description(step.message)
25
+ self.pbar.refresh()
26
+ # Update description if message is provided
27
+ if step.message:
28
+ # Fix the event message to a specific size so it's not jumping around
29
+ # If it's too small, add spaces
30
+ # If it's too large, truncate
31
+ if len(step.message) < 30:
32
+ self.pbar.set_description(step.message + " " * (30 - len(step.message)))
33
+ else:
34
+ self.pbar.set_description(step.message[-30:])
35
+ else:
36
+ self.pbar.set_description(step.name)
37
+
38
+
39
+ class TQDMProgress(Progress):
40
+ """TQDM-based progress callback implementation."""
41
+
42
+ def __init__(self, config: ReportingConfig | None = None) -> None:
43
+ """Initialize with a TQDM progress bar."""
44
+ self.config = config or ReportingConfig()
45
+ self.pbar = tqdm()
46
+
47
+ def on_update(self, state: ProgressState) -> None:
48
+ """Update the TQDM progress bar."""
49
+ # Update total if it changes
50
+ if state.total != self.pbar.total:
51
+ self.pbar.total = state.total
52
+
53
+ # Update the progress bar
54
+ self.pbar.n = state.current
55
+ self.pbar.refresh()
56
+
57
+ # Update description if message is provided
58
+ if state.message:
59
+ # Fix the event message to a specific size so it's not jumping around
60
+ # If it's too small, add spaces
61
+ # If it's too large, truncate
62
+ if len(state.message) < 30:
63
+ self.pbar.set_description(
64
+ state.message + " " * (30 - len(state.message))
65
+ )
66
+ else:
67
+ self.pbar.set_description(state.message[-30:])
68
+ else:
69
+ self.pbar.set_description(state.operation)
70
+
71
+ def on_complete(self) -> None:
72
+ """Complete the progress bar."""
73
+ self.pbar.close()
@@ -1,85 +1,64 @@
1
1
  """SQLAlchemy implementation of embedding repository."""
2
2
 
3
+ from collections.abc import Callable
4
+
3
5
  import numpy as np
4
6
  from sqlalchemy import select
5
7
  from sqlalchemy.ext.asyncio import AsyncSession
6
8
 
7
9
  from kodit.infrastructure.sqlalchemy.entities import Embedding, EmbeddingType
10
+ from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
8
11
 
9
12
 
10
- class SqlAlchemyEmbeddingRepository:
11
- """SQLAlchemy implementation of embedding repository."""
12
-
13
- def __init__(self, session: AsyncSession) -> None:
14
- """Initialize the SQLAlchemy embedding repository.
15
-
16
- Args:
17
- session: The SQLAlchemy async session to use for database operations
18
-
19
- """
20
- self.session = session
13
+ def create_embedding_repository(
14
+ session_factory: Callable[[], AsyncSession],
15
+ ) -> "SqlAlchemyEmbeddingRepository":
16
+ """Create an embedding repository."""
17
+ uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
18
+ return SqlAlchemyEmbeddingRepository(uow)
21
19
 
22
- async def create_embedding(self, embedding: Embedding) -> Embedding:
23
- """Create a new embedding record in the database.
24
20
 
25
- Args:
26
- embedding: The Embedding instance to create
21
+ class SqlAlchemyEmbeddingRepository:
22
+ """SQLAlchemy implementation of embedding repository."""
27
23
 
28
- Returns:
29
- The created Embedding instance
24
+ def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
25
+ """Initialize the SQLAlchemy embedding repository."""
26
+ self.uow = uow
30
27
 
31
- """
32
- self.session.add(embedding)
33
- return embedding
28
+ async def create_embedding(self, embedding: Embedding) -> None:
29
+ """Create a new embedding record in the database."""
30
+ async with self.uow:
31
+ self.uow.session.add(embedding)
34
32
 
35
33
  async def get_embedding_by_snippet_id_and_type(
36
34
  self, snippet_id: int, embedding_type: EmbeddingType
37
35
  ) -> Embedding | None:
38
- """Get an embedding by its snippet ID and type.
39
-
40
- Args:
41
- snippet_id: The ID of the snippet to get the embedding for
42
- embedding_type: The type of embedding to get
43
-
44
- Returns:
45
- The Embedding instance if found, None otherwise
46
-
47
- """
48
- query = select(Embedding).where(
49
- Embedding.snippet_id == snippet_id,
50
- Embedding.type == embedding_type,
51
- )
52
- result = await self.session.execute(query)
53
- return result.scalar_one_or_none()
36
+ """Get an embedding by its snippet ID and type."""
37
+ async with self.uow:
38
+ query = select(Embedding).where(
39
+ Embedding.snippet_id == snippet_id,
40
+ Embedding.type == embedding_type,
41
+ )
42
+ result = await self.uow.session.execute(query)
43
+ return result.scalar_one_or_none()
54
44
 
55
45
  async def list_embeddings_by_type(
56
46
  self, embedding_type: EmbeddingType
57
47
  ) -> list[Embedding]:
58
- """List all embeddings of a given type.
59
-
60
- Args:
61
- embedding_type: The type of embeddings to list
62
-
63
- Returns:
64
- A list of Embedding instances
65
-
66
- """
67
- query = select(Embedding).where(Embedding.type == embedding_type)
68
- result = await self.session.execute(query)
69
- return list(result.scalars())
48
+ """List all embeddings of a given type."""
49
+ async with self.uow:
50
+ query = select(Embedding).where(Embedding.type == embedding_type)
51
+ result = await self.uow.session.execute(query)
52
+ return list(result.scalars())
70
53
 
71
54
  async def delete_embeddings_by_snippet_id(self, snippet_id: int) -> None:
72
- """Delete all embeddings for a snippet.
73
-
74
- Args:
75
- snippet_id: The ID of the snippet to delete embeddings for
76
-
77
- """
78
- query = select(Embedding).where(Embedding.snippet_id == snippet_id)
79
- result = await self.session.execute(query)
80
- embeddings = result.scalars().all()
81
- for embedding in embeddings:
82
- await self.session.delete(embedding)
55
+ """Delete all embeddings for a snippet."""
56
+ async with self.uow:
57
+ query = select(Embedding).where(Embedding.snippet_id == snippet_id)
58
+ result = await self.uow.session.execute(query)
59
+ embeddings = result.scalars().all()
60
+ for embedding in embeddings:
61
+ await self.uow.session.delete(embedding)
83
62
 
84
63
  async def list_semantic_results(
85
64
  self,
@@ -130,17 +109,17 @@ class SqlAlchemyEmbeddingRepository:
130
109
  List of (snippet_id, embedding) tuples
131
110
 
132
111
  """
133
- # Only select the fields we need and use a more efficient query
134
- query = select(Embedding.snippet_id, Embedding.embedding).where(
135
- Embedding.type == embedding_type
136
- )
112
+ async with self.uow:
113
+ query = select(Embedding.snippet_id, Embedding.embedding).where(
114
+ Embedding.type == embedding_type
115
+ )
137
116
 
138
- # Add snippet_ids filter if provided
139
- if snippet_ids is not None:
140
- query = query.where(Embedding.snippet_id.in_(snippet_ids))
117
+ # Add snippet_ids filter if provided
118
+ if snippet_ids is not None:
119
+ query = query.where(Embedding.snippet_id.in_(snippet_ids))
141
120
 
142
- rows = await self.session.execute(query)
143
- return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
121
+ rows = await self.uow.session.execute(query)
122
+ return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
144
123
 
145
124
  def _prepare_vectors(
146
125
  self, embeddings: list[tuple[int, list[float]]], query_embedding: list[float]
@@ -2,6 +2,7 @@
2
2
 
3
3
  from datetime import UTC, datetime
4
4
  from enum import Enum
5
+ from typing import Any
5
6
 
6
7
  from git import Actor
7
8
  from sqlalchemy import (
@@ -9,6 +10,7 @@ from sqlalchemy import (
9
10
  ForeignKey,
10
11
  Integer,
11
12
  String,
13
+ TypeDecorator,
12
14
  UnicodeText,
13
15
  UniqueConstraint,
14
16
  )
@@ -18,6 +20,29 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
18
20
  from sqlalchemy.types import JSON
19
21
 
20
22
 
23
+ # See <https://docs.sqlalchemy.org/en/20/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc>
24
+ # And [this issue](https://github.com/sqlalchemy/sqlalchemy/issues/1985)
25
+ class TZDateTime(TypeDecorator):
26
+ """Timezone-aware datetime type."""
27
+
28
+ impl = DateTime
29
+ cache_ok = True
30
+
31
+ def process_bind_param(self, value: Any, dialect: Any) -> Any: # noqa: ARG002
32
+ """Process bind param."""
33
+ if value is not None:
34
+ if not value.tzinfo or value.tzinfo.utcoffset(value) is None:
35
+ raise TypeError("tzinfo is required")
36
+ value = value.astimezone(UTC).replace(tzinfo=None)
37
+ return value
38
+
39
+ def process_result_value(self, value: Any, dialect: Any) -> Any: # noqa: ARG002
40
+ """Process result value."""
41
+ if value is not None:
42
+ value = value.replace(tzinfo=UTC)
43
+ return value
44
+
45
+
21
46
  class Base(AsyncAttrs, DeclarativeBase):
22
47
  """Base class for all models."""
23
48
 
@@ -27,10 +52,11 @@ class CommonMixin:
27
52
 
28
53
  id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
29
54
  created_at: Mapped[datetime] = mapped_column(
30
- DateTime(timezone=True), default=lambda: datetime.now(UTC)
55
+ TZDateTime, nullable=False, default=lambda: datetime.now(UTC)
31
56
  )
32
57
  updated_at: Mapped[datetime] = mapped_column(
33
- DateTime(timezone=True),
58
+ TZDateTime,
59
+ nullable=False,
34
60
  default=lambda: datetime.now(UTC),
35
61
  onupdate=lambda: datetime.now(UTC),
36
62
  )