donkit-llm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- donkit/llm/__init__.py +55 -0
- donkit/llm/claude_model.py +484 -0
- donkit/llm/factory.py +164 -0
- donkit/llm/model_abstract.py +256 -0
- donkit/llm/openai_model.py +587 -0
- donkit/llm/vertex_model.py +478 -0
- donkit_llm-0.1.0.dist-info/METADATA +17 -0
- donkit_llm-0.1.0.dist-info/RECORD +9 -0
- donkit_llm-0.1.0.dist-info/WHEEL +4 -0
donkit/llm/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from .model_abstract import (
|
|
2
|
+
ContentPart,
|
|
3
|
+
ContentType,
|
|
4
|
+
EmbeddingRequest,
|
|
5
|
+
EmbeddingResponse,
|
|
6
|
+
FunctionCall,
|
|
7
|
+
FunctionDefinition,
|
|
8
|
+
GenerateRequest,
|
|
9
|
+
GenerateResponse,
|
|
10
|
+
LLMModelAbstract,
|
|
11
|
+
Message,
|
|
12
|
+
ModelCapability,
|
|
13
|
+
StreamChunk,
|
|
14
|
+
Tool,
|
|
15
|
+
ToolCall,
|
|
16
|
+
)
|
|
17
|
+
from .openai_model import (
|
|
18
|
+
AzureOpenAIEmbeddingModel,
|
|
19
|
+
AzureOpenAIModel,
|
|
20
|
+
OpenAIEmbeddingModel,
|
|
21
|
+
OpenAIModel,
|
|
22
|
+
)
|
|
23
|
+
from .claude_model import ClaudeModel, ClaudeVertexModel
|
|
24
|
+
from .vertex_model import VertexAIModel, VertexEmbeddingModel
|
|
25
|
+
from .factory import ModelFactory
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"ModelFactory",
|
|
29
|
+
# Abstract base
|
|
30
|
+
"LLMModelAbstract",
|
|
31
|
+
"ModelCapability",
|
|
32
|
+
# Request/Response models
|
|
33
|
+
"Message",
|
|
34
|
+
"ContentPart",
|
|
35
|
+
"ContentType",
|
|
36
|
+
"GenerateRequest",
|
|
37
|
+
"GenerateResponse",
|
|
38
|
+
"StreamChunk",
|
|
39
|
+
"EmbeddingRequest",
|
|
40
|
+
"EmbeddingResponse",
|
|
41
|
+
# Tool/Function calling
|
|
42
|
+
"Tool",
|
|
43
|
+
"ToolCall",
|
|
44
|
+
"FunctionCall",
|
|
45
|
+
"FunctionDefinition",
|
|
46
|
+
# Implementations
|
|
47
|
+
"OpenAIModel",
|
|
48
|
+
"AzureOpenAIModel",
|
|
49
|
+
"OpenAIEmbeddingModel",
|
|
50
|
+
"AzureOpenAIEmbeddingModel",
|
|
51
|
+
"ClaudeModel",
|
|
52
|
+
"ClaudeVertexModel",
|
|
53
|
+
"VertexAIModel",
|
|
54
|
+
"VertexEmbeddingModel",
|
|
55
|
+
]
|
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import AsyncIterator
|
|
3
|
+
|
|
4
|
+
from anthropic import AsyncAnthropic, AsyncAnthropicVertex
|
|
5
|
+
|
|
6
|
+
from .model_abstract import (
|
|
7
|
+
ContentType,
|
|
8
|
+
FunctionCall,
|
|
9
|
+
GenerateRequest,
|
|
10
|
+
GenerateResponse,
|
|
11
|
+
LLMModelAbstract,
|
|
12
|
+
Message,
|
|
13
|
+
ModelCapability,
|
|
14
|
+
StreamChunk,
|
|
15
|
+
Tool,
|
|
16
|
+
ToolCall,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ClaudeModel(LLMModelAbstract):
|
|
21
|
+
"""Anthropic Claude model implementation."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model_name: str,
|
|
26
|
+
api_key: str,
|
|
27
|
+
base_url: str | None = None,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize Claude model via Anthropic API.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model_name: Model identifier (e.g., "claude-3-5-sonnet-20241022", "claude-3-opus-20240229")
|
|
34
|
+
api_key: Anthropic API key
|
|
35
|
+
base_url: Optional custom base URL
|
|
36
|
+
"""
|
|
37
|
+
self._model_name = model_name
|
|
38
|
+
self.client = AsyncAnthropic(
|
|
39
|
+
api_key=api_key,
|
|
40
|
+
base_url=base_url,
|
|
41
|
+
)
|
|
42
|
+
self._capabilities = self._determine_capabilities()
|
|
43
|
+
|
|
44
|
+
def _determine_capabilities(self) -> ModelCapability:
|
|
45
|
+
"""Determine capabilities based on model name."""
|
|
46
|
+
caps = (
|
|
47
|
+
ModelCapability.TEXT_GENERATION
|
|
48
|
+
| ModelCapability.STREAMING
|
|
49
|
+
| ModelCapability.TOOL_CALLING
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Claude 3+ models support vision
|
|
53
|
+
if "claude-3" in self._model_name.lower():
|
|
54
|
+
caps |= ModelCapability.VISION | ModelCapability.MULTIMODAL_INPUT
|
|
55
|
+
|
|
56
|
+
# Structured output via tool use (not native JSON mode like OpenAI)
|
|
57
|
+
caps |= ModelCapability.STRUCTURED_OUTPUT
|
|
58
|
+
|
|
59
|
+
return caps
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def model_name(self) -> str:
|
|
63
|
+
return self._model_name
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def capabilities(self) -> ModelCapability:
|
|
67
|
+
return self._capabilities
|
|
68
|
+
|
|
69
|
+
def _convert_message(self, msg: Message) -> dict:
|
|
70
|
+
"""Convert internal Message to Claude format."""
|
|
71
|
+
result = {"role": msg.role}
|
|
72
|
+
|
|
73
|
+
# Claude uses "user" and "assistant" roles, no "system" in messages array
|
|
74
|
+
if msg.role == "system":
|
|
75
|
+
# System messages should be handled separately
|
|
76
|
+
result["role"] = "user"
|
|
77
|
+
|
|
78
|
+
# Handle content
|
|
79
|
+
if isinstance(msg.content, str):
|
|
80
|
+
result["content"] = msg.content
|
|
81
|
+
else:
|
|
82
|
+
# Multimodal content
|
|
83
|
+
content_parts = []
|
|
84
|
+
for part in msg.content:
|
|
85
|
+
if part.type == ContentType.TEXT:
|
|
86
|
+
content_parts.append({"type": "text", "text": part.content})
|
|
87
|
+
elif part.type == ContentType.IMAGE_URL:
|
|
88
|
+
# Claude expects base64 images, not URLs
|
|
89
|
+
content_parts.append(
|
|
90
|
+
{
|
|
91
|
+
"type": "image",
|
|
92
|
+
"source": {
|
|
93
|
+
"type": "url",
|
|
94
|
+
"url": part.content,
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
elif part.type == ContentType.IMAGE_BASE64:
|
|
99
|
+
content_parts.append(
|
|
100
|
+
{
|
|
101
|
+
"type": "image",
|
|
102
|
+
"source": {
|
|
103
|
+
"type": "base64",
|
|
104
|
+
"media_type": part.mime_type or "image/jpeg",
|
|
105
|
+
"data": part.content,
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
result["content"] = content_parts
|
|
110
|
+
|
|
111
|
+
return result
|
|
112
|
+
|
|
113
|
+
def _convert_tools(self, tools: list[Tool]) -> list[dict]:
|
|
114
|
+
"""Convert internal Tool definitions to Claude format."""
|
|
115
|
+
return [
|
|
116
|
+
{
|
|
117
|
+
"name": tool.function.name,
|
|
118
|
+
"description": tool.function.description,
|
|
119
|
+
"input_schema": tool.function.parameters,
|
|
120
|
+
}
|
|
121
|
+
for tool in tools
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
125
|
+
"""Generate a response using Claude API."""
|
|
126
|
+
await self.validate_request(request)
|
|
127
|
+
|
|
128
|
+
# Extract system message
|
|
129
|
+
system_message = None
|
|
130
|
+
messages = []
|
|
131
|
+
for msg in request.messages:
|
|
132
|
+
if msg.role == "system":
|
|
133
|
+
system_message = msg.content if isinstance(msg.content, str) else ""
|
|
134
|
+
else:
|
|
135
|
+
messages.append(self._convert_message(msg))
|
|
136
|
+
|
|
137
|
+
kwargs = {
|
|
138
|
+
"model": self._model_name,
|
|
139
|
+
"messages": messages,
|
|
140
|
+
"max_tokens": request.max_tokens or 4096, # Claude requires max_tokens
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
if system_message:
|
|
144
|
+
kwargs["system"] = system_message
|
|
145
|
+
if request.temperature is not None:
|
|
146
|
+
kwargs["temperature"] = request.temperature
|
|
147
|
+
if request.top_p is not None:
|
|
148
|
+
kwargs["top_p"] = request.top_p
|
|
149
|
+
if request.stop:
|
|
150
|
+
kwargs["stop_sequences"] = request.stop
|
|
151
|
+
if request.tools:
|
|
152
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
153
|
+
|
|
154
|
+
response = await self.client.messages.create(**kwargs)
|
|
155
|
+
|
|
156
|
+
# Extract content
|
|
157
|
+
content = None
|
|
158
|
+
text_blocks = [block.text for block in response.content if block.type == "text"]
|
|
159
|
+
if text_blocks:
|
|
160
|
+
content = "".join(text_blocks)
|
|
161
|
+
|
|
162
|
+
# Extract tool calls
|
|
163
|
+
tool_calls = None
|
|
164
|
+
tool_use_blocks = [
|
|
165
|
+
block for block in response.content if block.type == "tool_use"
|
|
166
|
+
]
|
|
167
|
+
if tool_use_blocks:
|
|
168
|
+
tool_calls = [
|
|
169
|
+
ToolCall(
|
|
170
|
+
id=block.id,
|
|
171
|
+
type="function",
|
|
172
|
+
function=FunctionCall(
|
|
173
|
+
name=block.name,
|
|
174
|
+
arguments=json.dumps(block.input),
|
|
175
|
+
),
|
|
176
|
+
)
|
|
177
|
+
for block in tool_use_blocks
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
return GenerateResponse(
|
|
181
|
+
content=content,
|
|
182
|
+
tool_calls=tool_calls,
|
|
183
|
+
finish_reason=response.stop_reason,
|
|
184
|
+
usage={
|
|
185
|
+
"prompt_tokens": response.usage.input_tokens,
|
|
186
|
+
"completion_tokens": response.usage.output_tokens,
|
|
187
|
+
"total_tokens": response.usage.input_tokens
|
|
188
|
+
+ response.usage.output_tokens,
|
|
189
|
+
},
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
async def generate_stream(
|
|
193
|
+
self, request: GenerateRequest
|
|
194
|
+
) -> AsyncIterator[StreamChunk]:
|
|
195
|
+
"""Generate a streaming response using Claude API."""
|
|
196
|
+
await self.validate_request(request)
|
|
197
|
+
|
|
198
|
+
# Extract system message
|
|
199
|
+
system_message = None
|
|
200
|
+
messages = []
|
|
201
|
+
for msg in request.messages:
|
|
202
|
+
if msg.role == "system":
|
|
203
|
+
system_message = msg.content if isinstance(msg.content, str) else ""
|
|
204
|
+
else:
|
|
205
|
+
messages.append(self._convert_message(msg))
|
|
206
|
+
|
|
207
|
+
kwargs = {
|
|
208
|
+
"model": self._model_name,
|
|
209
|
+
"messages": messages,
|
|
210
|
+
"max_tokens": request.max_tokens or 4096,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
if system_message:
|
|
214
|
+
kwargs["system"] = system_message
|
|
215
|
+
if request.temperature is not None:
|
|
216
|
+
kwargs["temperature"] = request.temperature
|
|
217
|
+
if request.top_p is not None:
|
|
218
|
+
kwargs["top_p"] = request.top_p
|
|
219
|
+
if request.stop:
|
|
220
|
+
kwargs["stop_sequences"] = request.stop
|
|
221
|
+
if request.tools:
|
|
222
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
223
|
+
|
|
224
|
+
async with self.client.messages.stream(**kwargs) as stream:
|
|
225
|
+
async for event in stream:
|
|
226
|
+
content = None
|
|
227
|
+
tool_calls = None
|
|
228
|
+
finish_reason = None
|
|
229
|
+
|
|
230
|
+
if event.type == "content_block_delta":
|
|
231
|
+
if hasattr(event.delta, "text"):
|
|
232
|
+
content = event.delta.text
|
|
233
|
+
|
|
234
|
+
elif event.type == "content_block_stop":
|
|
235
|
+
if (
|
|
236
|
+
hasattr(event, "content_block")
|
|
237
|
+
and event.content_block.type == "tool_use"
|
|
238
|
+
):
|
|
239
|
+
tool_calls = [
|
|
240
|
+
ToolCall(
|
|
241
|
+
id=event.content_block.id,
|
|
242
|
+
type="function",
|
|
243
|
+
function=FunctionCall(
|
|
244
|
+
name=event.content_block.name,
|
|
245
|
+
arguments=json.dumps(event.content_block.input),
|
|
246
|
+
),
|
|
247
|
+
)
|
|
248
|
+
]
|
|
249
|
+
|
|
250
|
+
elif event.type == "message_stop":
|
|
251
|
+
finish_reason = "stop"
|
|
252
|
+
|
|
253
|
+
if content or tool_calls or finish_reason:
|
|
254
|
+
yield StreamChunk(
|
|
255
|
+
content=content,
|
|
256
|
+
tool_calls=tool_calls,
|
|
257
|
+
finish_reason=finish_reason,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class ClaudeVertexModel(LLMModelAbstract):
|
|
262
|
+
"""Anthropic Claude model via Vertex AI."""
|
|
263
|
+
|
|
264
|
+
def __init__(
|
|
265
|
+
self,
|
|
266
|
+
model_name: str,
|
|
267
|
+
project_id: str,
|
|
268
|
+
location: str = "us-east5",
|
|
269
|
+
):
|
|
270
|
+
"""
|
|
271
|
+
Initialize Claude model via Vertex AI.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
model_name: Model identifier (e.g., "claude-3-5-sonnet-v2@20241022")
|
|
275
|
+
project_id: GCP project ID
|
|
276
|
+
location: GCP location (us-east5 for Claude)
|
|
277
|
+
"""
|
|
278
|
+
self._model_name = model_name
|
|
279
|
+
self.client = AsyncAnthropicVertex(
|
|
280
|
+
project_id=project_id,
|
|
281
|
+
region=location,
|
|
282
|
+
)
|
|
283
|
+
self._capabilities = self._determine_capabilities()
|
|
284
|
+
|
|
285
|
+
def _determine_capabilities(self) -> ModelCapability:
|
|
286
|
+
"""Determine capabilities based on model name."""
|
|
287
|
+
caps = (
|
|
288
|
+
ModelCapability.TEXT_GENERATION
|
|
289
|
+
| ModelCapability.STREAMING
|
|
290
|
+
| ModelCapability.TOOL_CALLING
|
|
291
|
+
| ModelCapability.STRUCTURED_OUTPUT
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Claude 3+ models support vision
|
|
295
|
+
if "claude-3" in self._model_name.lower():
|
|
296
|
+
caps |= ModelCapability.VISION | ModelCapability.MULTIMODAL_INPUT
|
|
297
|
+
|
|
298
|
+
return caps
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def model_name(self) -> str:
|
|
302
|
+
return self._model_name
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def capabilities(self) -> ModelCapability:
|
|
306
|
+
return self._capabilities
|
|
307
|
+
|
|
308
|
+
def _convert_message(self, msg: Message) -> dict:
|
|
309
|
+
"""Convert internal Message to Claude format."""
|
|
310
|
+
result = {"role": msg.role}
|
|
311
|
+
|
|
312
|
+
if msg.role == "system":
|
|
313
|
+
result["role"] = "user"
|
|
314
|
+
|
|
315
|
+
# Handle content
|
|
316
|
+
if isinstance(msg.content, str):
|
|
317
|
+
result["content"] = msg.content
|
|
318
|
+
else:
|
|
319
|
+
# Multimodal content
|
|
320
|
+
content_parts = []
|
|
321
|
+
for part in msg.content:
|
|
322
|
+
if part.type == ContentType.TEXT:
|
|
323
|
+
content_parts.append({"type": "text", "text": part.content})
|
|
324
|
+
elif part.type == ContentType.IMAGE_BASE64:
|
|
325
|
+
content_parts.append(
|
|
326
|
+
{
|
|
327
|
+
"type": "image",
|
|
328
|
+
"source": {
|
|
329
|
+
"type": "base64",
|
|
330
|
+
"media_type": part.mime_type or "image/jpeg",
|
|
331
|
+
"data": part.content,
|
|
332
|
+
},
|
|
333
|
+
}
|
|
334
|
+
)
|
|
335
|
+
result["content"] = content_parts
|
|
336
|
+
|
|
337
|
+
return result
|
|
338
|
+
|
|
339
|
+
def _convert_tools(self, tools: list[Tool]) -> list[dict]:
|
|
340
|
+
"""Convert internal Tool definitions to Claude format."""
|
|
341
|
+
return [
|
|
342
|
+
{
|
|
343
|
+
"name": tool.function.name,
|
|
344
|
+
"description": tool.function.description,
|
|
345
|
+
"input_schema": tool.function.parameters,
|
|
346
|
+
}
|
|
347
|
+
for tool in tools
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
351
|
+
"""Generate a response using Claude via Vertex AI."""
|
|
352
|
+
await self.validate_request(request)
|
|
353
|
+
|
|
354
|
+
# Extract system message
|
|
355
|
+
system_message = None
|
|
356
|
+
messages = []
|
|
357
|
+
for msg in request.messages:
|
|
358
|
+
if msg.role == "system":
|
|
359
|
+
system_message = msg.content if isinstance(msg.content, str) else ""
|
|
360
|
+
else:
|
|
361
|
+
messages.append(self._convert_message(msg))
|
|
362
|
+
|
|
363
|
+
kwargs = {
|
|
364
|
+
"model": self._model_name,
|
|
365
|
+
"messages": messages,
|
|
366
|
+
"max_tokens": request.max_tokens or 4096,
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
if system_message:
|
|
370
|
+
kwargs["system"] = system_message
|
|
371
|
+
if request.temperature is not None:
|
|
372
|
+
kwargs["temperature"] = request.temperature
|
|
373
|
+
if request.top_p is not None:
|
|
374
|
+
kwargs["top_p"] = request.top_p
|
|
375
|
+
if request.stop:
|
|
376
|
+
kwargs["stop_sequences"] = request.stop
|
|
377
|
+
if request.tools:
|
|
378
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
379
|
+
|
|
380
|
+
response = await self.client.messages.create(**kwargs)
|
|
381
|
+
|
|
382
|
+
# Extract content
|
|
383
|
+
content = None
|
|
384
|
+
text_blocks = [block.text for block in response.content if block.type == "text"]
|
|
385
|
+
if text_blocks:
|
|
386
|
+
content = "".join(text_blocks)
|
|
387
|
+
|
|
388
|
+
# Extract tool calls
|
|
389
|
+
tool_calls = None
|
|
390
|
+
tool_use_blocks = [
|
|
391
|
+
block for block in response.content if block.type == "tool_use"
|
|
392
|
+
]
|
|
393
|
+
if tool_use_blocks:
|
|
394
|
+
tool_calls = [
|
|
395
|
+
ToolCall(
|
|
396
|
+
id=block.id,
|
|
397
|
+
type="function",
|
|
398
|
+
function=FunctionCall(
|
|
399
|
+
name=block.name,
|
|
400
|
+
arguments=json.dumps(block.input),
|
|
401
|
+
),
|
|
402
|
+
)
|
|
403
|
+
for block in tool_use_blocks
|
|
404
|
+
]
|
|
405
|
+
|
|
406
|
+
return GenerateResponse(
|
|
407
|
+
content=content,
|
|
408
|
+
tool_calls=tool_calls,
|
|
409
|
+
finish_reason=response.stop_reason,
|
|
410
|
+
usage={
|
|
411
|
+
"prompt_tokens": response.usage.input_tokens,
|
|
412
|
+
"completion_tokens": response.usage.output_tokens,
|
|
413
|
+
"total_tokens": response.usage.input_tokens
|
|
414
|
+
+ response.usage.output_tokens,
|
|
415
|
+
},
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
async def generate_stream(
|
|
419
|
+
self, request: GenerateRequest
|
|
420
|
+
) -> AsyncIterator[StreamChunk]:
|
|
421
|
+
"""Generate a streaming response using Claude via Vertex AI."""
|
|
422
|
+
await self.validate_request(request)
|
|
423
|
+
|
|
424
|
+
# Extract system message
|
|
425
|
+
system_message = None
|
|
426
|
+
messages = []
|
|
427
|
+
for msg in request.messages:
|
|
428
|
+
if msg.role == "system":
|
|
429
|
+
system_message = msg.content if isinstance(msg.content, str) else ""
|
|
430
|
+
else:
|
|
431
|
+
messages.append(self._convert_message(msg))
|
|
432
|
+
|
|
433
|
+
kwargs = {
|
|
434
|
+
"model": self._model_name,
|
|
435
|
+
"messages": messages,
|
|
436
|
+
"max_tokens": request.max_tokens or 4096,
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
if system_message:
|
|
440
|
+
kwargs["system"] = system_message
|
|
441
|
+
if request.temperature is not None:
|
|
442
|
+
kwargs["temperature"] = request.temperature
|
|
443
|
+
if request.top_p is not None:
|
|
444
|
+
kwargs["top_p"] = request.top_p
|
|
445
|
+
if request.stop:
|
|
446
|
+
kwargs["stop_sequences"] = request.stop
|
|
447
|
+
if request.tools:
|
|
448
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
449
|
+
|
|
450
|
+
async with self.client.messages.stream(**kwargs) as stream:
|
|
451
|
+
async for event in stream:
|
|
452
|
+
content = None
|
|
453
|
+
tool_calls = None
|
|
454
|
+
finish_reason = None
|
|
455
|
+
|
|
456
|
+
if event.type == "content_block_delta":
|
|
457
|
+
if hasattr(event.delta, "text"):
|
|
458
|
+
content = event.delta.text
|
|
459
|
+
|
|
460
|
+
elif event.type == "content_block_stop":
|
|
461
|
+
if (
|
|
462
|
+
hasattr(event, "content_block")
|
|
463
|
+
and event.content_block.type == "tool_use"
|
|
464
|
+
):
|
|
465
|
+
tool_calls = [
|
|
466
|
+
ToolCall(
|
|
467
|
+
id=event.content_block.id,
|
|
468
|
+
type="function",
|
|
469
|
+
function=FunctionCall(
|
|
470
|
+
name=event.content_block.name,
|
|
471
|
+
arguments=json.dumps(event.content_block.input),
|
|
472
|
+
),
|
|
473
|
+
)
|
|
474
|
+
]
|
|
475
|
+
|
|
476
|
+
elif event.type == "message_stop":
|
|
477
|
+
finish_reason = "stop"
|
|
478
|
+
|
|
479
|
+
if content or tool_calls or finish_reason:
|
|
480
|
+
yield StreamChunk(
|
|
481
|
+
content=content,
|
|
482
|
+
tool_calls=tool_calls,
|
|
483
|
+
finish_reason=finish_reason,
|
|
484
|
+
)
|
donkit/llm/factory.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from .claude_model import ClaudeModel, ClaudeVertexModel
|
|
4
|
+
from .model_abstract import LLMModelAbstract
|
|
5
|
+
from .openai_model import (
|
|
6
|
+
AzureOpenAIEmbeddingModel,
|
|
7
|
+
AzureOpenAIModel,
|
|
8
|
+
OpenAIEmbeddingModel,
|
|
9
|
+
OpenAIModel,
|
|
10
|
+
)
|
|
11
|
+
from .vertex_model import VertexAIModel, VertexEmbeddingModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelFactory:
|
|
15
|
+
"""Factory for creating LLM model instances."""
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def create_openai_model(
|
|
19
|
+
model_name: str,
|
|
20
|
+
api_key: str,
|
|
21
|
+
base_url: str | None = None,
|
|
22
|
+
organization: str | None = None,
|
|
23
|
+
) -> OpenAIModel:
|
|
24
|
+
return OpenAIModel(
|
|
25
|
+
model_name=model_name,
|
|
26
|
+
api_key=api_key,
|
|
27
|
+
base_url=base_url,
|
|
28
|
+
organization=organization,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def create_azure_openai_model(
|
|
33
|
+
model_name: str,
|
|
34
|
+
api_key: str,
|
|
35
|
+
azure_endpoint: str,
|
|
36
|
+
api_version: str = "2024-08-01-preview",
|
|
37
|
+
deployment_name: str | None = None,
|
|
38
|
+
) -> AzureOpenAIModel:
|
|
39
|
+
return AzureOpenAIModel(
|
|
40
|
+
model_name=deployment_name or model_name,
|
|
41
|
+
api_key=api_key,
|
|
42
|
+
azure_endpoint=azure_endpoint,
|
|
43
|
+
api_version=api_version,
|
|
44
|
+
deployment_name=deployment_name,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def create_embedding_model(
|
|
49
|
+
provider: Literal["openai", "azure_openai", "vertex"],
|
|
50
|
+
model_name: str | None = None,
|
|
51
|
+
api_key: str | None = None,
|
|
52
|
+
**kwargs,
|
|
53
|
+
) -> LLMModelAbstract:
|
|
54
|
+
if provider == "openai":
|
|
55
|
+
return OpenAIEmbeddingModel(
|
|
56
|
+
model_name=model_name or "text-embedding-3-small",
|
|
57
|
+
api_key=api_key,
|
|
58
|
+
base_url=kwargs.get("base_url"),
|
|
59
|
+
)
|
|
60
|
+
elif provider == "azure_openai":
|
|
61
|
+
return AzureOpenAIEmbeddingModel(
|
|
62
|
+
model_name=model_name or "text-embedding-ada-002",
|
|
63
|
+
api_key=api_key,
|
|
64
|
+
azure_endpoint=kwargs["azure_endpoint"],
|
|
65
|
+
deployment_name=kwargs.get("deployment_name")
|
|
66
|
+
or model_name
|
|
67
|
+
or "text-embedding-ada-002",
|
|
68
|
+
api_version=kwargs.get("api_version", "2024-08-01-preview"),
|
|
69
|
+
)
|
|
70
|
+
elif provider == "vertex":
|
|
71
|
+
return VertexEmbeddingModel(
|
|
72
|
+
project_id=kwargs["project_id"],
|
|
73
|
+
model_name=model_name or "text-multilingual-embedding-002",
|
|
74
|
+
location=kwargs.get("location", "us-central1"),
|
|
75
|
+
credentials=kwargs.get("credentials"),
|
|
76
|
+
output_dimensionality=kwargs.get("output_dimensionality"),
|
|
77
|
+
batch_size=kwargs.get("batch_size", 100),
|
|
78
|
+
task_type=kwargs.get("task_type", "RETRIEVAL_DOCUMENT"),
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError(f"Unknown embedding provider: {provider}")
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def create_claude_model(
|
|
85
|
+
model_name: str,
|
|
86
|
+
api_key: str,
|
|
87
|
+
base_url: str | None = None,
|
|
88
|
+
) -> ClaudeModel:
|
|
89
|
+
return ClaudeModel(
|
|
90
|
+
model_name=model_name,
|
|
91
|
+
api_key=api_key,
|
|
92
|
+
base_url=base_url,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def create_claude_vertex_model(
|
|
97
|
+
model_name: str,
|
|
98
|
+
project_id: str,
|
|
99
|
+
location: str = "us-east5",
|
|
100
|
+
) -> ClaudeVertexModel:
|
|
101
|
+
return ClaudeVertexModel(
|
|
102
|
+
model_name=model_name,
|
|
103
|
+
project_id=project_id,
|
|
104
|
+
location=location,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def create_vertex_model(
|
|
109
|
+
model_name: str,
|
|
110
|
+
project_id: str,
|
|
111
|
+
location: str = "us-central1",
|
|
112
|
+
credentials: dict | None = None,
|
|
113
|
+
) -> VertexAIModel:
|
|
114
|
+
return VertexAIModel(
|
|
115
|
+
model_name=model_name,
|
|
116
|
+
project_id=project_id,
|
|
117
|
+
location=location,
|
|
118
|
+
credentials=credentials,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def create_model(
|
|
123
|
+
provider: Literal[
|
|
124
|
+
"openai", "azure_openai", "claude", "claude_vertex", "vertex"
|
|
125
|
+
],
|
|
126
|
+
model_name: str,
|
|
127
|
+
credentials: dict,
|
|
128
|
+
) -> LLMModelAbstract:
|
|
129
|
+
if provider == "openai":
|
|
130
|
+
return ModelFactory.create_openai_model(
|
|
131
|
+
model_name=model_name,
|
|
132
|
+
api_key=credentials["api_key"],
|
|
133
|
+
base_url=credentials.get("base_url"),
|
|
134
|
+
organization=credentials.get("organization"),
|
|
135
|
+
)
|
|
136
|
+
elif provider == "azure_openai":
|
|
137
|
+
return ModelFactory.create_azure_openai_model(
|
|
138
|
+
model_name=model_name,
|
|
139
|
+
api_key=credentials["api_key"],
|
|
140
|
+
azure_endpoint=credentials["azure_endpoint"],
|
|
141
|
+
api_version=credentials.get("api_version", "2024-08-01-preview"),
|
|
142
|
+
deployment_name=credentials.get("deployment_name"),
|
|
143
|
+
)
|
|
144
|
+
elif provider == "claude":
|
|
145
|
+
return ModelFactory.create_claude_model(
|
|
146
|
+
model_name=model_name,
|
|
147
|
+
api_key=credentials["api_key"],
|
|
148
|
+
base_url=credentials.get("base_url"),
|
|
149
|
+
)
|
|
150
|
+
elif provider == "claude_vertex":
|
|
151
|
+
return ModelFactory.create_claude_vertex_model(
|
|
152
|
+
model_name=model_name,
|
|
153
|
+
project_id=credentials["project_id"],
|
|
154
|
+
location=credentials.get("location", "us-east5"),
|
|
155
|
+
)
|
|
156
|
+
elif provider == "vertex":
|
|
157
|
+
return ModelFactory.create_vertex_model(
|
|
158
|
+
model_name=model_name,
|
|
159
|
+
project_id=credentials["project_id"],
|
|
160
|
+
location=credentials.get("location", "us-central1"),
|
|
161
|
+
credentials=credentials.get("credentials"),
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(f"Unknown provider: {provider}")
|