kodit 0.3.15__py3-none-any.whl → 0.3.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/app.py +11 -2
- kodit/application/services/auto_indexing_service.py +16 -7
- kodit/application/services/code_indexing_application_service.py +22 -11
- kodit/application/services/indexing_worker_service.py +154 -0
- kodit/application/services/queue_service.py +52 -0
- kodit/application/services/sync_scheduler.py +10 -48
- kodit/cli.py +407 -148
- kodit/cli_utils.py +74 -0
- kodit/config.py +41 -3
- kodit/domain/entities.py +48 -1
- kodit/domain/protocols.py +29 -2
- kodit/domain/value_objects.py +13 -0
- kodit/infrastructure/api/client/__init__.py +14 -0
- kodit/infrastructure/api/client/base.py +100 -0
- kodit/infrastructure/api/client/exceptions.py +21 -0
- kodit/infrastructure/api/client/generated_endpoints.py +27 -0
- kodit/infrastructure/api/client/index_client.py +57 -0
- kodit/infrastructure/api/client/search_client.py +86 -0
- kodit/infrastructure/api/v1/dependencies.py +13 -0
- kodit/infrastructure/api/v1/routers/indexes.py +9 -4
- kodit/infrastructure/embedding/embedding_factory.py +5 -7
- kodit/infrastructure/embedding/embedding_providers/openai_embedding_provider.py +75 -13
- kodit/infrastructure/enrichment/enrichment_factory.py +5 -8
- kodit/infrastructure/enrichment/local_enrichment_provider.py +4 -1
- kodit/infrastructure/enrichment/openai_enrichment_provider.py +84 -16
- kodit/infrastructure/enrichment/utils.py +30 -0
- kodit/infrastructure/mappers/task_mapper.py +81 -0
- kodit/infrastructure/sqlalchemy/entities.py +35 -0
- kodit/infrastructure/sqlalchemy/index_repository.py +4 -4
- kodit/infrastructure/sqlalchemy/task_repository.py +81 -0
- kodit/middleware.py +1 -0
- kodit/migrations/versions/9cf0e87de578_add_queue.py +47 -0
- kodit/utils/generate_api_paths.py +135 -0
- {kodit-0.3.15.dist-info → kodit-0.3.17.dist-info}/METADATA +1 -1
- {kodit-0.3.15.dist-info → kodit-0.3.17.dist-info}/RECORD +39 -25
- {kodit-0.3.15.dist-info → kodit-0.3.17.dist-info}/WHEEL +0 -0
- {kodit-0.3.15.dist-info → kodit-0.3.17.dist-info}/entry_points.txt +0 -0
- {kodit-0.3.15.dist-info → kodit-0.3.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
"""OpenAI embedding provider implementation."""
|
|
1
|
+
"""OpenAI embedding provider implementation using httpx."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
|
+
from typing import Any
|
|
5
6
|
|
|
7
|
+
import httpx
|
|
6
8
|
import structlog
|
|
7
9
|
import tiktoken
|
|
8
|
-
from openai import AsyncOpenAI
|
|
9
10
|
from tiktoken import Encoding
|
|
10
11
|
|
|
11
12
|
from kodit.domain.services.embedding_service import EmbeddingProvider
|
|
@@ -22,29 +23,53 @@ OPENAI_NUM_PARALLEL_TASKS = 10 # Semaphore limit for concurrent OpenAI requests
|
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
25
|
-
"""OpenAI embedding provider that uses OpenAI's embedding API."""
|
|
26
|
+
"""OpenAI embedding provider that uses OpenAI's embedding API via httpx."""
|
|
26
27
|
|
|
27
|
-
def __init__(
|
|
28
|
+
def __init__( # noqa: PLR0913
|
|
28
29
|
self,
|
|
29
|
-
|
|
30
|
+
api_key: str | None = None,
|
|
31
|
+
base_url: str = "https://api.openai.com",
|
|
30
32
|
model_name: str = "text-embedding-3-small",
|
|
31
33
|
num_parallel_tasks: int = OPENAI_NUM_PARALLEL_TASKS,
|
|
34
|
+
socket_path: str | None = None,
|
|
35
|
+
timeout: float = 30.0,
|
|
32
36
|
) -> None:
|
|
33
37
|
"""Initialize the OpenAI embedding provider.
|
|
34
38
|
|
|
35
39
|
Args:
|
|
36
|
-
|
|
37
|
-
|
|
40
|
+
api_key: The OpenAI API key.
|
|
41
|
+
base_url: The base URL for the OpenAI API.
|
|
42
|
+
model_name: The model name to use for embeddings.
|
|
43
|
+
num_parallel_tasks: Maximum number of concurrent requests.
|
|
44
|
+
socket_path: Optional Unix socket path for local communication.
|
|
45
|
+
timeout: Request timeout in seconds.
|
|
38
46
|
|
|
39
47
|
"""
|
|
40
|
-
self.openai_client = openai_client
|
|
41
48
|
self.model_name = model_name
|
|
42
49
|
self.num_parallel_tasks = num_parallel_tasks
|
|
43
50
|
self.log = structlog.get_logger(__name__)
|
|
51
|
+
self.api_key = api_key
|
|
52
|
+
self.base_url = base_url
|
|
53
|
+
self.socket_path = socket_path
|
|
54
|
+
self.timeout = timeout
|
|
44
55
|
|
|
45
56
|
# Lazily initialised token encoding
|
|
46
57
|
self._encoding: Encoding | None = None
|
|
47
58
|
|
|
59
|
+
# Create httpx client with optional Unix socket support
|
|
60
|
+
if socket_path:
|
|
61
|
+
transport = httpx.AsyncHTTPTransport(uds=socket_path)
|
|
62
|
+
self.http_client = httpx.AsyncClient(
|
|
63
|
+
transport=transport,
|
|
64
|
+
base_url="http://localhost", # Base URL for Unix socket
|
|
65
|
+
timeout=timeout,
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
self.http_client = httpx.AsyncClient(
|
|
69
|
+
base_url=base_url,
|
|
70
|
+
timeout=timeout,
|
|
71
|
+
)
|
|
72
|
+
|
|
48
73
|
# ---------------------------------------------------------------------
|
|
49
74
|
# Helper utilities
|
|
50
75
|
# ---------------------------------------------------------------------
|
|
@@ -76,6 +101,37 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
76
101
|
batch_size=BATCH_SIZE,
|
|
77
102
|
)
|
|
78
103
|
|
|
104
|
+
async def _call_embeddings_api(
|
|
105
|
+
self, texts: list[str]
|
|
106
|
+
) -> dict[str, Any]:
|
|
107
|
+
"""Call the embeddings API using httpx.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
texts: The texts to embed.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The API response as a dictionary.
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
headers = {
|
|
117
|
+
"Content-Type": "application/json",
|
|
118
|
+
}
|
|
119
|
+
if self.api_key:
|
|
120
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
121
|
+
|
|
122
|
+
data = {
|
|
123
|
+
"model": self.model_name,
|
|
124
|
+
"input": texts,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
response = await self.http_client.post(
|
|
128
|
+
"/v1/embeddings",
|
|
129
|
+
json=data,
|
|
130
|
+
headers=headers,
|
|
131
|
+
)
|
|
132
|
+
response.raise_for_status()
|
|
133
|
+
return response.json()
|
|
134
|
+
|
|
79
135
|
async def embed(
|
|
80
136
|
self, data: list[EmbeddingRequest]
|
|
81
137
|
) -> AsyncGenerator[list[EmbeddingResponse], None]:
|
|
@@ -99,17 +155,17 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
99
155
|
) -> list[EmbeddingResponse]:
|
|
100
156
|
async with sem:
|
|
101
157
|
try:
|
|
102
|
-
response = await self.
|
|
103
|
-
|
|
104
|
-
input=[item.text for item in batch],
|
|
158
|
+
response = await self._call_embeddings_api(
|
|
159
|
+
[item.text for item in batch]
|
|
105
160
|
)
|
|
161
|
+
embeddings_data = response.get("data", [])
|
|
106
162
|
|
|
107
163
|
return [
|
|
108
164
|
EmbeddingResponse(
|
|
109
165
|
snippet_id=item.snippet_id,
|
|
110
|
-
embedding=
|
|
166
|
+
embedding=emb_data.get("embedding", []),
|
|
111
167
|
)
|
|
112
|
-
for item,
|
|
168
|
+
for item, emb_data in zip(batch, embeddings_data, strict=True)
|
|
113
169
|
]
|
|
114
170
|
except Exception as e:
|
|
115
171
|
self.log.exception("Error embedding batch", error=str(e))
|
|
@@ -119,3 +175,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
119
175
|
tasks = [_process_batch(batch) for batch in batched_data]
|
|
120
176
|
for task in asyncio.as_completed(tasks):
|
|
121
177
|
yield await task
|
|
178
|
+
|
|
179
|
+
async def close(self) -> None:
|
|
180
|
+
"""Close the HTTP client."""
|
|
181
|
+
if hasattr(self, "http_client"):
|
|
182
|
+
await self.http_client.aclose()
|
|
183
|
+
|
|
@@ -45,17 +45,14 @@ def enrichment_domain_service_factory(
|
|
|
45
45
|
enrichment_provider: EnrichmentProvider | None = None
|
|
46
46
|
if endpoint and endpoint.type == "openai":
|
|
47
47
|
log_event("kodit.enrichment", {"provider": "openai"})
|
|
48
|
-
|
|
49
|
-
|
|
48
|
+
# Use new httpx-based provider with socket support
|
|
50
49
|
enrichment_provider = OpenAIEnrichmentProvider(
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
base_url=endpoint.base_url or "https://api.openai.com/v1",
|
|
54
|
-
timeout=60,
|
|
55
|
-
max_retries=2,
|
|
56
|
-
),
|
|
50
|
+
api_key=endpoint.api_key,
|
|
51
|
+
base_url=endpoint.base_url or "https://api.openai.com/v1",
|
|
57
52
|
model_name=endpoint.model or "gpt-4o-mini",
|
|
58
53
|
num_parallel_tasks=endpoint.num_parallel_tasks or OPENAI_NUM_PARALLEL_TASKS,
|
|
54
|
+
socket_path=endpoint.socket_path,
|
|
55
|
+
timeout=endpoint.timeout or 30.0,
|
|
59
56
|
)
|
|
60
57
|
else:
|
|
61
58
|
log_event("kodit.enrichment", {"provider": "local"})
|
|
@@ -8,6 +8,7 @@ import tiktoken
|
|
|
8
8
|
|
|
9
9
|
from kodit.domain.services.enrichment_service import EnrichmentProvider
|
|
10
10
|
from kodit.domain.value_objects import EnrichmentRequest, EnrichmentResponse
|
|
11
|
+
from kodit.infrastructure.enrichment.utils import clean_thinking_tags
|
|
11
12
|
|
|
12
13
|
ENRICHMENT_SYSTEM_PROMPT = """
|
|
13
14
|
You are a professional software developer. You will be given a snippet of code.
|
|
@@ -109,7 +110,9 @@ class LocalEnrichmentProvider(EnrichmentProvider):
|
|
|
109
110
|
content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip( # type: ignore[attr-defined]
|
|
110
111
|
"\n"
|
|
111
112
|
)
|
|
113
|
+
# Remove thinking tags from the response
|
|
114
|
+
cleaned_content = clean_thinking_tags(content)
|
|
112
115
|
yield EnrichmentResponse(
|
|
113
116
|
snippet_id=prompt["id"],
|
|
114
|
-
text=
|
|
117
|
+
text=cleaned_content,
|
|
115
118
|
)
|
|
@@ -1,13 +1,15 @@
|
|
|
1
|
-
"""OpenAI enrichment provider implementation."""
|
|
1
|
+
"""OpenAI enrichment provider implementation using httpx."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
|
+
import httpx
|
|
7
8
|
import structlog
|
|
8
9
|
|
|
9
10
|
from kodit.domain.services.enrichment_service import EnrichmentProvider
|
|
10
11
|
from kodit.domain.value_objects import EnrichmentRequest, EnrichmentResponse
|
|
12
|
+
from kodit.infrastructure.enrichment.utils import clean_thinking_tags
|
|
11
13
|
|
|
12
14
|
ENRICHMENT_SYSTEM_PROMPT = """
|
|
13
15
|
You are a professional software developer. You will be given a snippet of code.
|
|
@@ -18,26 +20,82 @@ Please provide a concise explanation of the code.
|
|
|
18
20
|
OPENAI_NUM_PARALLEL_TASKS = 40
|
|
19
21
|
|
|
20
22
|
|
|
23
|
+
|
|
21
24
|
class OpenAIEnrichmentProvider(EnrichmentProvider):
|
|
22
|
-
"""OpenAI enrichment provider implementation."""
|
|
25
|
+
"""OpenAI enrichment provider implementation using httpx."""
|
|
23
26
|
|
|
24
|
-
def __init__(
|
|
27
|
+
def __init__( # noqa: PLR0913
|
|
25
28
|
self,
|
|
26
|
-
|
|
29
|
+
api_key: str | None = None,
|
|
30
|
+
base_url: str = "https://api.openai.com",
|
|
27
31
|
model_name: str = "gpt-4o-mini",
|
|
28
32
|
num_parallel_tasks: int = OPENAI_NUM_PARALLEL_TASKS,
|
|
33
|
+
socket_path: str | None = None,
|
|
34
|
+
timeout: float = 30.0,
|
|
29
35
|
) -> None:
|
|
30
36
|
"""Initialize the OpenAI enrichment provider.
|
|
31
37
|
|
|
32
38
|
Args:
|
|
33
|
-
|
|
39
|
+
api_key: The OpenAI API key.
|
|
40
|
+
base_url: The base URL for the OpenAI API.
|
|
34
41
|
model_name: The model name to use for enrichment.
|
|
42
|
+
num_parallel_tasks: Maximum number of concurrent requests.
|
|
43
|
+
socket_path: Optional Unix socket path for local communication.
|
|
44
|
+
timeout: Request timeout in seconds.
|
|
35
45
|
|
|
36
46
|
"""
|
|
37
47
|
self.log = structlog.get_logger(__name__)
|
|
38
|
-
self.openai_client = openai_client
|
|
39
48
|
self.model_name = model_name
|
|
40
49
|
self.num_parallel_tasks = num_parallel_tasks
|
|
50
|
+
self.api_key = api_key
|
|
51
|
+
self.base_url = base_url
|
|
52
|
+
self.socket_path = socket_path
|
|
53
|
+
self.timeout = timeout
|
|
54
|
+
|
|
55
|
+
# Create httpx client with optional Unix socket support
|
|
56
|
+
if socket_path:
|
|
57
|
+
transport = httpx.AsyncHTTPTransport(uds=socket_path)
|
|
58
|
+
self.http_client = httpx.AsyncClient(
|
|
59
|
+
transport=transport,
|
|
60
|
+
base_url="http://localhost", # Base URL for Unix socket
|
|
61
|
+
timeout=timeout,
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
self.http_client = httpx.AsyncClient(
|
|
65
|
+
base_url=base_url,
|
|
66
|
+
timeout=timeout,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
async def _call_chat_completion(
|
|
70
|
+
self, messages: list[dict[str, str]]
|
|
71
|
+
) -> dict[str, Any]:
|
|
72
|
+
"""Call the chat completion API using httpx.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
messages: The messages to send to the API.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
The API response as a dictionary.
|
|
79
|
+
|
|
80
|
+
"""
|
|
81
|
+
headers = {
|
|
82
|
+
"Content-Type": "application/json",
|
|
83
|
+
}
|
|
84
|
+
if self.api_key:
|
|
85
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
86
|
+
|
|
87
|
+
data = {
|
|
88
|
+
"model": self.model_name,
|
|
89
|
+
"messages": messages,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
response = await self.http_client.post(
|
|
93
|
+
"/v1/chat/completions",
|
|
94
|
+
json=data,
|
|
95
|
+
headers=headers,
|
|
96
|
+
)
|
|
97
|
+
response.raise_for_status()
|
|
98
|
+
return response.json()
|
|
41
99
|
|
|
42
100
|
async def enrich(
|
|
43
101
|
self, requests: list[EnrichmentRequest]
|
|
@@ -66,19 +124,24 @@ class OpenAIEnrichmentProvider(EnrichmentProvider):
|
|
|
66
124
|
text="",
|
|
67
125
|
)
|
|
68
126
|
try:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
127
|
+
messages = [
|
|
128
|
+
{
|
|
129
|
+
"role": "system",
|
|
130
|
+
"content": ENRICHMENT_SYSTEM_PROMPT,
|
|
131
|
+
},
|
|
132
|
+
{"role": "user", "content": request.text},
|
|
133
|
+
]
|
|
134
|
+
response = await self._call_chat_completion(messages)
|
|
135
|
+
content = (
|
|
136
|
+
response.get("choices", [{}])[0]
|
|
137
|
+
.get("message", {})
|
|
138
|
+
.get("content", "")
|
|
78
139
|
)
|
|
140
|
+
# Remove thinking tags from the response
|
|
141
|
+
cleaned_content = clean_thinking_tags(content or "")
|
|
79
142
|
return EnrichmentResponse(
|
|
80
143
|
snippet_id=request.snippet_id,
|
|
81
|
-
text=
|
|
144
|
+
text=cleaned_content,
|
|
82
145
|
)
|
|
83
146
|
except Exception as e:
|
|
84
147
|
self.log.exception("Error enriching request", error=str(e))
|
|
@@ -93,3 +156,8 @@ class OpenAIEnrichmentProvider(EnrichmentProvider):
|
|
|
93
156
|
# Process all requests and yield results as they complete
|
|
94
157
|
for task in asyncio.as_completed(tasks):
|
|
95
158
|
yield await task
|
|
159
|
+
|
|
160
|
+
async def close(self) -> None:
|
|
161
|
+
"""Close the HTTP client."""
|
|
162
|
+
if hasattr(self, "http_client"):
|
|
163
|
+
await self.http_client.aclose()
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Utility functions for enrichment processing."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def clean_thinking_tags(content: str) -> str:
|
|
7
|
+
"""Remove <think>...</think> tags from content.
|
|
8
|
+
|
|
9
|
+
This utility handles thinking tags that may be produced by various AI models,
|
|
10
|
+
including both local and remote models. It safely removes thinking content
|
|
11
|
+
while preserving the actual response.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
content: The content that may contain thinking tags.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The content with thinking tags removed and cleaned up.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
if not content:
|
|
21
|
+
return content
|
|
22
|
+
|
|
23
|
+
# Remove thinking tags using regex with DOTALL flag to match across newlines
|
|
24
|
+
cleaned = re.sub(
|
|
25
|
+
r"<think>.*?</think>", "", content, flags=re.DOTALL | re.IGNORECASE
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Clean up extra whitespace that may be left behind
|
|
29
|
+
cleaned = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned)
|
|
30
|
+
return cleaned.strip() # Remove leading/trailing whitespace
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Task mapper for the task queue."""
|
|
2
|
+
|
|
3
|
+
from typing import ClassVar
|
|
4
|
+
|
|
5
|
+
from kodit.domain.entities import Task
|
|
6
|
+
from kodit.domain.value_objects import TaskType
|
|
7
|
+
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TaskTypeMapper:
|
|
11
|
+
"""Maps between domain QueuedTaskType and SQLAlchemy TaskType."""
|
|
12
|
+
|
|
13
|
+
# Map TaskType enum to QueuedTaskType
|
|
14
|
+
TASK_TYPE_MAPPING: ClassVar[dict[db_entities.TaskType, TaskType]] = {
|
|
15
|
+
db_entities.TaskType.INDEX_UPDATE: TaskType.INDEX_UPDATE,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def to_domain_type(task_type: db_entities.TaskType) -> TaskType:
|
|
20
|
+
"""Convert SQLAlchemy TaskType to domain QueuedTaskType."""
|
|
21
|
+
if task_type not in TaskTypeMapper.TASK_TYPE_MAPPING:
|
|
22
|
+
raise ValueError(f"Unknown task type: {task_type}")
|
|
23
|
+
return TaskTypeMapper.TASK_TYPE_MAPPING[task_type]
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def from_domain_type(task_type: TaskType) -> db_entities.TaskType:
|
|
27
|
+
"""Convert domain QueuedTaskType to SQLAlchemy TaskType."""
|
|
28
|
+
if task_type not in TaskTypeMapper.TASK_TYPE_MAPPING.values():
|
|
29
|
+
raise ValueError(f"Unknown task type: {task_type}")
|
|
30
|
+
|
|
31
|
+
# Find value in TASK_TYPE_MAPPING
|
|
32
|
+
return next(
|
|
33
|
+
(
|
|
34
|
+
db_task_type
|
|
35
|
+
for db_task_type, domain_task_type in TaskTypeMapper.TASK_TYPE_MAPPING.items() # noqa: E501
|
|
36
|
+
if domain_task_type == task_type
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TaskMapper:
|
|
42
|
+
"""Maps between domain QueuedTask and SQLAlchemy Task entities.
|
|
43
|
+
|
|
44
|
+
This mapper handles the conversion between the existing domain and
|
|
45
|
+
persistence layers without creating any new entities.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def to_domain_task(record: db_entities.Task) -> Task:
|
|
50
|
+
"""Convert SQLAlchemy Task record to domain QueuedTask.
|
|
51
|
+
|
|
52
|
+
Since QueuedTask doesn't have status fields, we store processing
|
|
53
|
+
state in the payload.
|
|
54
|
+
"""
|
|
55
|
+
# Get the task type
|
|
56
|
+
task_type = TaskTypeMapper.to_domain_type(record.type)
|
|
57
|
+
|
|
58
|
+
# The dedup_key becomes the id in the domain entity
|
|
59
|
+
return Task(
|
|
60
|
+
id=record.dedup_key, # Use dedup_key as the unique identifier
|
|
61
|
+
type=task_type,
|
|
62
|
+
priority=record.priority,
|
|
63
|
+
payload=record.payload or {},
|
|
64
|
+
created_at=record.created_at,
|
|
65
|
+
updated_at=record.updated_at,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def from_domain_task(task: Task) -> db_entities.Task:
|
|
70
|
+
"""Convert domain QueuedTask to SQLAlchemy Task record."""
|
|
71
|
+
if task.type not in TaskTypeMapper.TASK_TYPE_MAPPING.values():
|
|
72
|
+
raise ValueError(f"Unknown task type: {task.type}")
|
|
73
|
+
|
|
74
|
+
# Find value in TASK_TYPE_MAPPING
|
|
75
|
+
task_type = TaskTypeMapper.from_domain_type(task.type)
|
|
76
|
+
return db_entities.Task(
|
|
77
|
+
dedup_key=task.id,
|
|
78
|
+
type=task_type,
|
|
79
|
+
payload=task.payload,
|
|
80
|
+
priority=task.priority,
|
|
81
|
+
)
|
|
@@ -201,3 +201,38 @@ class Snippet(Base, CommonMixin):
|
|
|
201
201
|
self.index_id = index_id
|
|
202
202
|
self.content = content
|
|
203
203
|
self.summary = summary
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class TaskType(Enum):
|
|
207
|
+
"""Task type."""
|
|
208
|
+
|
|
209
|
+
INDEX_UPDATE = 1
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class Task(Base, CommonMixin):
|
|
213
|
+
"""Queued tasks."""
|
|
214
|
+
|
|
215
|
+
__tablename__ = "tasks"
|
|
216
|
+
|
|
217
|
+
# dedup_key is used to deduplicate items in the queue
|
|
218
|
+
dedup_key: Mapped[str] = mapped_column(String(255), index=True)
|
|
219
|
+
# type represents what the task is meant to achieve
|
|
220
|
+
type: Mapped[TaskType] = mapped_column(SQLAlchemyEnum(TaskType), index=True)
|
|
221
|
+
# payload contains the task-specific payload data
|
|
222
|
+
payload: Mapped[dict] = mapped_column(JSON)
|
|
223
|
+
# priority is used to determine the order of the items in the queue
|
|
224
|
+
priority: Mapped[int] = mapped_column(Integer)
|
|
225
|
+
|
|
226
|
+
def __init__(
|
|
227
|
+
self,
|
|
228
|
+
dedup_key: str,
|
|
229
|
+
type: TaskType, # noqa: A002
|
|
230
|
+
payload: dict,
|
|
231
|
+
priority: int,
|
|
232
|
+
) -> None:
|
|
233
|
+
"""Initialize the queue item."""
|
|
234
|
+
super().__init__()
|
|
235
|
+
self.dedup_key = dedup_key
|
|
236
|
+
self.type = type
|
|
237
|
+
self.payload = payload
|
|
238
|
+
self.priority = priority
|
|
@@ -597,12 +597,12 @@ class SqlAlchemyIndexRepository(IndexRepository):
|
|
|
597
597
|
)
|
|
598
598
|
await self._session.execute(stmt)
|
|
599
599
|
|
|
600
|
+
# Delete the index
|
|
601
|
+
stmt = delete(db_entities.Index).where(db_entities.Index.id == index.id)
|
|
602
|
+
await self._session.execute(stmt)
|
|
603
|
+
|
|
600
604
|
# Delete the source
|
|
601
605
|
stmt = delete(db_entities.Source).where(
|
|
602
606
|
db_entities.Source.id == index.source.id
|
|
603
607
|
)
|
|
604
608
|
await self._session.execute(stmt)
|
|
605
|
-
|
|
606
|
-
# Delete the index
|
|
607
|
-
stmt = delete(db_entities.Index).where(db_entities.Index.id == index.id)
|
|
608
|
-
await self._session.execute(stmt)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Task repository for the task queue."""
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
6
|
+
|
|
7
|
+
from kodit.domain.entities import Task
|
|
8
|
+
from kodit.domain.protocols import TaskRepository
|
|
9
|
+
from kodit.domain.value_objects import TaskType
|
|
10
|
+
from kodit.infrastructure.mappers.task_mapper import TaskMapper, TaskTypeMapper
|
|
11
|
+
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SqlAlchemyTaskRepository(TaskRepository):
|
|
15
|
+
"""Repository for task persistence using the existing Task entity."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, session: AsyncSession) -> None:
|
|
18
|
+
"""Initialize the repository."""
|
|
19
|
+
self.session = session
|
|
20
|
+
self.log = structlog.get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
async def add(
|
|
23
|
+
self,
|
|
24
|
+
task: Task,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Create a new task in the database."""
|
|
27
|
+
self.session.add(TaskMapper.from_domain_task(task))
|
|
28
|
+
|
|
29
|
+
async def get(self, task_id: str) -> Task | None:
|
|
30
|
+
"""Get a task by ID."""
|
|
31
|
+
stmt = select(db_entities.Task).where(db_entities.Task.dedup_key == task_id)
|
|
32
|
+
result = await self.session.execute(stmt)
|
|
33
|
+
db_task = result.scalar_one_or_none()
|
|
34
|
+
if not db_task:
|
|
35
|
+
return None
|
|
36
|
+
return TaskMapper.to_domain_task(db_task)
|
|
37
|
+
|
|
38
|
+
async def take(self) -> Task | None:
|
|
39
|
+
"""Take a task for processing and remove it from the database."""
|
|
40
|
+
stmt = (
|
|
41
|
+
select(db_entities.Task)
|
|
42
|
+
.order_by(db_entities.Task.priority.desc(), db_entities.Task.created_at)
|
|
43
|
+
.limit(1)
|
|
44
|
+
)
|
|
45
|
+
result = await self.session.execute(stmt)
|
|
46
|
+
db_task = result.scalar_one_or_none()
|
|
47
|
+
if not db_task:
|
|
48
|
+
return None
|
|
49
|
+
await self.session.delete(db_task)
|
|
50
|
+
return TaskMapper.to_domain_task(db_task)
|
|
51
|
+
|
|
52
|
+
async def update(self, task: Task) -> None:
|
|
53
|
+
"""Update a task in the database."""
|
|
54
|
+
stmt = select(db_entities.Task).where(db_entities.Task.dedup_key == task.id)
|
|
55
|
+
result = await self.session.execute(stmt)
|
|
56
|
+
db_task = result.scalar_one_or_none()
|
|
57
|
+
|
|
58
|
+
if not db_task:
|
|
59
|
+
raise ValueError(f"Task not found: {task.id}")
|
|
60
|
+
|
|
61
|
+
db_task.priority = task.priority
|
|
62
|
+
db_task.payload = task.payload
|
|
63
|
+
|
|
64
|
+
async def list(self, task_type: TaskType | None = None) -> list[Task]:
|
|
65
|
+
"""List tasks with optional status filter."""
|
|
66
|
+
stmt = select(db_entities.Task)
|
|
67
|
+
|
|
68
|
+
if task_type:
|
|
69
|
+
stmt = stmt.where(
|
|
70
|
+
db_entities.Task.type == TaskTypeMapper.from_domain_type(task_type)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
stmt = stmt.order_by(
|
|
74
|
+
db_entities.Task.priority.desc(), db_entities.Task.created_at
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
result = await self.session.execute(stmt)
|
|
78
|
+
records = result.scalars().all()
|
|
79
|
+
|
|
80
|
+
# Convert to domain entities
|
|
81
|
+
return [TaskMapper.to_domain_task(record) for record in records]
|
kodit/middleware.py
CHANGED
|
@@ -53,6 +53,7 @@ async def logging_middleware(request: Request, call_next: Callable) -> Response:
|
|
|
53
53
|
"client_host": client_host,
|
|
54
54
|
"client_port": client_port,
|
|
55
55
|
},
|
|
56
|
+
headers=dict(request.headers),
|
|
56
57
|
network={"client": {"ip": client_host, "port": client_port}},
|
|
57
58
|
duration=process_time,
|
|
58
59
|
)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# ruff: noqa
|
|
2
|
+
"""add queue
|
|
3
|
+
|
|
4
|
+
Revision ID: 9cf0e87de578
|
|
5
|
+
Revises: 4073b33f9436
|
|
6
|
+
Create Date: 2025-08-06 17:38:21.055235
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Sequence, Union
|
|
11
|
+
|
|
12
|
+
from alembic import op
|
|
13
|
+
import sqlalchemy as sa
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# revision identifiers, used by Alembic.
|
|
17
|
+
revision: str = '9cf0e87de578'
|
|
18
|
+
down_revision: Union[str, None] = '4073b33f9436'
|
|
19
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
|
20
|
+
depends_on: Union[str, Sequence[str], None] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def upgrade() -> None:
|
|
24
|
+
"""Upgrade schema."""
|
|
25
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
26
|
+
op.create_table('tasks',
|
|
27
|
+
sa.Column('dedup_key', sa.String(length=255), nullable=False),
|
|
28
|
+
sa.Column('type', sa.Enum('INDEX_UPDATE', name='tasktype'), nullable=False),
|
|
29
|
+
sa.Column('payload', sa.JSON(), nullable=False),
|
|
30
|
+
sa.Column('priority', sa.Integer(), nullable=False),
|
|
31
|
+
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
|
32
|
+
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
33
|
+
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
34
|
+
sa.PrimaryKeyConstraint('id')
|
|
35
|
+
)
|
|
36
|
+
op.create_index(op.f('ix_tasks_dedup_key'), 'tasks', ['dedup_key'], unique=False)
|
|
37
|
+
op.create_index(op.f('ix_tasks_type'), 'tasks', ['type'], unique=False)
|
|
38
|
+
# ### end Alembic commands ###
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def downgrade() -> None:
|
|
42
|
+
"""Downgrade schema."""
|
|
43
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
44
|
+
op.drop_index(op.f('ix_tasks_type'), table_name='tasks')
|
|
45
|
+
op.drop_index(op.f('ix_tasks_dedup_key'), table_name='tasks')
|
|
46
|
+
op.drop_table('tasks')
|
|
47
|
+
# ### end Alembic commands ###
|