donkit-llm 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,406 @@
1
+ import json
2
+ from typing import AsyncIterator
3
+
4
+ import google.genai as genai
5
+ from google.genai.types import Content, FunctionDeclaration, Part
6
+ from google.genai.types import Tool as GeminiTool
7
+
8
+ from .model_abstract import (
9
+ ContentType,
10
+ EmbeddingRequest,
11
+ EmbeddingResponse,
12
+ FunctionCall,
13
+ GenerateRequest,
14
+ GenerateResponse,
15
+ LLMModelAbstract,
16
+ Message,
17
+ ModelCapability,
18
+ StreamChunk,
19
+ Tool,
20
+ ToolCall,
21
+ )
22
+
23
+
24
+ class GeminiModel(LLMModelAbstract):
25
+ """Google Gemini model implementation."""
26
+
27
+ name = "gemini"
28
+
29
+ def __init__(
30
+ self,
31
+ model_name: str,
32
+ api_key: str | None = None,
33
+ project_id: str | None = None,
34
+ location: str = "us-central1",
35
+ use_vertex: bool = False,
36
+ ):
37
+ """
38
+ Initialize Gemini model.
39
+
40
+ Args:
41
+ model_name: Model identifier (e.g., "gemini-2.0-flash-exp", "gemini-1.5-pro")
42
+ api_key: Google AI API key (for AI Studio)
43
+ project_id: GCP project ID (for Vertex AI)
44
+ location: GCP location (for Vertex AI)
45
+ use_vertex: Whether to use Vertex AI instead of AI Studio
46
+ """
47
+ self._model_name = model_name
48
+ self._use_vertex = use_vertex
49
+
50
+ if use_vertex:
51
+ if not project_id:
52
+ raise ValueError("project_id required for Vertex AI")
53
+ self.client = genai.Client(
54
+ vertexai=True,
55
+ project=project_id,
56
+ location=location,
57
+ )
58
+ else:
59
+ self.client = genai.Client(api_key=api_key)
60
+
61
+ self._capabilities = self._determine_capabilities()
62
+
63
+ def _determine_capabilities(self) -> ModelCapability:
64
+ """Determine capabilities based on model name."""
65
+ caps = (
66
+ ModelCapability.TEXT_GENERATION
67
+ | ModelCapability.STREAMING
68
+ | ModelCapability.STRUCTURED_OUTPUT
69
+ | ModelCapability.TOOL_CALLING
70
+ | ModelCapability.MULTIMODAL_INPUT
71
+ | ModelCapability.VISION
72
+ | ModelCapability.AUDIO_INPUT
73
+ )
74
+ return caps
75
+
76
+ @property
77
+ def model_name(self) -> str:
78
+ return self._model_name
79
+
80
+ @property
81
+ def capabilities(self) -> ModelCapability:
82
+ return self._capabilities
83
+
84
+ def _convert_message(self, msg: Message) -> Content:
85
+ """Convert internal Message to Gemini Content format."""
86
+ parts = []
87
+
88
+ if isinstance(msg.content, str):
89
+ parts.append(Part(text=msg.content))
90
+ else:
91
+ # Multimodal content
92
+ for part in msg.content:
93
+ if part.content_type == ContentType.TEXT:
94
+ parts.append(Part(text=part.content))
95
+ elif part.content_type == ContentType.IMAGE_URL:
96
+ # Gemini expects inline data for images
97
+ parts.append(
98
+ Part(
99
+ inline_data={
100
+ "mime_type": part.mime_type or "image/jpeg",
101
+ "data": part.content,
102
+ }
103
+ )
104
+ )
105
+ elif part.content_type == ContentType.IMAGE_BASE64:
106
+ parts.append(
107
+ Part(
108
+ inline_data={
109
+ "mime_type": part.mime_type or "image/jpeg",
110
+ "data": part.content,
111
+ }
112
+ )
113
+ )
114
+ # Add more content types as needed
115
+
116
+ # Map roles: assistant -> model
117
+ role = "model" if msg.role == "assistant" else msg.role
118
+
119
+ return Content(role=role, parts=parts)
120
+
121
+ def _convert_tools(self, tools: list[Tool]) -> list[GeminiTool]:
122
+ """Convert internal Tool definitions to Gemini format."""
123
+ function_declarations = []
124
+ for tool in tools:
125
+ func_def = tool.function
126
+ # Clean schema: remove $ref and $defs
127
+ parameters = self._clean_json_schema(func_def.parameters)
128
+
129
+ function_declarations.append(
130
+ FunctionDeclaration(
131
+ name=func_def.name,
132
+ description=func_def.description,
133
+ parameters=parameters,
134
+ )
135
+ )
136
+
137
+ return [GeminiTool(function_declarations=function_declarations)]
138
+
139
+ def _clean_json_schema(self, schema: dict) -> dict:
140
+ """
141
+ Remove $ref and $defs from JSON Schema as Gemini doesn't support them.
142
+
143
+ This is a simplified version - for production you'd want to resolve references.
144
+ """
145
+ if not isinstance(schema, dict):
146
+ return schema
147
+
148
+ cleaned = {}
149
+ for key, value in schema.items():
150
+ if key in ("$ref", "$defs", "definitions"):
151
+ continue
152
+ if isinstance(value, dict):
153
+ cleaned[key] = self._clean_json_schema(value)
154
+ elif isinstance(value, list):
155
+ cleaned[key] = [
156
+ self._clean_json_schema(item) if isinstance(item, dict) else item
157
+ for item in value
158
+ ]
159
+ else:
160
+ cleaned[key] = value
161
+
162
+ return cleaned
163
+
164
+ async def generate(self, request: GenerateRequest) -> GenerateResponse:
165
+ """Generate a response using Gemini API."""
166
+ await self.validate_request(request)
167
+
168
+ # Separate system message from conversation
169
+ system_instruction = None
170
+ messages = []
171
+ for msg in request.messages:
172
+ if msg.role == "system":
173
+ system_instruction = msg.content if isinstance(msg.content, str) else ""
174
+ else:
175
+ messages.append(self._convert_message(msg))
176
+
177
+ config_kwargs = {}
178
+ if request.temperature is not None:
179
+ config_kwargs["temperature"] = request.temperature
180
+ if request.max_tokens is not None:
181
+ config_kwargs["max_output_tokens"] = request.max_tokens
182
+ if request.top_p is not None:
183
+ config_kwargs["top_p"] = request.top_p
184
+ if request.stop:
185
+ config_kwargs["stop_sequences"] = request.stop
186
+ if request.response_format:
187
+ # Gemini uses response_mime_type and response_schema
188
+ config_kwargs["response_mime_type"] = "application/json"
189
+ if "schema" in request.response_format:
190
+ config_kwargs["response_schema"] = self._clean_json_schema(
191
+ request.response_format["schema"]
192
+ )
193
+
194
+ generate_kwargs = {
195
+ "model": self._model_name,
196
+ "contents": messages,
197
+ }
198
+
199
+ if system_instruction:
200
+ generate_kwargs["system_instruction"] = system_instruction
201
+ if config_kwargs:
202
+ generate_kwargs["config"] = config_kwargs
203
+ if request.tools:
204
+ generate_kwargs["tools"] = self._convert_tools(request.tools)
205
+
206
+ response = await self.client.aio.models.generate_content(**generate_kwargs)
207
+
208
+ # Extract content
209
+ content = None
210
+ if response.text:
211
+ content = response.text
212
+
213
+ # Extract tool calls (function calls in Gemini)
214
+ tool_calls = None
215
+ if response.candidates and response.candidates[0].content.parts:
216
+ function_calls = []
217
+ for part in response.candidates[0].content.parts:
218
+ if hasattr(part, "function_call") and part.function_call:
219
+ fc = part.function_call
220
+ # Convert function call args to JSON string
221
+ args_dict = dict(fc.args) if fc.args else {}
222
+ function_calls.append(
223
+ ToolCall(
224
+ id=fc.name, # Gemini doesn't have separate ID
225
+ type="function",
226
+ function=FunctionCall(
227
+ name=fc.name,
228
+ arguments=json.dumps(args_dict),
229
+ ),
230
+ )
231
+ )
232
+ if function_calls:
233
+ tool_calls = function_calls
234
+
235
+ # Extract finish reason
236
+ finish_reason = None
237
+ if response.candidates:
238
+ finish_reason = str(response.candidates[0].finish_reason)
239
+
240
+ # Extract usage
241
+ usage = None
242
+ if response.usage_metadata:
243
+ usage = {
244
+ "prompt_tokens": response.usage_metadata.prompt_token_count,
245
+ "completion_tokens": response.usage_metadata.candidates_token_count,
246
+ "total_tokens": response.usage_metadata.total_token_count,
247
+ }
248
+
249
+ return GenerateResponse(
250
+ content=content,
251
+ tool_calls=tool_calls,
252
+ finish_reason=finish_reason,
253
+ usage=usage,
254
+ )
255
+
256
+ async def generate_stream(
257
+ self, request: GenerateRequest
258
+ ) -> AsyncIterator[StreamChunk]:
259
+ """Generate a streaming response using Gemini API."""
260
+ await self.validate_request(request)
261
+
262
+ # Separate system message from conversation
263
+ system_instruction = None
264
+ messages = []
265
+ for msg in request.messages:
266
+ if msg.role == "system":
267
+ system_instruction = msg.content if isinstance(msg.content, str) else ""
268
+ else:
269
+ messages.append(self._convert_message(msg))
270
+
271
+ config_kwargs = {}
272
+ if request.temperature is not None:
273
+ config_kwargs["temperature"] = request.temperature
274
+ if request.max_tokens is not None:
275
+ config_kwargs["max_output_tokens"] = request.max_tokens
276
+ if request.top_p is not None:
277
+ config_kwargs["top_p"] = request.top_p
278
+ if request.stop:
279
+ config_kwargs["stop_sequences"] = request.stop
280
+ if request.response_format:
281
+ config_kwargs["response_mime_type"] = "application/json"
282
+ if "schema" in request.response_format:
283
+ config_kwargs["response_schema"] = self._clean_json_schema(
284
+ request.response_format["schema"]
285
+ )
286
+
287
+ generate_kwargs = {
288
+ "model": self._model_name,
289
+ "contents": messages,
290
+ }
291
+
292
+ if system_instruction:
293
+ generate_kwargs["system_instruction"] = system_instruction
294
+ if config_kwargs:
295
+ generate_kwargs["config"] = config_kwargs
296
+ if request.tools:
297
+ generate_kwargs["tools"] = self._convert_tools(request.tools)
298
+
299
+ stream = await self.client.aio.models.generate_content_stream(**generate_kwargs)
300
+
301
+ async for chunk in stream:
302
+ content = None
303
+ if chunk.text:
304
+ content = chunk.text
305
+
306
+ # Extract tool calls from chunk
307
+ tool_calls = None
308
+ if chunk.candidates and chunk.candidates[0].content.parts:
309
+ function_calls = []
310
+ for part in chunk.candidates[0].content.parts:
311
+ if hasattr(part, "function_call") and part.function_call:
312
+ fc = part.function_call
313
+ args_dict = dict(fc.args) if fc.args else {}
314
+ function_calls.append(
315
+ ToolCall(
316
+ id=fc.name,
317
+ type="function",
318
+ function=FunctionCall(
319
+ name=fc.name,
320
+ arguments=json.dumps(args_dict),
321
+ ),
322
+ )
323
+ )
324
+ if function_calls:
325
+ tool_calls = function_calls
326
+
327
+ finish_reason = None
328
+ if chunk.candidates:
329
+ finish_reason = str(chunk.candidates[0].finish_reason)
330
+
331
+ yield StreamChunk(
332
+ content=content,
333
+ tool_calls=tool_calls,
334
+ finish_reason=finish_reason,
335
+ )
336
+
337
+
338
+ class GeminiEmbeddingModel(LLMModelAbstract):
339
+ """Google Gemini embedding model implementation."""
340
+
341
+ def __init__(
342
+ self,
343
+ model_name: str = "text-embedding-004",
344
+ api_key: str | None = None,
345
+ project_id: str | None = None,
346
+ location: str = "us-central1",
347
+ use_vertex: bool = False,
348
+ ):
349
+ """
350
+ Initialize Gemini embedding model.
351
+
352
+ Args:
353
+ model_name: Embedding model name
354
+ api_key: Google AI API key (for AI Studio)
355
+ project_id: GCP project ID (for Vertex AI)
356
+ location: GCP location (for Vertex AI)
357
+ use_vertex: Whether to use Vertex AI
358
+ """
359
+ self._model_name = model_name
360
+ self._use_vertex = use_vertex
361
+
362
+ if use_vertex:
363
+ if not project_id:
364
+ raise ValueError("project_id required for Vertex AI")
365
+ self.client = genai.Client(
366
+ vertexai=True,
367
+ project=project_id,
368
+ location=location,
369
+ )
370
+ else:
371
+ self.client = genai.Client(api_key=api_key)
372
+
373
+ @property
374
+ def model_name(self) -> str:
375
+ return self._model_name
376
+
377
+ @property
378
+ def capabilities(self) -> ModelCapability:
379
+ return ModelCapability.EMBEDDINGS
380
+
381
+ async def generate(self, request: GenerateRequest) -> GenerateResponse:
382
+ raise NotImplementedError("Embedding models do not support text generation")
383
+
384
+ async def generate_stream(
385
+ self, request: GenerateRequest
386
+ ) -> AsyncIterator[StreamChunk]:
387
+ raise NotImplementedError("Embedding models do not support text generation")
388
+
389
+ async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
390
+ """Generate embeddings using Gemini API."""
391
+ inputs = [request.input] if isinstance(request.input, str) else request.input
392
+
393
+ # Gemini embedding API expects list of Content objects
394
+ contents = [Content(parts=[Part(text=text)]) for text in inputs]
395
+
396
+ response = await self.client.aio.models.embed_content(
397
+ model=self._model_name,
398
+ contents=contents,
399
+ )
400
+
401
+ embeddings = [emb.values for emb in response.embeddings]
402
+
403
+ return EmbeddingResponse(
404
+ embeddings=embeddings,
405
+ usage=0, # Gemini doesn't provide token usage for embeddings
406
+ )
@@ -34,7 +34,7 @@ class ContentType(str, Enum):
34
34
  class ContentPart(BaseModel):
