donkit-llm 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,6 +21,8 @@ from .model_abstract import (
21
21
  class OpenAIModel(LLMModelAbstract):
22
22
  """OpenAI model implementation supporting GPT-4, GPT-3.5, etc."""
23
23
 
24
+ name = "openai"
25
+
24
26
  def __init__(
25
27
  self,
26
28
  model_name: str,
@@ -103,13 +105,13 @@ class OpenAIModel(LLMModelAbstract):
103
105
  # Multimodal content
104
106
  content_parts = []
105
107
  for part in msg.content:
106
- if part.type == ContentType.TEXT:
108
+ if part.content_type == ContentType.TEXT:
107
109
  content_parts.append({"type": "text", "text": part.content})
108
- elif part.type == ContentType.IMAGE_URL:
110
+ elif part.content_type == ContentType.IMAGE_URL:
109
111
  content_parts.append(
110
112
  {"type": "image_url", "image_url": {"url": part.content}}
111
113
  )
112
- elif part.type == ContentType.IMAGE_BASE64:
114
+ elif part.content_type == ContentType.IMAGE_BASE64:
113
115
  content_parts.append(
114
116
  {
115
117
  "type": "image_url",
@@ -174,12 +176,19 @@ class OpenAIModel(LLMModelAbstract):
174
176
  "messages": messages,
175
177
  }
176
178
 
177
- # if request.temperature is not None:
178
- # kwargs["temperature"] = request.temperature
179
+ if request.temperature is not None:
180
+ kwargs["temperature"] = request.temperature
179
181
  if request.max_tokens is not None:
180
- kwargs["max_completion_tokens"] = (
181
- request.max_tokens if request.max_tokens <= 16384 else 16384
182
- )
182
+ # Use max_completion_tokens for GPT models, max_tokens for others
183
+ model_lower = self._model_name.lower()
184
+ if "gpt" in model_lower and "oss" not in model_lower:
185
+ kwargs["max_completion_tokens"] = (
186
+ request.max_tokens if request.max_tokens <= 16384 else 16384
187
+ )
188
+ else:
189
+ kwargs["max_tokens"] = (
190
+ request.max_tokens if request.max_tokens <= 16384 else 16384
191
+ )
183
192
  if request.top_p is not None:
184
193
  kwargs["top_p"] = request.top_p
185
194
  if request.stop:
@@ -200,40 +209,47 @@ class OpenAIModel(LLMModelAbstract):
200
209
  if request.response_format:
201
210
  kwargs["response_format"] = request.response_format
202
211
 
203
- response = await self.client.chat.completions.create(**kwargs)
212
+ try:
213
+ response = await self.client.chat.completions.create(**kwargs)
204
214
 
205
- choice = response.choices[0]
206
- message = choice.message
215
+ if not response.choices:
216
+ return GenerateResponse(content="Error: No response choices returned")
207
217
 
208
- # Extract content
209
- content = message.content
218
+ choice = response.choices[0]
219
+ message = choice.message
210
220
 
211
- # Extract tool calls
212
- tool_calls = None
213
- if message.tool_calls:
214
- tool_calls = [
215
- ToolCall(
216
- id=tc.id,
217
- type=tc.type,
218
- function=FunctionCall(
219
- name=tc.function.name, arguments=tc.function.arguments
220
- ),
221
- )
222
- for tc in message.tool_calls
223
- ]
221
+ # Extract content
222
+ content = message.content
224
223
 
225
- return GenerateResponse(
226
- content=content,
227
- tool_calls=tool_calls,
228
- finish_reason=choice.finish_reason,
229
- usage={
230
- "prompt_tokens": response.usage.prompt_tokens,
231
- "completion_tokens": response.usage.completion_tokens,
232
- "total_tokens": response.usage.total_tokens,
233
- }
234
- if response.usage
235
- else None,
236
- )
224
+ # Extract tool calls
225
+ tool_calls = None
226
+ if message.tool_calls:
227
+ tool_calls = [
228
+ ToolCall(
229
+ id=tc.id,
230
+ type=tc.type,
231
+ function=FunctionCall(
232
+ name=tc.function.name, arguments=tc.function.arguments
233
+ ),
234
+ )
235
+ for tc in message.tool_calls
236
+ ]
237
+
238
+ return GenerateResponse(
239
+ content=content,
240
+ tool_calls=tool_calls,
241
+ finish_reason=choice.finish_reason,
242
+ usage={
243
+ "prompt_tokens": response.usage.prompt_tokens,
244
+ "completion_tokens": response.usage.completion_tokens,
245
+ "total_tokens": response.usage.total_tokens,
246
+ }
247
+ if response.usage
248
+ else None,
249
+ )
250
+ except Exception as e:
251
+ error_msg = str(e)
252
+ return GenerateResponse(content=f"Error: {error_msg}")
237
253
 
238
254
  async def generate_stream(
239
255
  self, request: GenerateRequest
@@ -252,7 +268,16 @@ class OpenAIModel(LLMModelAbstract):
252
268
  if request.temperature is not None:
253
269
  kwargs["temperature"] = request.temperature
254
270
  if request.max_tokens is not None:
255
- kwargs["max_tokens"] = request.max_tokens
271
+ # Use max_completion_tokens for GPT models, max_tokens for others
272
+ model_lower = self._model_name.lower()
273
+ if "gpt" in model_lower and "oss" not in model_lower:
274
+ kwargs["max_completion_tokens"] = (
275
+ request.max_tokens if request.max_tokens <= 16384 else 16384
276
+ )
277
+ else:
278
+ kwargs["max_tokens"] = (
279
+ request.max_tokens if request.max_tokens <= 16384 else 16384
280
+ )
256
281
  if request.top_p is not None:
257
282
  kwargs["top_p"] = request.top_p
258
283
  if request.stop:
@@ -273,38 +298,70 @@ class OpenAIModel(LLMModelAbstract):
273
298
  if request.response_format:
274
299
  kwargs["response_format"] = request.response_format
275
300
 
276
- stream = await self.client.chat.completions.create(**kwargs)
277
-
278
- async for chunk in stream:
279
- if not chunk.choices:
280
- continue
281
-
282
- choice = chunk.choices[0]
283
- delta = choice.delta
284
-
285
- content = delta.content if delta.content else None
286
- finish_reason = choice.finish_reason
287
-
288
- # Handle tool calls in streaming
289
- tool_calls = None
290
- if delta.tool_calls:
301
+ try:
302
+ stream = await self.client.chat.completions.create(**kwargs)
303
+
304
+ # Accumulate tool calls across chunks
305
+ accumulated_tool_calls: dict[int, dict] = {}
306
+
307
+ async for chunk in stream:
308
+ if not chunk.choices:
309
+ continue
310
+
311
+ choice = chunk.choices[0]
312
+ delta = choice.delta
313
+
314
+ # Yield text content if present
315
+ if delta.content:
316
+ yield StreamChunk(content=delta.content, tool_calls=None)
317
+
318
+ # Accumulate tool calls
319
+ if delta.tool_calls:
320
+ for tc_delta in delta.tool_calls:
321
+ idx = tc_delta.index
322
+ if idx not in accumulated_tool_calls:
323
+ accumulated_tool_calls[idx] = {
324
+ "id": tc_delta.id or "",
325
+ "type": tc_delta.type or "function",
326
+ "function": {"name": "", "arguments": ""},
327
+ }
328
+
329
+ if tc_delta.id:
330
+ accumulated_tool_calls[idx]["id"] = tc_delta.id
331
+ if tc_delta.type:
332
+ accumulated_tool_calls[idx]["type"] = tc_delta.type
333
+ if tc_delta.function:
334
+ if tc_delta.function.name:
335
+ accumulated_tool_calls[idx]["function"]["name"] = (
336
+ tc_delta.function.name
337
+ )
338
+ if tc_delta.function.arguments:
339
+ accumulated_tool_calls[idx]["function"][
340
+ "arguments"
341
+ ] += tc_delta.function.arguments
342
+
343
+ # Yield finish reason if present
344
+ if choice.finish_reason:
345
+ yield StreamChunk(content=None, finish_reason=choice.finish_reason)
346
+
347
+ # Yield final response with accumulated tool calls if any
348
+ if accumulated_tool_calls:
291
349
  tool_calls = [
292
350
  ToolCall(
293
- id=tc.id or "",
294
- type=tc.type or "function",
351
+ id=tc_data["id"],
352
+ type=tc_data["type"],
295
353
  function=FunctionCall(
296
- name=tc.function.name or "",
297
- arguments=tc.function.arguments or "",
354
+ name=tc_data["function"]["name"],
355
+ arguments=tc_data["function"]["arguments"],
298
356
  ),
299
357
  )
300
- for tc in delta.tool_calls
358
+ for tc_data in accumulated_tool_calls.values()
301
359
  ]
360
+ yield StreamChunk(content=None, tool_calls=tool_calls)
302
361
 
303
- yield StreamChunk(
304
- content=content,
305
- tool_calls=tool_calls,
306
- finish_reason=finish_reason,
307
- )
362
+ except Exception as e:
363
+ error_msg = str(e)
364
+ yield StreamChunk(content=f"Error: {error_msg}")
308
365
 
309
366
 
310
367
  class AzureOpenAIModel(OpenAIModel):
@@ -360,9 +417,8 @@ class AzureOpenAIModel(OpenAIModel):
360
417
  | ModelCapability.TOOL_CALLING
361
418
  | ModelCapability.STRUCTURED_OUTPUT
362
419
  | ModelCapability.MULTIMODAL_INPUT
420
+ | ModelCapability.VISION
363
421
  )
364
- if "vision" in self._base_model_name.lower() or "4o" in self._base_model_name:
365
- caps |= ModelCapability.MULTIMODAL_INPUT
366
422
  return caps
367
423
 
368
424
  @property
@@ -384,8 +440,27 @@ class AzureOpenAIModel(OpenAIModel):
384
440
 
385
441
  async def generate(self, request: GenerateRequest) -> GenerateResponse:
386
442
  """Generate a response using Azure OpenAI API with parameter adaptation."""
387
- # Override to adapt parameters where needed, then call parent
388
- return await super().generate(request)
443
+ # Azure OpenAI uses deployment name instead of model name
444
+ # Temporarily override model_name with deployment_name
445
+ original_model = self._model_name
446
+ self._model_name = self._deployment_name
447
+ try:
448
+ return await super().generate(request)
449
+ finally:
450
+ self._model_name = original_model
451
+
452
+ async def generate_stream(
453
+ self, request: GenerateRequest
454
+ ) -> AsyncIterator[StreamChunk]:
455
+ """Generate a streaming response using Azure OpenAI API."""
456
+ # Azure OpenAI uses deployment name instead of model name
457
+ original_model = self._model_name
458
+ self._model_name = self._deployment_name
459
+ try:
460
+ async for chunk in super().generate_stream(request):
461
+ yield chunk
462
+ finally:
463
+ self._model_name = original_model
389
464
 
390
465
 
391
466
  class OpenAIEmbeddingModel(LLMModelAbstract):
@@ -475,19 +550,25 @@ class OpenAIEmbeddingModel(LLMModelAbstract):
475
550
  if request.dimensions:
476
551
  kwargs["dimensions"] = request.dimensions
477
552
 
478
- response = await self.client.embeddings.create(**kwargs)
553
+ try:
554
+ response = await self.client.embeddings.create(**kwargs)
479
555
 
480
- embeddings = [item.embedding for item in response.data]
556
+ embeddings = [item.embedding for item in response.data]
481
557
 
482
- return EmbeddingResponse(
483
- embeddings=embeddings,
484
- usage={
485
- "prompt_tokens": response.usage.prompt_tokens,
486
- "total_tokens": response.usage.total_tokens,
487
- }
488
- if response.usage
489
- else None,
490
- )
558
+ return EmbeddingResponse(
559
+ embeddings=embeddings,
560
+ usage={
561
+ "prompt_tokens": response.usage.prompt_tokens,
562
+ "total_tokens": response.usage.total_tokens,
563
+ }
564
+ if response.usage
565
+ else None,
566
+ metadata={
567
+ "dimensionality": len(embeddings[0]) if len(embeddings) > 0 else 0
568
+ },
569
+ )
570
+ except Exception as e:
571
+ raise Exception(f"Failed to generate embeddings: {e}")
491
572
 
492
573
 
493
574
  class AzureOpenAIEmbeddingModel(LLMModelAbstract):
@@ -550,12 +631,12 @@ class AzureOpenAIEmbeddingModel(LLMModelAbstract):
550
631
  return ModelCapability.EMBEDDINGS
551
632
 
552
633
  async def generate(self, request: GenerateRequest) -> GenerateResponse:
553
- raise NotImplementedError("Embedding models do not support text generation")
634
+ raise NotImplementedError("Embedding models does not support text generation")
554
635
 
555
636
  async def generate_stream(
556
637
  self, request: GenerateRequest
557
638
  ) -> AsyncIterator[StreamChunk]:
558
- raise NotImplementedError("Embedding models do not support text generation")
639
+ raise NotImplementedError("Embedding models does not support text generation")
559
640
 
560
641
  async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
561
642
  """Generate embeddings using Azure OpenAI API."""
@@ -566,16 +647,22 @@ class AzureOpenAIEmbeddingModel(LLMModelAbstract):
566
647
  if request.dimensions:
567
648
  kwargs["dimensions"] = request.dimensions
568
649
 
569
- response = await self.client.embeddings.create(**kwargs)
650
+ try:
651
+ response = await self.client.embeddings.create(**kwargs)
570
652
 
571
- embeddings = [item.embedding for item in response.data]
653
+ embeddings = [item.embedding for item in response.data]
572
654
 
573
- return EmbeddingResponse(
574
- embeddings=embeddings,
575
- usage={
576
- "prompt_tokens": response.usage.prompt_tokens,
577
- "total_tokens": response.usage.total_tokens,
578
- }
579
- if response.usage
580
- else None,
581
- )
655
+ return EmbeddingResponse(
656
+ embeddings=embeddings,
657
+ usage={
658
+ "prompt_tokens": response.usage.prompt_tokens,
659
+ "total_tokens": response.usage.total_tokens,
660
+ }
661
+ if response.usage
662
+ else None,
663
+ metadata={
664
+ "dimensions": len(embeddings[0]) if len(embeddings) > 0 else 0
665
+ },
666
+ )
667
+ except Exception as e:
668
+ raise Exception(f"Failed to generate embeddings: {e}")