donkit-llm 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- donkit_llm-0.1.0/PKG-INFO +17 -0
- donkit_llm-0.1.0/pyproject.toml +23 -0
- donkit_llm-0.1.0/src/donkit/llm/__init__.py +55 -0
- donkit_llm-0.1.0/src/donkit/llm/claude_model.py +484 -0
- donkit_llm-0.1.0/src/donkit/llm/factory.py +164 -0
- donkit_llm-0.1.0/src/donkit/llm/model_abstract.py +256 -0
- donkit_llm-0.1.0/src/donkit/llm/openai_model.py +587 -0
- donkit_llm-0.1.0/src/donkit/llm/vertex_model.py +478 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: donkit-llm
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Unified LLM model implementations for Donkit (OpenAI, Azure OpenAI, Claude, Vertex AI)
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Donkit AI
|
|
7
|
+
Author-email: opensource@donkit.ai
|
|
8
|
+
Requires-Python: >=3.12,<4.0
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
|
+
Requires-Dist: anthropic[vertex] (>=0.42.0,<0.43.0)
|
|
14
|
+
Requires-Dist: google-auth (>=2.0.0,<3.0.0)
|
|
15
|
+
Requires-Dist: google-genai (>=1.38.0,<2.0.0)
|
|
16
|
+
Requires-Dist: openai (>=2.1.0,<3.0.0)
|
|
17
|
+
Requires-Dist: pydantic (>=2.8.0,<3.0.0)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "donkit-llm"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Unified LLM model implementations for Donkit (OpenAI, Azure OpenAI, Claude, Vertex AI)"
|
|
5
|
+
authors = ["Donkit AI <opensource@donkit.ai>"]
|
|
6
|
+
license = "MIT"
|
|
7
|
+
packages = [{ include = "donkit", from = "src" }]
|
|
8
|
+
|
|
9
|
+
[tool.poetry.dependencies]
|
|
10
|
+
python = "^3.12"
|
|
11
|
+
pydantic = "^2.8.0"
|
|
12
|
+
openai = "^2.1.0"
|
|
13
|
+
anthropic = {version = "^0.42.0", extras=["vertex"]}
|
|
14
|
+
google-genai = "^1.38.0"
|
|
15
|
+
google-auth = "^2.0.0"
|
|
16
|
+
|
|
17
|
+
[tool.poetry.group.dev.dependencies]
|
|
18
|
+
ruff = "^0.13.3"
|
|
19
|
+
pytest = "^8.4.2"
|
|
20
|
+
|
|
21
|
+
[build-system]
|
|
22
|
+
requires = ["poetry-core>=1.9.0"]
|
|
23
|
+
build-backend = "poetry.core.masonry.api"
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from .model_abstract import (
|
|
2
|
+
ContentPart,
|
|
3
|
+
ContentType,
|
|
4
|
+
EmbeddingRequest,
|
|
5
|
+
EmbeddingResponse,
|
|
6
|
+
FunctionCall,
|
|
7
|
+
FunctionDefinition,
|
|
8
|
+
GenerateRequest,
|
|
9
|
+
GenerateResponse,
|
|
10
|
+
LLMModelAbstract,
|
|
11
|
+
Message,
|
|
12
|
+
ModelCapability,
|
|
13
|
+
StreamChunk,
|
|
14
|
+
Tool,
|
|
15
|
+
ToolCall,
|
|
16
|
+
)
|
|
17
|
+
from .openai_model import (
|
|
18
|
+
AzureOpenAIEmbeddingModel,
|
|
19
|
+
AzureOpenAIModel,
|
|
20
|
+
OpenAIEmbeddingModel,
|
|
21
|
+
OpenAIModel,
|
|
22
|
+
)
|
|
23
|
+
from .claude_model import ClaudeModel, ClaudeVertexModel
|
|
24
|
+
from .vertex_model import VertexAIModel, VertexEmbeddingModel
|
|
25
|
+
from .factory import ModelFactory
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"ModelFactory",
|
|
29
|
+
# Abstract base
|
|
30
|
+
"LLMModelAbstract",
|
|
31
|
+
"ModelCapability",
|
|
32
|
+
# Request/Response models
|
|
33
|
+
"Message",
|
|
34
|
+
"ContentPart",
|
|
35
|
+
"ContentType",
|
|
36
|
+
"GenerateRequest",
|
|
37
|
+
"GenerateResponse",
|
|
38
|
+
"StreamChunk",
|
|
39
|
+
"EmbeddingRequest",
|
|
40
|
+
"EmbeddingResponse",
|
|
41
|
+
# Tool/Function calling
|
|
42
|
+
"Tool",
|
|
43
|
+
"ToolCall",
|
|
44
|
+
"FunctionCall",
|
|
45
|
+
"FunctionDefinition",
|
|
46
|
+
# Implementations
|
|
47
|
+
"OpenAIModel",
|
|
48
|
+
"AzureOpenAIModel",
|
|
49
|
+
"OpenAIEmbeddingModel",
|
|
50
|
+
"AzureOpenAIEmbeddingModel",
|
|
51
|
+
"ClaudeModel",
|
|
52
|
+
"ClaudeVertexModel",
|
|
53
|
+
"VertexAIModel",
|
|
54
|
+
"VertexEmbeddingModel",
|
|
55
|
+
]
|
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import AsyncIterator
|
|
3
|
+
|
|
4
|
+
from anthropic import AsyncAnthropic, AsyncAnthropicVertex
|
|
5
|
+
|
|
6
|
+
from .model_abstract import (
|
|
7
|
+
ContentType,
|
|
8
|
+
FunctionCall,
|
|
9
|
+
GenerateRequest,
|
|
10
|
+
GenerateResponse,
|
|
11
|
+
LLMModelAbstract,
|
|
12
|
+
Message,
|
|
13
|
+
ModelCapability,
|
|
14
|
+
StreamChunk,
|
|
15
|
+
Tool,
|
|
16
|
+
ToolCall,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ClaudeModel(LLMModelAbstract):
|
|
21
|
+
"""Anthropic Claude model implementation."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model_name: str,
|
|
26
|
+
api_key: str,
|
|
27
|
+
base_url: str | None = None,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize Claude model via Anthropic API.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model_name: Model identifier (e.g., "claude-3-5-sonnet-20241022", "claude-3-opus-20240229")
|
|
34
|
+
api_key: Anthropic API key
|
|
35
|
+
base_url: Optional custom base URL
|
|
36
|
+
"""
|
|
37
|
+
self._model_name = model_name
|
|
38
|
+
self.client = AsyncAnthropic(
|
|
39
|
+
api_key=api_key,
|
|
40
|
+
base_url=base_url,
|
|
41
|
+
)
|
|
42
|
+
self._capabilities = self._determine_capabilities()
|
|
43
|
+
|
|
44
|
+
def _determine_capabilities(self) -> ModelCapability:
|
|
45
|
+
"""Determine capabilities based on model name."""
|
|
46
|
+
caps = (
|
|
47
|
+
ModelCapability.TEXT_GENERATION
|
|
48
|
+
| ModelCapability.STREAMING
|
|
49
|
+
| ModelCapability.TOOL_CALLING
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Claude 3+ models support vision
|
|
53
|
+
if "claude-3" in self._model_name.lower():
|
|
54
|
+
caps |= ModelCapability.VISION | ModelCapability.MULTIMODAL_INPUT
|
|
55
|
+
|
|
56
|
+
# Structured output via tool use (not native JSON mode like OpenAI)
|
|
57
|
+
caps |= ModelCapability.STRUCTURED_OUTPUT
|
|
58
|
+
|
|
59
|
+
return caps
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def model_name(self) -> str:
|
|
63
|
+
return self._model_name
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def capabilities(self) -> ModelCapability:
|
|
67
|
+
return self._capabilities
|
|
68
|
+
|
|
69
|
+
def _convert_message(self, msg: Message) -> dict:
|
|
70
|
+
"""Convert internal Message to Claude format."""
|
|
71
|
+
result = {"role": msg.role}
|
|
72
|
+
|
|
73
|
+
# Claude uses "user" and "assistant" roles, no "system" in messages array
|
|
74
|
+
if msg.role == "system":
|
|
75
|
+
# System messages should be handled separately
|
|
76
|
+
result["role"] = "user"
|
|
77
|
+
|
|
78
|
+
# Handle content
|
|
79
|
+
if isinstance(msg.content, str):
|
|
80
|
+
result["content"] = msg.content
|
|
81
|
+
else:
|
|
82
|
+
# Multimodal content
|
|
83
|
+
content_parts = []
|
|
84
|
+
for part in msg.content:
|
|
85
|
+
if part.type == ContentType.TEXT:
|
|
86
|
+
content_parts.append({"type": "text", "text": part.content})
|
|
87
|
+
elif part.type == ContentType.IMAGE_URL:
|
|
88
|
+
# Claude expects base64 images, not URLs
|
|
89
|
+
content_parts.append(
|
|
90
|
+
{
|
|
91
|
+
"type": "image",
|
|
92
|
+
"source": {
|
|
93
|
+
"type": "url",
|
|
94
|
+
"url": part.content,
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
elif part.type == ContentType.IMAGE_BASE64:
|
|
99
|
+
content_parts.append(
|
|
100
|
+
{
|
|
101
|
+
"type": "image",
|
|
102
|
+
"source": {
|
|
103
|
+
"type": "base64",
|
|
104
|
+
"media_type": part.mime_type or "image/jpeg",
|
|
105
|
+
"data": part.content,
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
result["content"] = content_parts
|
|
110
|
+
|
|
111
|
+
return result
|
|
112
|
+
|
|
113
|
+
def _convert_tools(self, tools: list[Tool]) -> list[dict]:
|
|
114
|
+
"""Convert internal Tool definitions to Claude format."""
|
|
115
|
+
return [
|
|
116
|
+
{
|
|
117
|
+
"name": tool.function.name,
|
|
118
|
+
"description": tool.function.description,
|
|
119
|
+
"input_schema": tool.function.parameters,
|
|
120
|
+
}
|
|
121
|
+
for tool in tools
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
125
|
+
"""Generate a response using Claude API."""
|
|
126
|
+
await self.validate_request(request)
|
|
127
|
+
|
|
128
|
+
# Extract system message
|
|
129
|
+
system_message = None
|
|
130
|
+
messages = []
|
|
131
|
+
for msg in request.messages:
|
|
132
|
+
if msg.role == "system":
|
|
133
|
+
system_message = msg.content if isinstance(msg.content, str) else ""
|
|
134
|
+
else:
|
|
135
|
+
messages.append(self._convert_message(msg))
|
|
136
|
+
|
|
137
|
+
kwargs = {
|
|
138
|
+
"model": self._model_name,
|
|
139
|
+
"messages": messages,
|
|
140
|
+
"max_tokens": request.max_tokens or 4096, # Claude requires max_tokens
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
if system_message:
|
|
144
|
+
kwargs["system"] = system_message
|
|
145
|
+
if request.temperature is not None:
|
|
146
|
+
kwargs["temperature"] = request.temperature
|
|
147
|
+
if request.top_p is not None:
|
|
148
|
+
kwargs["top_p"] = request.top_p
|
|
149
|
+
if request.stop:
|
|
150
|
+
kwargs["stop_sequences"] = request.stop
|
|
151
|
+
if request.tools:
|
|
152
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
153
|
+
|
|
154
|
+
response = await self.client.messages.create(**kwargs)
|
|
155
|
+
|
|
156
|
+
# Extract content
|
|
157
|
+
content = None
|
|
158
|
+
text_blocks = [block.text for block in response.content if block.type == "text"]
|
|
159
|
+
if text_blocks:
|
|
160
|
+
content = "".join(text_blocks)
|
|
161
|
+
|
|
162
|
+
# Extract tool calls
|
|
163
|
+
tool_calls = None
|
|
164
|
+
tool_use_blocks = [
|
|
165
|
+
block for block in response.content if block.type == "tool_use"
|
|
166
|
+
]
|
|
167
|
+
if tool_use_blocks:
|
|
168
|
+
tool_calls = [
|
|
169
|
+
ToolCall(
|
|
170
|
+
id=block.id,
|
|
171
|
+
type="function",
|
|
172
|
+
function=FunctionCall(
|
|
173
|
+
name=block.name,
|
|
174
|
+
arguments=json.dumps(block.input),
|
|
175
|
+
),
|
|
176
|
+
)
|
|
177
|
+
for block in tool_use_blocks
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
return GenerateResponse(
|
|
181
|
+
content=content,
|
|
182
|
+
tool_calls=tool_calls,
|
|
183
|
+
finish_reason=response.stop_reason,
|
|
184
|
+
usage={
|
|
185
|
+
"prompt_tokens": response.usage.input_tokens,
|
|
186
|
+
"completion_tokens": response.usage.output_tokens,
|
|
187
|
+
"total_tokens": response.usage.input_tokens
|
|
188
|
+
+ response.usage.output_tokens,
|
|
189
|
+
},
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
async def generate_stream(
|
|
193
|
+
self, request: GenerateRequest
|
|
194
|
+
) -> AsyncIterator[StreamChunk]:
|
|
195
|
+
"""Generate a streaming response using Claude API."""
|
|
196
|
+
await self.validate_request(request)
|
|
197
|
+
|
|
198
|
+
# Extract system message
|
|
199
|
+
system_message = None
|
|
200
|
+
messages = []
|
|
201
|
+
for msg in request.messages:
|
|
202
|
+
if msg.role == "system":
|
|
203
|
+
system_message = msg.content if isinstance(msg.content, str) else ""
|
|
204
|
+
else:
|
|
205
|
+
messages.append(self._convert_message(msg))
|
|
206
|
+
|
|
207
|
+
kwargs = {
|
|
208
|
+
"model": self._model_name,
|
|
209
|
+
"messages": messages,
|
|
210
|
+
"max_tokens": request.max_tokens or 4096,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
if system_message:
|
|
214
|
+
kwargs["system"] = system_message
|
|
215
|
+
if request.temperature is not None:
|
|
216
|
+
kwargs["temperature"] = request.temperature
|
|
217
|
+
if request.top_p is not None:
|
|
218
|
+
kwargs["top_p"] = request.top_p
|
|
219
|
+
if request.stop:
|
|
220
|
+
kwargs["stop_sequences"] = request.stop
|
|
221
|
+
if request.tools:
|
|
222
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
223
|
+
|
|
224
|
+
async with self.client.messages.stream(**kwargs) as stream:
|
|
225
|
+
async for event in stream:
|
|
226
|
+
content = None
|
|
227
|
+
tool_calls = None
|
|
228
|
+
finish_reason = None
|
|
229
|
+
|
|
230
|
+
if event.type == "content_block_delta":
|
|
231
|
+
if hasattr(event.delta, "text"):
|
|
232
|
+
content = event.delta.text
|
|
233
|
+
|
|
234
|
+
elif event.type == "content_block_stop":
|
|
235
|
+
if (
|
|
236
|
+
hasattr(event, "content_block")
|
|
237
|
+
and event.content_block.type == "tool_use"
|
|
238
|
+
):
|
|
239
|
+
tool_calls = [
|
|
240
|
+
ToolCall(
|
|
241
|
+
id=event.content_block.id,
|
|
242
|
+
type="function",
|
|
243
|
+
function=FunctionCall(
|
|
244
|
+
name=event.content_block.name,
|
|
245
|
+
arguments=json.dumps(event.content_block.input),
|
|
246
|
+
),
|
|
247
|
+
)
|
|
248
|
+
]
|
|
249
|
+
|
|
250
|
+
elif event.type == "message_stop":
|
|
251
|
+
finish_reason = "stop"
|
|
252
|
+
|
|
253
|
+
if content or tool_calls or finish_reason:
|
|
254
|
+
yield StreamChunk(
|
|
255
|
+
content=content,
|
|
256
|
+
tool_calls=tool_calls,
|
|
257
|
+
finish_reason=finish_reason,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class ClaudeVertexModel(LLMModelAbstract):
|
|
262
|
+
"""Anthropic Claude model via Vertex AI."""
|
|
263
|
+
|
|
264
|
+
def __init__(
|
|
265
|
+
self,
|
|
266
|
+
model_name: str,
|
|
267
|
+
project_id: str,
|
|
268
|
+
location: str = "us-east5",
|
|
269
|
+
):
|
|
270
|
+
"""
|
|
271
|
+
Initialize Claude model via Vertex AI.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
model_name: Model identifier (e.g., "claude-3-5-sonnet-v2@20241022")
|
|
275
|
+
project_id: GCP project ID
|
|
276
|
+
location: GCP location (us-east5 for Claude)
|
|
277
|
+
"""
|
|
278
|
+
self._model_name = model_name
|
|
279
|
+
self.client = AsyncAnthropicVertex(
|
|
280
|
+
project_id=project_id,
|
|
281
|
+
region=location,
|
|
282
|
+
)
|
|
283
|
+
self._capabilities = self._determine_capabilities()
|
|
284
|
+
|
|
285
|
+
def _determine_capabilities(self) -> ModelCapability:
|
|
286
|
+
"""Determine capabilities based on model name."""
|
|
287
|
+
caps = (
|
|
288
|
+
ModelCapability.TEXT_GENERATION
|
|
289
|
+
| ModelCapability.STREAMING
|
|
290
|
+
| ModelCapability.TOOL_CALLING
|
|
291
|
+
| ModelCapability.STRUCTURED_OUTPUT
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Claude 3+ models support vision
|
|
295
|
+
if "claude-3" in self._model_name.lower():
|
|
296
|
+
caps |= ModelCapability.VISION | ModelCapability.MULTIMODAL_INPUT
|
|
297
|
+
|
|
298
|
+
return caps
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def model_name(self) -> str:
|
|
302
|
+
return self._model_name
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def capabilities(self) -> ModelCapability:
|
|
306
|
+
return self._capabilities
|
|
307
|
+
|
|
308
|
+
def _convert_message(self, msg: Message) -> dict:
|
|
309
|
+
"""Convert internal Message to Claude format."""
|
|
310
|
+
result = {"role": msg.role}
|
|
311
|
+
|
|
312
|
+
if msg.role == "system":
|
|
313
|
+
result["role"] = "user"
|
|
314
|
+
|
|
315
|
+
# Handle content
|
|
316
|
+
if isinstance(msg.content, str):
|
|
317
|
+
result["content"] = msg.content
|
|
318
|
+
else:
|
|
319
|
+
# Multimodal content
|
|
320
|
+
content_parts = []
|
|
321
|
+
for part in msg.content:
|
|
322
|
+
if part.type == ContentType.TEXT:
|
|
323
|
+
content_parts.append({"type": "text", "text": part.content})
|
|
324
|
+
elif part.type == ContentType.IMAGE_BASE64:
|
|
325
|
+
content_parts.append(
|
|
326
|
+
{
|
|
327
|
+
"type": "image",
|
|
328
|
+
"source": {
|
|
329
|
+
"type": "base64",
|
|
330
|
+
"media_type": part.mime_type or "image/jpeg",
|
|
331
|
+
"data": part.content,
|
|
332
|
+
},
|
|
333
|
+
}
|
|
334
|
+
)
|
|
335
|
+
result["content"] = content_parts
|
|
336
|
+
|
|
337
|
+
return result
|
|
338
|
+
|
|
339
|
+
def _convert_tools(self, tools: list[Tool]) -> list[dict]:
|
|
340
|
+
"""Convert internal Tool definitions to Claude format."""
|
|
341
|
+
return [
|
|
342
|
+
{
|
|
343
|
+
"name": tool.function.name,
|
|
344
|
+
"description": tool.function.description,
|
|
345
|
+
"input_schema": tool.function.parameters,
|
|
346
|
+
}
|
|
347
|
+
for tool in tools
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
351
|
+
"""Generate a response using Claude via Vertex AI."""
|
|
352
|
+
await self.validate_request(request)
|
|
353
|
+
|
|
354
|
+
# Extract system message
|
|
355
|
+
system_message = None
|
|
356
|
+
messages = []
|
|
357
|
+
for msg in request.messages:
|
|
358
|
+
if msg.role == "system":
|
|
359
|
+
system_message = msg.content if isinstance(msg.content, str) else ""
|
|
360
|
+
else:
|
|
361
|
+
messages.append(self._convert_message(msg))
|
|
362
|
+
|
|
363
|
+
kwargs = {
|
|
364
|
+
"model": self._model_name,
|
|
365
|
+
"messages": messages,
|
|
366
|
+
"max_tokens": request.max_tokens or 4096,
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
if system_message:
|
|
370
|
+
kwargs["system"] = system_message
|
|
371
|
+
if request.temperature is not None:
|
|
372
|
+
kwargs["temperature"] = request.temperature
|
|
373
|
+
if request.top_p is not None:
|
|
374
|
+
kwargs["top_p"] = request.top_p
|
|
375
|
+
if request.stop:
|
|
376
|
+
kwargs["stop_sequences"] = request.stop
|
|
377
|
+
if request.tools:
|
|
378
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
379
|
+
|
|
380
|
+
response = await self.client.messages.create(**kwargs)
|
|
381
|
+
|
|
382
|
+
# Extract content
|
|
383
|
+
content = None
|
|
384
|
+
text_blocks = [block.text for block in response.content if block.type == "text"]
|
|
385
|
+
if text_blocks:
|
|
386
|
+
content = "".join(text_blocks)
|
|
387
|
+
|
|
388
|
+
# Extract tool calls
|
|
389
|
+
tool_calls = None
|
|
390
|
+
tool_use_blocks = [
|
|
391
|
+
block for block in response.content if block.type == "tool_use"
|
|
392
|
+
]
|
|
393
|
+
if tool_use_blocks:
|
|
394
|
+
tool_calls = [
|
|
395
|
+
ToolCall(
|
|
396
|
+
id=block.id,
|
|
397
|
+
type="function",
|
|
398
|
+
function=FunctionCall(
|
|
399
|
+
name=block.name,
|
|
400
|
+
arguments=json.dumps(block.input),
|
|
401
|
+
),
|
|
402
|
+
)
|
|
403
|
+
for block in tool_use_blocks
|
|
404
|
+
]
|
|
405
|
+
|
|
406
|
+
return GenerateResponse(
|
|
407
|
+
content=content,
|
|
408
|
+
tool_calls=tool_calls,
|
|
409
|
+
finish_reason=response.stop_reason,
|
|
410
|
+
usage={
|
|
411
|
+
"prompt_tokens": response.usage.input_tokens,
|
|
412
|
+
"completion_tokens": response.usage.output_tokens,
|
|
413
|
+
"total_tokens": response.usage.input_tokens
|
|
414
|
+
+ response.usage.output_tokens,
|
|
415
|
+
},
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
async def generate_stream(
|
|
419
|
+
self, request: GenerateRequest
|
|
420
|
+
) -> AsyncIterator[StreamChunk]:
|
|
421
|
+
"""Generate a streaming response using Claude via Vertex AI."""
|
|
422
|
+
await self.validate_request(request)
|
|
423
|
+
|
|
424
|
+
# Extract system message
|
|
425
|
+
system_message = None
|
|
426
|
+
messages = []
|
|
427
|
+
for msg in request.messages:
|
|
428
|
+
if msg.role == "system":
|
|
429
|
+
system_message = msg.content if isinstance(msg.content, str) else ""
|
|
430
|
+
else:
|
|
431
|
+
messages.append(self._convert_message(msg))
|
|
432
|
+
|
|
433
|
+
kwargs = {
|
|
434
|
+
"model": self._model_name,
|
|
435
|
+
"messages": messages,
|
|
436
|
+
"max_tokens": request.max_tokens or 4096,
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
if system_message:
|
|
440
|
+
kwargs["system"] = system_message
|
|
441
|
+
if request.temperature is not None:
|
|
442
|
+
kwargs["temperature"] = request.temperature
|
|
443
|
+
if request.top_p is not None:
|
|
444
|
+
kwargs["top_p"] = request.top_p
|
|
445
|
+
if request.stop:
|
|
446
|
+
kwargs["stop_sequences"] = request.stop
|
|
447
|
+
if request.tools:
|
|
448
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
449
|
+
|
|
450
|
+
async with self.client.messages.stream(**kwargs) as stream:
|
|
451
|
+
async for event in stream:
|
|
452
|
+
content = None
|
|
453
|
+
tool_calls = None
|
|
454
|
+
finish_reason = None
|
|
455
|
+
|
|
456
|
+
if event.type == "content_block_delta":
|
|
457
|
+
if hasattr(event.delta, "text"):
|
|
458
|
+
content = event.delta.text
|
|
459
|
+
|
|
460
|
+
elif event.type == "content_block_stop":
|
|
461
|
+
if (
|
|
462
|
+
hasattr(event, "content_block")
|
|
463
|
+
and event.content_block.type == "tool_use"
|
|
464
|
+
):
|
|
465
|
+
tool_calls = [
|
|
466
|
+
ToolCall(
|
|
467
|
+
id=event.content_block.id,
|
|
468
|
+
type="function",
|
|
469
|
+
function=FunctionCall(
|
|
470
|
+
name=event.content_block.name,
|
|
471
|
+
arguments=json.dumps(event.content_block.input),
|
|
472
|
+
),
|
|
473
|
+
)
|
|
474
|
+
]
|
|
475
|
+
|
|
476
|
+
elif event.type == "message_stop":
|
|
477
|
+
finish_reason = "stop"
|
|
478
|
+
|
|
479
|
+
if content or tool_calls or finish_reason:
|
|
480
|
+
yield StreamChunk(
|
|
481
|
+
content=content,
|
|
482
|
+
tool_calls=tool_calls,
|
|
483
|
+
finish_reason=finish_reason,
|
|
484
|
+
)
|