donkit-llm 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,6 +21,8 @@ from .model_abstract import (
21
21
  class OpenAIModel(LLMModelAbstract):
22
22
  """OpenAI model implementation supporting GPT-4, GPT-3.5, etc."""
23
23
 
24
+ name = "openai"
25
+
24
26
  def __init__(
25
27
  self,
26
28
  model_name: str,
@@ -61,17 +63,11 @@ class OpenAIModel(LLMModelAbstract):
61
63
  | ModelCapability.STREAMING
62
64
  | ModelCapability.STRUCTURED_OUTPUT
63
65
  | ModelCapability.TOOL_CALLING
66
+ | ModelCapability.VISION
67
+ | ModelCapability.MULTIMODAL_INPUT
64
68
  )
65
69
 
66
70
  model_lower = self._model_name.lower()
67
-
68
- # Vision models (GPT-4o, GPT-4 Turbo, GPT-5, etc.)
69
- if any(
70
- x in model_lower
71
- for x in ["gpt-4o", "gpt-4-turbo", "gpt-4-vision", "gpt-5", "o1", "o3"]
72
- ):
73
- caps |= ModelCapability.VISION | ModelCapability.MULTIMODAL_INPUT
74
-
75
71
  # Audio models
76
72
  if "audio" in model_lower:
77
73
  caps |= ModelCapability.AUDIO_INPUT | ModelCapability.MULTIMODAL_INPUT
@@ -109,13 +105,13 @@ class OpenAIModel(LLMModelAbstract):
109
105
  # Multimodal content
110
106
  content_parts = []
111
107
  for part in msg.content:
112
- if part.type == ContentType.TEXT:
108
+ if part.content_type == ContentType.TEXT:
113
109
  content_parts.append({"type": "text", "text": part.content})
114
- elif part.type == ContentType.IMAGE_URL:
110
+ elif part.content_type == ContentType.IMAGE_URL:
115
111
  content_parts.append(
116
112
  {"type": "image_url", "image_url": {"url": part.content}}
117
113
  )
118
- elif part.type == ContentType.IMAGE_BASE64:
114
+ elif part.content_type == ContentType.IMAGE_BASE64:
119
115
  content_parts.append(
120
116
  {
121
117
  "type": "image_url",
@@ -180,12 +176,19 @@ class OpenAIModel(LLMModelAbstract):
180
176
  "messages": messages,
181
177
  }
182
178
 
183
- # if request.temperature is not None:
184
- # kwargs["temperature"] = request.temperature
179
+ if request.temperature is not None:
180
+ kwargs["temperature"] = request.temperature
185
181
  if request.max_tokens is not None:
186
- kwargs["max_completion_tokens"] = (
187
- request.max_tokens if request.max_tokens <= 16384 else 16384
188
- )
182
+ # Use max_completion_tokens for GPT models, max_tokens for others
183
+ model_lower = self._model_name.lower()
184
+ if "gpt" in model_lower and "oss" not in model_lower:
185
+ kwargs["max_completion_tokens"] = (
186
+ request.max_tokens if request.max_tokens <= 16384 else 16384
187
+ )
188
+ else:
189
+ kwargs["max_tokens"] = (
190
+ request.max_tokens if request.max_tokens <= 16384 else 16384
191
+ )
189
192
  if request.top_p is not None:
190
193
  kwargs["top_p"] = request.top_p
191
194
  if request.stop:
@@ -206,40 +209,47 @@ class OpenAIModel(LLMModelAbstract):
206
209
  if request.response_format:
207
210
  kwargs["response_format"] = request.response_format
208
211
 
209
- response = await self.client.chat.completions.create(**kwargs)
212
+ try:
213
+ response = await self.client.chat.completions.create(**kwargs)
210
214
 
211
- choice = response.choices[0]
212
- message = choice.message
215
+ if not response.choices:
216
+ return GenerateResponse(content="Error: No response choices returned")
213
217
 
214
- # Extract content
215
- content = message.content
218
+ choice = response.choices[0]
219
+ message = choice.message
216
220
 
217
- # Extract tool calls
218
- tool_calls = None
219
- if message.tool_calls:
220
- tool_calls = [
221
- ToolCall(
222
- id=tc.id,
223
- type=tc.type,
224
- function=FunctionCall(
225
- name=tc.function.name, arguments=tc.function.arguments
226
- ),
227
- )
228
- for tc in message.tool_calls
229
- ]
221
+ # Extract content
222
+ content = message.content
230
223
 
231
- return GenerateResponse(
232
- content=content,
233
- tool_calls=tool_calls,
234
- finish_reason=choice.finish_reason,
235
- usage={
236
- "prompt_tokens": response.usage.prompt_tokens,
237
- "completion_tokens": response.usage.completion_tokens,
238
- "total_tokens": response.usage.total_tokens,
239
- }
240
- if response.usage
241
- else None,
242
- )
224
+ # Extract tool calls
225
+ tool_calls = None
226
+ if message.tool_calls:
227
+ tool_calls = [
228
+ ToolCall(
229
+ id=tc.id,
230
+ type=tc.type,
231
+ function=FunctionCall(
232
+ name=tc.function.name, arguments=tc.function.arguments
233
+ ),
234
+ )
235
+ for tc in message.tool_calls
236
+ ]
237
+
238
+ return GenerateResponse(
239
+ content=content,
240
+ tool_calls=tool_calls,
241
+ finish_reason=choice.finish_reason,
242
+ usage={
243
+ "prompt_tokens": response.usage.prompt_tokens,
244
+ "completion_tokens": response.usage.completion_tokens,
245
+ "total_tokens": response.usage.total_tokens,
246
+ }
247
+ if response.usage
248
+ else None,
249
+ )
250
+ except Exception as e:
251
+ error_msg = str(e)
252
+ return GenerateResponse(content=f"Error: {error_msg}")
243
253
 
244
254
  async def generate_stream(
245
255
  self, request: GenerateRequest
@@ -258,7 +268,16 @@ class OpenAIModel(LLMModelAbstract):
258
268
  if request.temperature is not None:
259
269
  kwargs["temperature"] = request.temperature
260
270
  if request.max_tokens is not None:
261
- kwargs["max_tokens"] = request.max_tokens
271
+ # Use max_completion_tokens for GPT models, max_tokens for others
272
+ model_lower = self._model_name.lower()
273
+ if "gpt" in model_lower and "oss" not in model_lower:
274
+ kwargs["max_completion_tokens"] = (
275
+ request.max_tokens if request.max_tokens <= 16384 else 16384
276
+ )
277
+ else:
278
+ kwargs["max_tokens"] = (
279
+ request.max_tokens if request.max_tokens <= 16384 else 16384
280
+ )
262
281
  if request.top_p is not None:
263
282
  kwargs["top_p"] = request.top_p
264
283
  if request.stop:
@@ -279,38 +298,70 @@ class OpenAIModel(LLMModelAbstract):
279
298
  if request.response_format:
280
299
  kwargs["response_format"] = request.response_format
281
300
 
282
- stream = await self.client.chat.completions.create(**kwargs)
283
-
284
- async for chunk in stream:
285
- if not chunk.choices:
286
- continue
287
-
288
- choice = chunk.choices[0]
289
- delta = choice.delta
290
-
291
- content = delta.content if delta.content else None
292
- finish_reason = choice.finish_reason
293
-
294
- # Handle tool calls in streaming
295
- tool_calls = None
296
- if delta.tool_calls:
301
+ try:
302
+ stream = await self.client.chat.completions.create(**kwargs)
303
+
304
+ # Accumulate tool calls across chunks
305
+ accumulated_tool_calls: dict[int, dict] = {}
306
+
307
+ async for chunk in stream:
308
+ if not chunk.choices:
309
+ continue
310
+
311
+ choice = chunk.choices[0]
312
+ delta = choice.delta
313
+
314
+ # Yield text content if present
315
+ if delta.content:
316
+ yield StreamChunk(content=delta.content, tool_calls=None)
317
+
318
+ # Accumulate tool calls
319
+ if delta.tool_calls:
320
+ for tc_delta in delta.tool_calls:
321
+ idx = tc_delta.index
322
+ if idx not in accumulated_tool_calls:
323
+ accumulated_tool_calls[idx] = {
324
+ "id": tc_delta.id or "",
325
+ "type": tc_delta.type or "function",
326
+ "function": {"name": "", "arguments": ""},
327
+ }
328
+
329
+ if tc_delta.id:
330
+ accumulated_tool_calls[idx]["id"] = tc_delta.id
331
+ if tc_delta.type:
332
+ accumulated_tool_calls[idx]["type"] = tc_delta.type
333
+ if tc_delta.function:
334
+ if tc_delta.function.name:
335
+ accumulated_tool_calls[idx]["function"]["name"] = (
336
+ tc_delta.function.name
337
+ )
338
+ if tc_delta.function.arguments:
339
+ accumulated_tool_calls[idx]["function"][
340
+ "arguments"
341
+ ] += tc_delta.function.arguments
342
+
343
+ # Yield finish reason if present
344
+ if choice.finish_reason:
345
+ yield StreamChunk(content=None, finish_reason=choice.finish_reason)
346
+
347
+ # Yield final response with accumulated tool calls if any
348
+ if accumulated_tool_calls:
297
349
  tool_calls = [
298
350
  ToolCall(
299
- id=tc.id or "",
300
- type=tc.type or "function",
351
+ id=tc_data["id"],
352
+ type=tc_data["type"],
301
353
  function=FunctionCall(
302
- name=tc.function.name or "",
303
- arguments=tc.function.arguments or "",
354
+ name=tc_data["function"]["name"],
355
+ arguments=tc_data["function"]["arguments"],
304
356
  ),
305
357
  )
306
- for tc in delta.tool_calls
358
+ for tc_data in accumulated_tool_calls.values()
307
359
  ]
360
+ yield StreamChunk(content=None, tool_calls=tool_calls)
308
361
 
309
- yield StreamChunk(
310
- content=content,
311
- tool_calls=tool_calls,
312
- finish_reason=finish_reason,
313
- )
362
+ except Exception as e:
363
+ error_msg = str(e)
364
+ yield StreamChunk(content=f"Error: {error_msg}")
314
365
 
315
366
 
316
367
  class AzureOpenAIModel(OpenAIModel):
@@ -366,9 +417,8 @@ class AzureOpenAIModel(OpenAIModel):
366
417
  | ModelCapability.TOOL_CALLING
367
418
  | ModelCapability.STRUCTURED_OUTPUT
368
419
  | ModelCapability.MULTIMODAL_INPUT
420
+ | ModelCapability.VISION
369
421
  )
370
- if "vision" in self._base_model_name.lower() or "4o" in self._base_model_name:
371
- caps |= ModelCapability.MULTIMODAL_INPUT
372
422
  return caps
373
423
 
374
424
  @property
@@ -390,8 +440,27 @@ class AzureOpenAIModel(OpenAIModel):
390
440
 
391
441
  async def generate(self, request: GenerateRequest) -> GenerateResponse:
392
442
  """Generate a response using Azure OpenAI API with parameter adaptation."""
393
- # Override to adapt parameters where needed, then call parent
394
- return await super().generate(request)
443
+ # Azure OpenAI uses deployment name instead of model name
444
+ # Temporarily override model_name with deployment_name
445
+ original_model = self._model_name
446
+ self._model_name = self._deployment_name
447
+ try:
448
+ return await super().generate(request)
449
+ finally:
450
+ self._model_name = original_model
451
+
452
+ async def generate_stream(
453
+ self, request: GenerateRequest
454
+ ) -> AsyncIterator[StreamChunk]:
455
+ """Generate a streaming response using Azure OpenAI API."""
456
+ # Azure OpenAI uses deployment name instead of model name
457
+ original_model = self._model_name
458
+ self._model_name = self._deployment_name
459
+ try:
460
+ async for chunk in super().generate_stream(request):
461
+ yield chunk
462
+ finally:
463
+ self._model_name = original_model
395
464
 
396
465
 
397
466
  class OpenAIEmbeddingModel(LLMModelAbstract):
@@ -481,19 +550,25 @@ class OpenAIEmbeddingModel(LLMModelAbstract):
481
550
  if request.dimensions:
482
551
  kwargs["dimensions"] = request.dimensions
483
552
 
484
- response = await self.client.embeddings.create(**kwargs)
553
+ try:
554
+ response = await self.client.embeddings.create(**kwargs)
485
555
 
486
- embeddings = [item.embedding for item in response.data]
556
+ embeddings = [item.embedding for item in response.data]
487
557
 
488
- return EmbeddingResponse(
489
- embeddings=embeddings,
490
- usage={
491
- "prompt_tokens": response.usage.prompt_tokens,
492
- "total_tokens": response.usage.total_tokens,
493
- }
494
- if response.usage
495
- else None,
496
- )
558
+ return EmbeddingResponse(
559
+ embeddings=embeddings,
560
+ usage={
561
+ "prompt_tokens": response.usage.prompt_tokens,
562
+ "total_tokens": response.usage.total_tokens,
563
+ }
564
+ if response.usage
565
+ else None,
566
+ metadata={
567
+ "dimensionality": len(embeddings[0]) if len(embeddings) > 0 else 0
568
+ },
569
+ )
570
+ except Exception as e:
571
+ raise Exception(f"Failed to generate embeddings: {e}")
497
572
 
498
573
 
499
574
  class AzureOpenAIEmbeddingModel(LLMModelAbstract):
@@ -556,12 +631,12 @@ class AzureOpenAIEmbeddingModel(LLMModelAbstract):
556
631
  return ModelCapability.EMBEDDINGS
557
632
 
558
633
  async def generate(self, request: GenerateRequest) -> GenerateResponse:
559
- raise NotImplementedError("Embedding models do not support text generation")
634
+ raise NotImplementedError("Embedding models does not support text generation")
560
635
 
561
636
  async def generate_stream(
562
637
  self, request: GenerateRequest
563
638
  ) -> AsyncIterator[StreamChunk]:
564
- raise NotImplementedError("Embedding models do not support text generation")
639
+ raise NotImplementedError("Embedding models does not support text generation")
565
640
 
566
641
  async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
567
642
  """Generate embeddings using Azure OpenAI API."""
@@ -572,16 +647,22 @@ class AzureOpenAIEmbeddingModel(LLMModelAbstract):
572
647
  if request.dimensions:
573
648
  kwargs["dimensions"] = request.dimensions
574
649
 
575
- response = await self.client.embeddings.create(**kwargs)
650
+ try:
651
+ response = await self.client.embeddings.create(**kwargs)
576
652
 
577
- embeddings = [item.embedding for item in response.data]
653
+ embeddings = [item.embedding for item in response.data]
578
654
 
579
- return EmbeddingResponse(
580
- embeddings=embeddings,
581
- usage={
582
- "prompt_tokens": response.usage.prompt_tokens,
583
- "total_tokens": response.usage.total_tokens,
584
- }
585
- if response.usage
586
- else None,
587
- )
655
+ return EmbeddingResponse(
656
+ embeddings=embeddings,
657
+ usage={
658
+ "prompt_tokens": response.usage.prompt_tokens,
659
+ "total_tokens": response.usage.total_tokens,
660
+ }
661
+ if response.usage
662
+ else None,
663
+ metadata={
664
+ "dimensions": len(embeddings[0]) if len(embeddings) > 0 else 0
665
+ },
666
+ )
667
+ except Exception as e:
668
+ raise Exception(f"Failed to generate embeddings: {e}")