kodit 0.3.17__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kodit might be problematic. Click here for more details.

kodit/_version.py CHANGED
@@ -1,7 +1,14 @@
1
1
  # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
3
 
4
- __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
5
12
 
6
13
  TYPE_CHECKING = False
7
14
  if TYPE_CHECKING:
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
9
16
  from typing import Union
10
17
 
11
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
12
20
  else:
13
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
14
23
 
15
24
  version: str
16
25
  __version__: str
17
26
  __version_tuple__: VERSION_TUPLE
18
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
19
30
 
20
- __version__ = version = '0.3.17'
21
- __version_tuple__ = version_tuple = (0, 3, 17)
31
+ __version__ = version = '0.4.1'
32
+ __version_tuple__ = version_tuple = (0, 4, 1)
33
+
34
+ __commit_id__ = commit_id = None
kodit/app.py CHANGED
@@ -12,7 +12,11 @@ from kodit.application.services.auto_indexing_service import AutoIndexingService
12
12
  from kodit.application.services.indexing_worker_service import IndexingWorkerService
13
13
  from kodit.application.services.sync_scheduler import SyncSchedulerService
14
14
  from kodit.config import AppContext
15
- from kodit.infrastructure.api.v1.routers import indexes_router, search_router
15
+ from kodit.infrastructure.api.v1.routers import (
16
+ indexes_router,
17
+ queue_router,
18
+ search_router,
19
+ )
16
20
  from kodit.infrastructure.api.v1.schemas.context import AppLifespanState
17
21
  from kodit.mcp import mcp
18
22
  from kodit.middleware import ASGICancelledErrorMiddleware, logging_middleware
@@ -113,6 +117,7 @@ async def healthz() -> Response:
113
117
 
114
118
  # Include API routers
115
119
  app.include_router(indexes_router)
120
+ app.include_router(queue_router)
116
121
  app.include_router(search_router)
117
122
 
118
123
 
@@ -50,3 +50,8 @@ class QueueService:
50
50
  """List all tasks in the queue."""
51
51
  repo = SqlAlchemyTaskRepository(self.session)
52
52
  return await repo.list(task_type)
53
+
54
+ async def get_task(self, task_id: str) -> Task | None:
55
+ """Get a specific task by ID."""
56
+ repo = SqlAlchemyTaskRepository(self.session)
57
+ return await repo.get(task_id)
kodit/config.py CHANGED
@@ -38,15 +38,17 @@ DEFAULT_LOG_FORMAT = LogFormat.PRETTY
38
38
  DEFAULT_DISABLE_TELEMETRY = False
39
39
  T = TypeVar("T")
40
40
 
41
- EndpointType = Literal["openai"]
41
+ EndpointType = Literal["openai", "litellm"]
42
42
 
43
43
 
44
44
  class Endpoint(BaseModel):
45
45
  """Endpoint provides configuration for an AI service."""
46
46
 
47
- type: EndpointType | None = None
48
47
  base_url: str | None = None
49
- model: str | None = None
48
+ model: str | None = Field(
49
+ default=None,
50
+ description="Model to use for the endpoint in litellm format (e.g. 'openai/text-embedding-3-small')", # noqa: E501
51
+ )
50
52
  api_key: str | None = None
51
53
  num_parallel_tasks: int | None = None
52
54
  socket_path: str | None = Field(
@@ -57,6 +59,10 @@ class Endpoint(BaseModel):
57
59
  default=None,
58
60
  description="Request timeout in seconds (default: 30.0)",
59
61
  )
62
+ extra_params: dict[str, Any] | None = Field(
63
+ default=None,
64
+ description="Extra provider-specific non-secret parameters for LiteLLM",
65
+ )
60
66
 
61
67
 
62
68
  class Search(BaseModel):
@@ -114,15 +120,11 @@ class PeriodicSyncConfig(BaseModel):
114
120
  class RemoteConfig(BaseModel):
115
121
  """Configuration for remote server connection."""
116
122
 
117
- server_url: str | None = Field(
118
- default=None, description="Remote Kodit server URL"
119
- )
123
+ server_url: str | None = Field(default=None, description="Remote Kodit server URL")
120
124
  api_key: str | None = Field(default=None, description="API key for authentication")
121
125
  timeout: float = Field(default=30.0, description="Request timeout in seconds")
122
126
  max_retries: int = Field(default=3, description="Maximum retry attempts")
123
- verify_ssl: bool = Field(
124
- default=True, description="Verify SSL certificates"
125
- )
127
+ verify_ssl: bool = Field(default=True, description="Verify SSL certificates")
126
128
 
127
129
 
128
130
  class CustomAutoIndexingEnvSource(EnvSettingsSource):
@@ -198,13 +200,6 @@ class AppContext(BaseSettings):
198
200
  log_level: str = Field(default=DEFAULT_LOG_LEVEL)
199
201
  log_format: LogFormat = Field(default=DEFAULT_LOG_FORMAT)
200
202
  disable_telemetry: bool = Field(default=DEFAULT_DISABLE_TELEMETRY)