35
35
  """A single part of message content (text, image, file, etc.)."""
36
36
 
37
- type: ContentType
37
+ content_type: ContentType
38
38
  content: str # Text, URL, or base64 data
39
39
  mime_type: str | None = None # For files/images/audio
40
40
  metadata: dict[str, Any] | None = None
@@ -44,20 +44,14 @@ class Message(BaseModel):
44
44
  """A single message in a conversation."""
45
45
 
46
46
  role: Literal["system", "user", "assistant", "tool"]
47
- content: str | list[ContentPart] # Simple text or multimodal parts
47
+ content: str | list[ContentPart] | None = (
48
+ None # Simple text, multimodal parts, or None for tool calls
49
+ )
48
50
  name: str | None = None # For tool/function messages
49
51
  tool_call_id: str | None = None # For tool response messages
50
52
  tool_calls: list["ToolCall"] | None = None # For assistant requesting tools
51
53
 
52
54
 
53
- class ToolCall(BaseModel):
54
- """A tool/function call requested by the model."""
55
-
56
- id: str
57
- type: Literal["function"] = "function"
58
- function: "FunctionCall"
59
-
60
-
61
55
  class FunctionCall(BaseModel):
62
56
  """Details of a function call."""
63
57
 
@@ -65,11 +59,12 @@ class FunctionCall(BaseModel):
65
59
  arguments: str # JSON string of arguments
