kodit 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/app.py +59 -24
- kodit/application/factories/reporting_factory.py +16 -7
- kodit/application/factories/server_factory.py +311 -0
- kodit/application/services/code_search_application_service.py +144 -0
- kodit/application/services/commit_indexing_application_service.py +543 -0
- kodit/application/services/indexing_worker_service.py +13 -46
- kodit/application/services/queue_service.py +24 -3
- kodit/application/services/reporting.py +70 -54
- kodit/application/services/sync_scheduler.py +15 -31
- kodit/cli.py +2 -763
- kodit/cli_utils.py +2 -9
- kodit/config.py +3 -96
- kodit/database.py +38 -1
- kodit/domain/entities/__init__.py +276 -0
- kodit/domain/entities/git.py +190 -0
- kodit/domain/factories/__init__.py +1 -0
- kodit/domain/factories/git_repo_factory.py +76 -0
- kodit/domain/protocols.py +270 -46
- kodit/domain/services/bm25_service.py +5 -1
- kodit/domain/services/embedding_service.py +3 -0
- kodit/domain/services/git_repository_service.py +429 -0
- kodit/domain/services/git_service.py +300 -0
- kodit/domain/services/task_status_query_service.py +19 -0
- kodit/domain/value_objects.py +113 -147
- kodit/infrastructure/api/client/__init__.py +0 -2
- kodit/infrastructure/api/v1/__init__.py +0 -4
- kodit/infrastructure/api/v1/dependencies.py +105 -44
- kodit/infrastructure/api/v1/routers/__init__.py +0 -6
- kodit/infrastructure/api/v1/routers/commits.py +271 -0
- kodit/infrastructure/api/v1/routers/queue.py +2 -2
- kodit/infrastructure/api/v1/routers/repositories.py +282 -0
- kodit/infrastructure/api/v1/routers/search.py +31 -14
- kodit/infrastructure/api/v1/schemas/__init__.py +0 -24
- kodit/infrastructure/api/v1/schemas/commit.py +96 -0
- kodit/infrastructure/api/v1/schemas/context.py +2 -0
- kodit/infrastructure/api/v1/schemas/repository.py +128 -0
- kodit/infrastructure/api/v1/schemas/search.py +12 -9
- kodit/infrastructure/api/v1/schemas/snippet.py +58 -0
- kodit/infrastructure/api/v1/schemas/tag.py +31 -0
- kodit/infrastructure/api/v1/schemas/task_status.py +41 -0
- kodit/infrastructure/bm25/local_bm25_repository.py +16 -4
- kodit/infrastructure/bm25/vectorchord_bm25_repository.py +68 -52
- kodit/infrastructure/cloning/git/git_python_adaptor.py +467 -0
- kodit/infrastructure/cloning/git/working_copy.py +10 -3
- kodit/infrastructure/embedding/embedding_factory.py +3 -2
- kodit/infrastructure/embedding/local_vector_search_repository.py +1 -1
- kodit/infrastructure/embedding/vectorchord_vector_search_repository.py +111 -84
- kodit/infrastructure/enrichment/litellm_enrichment_provider.py +19 -26
- kodit/infrastructure/enrichment/local_enrichment_provider.py +41 -30
- kodit/infrastructure/indexing/fusion_service.py +1 -1
- kodit/infrastructure/mappers/git_mapper.py +193 -0
- kodit/infrastructure/mappers/snippet_mapper.py +106 -0
- kodit/infrastructure/mappers/task_mapper.py +5 -44
- kodit/infrastructure/mappers/task_status_mapper.py +85 -0
- kodit/infrastructure/reporting/db_progress.py +23 -0
- kodit/infrastructure/reporting/log_progress.py +13 -38
- kodit/infrastructure/reporting/telemetry_progress.py +21 -0
- kodit/infrastructure/slicing/slicer.py +32 -31
- kodit/infrastructure/sqlalchemy/embedding_repository.py +43 -23
- kodit/infrastructure/sqlalchemy/entities.py +428 -131
- kodit/infrastructure/sqlalchemy/git_branch_repository.py +263 -0
- kodit/infrastructure/sqlalchemy/git_commit_repository.py +337 -0
- kodit/infrastructure/sqlalchemy/git_repository.py +252 -0
- kodit/infrastructure/sqlalchemy/git_tag_repository.py +257 -0
- kodit/infrastructure/sqlalchemy/snippet_v2_repository.py +484 -0
- kodit/infrastructure/sqlalchemy/task_repository.py +29 -23
- kodit/infrastructure/sqlalchemy/task_status_repository.py +91 -0
- kodit/infrastructure/sqlalchemy/unit_of_work.py +10 -14
- kodit/mcp.py +12 -26
- kodit/migrations/env.py +1 -1
- kodit/migrations/versions/04b80f802e0c_foreign_key_review.py +100 -0
- kodit/migrations/versions/7f15f878c3a1_add_new_git_entities.py +690 -0
- kodit/migrations/versions/b9cd1c3fd762_add_task_status.py +77 -0
- kodit/migrations/versions/f9e5ef5e688f_add_git_commits_number.py +43 -0
- kodit/py.typed +0 -0
- kodit/utils/dump_openapi.py +7 -4
- kodit/utils/path_utils.py +29 -0
- {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/METADATA +3 -3
- kodit-0.5.0.dist-info/RECORD +137 -0
- kodit/application/factories/code_indexing_factory.py +0 -193
- kodit/application/services/auto_indexing_service.py +0 -103
- kodit/application/services/code_indexing_application_service.py +0 -393
- kodit/domain/entities.py +0 -323
- kodit/domain/services/index_query_service.py +0 -70
- kodit/domain/services/index_service.py +0 -267
- kodit/infrastructure/api/client/index_client.py +0 -57
- kodit/infrastructure/api/v1/routers/indexes.py +0 -119
- kodit/infrastructure/api/v1/schemas/index.py +0 -101
- kodit/infrastructure/bm25/bm25_factory.py +0 -28
- kodit/infrastructure/cloning/__init__.py +0 -1
- kodit/infrastructure/cloning/metadata.py +0 -98
- kodit/infrastructure/mappers/index_mapper.py +0 -345
- kodit/infrastructure/reporting/tdqm_progress.py +0 -73
- kodit/infrastructure/slicing/language_detection_service.py +0 -18
- kodit/infrastructure/sqlalchemy/index_repository.py +0 -646
- kodit-0.4.2.dist-info/RECORD +0 -119
- {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/WHEEL +0 -0
- {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/entry_points.txt +0 -0
- {kodit-0.4.2.dist-info → kodit-0.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""VectorChord vector search repository implementation."""
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncGenerator
|
|
4
|
-
from typing import
|
|
3
|
+
from collections.abc import AsyncGenerator, Callable
|
|
4
|
+
from typing import Literal
|
|
5
5
|
|
|
6
6
|
import structlog
|
|
7
|
-
from sqlalchemy import
|
|
7
|
+
from sqlalchemy import text
|
|
8
8
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
9
|
|
|
10
10
|
from kodit.domain.services.embedding_service import (
|
|
@@ -19,6 +19,7 @@ from kodit.domain.value_objects import (
|
|
|
19
19
|
SearchResult,
|
|
20
20
|
)
|
|
21
21
|
from kodit.infrastructure.sqlalchemy.entities import EmbeddingType
|
|
22
|
+
from kodit.infrastructure.sqlalchemy.unit_of_work import SqlAlchemyUnitOfWork
|
|
22
23
|
|
|
23
24
|
# SQL Queries
|
|
24
25
|
CREATE_VCHORD_EXTENSION = """
|
|
@@ -72,6 +73,10 @@ CHECK_VCHORD_EMBEDDING_EXISTS = """
|
|
|
72
73
|
SELECT EXISTS(SELECT 1 FROM {TABLE_NAME} WHERE snippet_id = :snippet_id)
|
|
73
74
|
"""
|
|
74
75
|
|
|
76
|
+
CHECK_VCHORD_EMBEDDING_EXISTS_MULTIPLE = """
|
|
77
|
+
SELECT snippet_id FROM {TABLE_NAME} WHERE snippet_id = ANY(:snippet_ids)
|
|
78
|
+
"""
|
|
79
|
+
|
|
75
80
|
TaskName = Literal["code", "text"]
|
|
76
81
|
|
|
77
82
|
|
|
@@ -80,8 +85,8 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
|
|
|
80
85
|
|
|
81
86
|
def __init__(
|
|
82
87
|
self,
|
|
88
|
+
session_factory: Callable[[], AsyncSession],
|
|
83
89
|
task_name: TaskName,
|
|
84
|
-
session: AsyncSession,
|
|
85
90
|
embedding_provider: EmbeddingProvider,
|
|
86
91
|
) -> None:
|
|
87
92
|
"""Initialize the VectorChord vector search repository.
|
|
@@ -93,7 +98,7 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
|
|
|
93
98
|
|
|
94
99
|
"""
|
|
95
100
|
self.embedding_provider = embedding_provider
|
|
96
|
-
self.
|
|
101
|
+
self.session_factory = session_factory
|
|
97
102
|
self._initialized = False
|
|
98
103
|
self.table_name = f"vectorchord_{task_name}_embeddings"
|
|
99
104
|
self.index_name = f"{self.table_name}_idx"
|
|
@@ -111,12 +116,12 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
|
|
|
111
116
|
|
|
112
117
|
async def _create_extensions(self) -> None:
|
|
113
118
|
"""Create the necessary extensions."""
|
|
114
|
-
|
|
115
|
-
|
|
119
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
120
|
+
await session.execute(text(CREATE_VCHORD_EXTENSION))
|
|
116
121
|
|
|
117
122
|
async def _create_tables(self) -> None:
|
|
118
123
|
"""Create the necessary tables."""
|
|
119
|
-
req = EmbeddingRequest(snippet_id=0, text="dimension")
|
|
124
|
+
req = EmbeddingRequest(snippet_id="0", text="dimension")
|
|
120
125
|
vector_dim: list[float] | None = None
|
|
121
126
|
async for batch in self.embedding_provider.embed([req]):
|
|
122
127
|
if batch:
|
|
@@ -125,79 +130,85 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
|
|
|
125
130
|
if vector_dim is None:
|
|
126
131
|
msg = "Failed to obtain embedding dimension from provider"
|
|
127
132
|
raise RuntimeError(msg)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
133
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
134
|
+
await session.execute(
|
|
135
|
+
text(
|
|
136
|
+
f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
137
|
+
id SERIAL PRIMARY KEY,
|
|
138
|
+
snippet_id VARCHAR(255) NOT NULL UNIQUE,
|
|
139
|
+
embedding VECTOR({len(vector_dim)}) NOT NULL
|
|
140
|
+
);"""
|
|
141
|
+
)
|
|
135
142
|
)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
143
|
+
await session.execute(
|
|
144
|
+
text(
|
|
145
|
+
CREATE_VCHORD_INDEX.format(
|
|
146
|
+
TABLE_NAME=self.table_name, INDEX_NAME=self.index_name
|
|
147
|
+
)
|
|
141
148
|
)
|
|
142
149
|
)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
vector_dim_from_db = result.scalar_one()
|
|
148
|
-
if vector_dim_from_db != len(vector_dim):
|
|
149
|
-
msg = (
|
|
150
|
-
f"Embedding vector dimension does not match database, "
|
|
151
|
-
f"please delete your index: {vector_dim_from_db} != {len(vector_dim)}"
|
|
150
|
+
result = await session.execute(
|
|
151
|
+
text(
|
|
152
|
+
CHECK_VCHORD_EMBEDDING_DIMENSION.format(TABLE_NAME=self.table_name)
|
|
153
|
+
)
|
|
152
154
|
)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
if not self._initialized:
|
|
161
|
-
await self._initialize()
|
|
162
|
-
return await self._session.execute(query, param_list)
|
|
163
|
-
|
|
164
|
-
async def _commit(self) -> None:
|
|
165
|
-
"""Commit the session."""
|
|
166
|
-
await self._session.commit()
|
|
155
|
+
vector_dim_from_db = result.scalar_one()
|
|
156
|
+
if vector_dim_from_db != len(vector_dim):
|
|
157
|
+
msg = (
|
|
158
|
+
f"Embedding vector dimension does not match database, please "
|
|
159
|
+
f"delete your index: {vector_dim_from_db} != {len(vector_dim)}"
|
|
160
|
+
)
|
|
161
|
+
raise ValueError(msg)
|
|
167
162
|
|
|
168
163
|
async def index_documents(
|
|
169
164
|
self, request: IndexRequest
|
|
170
165
|
) -> AsyncGenerator[list[IndexResult], None]:
|
|
171
166
|
"""Index documents for vector search."""
|
|
167
|
+
if not self._initialized:
|
|
168
|
+
await self._initialize()
|
|
169
|
+
|
|
172
170
|
if not request.documents:
|
|
173
171
|
yield []
|
|
174
172
|
|
|
173
|
+
# Search for existing embeddings
|
|
174
|
+
existing_ids = await self._get_existing_ids(
|
|
175
|
+
[doc.snippet_id for doc in request.documents]
|
|
176
|
+
)
|
|
177
|
+
new_documents = [
|
|
178
|
+
doc for doc in request.documents if doc.snippet_id not in existing_ids
|
|
179
|
+
]
|
|
180
|
+
if not new_documents:
|
|
181
|
+
self.log.info("No new documents to index")
|
|
182
|
+
return
|
|
183
|
+
|
|
175
184
|
# Convert to embedding requests
|
|
176
|
-
|
|
185
|
+
embedding_requests = [
|
|
177
186
|
EmbeddingRequest(snippet_id=doc.snippet_id, text=doc.text)
|
|
178
|
-
for doc in
|
|
187
|
+
for doc in new_documents
|
|
179
188
|
]
|
|
180
189
|
|
|
181
|
-
async for batch in self.embedding_provider.embed(
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
190
|
+
async for batch in self.embedding_provider.embed(embedding_requests):
|
|
191
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
192
|
+
await session.execute(
|
|
193
|
+
text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
|
|
194
|
+
[
|
|
195
|
+
{
|
|
196
|
+
"snippet_id": result.snippet_id,
|
|
197
|
+
"embedding": str(result.embedding),
|
|
198
|
+
}
|
|
199
|
+
for result in batch
|
|
200
|
+
],
|
|
201
|
+
)
|
|
202
|
+
yield [IndexResult(snippet_id=result.snippet_id) for result in batch]
|
|
194
203
|
|
|
195
204
|
async def search(self, request: SearchRequest) -> list[SearchResult]:
|
|
196
205
|
"""Search documents using vector similarity."""
|
|
206
|
+
if not self._initialized:
|
|
207
|
+
await self._initialize()
|
|
197
208
|
if not request.query or not request.query.strip():
|
|
198
209
|
return []
|
|
199
210
|
|
|
200
|
-
req = EmbeddingRequest(snippet_id=0, text=request.query)
|
|
211
|
+
req = EmbeddingRequest(snippet_id="0", text=request.query)
|
|
201
212
|
embedding_vec: list[float] | None = None
|
|
202
213
|
async for batch in self.embedding_provider.embed([req]):
|
|
203
214
|
if batch:
|
|
@@ -207,39 +218,55 @@ class VectorChordVectorSearchRepository(VectorSearchRepository):
|
|
|
207
218
|
if not embedding_vec:
|
|
208
219
|
return []
|
|
209
220
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
221
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
222
|
+
# Use filtered query if snippet_ids are provided
|
|
223
|
+
if request.snippet_ids is not None:
|
|
224
|
+
result = await session.execute(
|
|
225
|
+
text(SEARCH_QUERY_WITH_FILTER.format(TABLE_NAME=self.table_name)),
|
|
226
|
+
{
|
|
227
|
+
"query": str(embedding_vec),
|
|
228
|
+
"top_k": request.top_k,
|
|
229
|
+
"snippet_ids": request.snippet_ids,
|
|
230
|
+
},
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
result = await session.execute(
|
|
234
|
+
text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
|
|
235
|
+
{"query": str(embedding_vec), "top_k": request.top_k},
|
|
236
|
+
)
|
|
225
237
|
|
|
226
|
-
|
|
238
|
+
rows = result.mappings().all()
|
|
227
239
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
240
|
+
return [
|
|
241
|
+
SearchResult(snippet_id=row["snippet_id"], score=row["score"])
|
|
242
|
+
for row in rows
|
|
243
|
+
]
|
|
232
244
|
|
|
233
245
|
async def has_embedding(
|
|
234
246
|
self, snippet_id: int, embedding_type: EmbeddingType
|
|
235
247
|
) -> bool:
|
|
236
248
|
"""Check if a snippet has an embedding."""
|
|
249
|
+
if not self._initialized:
|
|
250
|
+
await self._initialize()
|
|
237
251
|
# For VectorChord, we check if the snippet exists in the table
|
|
238
252
|
# Note: embedding_type is ignored since VectorChord uses separate
|
|
239
253
|
# tables per task
|
|
240
254
|
# ruff: noqa: ARG002
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
255
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
256
|
+
result = await session.execute(
|
|
257
|
+
text(CHECK_VCHORD_EMBEDDING_EXISTS.format(TABLE_NAME=self.table_name)),
|
|
258
|
+
{"snippet_id": snippet_id},
|
|
259
|
+
)
|
|
260
|
+
return bool(result.scalar())
|
|
261
|
+
|
|
262
|
+
async def _get_existing_ids(self, snippet_ids: list[str]) -> set[str]:
|
|
263
|
+
async with SqlAlchemyUnitOfWork(self.session_factory) as session:
|
|
264
|
+
result = await session.execute(
|
|
265
|
+
text(
|
|
266
|
+
CHECK_VCHORD_EMBEDDING_EXISTS_MULTIPLE.format(
|
|
267
|
+
TABLE_NAME=self.table_name
|
|
268
|
+
)
|
|
269
|
+
),
|
|
270
|
+
{"snippet_ids": snippet_ids},
|
|
271
|
+
)
|
|
272
|
+
return {row[0] for row in result.fetchall()}
|
|
@@ -128,32 +128,25 @@ class LiteLLMEnrichmentProvider(EnrichmentProvider):
|
|
|
128
128
|
snippet_id=request.snippet_id,
|
|
129
129
|
text="",
|
|
130
130
|
)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
)
|
|
151
|
-
except Exception as e:
|
|
152
|
-
self.log.exception("Error enriching request", error=str(e))
|
|
153
|
-
return EnrichmentResponse(
|
|
154
|
-
snippet_id=request.snippet_id,
|
|
155
|
-
text="",
|
|
156
|
-
)
|
|
131
|
+
messages = [
|
|
132
|
+
{
|
|
133
|
+
"role": "system",
|
|
134
|
+
"content": ENRICHMENT_SYSTEM_PROMPT,
|
|
135
|
+
},
|
|
136
|
+
{"role": "user", "content": request.text},
|
|
137
|
+
]
|
|
138
|
+
response = await self._call_chat_completion(messages)
|
|
139
|
+
content = (
|
|
140
|
+
response.get("choices", [{}])[0]
|
|
141
|
+
.get("message", {})
|
|
142
|
+
.get("content", "")
|
|
143
|
+
)
|
|
144
|
+
# Remove thinking tags from the response
|
|
145
|
+
cleaned_content = clean_thinking_tags(content or "")
|
|
146
|
+
return EnrichmentResponse(
|
|
147
|
+
snippet_id=request.snippet_id,
|
|
148
|
+
text=cleaned_content,
|
|
149
|
+
)
|
|
157
150
|
|
|
158
151
|
# Create tasks for all requests
|
|
159
152
|
tasks = [process_request(request) for request in requests]
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""Local enrichment provider implementation."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import os
|
|
4
5
|
from collections.abc import AsyncGenerator
|
|
6
|
+
from typing import Any
|
|
5
7
|
|
|
6
8
|
import structlog
|
|
7
9
|
import tiktoken
|
|
@@ -60,23 +62,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
|
|
|
60
62
|
self.log.warning("No valid requests for enrichment")
|
|
61
63
|
return
|
|
62
64
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
67
|
-
|
|
68
|
-
if self.tokenizer is None:
|
|
69
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
70
|
-
self.model_name, padding_side="left"
|
|
71
|
-
)
|
|
72
|
-
if self.model is None:
|
|
73
|
-
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
74
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
75
|
-
self.model_name,
|
|
76
|
-
torch_dtype="auto",
|
|
77
|
-
trust_remote_code=True,
|
|
78
|
-
device_map="auto",
|
|
65
|
+
def _init_model() -> None:
|
|
66
|
+
from transformers.models.auto.modeling_auto import (
|
|
67
|
+
AutoModelForCausalLM,
|
|
79
68
|
)
|
|
69
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
70
|
+
|
|
71
|
+
if self.tokenizer is None:
|
|
72
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
73
|
+
self.model_name, padding_side="left"
|
|
74
|
+
)
|
|
75
|
+
if self.model is None:
|
|
76
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
77
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
78
|
+
self.model_name,
|
|
79
|
+
torch_dtype="auto",
|
|
80
|
+
trust_remote_code=True,
|
|
81
|
+
device_map="auto",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
await asyncio.to_thread(_init_model)
|
|
80
85
|
|
|
81
86
|
# Prepare prompts
|
|
82
87
|
prompts = [
|
|
@@ -96,20 +101,26 @@ class LocalEnrichmentProvider(EnrichmentProvider):
|
|
|
96
101
|
]
|
|
97
102
|
|
|
98
103
|
for prompt in prompts:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
104
|
+
|
|
105
|
+
def process_prompt(prompt: dict[str, Any]) -> str:
|
|
106
|
+
model_inputs = self.tokenizer( # type: ignore[misc]
|
|
107
|
+
prompt["text"],
|
|
108
|
+
return_tensors="pt",
|
|
109
|
+
padding=True,
|
|
110
|
+
truncation=True,
|
|
111
|
+
).to(self.model.device) # type: ignore[attr-defined]
|
|
112
|
+
generated_ids = self.model.generate( # type: ignore[attr-defined]
|
|
113
|
+
**model_inputs, max_new_tokens=self.context_window
|
|
114
|
+
)
|
|
115
|
+
input_ids = model_inputs["input_ids"][0]
|
|
116
|
+
output_ids = generated_ids[0][len(input_ids) :].tolist()
|
|
117
|
+
return self.tokenizer.decode( # type: ignore[attr-defined]
|
|
118
|
+
output_ids, skip_special_tokens=True
|
|
119
|
+
).strip( # type: ignore[attr-defined]
|
|
120
|
+
"\n"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
content = await asyncio.to_thread(process_prompt, prompt)
|
|
113
124
|
# Remove thinking tags from the response
|
|
114
125
|
cleaned_content = clean_thinking_tags(content)
|
|
115
126
|
yield EnrichmentResponse(
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Mapping between domain Git entities and SQLAlchemy entities."""
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from pydantic import AnyUrl
|
|
7
|
+
|
|
8
|
+
import kodit.domain.entities.git as domain_git_entities
|
|
9
|
+
from kodit.infrastructure.sqlalchemy import entities as db_entities
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GitMapper:
|
|
13
|
+
"""Mapper for converting between domain Git entities and database entities."""
|
|
14
|
+
|
|
15
|
+
def to_domain_commits(
|
|
16
|
+
self,
|
|
17
|
+
db_commits: list[db_entities.GitCommit],
|
|
18
|
+
db_commit_files: list[db_entities.GitCommitFile],
|
|
19
|
+
) -> list[domain_git_entities.GitCommit]:
|
|
20
|
+
"""Convert SQLAlchemy GitCommit to domain GitCommit."""
|
|
21
|
+
commit_files_map = defaultdict(list)
|
|
22
|
+
for file in db_commit_files:
|
|
23
|
+
commit_files_map[file.commit_sha].append(file.blob_sha)
|
|
24
|
+
|
|
25
|
+
commit_domain_files_map = defaultdict(list)
|
|
26
|
+
for file in db_commit_files:
|
|
27
|
+
commit_domain_files_map[file.commit_sha].append(
|
|
28
|
+
domain_git_entities.GitFile(
|
|
29
|
+
created_at=file.created_at,
|
|
30
|
+
blob_sha=file.blob_sha,
|
|
31
|
+
path=file.path,
|
|
32
|
+
mime_type=file.mime_type,
|
|
33
|
+
size=file.size,
|
|
34
|
+
extension=file.extension,
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
domain_commits = []
|
|
39
|
+
for db_commit in db_commits:
|
|
40
|
+
domain_commit = domain_git_entities.GitCommit(
|
|
41
|
+
created_at=db_commit.created_at,
|
|
42
|
+
updated_at=db_commit.updated_at,
|
|
43
|
+
commit_sha=db_commit.commit_sha,
|
|
44
|
+
date=db_commit.date,
|
|
45
|
+
message=db_commit.message,
|
|
46
|
+
parent_commit_sha=db_commit.parent_commit_sha,
|
|
47
|
+
files=commit_domain_files_map[db_commit.commit_sha],
|
|
48
|
+
author=db_commit.author,
|
|
49
|
+
)
|
|
50
|
+
domain_commits.append(domain_commit)
|
|
51
|
+
return domain_commits
|
|
52
|
+
|
|
53
|
+
def to_domain_branches(
|
|
54
|
+
self,
|
|
55
|
+
db_branches: list[db_entities.GitBranch],
|
|
56
|
+
domain_commits: list[domain_git_entities.GitCommit],
|
|
57
|
+
) -> list[domain_git_entities.GitBranch]:
|
|
58
|
+
"""Convert SQLAlchemy GitBranch to domain GitBranch."""
|
|
59
|
+
commit_map = {commit.commit_sha: commit for commit in domain_commits}
|
|
60
|
+
domain_branches = []
|
|
61
|
+
for db_branch in db_branches:
|
|
62
|
+
if db_branch.head_commit_sha not in commit_map:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Commit {db_branch.head_commit_sha} for "
|
|
65
|
+
f"branch {db_branch.name} not found in commits: {commit_map.keys()}"
|
|
66
|
+
)
|
|
67
|
+
domain_branch = domain_git_entities.GitBranch(
|
|
68
|
+
repo_id=db_branch.repo_id,
|
|
69
|
+
name=db_branch.name,
|
|
70
|
+
created_at=db_branch.created_at,
|
|
71
|
+
updated_at=db_branch.updated_at,
|
|
72
|
+
head_commit=commit_map[db_branch.head_commit_sha],
|
|
73
|
+
)
|
|
74
|
+
domain_branches.append(domain_branch)
|
|
75
|
+
return domain_branches
|
|
76
|
+
|
|
77
|
+
def to_domain_tags(
|
|
78
|
+
self,
|
|
79
|
+
db_tags: list[db_entities.GitTag],
|
|
80
|
+
domain_commits: list[domain_git_entities.GitCommit],
|
|
81
|
+
) -> list[domain_git_entities.GitTag]:
|
|
82
|
+
"""Convert SQLAlchemy GitTag to domain GitTag."""
|
|
83
|
+
commit_map = {commit.commit_sha: commit for commit in domain_commits}
|
|
84
|
+
domain_tags = []
|
|
85
|
+
for db_tag in db_tags:
|
|
86
|
+
if db_tag.target_commit_sha not in commit_map:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"Commit {db_tag.target_commit_sha} for tag {db_tag.name} not found"
|
|
89
|
+
)
|
|
90
|
+
domain_tag = domain_git_entities.GitTag(
|
|
91
|
+
created_at=db_tag.created_at,
|
|
92
|
+
updated_at=db_tag.updated_at,
|
|
93
|
+
repo_id=db_tag.repo_id,
|
|
94
|
+
name=db_tag.name,
|
|
95
|
+
target_commit=commit_map[db_tag.target_commit_sha],
|
|
96
|
+
)
|
|
97
|
+
domain_tags.append(domain_tag)
|
|
98
|
+
return domain_tags
|
|
99
|
+
|
|
100
|
+
def to_domain_tracking_branch(
|
|
101
|
+
self,
|
|
102
|
+
db_tracking_branch: db_entities.GitTrackingBranch | None,
|
|
103
|
+
db_tracking_branch_entity: db_entities.GitBranch | None,
|
|
104
|
+
domain_commits: list[domain_git_entities.GitCommit],
|
|
105
|
+
) -> domain_git_entities.GitBranch | None:
|
|
106
|
+
"""Convert SQLAlchemy GitTrackingBranch to domain GitBranch."""
|
|
107
|
+
if db_tracking_branch is None or db_tracking_branch_entity is None:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
commit_map = {commit.commit_sha: commit for commit in domain_commits}
|
|
111
|
+
if db_tracking_branch_entity.head_commit_sha not in commit_map:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Commit {db_tracking_branch_entity.head_commit_sha} for "
|
|
114
|
+
f"tracking branch {db_tracking_branch.name} not found"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return domain_git_entities.GitBranch(
|
|
118
|
+
repo_id=db_tracking_branch_entity.repo_id,
|
|
119
|
+
name=db_tracking_branch_entity.name,
|
|
120
|
+
created_at=db_tracking_branch_entity.created_at,
|
|
121
|
+
updated_at=db_tracking_branch_entity.updated_at,
|
|
122
|
+
head_commit=commit_map[db_tracking_branch_entity.head_commit_sha],
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def to_domain_git_repo( # noqa: PLR0913
|
|
126
|
+
self,
|
|
127
|
+
db_repo: db_entities.GitRepo,
|
|
128
|
+
db_tracking_branch_entity: db_entities.GitBranch | None,
|
|
129
|
+
db_commits: list[db_entities.GitCommit],
|
|
130
|
+
db_tags: list[db_entities.GitTag],
|
|
131
|
+
db_commit_files: list[db_entities.GitCommitFile],
|
|
132
|
+
db_tracking_branch: db_entities.GitTrackingBranch | None,
|
|
133
|
+
) -> domain_git_entities.GitRepo:
|
|
134
|
+
"""Convert SQLAlchemy GitRepo to domain GitRepo."""
|
|
135
|
+
# Build commits needed for tags and tracking branch
|
|
136
|
+
domain_commits = self.to_domain_commits(
|
|
137
|
+
db_commits=db_commits, db_commit_files=db_commit_files
|
|
138
|
+
)
|
|
139
|
+
self.to_domain_tags(
|
|
140
|
+
db_tags=db_tags, domain_commits=domain_commits
|
|
141
|
+
)
|
|
142
|
+
tracking_branch = self.to_domain_tracking_branch(
|
|
143
|
+
db_tracking_branch=db_tracking_branch,
|
|
144
|
+
db_tracking_branch_entity=db_tracking_branch_entity,
|
|
145
|
+
domain_commits=domain_commits,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
from kodit.domain.factories.git_repo_factory import GitRepoFactory
|
|
149
|
+
|
|
150
|
+
return GitRepoFactory.create_from_components(
|
|
151
|
+
repo_id=db_repo.id,
|
|
152
|
+
created_at=db_repo.created_at,
|
|
153
|
+
updated_at=db_repo.updated_at,
|
|
154
|
+
sanitized_remote_uri=AnyUrl(db_repo.sanitized_remote_uri),
|
|
155
|
+
remote_uri=AnyUrl(db_repo.remote_uri),
|
|
156
|
+
tracking_branch=tracking_branch,
|
|
157
|
+
cloned_path=Path(db_repo.cloned_path) if db_repo.cloned_path else None,
|
|
158
|
+
last_scanned_at=db_repo.last_scanned_at,
|
|
159
|
+
num_commits=db_repo.num_commits,
|
|
160
|
+
num_branches=db_repo.num_branches,
|
|
161
|
+
num_tags=db_repo.num_tags,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def to_domain_commit_index(
|
|
165
|
+
self,
|
|
166
|
+
db_commit_index: db_entities.CommitIndex,
|
|
167
|
+
snippets: list[domain_git_entities.SnippetV2],
|
|
168
|
+
) -> domain_git_entities.CommitIndex:
|
|
169
|
+
"""Convert SQLAlchemy CommitIndex to domain CommitIndex."""
|
|
170
|
+
return domain_git_entities.CommitIndex(
|
|
171
|
+
commit_sha=db_commit_index.commit_sha,
|
|
172
|
+
created_at=db_commit_index.created_at,
|
|
173
|
+
updated_at=db_commit_index.updated_at,
|
|
174
|
+
snippets=snippets,
|
|
175
|
+
status=domain_git_entities.IndexStatus(db_commit_index.status),
|
|
176
|
+
indexed_at=db_commit_index.indexed_at,
|
|
177
|
+
error_message=db_commit_index.error_message,
|
|
178
|
+
files_processed=db_commit_index.files_processed,
|
|
179
|
+
processing_time_seconds=float(db_commit_index.processing_time_seconds),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def from_domain_commit_index(
|
|
183
|
+
self, domain_commit_index: domain_git_entities.CommitIndex
|
|
184
|
+
) -> db_entities.CommitIndex:
|
|
185
|
+
"""Convert domain CommitIndex to SQLAlchemy CommitIndex."""
|
|
186
|
+
return db_entities.CommitIndex(
|
|
187
|
+
commit_sha=domain_commit_index.commit_sha,
|
|
188
|
+
status=domain_commit_index.status,
|
|
189
|
+
indexed_at=domain_commit_index.indexed_at,
|
|
190
|
+
error_message=domain_commit_index.error_message,
|
|
191
|
+
files_processed=domain_commit_index.files_processed,
|
|
192
|
+
processing_time_seconds=domain_commit_index.processing_time_seconds,
|
|
193
|
+
)
|