201
- default_endpoint: Endpoint | None = Field(
202
- default=None,
203
- description=(
204
- "Default endpoint to use for all AI interactions "
205
- "(can be overridden by task-specific configuration)."
206
- ),
207
- )
208
203
  embedding_endpoint: Endpoint | None = Field(
209
204
  default=None,
210
205
  description="Endpoint to use for embedding.",
@@ -1,5 +1,5 @@
1
1
  """API v1 modules."""
2
2
 
3
- from .routers import indexes_router, search_router
3
+ from .routers import indexes_router, queue_router, search_router
4
4
 
5
- __all__ = ["indexes_router", "search_router"]
5
+ __all__ = ["indexes_router", "queue_router", "search_router"]
@@ -1,6 +1,7 @@
1
1
  """API v1 routers."""
2
2
 
3
3
  from .indexes import router as indexes_router
4
+ from .queue import router as queue_router
4
5
  from .search import router as search_router
5
6
 
6
- __all__ = ["indexes_router", "search_router"]
7
+ __all__ = ["indexes_router", "queue_router", "search_router"]
@@ -0,0 +1,76 @@
1
+ """Queue management router for the REST API."""
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+
5
+ from kodit.domain.value_objects import TaskType
6
+ from kodit.infrastructure.api.middleware.auth import api_key_auth
7
+ from kodit.infrastructure.api.v1.dependencies import QueueServiceDep
8
+ from kodit.infrastructure.api.v1.schemas.queue import (
9
+ TaskAttributes,
10
+ TaskData,
11
+ TaskListResponse,
12
+ TaskResponse,
13
+ )
14
+
15
+ router = APIRouter(
16
+ prefix="/api/v1/queue",
17
+ tags=["queue"],
18
+ dependencies=[Depends(api_key_auth)],
19
+ responses={
20
+ 401: {"description": "Unauthorized"},
21
+ 422: {"description": "Invalid request"},
22
+ },
23
+ )
24
+
25
+
26
+ @router.get("")
27
+ async def list_queue_tasks(
28
+ queue_service: QueueServiceDep,
29
+ task_type: TaskType | None = None,
30
+ ) -> TaskListResponse:
31
+ """List all tasks in the queue.
32
+
33
+ Optionally filter by task type.
34
+ """
35
+ tasks = await queue_service.list_tasks(task_type)
36
+ return TaskListResponse(
37
+ data=[
38
+ TaskData(
39
+ type="task",
40
+ id=task.id,
41
+ attributes=TaskAttributes(
42
+ type=str(task.type),
43
+ priority=task.priority,
44
+ payload=task.payload,
45
+ created_at=task.created_at,
46
+ updated_at=task.updated_at,
47
+ ),
48
+ )
49
+ for task in tasks
50
+ ]
51
+ )
52
+
53
+
54
+ @router.get("/{task_id}", responses={404: {"description": "Task not found"}})
55
+ async def get_queue_task(
56
+ task_id: str,
57
+ queue_service: QueueServiceDep,
58
+ ) -> TaskResponse:
59
+ """Get details of a specific task in the queue."""
60
+ task = await queue_service.get_task(task_id)
61
+ if not task:
62
+ raise HTTPException(status_code=404, detail="Task not found")
63
+
64
+ return TaskResponse(
65
+ data=TaskData(
66
+ type="task",
67
+ id=task.id,
68
+ attributes=TaskAttributes(
69
+ type=str(task.type),
70
+ priority=task.priority,
71
+ payload=task.payload,
72
+ created_at=task.created_at,
73
+ updated_at=task.updated_at,
74
+ ),
75
+ )
76
+ )
@@ -0,0 +1,35 @@
1
+ """JSON:API schemas for queue operations."""
2
+
3
+ from datetime import datetime
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class TaskAttributes(BaseModel):
9
+ """Task attributes for JSON:API responses."""
10
+
11
+ type: str
12
+ priority: int
13
+ payload: dict
14
+ created_at: datetime | None
15
+ updated_at: datetime | None
16
+
17
+
18
+ class TaskData(BaseModel):
19
+ """Task data for JSON:API responses."""
20
+
21
+ type: str = "task"
22
+ id: str
23
+ attributes: TaskAttributes
24
+
25
+
26
+ class TaskResponse(BaseModel):
27
+ """JSON:API response for single task."""
28
+
29
+ data: TaskData
30
+
31
+
32
+ class TaskListResponse(BaseModel):
33
+ """JSON:API response for task list."""
34
+
35
+ data: list[TaskData]
@@ -1,5 +1,6 @@
1
1
  """Factory for creating embedding services with DDD architecture."""
2
2
 
3
+ import structlog
3
4
  from sqlalchemy.ext.asyncio import AsyncSession
4
5
 
5
6
  from kodit.config import AppContext, Endpoint
@@ -8,14 +9,13 @@ from kodit.domain.services.embedding_service import (
8
9
  EmbeddingProvider,
9
10
  VectorSearchRepository,
10
11
  )
12
+ from kodit.infrastructure.embedding.embedding_providers.litellm_embedding_provider import ( # noqa: E501
13
+ LiteLLMEmbeddingProvider,
14
+ )
11
15
  from kodit.infrastructure.embedding.embedding_providers.local_embedding_provider import ( # noqa: E501
12
16
  CODE,
13
17
  LocalEmbeddingProvider,
14
18
  )
15
- from kodit.infrastructure.embedding.embedding_providers.openai_embedding_provider import ( # noqa: E501
16
- OPENAI_NUM_PARALLEL_TASKS,
17
- OpenAIEmbeddingProvider,
18
- )
19
19
  from kodit.infrastructure.embedding.local_vector_search_repository import (
20
20
  LocalVectorSearchRepository,
21
21
  )
@@ -32,30 +32,24 @@ from kodit.log import log_event
32
32
 
33
33
  def _get_endpoint_configuration(app_context: AppContext) -> Endpoint | None:
34
34
  """Get the endpoint configuration for the embedding service."""
35
- return app_context.embedding_endpoint or app_context.default_endpoint or None
35
+ return app_context.embedding_endpoint or None
36
36
 
37
37
 
38
38
  def embedding_domain_service_factory(
39
39
  task_name: TaskName, app_context: AppContext, session: AsyncSession
40
40
  ) -> EmbeddingDomainService:
41
41
  """Create an embedding domain service."""
42
+ structlog.get_logger(__name__)
42
43
  # Create embedding repository
43
44
  embedding_repository = SqlAlchemyEmbeddingRepository(session=session)
44
45
 
45
46
  # Create embedding provider
46
47
  embedding_provider: EmbeddingProvider | None = None
47
48
  endpoint = _get_endpoint_configuration(app_context)
48
- if endpoint and endpoint.type == "openai":
49
- log_event("kodit.embedding", {"provider": "openai"})
50
- # Use new httpx-based provider with socket support
51
- embedding_provider = OpenAIEmbeddingProvider(
52
- api_key=endpoint.api_key,
53
- base_url=endpoint.base_url or "https://api.openai.com/v1",
54
- model_name=endpoint.model or "text-embedding-3-small",
55
- num_parallel_tasks=endpoint.num_parallel_tasks or OPENAI_NUM_PARALLEL_TASKS,
56
- socket_path=endpoint.socket_path,
57
- timeout=endpoint.timeout or 30.0,
58
- )
49
+
50
+ if endpoint:
51
+ log_event("kodit.embedding", {"provider": "litellm"})
52
+ embedding_provider = LiteLLMEmbeddingProvider(endpoint=endpoint)
59
53
  else:
60
54
  log_event("kodit.embedding", {"provider": "local"})
61
55
  embedding_provider = LocalEmbeddingProvider(CODE)
@@ -0,0 +1,163 @@
1
+ """LiteLLM embedding provider implementation."""
2
+
3
+ import asyncio
4
+ from collections.abc import AsyncGenerator
5
+ from typing import Any
6
+
7
+ import httpx
8
+ import litellm
9
+ import structlog
10
+ from litellm import aembedding
11
+
12
+ from kodit.config import Endpoint
13
+ from kodit.domain.services.embedding_service import EmbeddingProvider
14
+ from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
15
+
16
+ # Constants
17
+ MAX_TOKENS = 8192 # Conservative token limit for the embedding model
18
+ BATCH_SIZE = 10 # Maximum number of items per API call
19
+ DEFAULT_NUM_PARALLEL_TASKS = 10 # Semaphore limit for concurrent requests
20
+
21
+
22
+ class LiteLLMEmbeddingProvider(EmbeddingProvider):
23
+ """LiteLLM embedding provider that supports 100+ providers."""
24
+
25
+ def __init__(
26
+ self,
27
+ endpoint: Endpoint,
28
+ ) -> None:
29
+ """Initialize the LiteLLM embedding provider.
30
+
31
+ Args:
32
+ endpoint: The endpoint configuration containing all settings.
33
+
34
+ """
35
+ self.model_name = endpoint.model or "text-embedding-3-small"
36
+ self.api_key = endpoint.api_key
37
+ self.base_url = endpoint.base_url
38
+ self.socket_path = endpoint.socket_path
39
+ self.num_parallel_tasks = (
40
+ endpoint.num_parallel_tasks or DEFAULT_NUM_PARALLEL_TASKS
41
+ )
42
+ self.timeout = endpoint.timeout or 30.0
43
+ self.extra_params = endpoint.extra_params or {}
44
+ self.log = structlog.get_logger(__name__)
45
+
46
+ # Configure LiteLLM with custom HTTPX client for Unix socket support if needed
47
+ self._setup_litellm_client()
48
+
49
+ def _setup_litellm_client(self) -> None:
50
+ """Set up LiteLLM with custom HTTPX client for Unix socket support."""
51
+ if self.socket_path:
52
+ # Create HTTPX client with Unix socket transport
53
+ transport = httpx.AsyncHTTPTransport(uds=self.socket_path)
54
+ unix_client = httpx.AsyncClient(
55
+ transport=transport,
56
+ base_url="http://localhost", # Base URL for Unix socket
57
+ timeout=self.timeout,
58
+ )
59
+ # Set as LiteLLM's async client session
60
+ litellm.aclient_session = unix_client
61
+
62
+ def _split_sub_batches(
63
+ self, data: list[EmbeddingRequest]
64
+ ) -> list[list[EmbeddingRequest]]:
65
+ """Split data into manageable batches.
66
+
67
+ For LiteLLM, we use a simpler batching approach since token counting
68
+ varies by provider. We use a conservative batch size approach.
69
+ """
70
+ batches = []
71
+ for i in range(0, len(data), BATCH_SIZE):
72
+ batch = data[i : i + BATCH_SIZE]
73
+ batches.append(batch)
74
+ return batches
75
+
76
+ async def _call_embeddings_api(self, texts: list[str]) -> Any:
77
+ """Call the embeddings API using LiteLLM.
78
+
79
+ Args:
80
+ texts: The texts to embed.
81
+
82
+ Returns:
83
+ The API response as a dictionary.
84
+
85
+ """
86
+ kwargs = {
87
+ "model": self.model_name,
88
+ "input": texts,
89
+ "timeout": self.timeout,
90
+ }
91
+
92
+ # Add API key if provided
93
+ if self.api_key:
94
+ kwargs["api_key"] = self.api_key
95
+
96
+ # Add base_url if provided
97
+ if self.base_url:
98
+ kwargs["api_base"] = self.base_url
99
+
100
+ # Add extra parameters
101
+ kwargs.update(self.extra_params)
102
+
103
+ try:
104
+ # Use litellm's async embedding function
105
+ response = await aembedding(**kwargs)
106
+ return (
107
+ response.model_dump() if hasattr(response, "model_dump") else response
108
+ )
109
+ except Exception as e:
110
+ self.log.exception(
111
+ "LiteLLM embedding API error", error=str(e), model=self.model_name
112
+ )
113
+ raise
114
+
115
+ async def embed(
116
+ self, data: list[EmbeddingRequest]
117
+ ) -> AsyncGenerator[list[EmbeddingResponse], None]:
118
+ """Embed a list of strings using LiteLLM."""
119
+ if not data:
120
+ yield []
121
+ return
122
+
123
+ # Split into batches
124
+ batched_data = self._split_sub_batches(data)
125
+
126
+ # Process batches concurrently with semaphore
127
+ sem = asyncio.Semaphore(self.num_parallel_tasks)
128
+
129
+ async def _process_batch(
130
+ batch: list[EmbeddingRequest],
131
+ ) -> list[EmbeddingResponse]:
132
+ async with sem:
133
+ try:
134
+ response = await self._call_embeddings_api(
135
+ [item.text for item in batch]
136
+ )
137
+ embeddings_data = response.get("data", [])
138
+
139
+ return [
140
+ EmbeddingResponse(
141
+ snippet_id=item.snippet_id,
142
+ embedding=emb_data.get("embedding", []),
143
+ )
144
+ for item, emb_data in zip(batch, embeddings_data, strict=True)
145
+ ]
146
+ except Exception as e:
147
+ self.log.exception("Error embedding batch", error=str(e))
148
+ # Return no embeddings for this batch if there was an error
149
+ return []
150
+
151
+ tasks = [_process_batch(batch) for batch in batched_data]
152
+ for task in asyncio.as_completed(tasks):
153
+ yield await task
154
+
155
+ async def close(self) -> None:
156
+ """Close the provider and cleanup HTTPX client if using Unix sockets."""
157
+ if (
158
+ self.socket_path
159
+ and hasattr(litellm, "aclient_session")
160
+ and litellm.aclient_session
161
+ ):
162
+ await litellm.aclient_session.aclose()
163
+ litellm.aclient_session = None
@@ -5,13 +5,12 @@ from kodit.domain.services.enrichment_service import (
5
5
  EnrichmentDomainService,
6
6
  EnrichmentProvider,
7
7
  )
8
+ from kodit.infrastructure.enrichment.litellm_enrichment_provider import (
9
+ LiteLLMEnrichmentProvider,
10
+ )
8
11
  from kodit.infrastructure.enrichment.local_enrichment_provider import (
9
12
  LocalEnrichmentProvider,
10
13
  )
11
- from kodit.infrastructure.enrichment.openai_enrichment_provider import (
12
- OPENAI_NUM_PARALLEL_TASKS,
13
- OpenAIEnrichmentProvider,
14
- )
15
14
  from kodit.log import log_event
16
15
 
17
16
 
@@ -25,7 +24,7 @@ def _get_endpoint_configuration(app_context: AppContext) -> Endpoint | None:
25
24
  The endpoint configuration or None.
26
25
 
27
26
  """
28
- return app_context.enrichment_endpoint or app_context.default_endpoint or None
27
+ return app_context.enrichment_endpoint or None
29
28
 
30
29
 
31
30
  def enrichment_domain_service_factory(
@@ -43,17 +42,9 @@ def enrichment_domain_service_factory(
43
42
  endpoint = _get_endpoint_configuration(app_context)
44
43
 
45
44
  enrichment_provider: EnrichmentProvider | None = None
46
- if endpoint and endpoint.type == "openai":
47
- log_event("kodit.enrichment", {"provider": "openai"})
48
- # Use new httpx-based provider with socket support
49
- enrichment_provider = OpenAIEnrichmentProvider(
50
- api_key=endpoint.api_key,
51
- base_url=endpoint.base_url or "https://api.openai.com/v1",
52
- model_name=endpoint.model or "gpt-4o-mini",
53
- num_parallel_tasks=endpoint.num_parallel_tasks or OPENAI_NUM_PARALLEL_TASKS,
54
- socket_path=endpoint.socket_path,
55
- timeout=endpoint.timeout or 30.0,
56
- )
45
+ if endpoint:
46
+ log_event("kodit.enrichment", {"provider": "litellm"})
47
+ enrichment_provider = LiteLLMEnrichmentProvider(endpoint=endpoint)
57
48
  else:
58
49
  log_event("kodit.enrichment", {"provider": "local"})
59
50
  enrichment_provider = LocalEnrichmentProvider()
@@ -1,12 +1,15 @@
1
- """OpenAI enrichment provider implementation using httpx."""
1
+ """LiteLLM enrichment provider implementation."""
2
2
 
3
3
  import asyncio
4
4
  from collections.abc import AsyncGenerator
5
5
  from typing import Any
6
6
 
7
7
  import httpx
8
+ import litellm
8
9
  import structlog
10
+ from litellm import acompletion
9
11
 
12
+ from kodit.config import Endpoint
10
13
  from kodit.domain.services.enrichment_service import EnrichmentProvider
11
14
  from kodit.domain.value_objects import EnrichmentRequest, EnrichmentResponse
12
15
  from kodit.infrastructure.enrichment.utils import clean_thinking_tags
@@ -16,60 +19,52 @@ You are a professional software developer. You will be given a snippet of code.
16
19
  Please provide a concise explanation of the code.
17
20
  """
18
21
 
19
- # Default tuned to approximately fit within OpenAI's rate limit of 500 / RPM
20
- OPENAI_NUM_PARALLEL_TASKS = 40
22
+ # Default tuned conservatively for broad provider compatibility
23
+ DEFAULT_NUM_PARALLEL_TASKS = 20
21
24
 
22
25
 
26
+ class LiteLLMEnrichmentProvider(EnrichmentProvider):
27
+ """LiteLLM enrichment provider that supports 100+ providers."""
23
28
 
24
- class OpenAIEnrichmentProvider(EnrichmentProvider):
25
- """OpenAI enrichment provider implementation using httpx."""
26
-
27
- def __init__( # noqa: PLR0913
29
+ def __init__(
28
30
  self,
29
- api_key: str | None = None,
30
- base_url: str = "https://api.openai.com",
31
- model_name: str = "gpt-4o-mini",
32
- num_parallel_tasks: int = OPENAI_NUM_PARALLEL_TASKS,
33
- socket_path: str | None = None,
34
- timeout: float = 30.0,
31
+ endpoint: Endpoint,
35
32
  ) -> None:
36
- """Initialize the OpenAI enrichment provider.
33
+ """Initialize the LiteLLM enrichment provider.
37
34
 
38
35
  Args:
39
- api_key: The OpenAI API key.
40
- base_url: The base URL for the OpenAI API.
41
- model_name: The model name to use for enrichment.
42
- num_parallel_tasks: Maximum number of concurrent requests.
43
- socket_path: Optional Unix socket path for local communication.
44
- timeout: Request timeout in seconds.
36
+ endpoint: The endpoint configuration containing all settings.
45
37
 
46
38
  """
47
39
  self.log = structlog.get_logger(__name__)
48
- self.model_name = model_name
49
- self.num_parallel_tasks = num_parallel_tasks
50
- self.api_key = api_key
51
- self.base_url = base_url
52
- self.socket_path = socket_path
53
- self.timeout = timeout
54
-
55
- # Create httpx client with optional Unix socket support
56
- if socket_path:
57
- transport = httpx.AsyncHTTPTransport(uds=socket_path)
58
- self.http_client = httpx.AsyncClient(
40
+ self.model_name = endpoint.model or "gpt-4o-mini"
41
+ self.api_key = endpoint.api_key
42
+ self.base_url = endpoint.base_url
43
+ self.socket_path = endpoint.socket_path
44
+ self.num_parallel_tasks = (
45
+ endpoint.num_parallel_tasks or DEFAULT_NUM_PARALLEL_TASKS
46
+ )
47
+ self.timeout = endpoint.timeout or 30.0
48
+ self.extra_params = endpoint.extra_params or {}
49
+
50
+ # Configure LiteLLM with custom HTTPX client for Unix socket support if needed
51
+ self._setup_litellm_client()
52
+
53
+ def _setup_litellm_client(self) -> None:
54
+ """Set up LiteLLM with custom HTTPX client for Unix socket support."""
55
+ if self.socket_path:
56
+ # Create HTTPX client with Unix socket transport
57
+ transport = httpx.AsyncHTTPTransport(uds=self.socket_path)
58
+ unix_client = httpx.AsyncClient(
59
59
  transport=transport,
60
60
  base_url="http://localhost", # Base URL for Unix socket
61
- timeout=timeout,
62
- )
63
- else:
64
- self.http_client = httpx.AsyncClient(
65
- base_url=base_url,
66
- timeout=timeout,
61
+ timeout=self.timeout,
67
62
  )
63
+ # Set as LiteLLM's async client session
64
+ litellm.aclient_session = unix_client
68
65
 
69
- async def _call_chat_completion(
70
- self, messages: list[dict[str, str]]
71
- ) -> dict[str, Any]:
72
- """Call the chat completion API using httpx.
66
+ async def _call_chat_completion(self, messages: list[dict[str, str]]) -> Any:
67
+ """Call the chat completion API using LiteLLM.
73
68
 
74
69
  Args:
75
70
  messages: The messages to send to the API.
@@ -78,29 +73,39 @@ class OpenAIEnrichmentProvider(EnrichmentProvider):
78
73
  The API response as a dictionary.
79
74
 
80
75
  """
81
- headers = {
82
- "Content-Type": "application/json",
83
- }
84
- if self.api_key:
85
- headers["Authorization"] = f"Bearer {self.api_key}"
86
-
87
- data = {
76
+ kwargs = {
88
77
  "model": self.model_name,
89
78
  "messages": messages,
79
+ "timeout": self.timeout,
90
80
  }
91
81
 
92
- response = await self.http_client.post(
93
- "/v1/chat/completions",
94
- json=data,
95
- headers=headers,
96
- )
97
- response.raise_for_status()
98
- return response.json()
82
+ # Add API key if provided
83
+ if self.api_key:
84
+ kwargs["api_key"] = self.api_key
85
+
86
+ # Add base_url if provided
87
+ if self.base_url:
88
+ kwargs["api_base"] = self.base_url
89
+
90
+ # Add extra parameters
91
+ kwargs.update(self.extra_params)
92
+
93
+ try:
94
+ # Use litellm's async completion function
95
+ response = await acompletion(**kwargs)
96
+ return (
97
+ response.model_dump() if hasattr(response, "model_dump") else response
98
+ )
99
+ except Exception as e:
100
+ self.log.exception(
101
+ "LiteLLM completion API error", error=str(e), model=self.model_name
102
+ )
103
+ raise
99
104
 
100
105
  async def enrich(
101
106
  self, requests: list[EnrichmentRequest]
102
107
  ) -> AsyncGenerator[EnrichmentResponse, None]:
103
- """Enrich a list of requests using OpenAI API.
108
+ """Enrich a list of requests using LiteLLM.
104
109
 
105
110
  Args:
106
111
  requests: List of enrichment requests.
@@ -113,7 +118,7 @@ class OpenAIEnrichmentProvider(EnrichmentProvider):
113
118
  self.log.warning("No requests for enrichment")
114
119
  return
115
120
 
116
- # Process batches in parallel with a semaphore to limit concurrent requests
121
+ # Process requests in parallel with a semaphore to limit concurrent requests
117
122
  sem = asyncio.Semaphore(self.num_parallel_tasks)
118
123
 
119
124
  async def process_request(request: EnrichmentRequest) -> EnrichmentResponse:
@@ -158,6 +163,11 @@ class OpenAIEnrichmentProvider(EnrichmentProvider):
158
163
  yield await task
159
164
 
160
165
  async def close(self) -> None:
161
- """Close the HTTP client."""
162
- if hasattr(self, "http_client"):
163
- await self.http_client.aclose()
166
+ """Close the provider and cleanup HTTPX client if using Unix sockets."""
167
+ if (
168
+ self.socket_path
169
+ and hasattr(litellm, "aclient_session")
170
+ and litellm.aclient_session
171
+ ):
172
+ await litellm.aclient_session.aclose()
173
+ litellm.aclient_session = None
@@ -3,6 +3,7 @@
3
3
  import tempfile
4
4
 
5
5
  import git
6
+ import structlog
6
7
 
7
8
 
8
9
  # FUTURE: move to clone dir
@@ -19,7 +20,12 @@ def is_valid_clone_target(target: str) -> bool:
19
20
  with tempfile.TemporaryDirectory() as temp_dir:
20
21
  try:
21
22
  git.Repo.clone_from(target, temp_dir)
22
- except git.GitCommandError:
23
+ except git.GitCommandError as e:
24
+ structlog.get_logger(__name__).warning(
25
+ "Failed to clone git repository",
26
+ target=target,
27
+ error=e,
28
+ )
23
29
  return False
24
30
  else:
25
31
  return True
kodit/log.py CHANGED
@@ -11,6 +11,7 @@ from functools import lru_cache
11
11
  from pathlib import Path
12
12
  from typing import Any
13
13
 
14
+ import litellm
14
15
  import rudderstack.analytics as rudder_analytics # type: ignore[import-untyped]
15
16
  import structlog
16
17
  from structlog.types import EventDict
@@ -99,6 +100,7 @@ def configure_logging(app_context: AppContext) -> None:
99
100
  "bm25s",
100
101
  "sentence_transformers.SentenceTransformer",
101
102
  "httpx",
103
+ "LiteLLM",
102
104
  ]:
103
105
  if root_logger.getEffectiveLevel() == logging.DEBUG:
104
106
  logging.getLogger(_log).handlers.clear()
@@ -106,6 +108,9 @@ def configure_logging(app_context: AppContext) -> None:
106
108
  else:
107
109
  logging.getLogger(_log).disabled = True
108
110
 
111
+ # More litellm logging cruft
112
+ litellm.suppress_debug_info = True
113
+
109
114
  # Configure SQLAlchemy loggers to use our structlog setup
110
115
  for _log in ["sqlalchemy.engine", "alembic"]:
111
116
  engine_logger = logging.getLogger(_log)
@@ -138,6 +143,7 @@ def configure_logging(app_context: AppContext) -> None:
138
143
 
139
144
  def configure_telemetry(app_context: AppContext) -> None:
140
145
  """Configure telemetry for the application."""
146
+ litellm.telemetry = False # Disable litellm telemetry by default
141
147
  if app_context.disable_telemetry:
142
148
  structlog.stdlib.get_logger(__name__).info("Telemetry has been disabled")
143
149
  rudder_analytics.send = False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kodit
3
- Version: 0.3.17
3
+ Version: 0.4.1
4
4
  Summary: Code indexing for better AI code generation
5
5
  Project-URL: Homepage, https://docs.helixml.tech/kodit/
6
6
  Project-URL: Documentation, https://docs.helixml.tech/kodit/
@@ -35,7 +35,8 @@ Requires-Dist: gitpython>=3.1.44
35
35
  Requires-Dist: hf-xet>=1.1.2
36
36
  Requires-Dist: httpx-retries>=0.3.2
37
37
  Requires-Dist: httpx>=0.28.1
38
- Requires-Dist: openai>=1.82.0
38
+ Requires-Dist: litellm>=1.75.8
39
+ Requires-Dist: openai==1.99.9
39
40
  Requires-Dist: pathspec>=0.12.1
40
41
  Requires-Dist: pydantic-settings>=2.9.1
41
42
  Requires-Dist: pystemmer>=3.0.0
@@ -1,12 +1,12 @@
1
1
  kodit/.gitignore,sha256=ztkjgRwL9Uud1OEi36hGQeDGk3OLK1NfDEO8YqGYy8o,11
2
2
  kodit/__init__.py,sha256=aEKHYninUq1yh6jaNfvJBYg-6fenpN132nJt1UU6Jxs,59
3
- kodit/_version.py,sha256=cGWA93r_16rR-tkjWO67nlhhjy0fPwiHpKtUFrImZgo,513
4
- kodit/app.py,sha256=r0w0JJsOacrQ5aAQ1yf-BK1CrYZPrvtUoH1LACd9FaA,4262
3
+ kodit/_version.py,sha256=k7cu0JKra64gmMNU_UfA5sw2eNc_GRvf3QmesiYAy8g,704
4
+ kodit/app.py,sha256=FKqHbJNoHpnBb5KLJaUINDm0a5cHpxlrQ5qNoWnHsBc,4326
5
5
  kodit/cli.py,sha256=VUZD4cPRgAnrKEWUl2PbS-nOA0FkDVqmJ2SR0g1yJsk,28202
6
6
  kodit/cli_utils.py,sha256=bW4rIm-elrsyM_pSGHh30zV0_oX7V-64pL3YSaBcOt0,2810
7
- kodit/config.py,sha256=YYo36lBi3auCssyCVrpw3Z3ZXSzXKTDo45nIsUjkzfs,10305
7
+ kodit/config.py,sha256=EOyB3BfSDfGJeOts2vrOyX8lhGXbAHVUl69kzjqGHXE,10326
8
8
  kodit/database.py,sha256=kI9yBm4uunsgV4-QeVoCBL0wLzU4kYmYv5qZilGnbPE,1740
9
- kodit/log.py,sha256=XyuseZk90gUBj1B7np2UO2EW9eE_ApayIpPRvI19KCE,8651
9
+ kodit/log.py,sha256=ZpM0eMo_DVGQqrHxg0VV6dMrN2AAmu_3C0I3G7p2nMw,8828
10
10
  kodit/mcp.py,sha256=aEcPc8dQiZaR0AswCZZNxcm_rhhUZNsEBimYti0ibSI,7221
11
11
  kodit/middleware.py,sha256=TiwebNpaEmiP7QRuZrfZcCL51IUefQyNLSPuzVyk8UM,2813
12
12
  kodit/reporting.py,sha256=icce1ZyiADsA_Qz-mSjgn2H4SSqKuGfLKnw-yrl9nsg,2722
@@ -17,7 +17,7 @@ kodit/application/services/__init__.py,sha256=p5UQNw-H5sxQvs5Etfte93B3cJ1kKW6DNx
17
17
  kodit/application/services/auto_indexing_service.py,sha256=O5BNR5HypgghzUFG4ykIWMl9mxHCUExnBmJuITIhECk,3457
18
18
  kodit/application/services/code_indexing_application_service.py,sha256=nrnd_Md-D0AfNKku7Aqt3YHDbXsBV9f44Z6XsjhiF3E,15877
19
19
  kodit/application/services/indexing_worker_service.py,sha256=Un4PytnWJU4uwROcxOMUFkt4cD7nmPezaBLsEHrMN6U,5185
20
- kodit/application/services/queue_service.py,sha256=GaixRoCUaDhLYfwZLVED8C3w_NPiy_QbuVp_jhwP4GI,1727
20
+ kodit/application/services/queue_service.py,sha256=vf_TEl76B0F0RSvfCeGDuM-QFzW-VUuj3zQaRmDPEYI,1921
21
21
  kodit/application/services/sync_scheduler.py,sha256=aLpEczZdTM8ubfAEY0Ajdh3MLfDcB9s-0ILZJrtIuZs,3504
22
22
  kodit/domain/__init__.py,sha256=TCpg4Xx-oF4mKV91lo4iXqMEfBT1OoRSYnbG-zVWolA,66
23
23
  kodit/domain/entities.py,sha256=QsCzKXT7gF9jTPAjJo5lqjFGRsIklAFC2qRy_Gt3RbA,10377
@@ -41,14 +41,16 @@ kodit/infrastructure/api/client/index_client.py,sha256=OxsakDQBEulwmqZVzwOSSI0Lk
41
41
  kodit/infrastructure/api/client/search_client.py,sha256=f4mM5ZJpAuR7w-i9yASbh4SYMxOq7_f4hXgaQesGquI,2614
42
42
  kodit/infrastructure/api/middleware/__init__.py,sha256=6m7eE5k5buboJbuzyX5E9-Tf99yNwFaeJF0f_6HwLyM,30
43
43
  kodit/infrastructure/api/middleware/auth.py,sha256=QSnMcMLWvfumqN1iG4ePj2vEZb2Dlsgr-WHptkEkkhE,1064
44
- kodit/infrastructure/api/v1/__init__.py,sha256=XYv4_9Z6fo69oMvC2mEbtD6DaMqHth29KHUOelmQFwM,121
44
+ kodit/infrastructure/api/v1/__init__.py,sha256=hQ03es21FSgzQlmdP5xWZzK80woIvuYGjiZLwFYuYwk,151
45
45
  kodit/infrastructure/api/v1/dependencies.py,sha256=jaM000IfSnvU8uzwnC1cBZsfsMC-19jWFjObHfqBYuM,2475
46
- kodit/infrastructure/api/v1/routers/__init__.py,sha256=L8hT_SkDzmCXIiWrFQWCkZXQ3UDy_ZMxPr8AIhjSWK0,160
46
+ kodit/infrastructure/api/v1/routers/__init__.py,sha256=YYyeiuyphIPc-Q_2totF8zfR0BoseOH4ZYFdHP0ed_M,218
47
47
  kodit/infrastructure/api/v1/routers/indexes.py,sha256=_lUir1M0SW6kPHeGqjiPjtSa50rY4PN2es5TZEpSHYE,3442
48
+ kodit/infrastructure/api/v1/routers/queue.py,sha256=EZbR-G0qDO9W5ajV_75GRk2pW1Qdgc0ggOwrGKlBE2A,2138
48
49
  kodit/infrastructure/api/v1/routers/search.py,sha256=da9YTR6VTzU85_6X3aaZemdTHGCEvcPNeKuMFBgmT_A,2452
49
50
  kodit/infrastructure/api/v1/schemas/__init__.py,sha256=_5BVqv4EUi_vvWlAQOE_VfRulUDAF21ZQ7z27y7YOdw,498
50
51
  kodit/infrastructure/api/v1/schemas/context.py,sha256=NlsIn9j1R3se7JkGZivS_CUN4gGP5NYaAtkRe3QH6dk,214
51
52
  kodit/infrastructure/api/v1/schemas/index.py,sha256=NtL09YtO50h-ddpAFxNf-dyxu_Xi5v3yOpKW0W4xsAM,1950
53
+ kodit/infrastructure/api/v1/schemas/queue.py,sha256=oa4wumWOvGzi53Q3cjwIrQJRoentp5nsQSsaj-l-B4U,652
52
54
  kodit/infrastructure/api/v1/schemas/search.py,sha256=CWzg5SIMUJ_4yM-ZfgSLWCanMxov6AyGgQQcOMkRlGw,5618
53
55
  kodit/infrastructure/bm25/__init__.py,sha256=DmGbrEO34FOJy4e685BbyxLA7gPW1eqs2gAxsp6JOuM,34
54
56
  kodit/infrastructure/bm25/bm25_factory.py,sha256=I4eo7qRslnyXIRkBf-StZ5ga2Evrr5J5YFocTChFD3g,884
@@ -59,22 +61,22 @@ kodit/infrastructure/cloning/metadata.py,sha256=GD2UnCC1oR82RD0SVUqk9CJOqzXPxhOA
59
61
  kodit/infrastructure/cloning/git/__init__.py,sha256=20ePcp0qE6BuLsjsv4KYB1DzKhMIMsPXwEqIEZtjTJs,34
60
62
  kodit/infrastructure/cloning/git/working_copy.py,sha256=qYcrR5qP1rhWZiYGMT1p-1Alavi_YvQLXx4MgIV7eXs,2611
61
63
  kodit/infrastructure/embedding/__init__.py,sha256=F-8nLlWAerYJ0MOIA4tbXHLan8bW5rRR84vzxx6tRKI,39
62
- kodit/infrastructure/embedding/embedding_factory.py,sha256=8LC2jKf2vx-P-TCh8ZanxwF3hT5PSjWA3vuSR6ggcXk,3731
64
+ kodit/infrastructure/embedding/embedding_factory.py,sha256=wngBD2g6NniHDq_-KcYhhwSvmcMYyI8yIzoXvGQvu1U,3287
63
65
  kodit/infrastructure/embedding/local_vector_search_repository.py,sha256=ExweyNEL5cP-g3eDhGqZSih7zhdOrop2WdFPPJL-tB4,3505
64
66
  kodit/infrastructure/embedding/vectorchord_vector_search_repository.py,sha256=PIoU0HsDlaoXDXnGjOR0LAkAcW4JiE3ymJy_SBhEopc,8030
65
67
  kodit/infrastructure/embedding/embedding_providers/__init__.py,sha256=qeZ-oAIAxMl5QqebGtO1lq-tHjl_ucAwOXePklcwwGk,34
66
68
  kodit/infrastructure/embedding/embedding_providers/batching.py,sha256=a8CL9PX2VLmbeg616fc_lQzfC4BWTVn32m4SEhXpHxc,3279
67
69
  kodit/infrastructure/embedding/embedding_providers/hash_embedding_provider.py,sha256=V6OdCuWyQQOvo3OJGRi-gBKDApIcrELydFg7T696P5s,2257
70
+ kodit/infrastructure/embedding/embedding_providers/litellm_embedding_provider.py,sha256=5LCrPSQn3ZaLZ1XTKzJV_LzANH7FdaR4NL-gJupaiDE,5579
68
71
  kodit/infrastructure/embedding/embedding_providers/local_embedding_provider.py,sha256=9aLV1Zg4KMhYWlGRwgAUtswW4aIabNqbsipWhAn64RI,4133
69
- kodit/infrastructure/embedding/embedding_providers/openai_embedding_provider.py,sha256=CE86s8IicieUjIDWn2xzswteHXCzmw1Qz6Kp4GBIcus,6316
70
72
  kodit/infrastructure/enrichment/__init__.py,sha256=8acZKNzql8Fs0lceFu9U3KoUrOptRBtVIxr_Iw6lz3Y,40
71
- kodit/infrastructure/enrichment/enrichment_factory.py,sha256=jZWGgAvFjEuRUc1oW3iGhgipvX-EnVJZpw6ybzp9NGM,2016
73
+ kodit/infrastructure/enrichment/enrichment_factory.py,sha256=NFGY6u9SJ_GOgiB_RtotbQmte0kGFQUymwzZCbbsx34,1530
74
+ kodit/infrastructure/enrichment/litellm_enrichment_provider.py,sha256=AM4-4KApDndzWzQzzKAedy21iGMhkwylR5VCmV9K-uI,6040
72
75
  kodit/infrastructure/enrichment/local_enrichment_provider.py,sha256=aVU3_kbLJ0BihwGIwvJ00DBe0voHkiKdFSjPxxkVfVA,4150
73
76
  kodit/infrastructure/enrichment/null_enrichment_provider.py,sha256=DhZkJBnkvXg_XSAs-oKiFnKqYFPnmTl3ikdxrqeEfbc,713
74
- kodit/infrastructure/enrichment/openai_enrichment_provider.py,sha256=C0y0NEPu1GpFr22TGi1voxYGsYTV0ZITYuDzvRJ5vW4,5573
75
77
  kodit/infrastructure/enrichment/utils.py,sha256=FE9UCuxxzSdoHrmAC8Si2b5D6Nf6kVqgM1yjUVyCvW0,930
76
78
  kodit/infrastructure/git/__init__.py,sha256=0iMosFzudj4_xNIMe2SRbV6l5bWqkjnUsZoFsoZFuM8,33
77
- kodit/infrastructure/git/git_utils.py,sha256=KERwmhWDR4ooMQKS-nSPxjvdCzoWF9NS6nhdeXyzdtY,571
79
+ kodit/infrastructure/git/git_utils.py,sha256=3Fg2ZX9pkp8Mk1mWuW30PSO_ZKXrPL7wTCS9TMTfIUM,765
78
80
  kodit/infrastructure/ignore/__init__.py,sha256=VzFv8XOzHmsu0MEGnWVSF6KsgqLBmvHlRqAkT1Xb1MY,36
79
81
  kodit/infrastructure/ignore/ignore_pattern_provider.py,sha256=zdxun3GodLfXxyssBK8QDUK58xb4fBJ0SKcHUyn3pzM,2131
80
82
  kodit/infrastructure/indexing/__init__.py,sha256=7UPRa2jwCAsa0Orsp6PqXSF8iIXJVzXHMFmrKkI9yH8,38
@@ -109,8 +111,8 @@ kodit/utils/__init__.py,sha256=DPEB1i8evnLF4Ns3huuAYg-0pKBFKUFuiDzOKG9r-sw,33
109
111
  kodit/utils/dump_openapi.py,sha256=29VdjHpNSaGAg7RjQw0meq1OLhljCx1ElgBlTC8xoF4,1247
110
112
  kodit/utils/generate_api_paths.py,sha256=TMtx9v55podDfUmiWaHgJHLtEWLV2sLL-5ejGFMPzAo,3569
111
113
  kodit/utils/path_utils.py,sha256=thK6YGGNvQThdBaCYCCeCvS1L8x-lwl3AoGht2jnjGw,1645
112
- kodit-0.3.17.dist-info/METADATA,sha256=xxyRRv2pL9aSP6MYMRPDsVNLR0gdFNg5CJ4UbbLwOAU,7672
113
- kodit-0.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
114
- kodit-0.3.17.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
115
- kodit-0.3.17.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
116
- kodit-0.3.17.dist-info/RECORD,,
114
+ kodit-0.4.1.dist-info/METADATA,sha256=Mf4UuPg2D08hfp1STtsZ1DwOHA2a1J4ba5_-T7Pifr4,7702
115
+ kodit-0.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
116
+ kodit-0.4.1.dist-info/entry_points.txt,sha256=hoTn-1aKyTItjnY91fnO-rV5uaWQLQ-Vi7V5et2IbHY,40
117
+ kodit-0.4.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
118
+ kodit-0.4.1.dist-info/RECORD,,
@@ -1,183 +0,0 @@
1
- """OpenAI embedding provider implementation using httpx."""
2
-
3
- import asyncio
4
- from collections.abc import AsyncGenerator
5
- from typing import Any
6
-
7
- import httpx
8
- import structlog
9
- import tiktoken
10
- from tiktoken import Encoding
11
-
12
- from kodit.domain.services.embedding_service import EmbeddingProvider
13
- from kodit.domain.value_objects import EmbeddingRequest, EmbeddingResponse
14
-
15
- from .batching import split_sub_batches
16
-
17
- # Constants
18
- MAX_TOKENS = 8192 # Conservative token limit for the embedding model
19
- BATCH_SIZE = (
20
- 10 # Maximum number of items per API call (keeps existing test expectations)
21
- )
22
- OPENAI_NUM_PARALLEL_TASKS = 10 # Semaphore limit for concurrent OpenAI requests
23
-
24
-
25
- class OpenAIEmbeddingProvider(EmbeddingProvider):
26
- """OpenAI embedding provider that uses OpenAI's embedding API via httpx."""
27
-
28
- def __init__( # noqa: PLR0913
29
- self,
30
- api_key: str | None = None,
31
- base_url: str = "https://api.openai.com",
32
- model_name: str = "text-embedding-3-small",
33
- num_parallel_tasks: int = OPENAI_NUM_PARALLEL_TASKS,
34
- socket_path: str | None = None,
35
- timeout: float = 30.0,
36
- ) -> None:
37
- """Initialize the OpenAI embedding provider.
38
-
39
- Args:
40
- api_key: The OpenAI API key.
41
- base_url: The base URL for the OpenAI API.
42
- model_name: The model name to use for embeddings.
43
- num_parallel_tasks: Maximum number of concurrent requests.
44
- socket_path: Optional Unix socket path for local communication.
45
- timeout: Request timeout in seconds.
46
-
47
- """
48
- self.model_name = model_name
49
- self.num_parallel_tasks = num_parallel_tasks
50
- self.log = structlog.get_logger(__name__)
51
- self.api_key = api_key
52
- self.base_url = base_url
53
- self.socket_path = socket_path
54
- self.timeout = timeout
55
-
56
- # Lazily initialised token encoding
57
- self._encoding: Encoding | None = None
58
-
59
- # Create httpx client with optional Unix socket support
60
- if socket_path:
61
- transport = httpx.AsyncHTTPTransport(uds=socket_path)
62
- self.http_client = httpx.AsyncClient(
63
- transport=transport,
64
- base_url="http://localhost", # Base URL for Unix socket
65
- timeout=timeout,
66
- )
67
- else:
68
- self.http_client = httpx.AsyncClient(
69
- base_url=base_url,
70
- timeout=timeout,
71
- )
72
-
73
- # ---------------------------------------------------------------------
74
- # Helper utilities
75
- # ---------------------------------------------------------------------
76
-
77
- def _get_encoding(self) -> "Encoding":
78
- """Return (and cache) the tiktoken encoding for the chosen model."""
79
- if self._encoding is None:
80
- try:
81
- self._encoding = tiktoken.encoding_for_model(self.model_name)
82
- except KeyError:
83
- # If the model is not supported by tiktoken, use a default encoding
84
- self.log.info(
85
- "Model not supported by tiktoken, using default encoding",
86
- model_name=self.model_name,
87
- default_encoding="o200k_base",
88
- )
89
- self._encoding = tiktoken.get_encoding("o200k_base")
90
-
91
- return self._encoding
92
-
93
- def _split_sub_batches(
94
- self, encoding: "Encoding", data: list[EmbeddingRequest]
95
- ) -> list[list[EmbeddingRequest]]:
96
- """Proxy to the shared batching utility (kept for backward-compat)."""
97
- return split_sub_batches(
98
- encoding,
99
- data,
100
- max_tokens=MAX_TOKENS,
101
- batch_size=BATCH_SIZE,
102
- )
103
-
104
- async def _call_embeddings_api(
105
- self, texts: list[str]
106
- ) -> dict[str, Any]:
107
- """Call the embeddings API using httpx.
108
-
109
- Args:
110
- texts: The texts to embed.
111
-
112
- Returns:
113
- The API response as a dictionary.
114
-
115
- """
116
- headers = {
117
- "Content-Type": "application/json",
118
- }
119
- if self.api_key:
120
- headers["Authorization"] = f"Bearer {self.api_key}"
121
-
122
- data = {
123
- "model": self.model_name,
124
- "input": texts,
125
- }
126
-
127
- response = await self.http_client.post(
128
- "/v1/embeddings",
129
- json=data,
130
- headers=headers,
131
- )
132
- response.raise_for_status()
133
- return response.json()
134
-
135
- async def embed(
136
- self, data: list[EmbeddingRequest]
137
- ) -> AsyncGenerator[list[EmbeddingResponse], None]:
138
- """Embed a list of strings using OpenAI's API."""
139
- if not data:
140
- yield []
141
-
142
- encoding = self._get_encoding()
143
-
144
- # First, split by token limits (and max batch size)
145
- batched_data = self._split_sub_batches(encoding, data)
146
-
147
- # -----------------------------------------------------------------
148
- # Process batches concurrently (but bounded by a semaphore)
149
- # -----------------------------------------------------------------
150
-
151
- sem = asyncio.Semaphore(self.num_parallel_tasks)
152
-
153
- async def _process_batch(
154
- batch: list[EmbeddingRequest],
155
- ) -> list[EmbeddingResponse]:
156
- async with sem:
157
- try:
158
- response = await self._call_embeddings_api(
159
- [item.text for item in batch]
160
- )
161
- embeddings_data = response.get("data", [])
162
-
163
- return [
164
- EmbeddingResponse(
165
- snippet_id=item.snippet_id,
166
- embedding=emb_data.get("embedding", []),
167
- )
168
- for item, emb_data in zip(batch, embeddings_data, strict=True)
169
- ]
170
- except Exception as e:
171
- self.log.exception("Error embedding batch", error=str(e))
172
- # Return no embeddings for this batch if there was an error
173
- return []
174
-
175
- tasks = [_process_batch(batch) for batch in batched_data]
176
- for task in asyncio.as_completed(tasks):
177
- yield await task
178
-
179
- async def close(self) -> None:
180
- """Close the HTTP client."""
181
- if hasattr(self, "http_client"):
182
- await self.http_client.aclose()
183
-
File without changes