kodit 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kodit might be problematic. Click here for more details.
- kodit/_version.py +2 -2
- kodit/cli.py +105 -19
- kodit/embedding/embedding_factory.py +2 -2
- kodit/embedding/embedding_provider/embedding_provider.py +9 -2
- kodit/embedding/embedding_provider/openai_embedding_provider.py +19 -7
- kodit/embedding/vectorchord_vector_search_service.py +26 -15
- kodit/enrichment/__init__.py +1 -0
- kodit/enrichment/enrichment_factory.py +23 -0
- kodit/enrichment/enrichment_provider/__init__.py +1 -0
- kodit/enrichment/enrichment_provider/enrichment_provider.py +16 -0
- kodit/enrichment/enrichment_provider/local_enrichment_provider.py +63 -0
- kodit/enrichment/enrichment_provider/openai_enrichment_provider.py +77 -0
- kodit/enrichment/enrichment_service.py +33 -0
- kodit/indexing/fusion.py +67 -0
- kodit/indexing/indexing_repository.py +41 -23
- kodit/indexing/indexing_service.py +128 -8
- kodit/mcp.py +25 -16
- kodit/migrations/versions/c3f5137d30f5_index_all_the_things.py +44 -0
- kodit/snippets/languages/go.scm +26 -0
- kodit/source/source_models.py +4 -4
- kodit-0.1.17.dist-info/METADATA +152 -0
- {kodit-0.1.15.dist-info → kodit-0.1.17.dist-info}/RECORD +25 -18
- kodit/search/__init__.py +0 -1
- kodit/search/search_repository.py +0 -57
- kodit/search/search_service.py +0 -135
- kodit-0.1.15.dist-info/METADATA +0 -89
- {kodit-0.1.15.dist-info → kodit-0.1.17.dist-info}/WHEEL +0 -0
- {kodit-0.1.15.dist-info → kodit-0.1.17.dist-info}/entry_points.txt +0 -0
- {kodit-0.1.15.dist-info → kodit-0.1.17.dist-info}/licenses/LICENSE +0 -0
kodit/_version.py
CHANGED
kodit/cli.py
CHANGED
|
@@ -17,11 +17,10 @@ from kodit.config import (
|
|
|
17
17
|
with_session,
|
|
18
18
|
)
|
|
19
19
|
from kodit.embedding.embedding_factory import embedding_factory
|
|
20
|
+
from kodit.enrichment.enrichment_factory import enrichment_factory
|
|
20
21
|
from kodit.indexing.indexing_repository import IndexRepository
|
|
21
|
-
from kodit.indexing.indexing_service import IndexService
|
|
22
|
+
from kodit.indexing.indexing_service import IndexService, SearchRequest
|
|
22
23
|
from kodit.log import configure_logging, configure_telemetry, log_event
|
|
23
|
-
from kodit.search.search_repository import SearchRepository
|
|
24
|
-
from kodit.search.search_service import SearchRequest, SearchService
|
|
25
24
|
from kodit.source.source_repository import SourceRepository
|
|
26
25
|
from kodit.source.source_service import SourceService
|
|
27
26
|
|
|
@@ -72,9 +71,13 @@ async def index(
|
|
|
72
71
|
repository=repository,
|
|
73
72
|
source_service=source_service,
|
|
74
73
|
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
75
|
-
|
|
76
|
-
app_context=app_context, session=session
|
|
74
|
+
code_search_service=embedding_factory(
|
|
75
|
+
task_name="code", app_context=app_context, session=session
|
|
77
76
|
),
|
|
77
|
+
text_search_service=embedding_factory(
|
|
78
|
+
task_name="text", app_context=app_context, session=session
|
|
79
|
+
),
|
|
80
|
+
enrichment_service=enrichment_factory(app_context),
|
|
78
81
|
)
|
|
79
82
|
|
|
80
83
|
if not sources:
|
|
@@ -131,11 +134,20 @@ async def code(
|
|
|
131
134
|
|
|
132
135
|
This works best if your query is code.
|
|
133
136
|
"""
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
+
source_repository = SourceRepository(session)
|
|
138
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
139
|
+
repository = IndexRepository(session)
|
|
140
|
+
service = IndexService(
|
|
141
|
+
repository=repository,
|
|
142
|
+
source_service=source_service,
|
|
137
143
|
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
138
|
-
|
|
144
|
+
code_search_service=embedding_factory(
|
|
145
|
+
task_name="code", app_context=app_context, session=session
|
|
146
|
+
),
|
|
147
|
+
text_search_service=embedding_factory(
|
|
148
|
+
task_name="text", app_context=app_context, session=session
|
|
149
|
+
),
|
|
150
|
+
enrichment_service=enrichment_factory(app_context),
|
|
139
151
|
)
|
|
140
152
|
|
|
141
153
|
snippets = await service.search(SearchRequest(code_query=query, top_k=top_k))
|
|
@@ -147,6 +159,7 @@ async def code(
|
|
|
147
159
|
for snippet in snippets:
|
|
148
160
|
click.echo("-" * 80)
|
|
149
161
|
click.echo(f"{snippet.uri}")
|
|
162
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
150
163
|
click.echo(snippet.content)
|
|
151
164
|
click.echo("-" * 80)
|
|
152
165
|
click.echo()
|
|
@@ -164,11 +177,20 @@ async def keyword(
|
|
|
164
177
|
top_k: int,
|
|
165
178
|
) -> None:
|
|
166
179
|
"""Search for snippets using keyword search."""
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
180
|
+
source_repository = SourceRepository(session)
|
|
181
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
182
|
+
repository = IndexRepository(session)
|
|
183
|
+
service = IndexService(
|
|
184
|
+
repository=repository,
|
|
185
|
+
source_service=source_service,
|
|
170
186
|
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
171
|
-
|
|
187
|
+
code_search_service=embedding_factory(
|
|
188
|
+
task_name="code", app_context=app_context, session=session
|
|
189
|
+
),
|
|
190
|
+
text_search_service=embedding_factory(
|
|
191
|
+
task_name="text", app_context=app_context, session=session
|
|
192
|
+
),
|
|
193
|
+
enrichment_service=enrichment_factory(app_context),
|
|
172
194
|
)
|
|
173
195
|
|
|
174
196
|
snippets = await service.search(SearchRequest(keywords=keywords, top_k=top_k))
|
|
@@ -180,6 +202,53 @@ async def keyword(
|
|
|
180
202
|
for snippet in snippets:
|
|
181
203
|
click.echo("-" * 80)
|
|
182
204
|
click.echo(f"{snippet.uri}")
|
|
205
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
206
|
+
click.echo(snippet.content)
|
|
207
|
+
click.echo("-" * 80)
|
|
208
|
+
click.echo()
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@search.command()
|
|
212
|
+
@click.argument("query")
|
|
213
|
+
@click.option("--top-k", default=10, help="Number of snippets to retrieve")
|
|
214
|
+
@with_app_context
|
|
215
|
+
@with_session
|
|
216
|
+
async def text(
|
|
217
|
+
session: AsyncSession,
|
|
218
|
+
app_context: AppContext,
|
|
219
|
+
query: str,
|
|
220
|
+
top_k: int,
|
|
221
|
+
) -> None:
|
|
222
|
+
"""Search for snippets using semantic text search.
|
|
223
|
+
|
|
224
|
+
This works best if your query is text.
|
|
225
|
+
"""
|
|
226
|
+
source_repository = SourceRepository(session)
|
|
227
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
228
|
+
repository = IndexRepository(session)
|
|
229
|
+
service = IndexService(
|
|
230
|
+
repository=repository,
|
|
231
|
+
source_service=source_service,
|
|
232
|
+
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
233
|
+
code_search_service=embedding_factory(
|
|
234
|
+
task_name="code", app_context=app_context, session=session
|
|
235
|
+
),
|
|
236
|
+
text_search_service=embedding_factory(
|
|
237
|
+
task_name="text", app_context=app_context, session=session
|
|
238
|
+
),
|
|
239
|
+
enrichment_service=enrichment_factory(app_context),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
snippets = await service.search(SearchRequest(text_query=query, top_k=top_k))
|
|
243
|
+
|
|
244
|
+
if len(snippets) == 0:
|
|
245
|
+
click.echo("No snippets found")
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
for snippet in snippets:
|
|
249
|
+
click.echo("-" * 80)
|
|
250
|
+
click.echo(f"{snippet.uri}")
|
|
251
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
183
252
|
click.echo(snippet.content)
|
|
184
253
|
click.echo("-" * 80)
|
|
185
254
|
click.echo()
|
|
@@ -189,28 +258,44 @@ async def keyword(
|
|
|
189
258
|
@click.option("--top-k", default=10, help="Number of snippets to retrieve")
|
|
190
259
|
@click.option("--keywords", required=True, help="Comma separated list of keywords")
|
|
191
260
|
@click.option("--code", required=True, help="Semantic code search query")
|
|
261
|
+
@click.option("--text", required=True, help="Semantic text search query")
|
|
192
262
|
@with_app_context
|
|
193
263
|
@with_session
|
|
194
|
-
async def hybrid(
|
|
264
|
+
async def hybrid( # noqa: PLR0913
|
|
195
265
|
session: AsyncSession,
|
|
196
266
|
app_context: AppContext,
|
|
197
267
|
top_k: int,
|
|
198
268
|
keywords: str,
|
|
199
269
|
code: str,
|
|
270
|
+
text: str,
|
|
200
271
|
) -> None:
|
|
201
272
|
"""Search for snippets using hybrid search."""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
273
|
+
source_repository = SourceRepository(session)
|
|
274
|
+
source_service = SourceService(app_context.get_clone_dir(), source_repository)
|
|
275
|
+
repository = IndexRepository(session)
|
|
276
|
+
service = IndexService(
|
|
277
|
+
repository=repository,
|
|
278
|
+
source_service=source_service,
|
|
205
279
|
keyword_search_provider=keyword_search_factory(app_context, session),
|
|
206
|
-
|
|
280
|
+
code_search_service=embedding_factory(
|
|
281
|
+
task_name="code", app_context=app_context, session=session
|
|
282
|
+
),
|
|
283
|
+
text_search_service=embedding_factory(
|
|
284
|
+
task_name="text", app_context=app_context, session=session
|
|
285
|
+
),
|
|
286
|
+
enrichment_service=enrichment_factory(app_context),
|
|
207
287
|
)
|
|
208
288
|
|
|
209
289
|
# Parse keywords into a list of strings
|
|
210
290
|
keywords_list = [k.strip().lower() for k in keywords.split(",")]
|
|
211
291
|
|
|
212
292
|
snippets = await service.search(
|
|
213
|
-
SearchRequest(
|
|
293
|
+
SearchRequest(
|
|
294
|
+
text_query=text,
|
|
295
|
+
keywords=keywords_list,
|
|
296
|
+
code_query=code,
|
|
297
|
+
top_k=top_k,
|
|
298
|
+
)
|
|
214
299
|
)
|
|
215
300
|
|
|
216
301
|
if len(snippets) == 0:
|
|
@@ -220,6 +305,7 @@ async def hybrid(
|
|
|
220
305
|
for snippet in snippets:
|
|
221
306
|
click.echo("-" * 80)
|
|
222
307
|
click.echo(f"{snippet.uri}")
|
|
308
|
+
click.echo(f"Original scores: {snippet.original_scores}")
|
|
223
309
|
click.echo(snippet.content)
|
|
224
310
|
click.echo("-" * 80)
|
|
225
311
|
click.echo()
|
|
@@ -21,7 +21,7 @@ from kodit.embedding.vectorchord_vector_search_service import (
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def embedding_factory(
|
|
24
|
-
app_context: AppContext, session: AsyncSession
|
|
24
|
+
task_name: str, app_context: AppContext, session: AsyncSession
|
|
25
25
|
) -> VectorSearchService:
|
|
26
26
|
"""Create an embedding service."""
|
|
27
27
|
embedding_repository = EmbeddingRepository(session=session)
|
|
@@ -33,7 +33,7 @@ def embedding_factory(
|
|
|
33
33
|
embedding_provider = LocalEmbeddingProvider(CODE)
|
|
34
34
|
|
|
35
35
|
if app_context.default_search.provider == "vectorchord":
|
|
36
|
-
return VectorChordVectorSearchService(session, embedding_provider)
|
|
36
|
+
return VectorChordVectorSearchService(task_name, session, embedding_provider)
|
|
37
37
|
if app_context.default_search.provider == "sqlite":
|
|
38
38
|
return LocalVectorSearchService(
|
|
39
39
|
embedding_repository=embedding_repository,
|
|
@@ -38,8 +38,15 @@ def split_sub_batches(encoding: tiktoken.Encoding, data: list[str]) -> list[list
|
|
|
38
38
|
item_tokens = len(encoding.encode(next_item))
|
|
39
39
|
|
|
40
40
|
if item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
# Loop around trying to truncate the snippet until it fits in the max
|
|
42
|
+
# embedding size
|
|
43
|
+
while item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
|
|
44
|
+
next_item = next_item[:-1]
|
|
45
|
+
item_tokens = len(encoding.encode(next_item))
|
|
46
|
+
|
|
47
|
+
data_to_process[0] = next_item
|
|
48
|
+
|
|
49
|
+
log.warning("Truncated snippet", snippet=next_item)
|
|
43
50
|
|
|
44
51
|
if current_tokens + item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
|
|
45
52
|
break
|
|
@@ -38,26 +38,38 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
38
38
|
# Process batches in parallel with a semaphore to limit concurrent requests
|
|
39
39
|
sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
|
|
40
40
|
|
|
41
|
-
|
|
41
|
+
# Create a list of tuples with a temporary id for each batch
|
|
42
|
+
# We need to do this so that we can return the results in the same order as the
|
|
43
|
+
# input data
|
|
44
|
+
input_data = [(i, batch) for i, batch in enumerate(batched_data)]
|
|
45
|
+
|
|
46
|
+
async def process_batch(
|
|
47
|
+
data: tuple[int, list[str]],
|
|
48
|
+
) -> tuple[int, list[Vector]]:
|
|
49
|
+
batch_id, batch = data
|
|
42
50
|
async with sem:
|
|
43
51
|
try:
|
|
44
52
|
response = await self.openai_client.embeddings.create(
|
|
45
53
|
model=self.model_name,
|
|
46
54
|
input=batch,
|
|
47
55
|
)
|
|
48
|
-
return [
|
|
56
|
+
return batch_id, [
|
|
49
57
|
[float(x) for x in embedding.embedding]
|
|
50
58
|
for embedding in response.data
|
|
51
59
|
]
|
|
52
60
|
except Exception as e:
|
|
53
61
|
self.log.exception("Error embedding batch", error=str(e))
|
|
54
|
-
return []
|
|
62
|
+
return batch_id, []
|
|
55
63
|
|
|
56
64
|
# Create tasks for all batches
|
|
57
|
-
tasks = [process_batch(batch) for batch in
|
|
65
|
+
tasks = [process_batch(batch) for batch in input_data]
|
|
58
66
|
|
|
59
67
|
# Process all batches and yield results as they complete
|
|
60
|
-
results: list[Vector] = []
|
|
68
|
+
results: list[tuple[int, list[Vector]]] = []
|
|
61
69
|
for task in asyncio.as_completed(tasks):
|
|
62
|
-
|
|
63
|
-
|
|
70
|
+
result = await task
|
|
71
|
+
results.append(result)
|
|
72
|
+
|
|
73
|
+
# Output in the same order as the input data
|
|
74
|
+
ordered_results = [result for _, result in sorted(results, key=lambda x: x[0])]
|
|
75
|
+
return [item for sublist in ordered_results for item in sublist]
|
|
@@ -12,23 +12,20 @@ from kodit.embedding.vector_search_service import (
|
|
|
12
12
|
VectorSearchService,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
-
TABLE_NAME = "vectorchord_embeddings"
|
|
16
|
-
INDEX_NAME = f"{TABLE_NAME}_idx"
|
|
17
|
-
|
|
18
15
|
# SQL Queries
|
|
19
16
|
CREATE_VCHORD_EXTENSION = """
|
|
20
17
|
CREATE EXTENSION IF NOT EXISTS vchord CASCADE;
|
|
21
18
|
"""
|
|
22
19
|
|
|
23
|
-
CHECK_VCHORD_EMBEDDING_DIMENSION =
|
|
20
|
+
CHECK_VCHORD_EMBEDDING_DIMENSION = """
|
|
24
21
|
SELECT a.atttypmod as dimension
|
|
25
22
|
FROM pg_attribute a
|
|
26
23
|
JOIN pg_class c ON a.attrelid = c.oid
|
|
27
24
|
WHERE c.relname = '{TABLE_NAME}'
|
|
28
25
|
AND a.attname = 'embedding';
|
|
29
|
-
"""
|
|
26
|
+
"""
|
|
30
27
|
|
|
31
|
-
CREATE_VCHORD_INDEX =
|
|
28
|
+
CREATE_VCHORD_INDEX = """
|
|
32
29
|
CREATE INDEX IF NOT EXISTS {INDEX_NAME}
|
|
33
30
|
ON {TABLE_NAME}
|
|
34
31
|
USING vchordrq (embedding vector_l2_ops) WITH (options = $$
|
|
@@ -38,21 +35,21 @@ lists = []
|
|
|
38
35
|
$$);
|
|
39
36
|
"""
|
|
40
37
|
|
|
41
|
-
INSERT_QUERY =
|
|
38
|
+
INSERT_QUERY = """
|
|
42
39
|
INSERT INTO {TABLE_NAME} (snippet_id, embedding)
|
|
43
40
|
VALUES (:snippet_id, :embedding)
|
|
44
41
|
ON CONFLICT (snippet_id) DO UPDATE
|
|
45
42
|
SET embedding = EXCLUDED.embedding
|
|
46
|
-
"""
|
|
43
|
+
"""
|
|
47
44
|
|
|
48
45
|
# Note that <=> in vectorchord is cosine distance
|
|
49
46
|
# So scores go from 0 (similar) to 2 (opposite)
|
|
50
|
-
SEARCH_QUERY =
|
|
47
|
+
SEARCH_QUERY = """
|
|
51
48
|
SELECT snippet_id, embedding <=> :query as score
|
|
52
49
|
FROM {TABLE_NAME}
|
|
53
50
|
ORDER BY score ASC
|
|
54
51
|
LIMIT :top_k;
|
|
55
|
-
"""
|
|
52
|
+
"""
|
|
56
53
|
|
|
57
54
|
|
|
58
55
|
class VectorChordVectorSearchService(VectorSearchService):
|
|
@@ -60,6 +57,7 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
60
57
|
|
|
61
58
|
def __init__(
|
|
62
59
|
self,
|
|
60
|
+
task_name: str,
|
|
63
61
|
session: AsyncSession,
|
|
64
62
|
embedding_provider: EmbeddingProvider,
|
|
65
63
|
) -> None:
|
|
@@ -67,6 +65,8 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
67
65
|
self.embedding_provider = embedding_provider
|
|
68
66
|
self._session = session
|
|
69
67
|
self._initialized = False
|
|
68
|
+
self.table_name = f"vectorchord_{task_name}_embeddings"
|
|
69
|
+
self.index_name = f"{self.table_name}_idx"
|
|
70
70
|
|
|
71
71
|
async def _initialize(self) -> None:
|
|
72
72
|
"""Initialize the VectorChord environment."""
|
|
@@ -88,15 +88,23 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
88
88
|
vector_dim = (await self.embedding_provider.embed(["dimension"]))[0]
|
|
89
89
|
await self._session.execute(
|
|
90
90
|
text(
|
|
91
|
-
f"""CREATE TABLE IF NOT EXISTS {
|
|
91
|
+
f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
92
92
|
id SERIAL PRIMARY KEY,
|
|
93
93
|
snippet_id INT NOT NULL UNIQUE,
|
|
94
94
|
embedding VECTOR({len(vector_dim)}) NOT NULL
|
|
95
95
|
);"""
|
|
96
96
|
)
|
|
97
97
|
)
|
|
98
|
-
await self._session.execute(
|
|
99
|
-
|
|
98
|
+
await self._session.execute(
|
|
99
|
+
text(
|
|
100
|
+
CREATE_VCHORD_INDEX.format(
|
|
101
|
+
TABLE_NAME=self.table_name, INDEX_NAME=self.index_name
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
result = await self._session.execute(
|
|
106
|
+
text(CHECK_VCHORD_EMBEDDING_DIMENSION.format(TABLE_NAME=self.table_name))
|
|
107
|
+
)
|
|
100
108
|
vector_dim_from_db = result.scalar_one()
|
|
101
109
|
if vector_dim_from_db != len(vector_dim):
|
|
102
110
|
msg = (
|
|
@@ -123,7 +131,7 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
123
131
|
embeddings = await self.embedding_provider.embed([doc.text for doc in data])
|
|
124
132
|
# Execute inserts
|
|
125
133
|
await self._execute(
|
|
126
|
-
text(INSERT_QUERY),
|
|
134
|
+
text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
|
|
127
135
|
[
|
|
128
136
|
{"snippet_id": doc.snippet_id, "embedding": str(embedding)}
|
|
129
137
|
for doc, embedding in zip(data, embeddings, strict=True)
|
|
@@ -134,8 +142,11 @@ class VectorChordVectorSearchService(VectorSearchService):
|
|
|
134
142
|
async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
|
|
135
143
|
"""Query the embedding model."""
|
|
136
144
|
embedding = await self.embedding_provider.embed([query])
|
|
145
|
+
if len(embedding) == 0 or len(embedding[0]) == 0:
|
|
146
|
+
return []
|
|
137
147
|
result = await self._execute(
|
|
138
|
-
text(SEARCH_QUERY
|
|
148
|
+
text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
|
|
149
|
+
{"query": str(embedding[0]), "top_k": top_k},
|
|
139
150
|
)
|
|
140
151
|
rows = result.mappings().all()
|
|
141
152
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Enrichment."""
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Embedding service."""
|
|
2
|
+
|
|
3
|
+
from kodit.config import AppContext
|
|
4
|
+
from kodit.enrichment.enrichment_provider.local_enrichment_provider import (
|
|
5
|
+
LocalEnrichmentProvider,
|
|
6
|
+
)
|
|
7
|
+
from kodit.enrichment.enrichment_provider.openai_enrichment_provider import (
|
|
8
|
+
OpenAIEnrichmentProvider,
|
|
9
|
+
)
|
|
10
|
+
from kodit.enrichment.enrichment_service import (
|
|
11
|
+
EnrichmentService,
|
|
12
|
+
LLMEnrichmentService,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def enrichment_factory(app_context: AppContext) -> EnrichmentService:
|
|
17
|
+
"""Create an embedding service."""
|
|
18
|
+
openai_client = app_context.get_default_openai_client()
|
|
19
|
+
if openai_client is not None:
|
|
20
|
+
enrichment_provider = OpenAIEnrichmentProvider(openai_client=openai_client)
|
|
21
|
+
return LLMEnrichmentService(enrichment_provider)
|
|
22
|
+
|
|
23
|
+
return LLMEnrichmentService(LocalEnrichmentProvider())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Enrichment provider."""
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Enrichment provider."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
ENRICHMENT_SYSTEM_PROMPT = """
|
|
6
|
+
You are a professional software developer. You will be given a snippet of code.
|
|
7
|
+
Please provide a concise explanation of the code.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EnrichmentProvider(ABC):
|
|
12
|
+
"""Enrichment provider."""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
16
|
+
"""Enrich a list of strings."""
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Local embedding service."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
|
7
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
8
|
+
|
|
9
|
+
from kodit.enrichment.enrichment_provider.enrichment_provider import (
|
|
10
|
+
ENRICHMENT_SYSTEM_PROMPT,
|
|
11
|
+
EnrichmentProvider,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LocalEnrichmentProvider(EnrichmentProvider):
|
|
16
|
+
"""Local embedder."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, model_name: str = "Qwen/Qwen3-0.6B") -> None:
|
|
19
|
+
"""Initialize the local enrichment provider."""
|
|
20
|
+
self.log = structlog.get_logger(__name__)
|
|
21
|
+
self.model_name = model_name
|
|
22
|
+
self.model = None
|
|
23
|
+
self.tokenizer = None
|
|
24
|
+
|
|
25
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
26
|
+
"""Enrich a list of strings."""
|
|
27
|
+
if self.tokenizer is None:
|
|
28
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
29
|
+
if self.model is None:
|
|
30
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
|
|
31
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
32
|
+
self.model_name,
|
|
33
|
+
torch_dtype="auto",
|
|
34
|
+
trust_remote_code=True,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
results = []
|
|
38
|
+
for snippet in data:
|
|
39
|
+
# prepare the model input
|
|
40
|
+
messages = [
|
|
41
|
+
{"role": "system", "content": ENRICHMENT_SYSTEM_PROMPT},
|
|
42
|
+
{"role": "user", "content": snippet},
|
|
43
|
+
]
|
|
44
|
+
text = self.tokenizer.apply_chat_template(
|
|
45
|
+
messages,
|
|
46
|
+
tokenize=False,
|
|
47
|
+
add_generation_prompt=True,
|
|
48
|
+
enable_thinking=False,
|
|
49
|
+
)
|
|
50
|
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(
|
|
51
|
+
self.model.device
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# conduct text completion
|
|
55
|
+
generated_ids = self.model.generate(**model_inputs, max_new_tokens=32768)
|
|
56
|
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
|
|
57
|
+
content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip(
|
|
58
|
+
"\n"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
results.append(content)
|
|
62
|
+
|
|
63
|
+
return results
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""OpenAI embedding service."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
import tiktoken
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from kodit.enrichment.enrichment_provider.enrichment_provider import (
|
|
11
|
+
ENRICHMENT_SYSTEM_PROMPT,
|
|
12
|
+
EnrichmentProvider,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
OPENAI_NUM_PARALLEL_TASKS = 10
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OpenAIEnrichmentProvider(EnrichmentProvider):
|
|
19
|
+
"""OpenAI enrichment provider."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
openai_client: AsyncOpenAI,
|
|
24
|
+
model_name: str = "gpt-4o-mini",
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initialize the OpenAI enrichment provider."""
|
|
27
|
+
self.log = structlog.get_logger(__name__)
|
|
28
|
+
self.openai_client = openai_client
|
|
29
|
+
self.model_name = model_name
|
|
30
|
+
self.encoding = tiktoken.encoding_for_model(model_name)
|
|
31
|
+
|
|
32
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
33
|
+
"""Enrich a list of documents."""
|
|
34
|
+
# Process batches in parallel with a semaphore to limit concurrent requests
|
|
35
|
+
sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
|
|
36
|
+
|
|
37
|
+
# Create a list of tuples with a temporary id for each snippet
|
|
38
|
+
# We need to do this so that we can return the results in the same order as the
|
|
39
|
+
# input data
|
|
40
|
+
input_data = [(i, snippet) for i, snippet in enumerate(data)]
|
|
41
|
+
|
|
42
|
+
async def process_data(data: tuple[int, str]) -> tuple[int, str]:
|
|
43
|
+
snippet_id, snippet = data
|
|
44
|
+
if not snippet:
|
|
45
|
+
return snippet_id, ""
|
|
46
|
+
async with sem:
|
|
47
|
+
try:
|
|
48
|
+
response = await self.openai_client.chat.completions.create(
|
|
49
|
+
model=self.model_name,
|
|
50
|
+
messages=[
|
|
51
|
+
{
|
|
52
|
+
"role": "system",
|
|
53
|
+
"content": ENRICHMENT_SYSTEM_PROMPT,
|
|
54
|
+
},
|
|
55
|
+
{"role": "user", "content": snippet},
|
|
56
|
+
],
|
|
57
|
+
)
|
|
58
|
+
return snippet_id, response.choices[0].message.content or ""
|
|
59
|
+
except Exception as e:
|
|
60
|
+
self.log.exception("Error enriching data", error=str(e))
|
|
61
|
+
return snippet_id, ""
|
|
62
|
+
|
|
63
|
+
# Create tasks for all data
|
|
64
|
+
tasks = [process_data(snippet) for snippet in input_data]
|
|
65
|
+
|
|
66
|
+
# Process all data and yield results as they complete
|
|
67
|
+
results: list[tuple[int, str]] = []
|
|
68
|
+
for task in tqdm(
|
|
69
|
+
asyncio.as_completed(tasks),
|
|
70
|
+
total=len(tasks),
|
|
71
|
+
leave=False,
|
|
72
|
+
):
|
|
73
|
+
result = await task
|
|
74
|
+
results.append(result)
|
|
75
|
+
|
|
76
|
+
# Output in the same order as the input data
|
|
77
|
+
return [result for _, result in sorted(results, key=lambda x: x[0])]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Enrichment service."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from kodit.enrichment.enrichment_provider.enrichment_provider import EnrichmentProvider
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EnrichmentService(ABC):
|
|
9
|
+
"""Enrichment service."""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
13
|
+
"""Enrich a list of strings."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NullEnrichmentService(EnrichmentService):
|
|
17
|
+
"""Null enrichment service."""
|
|
18
|
+
|
|
19
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
20
|
+
"""Enrich a list of strings."""
|
|
21
|
+
return [""] * len(data)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LLMEnrichmentService(EnrichmentService):
|
|
25
|
+
"""Enrichment service using an LLM."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, enrichment_provider: EnrichmentProvider) -> None:
|
|
28
|
+
"""Initialize the enrichment service."""
|
|
29
|
+
self.enrichment_provider = enrichment_provider
|
|
30
|
+
|
|
31
|
+
async def enrich(self, data: list[str]) -> list[str]:
|
|
32
|
+
"""Enrich a list of strings."""
|
|
33
|
+
return await self.enrichment_provider.enrich(data)
|