donkit-llm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- donkit/llm/__init__.py +55 -0
- donkit/llm/claude_model.py +484 -0
- donkit/llm/factory.py +164 -0
- donkit/llm/model_abstract.py +256 -0
- donkit/llm/openai_model.py +587 -0
- donkit/llm/vertex_model.py +478 -0
- donkit_llm-0.1.0.dist-info/METADATA +17 -0
- donkit_llm-0.1.0.dist-info/RECORD +9 -0
- donkit_llm-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
from typing import AsyncIterator
|
|
2
|
+
|
|
3
|
+
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
4
|
+
|
|
5
|
+
from .model_abstract import (
|
|
6
|
+
ContentType,
|
|
7
|
+
EmbeddingRequest,
|
|
8
|
+
EmbeddingResponse,
|
|
9
|
+
FunctionCall,
|
|
10
|
+
GenerateRequest,
|
|
11
|
+
GenerateResponse,
|
|
12
|
+
LLMModelAbstract,
|
|
13
|
+
Message,
|
|
14
|
+
ModelCapability,
|
|
15
|
+
StreamChunk,
|
|
16
|
+
Tool,
|
|
17
|
+
ToolCall,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OpenAIModel(LLMModelAbstract):
|
|
22
|
+
"""OpenAI model implementation supporting GPT-4, GPT-3.5, etc."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model_name: str,
|
|
27
|
+
api_key: str,
|
|
28
|
+
base_url: str | None = None,
|
|
29
|
+
organization: str | None = None,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Initialize OpenAI model.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
model_name: Model identifier (e.g., "gpt-4o", "gpt-4o-mini")
|
|
36
|
+
api_key: OpenAI API key
|
|
37
|
+
base_url: Optional custom base URL
|
|
38
|
+
organization: Optional organization ID
|
|
39
|
+
"""
|
|
40
|
+
self._model_name = model_name
|
|
41
|
+
self._init_client(api_key, base_url, organization)
|
|
42
|
+
self._capabilities = self._determine_capabilities()
|
|
43
|
+
|
|
44
|
+
def _init_client(
|
|
45
|
+
self,
|
|
46
|
+
api_key: str,
|
|
47
|
+
base_url: str | None = None,
|
|
48
|
+
organization: str | None = None,
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Initialize the OpenAI client. Can be overridden by subclasses."""
|
|
51
|
+
self.client = AsyncOpenAI(
|
|
52
|
+
api_key=api_key,
|
|
53
|
+
base_url=base_url,
|
|
54
|
+
organization=organization,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def _determine_capabilities(self) -> ModelCapability:
|
|
58
|
+
"""Determine capabilities based on model name."""
|
|
59
|
+
caps = (
|
|
60
|
+
ModelCapability.TEXT_GENERATION
|
|
61
|
+
| ModelCapability.STREAMING
|
|
62
|
+
| ModelCapability.STRUCTURED_OUTPUT
|
|
63
|
+
| ModelCapability.TOOL_CALLING
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
model_lower = self._model_name.lower()
|
|
67
|
+
|
|
68
|
+
# Vision models (GPT-4o, GPT-4 Turbo, GPT-5, etc.)
|
|
69
|
+
if any(
|
|
70
|
+
x in model_lower
|
|
71
|
+
for x in ["gpt-4o", "gpt-4-turbo", "gpt-4-vision", "gpt-5", "o1", "o3"]
|
|
72
|
+
):
|
|
73
|
+
caps |= ModelCapability.VISION | ModelCapability.MULTIMODAL_INPUT
|
|
74
|
+
|
|
75
|
+
# Audio models
|
|
76
|
+
if "audio" in model_lower:
|
|
77
|
+
caps |= ModelCapability.AUDIO_INPUT | ModelCapability.MULTIMODAL_INPUT
|
|
78
|
+
|
|
79
|
+
return caps
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def model_name(self) -> str:
|
|
83
|
+
return self._model_name
|
|
84
|
+
|
|
85
|
+
@model_name.setter
|
|
86
|
+
def model_name(self, value: str):
|
|
87
|
+
"""
|
|
88
|
+
Set new model name and recalculate capabilities.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
value: New model name
|
|
92
|
+
"""
|
|
93
|
+
self._model_name = value
|
|
94
|
+
# Recalculate capabilities based on new model name
|
|
95
|
+
self._capabilities = self._determine_capabilities()
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def capabilities(self) -> ModelCapability:
|
|
99
|
+
return self._capabilities
|
|
100
|
+
|
|
101
|
+
def _convert_message(self, msg: Message) -> dict:
|
|
102
|
+
"""Convert internal Message to OpenAI format."""
|
|
103
|
+
result = {"role": msg.role}
|
|
104
|
+
|
|
105
|
+
# Handle content
|
|
106
|
+
if isinstance(msg.content, str):
|
|
107
|
+
result["content"] = msg.content
|
|
108
|
+
else:
|
|
109
|
+
# Multimodal content
|
|
110
|
+
content_parts = []
|
|
111
|
+
for part in msg.content:
|
|
112
|
+
if part.type == ContentType.TEXT:
|
|
113
|
+
content_parts.append({"type": "text", "text": part.content})
|
|
114
|
+
elif part.type == ContentType.IMAGE_URL:
|
|
115
|
+
content_parts.append(
|
|
116
|
+
{"type": "image_url", "image_url": {"url": part.content}}
|
|
117
|
+
)
|
|
118
|
+
elif part.type == ContentType.IMAGE_BASE64:
|
|
119
|
+
content_parts.append(
|
|
120
|
+
{
|
|
121
|
+
"type": "image_url",
|
|
122
|
+
"image_url": {
|
|
123
|
+
"url": f"data:{part.mime_type or 'image/jpeg'};base64,{part.content}"
|
|
124
|
+
},
|
|
125
|
+
}
|
|
126
|
+
)
|
|
127
|
+
# Add more content types as needed
|
|
128
|
+
result["content"] = content_parts
|
|
129
|
+
|
|
130
|
+
# Handle tool calls
|
|
131
|
+
if msg.tool_calls:
|
|
132
|
+
result["tool_calls"] = [
|
|
133
|
+
{
|
|
134
|
+
"id": tc.id,
|
|
135
|
+
"type": tc.type,
|
|
136
|
+
"function": {
|
|
137
|
+
"name": tc.function.name,
|
|
138
|
+
"arguments": tc.function.arguments,
|
|
139
|
+
},
|
|
140
|
+
}
|
|
141
|
+
for tc in msg.tool_calls
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
# Handle tool responses
|
|
145
|
+
if msg.tool_call_id:
|
|
146
|
+
result["tool_call_id"] = msg.tool_call_id
|
|
147
|
+
|
|
148
|
+
if msg.name:
|
|
149
|
+
result["name"] = msg.name
|
|
150
|
+
|
|
151
|
+
return result
|
|
152
|
+
|
|
153
|
+
def _convert_tools(self, tools: list[Tool]) -> list[dict]:
|
|
154
|
+
"""Convert internal Tool definitions to OpenAI format."""
|
|
155
|
+
return [
|
|
156
|
+
{
|
|
157
|
+
"type": tool.type,
|
|
158
|
+
"function": {
|
|
159
|
+
"name": tool.function.name,
|
|
160
|
+
"description": tool.function.description,
|
|
161
|
+
"parameters": tool.function.parameters,
|
|
162
|
+
**(
|
|
163
|
+
{"strict": tool.function.strict}
|
|
164
|
+
if tool.function.strict is not None
|
|
165
|
+
else {}
|
|
166
|
+
),
|
|
167
|
+
},
|
|
168
|
+
}
|
|
169
|
+
for tool in tools
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
173
|
+
"""Generate a response using OpenAI API."""
|
|
174
|
+
await self.validate_request(request)
|
|
175
|
+
|
|
176
|
+
messages = [self._convert_message(msg) for msg in request.messages]
|
|
177
|
+
|
|
178
|
+
kwargs = {
|
|
179
|
+
"model": self._model_name,
|
|
180
|
+
"messages": messages,
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
# if request.temperature is not None:
|
|
184
|
+
# kwargs["temperature"] = request.temperature
|
|
185
|
+
if request.max_tokens is not None:
|
|
186
|
+
kwargs["max_completion_tokens"] = (
|
|
187
|
+
request.max_tokens if request.max_tokens <= 16384 else 16384
|
|
188
|
+
)
|
|
189
|
+
if request.top_p is not None:
|
|
190
|
+
kwargs["top_p"] = request.top_p
|
|
191
|
+
if request.stop:
|
|
192
|
+
kwargs["stop"] = request.stop
|
|
193
|
+
if request.tools:
|
|
194
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
195
|
+
# Only add tool_choice if tools are present
|
|
196
|
+
if request.tool_choice:
|
|
197
|
+
# Validate tool_choice - OpenAI only supports 'none', 'auto', 'required', or dict
|
|
198
|
+
if isinstance(request.tool_choice, str):
|
|
199
|
+
if request.tool_choice in ("none", "auto", "required"):
|
|
200
|
+
kwargs["tool_choice"] = request.tool_choice
|
|
201
|
+
else:
|
|
202
|
+
# Invalid string value - default to 'auto'
|
|
203
|
+
kwargs["tool_choice"] = "auto"
|
|
204
|
+
elif isinstance(request.tool_choice, dict):
|
|
205
|
+
kwargs["tool_choice"] = request.tool_choice
|
|
206
|
+
if request.response_format:
|
|
207
|
+
kwargs["response_format"] = request.response_format
|
|
208
|
+
|
|
209
|
+
response = await self.client.chat.completions.create(**kwargs)
|
|
210
|
+
|
|
211
|
+
choice = response.choices[0]
|
|
212
|
+
message = choice.message
|
|
213
|
+
|
|
214
|
+
# Extract content
|
|
215
|
+
content = message.content
|
|
216
|
+
|
|
217
|
+
# Extract tool calls
|
|
218
|
+
tool_calls = None
|
|
219
|
+
if message.tool_calls:
|
|
220
|
+
tool_calls = [
|
|
221
|
+
ToolCall(
|
|
222
|
+
id=tc.id,
|
|
223
|
+
type=tc.type,
|
|
224
|
+
function=FunctionCall(
|
|
225
|
+
name=tc.function.name, arguments=tc.function.arguments
|
|
226
|
+
),
|
|
227
|
+
)
|
|
228
|
+
for tc in message.tool_calls
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
return GenerateResponse(
|
|
232
|
+
content=content,
|
|
233
|
+
tool_calls=tool_calls,
|
|
234
|
+
finish_reason=choice.finish_reason,
|
|
235
|
+
usage={
|
|
236
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
237
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
238
|
+
"total_tokens": response.usage.total_tokens,
|
|
239
|
+
}
|
|
240
|
+
if response.usage
|
|
241
|
+
else None,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
async def generate_stream(
|
|
245
|
+
self, request: GenerateRequest
|
|
246
|
+
) -> AsyncIterator[StreamChunk]:
|
|
247
|
+
"""Generate a streaming response using OpenAI API."""
|
|
248
|
+
await self.validate_request(request)
|
|
249
|
+
|
|
250
|
+
messages = [self._convert_message(msg) for msg in request.messages]
|
|
251
|
+
|
|
252
|
+
kwargs = {
|
|
253
|
+
"model": self._model_name,
|
|
254
|
+
"messages": messages,
|
|
255
|
+
"stream": True,
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
if request.temperature is not None:
|
|
259
|
+
kwargs["temperature"] = request.temperature
|
|
260
|
+
if request.max_tokens is not None:
|
|
261
|
+
kwargs["max_tokens"] = request.max_tokens
|
|
262
|
+
if request.top_p is not None:
|
|
263
|
+
kwargs["top_p"] = request.top_p
|
|
264
|
+
if request.stop:
|
|
265
|
+
kwargs["stop"] = request.stop
|
|
266
|
+
if request.tools:
|
|
267
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
268
|
+
# Only add tool_choice if tools are present
|
|
269
|
+
if request.tool_choice:
|
|
270
|
+
# Validate tool_choice - OpenAI only supports 'none', 'auto', 'required', or dict
|
|
271
|
+
if isinstance(request.tool_choice, str):
|
|
272
|
+
if request.tool_choice in ("none", "auto", "required"):
|
|
273
|
+
kwargs["tool_choice"] = request.tool_choice
|
|
274
|
+
else:
|
|
275
|
+
# Invalid string value - default to 'auto'
|
|
276
|
+
kwargs["tool_choice"] = "auto"
|
|
277
|
+
elif isinstance(request.tool_choice, dict):
|
|
278
|
+
kwargs["tool_choice"] = request.tool_choice
|
|
279
|
+
if request.response_format:
|
|
280
|
+
kwargs["response_format"] = request.response_format
|
|
281
|
+
|
|
282
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
|
283
|
+
|
|
284
|
+
async for chunk in stream:
|
|
285
|
+
if not chunk.choices:
|
|
286
|
+
continue
|
|
287
|
+
|
|
288
|
+
choice = chunk.choices[0]
|
|
289
|
+
delta = choice.delta
|
|
290
|
+
|
|
291
|
+
content = delta.content if delta.content else None
|
|
292
|
+
finish_reason = choice.finish_reason
|
|
293
|
+
|
|
294
|
+
# Handle tool calls in streaming
|
|
295
|
+
tool_calls = None
|
|
296
|
+
if delta.tool_calls:
|
|
297
|
+
tool_calls = [
|
|
298
|
+
ToolCall(
|
|
299
|
+
id=tc.id or "",
|
|
300
|
+
type=tc.type or "function",
|
|
301
|
+
function=FunctionCall(
|
|
302
|
+
name=tc.function.name or "",
|
|
303
|
+
arguments=tc.function.arguments or "",
|
|
304
|
+
),
|
|
305
|
+
)
|
|
306
|
+
for tc in delta.tool_calls
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
yield StreamChunk(
|
|
310
|
+
content=content,
|
|
311
|
+
tool_calls=tool_calls,
|
|
312
|
+
finish_reason=finish_reason,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class AzureOpenAIModel(OpenAIModel):
|
|
317
|
+
"""Azure OpenAI model implementation."""
|
|
318
|
+
|
|
319
|
+
def __init__(
|
|
320
|
+
self,
|
|
321
|
+
model_name: str,
|
|
322
|
+
api_key: str,
|
|
323
|
+
azure_endpoint: str,
|
|
324
|
+
deployment_name: str,
|
|
325
|
+
api_version: str = "2024-08-01-preview",
|
|
326
|
+
):
|
|
327
|
+
"""
|
|
328
|
+
Initialize Azure OpenAI model.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
model_name: Model identifier for capability detection
|
|
332
|
+
api_key: Azure OpenAI API key
|
|
333
|
+
azure_endpoint: Azure endpoint URL
|
|
334
|
+
deployment_name: Azure deployment name
|
|
335
|
+
api_version: API version
|
|
336
|
+
"""
|
|
337
|
+
# Store Azure-specific parameters before calling super().__init__()
|
|
338
|
+
self._api_key = api_key
|
|
339
|
+
self._azure_endpoint = azure_endpoint
|
|
340
|
+
self._api_version = api_version
|
|
341
|
+
self._base_model_name = model_name
|
|
342
|
+
self._deployment_name = deployment_name
|
|
343
|
+
|
|
344
|
+
# Call parent constructor (will call our overridden _init_client)
|
|
345
|
+
super().__init__(model_name, api_key)
|
|
346
|
+
|
|
347
|
+
def _init_client(
|
|
348
|
+
self,
|
|
349
|
+
api_key: str | None = None,
|
|
350
|
+
base_url: str | None = None,
|
|
351
|
+
organization: str | None = None,
|
|
352
|
+
) -> None:
|
|
353
|
+
"""Initialize Azure OpenAI client."""
|
|
354
|
+
self.client = AsyncAzureOpenAI(
|
|
355
|
+
api_key=self._api_key,
|
|
356
|
+
azure_endpoint=self._azure_endpoint,
|
|
357
|
+
api_version=self._api_version,
|
|
358
|
+
azure_deployment=self._deployment_name,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
def _determine_capabilities(self) -> ModelCapability:
|
|
362
|
+
"""Determine capabilities based on base model name."""
|
|
363
|
+
caps = (
|
|
364
|
+
ModelCapability.TEXT_GENERATION
|
|
365
|
+
| ModelCapability.STREAMING
|
|
366
|
+
| ModelCapability.TOOL_CALLING
|
|
367
|
+
| ModelCapability.STRUCTURED_OUTPUT
|
|
368
|
+
| ModelCapability.MULTIMODAL_INPUT
|
|
369
|
+
)
|
|
370
|
+
if "vision" in self._base_model_name.lower() or "4o" in self._base_model_name:
|
|
371
|
+
caps |= ModelCapability.MULTIMODAL_INPUT
|
|
372
|
+
return caps
|
|
373
|
+
|
|
374
|
+
@property
|
|
375
|
+
def deployment_name(self) -> str:
|
|
376
|
+
"""Get current deployment name."""
|
|
377
|
+
return self._deployment_name
|
|
378
|
+
|
|
379
|
+
@deployment_name.setter
|
|
380
|
+
def deployment_name(self, value: str):
|
|
381
|
+
"""
|
|
382
|
+
Set new deployment name and reinitialize client.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
value: New deployment name
|
|
386
|
+
"""
|
|
387
|
+
self._deployment_name = value
|
|
388
|
+
# Reinitialize client with new deployment name
|
|
389
|
+
self._init_client(self._api_key)
|
|
390
|
+
|
|
391
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
392
|
+
"""Generate a response using Azure OpenAI API with parameter adaptation."""
|
|
393
|
+
# Override to adapt parameters where needed, then call parent
|
|
394
|
+
return await super().generate(request)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class OpenAIEmbeddingModel(LLMModelAbstract):
|
|
398
|
+
"""OpenAI embedding model implementation."""
|
|
399
|
+
|
|
400
|
+
def __init__(
|
|
401
|
+
self,
|
|
402
|
+
model_name: str = "text-embedding-3-small",
|
|
403
|
+
api_key: str | None = None,
|
|
404
|
+
base_url: str | None = None,
|
|
405
|
+
):
|
|
406
|
+
"""
|
|
407
|
+
Initialize OpenAI embedding model.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
model_name: Embedding model name
|
|
411
|
+
api_key: OpenAI API key
|
|
412
|
+
base_url: Optional custom base URL
|
|
413
|
+
"""
|
|
414
|
+
self._base_url = base_url
|
|
415
|
+
self._api_key = api_key
|
|
416
|
+
self._model_name = model_name
|
|
417
|
+
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
418
|
+
self._capabilities = self._determine_capabilities()
|
|
419
|
+
|
|
420
|
+
def _determine_capabilities(self) -> ModelCapability:
|
|
421
|
+
"""Determine capabilities based on model name."""
|
|
422
|
+
caps = ModelCapability.EMBEDDINGS
|
|
423
|
+
return caps
|
|
424
|
+
|
|
425
|
+
@property
|
|
426
|
+
def capabilities(self) -> ModelCapability:
|
|
427
|
+
"""Return the capabilities supported by this model."""
|
|
428
|
+
return self._capabilities
|
|
429
|
+
|
|
430
|
+
def supports_capability(self, capability: ModelCapability) -> bool:
|
|
431
|
+
"""Check if this model supports a specific capability."""
|
|
432
|
+
return bool(self.capabilities & capability)
|
|
433
|
+
|
|
434
|
+
@property
|
|
435
|
+
def model_name(self) -> str:
|
|
436
|
+
return self._model_name
|
|
437
|
+
|
|
438
|
+
@model_name.setter
|
|
439
|
+
def model_name(self, value: str):
|
|
440
|
+
"""
|
|
441
|
+
Set new model name and recalculate capabilities.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
value: New model name
|
|
445
|
+
"""
|
|
446
|
+
self._model_name = value
|
|
447
|
+
# Recalculate capabilities based on new model name
|
|
448
|
+
|
|
449
|
+
@property
|
|
450
|
+
def base_url(self):
|
|
451
|
+
return self._base_url
|
|
452
|
+
|
|
453
|
+
@base_url.setter
|
|
454
|
+
def base_url(self, value: str):
|
|
455
|
+
self._base_url = value
|
|
456
|
+
self.client = AsyncOpenAI(api_key=self._api_key, base_url=value)
|
|
457
|
+
|
|
458
|
+
@property
|
|
459
|
+
def api_key(self):
|
|
460
|
+
return self._api_key
|
|
461
|
+
|
|
462
|
+
@api_key.setter
|
|
463
|
+
def api_key(self, value: str):
|
|
464
|
+
self._api_key = value
|
|
465
|
+
self.client = AsyncOpenAI(api_key=value, base_url=self._base_url)
|
|
466
|
+
|
|
467
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
468
|
+
raise NotImplementedError("Embedding models do not support text generation")
|
|
469
|
+
|
|
470
|
+
async def generate_stream(
|
|
471
|
+
self, request: GenerateRequest
|
|
472
|
+
) -> AsyncIterator[StreamChunk]:
|
|
473
|
+
raise NotImplementedError("Embedding models do not support text generation")
|
|
474
|
+
|
|
475
|
+
async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
476
|
+
"""Generate embeddings using OpenAI API."""
|
|
477
|
+
inputs = [request.input] if isinstance(request.input, str) else request.input
|
|
478
|
+
|
|
479
|
+
kwargs = {"model": self._model_name, "input": inputs}
|
|
480
|
+
|
|
481
|
+
if request.dimensions:
|
|
482
|
+
kwargs["dimensions"] = request.dimensions
|
|
483
|
+
|
|
484
|
+
response = await self.client.embeddings.create(**kwargs)
|
|
485
|
+
|
|
486
|
+
embeddings = [item.embedding for item in response.data]
|
|
487
|
+
|
|
488
|
+
return EmbeddingResponse(
|
|
489
|
+
embeddings=embeddings,
|
|
490
|
+
usage={
|
|
491
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
492
|
+
"total_tokens": response.usage.total_tokens,
|
|
493
|
+
}
|
|
494
|
+
if response.usage
|
|
495
|
+
else None,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
class AzureOpenAIEmbeddingModel(LLMModelAbstract):
|
|
500
|
+
"""Azure OpenAI embedding model implementation."""
|
|
501
|
+
|
|
502
|
+
def __init__(
|
|
503
|
+
self,
|
|
504
|
+
model_name: str,
|
|
505
|
+
api_key: str,
|
|
506
|
+
azure_endpoint: str,
|
|
507
|
+
deployment_name: str,
|
|
508
|
+
api_version: str = "2024-08-01-preview",
|
|
509
|
+
):
|
|
510
|
+
"""
|
|
511
|
+
Initialize Azure OpenAI embedding model.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
model_name: Model identifier (e.g., "text-embedding-3-small")
|
|
515
|
+
api_key: Azure OpenAI API key
|
|
516
|
+
azure_endpoint: Azure endpoint URL
|
|
517
|
+
deployment_name: Azure deployment name
|
|
518
|
+
api_version: API version
|
|
519
|
+
"""
|
|
520
|
+
self._model_name = model_name
|
|
521
|
+
self._deployment_name = deployment_name
|
|
522
|
+
self._api_key = api_key
|
|
523
|
+
self._azure_endpoint = azure_endpoint
|
|
524
|
+
self._api_version = api_version
|
|
525
|
+
|
|
526
|
+
self.client = AsyncAzureOpenAI(
|
|
527
|
+
api_key=api_key,
|
|
528
|
+
azure_endpoint=azure_endpoint,
|
|
529
|
+
api_version=api_version,
|
|
530
|
+
azure_deployment=deployment_name,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
@property
|
|
534
|
+
def model_name(self) -> str:
|
|
535
|
+
return self._model_name
|
|
536
|
+
|
|
537
|
+
@model_name.setter
|
|
538
|
+
def model_name(self, value: str):
|
|
539
|
+
"""
|
|
540
|
+
Set new model name.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
value: New model name
|
|
544
|
+
"""
|
|
545
|
+
self._model_name = value
|
|
546
|
+
self._deployment_name = value
|
|
547
|
+
self.client = AsyncAzureOpenAI(
|
|
548
|
+
api_key=self._api_key,
|
|
549
|
+
azure_endpoint=self._azure_endpoint,
|
|
550
|
+
api_version=self._api_version,
|
|
551
|
+
azure_deployment=value,
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
@property
|
|
555
|
+
def capabilities(self) -> ModelCapability:
|
|
556
|
+
return ModelCapability.EMBEDDINGS
|
|
557
|
+
|
|
558
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
559
|
+
raise NotImplementedError("Embedding models do not support text generation")
|
|
560
|
+
|
|
561
|
+
async def generate_stream(
|
|
562
|
+
self, request: GenerateRequest
|
|
563
|
+
) -> AsyncIterator[StreamChunk]:
|
|
564
|
+
raise NotImplementedError("Embedding models do not support text generation")
|
|
565
|
+
|
|
566
|
+
async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
567
|
+
"""Generate embeddings using Azure OpenAI API."""
|
|
568
|
+
inputs = [request.input] if isinstance(request.input, str) else request.input
|
|
569
|
+
|
|
570
|
+
kwargs = {"model": self._deployment_name, "input": inputs}
|
|
571
|
+
|
|
572
|
+
if request.dimensions:
|
|
573
|
+
kwargs["dimensions"] = request.dimensions
|
|
574
|
+
|
|
575
|
+
response = await self.client.embeddings.create(**kwargs)
|
|
576
|
+
|
|
577
|
+
embeddings = [item.embedding for item in response.data]
|
|
578
|
+
|
|
579
|
+
return EmbeddingResponse(
|
|
580
|
+
embeddings=embeddings,
|
|
581
|
+
usage={
|
|
582
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
583
|
+
"total_tokens": response.usage.total_tokens,
|
|
584
|
+
}
|
|
585
|
+
if response.usage
|
|
586
|
+
else None,
|
|
587
|
+
)
|