kodit 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

kodit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.15'
21
- __version_tuple__ = version_tuple = (0, 1, 15)
20
+ __version__ = version = '0.1.17'
21
+ __version_tuple__ = version_tuple = (0, 1, 17)
kodit/cli.py CHANGED
@@ -17,11 +17,10 @@ from kodit.config import (
17
17
  with_session,
18
18
  )
19
19
  from kodit.embedding.embedding_factory import embedding_factory
20
+ from kodit.enrichment.enrichment_factory import enrichment_factory
20
21
  from kodit.indexing.indexing_repository import IndexRepository
21
- from kodit.indexing.indexing_service import IndexService
22
+ from kodit.indexing.indexing_service import IndexService, SearchRequest
22
23
  from kodit.log import configure_logging, configure_telemetry, log_event
23
- from kodit.search.search_repository import SearchRepository
24
- from kodit.search.search_service import SearchRequest, SearchService
25
24
  from kodit.source.source_repository import SourceRepository
26
25
  from kodit.source.source_service import SourceService
27
26
 
@@ -72,9 +71,13 @@ async def index(
72
71
  repository=repository,
73
72
  source_service=source_service,
74
73
  keyword_search_provider=keyword_search_factory(app_context, session),
75
- vector_search_service=embedding_factory(
76
- app_context=app_context, session=session
74
+ code_search_service=embedding_factory(
75
+ task_name="code", app_context=app_context, session=session
77
76
  ),
77
+ text_search_service=embedding_factory(
78
+ task_name="text", app_context=app_context, session=session
79
+ ),
80
+ enrichment_service=enrichment_factory(app_context),
78
81
  )
79
82
 
80
83
  if not sources:
@@ -131,11 +134,20 @@ async def code(
131
134
 
132
135
  This works best if your query is code.
133
136
  """
134
- repository = SearchRepository(session)
135
- service = SearchService(
136
- repository,
137
+ source_repository = SourceRepository(session)
138
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
139
+ repository = IndexRepository(session)
140
+ service = IndexService(
141
+ repository=repository,
142
+ source_service=source_service,
137
143
  keyword_search_provider=keyword_search_factory(app_context, session),
138
- embedding_service=embedding_factory(app_context=app_context, session=session),
144
+ code_search_service=embedding_factory(
145
+ task_name="code", app_context=app_context, session=session
146
+ ),
147
+ text_search_service=embedding_factory(
148
+ task_name="text", app_context=app_context, session=session
149
+ ),
150
+ enrichment_service=enrichment_factory(app_context),
139
151
  )
140
152
 
141
153
  snippets = await service.search(SearchRequest(code_query=query, top_k=top_k))
@@ -147,6 +159,7 @@ async def code(
147
159
  for snippet in snippets:
148
160
  click.echo("-" * 80)
149
161
  click.echo(f"{snippet.uri}")
162
+ click.echo(f"Original scores: {snippet.original_scores}")
150
163
  click.echo(snippet.content)
151
164
  click.echo("-" * 80)
152
165
  click.echo()
@@ -164,11 +177,20 @@ async def keyword(
164
177
  top_k: int,
165
178
  ) -> None:
166
179
  """Search for snippets using keyword search."""
167
- repository = SearchRepository(session)
168
- service = SearchService(
169
- repository,
180
+ source_repository = SourceRepository(session)
181
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
182
+ repository = IndexRepository(session)
183
+ service = IndexService(
184
+ repository=repository,
185
+ source_service=source_service,
170
186
  keyword_search_provider=keyword_search_factory(app_context, session),
171
- embedding_service=embedding_factory(app_context=app_context, session=session),
187
+ code_search_service=embedding_factory(
188
+ task_name="code", app_context=app_context, session=session
189
+ ),
190
+ text_search_service=embedding_factory(
191
+ task_name="text", app_context=app_context, session=session
192
+ ),
193
+ enrichment_service=enrichment_factory(app_context),
172
194
  )
173
195
 
174
196
  snippets = await service.search(SearchRequest(keywords=keywords, top_k=top_k))
@@ -180,6 +202,53 @@ async def keyword(
180
202
  for snippet in snippets:
181
203
  click.echo("-" * 80)
182
204
  click.echo(f"{snippet.uri}")
205
+ click.echo(f"Original scores: {snippet.original_scores}")
206
+ click.echo(snippet.content)
207
+ click.echo("-" * 80)
208
+ click.echo()
209
+
210
+
211
+ @search.command()
212
+ @click.argument("query")
213
+ @click.option("--top-k", default=10, help="Number of snippets to retrieve")
214
+ @with_app_context
215
+ @with_session
216
+ async def text(
217
+ session: AsyncSession,
218
+ app_context: AppContext,
219
+ query: str,
220
+ top_k: int,
221
+ ) -> None:
222
+ """Search for snippets using semantic text search.
223
+
224
+ This works best if your query is text.
225
+ """
226
+ source_repository = SourceRepository(session)
227
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
228
+ repository = IndexRepository(session)
229
+ service = IndexService(
230
+ repository=repository,
231
+ source_service=source_service,
232
+ keyword_search_provider=keyword_search_factory(app_context, session),
233
+ code_search_service=embedding_factory(
234
+ task_name="code", app_context=app_context, session=session
235
+ ),
236
+ text_search_service=embedding_factory(
237
+ task_name="text", app_context=app_context, session=session
238
+ ),
239
+ enrichment_service=enrichment_factory(app_context),
240
+ )
241
+
242
+ snippets = await service.search(SearchRequest(text_query=query, top_k=top_k))
243
+
244
+ if len(snippets) == 0:
245
+ click.echo("No snippets found")
246
+ return
247
+
248
+ for snippet in snippets:
249
+ click.echo("-" * 80)
250
+ click.echo(f"{snippet.uri}")
251
+ click.echo(f"Original scores: {snippet.original_scores}")
183
252
  click.echo(snippet.content)
184
253
  click.echo("-" * 80)
185
254
  click.echo()
@@ -189,28 +258,44 @@ async def keyword(
189
258
  @click.option("--top-k", default=10, help="Number of snippets to retrieve")
190
259
  @click.option("--keywords", required=True, help="Comma separated list of keywords")
191
260
  @click.option("--code", required=True, help="Semantic code search query")
261
+ @click.option("--text", required=True, help="Semantic text search query")
192
262
  @with_app_context
193
263
  @with_session
194
- async def hybrid(
264
+ async def hybrid( # noqa: PLR0913
195
265
  session: AsyncSession,
196
266
  app_context: AppContext,
197
267
  top_k: int,
198
268
  keywords: str,
199
269
  code: str,
270
+ text: str,
200
271
  ) -> None:
201
272
  """Search for snippets using hybrid search."""
202
- repository = SearchRepository(session)
203
- service = SearchService(
204
- repository,
273
+ source_repository = SourceRepository(session)
274
+ source_service = SourceService(app_context.get_clone_dir(), source_repository)
275
+ repository = IndexRepository(session)
276
+ service = IndexService(
277
+ repository=repository,
278
+ source_service=source_service,
205
279
  keyword_search_provider=keyword_search_factory(app_context, session),
206
- embedding_service=embedding_factory(app_context=app_context, session=session),
280
+ code_search_service=embedding_factory(
281
+ task_name="code", app_context=app_context, session=session
282
+ ),
283
+ text_search_service=embedding_factory(
284
+ task_name="text", app_context=app_context, session=session
285
+ ),
286
+ enrichment_service=enrichment_factory(app_context),
207
287
  )
208
288
 
209
289
  # Parse keywords into a list of strings
210
290
  keywords_list = [k.strip().lower() for k in keywords.split(",")]
211
291
 
212
292
  snippets = await service.search(
213
- SearchRequest(keywords=keywords_list, code_query=code, top_k=top_k)
293
+ SearchRequest(
294
+ text_query=text,
295
+ keywords=keywords_list,
296
+ code_query=code,
297
+ top_k=top_k,
298
+ )
214
299
  )
215
300
 
216
301
  if len(snippets) == 0:
@@ -220,6 +305,7 @@ async def hybrid(
220
305
  for snippet in snippets:
221
306
  click.echo("-" * 80)
222
307
  click.echo(f"{snippet.uri}")
308
+ click.echo(f"Original scores: {snippet.original_scores}")
223
309
  click.echo(snippet.content)
224
310
  click.echo("-" * 80)
225
311
  click.echo()
@@ -21,7 +21,7 @@ from kodit.embedding.vectorchord_vector_search_service import (
21
21
 
22
22
 
23
23
  def embedding_factory(
24
- app_context: AppContext, session: AsyncSession
24
+ task_name: str, app_context: AppContext, session: AsyncSession
25
25
  ) -> VectorSearchService:
26
26
  """Create an embedding service."""
27
27
  embedding_repository = EmbeddingRepository(session=session)
@@ -33,7 +33,7 @@ def embedding_factory(
33
33
  embedding_provider = LocalEmbeddingProvider(CODE)
34
34
 
35
35
  if app_context.default_search.provider == "vectorchord":
36
- return VectorChordVectorSearchService(session, embedding_provider)
36
+ return VectorChordVectorSearchService(task_name, session, embedding_provider)
37
37
  if app_context.default_search.provider == "sqlite":
38
38
  return LocalVectorSearchService(
39
39
  embedding_repository=embedding_repository,
@@ -38,8 +38,15 @@ def split_sub_batches(encoding: tiktoken.Encoding, data: list[str]) -> list[list
38
38
  item_tokens = len(encoding.encode(next_item))
39
39
 
40
40
  if item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
41
- log.warning("Skipping too long snippet", snippet=data_to_process.pop(0))
42
- continue
41
+ # Loop around trying to truncate the snippet until it fits in the max
42
+ # embedding size
43
+ while item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
44
+ next_item = next_item[:-1]
45
+ item_tokens = len(encoding.encode(next_item))
46
+
47
+ data_to_process[0] = next_item
48
+
49
+ log.warning("Truncated snippet", snippet=next_item)
43
50
 
44
51
  if current_tokens + item_tokens > OPENAI_MAX_EMBEDDING_SIZE:
45
52
  break
@@ -38,26 +38,38 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
38
38
  # Process batches in parallel with a semaphore to limit concurrent requests
39
39
  sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
40
40
 
41
- async def process_batch(batch: list[str]) -> list[Vector]:
41
+ # Create a list of tuples with a temporary id for each batch
42
+ # We need to do this so that we can return the results in the same order as the
43
+ # input data
44
+ input_data = [(i, batch) for i, batch in enumerate(batched_data)]
45
+
46
+ async def process_batch(
47
+ data: tuple[int, list[str]],
48
+ ) -> tuple[int, list[Vector]]:
49
+ batch_id, batch = data
42
50
  async with sem:
43
51
  try:
44
52
  response = await self.openai_client.embeddings.create(
45
53
  model=self.model_name,
46
54
  input=batch,
47
55
  )
48
- return [
56
+ return batch_id, [
49
57
  [float(x) for x in embedding.embedding]
50
58
  for embedding in response.data
51
59
  ]
52
60
  except Exception as e:
53
61
  self.log.exception("Error embedding batch", error=str(e))
54
- return []
62
+ return batch_id, []
55
63
 
56
64
  # Create tasks for all batches
57
- tasks = [process_batch(batch) for batch in batched_data]
65
+ tasks = [process_batch(batch) for batch in input_data]
58
66
 
59
67
  # Process all batches and yield results as they complete
60
- results: list[Vector] = []
68
+ results: list[tuple[int, list[Vector]]] = []
61
69
  for task in asyncio.as_completed(tasks):
62
- results.extend(await task)
63
- return results
70
+ result = await task
71
+ results.append(result)
72
+
73
+ # Output in the same order as the input data
74
+ ordered_results = [result for _, result in sorted(results, key=lambda x: x[0])]
75
+ return [item for sublist in ordered_results for item in sublist]
@@ -12,23 +12,20 @@ from kodit.embedding.vector_search_service import (
12
12
  VectorSearchService,
13
13
  )
14
14
 
15
- TABLE_NAME = "vectorchord_embeddings"
16
- INDEX_NAME = f"{TABLE_NAME}_idx"
17
-
18
15
  # SQL Queries
19
16
  CREATE_VCHORD_EXTENSION = """
20
17
  CREATE EXTENSION IF NOT EXISTS vchord CASCADE;
21
18
  """
22
19
 
23
- CHECK_VCHORD_EMBEDDING_DIMENSION = f"""
20
+ CHECK_VCHORD_EMBEDDING_DIMENSION = """
24
21
  SELECT a.atttypmod as dimension
25
22
  FROM pg_attribute a
26
23
  JOIN pg_class c ON a.attrelid = c.oid
27
24
  WHERE c.relname = '{TABLE_NAME}'
28
25
  AND a.attname = 'embedding';
29
- """ # noqa: S608
26
+ """
30
27
 
31
- CREATE_VCHORD_INDEX = f"""
28
+ CREATE_VCHORD_INDEX = """
32
29
  CREATE INDEX IF NOT EXISTS {INDEX_NAME}
33
30
  ON {TABLE_NAME}
34
31
  USING vchordrq (embedding vector_l2_ops) WITH (options = $$
@@ -38,21 +35,21 @@ lists = []
38
35
  $$);
39
36
  """
40
37
 
41
- INSERT_QUERY = f"""
38
+ INSERT_QUERY = """
42
39
  INSERT INTO {TABLE_NAME} (snippet_id, embedding)
43
40
  VALUES (:snippet_id, :embedding)
44
41
  ON CONFLICT (snippet_id) DO UPDATE
45
42
  SET embedding = EXCLUDED.embedding
46
- """ # noqa: S608
43
+ """
47
44
 
48
45
  # Note that <=> in vectorchord is cosine distance
49
46
  # So scores go from 0 (similar) to 2 (opposite)
50
- SEARCH_QUERY = f"""
47
+ SEARCH_QUERY = """
51
48
  SELECT snippet_id, embedding <=> :query as score
52
49
  FROM {TABLE_NAME}
53
50
  ORDER BY score ASC
54
51
  LIMIT :top_k;
55
- """ # noqa: S608
52
+ """
56
53
 
57
54
 
58
55
  class VectorChordVectorSearchService(VectorSearchService):
@@ -60,6 +57,7 @@ class VectorChordVectorSearchService(VectorSearchService):
60
57
 
61
58
  def __init__(
62
59
  self,
60
+ task_name: str,
63
61
  session: AsyncSession,
64
62
  embedding_provider: EmbeddingProvider,
65
63
  ) -> None:
@@ -67,6 +65,8 @@ class VectorChordVectorSearchService(VectorSearchService):
67
65
  self.embedding_provider = embedding_provider
68
66
  self._session = session
69
67
  self._initialized = False
68
+ self.table_name = f"vectorchord_{task_name}_embeddings"
69
+ self.index_name = f"{self.table_name}_idx"
70
70
 
71
71
  async def _initialize(self) -> None:
72
72
  """Initialize the VectorChord environment."""
@@ -88,15 +88,23 @@ class VectorChordVectorSearchService(VectorSearchService):
88
88
  vector_dim = (await self.embedding_provider.embed(["dimension"]))[0]
89
89
  await self._session.execute(
90
90
  text(
91
- f"""CREATE TABLE IF NOT EXISTS {TABLE_NAME} (
91
+ f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
92
92
  id SERIAL PRIMARY KEY,
93
93
  snippet_id INT NOT NULL UNIQUE,
94
94
  embedding VECTOR({len(vector_dim)}) NOT NULL
95
95
  );"""
96
96
  )
97
97
  )
98
- await self._session.execute(text(CREATE_VCHORD_INDEX))
99
- result = await self._session.execute(text(CHECK_VCHORD_EMBEDDING_DIMENSION))
98
+ await self._session.execute(
99
+ text(
100
+ CREATE_VCHORD_INDEX.format(
101
+ TABLE_NAME=self.table_name, INDEX_NAME=self.index_name
102
+ )
103
+ )
104
+ )
105
+ result = await self._session.execute(
106
+ text(CHECK_VCHORD_EMBEDDING_DIMENSION.format(TABLE_NAME=self.table_name))
107
+ )
100
108
  vector_dim_from_db = result.scalar_one()
101
109
  if vector_dim_from_db != len(vector_dim):
102
110
  msg = (
@@ -123,7 +131,7 @@ class VectorChordVectorSearchService(VectorSearchService):
123
131
  embeddings = await self.embedding_provider.embed([doc.text for doc in data])
124
132
  # Execute inserts
125
133
  await self._execute(
126
- text(INSERT_QUERY),
134
+ text(INSERT_QUERY.format(TABLE_NAME=self.table_name)),
127
135
  [
128
136
  {"snippet_id": doc.snippet_id, "embedding": str(embedding)}
129
137
  for doc, embedding in zip(data, embeddings, strict=True)
@@ -134,8 +142,11 @@ class VectorChordVectorSearchService(VectorSearchService):
134
142
  async def retrieve(self, query: str, top_k: int = 10) -> list[VectorSearchResponse]:
135
143
  """Query the embedding model."""
136
144
  embedding = await self.embedding_provider.embed([query])
145
+ if len(embedding) == 0 or len(embedding[0]) == 0:
146
+ return []
137
147
  result = await self._execute(
138
- text(SEARCH_QUERY), {"query": str(embedding[0]), "top_k": top_k}
148
+ text(SEARCH_QUERY.format(TABLE_NAME=self.table_name)),
149
+ {"query": str(embedding[0]), "top_k": top_k},
139
150
  )
140
151
  rows = result.mappings().all()
141
152
 
@@ -0,0 +1 @@
1
+ """Enrichment."""
@@ -0,0 +1,23 @@
1
+ """Embedding service."""
2
+
3
+ from kodit.config import AppContext
4
+ from kodit.enrichment.enrichment_provider.local_enrichment_provider import (
5
+ LocalEnrichmentProvider,
6
+ )
7
+ from kodit.enrichment.enrichment_provider.openai_enrichment_provider import (
8
+ OpenAIEnrichmentProvider,
9
+ )
10
+ from kodit.enrichment.enrichment_service import (
11
+ EnrichmentService,
12
+ LLMEnrichmentService,
13
+ )
14
+
15
+
16
+ def enrichment_factory(app_context: AppContext) -> EnrichmentService:
17
+ """Create an embedding service."""
18
+ openai_client = app_context.get_default_openai_client()
19
+ if openai_client is not None:
20
+ enrichment_provider = OpenAIEnrichmentProvider(openai_client=openai_client)
21
+ return LLMEnrichmentService(enrichment_provider)
22
+
23
+ return LLMEnrichmentService(LocalEnrichmentProvider())
@@ -0,0 +1 @@
1
+ """Enrichment provider."""
@@ -0,0 +1,16 @@
1
+ """Enrichment provider."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ ENRICHMENT_SYSTEM_PROMPT = """
6
+ You are a professional software developer. You will be given a snippet of code.
7
+ Please provide a concise explanation of the code.
8
+ """
9
+
10
+
11
+ class EnrichmentProvider(ABC):
12
+ """Enrichment provider."""
13
+
14
+ @abstractmethod
15
+ async def enrich(self, data: list[str]) -> list[str]:
16
+ """Enrich a list of strings."""
@@ -0,0 +1,63 @@
1
+ """Local embedding service."""
2
+
3
+ import os
4
+
5
+ import structlog
6
+ from transformers.models.auto.modeling_auto import AutoModelForCausalLM
7
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
8
+
9
+ from kodit.enrichment.enrichment_provider.enrichment_provider import (
10
+ ENRICHMENT_SYSTEM_PROMPT,
11
+ EnrichmentProvider,
12
+ )
13
+
14
+
15
+ class LocalEnrichmentProvider(EnrichmentProvider):
16
+ """Local embedder."""
17
+
18
+ def __init__(self, model_name: str = "Qwen/Qwen3-0.6B") -> None:
19
+ """Initialize the local enrichment provider."""
20
+ self.log = structlog.get_logger(__name__)
21
+ self.model_name = model_name
22
+ self.model = None
23
+ self.tokenizer = None
24
+
25
+ async def enrich(self, data: list[str]) -> list[str]:
26
+ """Enrich a list of strings."""
27
+ if self.tokenizer is None:
28
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
29
+ if self.model is None:
30
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Avoid warnings
31
+ self.model = AutoModelForCausalLM.from_pretrained(
32
+ self.model_name,
33
+ torch_dtype="auto",
34
+ trust_remote_code=True,
35
+ )
36
+
37
+ results = []
38
+ for snippet in data:
39
+ # prepare the model input
40
+ messages = [
41
+ {"role": "system", "content": ENRICHMENT_SYSTEM_PROMPT},
42
+ {"role": "user", "content": snippet},
43
+ ]
44
+ text = self.tokenizer.apply_chat_template(
45
+ messages,
46
+ tokenize=False,
47
+ add_generation_prompt=True,
48
+ enable_thinking=False,
49
+ )
50
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(
51
+ self.model.device
52
+ )
53
+
54
+ # conduct text completion
55
+ generated_ids = self.model.generate(**model_inputs, max_new_tokens=32768)
56
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
57
+ content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip(
58
+ "\n"
59
+ )
60
+
61
+ results.append(content)
62
+
63
+ return results
@@ -0,0 +1,77 @@
1
+ """OpenAI embedding service."""
2
+
3
+ import asyncio
4
+
5
+ import structlog
6
+ import tiktoken
7
+ from openai import AsyncOpenAI
8
+ from tqdm import tqdm
9
+
10
+ from kodit.enrichment.enrichment_provider.enrichment_provider import (
11
+ ENRICHMENT_SYSTEM_PROMPT,
12
+ EnrichmentProvider,
13
+ )
14
+
15
+ OPENAI_NUM_PARALLEL_TASKS = 10
16
+
17
+
18
+ class OpenAIEnrichmentProvider(EnrichmentProvider):
19
+ """OpenAI enrichment provider."""
20
+
21
+ def __init__(
22
+ self,
23
+ openai_client: AsyncOpenAI,
24
+ model_name: str = "gpt-4o-mini",
25
+ ) -> None:
26
+ """Initialize the OpenAI enrichment provider."""
27
+ self.log = structlog.get_logger(__name__)
28
+ self.openai_client = openai_client
29
+ self.model_name = model_name
30
+ self.encoding = tiktoken.encoding_for_model(model_name)
31
+
32
+ async def enrich(self, data: list[str]) -> list[str]:
33
+ """Enrich a list of documents."""
34
+ # Process batches in parallel with a semaphore to limit concurrent requests
35
+ sem = asyncio.Semaphore(OPENAI_NUM_PARALLEL_TASKS)
36
+
37
+ # Create a list of tuples with a temporary id for each snippet
38
+ # We need to do this so that we can return the results in the same order as the
39
+ # input data
40
+ input_data = [(i, snippet) for i, snippet in enumerate(data)]
41
+
42
+ async def process_data(data: tuple[int, str]) -> tuple[int, str]:
43
+ snippet_id, snippet = data
44
+ if not snippet:
45
+ return snippet_id, ""
46
+ async with sem:
47
+ try:
48
+ response = await self.openai_client.chat.completions.create(
49
+ model=self.model_name,
50
+ messages=[
51
+ {
52
+ "role": "system",
53
+ "content": ENRICHMENT_SYSTEM_PROMPT,
54
+ },
55
+ {"role": "user", "content": snippet},
56
+ ],
57
+ )
58
+ return snippet_id, response.choices[0].message.content or ""
59
+ except Exception as e:
60
+ self.log.exception("Error enriching data", error=str(e))
61
+ return snippet_id, ""
62
+
63
+ # Create tasks for all data
64
+ tasks = [process_data(snippet) for snippet in input_data]
65
+
66
+ # Process all data and yield results as they complete
67
+ results: list[tuple[int, str]] = []
68
+ for task in tqdm(
69
+ asyncio.as_completed(tasks),
70
+ total=len(tasks),
71
+ leave=False,
72
+ ):
73
+ result = await task
74
+ results.append(result)
75
+
76
+ # Output in the same order as the input data
77
+ return [result for _, result in sorted(results, key=lambda x: x[0])]
@@ -0,0 +1,33 @@
1
+ """Enrichment service."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from kodit.enrichment.enrichment_provider.enrichment_provider import EnrichmentProvider
6
+
7
+
8
+ class EnrichmentService(ABC):
9
+ """Enrichment service."""
10
+
11
+ @abstractmethod
12
+ async def enrich(self, data: list[str]) -> list[str]:
13
+ """Enrich a list of strings."""
14
+
15
+
16
+ class NullEnrichmentService(EnrichmentService):
17
+ """Null enrichment service."""
18
+
19
+ async def enrich(self, data: list[str]) -> list[str]:
20
+ """Enrich a list of strings."""
21
+ return [""] * len(data)
22
+
23
+
24
+ class LLMEnrichmentService(EnrichmentService):
25
+ """Enrichment service using an LLM."""
26
+
27
+ def __init__(self, enrichment_provider: EnrichmentProvider) -> None:
28
+ """Initialize the enrichment service."""
29
+ self.enrichment_provider = enrichment_provider
30
+
31
+ async def enrich(self, data: list[str]) -> list[str]:
32
+ """Enrich a list of strings."""
33
+ return await self.enrichment_provider.enrich(data)