66
60
 
67
61
 
68
- class Tool(BaseModel):
69
- """Definition of a tool/function available to the model."""
62
+ class ToolCall(BaseModel):
63
+ """A tool/function call requested by the model."""
70
64
 
65
+ id: str
71
66
  type: Literal["function"] = "function"
72
- function: "FunctionDefinition"
67
+ function: FunctionCall
73
68
 
74
69
 
75
70
  class FunctionDefinition(BaseModel):
@@ -81,6 +76,13 @@ class FunctionDefinition(BaseModel):
81
76
  strict: bool | None = None # For strict schema adherence (OpenAI)
82
77
 
83
78
 
79
+ class Tool(BaseModel):
80
+ """Definition of a tool/function available to the model."""
81
+
82
+ type: Literal["function"] = "function"
83
+ function: FunctionDefinition
84
+
85
+
84
86
  class GenerateRequest(BaseModel):
85
87
  """Request for text generation."""
86
88
 
@@ -145,17 +147,25 @@ class LLMModelAbstract(ABC):
145
147
  Not all models support all capabilities. Use `supports_capability()` to check.
146
148
  """
147
149
 
150
+ name: str = ""
151
+
148
152
  @property
149
153
  @abstractmethod
150
154
  def model_name(self) -> str:
151
155
  """Return the model name/identifier."""
152
- pass
156
+ raise NotImplementedError
157
+
158
+ @model_name.setter
159
+ @abstractmethod
160
+ def model_name(self, value: str) -> None:
161
+ """Set the model name/identifier."""
162
+ raise NotImplementedError
153
163
 
154
164
  @property
155
165
  @abstractmethod
156
166
  def capabilities(self) -> ModelCapability:
157
167
  """Return the capabilities supported by this model."""
158
- pass
168
+ raise NotImplementedError
159
169
 
160
170
  def supports_capability(self, capability: ModelCapability) -> bool:
161
171
  """Check if this model supports a specific capability."""
@@ -176,7 +186,7 @@ class LLMModelAbstract(ABC):
176
186
  NotImplementedError: If the model doesn't support text generation.
177
187
  ValueError: If request contains unsupported features.
178
188
  """
179
- pass
189
+ raise NotImplementedError
180
190
 
181
191
  @abstractmethod
182
192
  async def generate_stream(
@@ -195,7 +205,7 @@ class LLMModelAbstract(ABC):
195
205
  NotImplementedError: If the model doesn't support streaming.
196
206
  ValueError: If request contains unsupported features.
197
207
  """
198
- pass
208
+ raise NotImplementedError
199
209
 
200
210
  async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
201
211
  """