kodit 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/app.py +9 -2
- kodit/application/factories/code_indexing_factory.py +62 -13
- kodit/application/factories/reporting_factory.py +32 -0
- kodit/application/services/auto_indexing_service.py +41 -33
- kodit/application/services/code_indexing_application_service.py +137 -138
- kodit/application/services/indexing_worker_service.py +26 -30
- kodit/application/services/queue_service.py +12 -14
- kodit/application/services/reporting.py +104 -0
- kodit/application/services/sync_scheduler.py +21 -20
- kodit/cli.py +71 -85
- kodit/config.py +26 -3
- kodit/database.py +2 -1
- kodit/domain/entities.py +99 -1
- kodit/domain/protocols.py +34 -1
- kodit/domain/services/bm25_service.py +1 -6
- kodit/domain/services/index_service.py +23 -57
- kodit/domain/services/task_status_query_service.py +19 -0
- kodit/domain/value_objects.py +53 -8
- kodit/infrastructure/api/v1/dependencies.py +40 -12
- kodit/infrastructure/api/v1/routers/indexes.py +45 -0
- kodit/infrastructure/api/v1/schemas/task_status.py +39 -0
- kodit/infrastructure/cloning/git/working_copy.py +43 -7
- kodit/infrastructure/embedding/embedding_factory.py +8 -3
- kodit/infrastructure/embedding/embedding_providers/litellm_embedding_provider.py +48 -55
- kodit/infrastructure/enrichment/local_enrichment_provider.py +41 -30
- kodit/infrastructure/git/git_utils.py +3 -2
- kodit/infrastructure/mappers/index_mapper.py +1 -0
- kodit/infrastructure/mappers/task_status_mapper.py +85 -0
- kodit/infrastructure/reporting/__init__.py +1 -0
- kodit/infrastructure/reporting/db_progress.py +23 -0
- kodit/infrastructure/reporting/log_progress.py +37 -0
- kodit/infrastructure/reporting/tdqm_progress.py +38 -0
- kodit/infrastructure/sqlalchemy/embedding_repository.py +47 -68
- kodit/infrastructure/sqlalchemy/entities.py +89 -2
- kodit/infrastructure/sqlalchemy/index_repository.py +274 -236
- kodit/infrastructure/sqlalchemy/task_repository.py +55 -39
- kodit/infrastructure/sqlalchemy/task_status_repository.py +79 -0
- kodit/infrastructure/sqlalchemy/unit_of_work.py +59 -0
- kodit/mcp.py +15 -3
- kodit/migrations/env.py +0 -1
- kodit/migrations/versions/b9cd1c3fd762_add_task_status.py +77 -0
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/METADATA +1 -1
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/RECORD +47 -40
- kodit/domain/interfaces.py +0 -27
- kodit/infrastructure/ui/__init__.py +0 -1
- kodit/infrastructure/ui/progress.py +0 -170
- kodit/infrastructure/ui/spinner.py +0 -74
- kodit/reporting.py +0 -78
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/WHEEL +0 -0
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/entry_points.txt +0 -0
- {kodit-0.4.1.dist-info → kodit-0.4.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,16 +7,15 @@ from typing import Any
|
|
|
7
7
|
import httpx
|
|
8
8
|
import litellm
|
|
9
9
|
import structlog
|
|
10
|
+
import tiktoken
|
|
10
11
|
from litellm import aembedding
|
|
11
12
|
|
|
12
13
|
from kodit.config import Endpoint
|
|
13
14
|
from kodit.domain.services.embedding_service import EmbeddingProvider
|
|
14
15
|
from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
BATCH_SIZE = 10 # Maximum number of items per API call
|
|
19
|
-
DEFAULT_NUM_PARALLEL_TASKS = 10 # Semaphore limit for concurrent requests
|
|
16
|
+
from kodit.infrastructure.embedding.embedding_providers.batching import (
|
|
17
|
+
split_sub_batches,
|
|
18
|
+
)
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class LiteLLMEmbeddingProvider(EmbeddingProvider):
|
|
@@ -32,46 +31,36 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
|
|
|
32
31
|
endpoint: The endpoint configuration containing all settings.
|
|
33
32
|
|
|
34
33
|
"""
|
|
35
|
-
self.
|
|
36
|
-
self.api_key = endpoint.api_key
|
|
37
|
-
self.base_url = endpoint.base_url
|
|
38
|
-
self.socket_path = endpoint.socket_path
|
|
39
|
-
self.num_parallel_tasks = (
|
|
40
|
-
endpoint.num_parallel_tasks or DEFAULT_NUM_PARALLEL_TASKS
|
|
41
|
-
)
|
|
42
|
-
self.timeout = endpoint.timeout or 30.0
|
|
43
|
-
self.extra_params = endpoint.extra_params or {}
|
|
34
|
+
self.endpoint = endpoint
|
|
44
35
|
self.log = structlog.get_logger(__name__)
|
|
36
|
+
self._encoding: tiktoken.Encoding | None = None
|
|
45
37
|
|
|
46
38
|
# Configure LiteLLM with custom HTTPX client for Unix socket support if needed
|
|
47
39
|
self._setup_litellm_client()
|
|
48
40
|
|
|
49
41
|
def _setup_litellm_client(self) -> None:
|
|
50
42
|
"""Set up LiteLLM with custom HTTPX client for Unix socket support."""
|
|
51
|
-
if self.socket_path:
|
|
43
|
+
if self.endpoint.socket_path:
|
|
52
44
|
# Create HTTPX client with Unix socket transport
|
|
53
|
-
transport = httpx.AsyncHTTPTransport(uds=self.socket_path)
|
|
45
|
+
transport = httpx.AsyncHTTPTransport(uds=self.endpoint.socket_path)
|
|
54
46
|
unix_client = httpx.AsyncClient(
|
|
55
47
|
transport=transport,
|
|
56
48
|
base_url="http://localhost", # Base URL for Unix socket
|
|
57
|
-
timeout=self.timeout,
|
|
49
|
+
timeout=self.endpoint.timeout,
|
|
58
50
|
)
|
|
59
51
|
# Set as LiteLLM's async client session
|
|
60
52
|
litellm.aclient_session = unix_client
|
|
61
53
|
|
|
62
54
|
def _split_sub_batches(
|
|
63
|
-
self, data: list[EmbeddingRequest]
|
|
55
|
+
self, encoding: tiktoken.Encoding, data: list[EmbeddingRequest]
|
|
64
56
|
) -> list[list[EmbeddingRequest]]:
|
|
65
|
-
"""
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
batch = data[i : i + BATCH_SIZE]
|
|
73
|
-
batches.append(batch)
|
|
74
|
-
return batches
|
|
57
|
+
"""Proxy to the shared batching utility (kept for backward-compat)."""
|
|
58
|
+
return split_sub_batches(
|
|
59
|
+
encoding,
|
|
60
|
+
data,
|
|
61
|
+
max_tokens=self.endpoint.max_tokens,
|
|
62
|
+
batch_size=self.endpoint.num_parallel_tasks,
|
|
63
|
+
)
|
|
75
64
|
|
|
76
65
|
async def _call_embeddings_api(self, texts: list[str]) -> Any:
|
|
77
66
|
"""Call the embeddings API using LiteLLM.
|
|
@@ -84,21 +73,21 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
|
|
|
84
73
|
|
|
85
74
|
"""
|
|
86
75
|
kwargs = {
|
|
87
|
-
"model": self.
|
|
76
|
+
"model": self.endpoint.model,
|
|
88
77
|
"input": texts,
|
|
89
|
-
"timeout": self.timeout,
|
|
78
|
+
"timeout": self.endpoint.timeout,
|
|
90
79
|
}
|
|
91
80
|
|
|
92
81
|
# Add API key if provided
|
|
93
|
-
if self.api_key:
|
|
94
|
-
kwargs["api_key"] = self.api_key
|
|
82
|
+
if self.endpoint.api_key:
|
|
83
|
+
kwargs["api_key"] = self.endpoint.api_key
|
|
95
84
|
|
|
96
85
|
# Add base_url if provided
|
|
97
|
-
if self.base_url:
|
|
98
|
-
kwargs["api_base"] = self.base_url
|
|
86
|
+
if self.endpoint.base_url:
|
|
87
|
+
kwargs["api_base"] = self.endpoint.base_url
|
|
99
88
|
|
|
100
89
|
# Add extra parameters
|
|
101
|
-
kwargs.update(self.extra_params)
|
|
90
|
+
kwargs.update(self.endpoint.extra_params or {})
|
|
102
91
|
|
|
103
92
|
try:
|
|
104
93
|
# Use litellm's async embedding function
|
|
@@ -108,7 +97,7 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
|
|
|
108
97
|
)
|
|
109
98
|
except Exception as e:
|
|
110
99
|
self.log.exception(
|
|
111
|
-
"LiteLLM embedding API error", error=str(e), model=self.
|
|
100
|
+
"LiteLLM embedding API error", error=str(e), model=self.endpoint.model
|
|
112
101
|
)
|
|
113
102
|
raise
|
|
114
103
|
|
|
@@ -121,32 +110,28 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
|
|
|
121
110
|
return
|
|
122
111
|
|
|
123
112
|
# Split into batches
|
|
124
|
-
|
|
113
|
+
encoding = self._get_encoding()
|
|
114
|
+
batched_data = self._split_sub_batches(encoding, data)
|
|
125
115
|
|
|
126
116
|
# Process batches concurrently with semaphore
|
|
127
|
-
sem = asyncio.Semaphore(self.num_parallel_tasks)
|
|
117
|
+
sem = asyncio.Semaphore(self.endpoint.num_parallel_tasks or 10)
|
|
128
118
|
|
|
129
119
|
async def _process_batch(
|
|
130
120
|
batch: list[EmbeddingRequest],
|
|
131
121
|
) -> list[EmbeddingResponse]:
|
|
132
122
|
async with sem:
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
123
|
+
response = await self._call_embeddings_api(
|
|
124
|
+
[item.text for item in batch]
|
|
125
|
+
)
|
|
126
|
+
embeddings_data = response.get("data", [])
|
|
127
|
+
|
|
128
|
+
return [
|
|
129
|
+
EmbeddingResponse(
|
|
130
|
+
snippet_id=item.snippet_id,
|
|
131
|
+
embedding=emb_data.get("embedding", []),
|
|
136
132
|
)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
return [
|
|
140
|
-
EmbeddingResponse(
|
|
141
|
-
snippet_id=item.snippet_id,
|
|
142
|
-
embedding=emb_data.get("embedding", []),
|
|
143
|
-
)
|
|
144
|
-
for item, emb_data in zip(batch, embeddings_data, strict=True)
|
|
145
|
-
]
|
|
146
|
-
except Exception as e:
|
|
147
|
-
self.log.exception("Error embedding batch", error=str(e))
|
|
148
|
-
# Return no embeddings for this batch if there was an error
|
|
149
|
-
return []
|
|
133
|
+
for item, emb_data in zip(batch, embeddings_data, strict=True)
|
|
134
|
+
]
|
|
150
135
|
|
|
151
136
|
tasks = [_process_batch(batch) for batch in batched_data]
|
|
152
137
|
for task in asyncio.as_completed(tasks):
|
|
@@ -155,9 +140,17 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
|
|
|
155
140
|
async def close(self) -> None:
|
|
156
141
|
"""Close the provider and cleanup HTTPX client if using Unix sockets."""
|
|
157
142
|
if (
|
|
158
|
-
self.socket_path
|
|
143
|
+
self.endpoint.socket_path
|
|
159
144
|
and hasattr(litellm, "aclient_session")
|
|
160
145
|
and litellm.aclient_session
|
|
161
146
|
):
|
|
162
147
|
await litellm.aclient_session.aclose()
|
|
163
148
|
litellm.aclient_session = None
|
|
149
|
+
|
|
150
|
+
def _get_encoding(self) -> tiktoken.Encoding:
|
|
151
|
+
"""Return (and cache) the tiktoken encoding for the chosen model."""
|
|
152
|
+
if self._encoding is None:
|
|
153
|
+
self._encoding = tiktoken.get_encoding(
|
|
154
|
+
"o200k_base"
|
|
155
|
+
) # Reasonable default for most models, but might not be perfect.
|
|
156
|
+
return self._encoding
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""Local enrichment provider implementation."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import os
|
|
4
5
|
from collections.abc import AsyncGenerator
|
|
6
|
+
from typing import Any
|
|
5
7
|
|
|
6
8
|
import structlog
|
|
7
9
|
import tiktoken
|
|
@@ -60,23 +62,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
|
|
|
60
62
|
self.log.warning("No valid requests for enrichment")
|
|
61
63
|
return
|
|
62
64
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
67
|
-
|
|
68
|
-
if self.tokenizer is None:
|
|
69
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
70
|
-
self.model_name, padding_side="left"
|
|
71
|
-
)
|
|
72
|
-
if self.model is None:
|
|
73
|
-
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
74
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
75
|
-
self.model_name,
|
|
76
|
-
torch_dtype="auto",
|
|
77
|
-
trust_remote_code=True,
|
|
78
|
-
device_map="auto",
|
|
65
|
+
def _init_model() -> None:
|
|
66
|
+
from transformers.models.auto.modeling_auto import (
|
|
67
|
+
AutoModelForCausalLM,
|
|
79
68
|
)
|
|
69
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
70
|
+
|
|
71
|
+
if self.tokenizer is None:
|
|
72
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
73
|
+
self.model_name, padding_side="left"
|
|
74
|
+
)
|
|
75
|
+
if self.model is None:
|
|
76
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
77
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
78
|
+
self.model_name,
|
|
79
|
+
torch_dtype="auto",
|
|
80
|
+
trust_remote_code=True,
|
|
81
|
+
device_map="auto",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
await asyncio.to_thread(_init_model)
|
|
80
85
|
|
|
81
86
|
# Prepare prompts
|
|
82
87
|
prompts = [
|
|
@@ -96,20 +101,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
|
|
|
96
101
|
]
|
|
97
102
|
|
|
98
103
|
for prompt in prompts:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
104
|
+
|
|
105
|
+
def process_prompt(prompt: dict[str, Any]) -> str:
|
|
106
|
+
model_inputs = self.tokenizer( # type: ignore[misc]
|
|
107
|
+
prompt["text"],
|
|
108
|
+
return_tensors="pt",
|
|
109
|
+
padding=True,
|
|
110
|
+
truncation=True,
|
|
111
|
+
).to(self.model.device) # type: ignore[attr-defined]
|
|
112
|
+
generated_ids = self.model.generate( # type: ignore[attr-defined]
|
|
113
|
+
**model_inputs, max_new_tokens=self.context_window
|
|
114
|
+
)
|
|
115
|
+
input_ids = model_inputs["input_ids"][0]
|
|
116
|
+
output_ids = generated_ids[0][len(input_ids) :].tolist()
|
|
117
|
+
return self.tokenizer.decode( # type: ignore[attr-defined]
|
|
118
|
+
output_ids, skip_special_tokens=True
|
|
119
|
+
).strip( # type: ignore[attr-defined]
|
|
120
|
+
"\n"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
content = await asyncio.to_thread(process_prompt, prompt)
|
|
113
124
|
# Remove thinking tags from the response
|
|
114
125
|
cleaned_content = clean_thinking_tags(content)
|
|
115
126
|
yield EnrichmentResponse(
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import tempfile
|
|
4
4
|
|
|
5
5
|
import git
|
|
6
|
+
import git.cmd
|
|
6
7
|
import structlog
|
|
7
8
|
|
|
8
9
|
|
|
@@ -19,10 +20,10 @@ def is_valid_clone_target(target: str) -> bool:
|
|
|
19
20
|
"""
|
|
20
21
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
21
22
|
try:
|
|
22
|
-
git.
|
|
23
|
+
git.cmd.Git(temp_dir).ls_remote(target)
|
|
23
24
|
except git.GitCommandError as e:
|
|
24
25
|
structlog.get_logger(__name__).warning(
|
|
25
|
-
"Failed to
|
|
26
|
+
"Failed to list git repository",
|
|
26
27
|
target=target,
|
|
27
28
|
error=e,
|
|
28
29
|
)
|
|
@@ -15,6 +15,7 @@ from kodit.domain.value_objects import (
|
|
|
15
15
|
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
# TODO(Phil): Make this a pure mapper without any DB access # noqa: TD003, FIX002
|
|
18
19
|
class IndexMapper:
|
|
19
20
|
"""Mapper for converting between domain Index aggregate and database entities."""
|
|
20
21
|
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Task status mapper."""
|
|
2
|
+
|
|
3
|
+
from kodit.domain import entities as domain_entities
|
|
4
|
+
from kodit.domain.value_objects import ReportingState, TaskOperation, TrackableType
|
|
5
|
+
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TaskStatusMapper:
|
|
9
|
+
"""Mapper for converting between domain TaskStatus and database entities."""
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def from_domain_task_status(
|
|
13
|
+
task_status: domain_entities.TaskStatus,
|
|
14
|
+
) -> db_entities.TaskStatus:
|
|
15
|
+
"""Convert domain TaskStatus to database TaskStatus."""
|
|
16
|
+
return db_entities.TaskStatus(
|
|
17
|
+
id=task_status.id,
|
|
18
|
+
operation=task_status.operation,
|
|
19
|
+
created_at=task_status.created_at,
|
|
20
|
+
updated_at=task_status.updated_at,
|
|
21
|
+
trackable_id=task_status.trackable_id,
|
|
22
|
+
trackable_type=(
|
|
23
|
+
task_status.trackable_type.value if task_status.trackable_type else None
|
|
24
|
+
),
|
|
25
|
+
parent=task_status.parent.id if task_status.parent else None,
|
|
26
|
+
state=(
|
|
27
|
+
task_status.state.value
|
|
28
|
+
if isinstance(task_status.state, ReportingState)
|
|
29
|
+
else task_status.state
|
|
30
|
+
),
|
|
31
|
+
error=task_status.error,
|
|
32
|
+
total=task_status.total,
|
|
33
|
+
current=task_status.current,
|
|
34
|
+
message=task_status.message,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def to_domain_task_status(
|
|
39
|
+
db_status: db_entities.TaskStatus,
|
|
40
|
+
) -> domain_entities.TaskStatus:
|
|
41
|
+
"""Convert database TaskStatus to domain TaskStatus."""
|
|
42
|
+
return domain_entities.TaskStatus(
|
|
43
|
+
id=db_status.id,
|
|
44
|
+
operation=TaskOperation(db_status.operation),
|
|
45
|
+
state=ReportingState(db_status.state),
|
|
46
|
+
created_at=db_status.created_at,
|
|
47
|
+
updated_at=db_status.updated_at,
|
|
48
|
+
trackable_id=db_status.trackable_id,
|
|
49
|
+
trackable_type=(
|
|
50
|
+
TrackableType(db_status.trackable_type)
|
|
51
|
+
if db_status.trackable_type
|
|
52
|
+
else None
|
|
53
|
+
),
|
|
54
|
+
parent=None, # Parent relationships need to be reconstructed separately
|
|
55
|
+
error=db_status.error if db_status.error else None,
|
|
56
|
+
total=db_status.total,
|
|
57
|
+
current=db_status.current,
|
|
58
|
+
message=db_status.message,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def to_domain_task_status_with_hierarchy(
|
|
63
|
+
db_statuses: list[db_entities.TaskStatus],
|
|
64
|
+
) -> list[domain_entities.TaskStatus]:
|
|
65
|
+
"""Convert database TaskStatus list to domain with parent-child hierarchy.
|
|
66
|
+
|
|
67
|
+
This method performs a two-pass conversion:
|
|
68
|
+
1. First pass: Convert all DB entities to domain entities
|
|
69
|
+
2. Second pass: Reconstruct parent-child relationships using ID mapping
|
|
70
|
+
"""
|
|
71
|
+
# First pass: Convert all database entities to domain entities
|
|
72
|
+
domain_statuses = [
|
|
73
|
+
TaskStatusMapper.to_domain_task_status(db_status)
|
|
74
|
+
for db_status in db_statuses
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
# Create ID-to-entity mapping for efficient parent lookup
|
|
78
|
+
id_to_entity = {status.id: status for status in domain_statuses}
|
|
79
|
+
|
|
80
|
+
# Second pass: Reconstruct parent-child relationships
|
|
81
|
+
for db_status, domain_status in zip(db_statuses, domain_statuses, strict=True):
|
|
82
|
+
if db_status.parent and db_status.parent in id_to_entity:
|
|
83
|
+
domain_status.parent = id_to_entity[db_status.parent]
|
|
84
|
+
|
|
85
|
+
return domain_statuses
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Reporting infrastructure."""
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Log progress using structlog."""
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
|
|
5
|
+
from kodit.config import ReportingConfig
|
|
6
|
+
from kodit.domain.entities import TaskStatus
|
|
7
|
+
from kodit.domain.protocols import ReportingModule, TaskStatusRepository
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DBProgressReportingModule(ReportingModule):
|
|
11
|
+
"""Database progress reporting module."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self, task_status_repository: TaskStatusRepository, config: ReportingConfig
|
|
15
|
+
) -> None:
|
|
16
|
+
"""Initialize the database progress reporting module."""
|
|
17
|
+
self.task_status_repository = task_status_repository
|
|
18
|
+
self.config = config
|
|
19
|
+
self._log = structlog.get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
async def on_change(self, progress: TaskStatus) -> None:
|
|
22
|
+
"""On step changed - update task status in database."""
|
|
23
|
+
await self.task_status_repository.save(progress)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Log progress using structlog."""
|
|
2
|
+
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
|
|
7
|
+
from kodit.config import ReportingConfig
|
|
8
|
+
from kodit.domain.entities import TaskStatus
|
|
9
|
+
from kodit.domain.protocols import ReportingModule
|
|
10
|
+
from kodit.domain.value_objects import ReportingState
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LoggingReportingModule(ReportingModule):
|
|
14
|
+
"""Logging reporting module."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config: ReportingConfig) -> None:
|
|
17
|
+
"""Initialize the logging reporting module."""
|
|
18
|
+
self.config = config
|
|
19
|
+
self._log = structlog.get_logger(__name__)
|
|
20
|
+
self._last_log_time: datetime = datetime.now(UTC)
|
|
21
|
+
|
|
22
|
+
async def on_change(self, progress: TaskStatus) -> None:
|
|
23
|
+
"""On step changed."""
|
|
24
|
+
current_time = datetime.now(UTC)
|
|
25
|
+
time_since_last_log = current_time - self._last_log_time
|
|
26
|
+
step = progress
|
|
27
|
+
|
|
28
|
+
if (
|
|
29
|
+
step.state != ReportingState.IN_PROGRESS
|
|
30
|
+
or time_since_last_log >= self.config.log_time_interval
|
|
31
|
+
):
|
|
32
|
+
self._log.info(
|
|
33
|
+
step.operation,
|
|
34
|
+
state=step.state,
|
|
35
|
+
completion_percent=step.completion_percent,
|
|
36
|
+
)
|
|
37
|
+
self._last_log_time = current_time
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""TQDM progress."""
|
|
2
|
+
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
|
|
5
|
+
from kodit.config import ReportingConfig
|
|
6
|
+
from kodit.domain.entities import TaskStatus
|
|
7
|
+
from kodit.domain.protocols import ReportingModule
|
|
8
|
+
from kodit.domain.value_objects import ReportingState
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TQDMReportingModule(ReportingModule):
|
|
12
|
+
"""TQDM reporting module."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: ReportingConfig) -> None:
|
|
15
|
+
"""Initialize the TQDM reporting module."""
|
|
16
|
+
self.config = config
|
|
17
|
+
self.pbar = tqdm()
|
|
18
|
+
|
|
19
|
+
async def on_change(self, progress: TaskStatus) -> None:
|
|
20
|
+
"""On step changed."""
|
|
21
|
+
step = progress
|
|
22
|
+
if step.state == ReportingState.COMPLETED:
|
|
23
|
+
self.pbar.close()
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
self.pbar.set_description(step.operation)
|
|
27
|
+
self.pbar.refresh()
|
|
28
|
+
# Update description if message is provided
|
|
29
|
+
if step.error:
|
|
30
|
+
# Fix the event message to a specific size so it's not jumping around
|
|
31
|
+
# If it's too small, add spaces
|
|
32
|
+
# If it's too large, truncate
|
|
33
|
+
if len(step.error) < 30:
|
|
34
|
+
self.pbar.set_description(step.error + " " * (30 - len(step.error)))
|
|
35
|
+
else:
|
|
36
|
+
self.pbar.set_description(step.error[-30:])
|
|
37
|
+
else:
|
|
38
|
+
self.pbar.set_description(step.operation)
|
|
@@ -1,85 +1,64 @@
|
|
|
1
1
|
"""SQLAlchemy implementation of embedding repository."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
3
5
|
import numpy as np
|
|
4
6
|
from sqlalchemy import select
|
|
5
7
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
6
8
|
|
|
7
9
|
from kodit.infrastructure.sqlalchemy.entities import Embedding, EmbeddingType
|
|
10
|
+
from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
|
|
8
11
|
|
|
9
12
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
session: The SQLAlchemy async session to use for database operations
|
|
18
|
-
|
|
19
|
-
"""
|
|
20
|
-
self.session = session
|
|
13
|
+
def create_embedding_repository(
|
|
14
|
+
session_factory: Callable[[], AsyncSession],
|
|
15
|
+
) -> "SqlAlchemyEmbeddingRepository":
|
|
16
|
+
"""Create an embedding repository."""
|
|
17
|
+
uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
|
|
18
|
+
return SqlAlchemyEmbeddingRepository(uow)
|
|
21
19
|
|
|
22
|
-
async def create_embedding(self, embedding: Embedding) -> Embedding:
|
|
23
|
-
"""Create a new embedding record in the database.
|
|
24
20
|
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
class SqlAlchemyEmbeddingRepository:
|
|
22
|
+
"""SQLAlchemy implementation of embedding repository."""
|
|
27
23
|
|
|
28
|
-
|
|
29
|
-
|
|
24
|
+
def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
|
|
25
|
+
"""Initialize the SQLAlchemy embedding repository."""
|
|
26
|
+
self.uow = uow
|
|
30
27
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
28
|
+
async def create_embedding(self, embedding: Embedding) -> None:
|
|
29
|
+
"""Create a new embedding record in the database."""
|
|
30
|
+
async with self.uow:
|
|
31
|
+
self.uow.session.add(embedding)
|
|
34
32
|
|
|
35
33
|
async def get_embedding_by_snippet_id_and_type(
|
|
36
34
|
self, snippet_id: int, embedding_type: EmbeddingType
|
|
37
35
|
) -> Embedding | None:
|
|
38
|
-
"""Get an embedding by its snippet ID and type.
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
"""
|
|
48
|
-
query = select(Embedding).where(
|
|
49
|
-
Embedding.snippet_id == snippet_id,
|
|
50
|
-
Embedding.type == embedding_type,
|
|
51
|
-
)
|
|
52
|
-
result = await self.session.execute(query)
|
|
53
|
-
return result.scalar_one_or_none()
|
|
36
|
+
"""Get an embedding by its snippet ID and type."""
|
|
37
|
+
async with self.uow:
|
|
38
|
+
query = select(Embedding).where(
|
|
39
|
+
Embedding.snippet_id == snippet_id,
|
|
40
|
+
Embedding.type == embedding_type,
|
|
41
|
+
)
|
|
42
|
+
result = await self.uow.session.execute(query)
|
|
43
|
+
return result.scalar_one_or_none()
|
|
54
44
|
|
|
55
45
|
async def list_embeddings_by_type(
|
|
56
46
|
self, embedding_type: EmbeddingType
|
|
57
47
|
) -> list[Embedding]:
|
|
58
|
-
"""List all embeddings of a given type.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
Returns:
|
|
64
|
-
A list of Embedding instances
|
|
65
|
-
|
|
66
|
-
"""
|
|
67
|
-
query = select(Embedding).where(Embedding.type == embedding_type)
|
|
68
|
-
result = await self.session.execute(query)
|
|
69
|
-
return list(result.scalars())
|
|
48
|
+
"""List all embeddings of a given type."""
|
|
49
|
+
async with self.uow:
|
|
50
|
+
query = select(Embedding).where(Embedding.type == embedding_type)
|
|
51
|
+
result = await self.uow.session.execute(query)
|
|
52
|
+
return list(result.scalars())
|
|
70
53
|
|
|
71
54
|
async def delete_embeddings_by_snippet_id(self, snippet_id: int) -> None:
|
|
72
|
-
"""Delete all embeddings for a snippet.
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
result = await self.session.execute(query)
|
|
80
|
-
embeddings = result.scalars().all()
|
|
81
|
-
for embedding in embeddings:
|
|
82
|
-
await self.session.delete(embedding)
|
|
55
|
+
"""Delete all embeddings for a snippet."""
|
|
56
|
+
async with self.uow:
|
|
57
|
+
query = select(Embedding).where(Embedding.snippet_id == snippet_id)
|
|
58
|
+
result = await self.uow.session.execute(query)
|
|
59
|
+
embeddings = result.scalars().all()
|
|
60
|
+
for embedding in embeddings:
|
|
61
|
+
await self.uow.session.delete(embedding)
|
|
83
62
|
|
|
84
63
|
async def list_semantic_results(
|
|
85
64
|
self,
|
|
@@ -130,17 +109,17 @@ class SqlAlchemyEmbeddingRepository:
|
|
|
130
109
|
List of (snippet_id, embedding) tuples
|
|
131
110
|
|
|
132
111
|
"""
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
112
|
+
async with self.uow:
|
|
113
|
+
query = select(Embedding.snippet_id, Embedding.embedding).where(
|
|
114
|
+
Embedding.type == embedding_type
|
|
115
|
+
)
|
|
137
116
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
117
|
+
# Add snippet_ids filter if provided
|
|
118
|
+
if snippet_ids is not None:
|
|
119
|
+
query = query.where(Embedding.snippet_id.in_(snippet_ids))
|
|
141
120
|
|
|
142
|
-
|
|
143
|
-
|
|
121
|
+
rows = await self.uow.session.execute(query)
|
|
122
|
+
return [tuple(row) for row in rows.all()] # Convert Row objects to tuples
|
|
144
123
|
|
|
145
124
|
def _prepare_vectors(
|
|
146
125
|
self, embeddings: list[tuple[int, list[float]]], query_embedding: list[float]
|