donkit-llm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,587 @@
1
+ from typing import AsyncIterator
2
+
3
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
4
+
5
+ from .model_abstract import (
6
+ ContentType,
7
+ EmbeddingRequest,
8
+ EmbeddingResponse,
9
+ FunctionCall,
10
+ GenerateRequest,
11
+ GenerateResponse,
12
+ LLMModelAbstract,
13
+ Message,
14
+ ModelCapability,
15
+ StreamChunk,
16
+ Tool,
17
+ ToolCall,
18
+ )
19
+
20
+
21
+ class OpenAIModel(LLMModelAbstract):
22
+ """OpenAI model implementation supporting GPT-4, GPT-3.5, etc."""
23
+
24
+ def __init__(
25
+ self,
26
+ model_name: str,
27
+ api_key: str,
28
+ base_url: str | None = None,
29
+ organization: str | None = None,
30
+ ):
31
+ """
32
+ Initialize OpenAI model.
33
+
34
+ Args:
35
+ model_name: Model identifier (e.g., "gpt-4o", "gpt-4o-mini")
36
+ api_key: OpenAI API key
37
+ base_url: Optional custom base URL
38
+ organization: Optional organization ID
39
+ """
40
+ self._model_name = model_name
41
+ self._init_client(api_key, base_url, organization)
42
+ self._capabilities = self._determine_capabilities()
43
+
44
+ def _init_client(
45
+ self,
46
+ api_key: str,
47
+ base_url: str | None = None,
48
+ organization: str | None = None,
49
+ ) -> None:
50
+ """Initialize the OpenAI client. Can be overridden by subclasses."""
51
+ self.client = AsyncOpenAI(
52
+ api_key=api_key,
53
+ base_url=base_url,
54
+ organization=organization,
55
+ )
56
+
57
+ def _determine_capabilities(self) -> ModelCapability:
58
+ """Determine capabilities based on model name."""
59
+ caps = (
60
+ ModelCapability.TEXT_GENERATION
61
+ | ModelCapability.STREAMING
62
+ | ModelCapability.STRUCTURED_OUTPUT
63
+ | ModelCapability.TOOL_CALLING
64
+ )
65
+
66
+ model_lower = self._model_name.lower()
67
+
68
+ # Vision models (GPT-4o, GPT-4 Turbo, GPT-5, etc.)
69
+ if any(
70
+ x in model_lower
71
+ for x in ["gpt-4o", "gpt-4-turbo", "gpt-4-vision", "gpt-5", "o1", "o3"]
72
+ ):
73
+ caps |= ModelCapability.VISION | ModelCapability.MULTIMODAL_INPUT
74
+
75
+ # Audio models
76
+ if "audio" in model_lower:
77
+ caps |= ModelCapability.AUDIO_INPUT | ModelCapability.MULTIMODAL_INPUT
78
+
79
+ return caps
80
+
81
+ @property
82
+ def model_name(self) -> str:
83
+ return self._model_name
84
+
85
+ @model_name.setter
86
+ def model_name(self, value: str):
87
+ """
88
+ Set new model name and recalculate capabilities.
89
+
90
+ Args:
91
+ value: New model name
92
+ """
93
+ self._model_name = value
94
+ # Recalculate capabilities based on new model name
95
+ self._capabilities = self._determine_capabilities()
96
+
97
+ @property
98
+ def capabilities(self) -> ModelCapability:
99
+ return self._capabilities
100
+
101
+ def _convert_message(self, msg: Message) -> dict:
102
+ """Convert internal Message to OpenAI format."""
103
+ result = {"role": msg.role}
104
+
105
+ # Handle content
106
+ if isinstance(msg.content, str):
107
+ result["content"] = msg.content
108
+ else:
109
+ # Multimodal content
110
+ content_parts = []
111
+ for part in msg.content:
112
+ if part.type == ContentType.TEXT:
113
+ content_parts.append({"type": "text", "text": part.content})
114
+ elif part.type == ContentType.IMAGE_URL:
115
+ content_parts.append(
116
+ {"type": "image_url", "image_url": {"url": part.content}}
117
+ )
118
+ elif part.type == ContentType.IMAGE_BASE64:
119
+ content_parts.append(
120
+ {
121
+ "type": "image_url",
122
+ "image_url": {
123
+ "url": f"data:{part.mime_type or 'image/jpeg'};base64,{part.content}"
124
+ },
125
+ }
126
+ )
127
+ # Add more content types as needed
128
+ result["content"] = content_parts
129
+
130
+ # Handle tool calls
131
+ if msg.tool_calls:
132
+ result["tool_calls"] = [
133
+ {
134
+ "id": tc.id,
135
+ "type": tc.type,
136
+ "function": {
137
+ "name": tc.function.name,
138
+ "arguments": tc.function.arguments,
139
+ },
140
+ }
141
+ for tc in msg.tool_calls
142
+ ]
143
+
144
+ # Handle tool responses
145
+ if msg.tool_call_id:
146
+ result["tool_call_id"] = msg.tool_call_id
147
+
148
+ if msg.name:
149
+ result["name"] = msg.name
150
+
151
+ return result
152
+
153
+ def _convert_tools(self, tools: list[Tool]) -> list[dict]:
154
+ """Convert internal Tool definitions to OpenAI format."""
155
+ return [
156
+ {
157
+ "type": tool.type,
158
+ "function": {
159
+ "name": tool.function.name,
160
+ "description": tool.function.description,
161
+ "parameters": tool.function.parameters,
162
+ **(
163
+ {"strict": tool.function.strict}
164
+ if tool.function.strict is not None
165
+ else {}
166
+ ),
167
+ },
168
+ }
169
+ for tool in tools
170
+ ]
171
+
172
+ async def generate(self, request: GenerateRequest) -> GenerateResponse:
173
+ """Generate a response using OpenAI API."""
174
+ await self.validate_request(request)
175
+
176
+ messages = [self._convert_message(msg) for msg in request.messages]
177
+
178
+ kwargs = {
179
+ "model": self._model_name,
180
+ "messages": messages,
181
+ }
182
+
183
+ # if request.temperature is not None:
184
+ # kwargs["temperature"] = request.temperature
185
+ if request.max_tokens is not None:
186
+ kwargs["max_completion_tokens"] = (
187
+ request.max_tokens if request.max_tokens <= 16384 else 16384
188
+ )
189
+ if request.top_p is not None:
190
+ kwargs["top_p"] = request.top_p
191
+ if request.stop:
192
+ kwargs["stop"] = request.stop
193
+ if request.tools:
194
+ kwargs["tools"] = self._convert_tools(request.tools)
195
+ # Only add tool_choice if tools are present
196
+ if request.tool_choice:
197
+ # Validate tool_choice - OpenAI only supports 'none', 'auto', 'required', or dict
198
+ if isinstance(request.tool_choice, str):
199
+ if request.tool_choice in ("none", "auto", "required"):
200
+ kwargs["tool_choice"] = request.tool_choice
201
+ else:
202
+ # Invalid string value - default to 'auto'
203
+ kwargs["tool_choice"] = "auto"
204
+ elif isinstance(request.tool_choice, dict):
205
+ kwargs["tool_choice"] = request.tool_choice
206
+ if request.response_format:
207
+ kwargs["response_format"] = request.response_format
208
+
209
+ response = await self.client.chat.completions.create(**kwargs)
210
+
211
+ choice = response.choices[0]
212
+ message = choice.message
213
+
214
+ # Extract content
215
+ content = message.content
216
+
217
+ # Extract tool calls
218
+ tool_calls = None
219
+ if message.tool_calls:
220
+ tool_calls = [
221
+ ToolCall(
222
+ id=tc.id,
223
+ type=tc.type,
224
+ function=FunctionCall(
225
+ name=tc.function.name, arguments=tc.function.arguments
226
+ ),
227
+ )
228
+ for tc in message.tool_calls
229
+ ]
230
+
231
+ return GenerateResponse(
232
+ content=content,
233
+ tool_calls=tool_calls,
234
+ finish_reason=choice.finish_reason,
235
+ usage={
236
+ "prompt_tokens": response.usage.prompt_tokens,
237
+ "completion_tokens": response.usage.completion_tokens,
238
+ "total_tokens": response.usage.total_tokens,
239
+ }
240
+ if response.usage
241
+ else None,
242
+ )
243
+
244
+ async def generate_stream(
245
+ self, request: GenerateRequest
246
+ ) -> AsyncIterator[StreamChunk]:
247
+ """Generate a streaming response using OpenAI API."""
248
+ await self.validate_request(request)
249
+
250
+ messages = [self._convert_message(msg) for msg in request.messages]
251
+
252
+ kwargs = {
253
+ "model": self._model_name,
254
+ "messages": messages,
255
+ "stream": True,
256
+ }
257
+
258
+ if request.temperature is not None:
259
+ kwargs["temperature"] = request.temperature
260
+ if request.max_tokens is not None:
261
+ kwargs["max_tokens"] = request.max_tokens
262
+ if request.top_p is not None:
263
+ kwargs["top_p"] = request.top_p
264
+ if request.stop:
265
+ kwargs["stop"] = request.stop
266
+ if request.tools:
267
+ kwargs["tools"] = self._convert_tools(request.tools)
268
+ # Only add tool_choice if tools are present
269
+ if request.tool_choice:
270
+ # Validate tool_choice - OpenAI only supports 'none', 'auto', 'required', or dict
271
+ if isinstance(request.tool_choice, str):
272
+ if request.tool_choice in ("none", "auto", "required"):
273
+ kwargs["tool_choice"] = request.tool_choice
274
+ else:
275
+ # Invalid string value - default to 'auto'
276
+ kwargs["tool_choice"] = "auto"
277
+ elif isinstance(request.tool_choice, dict):
278
+ kwargs["tool_choice"] = request.tool_choice
279
+ if request.response_format:
280
+ kwargs["response_format"] = request.response_format
281
+
282
+ stream = await self.client.chat.completions.create(**kwargs)
283
+
284
+ async for chunk in stream:
285
+ if not chunk.choices:
286
+ continue
287
+
288
+ choice = chunk.choices[0]
289
+ delta = choice.delta
290
+
291
+ content = delta.content if delta.content else None
292
+ finish_reason = choice.finish_reason
293
+
294
+ # Handle tool calls in streaming
295
+ tool_calls = None
296
+ if delta.tool_calls:
297
+ tool_calls = [
298
+ ToolCall(
299
+ id=tc.id or "",
300
+ type=tc.type or "function",
301
+ function=FunctionCall(
302
+ name=tc.function.name or "",
303
+ arguments=tc.function.arguments or "",
304
+ ),
305
+ )
306
+ for tc in delta.tool_calls
307
+ ]
308
+
309
+ yield StreamChunk(
310
+ content=content,
311
+ tool_calls=tool_calls,
312
+ finish_reason=finish_reason,
313
+ )
314
+
315
+
316
+ class AzureOpenAIModel(OpenAIModel):
317
+ """Azure OpenAI model implementation."""
318
+
319
+ def __init__(
320
+ self,
321
+ model_name: str,
322
+ api_key: str,
323
+ azure_endpoint: str,
324
+ deployment_name: str,
325
+ api_version: str = "2024-08-01-preview",
326
+ ):
327
+ """
328
+ Initialize Azure OpenAI model.
329
+
330
+ Args:
331
+ model_name: Model identifier for capability detection
332
+ api_key: Azure OpenAI API key
333
+ azure_endpoint: Azure endpoint URL
334
+ deployment_name: Azure deployment name
335
+ api_version: API version
336
+ """
337
+ # Store Azure-specific parameters before calling super().__init__()
338
+ self._api_key = api_key
339
+ self._azure_endpoint = azure_endpoint
340
+ self._api_version = api_version
341
+ self._base_model_name = model_name
342
+ self._deployment_name = deployment_name
343
+
344
+ # Call parent constructor (will call our overridden _init_client)
345
+ super().__init__(model_name, api_key)
346
+
347
+ def _init_client(
348
+ self,
349
+ api_key: str | None = None,
350
+ base_url: str | None = None,
351
+ organization: str | None = None,
352
+ ) -> None:
353
+ """Initialize Azure OpenAI client."""
354
+ self.client = AsyncAzureOpenAI(
355
+ api_key=self._api_key,
356
+ azure_endpoint=self._azure_endpoint,
357
+ api_version=self._api_version,
358
+ azure_deployment=self._deployment_name,
359
+ )
360
+
361
+ def _determine_capabilities(self) -> ModelCapability:
362
+ """Determine capabilities based on base model name."""
363
+ caps = (
364
+ ModelCapability.TEXT_GENERATION
365
+ | ModelCapability.STREAMING
366
+ | ModelCapability.TOOL_CALLING
367
+ | ModelCapability.STRUCTURED_OUTPUT
368
+ | ModelCapability.MULTIMODAL_INPUT
369
+ )
370
+ if "vision" in self._base_model_name.lower() or "4o" in self._base_model_name:
371
+ caps |= ModelCapability.MULTIMODAL_INPUT
372
+ return caps
373
+
374
+ @property
375
+ def deployment_name(self) -> str:
376
+ """Get current deployment name."""
377
+ return self._deployment_name
378
+
379
+ @deployment_name.setter
380
+ def deployment_name(self, value: str):
381
+ """
382
+ Set new deployment name and reinitialize client.
383
+
384
+ Args:
385
+ value: New deployment name
386
+ """
387
+ self._deployment_name = value
388
+ # Reinitialize client with new deployment name
389
+ self._init_client(self._api_key)
390
+
391
+ async def generate(self, request: GenerateRequest) -> GenerateResponse:
392
+ """Generate a response using Azure OpenAI API with parameter adaptation."""
393
+ # Override to adapt parameters where needed, then call parent
394
+ return await super().generate(request)
395
+
396
+
397
+ class OpenAIEmbeddingModel(LLMModelAbstract):
398
+ """OpenAI embedding model implementation."""
399
+
400
+ def __init__(
401
+ self,
402
+ model_name: str = "text-embedding-3-small",
403
+ api_key: str | None = None,
404
+ base_url: str | None = None,
405
+ ):
406
+ """
407
+ Initialize OpenAI embedding model.
408
+
409
+ Args:
410
+ model_name: Embedding model name
411
+ api_key: OpenAI API key
412
+ base_url: Optional custom base URL
413
+ """
414
+ self._base_url = base_url
415
+ self._api_key = api_key
416
+ self._model_name = model_name
417
+ self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
418
+ self._capabilities = self._determine_capabilities()
419
+
420
+ def _determine_capabilities(self) -> ModelCapability:
421
+ """Determine capabilities based on model name."""
422
+ caps = ModelCapability.EMBEDDINGS
423
+ return caps
424
+
425
+ @property
426
+ def capabilities(self) -> ModelCapability:
427
+ """Return the capabilities supported by this model."""
428
+ return self._capabilities
429
+
430
+ def supports_capability(self, capability: ModelCapability) -> bool:
431
+ """Check if this model supports a specific capability."""
432
+ return bool(self.capabilities & capability)
433
+
434
+ @property
435
+ def model_name(self) -> str:
436
+ return self._model_name
437
+
438
+ @model_name.setter
439
+ def model_name(self, value: str):
440
+ """
441
+ Set new model name and recalculate capabilities.
442
+
443
+ Args:
444
+ value: New model name
445
+ """
446
+ self._model_name = value
447
+ # Recalculate capabilities based on new model name
448
+
449
+ @property
450
+ def base_url(self):
451
+ return self._base_url
452
+
453
+ @base_url.setter
454
+ def base_url(self, value: str):
455
+ self._base_url = value
456
+ self.client = AsyncOpenAI(api_key=self._api_key, base_url=value)
457
+
458
+ @property
459
+ def api_key(self):
460
+ return self._api_key
461
+
462
+ @api_key.setter
463
+ def api_key(self, value: str):
464
+ self._api_key = value
465
+ self.client = AsyncOpenAI(api_key=value, base_url=self._base_url)
466
+
467
+ async def generate(self, request: GenerateRequest) -> GenerateResponse:
468
+ raise NotImplementedError("Embedding models do not support text generation")
469
+
470
+ async def generate_stream(
471
+ self, request: GenerateRequest
472
+ ) -> AsyncIterator[StreamChunk]:
473
+ raise NotImplementedError("Embedding models do not support text generation")
474
+
475
+ async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
476
+ """Generate embeddings using OpenAI API."""
477
+ inputs = [request.input] if isinstance(request.input, str) else request.input
478
+
479
+ kwargs = {"model": self._model_name, "input": inputs}
480
+
481
+ if request.dimensions:
482
+ kwargs["dimensions"] = request.dimensions
483
+
484
+ response = await self.client.embeddings.create(**kwargs)
485
+
486
+ embeddings = [item.embedding for item in response.data]
487
+
488
+ return EmbeddingResponse(
489
+ embeddings=embeddings,
490
+ usage={
491
+ "prompt_tokens": response.usage.prompt_tokens,
492
+ "total_tokens": response.usage.total_tokens,
493
+ }
494
+ if response.usage
495
+ else None,
496
+ )
497
+
498
+
499
+ class AzureOpenAIEmbeddingModel(LLMModelAbstract):
500
+ """Azure OpenAI embedding model implementation."""
501
+
502
+ def __init__(
503
+ self,
504
+ model_name: str,
505
+ api_key: str,
506
+ azure_endpoint: str,
507
+ deployment_name: str,
508
+ api_version: str = "2024-08-01-preview",
509
+ ):
510
+ """
511
+ Initialize Azure OpenAI embedding model.
512
+
513
+ Args:
514
+ model_name: Model identifier (e.g., "text-embedding-3-small")
515
+ api_key: Azure OpenAI API key
516
+ azure_endpoint: Azure endpoint URL
517
+ deployment_name: Azure deployment name
518
+ api_version: API version
519
+ """
520
+ self._model_name = model_name
521
+ self._deployment_name = deployment_name
522
+ self._api_key = api_key
523
+ self._azure_endpoint = azure_endpoint
524
+ self._api_version = api_version
525
+
526
+ self.client = AsyncAzureOpenAI(
527
+ api_key=api_key,
528
+ azure_endpoint=azure_endpoint,
529
+ api_version=api_version,
530
+ azure_deployment=deployment_name,
531
+ )
532
+
533
+ @property
534
+ def model_name(self) -> str:
535
+ return self._model_name
536
+
537
+ @model_name.setter
538
+ def model_name(self, value: str):
539
+ """
540
+ Set new model name.
541
+
542
+ Args:
543
+ value: New model name
544
+ """
545
+ self._model_name = value
546
+ self._deployment_name = value
547
+ self.client = AsyncAzureOpenAI(
548
+ api_key=self._api_key,
549
+ azure_endpoint=self._azure_endpoint,
550
+ api_version=self._api_version,
551
+ azure_deployment=value,
552
+ )
553
+
554
+ @property
555
+ def capabilities(self) -> ModelCapability:
556
+ return ModelCapability.EMBEDDINGS
557
+
558
+ async def generate(self, request: GenerateRequest) -> GenerateResponse:
559
+ raise NotImplementedError("Embedding models do not support text generation")
560
+
561
+ async def generate_stream(
562
+ self, request: GenerateRequest
563
+ ) -> AsyncIterator[StreamChunk]:
564
+ raise NotImplementedError("Embedding models do not support text generation")
565
+
566
+ async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
567
+ """Generate embeddings using Azure OpenAI API."""
568
+ inputs = [request.input] if isinstance(request.input, str) else request.input
569
+
570
+ kwargs = {"model": self._deployment_name, "input": inputs}
571
+
572
+ if request.dimensions:
573
+ kwargs["dimensions"] = request.dimensions
574
+
575
+ response = await self.client.embeddings.create(**kwargs)
576
+
577
+ embeddings = [item.embedding for item in response.data]
578
+
579
+ return EmbeddingResponse(
580
+ embeddings=embeddings,
581
+ usage={
582
+ "prompt_tokens": response.usage.prompt_tokens,
583
+ "total_tokens": response.usage.total_tokens,
584
+ }
585
+ if response.usage
586
+ else None,
587
+ )