donkit-llm 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- donkit/llm/__init__.py +5 -0
- donkit/llm/claude_model.py +7 -5
- donkit/llm/donkit_model.py +239 -0
- donkit/llm/factory.py +105 -14
- donkit/llm/gemini_model.py +406 -0
- donkit/llm/model_abstract.py +27 -17
- donkit/llm/ollama_integration.py +442 -0
- donkit/llm/openai_model.py +179 -92
- donkit/llm/vertex_model.py +446 -178
- {donkit_llm-0.1.1.dist-info → donkit_llm-0.1.3.dist-info}/METADATA +3 -2
- donkit_llm-0.1.3.dist-info/RECORD +12 -0
- {donkit_llm-0.1.1.dist-info → donkit_llm-0.1.3.dist-info}/WHEEL +1 -1
- donkit_llm-0.1.1.dist-info/RECORD +0 -9
donkit/llm/vertex_model.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import base64
|
|
3
|
-
from typing import AsyncIterator
|
|
3
|
+
from typing import Any, AsyncIterator
|
|
4
4
|
|
|
5
5
|
import google.genai as genai
|
|
6
6
|
from google.genai.types import Blob, Content, FunctionDeclaration, Part
|
|
@@ -32,6 +32,8 @@ class VertexAIModel(LLMModelAbstract):
|
|
|
32
32
|
- Claude models via Vertex AI (claude-3-5-sonnet-v2@20241022, etc.)
|
|
33
33
|
"""
|
|
34
34
|
|
|
35
|
+
name = "vertex"
|
|
36
|
+
|
|
35
37
|
def __init__(
|
|
36
38
|
self,
|
|
37
39
|
project_id: str,
|
|
@@ -104,9 +106,9 @@ class VertexAIModel(LLMModelAbstract):
|
|
|
104
106
|
else:
|
|
105
107
|
# Multimodal content
|
|
106
108
|
for part in msg.content:
|
|
107
|
-
if part.
|
|
109
|
+
if part.content_type == ContentType.TEXT:
|
|
108
110
|
parts.append(Part(text=part.content))
|
|
109
|
-
elif part.
|
|
111
|
+
elif part.content_type == ContentType.IMAGE_URL:
|
|
110
112
|
# For URLs, we'd need to fetch and convert to inline data
|
|
111
113
|
parts.append(
|
|
112
114
|
Part(
|
|
@@ -116,7 +118,7 @@ class VertexAIModel(LLMModelAbstract):
|
|
|
116
118
|
)
|
|
117
119
|
)
|
|
118
120
|
)
|
|
119
|
-
elif part.
|
|
121
|
+
elif part.content_type == ContentType.IMAGE_BASE64:
|
|
120
122
|
# part.content is base64 string; Vertex needs raw bytes
|
|
121
123
|
raw = base64.b64decode(part.content, validate=True)
|
|
122
124
|
parts.append(
|
|
@@ -127,7 +129,7 @@ class VertexAIModel(LLMModelAbstract):
|
|
|
127
129
|
)
|
|
128
130
|
)
|
|
129
131
|
)
|
|
130
|
-
elif part.
|
|
132
|
+
elif part.content_type == ContentType.AUDIO_BASE64:
|
|
131
133
|
raw = base64.b64decode(part.content, validate=True)
|
|
132
134
|
parts.append(
|
|
133
135
|
Part(
|
|
@@ -137,7 +139,7 @@ class VertexAIModel(LLMModelAbstract):
|
|
|
137
139
|
)
|
|
138
140
|
)
|
|
139
141
|
)
|
|
140
|
-
elif part.
|
|
142
|
+
elif part.content_type == ContentType.FILE_BASE64:
|
|
141
143
|
raw = base64.b64decode(part.content, validate=True)
|
|
142
144
|
parts.append(
|
|
143
145
|
Part(
|
|
@@ -167,130 +169,342 @@ class VertexAIModel(LLMModelAbstract):
|
|
|
167
169
|
|
|
168
170
|
return [GeminiTool(function_declarations=function_declarations)]
|
|
169
171
|
|
|
170
|
-
def
|
|
172
|
+
def _parse_response(self, response) -> tuple[str | None, list[ToolCall] | None]:
|
|
173
|
+
"""Parse a genai response (or stream chunk) into plain text and tool calls."""
|
|
174
|
+
calls: list[ToolCall] = []
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
cand_list = response.candidates
|
|
178
|
+
except AttributeError:
|
|
179
|
+
cand_list = None
|
|
180
|
+
if not cand_list:
|
|
181
|
+
return None, None
|
|
182
|
+
|
|
183
|
+
cand = cand_list[0]
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
parts = cand.content.parts or []
|
|
187
|
+
except AttributeError:
|
|
188
|
+
parts = []
|
|
189
|
+
|
|
190
|
+
# Collect text and tool calls in a single pass
|
|
191
|
+
collected_text: list[str] = []
|
|
192
|
+
for p in parts:
|
|
193
|
+
# Try to get text from this part
|
|
194
|
+
try:
|
|
195
|
+
t = p.text
|
|
196
|
+
if t:
|
|
197
|
+
collected_text.append(t)
|
|
198
|
+
except AttributeError:
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
# Try to get function_call from this part
|
|
202
|
+
try:
|
|
203
|
+
fc = p.function_call
|
|
204
|
+
if fc:
|
|
205
|
+
# Extract function name and arguments
|
|
206
|
+
try:
|
|
207
|
+
name = fc.name
|
|
208
|
+
except AttributeError:
|
|
209
|
+
name = ""
|
|
210
|
+
|
|
211
|
+
if not name:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
args = dict(fc.args) if fc.args else {}
|
|
216
|
+
except (AttributeError, TypeError):
|
|
217
|
+
args = {}
|
|
218
|
+
|
|
219
|
+
calls.append(
|
|
220
|
+
ToolCall(
|
|
221
|
+
id=name,
|
|
222
|
+
type="function",
|
|
223
|
+
function=FunctionCall(
|
|
224
|
+
name=name,
|
|
225
|
+
arguments=json.dumps(args),
|
|
226
|
+
),
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
except AttributeError:
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
text = "".join(collected_text)
|
|
233
|
+
return text or None, calls or None
|
|
234
|
+
|
|
235
|
+
def _clean_json_schema(self, schema: dict | None) -> dict:
|
|
171
236
|
"""
|
|
172
|
-
|
|
237
|
+
Transform an arbitrary JSON Schema-like dict (possibly produced by Pydantic)
|
|
238
|
+
into a format compatible with google.genai by:
|
|
239
|
+
- Inlining $ref by replacing references with actual schemas from $defs
|
|
240
|
+
- Removing $defs after inlining all references
|
|
241
|
+
- Renaming unsupported keys to the SDK's expected snake_case
|
|
242
|
+
- Recursively converting nested schemas (properties, items, anyOf)
|
|
243
|
+
- Preserving fields supported by the SDK Schema model
|
|
173
244
|
"""
|
|
174
245
|
if not isinstance(schema, dict):
|
|
175
|
-
return
|
|
246
|
+
return {}
|
|
247
|
+
|
|
248
|
+
# Step 1: Inline $ref references before any conversion
|
|
249
|
+
defs = schema.get("$defs", {})
|
|
250
|
+
|
|
251
|
+
def inline_refs(obj, definitions):
|
|
252
|
+
"""Recursively inline $ref references."""
|
|
253
|
+
if isinstance(obj, dict):
|
|
254
|
+
# If this object has a $ref, replace it with the referenced schema
|
|
255
|
+
if "$ref" in obj:
|
|
256
|
+
ref_path = obj["$ref"]
|
|
257
|
+
if ref_path.startswith("#/$defs/"):
|
|
258
|
+
ref_name = ref_path.replace("#/$defs/", "")
|
|
259
|
+
if ref_name in definitions:
|
|
260
|
+
# Get the referenced schema and inline it recursively
|
|
261
|
+
referenced = definitions[ref_name].copy()
|
|
262
|
+
# Preserve description and default from the referencing object
|
|
263
|
+
if "description" in obj and "description" not in referenced:
|
|
264
|
+
referenced["description"] = obj["description"]
|
|
265
|
+
if "default" in obj:
|
|
266
|
+
referenced["default"] = obj["default"]
|
|
267
|
+
return inline_refs(referenced, definitions)
|
|
268
|
+
# If can't resolve, remove the $ref
|
|
269
|
+
return {k: v for k, v in obj.items() if k != "$ref"}
|
|
270
|
+
|
|
271
|
+
# Recursively process all properties
|
|
272
|
+
result = {}
|
|
273
|
+
for key, value in obj.items():
|
|
274
|
+
if key == "$defs":
|
|
275
|
+
continue # Remove $defs after inlining
|
|
276
|
+
# Skip additionalProperties: true as it's not well supported
|
|
277
|
+
if key == "additionalProperties" and value is True:
|
|
278
|
+
continue
|
|
279
|
+
result[key] = inline_refs(value, definitions)
|
|
280
|
+
return result
|
|
281
|
+
elif isinstance(obj, list):
|
|
282
|
+
return [inline_refs(item, definitions) for item in obj]
|
|
283
|
+
else:
|
|
284
|
+
return obj
|
|
285
|
+
|
|
286
|
+
# Inline all references
|
|
287
|
+
schema = inline_refs(schema, defs)
|
|
288
|
+
|
|
289
|
+
# Step 2: Convert to SDK schema format
|
|
290
|
+
# Mapping from common JSON Schema/OpenAPI keys to google-genai Schema fields
|
|
291
|
+
key_map = {
|
|
292
|
+
"anyOf": "any_of",
|
|
293
|
+
"additionalProperties": "additional_properties",
|
|
294
|
+
"maxItems": "max_items",
|
|
295
|
+
"maxLength": "max_length",
|
|
296
|
+
"maxProperties": "max_properties",
|
|
297
|
+
"minItems": "min_items",
|
|
298
|
+
"minLength": "min_length",
|
|
299
|
+
"minProperties": "min_properties",
|
|
300
|
+
"propertyOrdering": "property_ordering",
|
|
301
|
+
}
|
|
176
302
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
303
|
+
def convert(obj):
|
|
304
|
+
if isinstance(obj, dict):
|
|
305
|
+
out: dict[str, object] = {}
|
|
306
|
+
for k, v in obj.items():
|
|
307
|
+
if k == "const":
|
|
308
|
+
out["enum"] = [v]
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
kk = key_map.get(k, k)
|
|
312
|
+
if kk == "properties" and isinstance(v, dict):
|
|
313
|
+
# properties: dict[str, Schema]
|
|
314
|
+
out[kk] = {pk: convert(pv) for pk, pv in v.items()}
|
|
315
|
+
elif kk == "items":
|
|
316
|
+
# items: Schema (treat list as first item schema)
|
|
317
|
+
if isinstance(v, list) and v:
|
|
318
|
+
out[kk] = convert(v[0])
|
|
319
|
+
else:
|
|
320
|
+
out[kk] = convert(v)
|
|
321
|
+
elif kk == "any_of" and isinstance(v, list):
|
|
322
|
+
out[kk] = [convert(iv) for iv in v]
|
|
323
|
+
else:
|
|
324
|
+
out[kk] = convert(v)
|
|
325
|
+
return out
|
|
326
|
+
elif isinstance(obj, list):
|
|
327
|
+
return [convert(i) for i in obj]
|
|
188
328
|
else:
|
|
189
|
-
|
|
329
|
+
return obj
|
|
190
330
|
|
|
191
|
-
return
|
|
331
|
+
return convert(schema)
|
|
192
332
|
|
|
193
333
|
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
194
334
|
"""Generate a response using Vertex AI."""
|
|
195
335
|
await self.validate_request(request)
|
|
196
336
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
337
|
+
def _safe_text(text: str) -> str:
|
|
338
|
+
try:
|
|
339
|
+
return text.encode("utf-8", errors="replace").decode(
|
|
340
|
+
"utf-8", errors="replace"
|
|
341
|
+
)
|
|
342
|
+
except Exception:
|
|
343
|
+
return ""
|
|
344
|
+
|
|
345
|
+
contents: list[Content] = []
|
|
346
|
+
system_instruction = ""
|
|
347
|
+
|
|
348
|
+
# Group consecutive tool messages into single Content
|
|
349
|
+
i = 0
|
|
350
|
+
while i < len(request.messages):
|
|
351
|
+
m = request.messages[i]
|
|
352
|
+
|
|
353
|
+
if m.role == "tool":
|
|
354
|
+
# Collect all consecutive tool messages
|
|
355
|
+
tool_parts = []
|
|
356
|
+
while i < len(request.messages) and request.messages[i].role == "tool":
|
|
357
|
+
tool_msg = request.messages[i]
|
|
358
|
+
content_str = (
|
|
359
|
+
tool_msg.content
|
|
360
|
+
if isinstance(tool_msg.content, str)
|
|
361
|
+
else str(tool_msg.content)
|
|
362
|
+
)
|
|
363
|
+
part = Part.from_function_response(
|
|
364
|
+
name=getattr(tool_msg, "name", "") or "",
|
|
365
|
+
response={"result": _safe_text(content_str)},
|
|
366
|
+
)
|
|
367
|
+
tool_parts.append(part)
|
|
368
|
+
i += 1
|
|
369
|
+
# Add all tool responses as a single Content
|
|
370
|
+
if tool_parts:
|
|
371
|
+
contents.append(Content(role="function", parts=tool_parts))
|
|
372
|
+
continue
|
|
373
|
+
elif m.role == "system":
|
|
374
|
+
content_str = (
|
|
375
|
+
m.content if isinstance(m.content, str) else str(m.content)
|
|
376
|
+
)
|
|
377
|
+
system_instruction += _safe_text(content_str).strip()
|
|
378
|
+
i += 1
|
|
379
|
+
elif m.role == "assistant":
|
|
380
|
+
# Check if message has tool_calls attribute
|
|
381
|
+
if hasattr(m, "tool_calls") and m.tool_calls:
|
|
382
|
+
# Assistant message with tool calls
|
|
383
|
+
parts_list = []
|
|
384
|
+
for tc in m.tool_calls:
|
|
385
|
+
if not tc.function.name:
|
|
386
|
+
continue
|
|
387
|
+
args = (
|
|
388
|
+
json.loads(tc.function.arguments)
|
|
389
|
+
if isinstance(tc.function.arguments, str)
|
|
390
|
+
else tc.function.arguments
|
|
391
|
+
)
|
|
392
|
+
if not isinstance(args, dict):
|
|
393
|
+
args = {}
|
|
394
|
+
part = Part.from_function_call(name=tc.function.name, args=args)
|
|
395
|
+
parts_list.append(part)
|
|
396
|
+
if parts_list:
|
|
397
|
+
contents.append(Content(role="model", parts=parts_list))
|
|
398
|
+
else:
|
|
399
|
+
# Regular assistant text response
|
|
400
|
+
content_str = (
|
|
401
|
+
m.content if isinstance(m.content, str) else str(m.content)
|
|
402
|
+
)
|
|
403
|
+
if content_str:
|
|
404
|
+
part = Part(text=_safe_text(content_str))
|
|
405
|
+
contents.append(Content(role="model", parts=[part]))
|
|
406
|
+
i += 1
|
|
203
407
|
else:
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
408
|
+
# User message - use _convert_message to handle multimodal content
|
|
409
|
+
user_content = self._convert_message(m)
|
|
410
|
+
contents.append(user_content)
|
|
411
|
+
i += 1
|
|
412
|
+
|
|
413
|
+
config_kwargs = {
|
|
414
|
+
"temperature": request.temperature
|
|
415
|
+
if request.temperature is not None
|
|
416
|
+
else 0.2,
|
|
417
|
+
"top_p": request.top_p if request.top_p is not None else 0.95,
|
|
418
|
+
"max_output_tokens": request.max_tokens
|
|
419
|
+
if request.max_tokens is not None
|
|
420
|
+
else 8192,
|
|
421
|
+
}
|
|
422
|
+
if system_instruction:
|
|
423
|
+
config_kwargs["system_instruction"] = system_instruction
|
|
213
424
|
if request.stop:
|
|
214
425
|
config_kwargs["stop_sequences"] = request.stop
|
|
215
426
|
if request.response_format:
|
|
216
|
-
# Vertex AI uses response_mime_type and response_schema
|
|
217
427
|
config_kwargs["response_mime_type"] = "application/json"
|
|
218
428
|
if "schema" in request.response_format:
|
|
219
429
|
config_kwargs["response_schema"] = self._clean_json_schema(
|
|
220
430
|
request.response_format["schema"]
|
|
221
431
|
)
|
|
222
432
|
|
|
223
|
-
|
|
224
|
-
config = (
|
|
225
|
-
genai.types.GenerateContentConfig(**config_kwargs)
|
|
226
|
-
if config_kwargs
|
|
227
|
-
else None
|
|
228
|
-
)
|
|
433
|
+
config = genai.types.GenerateContentConfig(**config_kwargs)
|
|
229
434
|
|
|
230
|
-
# Add tools to config if present
|
|
231
435
|
if request.tools:
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
config.system_instruction = system_instruction
|
|
241
|
-
|
|
242
|
-
response = await self.client.aio.models.generate_content(
|
|
243
|
-
model=self._model_name,
|
|
244
|
-
contents=messages,
|
|
245
|
-
config=config,
|
|
246
|
-
)
|
|
247
|
-
# Extract content
|
|
248
|
-
content = None
|
|
249
|
-
if response.text:
|
|
250
|
-
content = response.text
|
|
251
|
-
|
|
252
|
-
# Extract tool calls
|
|
253
|
-
tool_calls = None
|
|
254
|
-
if response.candidates and response.candidates[0].content.parts:
|
|
255
|
-
function_calls = []
|
|
256
|
-
for part in response.candidates[0].content.parts:
|
|
257
|
-
if not hasattr(part, "function_call") or not part.function_call:
|
|
258
|
-
continue
|
|
259
|
-
fc = part.function_call
|
|
260
|
-
args_dict = dict(fc.args) if fc.args else {}
|
|
261
|
-
function_calls.append(
|
|
262
|
-
ToolCall(
|
|
263
|
-
id=fc.name,
|
|
264
|
-
type="function",
|
|
265
|
-
function=FunctionCall(
|
|
266
|
-
name=fc.name,
|
|
267
|
-
arguments=json.dumps(args_dict),
|
|
268
|
-
),
|
|
436
|
+
function_declarations: list[FunctionDeclaration] = []
|
|
437
|
+
for t in request.tools:
|
|
438
|
+
schema_obj = self._clean_json_schema(t.function.parameters or {})
|
|
439
|
+
function_declarations.append(
|
|
440
|
+
FunctionDeclaration(
|
|
441
|
+
name=t.function.name,
|
|
442
|
+
description=t.function.description or "",
|
|
443
|
+
parameters=schema_obj,
|
|
269
444
|
)
|
|
270
445
|
)
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
446
|
+
gen_tools = [GeminiTool(function_declarations=function_declarations)]
|
|
447
|
+
config.tools = gen_tools
|
|
448
|
+
|
|
449
|
+
try:
|
|
450
|
+
response = await self.client.aio.models.generate_content(
|
|
451
|
+
model=self._model_name,
|
|
452
|
+
contents=contents,
|
|
453
|
+
config=config,
|
|
454
|
+
)
|
|
455
|
+
text, tool_calls = self._parse_response(response)
|
|
456
|
+
|
|
457
|
+
# If no text and no tool calls, check for errors in response
|
|
458
|
+
if not text and not tool_calls:
|
|
459
|
+
try:
|
|
460
|
+
# Check for blocking reasons or errors
|
|
461
|
+
if hasattr(response, "candidates") and response.candidates:
|
|
462
|
+
cand = response.candidates[0]
|
|
463
|
+
if hasattr(cand, "finish_reason") and cand.finish_reason:
|
|
464
|
+
finish_reason = cand.finish_reason
|
|
465
|
+
if finish_reason not in ("STOP", None):
|
|
466
|
+
error_msg = (
|
|
467
|
+
f"Model finished with reason: {finish_reason}"
|
|
468
|
+
)
|
|
469
|
+
return GenerateResponse(content=f"Warning: {error_msg}")
|
|
470
|
+
# Check for safety ratings that might block content
|
|
471
|
+
if hasattr(response, "candidates") and response.candidates:
|
|
472
|
+
cand = response.candidates[0]
|
|
473
|
+
if hasattr(cand, "safety_ratings"):
|
|
474
|
+
blocked = any(
|
|
475
|
+
getattr(r, "blocked", False)
|
|
476
|
+
for r in getattr(cand, "safety_ratings", [])
|
|
477
|
+
)
|
|
478
|
+
if blocked:
|
|
479
|
+
error_msg = "Response was blocked by safety filters"
|
|
480
|
+
return GenerateResponse(content=f"Warning: {error_msg}")
|
|
481
|
+
except Exception:
|
|
482
|
+
pass # If we can't check, just return empty
|
|
483
|
+
|
|
484
|
+
# Extract finish reason
|
|
485
|
+
finish_reason = None
|
|
486
|
+
if response.candidates:
|
|
487
|
+
finish_reason = str(response.candidates[0].finish_reason)
|
|
488
|
+
|
|
489
|
+
# Extract usage
|
|
490
|
+
usage = None
|
|
491
|
+
if response.usage_metadata:
|
|
492
|
+
usage = {
|
|
493
|
+
"prompt_tokens": response.usage_metadata.prompt_token_count,
|
|
494
|
+
"completion_tokens": response.usage_metadata.candidates_token_count,
|
|
495
|
+
"total_tokens": response.usage_metadata.total_token_count,
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
return GenerateResponse(
|
|
499
|
+
content=text,
|
|
500
|
+
tool_calls=tool_calls,
|
|
501
|
+
finish_reason=finish_reason,
|
|
502
|
+
usage=usage,
|
|
503
|
+
)
|
|
504
|
+
except Exception as e:
|
|
505
|
+
error_msg = str(e)
|
|
506
|
+
# Return error message instead of empty response
|
|
507
|
+
return GenerateResponse(content=f"Error: {error_msg}")
|
|
294
508
|
|
|
295
509
|
async def generate_stream(
|
|
296
510
|
self, request: GenerateRequest
|
|
@@ -298,22 +512,89 @@ class VertexAIModel(LLMModelAbstract):
|
|
|
298
512
|
"""Generate a streaming response using Vertex AI."""
|
|
299
513
|
await self.validate_request(request)
|
|
300
514
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
515
|
+
def _safe_text(text: str) -> str:
|
|
516
|
+
try:
|
|
517
|
+
return text.encode("utf-8", errors="replace").decode(
|
|
518
|
+
"utf-8", errors="replace"
|
|
519
|
+
)
|
|
520
|
+
except Exception:
|
|
521
|
+
return ""
|
|
522
|
+
|
|
523
|
+
contents: list[Content] = []
|
|
524
|
+
system_instruction = ""
|
|
525
|
+
|
|
526
|
+
# Convert messages to genai format (same logic as generate())
|
|
527
|
+
i = 0
|
|
528
|
+
while i < len(request.messages):
|
|
529
|
+
m = request.messages[i]
|
|
530
|
+
|
|
531
|
+
if m.role == "tool":
|
|
532
|
+
# Collect all consecutive tool messages
|
|
533
|
+
tool_parts = []
|
|
534
|
+
while i < len(request.messages) and request.messages[i].role == "tool":
|
|
535
|
+
tool_msg = request.messages[i]
|
|
536
|
+
content_str = (
|
|
537
|
+
tool_msg.content
|
|
538
|
+
if isinstance(tool_msg.content, str)
|
|
539
|
+
else str(tool_msg.content)
|
|
540
|
+
)
|
|
541
|
+
part = Part.from_function_response(
|
|
542
|
+
name=getattr(tool_msg, "name", "") or "",
|
|
543
|
+
response={"result": _safe_text(content_str)},
|
|
544
|
+
)
|
|
545
|
+
tool_parts.append(part)
|
|
546
|
+
i += 1
|
|
547
|
+
if tool_parts:
|
|
548
|
+
contents.append(Content(role="function", parts=tool_parts))
|
|
549
|
+
continue
|
|
550
|
+
elif m.role == "system":
|
|
551
|
+
content_str = (
|
|
552
|
+
m.content if isinstance(m.content, str) else str(m.content)
|
|
553
|
+
)
|
|
554
|
+
system_instruction += _safe_text(content_str).strip()
|
|
555
|
+
i += 1
|
|
556
|
+
elif m.role == "assistant":
|
|
557
|
+
if hasattr(m, "tool_calls") and m.tool_calls:
|
|
558
|
+
parts_list = []
|
|
559
|
+
for tc in m.tool_calls:
|
|
560
|
+
if not tc.function.name:
|
|
561
|
+
continue
|
|
562
|
+
args = (
|
|
563
|
+
json.loads(tc.function.arguments)
|
|
564
|
+
if isinstance(tc.function.arguments, str)
|
|
565
|
+
else tc.function.arguments
|
|
566
|
+
)
|
|
567
|
+
if not isinstance(args, dict):
|
|
568
|
+
args = {}
|
|
569
|
+
part = Part.from_function_call(name=tc.function.name, args=args)
|
|
570
|
+
parts_list.append(part)
|
|
571
|
+
if parts_list:
|
|
572
|
+
contents.append(Content(role="model", parts=parts_list))
|
|
573
|
+
else:
|
|
574
|
+
content_str = (
|
|
575
|
+
m.content if isinstance(m.content, str) else str(m.content)
|
|
576
|
+
)
|
|
577
|
+
if content_str:
|
|
578
|
+
part = Part(text=_safe_text(content_str))
|
|
579
|
+
contents.append(Content(role="model", parts=[part]))
|
|
580
|
+
i += 1
|
|
307
581
|
else:
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
582
|
+
# User message - use _convert_message to handle multimodal content
|
|
583
|
+
user_content = self._convert_message(m)
|
|
584
|
+
contents.append(user_content)
|
|
585
|
+
i += 1
|
|
586
|
+
|
|
587
|
+
config_kwargs: dict[str, Any] = {
|
|
588
|
+
"temperature": request.temperature
|
|
589
|
+
if request.temperature is not None
|
|
590
|
+
else 0.2,
|
|
591
|
+
"top_p": request.top_p if request.top_p is not None else 0.95,
|
|
592
|
+
"max_output_tokens": request.max_tokens
|
|
593
|
+
if request.max_tokens is not None
|
|
594
|
+
else 8192,
|
|
595
|
+
}
|
|
596
|
+
if system_instruction:
|
|
597
|
+
config_kwargs["system_instruction"] = system_instruction
|
|
317
598
|
if request.stop:
|
|
318
599
|
config_kwargs["stop_sequences"] = request.stop
|
|
319
600
|
if request.response_format:
|
|
@@ -322,70 +603,49 @@ class VertexAIModel(LLMModelAbstract):
|
|
|
322
603
|
config_kwargs["response_schema"] = self._clean_json_schema(
|
|
323
604
|
request.response_format["schema"]
|
|
324
605
|
)
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
config = (
|
|
328
|
-
genai.types.GenerateContentConfig(**config_kwargs)
|
|
329
|
-
if config_kwargs
|
|
330
|
-
else None
|
|
606
|
+
config_kwargs["automatic_function_calling"] = (
|
|
607
|
+
genai.types.AutomaticFunctionCallingConfig(maximum_remote_calls=100)
|
|
331
608
|
)
|
|
332
609
|
|
|
333
|
-
|
|
610
|
+
config = genai.types.GenerateContentConfig(**config_kwargs)
|
|
611
|
+
|
|
334
612
|
if request.tools:
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
613
|
+
function_declarations: list[FunctionDeclaration] = []
|
|
614
|
+
for t in request.tools:
|
|
615
|
+
schema_obj = self._clean_json_schema(t.function.parameters or {})
|
|
616
|
+
function_declarations.append(
|
|
617
|
+
FunctionDeclaration(
|
|
618
|
+
name=t.function.name,
|
|
619
|
+
description=t.function.description or "",
|
|
620
|
+
parameters=schema_obj,
|
|
621
|
+
)
|
|
622
|
+
)
|
|
623
|
+
gen_tools = [GeminiTool(function_declarations=function_declarations)]
|
|
624
|
+
config.tools = gen_tools
|
|
625
|
+
|
|
626
|
+
try:
|
|
627
|
+
# Use generate_content_stream for streaming
|
|
628
|
+
stream = await self.client.aio.models.generate_content_stream(
|
|
629
|
+
model=self._model_name,
|
|
630
|
+
contents=contents,
|
|
631
|
+
config=config,
|
|
632
|
+
)
|
|
338
633
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
if config is None:
|
|
342
|
-
config = genai.types.GenerateContentConfig()
|
|
343
|
-
config.system_instruction = system_instruction
|
|
344
|
-
|
|
345
|
-
model_name = self._model_name
|
|
346
|
-
stream = await self.client.aio.models.generate_content_stream(
|
|
347
|
-
model=model_name,
|
|
348
|
-
contents=messages,
|
|
349
|
-
config=config,
|
|
350
|
-
)
|
|
634
|
+
async for chunk in stream:
|
|
635
|
+
text, tool_calls = self._parse_response(chunk)
|
|
351
636
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
content = chunk.text
|
|
356
|
-
|
|
357
|
-
# Extract tool calls from chunk
|
|
358
|
-
tool_calls = None
|
|
359
|
-
if chunk.candidates and chunk.candidates[0].content.parts:
|
|
360
|
-
function_calls = []
|
|
361
|
-
for part in chunk.candidates[0].content.parts:
|
|
362
|
-
if not hasattr(part, "function_call") or not part.function_call:
|
|
363
|
-
continue
|
|
364
|
-
fc = part.function_call
|
|
365
|
-
args_dict = dict(fc.args) if fc.args else {}
|
|
366
|
-
function_calls.append(
|
|
367
|
-
ToolCall(
|
|
368
|
-
id=fc.name,
|
|
369
|
-
type="function",
|
|
370
|
-
function=FunctionCall(
|
|
371
|
-
name=fc.name,
|
|
372
|
-
arguments=json.dumps(args_dict),
|
|
373
|
-
),
|
|
374
|
-
)
|
|
375
|
-
)
|
|
376
|
-
if function_calls:
|
|
377
|
-
tool_calls = function_calls
|
|
637
|
+
# Yield text chunks as they come
|
|
638
|
+
if text:
|
|
639
|
+
yield StreamChunk(content=text, tool_calls=None)
|
|
378
640
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
finish_reason=finish_reason,
|
|
388
|
-
)
|
|
641
|
+
# Tool calls come in final chunk - yield them separately
|
|
642
|
+
if tool_calls:
|
|
643
|
+
yield StreamChunk(content=None, tool_calls=tool_calls)
|
|
644
|
+
|
|
645
|
+
except Exception as e:
|
|
646
|
+
error_msg = str(e)
|
|
647
|
+
# Yield error message instead of empty response
|
|
648
|
+
yield StreamChunk(content=f"Error: {error_msg}")
|
|
389
649
|
|
|
390
650
|
|
|
391
651
|
class VertexEmbeddingModel(LLMModelAbstract):
|
|
@@ -428,6 +688,10 @@ class VertexEmbeddingModel(LLMModelAbstract):
|
|
|
428
688
|
def model_name(self) -> str:
|
|
429
689
|
return self._model_name
|
|
430
690
|
|
|
691
|
+
@model_name.setter
|
|
692
|
+
def model_name(self, model_name: str) -> None:
|
|
693
|
+
self._model_name = model_name
|
|
694
|
+
|
|
431
695
|
@property
|
|
432
696
|
def capabilities(self) -> ModelCapability:
|
|
433
697
|
return ModelCapability.EMBEDDINGS
|
|
@@ -475,4 +739,8 @@ class VertexEmbeddingModel(LLMModelAbstract):
|
|
|
475
739
|
return EmbeddingResponse(
|
|
476
740
|
embeddings=all_embeddings,
|
|
477
741
|
usage=None,
|
|
742
|
+
metadata={
|
|
743
|
+
"dimensions": len(all_embeddings[0]) if len(all_embeddings) > 0 else 0,
|
|
744
|
+
"batch_size": self._batch_size,
|
|
745
|
+
},
|
|
478
746
|
)
|