donkit-llm 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
donkit/llm/__init__.py CHANGED
@@ -23,6 +23,8 @@ from .openai_model import (
23
23
  from .claude_model import ClaudeModel, ClaudeVertexModel
24
24
  from .vertex_model import VertexAIModel, VertexEmbeddingModel
25
25
  from .factory import ModelFactory
26
+ from .gemini_model import GeminiModel, GeminiEmbeddingModel
27
+ from .donkit_model import DonkitModel
26
28
 
27
29
  __all__ = [
28
30
  "ModelFactory",
@@ -52,4 +54,7 @@ __all__ = [
52
54
  "ClaudeVertexModel",
53
55
  "VertexAIModel",
54
56
  "VertexEmbeddingModel",
57
+ "GeminiModel",
58
+ "GeminiEmbeddingModel",
59
+ "DonkitModel",
55
60
  ]
@@ -20,6 +20,8 @@ from .model_abstract import (
20
20
  class ClaudeModel(LLMModelAbstract):
21
21
  """Anthropic Claude model implementation."""
22
22
 
23
+ name = "claude"
24
+
23
25
  def __init__(
24
26
  self,
25
27
  model_name: str,
@@ -82,9 +84,9 @@ class ClaudeModel(LLMModelAbstract):
82
84
  # Multimodal content
83
85
  content_parts = []
84
86
  for part in msg.content:
85
- if part.type == ContentType.TEXT:
87
+ if part.content_type == ContentType.TEXT:
86
88
  content_parts.append({"type": "text", "text": part.content})
87
- elif part.type == ContentType.IMAGE_URL:
89
+ elif part.content_type == ContentType.IMAGE_URL:
88
90
  # Claude expects base64 images, not URLs
89
91
  content_parts.append(
90
92
  {
@@ -95,7 +97,7 @@ class ClaudeModel(LLMModelAbstract):
95
97
  },
96
98
  }
97
99
  )
98
- elif part.type == ContentType.IMAGE_BASE64:
100
+ elif part.content_type == ContentType.IMAGE_BASE64:
99
101
  content_parts.append(
100
102
  {
101
103
  "type": "image",
@@ -319,9 +321,9 @@ class ClaudeVertexModel(LLMModelAbstract):
319
321
  # Multimodal content
320
322
  content_parts = []
321
323
  for part in msg.content:
322
- if part.type == ContentType.TEXT:
324
+ if part.content_type == ContentType.TEXT:
323
325
  content_parts.append({"type": "text", "text": part.content})
324
- elif part.type == ContentType.IMAGE_BASE64:
326
+ elif part.content_type == ContentType.IMAGE_BASE64:
325
327
  content_parts.append(
326
328
  {
327
329
  "type": "image",
@@ -0,0 +1,239 @@
1
+ from typing import Any, AsyncIterator
2
+
3
+ from donkit.ragops_api_gateway_client.client import RagopsAPIGatewayClient
4
+ from .model_abstract import (
5
+ EmbeddingRequest,
6
+ EmbeddingResponse,
7
+ FunctionCall,
8
+ GenerateRequest,
9
+ GenerateResponse,
10
+ LLMModelAbstract,
11
+ Message,
12
+ ModelCapability,
13
+ StreamChunk,
14
+ Tool,
15
+ ToolCall,
16
+ )
17
+
18
+
19
+ class DonkitModel(LLMModelAbstract):
20
+ """
21
+ Implementation of LLMModelAbstract that proxies requests via RagopsAPIGatewayClient.
22
+ """
23
+
24
+ name = "donkit"
25
+
26
+ def __init__(
27
+ self,
28
+ base_url: str,
29
+ api_token: str,
30
+ provider: str = "default",
31
+ model_name: str | None = None,
32
+ project_id: str | None = None,
33
+ ):
34
+ """
35
+ Initialize DonkitModel.
36
+
37
+ Args:
38
+ base_url: Base URL for the API Gateway
39
+ api_token: API token for authentication
40
+ provider: The LLM provider name
41
+ (e.g., "openai", "anthropic", "vertex", "azure_openai", "ollama", "default")
42
+ model_name: The specific model identifier (e.g., "gpt-4o", "claude-3-opus")
43
+ project_id: The project ID for the gateway
44
+ """
45
+ self.base_url = base_url
46
+ self.api_token = api_token
47
+ self.provider = provider
48
+ self._model_name = model_name
49
+ self.project_id = project_id
50
+ self._capabilities = self._determine_capabilities()
51
+
52
+ @property
53
+ def model_name(self) -> str:
54
+ return self._model_name
55
+
56
+ @model_name.setter
57
+ def model_name(self, value: str):
58
+ self._model_name = value
59
+ self._capabilities = self._determine_capabilities()
60
+
61
+ @property
62
+ def capabilities(self) -> ModelCapability:
63
+ return self._capabilities
64
+
65
+ def _determine_capabilities(self) -> ModelCapability:
66
+ """
67
+ Estimate capabilities based on model name.
68
+ Since this is a proxy, we assume modern defaults but refine based on keywords.
69
+ """
70
+ caps = (
71
+ ModelCapability.TEXT_GENERATION
72
+ | ModelCapability.STREAMING
73
+ | ModelCapability.STRUCTURED_OUTPUT
74
+ | ModelCapability.TOOL_CALLING
75
+ | ModelCapability.MULTIMODAL_INPUT
76
+ | ModelCapability.EMBEDDINGS
77
+ )
78
+ return caps
79
+
80
+ def _convert_message(self, msg: Message) -> dict:
81
+ """Convert internal Message to dictionary format expected by the Gateway."""
82
+ result: dict[str, Any] = {"role": msg.role}
83
+ if isinstance(msg.content, str):
84
+ result["content"] = msg.content
85
+ else:
86
+ # Multimodal content processing
87
+ content_parts = []
88
+ for part in msg.content if msg.content else []:
89
+ content_parts.append(part.model_dump(exclude_none=True))
90
+ result["content"] = content_parts
91
+ if msg.tool_calls:
92
+ result["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls]
93
+ if msg.tool_call_id:
94
+ result["tool_call_id"] = msg.tool_call_id
95
+ if msg.name:
96
+ result["name"] = msg.name
97
+
98
+ return result
99
+
100
+ def _convert_tools(self, tools: list[Tool]) -> list[dict]:
101
+ """Convert internal Tool definitions to Gateway dictionary format."""
102
+ return [tool.model_dump(exclude_none=True) for tool in tools]
103
+
104
+ def _prepare_generate_kwargs(self, request: GenerateRequest) -> dict:
105
+ """Prepare kwargs for generate/generate_stream calls."""
106
+ messages = [self._convert_message(msg) for msg in request.messages]
107
+ tools_payload = self._convert_tools(request.tools) if request.tools else None
108
+
109
+ kwargs: dict[str, Any] = {
110
+ "provider": self.provider,
111
+ "model_name": self._model_name,
112
+ "messages": messages,
113
+ "project_id": self.project_id,
114
+ }
115
+
116
+ if request.temperature is not None:
117
+ kwargs["temperature"] = request.temperature
118
+ if request.max_tokens is not None:
119
+ kwargs["max_tokens"] = request.max_tokens
120
+ if request.top_p is not None:
121
+ kwargs["top_p"] = request.top_p
122
+ if request.stop:
123
+ kwargs["stop"] = request.stop
124
+ if tools_payload:
125
+ kwargs["tools"] = tools_payload
126
+ if request.tool_choice:
127
+ if isinstance(request.tool_choice, (str, dict)):
128
+ kwargs["tool_choice"] = request.tool_choice
129
+ else:
130
+ kwargs["tool_choice"] = "auto"
131
+ if request.response_format:
132
+ kwargs["response_format"] = request.response_format
133
+
134
+ return kwargs
135
+
136
+ async def generate(self, request: GenerateRequest) -> GenerateResponse:
137
+ """Generate a response using RagopsAPIGatewayClient."""
138
+ await self.validate_request(request)
139
+
140
+ kwargs = self._prepare_generate_kwargs(request)
141
+
142
+ async with RagopsAPIGatewayClient(
143
+ base_url=self.base_url,
144
+ api_token=self.api_token,
145
+ ) as client:
146
+ response_dict = await client.generate(**kwargs)
147
+
148
+ # Gateway returns simplified format: {content, tool_calls, finish_reason, usage}
149
+ content = response_dict.get("content")
150
+ finish_reason = response_dict.get("finish_reason")
151
+
152
+ # Extract tool calls
153
+ tool_calls = None
154
+ if response_dict.get("tool_calls"):
155
+ tool_calls = [
156
+ ToolCall(
157
+ id=tc.get("id"),
158
+ type=tc.get("type", "function"),
159
+ function=FunctionCall(
160
+ name=tc.get("function", {}).get("name"),
161
+ arguments=tc.get("function", {}).get("arguments"),
162
+ ),
163
+ )
164
+ for tc in response_dict["tool_calls"]
165
+ ]
166
+
167
+ usage_data = response_dict.get("usage", {})
168
+
169
+ return GenerateResponse(
170
+ content=content,
171
+ tool_calls=tool_calls,
172
+ finish_reason=finish_reason,
173
+ usage={
174
+ "prompt_tokens": usage_data.get("prompt_tokens"),
175
+ "completion_tokens": usage_data.get("completion_tokens"),
176
+ "total_tokens": usage_data.get("total_tokens"),
177
+ }
178
+ if usage_data
179
+ else None,
180
+ metadata=response_dict.get("metadata"),
181
+ )
182
+
183
+ async def generate_stream(
184
+ self, request: GenerateRequest
185
+ ) -> AsyncIterator[StreamChunk]:
186
+ """Generate a streaming response using RagopsAPIGatewayClient."""
187
+ await self.validate_request(request)
188
+
189
+ kwargs = self._prepare_generate_kwargs(request)
190
+
191
+ async with RagopsAPIGatewayClient(
192
+ base_url=self.base_url,
193
+ api_token=self.api_token,
194
+ ) as client:
195
+ # Iterate over the stream from client
196
+ async for chunk_dict in client.generate_stream(**kwargs):
197
+ content = chunk_dict.get("content")
198
+ finish_reason = chunk_dict.get("finish_reason")
199
+
200
+ tool_calls = None
201
+ if chunk_dict.get("tool_calls"):
202
+ tool_calls = [
203
+ ToolCall(
204
+ id=tc.get("id", ""),
205
+ type=tc.get("type", "function"),
206
+ function=FunctionCall(
207
+ name=tc.get("function", {}).get("name", ""),
208
+ arguments=tc.get("function", {}).get("arguments", ""),
209
+ ),
210
+ )
211
+ for tc in chunk_dict["tool_calls"]
212
+ ]
213
+
214
+ yield StreamChunk(
215
+ content=content,
216
+ tool_calls=tool_calls,
217
+ finish_reason=finish_reason,
218
+ metadata=chunk_dict.get("metadata", {}),
219
+ )
220
+
221
+ async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
222
+ """Generate embeddings using RagopsAPIGatewayClient."""
223
+
224
+ kwargs: dict[str, Any] = {
225
+ "provider": self.provider,
226
+ "input": request.input,
227
+ "model_name": self._model_name,
228
+ "project_id": self.project_id,
229
+ }
230
+
231
+ if request.dimensions:
232
+ kwargs["dimensions"] = request.dimensions
233
+ async with RagopsAPIGatewayClient(
234
+ base_url=self.base_url,
235
+ api_token=self.api_token,
236
+ ) as client:
237
+ response_dict = await client.embeddings(**kwargs)
238
+
239
+ return EmbeddingResponse(**response_dict)
donkit/llm/factory.py CHANGED
@@ -1,14 +1,16 @@
1
1
  from typing import Literal
2
2
 
3
- from .claude_model import ClaudeModel, ClaudeVertexModel
3
+ from .claude_model import ClaudeModel
4
+ from .claude_model import ClaudeVertexModel
5
+ from .donkit_model import DonkitModel
6
+ from .gemini_model import GeminiModel
4
7
  from .model_abstract import LLMModelAbstract
5
- from .openai_model import (
6
- AzureOpenAIEmbeddingModel,
7
- AzureOpenAIModel,
8
- OpenAIEmbeddingModel,
9
- OpenAIModel,
10
- )
11
- from .vertex_model import VertexAIModel, VertexEmbeddingModel
8
+ from .openai_model import AzureOpenAIEmbeddingModel
9
+ from .openai_model import AzureOpenAIModel
10
+ from .openai_model import OpenAIEmbeddingModel
11
+ from .openai_model import OpenAIModel
12
+ from .vertex_model import VertexAIModel
13
+ from .vertex_model import VertexEmbeddingModel
12
14
 
13
15
 
14
16
  class ModelFactory:
@@ -46,7 +48,7 @@ class ModelFactory:
46
48
 
47
49
  @staticmethod
48
50
  def create_embedding_model(
49
- provider: Literal["openai", "azure_openai", "vertex"],
51
+ provider: Literal["openai", "azure_openai", "vertex", "custom", "default"],
50
52
  model_name: str | None = None,
51
53
  api_key: str | None = None,
52
54
  **kwargs,
@@ -92,6 +94,35 @@ class ModelFactory:
92
94
  base_url=base_url,
93
95
  )
94
96
 
97
+ @staticmethod
98
+ def create_gemini_model(
99
+ model_name: str,
100
+ api_key: str | None = None,
101
+ project_id: str | None = None,
102
+ location: str = "us-central1",
103
+ use_vertex: bool = False,
104
+ ) -> GeminiModel:
105
+ """
106
+ Create a Gemini model instance.
107
+
108
+ Args:
109
+ model_name: Model identifier (e.g., "gemini-2.0-flash-exp")
110
+ api_key: Google AI API key (for AI Studio)
111
+ project_id: GCP project ID (for Vertex AI)
112
+ location: GCP location (for Vertex AI)
113
+ use_vertex: Whether to use Vertex AI instead of AI Studio
114
+
115
+ Returns:
116
+ Configured Gemini model instance
117
+ """
118
+ return GeminiModel(
119
+ model_name=model_name,
120
+ api_key=api_key,
121
+ project_id=project_id,
122
+ location=location,
123
+ use_vertex=use_vertex,
124
+ )
125
+
95
126
  @staticmethod
96
127
  def create_claude_vertex_model(
97
128
  model_name: str,
@@ -118,14 +149,57 @@ class ModelFactory:
118
149
  credentials=credentials,
119
150
  )
120
151
 
152
+ @staticmethod
153
+ def create_donkit_model(
154
+ model_name: str | None,
155
+ api_key: str,
156
+ base_url: str = "http://localhost:9017",
157
+ provider: str = "default",
158
+ ) -> DonkitModel:
159
+ """Create a Donkit model that proxies through RagOps API Gateway.
160
+
161
+ Args:
162
+ model_name: Name of the model to use
163
+ api_key: API key for authentication
164
+ base_url: Base URL of the RagOps API Gateway
165
+ provider: Provider to use e.g.:
166
+ vertex, openai, azure_openai, ollama, default
167
+ Returns:
168
+ DonkitModel instance
169
+ """
170
+ return DonkitModel(
171
+ base_url=base_url,
172
+ api_token=api_key,
173
+ provider=provider,
174
+ model_name=model_name,
175
+ )
176
+
121
177
  @staticmethod
122
178
  def create_model(
123
179
  provider: Literal[
124
- "openai", "azure_openai", "claude", "claude_vertex", "vertex", "ollama"
180
+ "openai",
181
+ "azure_openai",
182
+ "claude",
183
+ "claude_vertex",
184
+ "vertex",
185
+ "ollama",
186
+ "donkit",
125
187
  ],
126
- model_name: str,
188
+ model_name: str | None,
127
189
  credentials: dict,
128
190
  ) -> LLMModelAbstract:
191
+ if model_name is None:
192
+ default_models = {
193
+ "openai": "gpt-5-mini",
194
+ "azure_openai": "gpt-4.1-mini",
195
+ "claude": "claude-4-5-sonnet",
196
+ "claude_vertex": "claude-4-5-sonnet",
197
+ "gemini": "gemini-2.5-flash",
198
+ "vertex": "gemini-2.5-flash",
199
+ "ollama": "mistral",
200
+ "donkit": None,
201
+ }
202
+ model_name = default_models.get(provider, "default")
129
203
  if provider == "openai":
130
204
  return ModelFactory.create_openai_model(
131
205
  model_name=model_name,
@@ -136,11 +210,19 @@ class ModelFactory:
136
210
  elif provider == "azure_openai":
137
211
  return ModelFactory.create_azure_openai_model(
138
212
  model_name=model_name,
139
- api_key=credentials["api_key"],
140
- azure_endpoint=credentials["azure_endpoint"],
213
+ api_key=credentials.get("api_key"),
214
+ azure_endpoint=credentials.get("azure_endpoint"),
141
215
  api_version=credentials.get("api_version", "2024-08-01-preview"),
142
216
  deployment_name=credentials.get("deployment_name"),
143
217
  )
218
+ elif provider == "gemini":
219
+ return ModelFactory.create_gemini_model(
220
+ model_name=model_name,
221
+ api_key=credentials.get("api_key"),
222
+ project_id=credentials.get("project_id"),
223
+ location=credentials.get("location", "us-central1"),
224
+ use_vertex=credentials.get("use_vertex", False),
225
+ )
144
226
  elif provider == "claude":
145
227
  return ModelFactory.create_claude_model(
146
228
  model_name=model_name,
@@ -162,10 +244,19 @@ class ModelFactory:
162
244
  )
163
245
  elif provider == "ollama":
164
246
  # Ollama uses OpenAI-compatible API
247
+ ollama_url = credentials.get("ollama_url")
248
+ if "/v1" not in ollama_url:
249
+ ollama_url += "/v1"
165
250
  return ModelFactory.create_openai_model(
166
251
  model_name=model_name,
167
252
  api_key=credentials.get("api_key", "ollama"),
168
- base_url=credentials.get("base_url"),
253
+ base_url=ollama_url,
254
+ )
255
+ elif provider == "donkit":
256
+ return ModelFactory.create_donkit_model(
257
+ model_name=model_name,
258
+ api_key=credentials["api_key"],
259
+ base_url=credentials["base_url"],
169
260
  )
170
261
  else:
171
262
  raise ValueError(f"Unknown provider: {provider}")