kodit 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

Files changed (42) hide show
  1. kodit/_version.py +2 -2
  2. kodit/app.py +4 -2
  3. kodit/application/factories/code_indexing_factory.py +54 -7
  4. kodit/application/factories/reporting_factory.py +27 -0
  5. kodit/application/services/auto_indexing_service.py +16 -4
  6. kodit/application/services/code_indexing_application_service.py +115 -133
  7. kodit/application/services/indexing_worker_service.py +18 -20
  8. kodit/application/services/queue_service.py +12 -14
  9. kodit/application/services/reporting.py +86 -0
  10. kodit/application/services/sync_scheduler.py +21 -20
  11. kodit/cli.py +14 -18
  12. kodit/config.py +24 -1
  13. kodit/database.py +2 -1
  14. kodit/domain/protocols.py +9 -1
  15. kodit/domain/services/bm25_service.py +1 -6
  16. kodit/domain/services/index_service.py +22 -58
  17. kodit/domain/value_objects.py +57 -9
  18. kodit/infrastructure/api/v1/dependencies.py +23 -10
  19. kodit/infrastructure/cloning/git/working_copy.py +36 -7
  20. kodit/infrastructure/embedding/embedding_factory.py +8 -3
  21. kodit/infrastructure/embedding/embedding_providers/litellm_embedding_provider.py +48 -55
  22. kodit/infrastructure/git/git_utils.py +3 -2
  23. kodit/infrastructure/mappers/index_mapper.py +1 -0
  24. kodit/infrastructure/reporting/__init__.py +1 -0
  25. kodit/infrastructure/reporting/log_progress.py +65 -0
  26. kodit/infrastructure/reporting/tdqm_progress.py +73 -0
  27. kodit/infrastructure/sqlalchemy/embedding_repository.py +47 -68
  28. kodit/infrastructure/sqlalchemy/entities.py +28 -2
  29. kodit/infrastructure/sqlalchemy/index_repository.py +274 -236
  30. kodit/infrastructure/sqlalchemy/task_repository.py +55 -39
  31. kodit/infrastructure/sqlalchemy/unit_of_work.py +59 -0
  32. kodit/mcp.py +10 -2
  33. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/METADATA +1 -1
  34. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/RECORD +37 -36
  35. kodit/domain/interfaces.py +0 -27
  36. kodit/infrastructure/ui/__init__.py +0 -1
  37. kodit/infrastructure/ui/progress.py +0 -170
  38. kodit/infrastructure/ui/spinner.py +0 -74
  39. kodit/reporting.py +0 -78
  40. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/WHEEL +0 -0
  41. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/entry_points.txt +0 -0
  42. {kodit-0.4.1.dist-info → kodit-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -8,7 +8,8 @@ import structlog
8
8
  from pydantic import AnyUrl
9
9
 
10
10
  import kodit.domain.entities as domain_entities
11
- from kodit.domain.interfaces import ProgressCallback
11
+ from kodit.application.factories.reporting_factory import create_noop_operation
12
+ from kodit.application.services.reporting import ProgressTracker
12
13
  from kodit.domain.services.enrichment_service import EnrichmentDomainService
13
14
  from kodit.domain.value_objects import (
14
15
  EnrichmentIndexRequest,
@@ -21,7 +22,6 @@ from kodit.infrastructure.cloning.metadata import FileMetadataExtractor
21
22
  from kodit.infrastructure.git.git_utils import is_valid_clone_target
22
23
  from kodit.infrastructure.ignore.ignore_pattern_provider import GitIgnorePatternProvider
23
24
  from kodit.infrastructure.slicing.slicer import Slicer
24
- from kodit.reporting import Reporter
25
25
  from kodit.utils.path_utils import path_from_uri
26
26
 
27
27
 
@@ -58,27 +58,23 @@ class IndexDomainService:
58
58
  async def prepare_index(
59
59
  self,
60
60
  uri_or_path_like: str, # Must include user/pass, etc
61
- progress_callback: ProgressCallback | None = None,
61
+ step: ProgressTracker | None = None,
62
62
  ) -> domain_entities.WorkingCopy:
63
63
  """Prepare an index by scanning files and creating working copy."""
64
+ step = step or create_noop_operation()
65
+ self.log.info("Preparing index")
64
66
  sanitized_uri, source_type = self.sanitize_uri(uri_or_path_like)
65
- reporter = Reporter(self.log, progress_callback)
66
67
  self.log.info("Preparing source", uri=str(sanitized_uri))
67
68
 
68
69
  if source_type == domain_entities.SourceType.FOLDER:
69
- await reporter.start("prepare_index", 1, "Scanning source...")
70
70
  local_path = path_from_uri(str(sanitized_uri))
71
71
  elif source_type == domain_entities.SourceType.GIT:
72
72
  source_type = domain_entities.SourceType.GIT
73
73
  git_working_copy_provider = GitWorkingCopyProvider(self._clone_dir)
74
- await reporter.start("prepare_index", 1, "Cloning source...")
75
- local_path = await git_working_copy_provider.prepare(uri_or_path_like)
76
- await reporter.done("prepare_index")
74
+ local_path = await git_working_copy_provider.prepare(uri_or_path_like, step)
77
75
  else:
78
76
  raise ValueError(f"Unsupported source: {uri_or_path_like}")
79
77
 
80
- await reporter.done("prepare_index")
81
-
82
78
  return domain_entities.WorkingCopy(
83
79
  remote_uri=sanitized_uri,
84
80
  cloned_path=local_path,
@@ -89,9 +85,10 @@ class IndexDomainService:
89
85
  async def extract_snippets_from_index(
90
86
  self,
91
87
  index: domain_entities.Index,
92
- progress_callback: ProgressCallback | None = None,
88
+ step: ProgressTracker | None = None,
93
89
  ) -> domain_entities.Index:
94
90
  """Extract code snippets from files in the index."""
91
+ step = step or create_noop_operation()
95
92
  file_count = len(index.source.working_copy.files)
96
93
 
97
94
  self.log.info(
@@ -127,40 +124,28 @@ class IndexDomainService:
127
124
  languages=lang_files_map.keys(),
128
125
  )
129
126
 
130
- reporter = Reporter(self.log, progress_callback)
131
- await reporter.start(
132
- "extract_snippets",
133
- len(lang_files_map.keys()),
134
- "Extracting code snippets...",
135
- )
136
-
137
127
  # Calculate snippets for each language
138
128
  slicer = Slicer()
129
+ step.set_total(len(lang_files_map.keys()))
139
130
  for i, (lang, lang_files) in enumerate(lang_files_map.items()):
140
- await reporter.step(
141
- "extract_snippets",
142
- i,
143
- len(lang_files_map.keys()),
144
- f"Extracting code snippets for {lang}...",
145
- )
131
+ step.set_current(i)
146
132
  s = slicer.extract_snippets(lang_files, language=lang)
147
133
  index.snippets.extend(s)
148
134
 
149
- await reporter.done("extract_snippets")
150
135
  return index
151
136
 
152
137
  async def enrich_snippets_in_index(
153
138
  self,
154
139
  snippets: list[domain_entities.Snippet],
155
- progress_callback: ProgressCallback | None = None,
140
+ reporting_step: ProgressTracker | None = None,
156
141
  ) -> list[domain_entities.Snippet]:
157
142
  """Enrich snippets with AI-generated summaries."""
143
+ reporting_step = reporting_step or create_noop_operation()
158
144
  if not snippets or len(snippets) == 0:
145
+ reporting_step.skip("No snippets to enrich")
159
146
  return snippets
160
147
 
161
- reporter = Reporter(self.log, progress_callback)
162
- await reporter.start("enrichment", len(snippets), "Enriching snippets...")
163
-
148
+ reporting_step.set_total(len(snippets))
164
149
  snippet_map = {snippet.id: snippet for snippet in snippets if snippet.id}
165
150
 
166
151
  enrichment_request = EnrichmentIndexRequest(
@@ -177,11 +162,8 @@ class IndexDomainService:
177
162
  snippet_map[result.snippet_id].add_summary(result.text)
178
163
 
179
164
  processed += 1
180
- await reporter.step(
181
- "enrichment", processed, len(snippets), "Enriching snippets..."
182
- )
165
+ reporting_step.set_current(processed)
183
166
 
184
- await reporter.done("enrichment")
185
167
  return list(snippet_map.values())
186
168
 
187
169
  def sanitize_uri(
@@ -207,15 +189,14 @@ class IndexDomainService:
207
189
  async def refresh_working_copy(
208
190
  self,
209
191
  working_copy: domain_entities.WorkingCopy,
210
- progress_callback: ProgressCallback | None = None,
192
+ step: ProgressTracker | None = None,
211
193
  ) -> domain_entities.WorkingCopy:
212
194
  """Refresh the working copy."""
195
+ step = step or create_noop_operation()
213
196
  metadata_extractor = FileMetadataExtractor(working_copy.source_type)
214
- reporter = Reporter(self.log, progress_callback)
215
-
216
197
  if working_copy.source_type == domain_entities.SourceType.GIT:
217
198
  git_working_copy_provider = GitWorkingCopyProvider(self._clone_dir)
218
- await git_working_copy_provider.sync(str(working_copy.remote_uri))
199
+ await git_working_copy_provider.sync(str(working_copy.remote_uri), step)
219
200
 
220
201
  current_file_paths = working_copy.list_filesystem_paths(
221
202
  GitIgnorePatternProvider(working_copy.cloned_path)
@@ -241,19 +222,12 @@ class IndexDomainService:
241
222
 
242
223
  # Setup reporter
243
224
  processed = 0
244
- await reporter.start(
245
- "refresh_working_copy", num_files_to_process, "Refreshing working copy..."
246
- )
225
+ step.set_total(num_files_to_process)
247
226
 
248
227
  # First check to see if any files have been deleted
249
228
  for file_path in deleted_file_paths:
250
229
  processed += 1
251
- await reporter.step(
252
- "refresh_working_copy",
253
- processed,
254
- num_files_to_process,
255
- f"Deleted {file_path.name}",
256
- )
230
+ step.set_current(processed)
257
231
  previous_files_map[
258
232
  file_path
259
233
  ].file_processing_status = domain_entities.FileProcessingStatus.DELETED
@@ -261,12 +235,7 @@ class IndexDomainService:
261
235
  # Then check to see if there are any new files
262
236
  for file_path in new_file_paths:
263
237
  processed += 1
264
- await reporter.step(
265
- "refresh_working_copy",
266
- processed,
267
- num_files_to_process,
268
- f"New {file_path.name}",
269
- )
238
+ step.set_current(processed)
270
239
  try:
271
240
  working_copy.files.append(
272
241
  await metadata_extractor.extract(file_path=file_path)
@@ -278,12 +247,7 @@ class IndexDomainService:
278
247
  # Finally check if there are any modified files
279
248
  for file_path in modified_file_paths:
280
249
  processed += 1
281
- await reporter.step(
282
- "refresh_working_copy",
283
- processed,
284
- num_files_to_process,
285
- f"Modified {file_path.name}",
286
- )
250
+ step.set_current(processed)
287
251
  try:
288
252
  previous_file = previous_files_map[file_path]
289
253
  new_file = await metadata_extractor.extract(file_path=file_path)
@@ -1,9 +1,9 @@
1
1
  """Pure domain value objects and DTOs."""
2
2
 
3
3
  import json
4
- from dataclasses import dataclass
4
+ from dataclasses import dataclass, replace
5
5
  from datetime import datetime
6
- from enum import Enum, IntEnum
6
+ from enum import Enum, IntEnum, StrEnum
7
7
  from pathlib import Path
8
8
  from typing import ClassVar
9
9
 
@@ -390,18 +390,18 @@ class IndexRunRequest:
390
390
 
391
391
 
392
392
  @dataclass
393
- class ProgressEvent:
394
- """Domain model for progress events."""
393
+ class ProgressState:
394
+ """Progress state."""
395
395
 
396
- operation: str
397
- current: int
398
- total: int
399
- message: str | None = None
396
+ current: int = 0
397
+ total: int = 0
398
+ operation: str = ""
399
+ message: str = ""
400
400
 
401
401
  @property
402
402
  def percentage(self) -> float:
403
403
  """Calculate the percentage of completion."""
404
- return (self.current / self.total * 100) if self.total > 0 else 0.0
404
+ return (self.current / self.total) * 100 if self.total > 0 else 0.0
405
405
 
406
406
 
407
407
  @dataclass
@@ -662,3 +662,51 @@ class QueuePriority(IntEnum):
662
662
 
663
663
  BACKGROUND = 10
664
664
  USER_INITIATED = 50
665
+
666
+
667
+ # Reporting value objects
668
+
669
+
670
+ class ReportingState(StrEnum):
671
+ """Reporting state."""
672
+
673
+ STARTED = "started"
674
+ IN_PROGRESS = "in_progress"
675
+ COMPLETED = "completed"
676
+ FAILED = "failed"
677
+ SKIPPED = "skipped"
678
+
679
+
680
+ @dataclass(frozen=True)
681
+ class Progress:
682
+ """Immutable representation of a step's state."""
683
+
684
+ name: str
685
+ state: ReportingState
686
+ message: str = ""
687
+ error: BaseException | None = None
688
+ total: int = 0
689
+ current: int = 0
690
+
691
+ @property
692
+ def completion_percent(self) -> float:
693
+ """Calculate the percentage of completion."""
694
+ if self.total == 0:
695
+ return 0.0
696
+ return min(100.0, max(0.0, (self.current / self.total) * 100.0))
697
+
698
+ def with_error(self, error: BaseException) -> "Progress":
699
+ """Return a new snapshot with updated error."""
700
+ return replace(self, error=error)
701
+
702
+ def with_total(self, total: int) -> "Progress":
703
+ """Return a new snapshot with updated total."""
704
+ return replace(self, total=total)
705
+
706
+ def with_progress(self, current: int) -> "Progress":
707
+ """Return a new snapshot with updated progress."""
708
+ return replace(self, current=current)
709
+
710
+ def with_state(self, state: ReportingState, message: str = "") -> "Progress":
711
+ """Return a new snapshot with updated state."""
712
+ return replace(self, state=state, message=message)
@@ -1,13 +1,13 @@
1
1
  """FastAPI dependencies for the REST API."""
2
2
 
3
- from collections.abc import AsyncGenerator
3
+ from collections.abc import AsyncGenerator, Callable
4
4
  from typing import Annotated, cast
5
5
 
6
6
  from fastapi import Depends, Request
7
7
  from sqlalchemy.ext.asyncio import AsyncSession
8
8
 
9
9
  from kodit.application.factories.code_indexing_factory import (
10
- create_code_indexing_application_service,
10
+ create_server_code_indexing_application_service,
11
11
  )
12
12
  from kodit.application.services.code_indexing_application_service import (
13
13
  CodeIndexingApplicationService,
@@ -16,7 +16,7 @@ from kodit.application.services.queue_service import QueueService
16
16
  from kodit.config import AppContext
17
17
  from kodit.domain.services.index_query_service import IndexQueryService
18
18
  from kodit.infrastructure.indexing.fusion_service import ReciprocalRankFusionService
19
- from kodit.infrastructure.sqlalchemy.index_repository import SqlAlchemyIndexRepository
19
+ from kodit.infrastructure.sqlalchemy.index_repository import create_index_repository
20
20
 
21
21
 
22
22
  def get_app_context(request: Request) -> AppContext:
@@ -42,12 +42,25 @@ async def get_db_session(
42
42
  DBSessionDep = Annotated[AsyncSession, Depends(get_db_session)]
43
43
 
44
44
 
45
+ async def get_db_session_factory(
46
+ app_context: AppContextDep,
47
+ ) -> AsyncGenerator[Callable[[], AsyncSession], None]:
48
+ """Get database session dependency."""
49
+ db = await app_context.get_db()
50
+ yield db.session_factory
51
+
52
+
53
+ DBSessionFactoryDep = Annotated[
54
+ Callable[[], AsyncSession], Depends(get_db_session_factory)
55
+ ]
56
+
57
+
45
58
  async def get_index_query_service(
46
- session: DBSessionDep,
59
+ session_factory: DBSessionFactoryDep,
47
60
  ) -> IndexQueryService:
48
61
  """Get index query service dependency."""
49
62
  return IndexQueryService(
50
- index_repository=SqlAlchemyIndexRepository(session=session),
63
+ index_repository=create_index_repository(session_factory=session_factory),
51
64
  fusion_service=ReciprocalRankFusionService(),
52
65
  )
53
66
 
@@ -58,11 +71,11 @@ IndexQueryServiceDep = Annotated[IndexQueryService, Depends(get_index_query_serv
58
71
  async def get_indexing_app_service(
59
72
  app_context: AppContextDep,
60
73
  session: DBSessionDep,
74
+ session_factory: DBSessionFactoryDep,
61
75
  ) -> CodeIndexingApplicationService:
62
76
  """Get indexing application service dependency."""
63
- return create_code_indexing_application_service(
64
- app_context=app_context,
65
- session=session,
77
+ return create_server_code_indexing_application_service(
78
+ app_context, session, session_factory
66
79
  )
67
80
 
68
81
 
@@ -72,11 +85,11 @@ IndexingAppServiceDep = Annotated[
72
85
 
73
86
 
74
87
  async def get_queue_service(
75
- session: DBSessionDep,
88
+ session_factory: DBSessionFactoryDep,
76
89
  ) -> QueueService:
77
90
  """Get queue service dependency."""
78
91
  return QueueService(
79
- session=session,
92
+ session_factory=session_factory,
80
93
  )
81
94
 
82
95
 
@@ -7,6 +7,8 @@ from pathlib import Path
7
7
  import git
8
8
  import structlog
9
9
 
10
+ from kodit.application.factories.reporting_factory import create_noop_operation
11
+ from kodit.application.services.reporting import ProgressTracker
10
12
  from kodit.domain.entities import WorkingCopy
11
13
 
12
14
 
@@ -25,18 +27,42 @@ class GitWorkingCopyProvider:
25
27
  dir_name = f"repo-{dir_hash}"
26
28
  return self.clone_dir / dir_name
27
29
 
28
- async def prepare(self, uri: str) -> Path:
30
+ async def prepare(
31
+ self,
32
+ uri: str,
33
+ step: ProgressTracker | None = None,
34
+ ) -> Path:
29
35
  """Prepare a Git working copy."""
36
+ step = step or create_noop_operation()
30
37
  sanitized_uri = WorkingCopy.sanitize_git_url(uri)
31
38
  clone_path = self.get_clone_path(uri)
32
39
  clone_path.mkdir(parents=True, exist_ok=True)
33
40
 
41
+ step_record = []
42
+ step.set_total(12)
43
+
44
+ def _clone_progress_callback(
45
+ a: int, _: str | float | None, __: str | float | None, _d: str
46
+ ) -> None:
47
+ if a not in step_record:
48
+ step_record.append(a)
49
+
50
+ # Git reports a really weird format. This is a quick hack to get some
51
+ # progress.
52
+ step.set_current(len(step_record))
53
+
34
54
  try:
35
55
  self.log.info(
36
56
  "Cloning repository", uri=sanitized_uri, clone_path=str(clone_path)
37
57
  )
38
58
  # Use the original URI for cloning (with credentials if present)
39
- git.Repo.clone_from(uri, clone_path)
59
+ options = ["--depth=1", "--single-branch"]
60
+ git.Repo.clone_from(
61
+ uri,
62
+ clone_path,
63
+ progress=_clone_progress_callback,
64
+ multi_options=options,
65
+ )
40
66
  except git.GitCommandError as e:
41
67
  if "already exists and is not an empty directory" not in str(e):
42
68
  msg = f"Failed to clone repository: {e}"
@@ -45,8 +71,9 @@ class GitWorkingCopyProvider:
45
71
 
46
72
  return clone_path
47
73
 
48
- async def sync(self, uri: str) -> Path:
74
+ async def sync(self, uri: str, step: ProgressTracker | None = None) -> Path:
49
75
  """Refresh a Git working copy."""
76
+ step = step or create_noop_operation()
50
77
  clone_path = self.get_clone_path(uri)
51
78
 
52
79
  # Check if the clone directory exists and is a valid Git repository
@@ -54,9 +81,10 @@ class GitWorkingCopyProvider:
54
81
  self.log.info(
55
82
  "Clone directory does not exist or is not a Git repository, "
56
83
  "preparing...",
57
- uri=uri, clone_path=str(clone_path)
84
+ uri=uri,
85
+ clone_path=str(clone_path),
58
86
  )
59
- return await self.prepare(uri)
87
+ return await self.prepare(uri, step)
60
88
 
61
89
  try:
62
90
  repo = git.Repo(clone_path)
@@ -64,10 +92,11 @@ class GitWorkingCopyProvider:
64
92
  except git.InvalidGitRepositoryError:
65
93
  self.log.warning(
66
94
  "Invalid Git repository found, re-cloning...",
67
- uri=uri, clone_path=str(clone_path)
95
+ uri=uri,
96
+ clone_path=str(clone_path),
68
97
  )
69
98
  # Remove the invalid directory and re-clone
70
99
  shutil.rmtree(clone_path)
71
- return await self.prepare(uri)
100
+ return await self.prepare(uri, step)
72
101
 
73
102
  return clone_path
@@ -1,5 +1,7 @@
1
1
  """Factory for creating embedding services with DDD architecture."""
2
2
 
3
+ from collections.abc import Callable
4
+
3
5
  import structlog
4
6
  from sqlalchemy.ext.asyncio import AsyncSession
5
7
 
@@ -24,7 +26,7 @@ from kodit.infrastructure.embedding.vectorchord_vector_search_repository import
24
26
  VectorChordVectorSearchRepository,
25
27
  )
26
28
  from kodit.infrastructure.sqlalchemy.embedding_repository import (
27
- SqlAlchemyEmbeddingRepository,
29
+ create_embedding_repository,
28
30
  )
29
31
  from kodit.infrastructure.sqlalchemy.entities import EmbeddingType
30
32
  from kodit.log import log_event
@@ -36,12 +38,15 @@ def _get_endpoint_configuration(app_context: AppContext) -> Endpoint | None:
36
38
 
37
39
 
38
40
  def embedding_domain_service_factory(
39
- task_name: TaskName, app_context: AppContext, session: AsyncSession
41
+ task_name: TaskName,
42
+ app_context: AppContext,
43
+ session: AsyncSession,
44
+ session_factory: Callable[[], AsyncSession],
40
45
  ) -> EmbeddingDomainService:
41
46
  """Create an embedding domain service."""
42
47
  structlog.get_logger(__name__)
43
48
  # Create embedding repository
44
- embedding_repository = SqlAlchemyEmbeddingRepository(session=session)
49
+ embedding_repository = create_embedding_repository(session_factory=session_factory)
45
50
 
46
51
  # Create embedding provider
47
52
  embedding_provider: EmbeddingProvider | None = None
@@ -7,16 +7,15 @@ from typing import Any
7
7
  import httpx
8
8
  import litellm
9
9
  import structlog
10
+ import tiktoken
10
11
  from litellm import aembedding
11
12
 
12
13
  from kodit.config import Endpoint
13
14
  from kodit.domain.services.embedding_service import EmbeddingProvider
14
15
  from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
15
-
16
- # Constants
17
- MAX_TOKENS = 8192 # Conservative token limit for the embedding model
18
- BATCH_SIZE = 10 # Maximum number of items per API call
19
- DEFAULT_NUM_PARALLEL_TASKS = 10 # Semaphore limit for concurrent requests
16
+ from kodit.infrastructure.embedding.embedding_providers.batching import (
17
+ split_sub_batches,
18
+ )
20
19
 
21
20
 
22
21
  class LiteLLMEmbeddingProvider(EmbeddingProvider):
@@ -32,46 +31,36 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
32
31
  endpoint: The endpoint configuration containing all settings.
33
32
 
34
33
  """
35
- self.model_name = endpoint.model or "text-embedding-3-small"
36
- self.api_key = endpoint.api_key
37
- self.base_url = endpoint.base_url
38
- self.socket_path = endpoint.socket_path
39
- self.num_parallel_tasks = (
40
- endpoint.num_parallel_tasks or DEFAULT_NUM_PARALLEL_TASKS
41
- )
42
- self.timeout = endpoint.timeout or 30.0
43
- self.extra_params = endpoint.extra_params or {}
34
+ self.endpoint = endpoint
44
35
  self.log = structlog.get_logger(__name__)
36
+ self._encoding: tiktoken.Encoding | None = None
45
37
 
46
38
  # Configure LiteLLM with custom HTTPX client for Unix socket support if needed
47
39
  self._setup_litellm_client()
48
40
 
49
41
  def _setup_litellm_client(self) -> None:
50
42
  """Set up LiteLLM with custom HTTPX client for Unix socket support."""
51
- if self.socket_path:
43
+ if self.endpoint.socket_path:
52
44
  # Create HTTPX client with Unix socket transport
53
- transport = httpx.AsyncHTTPTransport(uds=self.socket_path)
45
+ transport = httpx.AsyncHTTPTransport(uds=self.endpoint.socket_path)
54
46
  unix_client = httpx.AsyncClient(
55
47
  transport=transport,
56
48
  base_url="http://localhost", # Base URL for Unix socket
57
- timeout=self.timeout,
49
+ timeout=self.endpoint.timeout,
58
50
  )
59
51
  # Set as LiteLLM's async client session
60
52
  litellm.aclient_session = unix_client
61
53
 
62
54
  def _split_sub_batches(
63
- self, data: list[EmbeddingRequest]
55
+ self, encoding: tiktoken.Encoding, data: list[EmbeddingRequest]
64
56
  ) -> list[list[EmbeddingRequest]]:
65
- """Split data into manageable batches.
66
-
67
- For LiteLLM, we use a simpler batching approach since token counting
68
- varies by provider. We use a conservative batch size approach.
69
- """
70
- batches = []
71
- for i in range(0, len(data), BATCH_SIZE):
72
- batch = data[i : i + BATCH_SIZE]
73
- batches.append(batch)
74
- return batches
57
+ """Proxy to the shared batching utility (kept for backward-compat)."""
58
+ return split_sub_batches(
59
+ encoding,
60
+ data,
61
+ max_tokens=self.endpoint.max_tokens,
62
+ batch_size=self.endpoint.num_parallel_tasks,
63
+ )
75
64
 
76
65
  async def _call_embeddings_api(self, texts: list[str]) -> Any:
77
66
  """Call the embeddings API using LiteLLM.
@@ -84,21 +73,21 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
84
73
 
85
74
  """
86
75
  kwargs = {
87
- "model": self.model_name,
76
+ "model": self.endpoint.model,
88
77
  "input": texts,
89
- "timeout": self.timeout,
78
+ "timeout": self.endpoint.timeout,
90
79
  }
91
80
 
92
81
  # Add API key if provided
93
- if self.api_key:
94
- kwargs["api_key"] = self.api_key
82
+ if self.endpoint.api_key:
83
+ kwargs["api_key"] = self.endpoint.api_key
95
84
 
96
85
  # Add base_url if provided
97
- if self.base_url:
98
- kwargs["api_base"] = self.base_url
86
+ if self.endpoint.base_url:
87
+ kwargs["api_base"] = self.endpoint.base_url
99
88
 
100
89
  # Add extra parameters
101
- kwargs.update(self.extra_params)
90
+ kwargs.update(self.endpoint.extra_params or {})
102
91
 
103
92
  try:
104
93
  # Use litellm's async embedding function
@@ -108,7 +97,7 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
108
97
  )
109
98
  except Exception as e:
110
99
  self.log.exception(
111
- "LiteLLM embedding API error", error=str(e), model=self.model_name
100
+ "LiteLLM embedding API error", error=str(e), model=self.endpoint.model
112
101
  )
113
102
  raise
114
103
 
@@ -121,32 +110,28 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
121
110
  return
122
111
 
123
112
  # Split into batches
124
- batched_data = self._split_sub_batches(data)
113
+ encoding = self._get_encoding()
114
+ batched_data = self._split_sub_batches(encoding, data)
125
115
 
126
116
  # Process batches concurrently with semaphore
127
- sem = asyncio.Semaphore(self.num_parallel_tasks)
117
+ sem = asyncio.Semaphore(self.endpoint.num_parallel_tasks or 10)
128
118
 
129
119
  async def _process_batch(
130
120
  batch: list[EmbeddingRequest],
131
121
  ) -> list[EmbeddingResponse]:
132
122
  async with sem:
133
- try:
134
- response = await self._call_embeddings_api(
135
- [item.text for item in batch]
123
+ response = await self._call_embeddings_api(
124
+ [item.text for item in batch]
125
+ )
126
+ embeddings_data = response.get("data", [])
127
+
128
+ return [
129
+ EmbeddingResponse(
130
+ snippet_id=item.snippet_id,
131
+ embedding=emb_data.get("embedding", []),
136
132
  )
137
- embeddings_data = response.get("data", [])
138
-
139
- return [
140
- EmbeddingResponse(
141
- snippet_id=item.snippet_id,
142
- embedding=emb_data.get("embedding", []),
143
- )
144
- for item, emb_data in zip(batch, embeddings_data, strict=True)
145
- ]
146
- except Exception as e:
147
- self.log.exception("Error embedding batch", error=str(e))
148
- # Return no embeddings for this batch if there was an error
149
- return []
133
+ for item, emb_data in zip(batch, embeddings_data, strict=True)
134
+ ]
150
135
 
151
136
  tasks = [_process_batch(batch) for batch in batched_data]
152
137
  for task in asyncio.as_completed(tasks):
@@ -155,9 +140,17 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
155
140
  async def close(self) -> None:
156
141
  """Close the provider and cleanup HTTPX client if using Unix sockets."""
157
142
  if (
158
- self.socket_path
143
+ self.endpoint.socket_path
159
144
  and hasattr(litellm, "aclient_session")
160
145
  and litellm.aclient_session
161
146
  ):
162
147
  await litellm.aclient_session.aclose()
163
148
  litellm.aclient_session = None
149
+
150
+ def _get_encoding(self) -> tiktoken.Encoding:
151
+ """Return (and cache) the tiktoken encoding for the chosen model."""
152
+ if self._encoding is None:
153
+ self._encoding = tiktoken.get_encoding(
154
+ "o200k_base"
155
+ ) # Reasonable default for most models, but might not be perfect.
156
+ return self._encoding