kodit 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/app.py +6 -1
- kodit/application/factories/code_indexing_factory.py +14 -12
- kodit/application/factories/reporting_factory.py +10 -5
- kodit/application/services/auto_indexing_service.py +28 -32
- kodit/application/services/code_indexing_application_service.py +43 -26
- kodit/application/services/indexing_worker_service.py +10 -12
- kodit/application/services/reporting.py +72 -54
- kodit/cli.py +68 -78
- kodit/config.py +2 -2
- kodit/domain/entities.py +99 -1
- kodit/domain/protocols.py +28 -3
- kodit/domain/services/index_service.py +11 -9
- kodit/domain/services/task_status_query_service.py +19 -0
- kodit/domain/value_objects.py +26 -29
- kodit/infrastructure/api/v1/dependencies.py +19 -4
- kodit/infrastructure/api/v1/routers/indexes.py +45 -0
- kodit/infrastructure/api/v1/schemas/task_status.py +39 -0
- kodit/infrastructure/cloning/git/working_copy.py +9 -2
- kodit/infrastructure/enrichment/local_enrichment_provider.py +41 -30
- kodit/infrastructure/mappers/task_status_mapper.py +85 -0
- kodit/infrastructure/reporting/db_progress.py +23 -0
- kodit/infrastructure/reporting/log_progress.py +5 -33
- kodit/infrastructure/reporting/tdqm_progress.py +10 -45
- kodit/infrastructure/sqlalchemy/entities.py +61 -0
- kodit/infrastructure/sqlalchemy/task_status_repository.py +79 -0
- kodit/mcp.py +6 -2
- kodit/migrations/env.py +0 -1
- kodit/migrations/versions/b9cd1c3fd762_add_task_status.py +77 -0
- {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/METADATA +1 -1
- {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/RECORD +34 -28
- {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/WHEEL +0 -0
- {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/entry_points.txt +0 -0
- {kodit-0.4.2.dist-info → kodit-0.4.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Working copy provider for git-based sources."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import hashlib
|
|
4
5
|
import shutil
|
|
5
6
|
from pathlib import Path
|
|
@@ -39,7 +40,7 @@ class GitWorkingCopyProvider:
|
|
|
39
40
|
clone_path.mkdir(parents=True, exist_ok=True)
|
|
40
41
|
|
|
41
42
|
step_record = []
|
|
42
|
-
step.set_total(12)
|
|
43
|
+
await step.set_total(12)
|
|
43
44
|
|
|
44
45
|
def _clone_progress_callback(
|
|
45
46
|
a: int, _: str | float | None, __: str | float | None, _d: str
|
|
@@ -49,7 +50,13 @@ class GitWorkingCopyProvider:
|
|
|
49
50
|
|
|
50
51
|
# Git reports a really weird format. This is a quick hack to get some
|
|
51
52
|
# progress.
|
|
52
|
-
|
|
53
|
+
# Normally this would fail because the loop is already running,
|
|
54
|
+
# but in this case, this callback is called by some git sub-thread.
|
|
55
|
+
asyncio.run(
|
|
56
|
+
step.set_current(
|
|
57
|
+
len(step_record), f"Cloning repository ({step_record[-1]})"
|
|
58
|
+
)
|
|
59
|
+
)
|
|
53
60
|
|
|
54
61
|
try:
|
|
55
62
|
self.log.info(
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""Local enrichment provider implementation."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import os
|
|
4
5
|
from collections.abc import AsyncGenerator
|
|
6
|
+
from typing import Any
|
|
5
7
|
|
|
6
8
|
import structlog
|
|
7
9
|
import tiktoken
|
|
@@ -60,23 +62,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
|
|
|
60
62
|
self.log.warning("No valid requests for enrichment")
|
|
61
63
|
return
|
|
62
64
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
67
|
-
|
|
68
|
-
if self.tokenizer is None:
|
|
69
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
70
|
-
self.model_name, padding_side="left"
|
|
71
|
-
)
|
|
72
|
-
if self.model is None:
|
|
73
|
-
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
74
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
75
|
-
self.model_name,
|
|
76
|
-
torch_dtype="auto",
|
|
77
|
-
trust_remote_code=True,
|
|
78
|
-
device_map="auto",
|
|
65
|
+
def _init_model() -> None:
|
|
66
|
+
from transformers.models.auto.modeling_auto import (
|
|
67
|
+
AutoModelForCausalLM,
|
|
79
68
|
)
|
|
69
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
70
|
+
|
|
71
|
+
if self.tokenizer is None:
|
|
72
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
73
|
+
self.model_name, padding_side="left"
|
|
74
|
+
)
|
|
75
|
+
if self.model is None:
|
|
76
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
77
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
78
|
+
self.model_name,
|
|
79
|
+
torch_dtype="auto",
|
|
80
|
+
trust_remote_code=True,
|
|
81
|
+
device_map="auto",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
await asyncio.to_thread(_init_model)
|
|
80
85
|
|
|
81
86
|
# Prepare prompts
|
|
82
87
|
prompts = [
|
|
@@ -96,20 +101,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
|
|
|
96
101
|
]
|
|
97
102
|
|
|
98
103
|
for prompt in prompts:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
104
|
+
|
|
105
|
+
def process_prompt(prompt: dict[str, Any]) -> str:
|
|
106
|
+
model_inputs = self.tokenizer( # type: ignore[misc]
|
|
107
|
+
prompt["text"],
|
|
108
|
+
return_tensors="pt",
|
|
109
|
+
padding=True,
|
|
110
|
+
truncation=True,
|
|
111
|
+
).to(self.model.device) # type: ignore[attr-defined]
|
|
112
|
+
generated_ids = self.model.generate( # type: ignore[attr-defined]
|
|
113
|
+
**model_inputs, max_new_tokens=self.context_window
|
|
114
|
+
)
|
|
115
|
+
input_ids = model_inputs["input_ids"][0]
|
|
116
|
+
output_ids = generated_ids[0][len(input_ids) :].tolist()
|
|
117
|
+
return self.tokenizer.decode( # type: ignore[attr-defined]
|
|
118
|
+
output_ids, skip_special_tokens=True
|
|
119
|
+
).strip( # type: ignore[attr-defined]
|
|
120
|
+
"\n"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
content = await asyncio.to_thread(process_prompt, prompt)
|
|
113
124
|
# Remove thinking tags from the response
|
|
114
125
|
cleaned_content = clean_thinking_tags(content)
|
|
115
126
|
yield EnrichmentResponse(
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Task status mapper."""
|
|
2
|
+
|
|
3
|
+
from kodit.domain import entities as domain_entities
|
|
4
|
+
from kodit.domain.value_objects import ReportingState, TaskOperation, TrackableType
|
|
5
|
+
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TaskStatusMapper:
|
|
9
|
+
"""Mapper for converting between domain TaskStatus and database entities."""
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def from_domain_task_status(
|
|
13
|
+
task_status: domain_entities.TaskStatus,
|
|
14
|
+
) -> db_entities.TaskStatus:
|
|
15
|
+
"""Convert domain TaskStatus to database TaskStatus."""
|
|
16
|
+
return db_entities.TaskStatus(
|
|
17
|
+
id=task_status.id,
|
|
18
|
+
operation=task_status.operation,
|
|
19
|
+
created_at=task_status.created_at,
|
|
20
|
+
updated_at=task_status.updated_at,
|
|
21
|
+
trackable_id=task_status.trackable_id,
|
|
22
|
+
trackable_type=(
|
|
23
|
+
task_status.trackable_type.value if task_status.trackable_type else None
|
|
24
|
+
),
|
|
25
|
+
parent=task_status.parent.id if task_status.parent else None,
|
|
26
|
+
state=(
|
|
27
|
+
task_status.state.value
|
|
28
|
+
if isinstance(task_status.state, ReportingState)
|
|
29
|
+
else task_status.state
|
|
30
|
+
),
|
|
31
|
+
error=task_status.error,
|
|
32
|
+
total=task_status.total,
|
|
33
|
+
current=task_status.current,
|
|
34
|
+
message=task_status.message,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def to_domain_task_status(
|
|
39
|
+
db_status: db_entities.TaskStatus,
|
|
40
|
+
) -> domain_entities.TaskStatus:
|
|
41
|
+
"""Convert database TaskStatus to domain TaskStatus."""
|
|
42
|
+
return domain_entities.TaskStatus(
|
|
43
|
+
id=db_status.id,
|
|
44
|
+
operation=TaskOperation(db_status.operation),
|
|
45
|
+
state=ReportingState(db_status.state),
|
|
46
|
+
created_at=db_status.created_at,
|
|
47
|
+
updated_at=db_status.updated_at,
|
|
48
|
+
trackable_id=db_status.trackable_id,
|
|
49
|
+
trackable_type=(
|
|
50
|
+
TrackableType(db_status.trackable_type)
|
|
51
|
+
if db_status.trackable_type
|
|
52
|
+
else None
|
|
53
|
+
),
|
|
54
|
+
parent=None, # Parent relationships need to be reconstructed separately
|
|
55
|
+
error=db_status.error if db_status.error else None,
|
|
56
|
+
total=db_status.total,
|
|
57
|
+
current=db_status.current,
|
|
58
|
+
message=db_status.message,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def to_domain_task_status_with_hierarchy(
|
|
63
|
+
db_statuses: list[db_entities.TaskStatus],
|
|
64
|
+
) -> list[domain_entities.TaskStatus]:
|
|
65
|
+
"""Convert database TaskStatus list to domain with parent-child hierarchy.
|
|
66
|
+
|
|
67
|
+
This method performs a two-pass conversion:
|
|
68
|
+
1. First pass: Convert all DB entities to domain entities
|
|
69
|
+
2. Second pass: Reconstruct parent-child relationships using ID mapping
|
|
70
|
+
"""
|
|
71
|
+
# First pass: Convert all database entities to domain entities
|
|
72
|
+
domain_statuses = [
|
|
73
|
+
TaskStatusMapper.to_domain_task_status(db_status)
|
|
74
|
+
for db_status in db_statuses
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
# Create ID-to-entity mapping for efficient parent lookup
|
|
78
|
+
id_to_entity = {status.id: status for status in domain_statuses}
|
|
79
|
+
|
|
80
|
+
# Second pass: Reconstruct parent-child relationships
|
|
81
|
+
for db_status, domain_status in zip(db_statuses, domain_statuses, strict=True):
|
|
82
|
+
if db_status.parent and db_status.parent in id_to_entity:
|
|
83
|
+
domain_status.parent = id_to_entity[db_status.parent]
|
|
84
|
+
|
|
85
|
+
return domain_statuses
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Log progress using structlog."""
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
|
|
5
|
+
from kodit.config import ReportingConfig
|
|
6
|
+
from kodit.domain.entities import TaskStatus
|
|
7
|
+
from kodit.domain.protocols import ReportingModule, TaskStatusRepository
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DBProgressReportingModule(ReportingModule):
|
|
11
|
+
"""Database progress reporting module."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self, task_status_repository: TaskStatusRepository, config: ReportingConfig
|
|
15
|
+
) -> None:
|
|
16
|
+
"""Initialize the database progress reporting module."""
|
|
17
|
+
self.task_status_repository = task_status_repository
|
|
18
|
+
self.config = config
|
|
19
|
+
self._log = structlog.get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
async def on_change(self, progress: TaskStatus) -> None:
|
|
22
|
+
"""On step changed - update task status in database."""
|
|
23
|
+
await self.task_status_repository.save(progress)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Log progress using structlog."""
|
|
2
2
|
|
|
3
|
-
import time
|
|
4
3
|
from datetime import UTC, datetime
|
|
5
4
|
|
|
6
5
|
import structlog
|
|
7
6
|
|
|
8
7
|
from kodit.config import ReportingConfig
|
|
8
|
+
from kodit.domain.entities import TaskStatus
|
|
9
9
|
from kodit.domain.protocols import ReportingModule
|
|
10
|
-
from kodit.domain.value_objects import
|
|
10
|
+
from kodit.domain.value_objects import ReportingState
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class LoggingReportingModule(ReportingModule):
|
|
@@ -19,47 +19,19 @@ class LoggingReportingModule(ReportingModule):
|
|
|
19
19
|
self._log = structlog.get_logger(__name__)
|
|
20
20
|
self._last_log_time: datetime = datetime.now(UTC)
|
|
21
21
|
|
|
22
|
-
def on_change(self,
|
|
22
|
+
async def on_change(self, progress: TaskStatus) -> None:
|
|
23
23
|
"""On step changed."""
|
|
24
24
|
current_time = datetime.now(UTC)
|
|
25
25
|
time_since_last_log = current_time - self._last_log_time
|
|
26
|
+
step = progress
|
|
26
27
|
|
|
27
28
|
if (
|
|
28
29
|
step.state != ReportingState.IN_PROGRESS
|
|
29
30
|
or time_since_last_log >= self.config.log_time_interval
|
|
30
31
|
):
|
|
31
32
|
self._log.info(
|
|
32
|
-
step.
|
|
33
|
+
step.operation,
|
|
33
34
|
state=step.state,
|
|
34
|
-
message=step.message,
|
|
35
35
|
completion_percent=step.completion_percent,
|
|
36
36
|
)
|
|
37
37
|
self._last_log_time = current_time
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class LogProgress(Progress):
|
|
41
|
-
"""Log progress using structlog with time-based throttling."""
|
|
42
|
-
|
|
43
|
-
def __init__(self, config: ReportingConfig | None = None) -> None:
|
|
44
|
-
"""Initialize the log progress."""
|
|
45
|
-
self.log = structlog.get_logger()
|
|
46
|
-
self.config = config or ReportingConfig()
|
|
47
|
-
self.last_log_time: float = 0
|
|
48
|
-
|
|
49
|
-
def on_update(self, state: ProgressState) -> None:
|
|
50
|
-
"""Log the progress with time-based throttling."""
|
|
51
|
-
current_time = time.time()
|
|
52
|
-
time_since_last_log = current_time - self.last_log_time
|
|
53
|
-
|
|
54
|
-
if time_since_last_log >= self.config.log_time_interval.total_seconds():
|
|
55
|
-
self.log.info(
|
|
56
|
-
"Progress...",
|
|
57
|
-
operation=state.operation,
|
|
58
|
-
percentage=state.percentage,
|
|
59
|
-
message=state.message,
|
|
60
|
-
)
|
|
61
|
-
self.last_log_time = current_time
|
|
62
|
-
|
|
63
|
-
def on_complete(self) -> None:
|
|
64
|
-
"""Log the completion."""
|
|
65
|
-
self.log.info("Completed")
|
|
@@ -3,8 +3,9 @@
|
|
|
3
3
|
from tqdm import tqdm
|
|
4
4
|
|
|
5
5
|
from kodit.config import ReportingConfig
|
|
6
|
+
from kodit.domain.entities import TaskStatus
|
|
6
7
|
from kodit.domain.protocols import ReportingModule
|
|
7
|
-
from kodit.domain.value_objects import
|
|
8
|
+
from kodit.domain.value_objects import ReportingState
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class TQDMReportingModule(ReportingModule):
|
|
@@ -15,59 +16,23 @@ class TQDMReportingModule(ReportingModule):
|
|
|
15
16
|
self.config = config
|
|
16
17
|
self.pbar = tqdm()
|
|
17
18
|
|
|
18
|
-
def on_change(self,
|
|
19
|
+
async def on_change(self, progress: TaskStatus) -> None:
|
|
19
20
|
"""On step changed."""
|
|
21
|
+
step = progress
|
|
20
22
|
if step.state == ReportingState.COMPLETED:
|
|
21
23
|
self.pbar.close()
|
|
22
24
|
return
|
|
23
25
|
|
|
24
|
-
self.pbar.set_description(step.
|
|
26
|
+
self.pbar.set_description(step.operation)
|
|
25
27
|
self.pbar.refresh()
|
|
26
28
|
# Update description if message is provided
|
|
27
|
-
if step.
|
|
29
|
+
if step.error:
|
|
28
30
|
# Fix the event message to a specific size so it's not jumping around
|
|
29
31
|
# If it's too small, add spaces
|
|
30
32
|
# If it's too large, truncate
|
|
31
|
-
if len(step.
|
|
32
|
-
self.pbar.set_description(step.
|
|
33
|
+
if len(step.error) < 30:
|
|
34
|
+
self.pbar.set_description(step.error + " " * (30 - len(step.error)))
|
|
33
35
|
else:
|
|
34
|
-
self.pbar.set_description(step.
|
|
36
|
+
self.pbar.set_description(step.error[-30:])
|
|
35
37
|
else:
|
|
36
|
-
self.pbar.set_description(step.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class TQDMProgress(Progress):
|
|
40
|
-
"""TQDM-based progress callback implementation."""
|
|
41
|
-
|
|
42
|
-
def __init__(self, config: ReportingConfig | None = None) -> None:
|
|
43
|
-
"""Initialize with a TQDM progress bar."""
|
|
44
|
-
self.config = config or ReportingConfig()
|
|
45
|
-
self.pbar = tqdm()
|
|
46
|
-
|
|
47
|
-
def on_update(self, state: ProgressState) -> None:
|
|
48
|
-
"""Update the TQDM progress bar."""
|
|
49
|
-
# Update total if it changes
|
|
50
|
-
if state.total != self.pbar.total:
|
|
51
|
-
self.pbar.total = state.total
|
|
52
|
-
|
|
53
|
-
# Update the progress bar
|
|
54
|
-
self.pbar.n = state.current
|
|
55
|
-
self.pbar.refresh()
|
|
56
|
-
|
|
57
|
-
# Update description if message is provided
|
|
58
|
-
if state.message:
|
|
59
|
-
# Fix the event message to a specific size so it's not jumping around
|
|
60
|
-
# If it's too small, add spaces
|
|
61
|
-
# If it's too large, truncate
|
|
62
|
-
if len(state.message) < 30:
|
|
63
|
-
self.pbar.set_description(
|
|
64
|
-
state.message + " " * (30 - len(state.message))
|
|
65
|
-
)
|
|
66
|
-
else:
|
|
67
|
-
self.pbar.set_description(state.message[-30:])
|
|
68
|
-
else:
|
|
69
|
-
self.pbar.set_description(state.operation)
|
|
70
|
-
|
|
71
|
-
def on_complete(self) -> None:
|
|
72
|
-
"""Complete the progress bar."""
|
|
73
|
-
self.pbar.close()
|
|
38
|
+
self.pbar.set_description(step.operation)
|
|
@@ -262,3 +262,64 @@ class Task(Base, CommonMixin):
|
|
|
262
262
|
self.type = type
|
|
263
263
|
self.payload = payload
|
|
264
264
|
self.priority = priority
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class TaskStatus(Base):
|
|
268
|
+
"""Task status model."""
|
|
269
|
+
|
|
270
|
+
__tablename__ = "task_status"
|
|
271
|
+
id: Mapped[str] = mapped_column(
|
|
272
|
+
String(255), primary_key=True, index=True, nullable=False
|
|
273
|
+
)
|
|
274
|
+
created_at: Mapped[datetime] = mapped_column(
|
|
275
|
+
TZDateTime, nullable=False, default=lambda: datetime.now(UTC)
|
|
276
|
+
)
|
|
277
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
278
|
+
TZDateTime,
|
|
279
|
+
nullable=False,
|
|
280
|
+
default=lambda: datetime.now(UTC),
|
|
281
|
+
onupdate=lambda: datetime.now(UTC),
|
|
282
|
+
)
|
|
283
|
+
operation: Mapped[str] = mapped_column(String(255), index=True, nullable=False)
|
|
284
|
+
trackable_id: Mapped[int | None] = mapped_column(Integer, index=True, nullable=True)
|
|
285
|
+
trackable_type: Mapped[str | None] = mapped_column(
|
|
286
|
+
String(255), index=True, nullable=True
|
|
287
|
+
)
|
|
288
|
+
parent: Mapped[str | None] = mapped_column(
|
|
289
|
+
ForeignKey("task_status.id"), index=True, nullable=True
|
|
290
|
+
)
|
|
291
|
+
message: Mapped[str] = mapped_column(UnicodeText, default="")
|
|
292
|
+
state: Mapped[str] = mapped_column(String(255), default="")
|
|
293
|
+
error: Mapped[str] = mapped_column(UnicodeText, default="")
|
|
294
|
+
total: Mapped[int] = mapped_column(Integer, default=0)
|
|
295
|
+
current: Mapped[int] = mapped_column(Integer, default=0)
|
|
296
|
+
|
|
297
|
+
def __init__( # noqa: PLR0913
|
|
298
|
+
self,
|
|
299
|
+
id: str, # noqa: A002
|
|
300
|
+
operation: str,
|
|
301
|
+
created_at: datetime,
|
|
302
|
+
updated_at: datetime,
|
|
303
|
+
trackable_id: int | None,
|
|
304
|
+
trackable_type: str | None,
|
|
305
|
+
parent: str | None,
|
|
306
|
+
state: str,
|
|
307
|
+
error: str | None,
|
|
308
|
+
total: int,
|
|
309
|
+
current: int,
|
|
310
|
+
message: str,
|
|
311
|
+
) -> None:
|
|
312
|
+
"""Initialize the task status."""
|
|
313
|
+
super().__init__()
|
|
314
|
+
self.id = id
|
|
315
|
+
self.operation = operation
|
|
316
|
+
self.created_at = created_at
|
|
317
|
+
self.updated_at = updated_at
|
|
318
|
+
self.trackable_id = trackable_id
|
|
319
|
+
self.trackable_type = trackable_type
|
|
320
|
+
self.parent = parent
|
|
321
|
+
self.state = state
|
|
322
|
+
self.error = error or ""
|
|
323
|
+
self.total = total
|
|
324
|
+
self.current = current
|
|
325
|
+
self.message = message or ""
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Task repository for the task queue."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from sqlalchemy import delete, select
|
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
8
|
+
|
|
9
|
+
from kodit.domain import entities as domain_entities
|
|
10
|
+
from kodit.domain.protocols import TaskStatusRepository
|
|
11
|
+
from kodit.infrastructure.mappers.task_status_mapper import TaskStatusMapper
|
|
12
|
+
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
13
|
+
from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def create_task_status_repository(
|
|
17
|
+
session_factory: Callable[[], AsyncSession],
|
|
18
|
+
) -> TaskStatusRepository:
|
|
19
|
+
"""Create an index repository."""
|
|
20
|
+
uow = SqlAlchemyUnitOfWork(session_factory=session_factory)
|
|
21
|
+
return SqlAlchemyTaskStatusRepository(uow)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SqlAlchemyTaskStatusRepository(TaskStatusRepository):
|
|
25
|
+
"""Repository for persisting TaskStatus entities."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, uow: SqlAlchemyUnitOfWork) -> None:
|
|
28
|
+
"""Initialize the repository."""
|
|
29
|
+
self.uow = uow
|
|
30
|
+
self.log = structlog.get_logger(__name__)
|
|
31
|
+
self.mapper = TaskStatusMapper()
|
|
32
|
+
|
|
33
|
+
async def save(self, status: domain_entities.TaskStatus) -> None:
|
|
34
|
+
"""Save a TaskStatus to database."""
|
|
35
|
+
async with self.uow:
|
|
36
|
+
# Convert domain entity to database entity
|
|
37
|
+
db_status = self.mapper.from_domain_task_status(status)
|
|
38
|
+
stmt = select(db_entities.TaskStatus).where(
|
|
39
|
+
db_entities.TaskStatus.id == db_status.id,
|
|
40
|
+
)
|
|
41
|
+
result = await self.uow.session.execute(stmt)
|
|
42
|
+
existing = result.scalar_one_or_none()
|
|
43
|
+
|
|
44
|
+
if not existing:
|
|
45
|
+
self.uow.session.add(db_status)
|
|
46
|
+
else:
|
|
47
|
+
# Update existing record with new values
|
|
48
|
+
existing.operation = db_status.operation
|
|
49
|
+
existing.state = db_status.state
|
|
50
|
+
existing.error = db_status.error
|
|
51
|
+
existing.total = db_status.total
|
|
52
|
+
existing.current = db_status.current
|
|
53
|
+
existing.updated_at = db_status.updated_at
|
|
54
|
+
existing.parent = db_status.parent
|
|
55
|
+
existing.trackable_id = db_status.trackable_id
|
|
56
|
+
existing.trackable_type = db_status.trackable_type
|
|
57
|
+
|
|
58
|
+
async def load_with_hierarchy(
|
|
59
|
+
self, trackable_type: str, trackable_id: int
|
|
60
|
+
) -> list[domain_entities.TaskStatus]:
|
|
61
|
+
"""Load TaskStatus entities with hierarchy from database."""
|
|
62
|
+
async with self.uow:
|
|
63
|
+
stmt = select(db_entities.TaskStatus).where(
|
|
64
|
+
db_entities.TaskStatus.trackable_id == trackable_id,
|
|
65
|
+
db_entities.TaskStatus.trackable_type == trackable_type,
|
|
66
|
+
)
|
|
67
|
+
result = await self.uow.session.execute(stmt)
|
|
68
|
+
db_statuses = list(result.scalars().all())
|
|
69
|
+
|
|
70
|
+
# Use mapper to convert and reconstruct hierarchy
|
|
71
|
+
return self.mapper.to_domain_task_status_with_hierarchy(db_statuses)
|
|
72
|
+
|
|
73
|
+
async def delete(self, status: domain_entities.TaskStatus) -> None:
|
|
74
|
+
"""Delete a TaskStatus."""
|
|
75
|
+
async with self.uow:
|
|
76
|
+
stmt = delete(db_entities.TaskStatus).where(
|
|
77
|
+
db_entities.TaskStatus.id == status.id,
|
|
78
|
+
)
|
|
79
|
+
await self.uow.session.execute(stmt)
|
kodit/mcp.py
CHANGED
|
@@ -23,6 +23,9 @@ from kodit.domain.value_objects import (
|
|
|
23
23
|
MultiSearchResult,
|
|
24
24
|
SnippetSearchFilters,
|
|
25
25
|
)
|
|
26
|
+
from kodit.infrastructure.sqlalchemy.task_status_repository import (
|
|
27
|
+
create_task_status_repository,
|
|
28
|
+
)
|
|
26
29
|
|
|
27
30
|
# Global database connection for MCP server
|
|
28
31
|
_mcp_db: Database | None = None
|
|
@@ -179,9 +182,10 @@ def register_mcp_tools(mcp_server: FastMCP) -> None:
|
|
|
179
182
|
# Use the unified application service
|
|
180
183
|
service = create_code_indexing_application_service(
|
|
181
184
|
app_context=mcp_context.app_context,
|
|
182
|
-
session=mcp_context.session,
|
|
183
185
|
session_factory=mcp_context.session_factory,
|
|
184
|
-
operation=create_server_operation(
|
|
186
|
+
operation=create_server_operation(
|
|
187
|
+
create_task_status_repository(mcp_context.session_factory)
|
|
188
|
+
),
|
|
185
189
|
)
|
|
186
190
|
|
|
187
191
|
log.debug("Searching for snippets")
|
kodit/migrations/env.py
CHANGED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# ruff: noqa
|
|
2
|
+
"""add task status
|
|
3
|
+
|
|
4
|
+
Revision ID: b9cd1c3fd762
|
|
5
|
+
Revises: 9cf0e87de578
|
|
6
|
+
Create Date: 2025-09-05 13:41:29.645898
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Sequence, Union
|
|
11
|
+
|
|
12
|
+
from alembic import op
|
|
13
|
+
import sqlalchemy as sa
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# revision identifiers, used by Alembic.
|
|
17
|
+
revision: str = "b9cd1c3fd762"
|
|
18
|
+
down_revision: Union[str, None] = "9cf0e87de578"
|
|
19
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
|
20
|
+
depends_on: Union[str, Sequence[str], None] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def upgrade() -> None:
|
|
24
|
+
"""Upgrade schema."""
|
|
25
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
26
|
+
op.create_table(
|
|
27
|
+
"task_status",
|
|
28
|
+
sa.Column("id", sa.String(length=255), nullable=False),
|
|
29
|
+
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
|
30
|
+
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
|
31
|
+
sa.Column("operation", sa.String(length=255), nullable=False),
|
|
32
|
+
sa.Column("trackable_id", sa.Integer(), nullable=True),
|
|
33
|
+
sa.Column("trackable_type", sa.String(length=255), nullable=True),
|
|
34
|
+
sa.Column("parent", sa.String(length=255), nullable=True),
|
|
35
|
+
sa.Column("message", sa.UnicodeText(), nullable=False),
|
|
36
|
+
sa.Column("state", sa.String(length=255), nullable=False),
|
|
37
|
+
sa.Column("error", sa.UnicodeText(), nullable=False),
|
|
38
|
+
sa.Column("total", sa.Integer(), nullable=False),
|
|
39
|
+
sa.Column("current", sa.Integer(), nullable=False),
|
|
40
|
+
sa.ForeignKeyConstraint(
|
|
41
|
+
["parent"],
|
|
42
|
+
["task_status.id"],
|
|
43
|
+
),
|
|
44
|
+
sa.PrimaryKeyConstraint("id"),
|
|
45
|
+
)
|
|
46
|
+
op.create_index(op.f("ix_task_status_id"), "task_status", ["id"], unique=False)
|
|
47
|
+
op.create_index(
|
|
48
|
+
op.f("ix_task_status_operation"), "task_status", ["operation"], unique=False
|
|
49
|
+
)
|
|
50
|
+
op.create_index(
|
|
51
|
+
op.f("ix_task_status_parent"), "task_status", ["parent"], unique=False
|
|
52
|
+
)
|
|
53
|
+
op.create_index(
|
|
54
|
+
op.f("ix_task_status_trackable_id"),
|
|
55
|
+
"task_status",
|
|
56
|
+
["trackable_id"],
|
|
57
|
+
unique=False,
|
|
58
|
+
)
|
|
59
|
+
op.create_index(
|
|
60
|
+
op.f("ix_task_status_trackable_type"),
|
|
61
|
+
"task_status",
|
|
62
|
+
["trackable_type"],
|
|
63
|
+
unique=False,
|
|
64
|
+
)
|
|
65
|
+
# ### end Alembic commands ###
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def downgrade() -> None:
|
|
69
|
+
"""Downgrade schema."""
|
|
70
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
71
|
+
op.drop_index(op.f("ix_task_status_trackable_type"), table_name="task_status")
|
|
72
|
+
op.drop_index(op.f("ix_task_status_trackable_id"), table_name="task_status")
|
|
73
|
+
op.drop_index(op.f("ix_task_status_parent"), table_name="task_status")
|
|
74
|
+
op.drop_index(op.f("ix_task_status_operation"), table_name="task_status")
|
|
75
|
+
op.drop_index(op.f("ix_task_status_id"), table_name="task_status")
|
|
76
|
+
op.drop_table("task_status")
|
|
77
|
+
# ### end Alembic commands ###
|