donkit-llm 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  import base64
3
- from typing import AsyncIterator
3
+ from typing import Any, AsyncIterator
4
4
 
5
5
  import google.genai as genai
6
6
  from google.genai.types import Blob, Content, FunctionDeclaration, Part
@@ -32,6 +32,8 @@ class VertexAIModel(LLMModelAbstract):
32
32
  - Claude models via Vertex AI (claude-3-5-sonnet-v2@20241022, etc.)
33
33
  """
34
34
 
35
+ name = "vertex"
36
+
35
37
  def __init__(
36
38
  self,
37
39
  project_id: str,
@@ -104,9 +106,9 @@ class VertexAIModel(LLMModelAbstract):
104
106
  else:
105
107
  # Multimodal content
106
108
  for part in msg.content:
107
- if part.type == ContentType.TEXT:
109
+ if part.content_type == ContentType.TEXT:
108
110
  parts.append(Part(text=part.content))
109
- elif part.type == ContentType.IMAGE_URL:
111
+ elif part.content_type == ContentType.IMAGE_URL:
110
112
  # For URLs, we'd need to fetch and convert to inline data
111
113
  parts.append(
112
114
  Part(
@@ -116,7 +118,7 @@ class VertexAIModel(LLMModelAbstract):
116
118
  )
117
119
  )
118
120
  )
119
- elif part.type == ContentType.IMAGE_BASE64:
121
+ elif part.content_type == ContentType.IMAGE_BASE64:
120
122
  # part.content is base64 string; Vertex needs raw bytes
121
123
  raw = base64.b64decode(part.content, validate=True)
122
124
  parts.append(
@@ -127,7 +129,7 @@ class VertexAIModel(LLMModelAbstract):
127
129
  )
128
130
  )
129
131
  )
130
- elif part.type == ContentType.AUDIO_BASE64:
132
+ elif part.content_type == ContentType.AUDIO_BASE64:
131
133
  raw = base64.b64decode(part.content, validate=True)
132
134
  parts.append(
133
135
  Part(
@@ -137,7 +139,7 @@ class VertexAIModel(LLMModelAbstract):
137
139
  )
138
140
  )
139
141
  )
140
- elif part.type == ContentType.FILE_BASE64:
142
+ elif part.content_type == ContentType.FILE_BASE64:
141
143
  raw = base64.b64decode(part.content, validate=True)
142
144
  parts.append(
143
145
  Part(
@@ -167,130 +169,342 @@ class VertexAIModel(LLMModelAbstract):
167
169
 
168
170
  return [GeminiTool(function_declarations=function_declarations)]
169
171
 
170
- def _clean_json_schema(self, schema: dict) -> dict:
172
+ def _parse_response(self, response) -> tuple[str | None, list[ToolCall] | None]:
173
+ """Parse a genai response (or stream chunk) into plain text and tool calls."""
174
+ calls: list[ToolCall] = []
175
+
176
+ try:
177
+ cand_list = response.candidates
178
+ except AttributeError:
179
+ cand_list = None
180
+ if not cand_list:
181
+ return None, None
182
+
183
+ cand = cand_list[0]
184
+
185
+ try:
186
+ parts = cand.content.parts or []
187
+ except AttributeError:
188
+ parts = []
189
+
190
+ # Collect text and tool calls in a single pass
191
+ collected_text: list[str] = []
192
+ for p in parts:
193
+ # Try to get text from this part
194
+ try:
195
+ t = p.text
196
+ if t:
197
+ collected_text.append(t)
198
+ except AttributeError:
199
+ pass
200
+
201
+ # Try to get function_call from this part
202
+ try:
203
+ fc = p.function_call
204
+ if fc:
205
+ # Extract function name and arguments
206
+ try:
207
+ name = fc.name
208
+ except AttributeError:
209
+ name = ""
210
+
211
+ if not name:
212
+ continue
213
+
214
+ try:
215
+ args = dict(fc.args) if fc.args else {}
216
+ except (AttributeError, TypeError):
217
+ args = {}
218
+
219
+ calls.append(
220
+ ToolCall(
221
+ id=name,
222
+ type="function",
223
+ function=FunctionCall(
224
+ name=name,
225
+ arguments=json.dumps(args),
226
+ ),
227
+ )
228
+ )
229
+ except AttributeError:
230
+ pass
231
+
232
+ text = "".join(collected_text)
233
+ return text or None, calls or None
234
+
235
+ def _clean_json_schema(self, schema: dict | None) -> dict:
171
236
  """
172
- Remove $ref and $defs from JSON Schema as Vertex AI doesn't support them.
237
+ Transform an arbitrary JSON Schema-like dict (possibly produced by Pydantic)
238
+ into a format compatible with google.genai by:
239
+ - Inlining $ref by replacing references with actual schemas from $defs
240
+ - Removing $defs after inlining all references
241
+ - Renaming unsupported keys to the SDK's expected snake_case
242
+ - Recursively converting nested schemas (properties, items, anyOf)
243
+ - Preserving fields supported by the SDK Schema model
173
244
  """
174
245
  if not isinstance(schema, dict):
175
- return schema
246
+ return {}
247
+
248
+ # Step 1: Inline $ref references before any conversion
249
+ defs = schema.get("$defs", {})
250
+
251
+ def inline_refs(obj, definitions):
252
+ """Recursively inline $ref references."""
253
+ if isinstance(obj, dict):
254
+ # If this object has a $ref, replace it with the referenced schema
255
+ if "$ref" in obj:
256
+ ref_path = obj["$ref"]
257
+ if ref_path.startswith("#/$defs/"):
258
+ ref_name = ref_path.replace("#/$defs/", "")
259
+ if ref_name in definitions:
260
+ # Get the referenced schema and inline it recursively
261
+ referenced = definitions[ref_name].copy()
262
+ # Preserve description and default from the referencing object
263
+ if "description" in obj and "description" not in referenced:
264
+ referenced["description"] = obj["description"]
265
+ if "default" in obj:
266
+ referenced["default"] = obj["default"]
267
+ return inline_refs(referenced, definitions)
268
+ # If can't resolve, remove the $ref
269
+ return {k: v for k, v in obj.items() if k != "$ref"}
270
+
271
+ # Recursively process all properties
272
+ result = {}
273
+ for key, value in obj.items():
274
+ if key == "$defs":
275
+ continue # Remove $defs after inlining
276
+ # Skip additionalProperties: true as it's not well supported
277
+ if key == "additionalProperties" and value is True:
278
+ continue
279
+ result[key] = inline_refs(value, definitions)
280
+ return result
281
+ elif isinstance(obj, list):
282
+ return [inline_refs(item, definitions) for item in obj]
283
+ else:
284
+ return obj
285
+
286
+ # Inline all references
287
+ schema = inline_refs(schema, defs)
288
+
289
+ # Step 2: Convert to SDK schema format
290
+ # Mapping from common JSON Schema/OpenAPI keys to google-genai Schema fields
291
+ key_map = {
292
+ "anyOf": "any_of",
293
+ "additionalProperties": "additional_properties",
294
+ "maxItems": "max_items",
295
+ "maxLength": "max_length",
296
+ "maxProperties": "max_properties",
297
+ "minItems": "min_items",
298
+ "minLength": "min_length",
299
+ "minProperties": "min_properties",
300
+ "propertyOrdering": "property_ordering",
301
+ }
176
302
 
177
- cleaned = {}
178
- for key, value in schema.items():
179
- if key in ("$ref", "$defs", "definitions"):
180
- continue
181
- if isinstance(value, dict):
182
- cleaned[key] = self._clean_json_schema(value)
183
- elif isinstance(value, list):
184
- cleaned[key] = [
185
- self._clean_json_schema(item) if isinstance(item, dict) else item
186
- for item in value
187
- ]
303
+ def convert(obj):
304
+ if isinstance(obj, dict):
305
+ out: dict[str, object] = {}
306
+ for k, v in obj.items():
307
+ if k == "const":
308
+ out["enum"] = [v]
309
+ continue
310
+
311
+ kk = key_map.get(k, k)
312
+ if kk == "properties" and isinstance(v, dict):
313
+ # properties: dict[str, Schema]
314
+ out[kk] = {pk: convert(pv) for pk, pv in v.items()}
315
+ elif kk == "items":
316
+ # items: Schema (treat list as first item schema)
317
+ if isinstance(v, list) and v:
318
+ out[kk] = convert(v[0])
319
+ else:
320
+ out[kk] = convert(v)
321
+ elif kk == "any_of" and isinstance(v, list):
322
+ out[kk] = [convert(iv) for iv in v]
323
+ else:
324
+ out[kk] = convert(v)
325
+ return out
326
+ elif isinstance(obj, list):
327
+ return [convert(i) for i in obj]
188
328
  else:
189
- cleaned[key] = value
329
+ return obj
190
330
 
191
- return cleaned
331
+ return convert(schema)
192
332
 
193
333
  async def generate(self, request: GenerateRequest) -> GenerateResponse:
194
334
  """Generate a response using Vertex AI."""
195
335
  await self.validate_request(request)
196
336
 
197
- # Separate system message from conversation
198
- system_instruction = None
199
- messages = []
200
- for msg in request.messages:
201
- if msg.role == "system":
202
- system_instruction = msg.content if isinstance(msg.content, str) else ""
337
+ def _safe_text(text: str) -> str:
338
+ try:
339
+ return text.encode("utf-8", errors="replace").decode(
340
+ "utf-8", errors="replace"
341
+ )
342
+ except Exception:
343
+ return ""
344
+
345
+ contents: list[Content] = []
346
+ system_instruction = ""
347
+
348
+ # Group consecutive tool messages into single Content
349
+ i = 0
350
+ while i < len(request.messages):
351
+ m = request.messages[i]
352
+
353
+ if m.role == "tool":
354
+ # Collect all consecutive tool messages
355
+ tool_parts = []
356
+ while i < len(request.messages) and request.messages[i].role == "tool":
357
+ tool_msg = request.messages[i]
358
+ content_str = (
359
+ tool_msg.content
360
+ if isinstance(tool_msg.content, str)
361
+ else str(tool_msg.content)
362
+ )
363
+ part = Part.from_function_response(
364
+ name=getattr(tool_msg, "name", "") or "",
365
+ response={"result": _safe_text(content_str)},
366
+ )
367
+ tool_parts.append(part)
368
+ i += 1
369
+ # Add all tool responses as a single Content
370
+ if tool_parts:
371
+ contents.append(Content(role="function", parts=tool_parts))
372
+ continue
373
+ elif m.role == "system":
374
+ content_str = (
375
+ m.content if isinstance(m.content, str) else str(m.content)
376
+ )
377
+ system_instruction += _safe_text(content_str).strip()
378
+ i += 1
379
+ elif m.role == "assistant":
380
+ # Check if message has tool_calls attribute
381
+ if hasattr(m, "tool_calls") and m.tool_calls:
382
+ # Assistant message with tool calls
383
+ parts_list = []
384
+ for tc in m.tool_calls:
385
+ if not tc.function.name:
386
+ continue
387
+ args = (
388
+ json.loads(tc.function.arguments)
389
+ if isinstance(tc.function.arguments, str)
390
+ else tc.function.arguments
391
+ )
392
+ if not isinstance(args, dict):
393
+ args = {}
394
+ part = Part.from_function_call(name=tc.function.name, args=args)
395
+ parts_list.append(part)
396
+ if parts_list:
397
+ contents.append(Content(role="model", parts=parts_list))
398
+ else:
399
+ # Regular assistant text response
400
+ content_str = (
401
+ m.content if isinstance(m.content, str) else str(m.content)
402
+ )
403
+ if content_str:
404
+ part = Part(text=_safe_text(content_str))
405
+ contents.append(Content(role="model", parts=[part]))
406
+ i += 1
203
407
  else:
204
- messages.append(self._convert_message(msg))
205
-
206
- config_kwargs = {}
207
- if request.temperature is not None:
208
- config_kwargs["temperature"] = request.temperature
209
- if request.max_tokens is not None:
210
- config_kwargs["max_output_tokens"] = request.max_tokens
211
- if request.top_p is not None:
212
- config_kwargs["top_p"] = request.top_p
408
+ # User message - use _convert_message to handle multimodal content
409
+ user_content = self._convert_message(m)
410
+ contents.append(user_content)
411
+ i += 1
412
+
413
+ config_kwargs = {
414
+ "temperature": request.temperature
415
+ if request.temperature is not None
416
+ else 0.2,
417
+ "top_p": request.top_p if request.top_p is not None else 0.95,
418
+ "max_output_tokens": request.max_tokens
419
+ if request.max_tokens is not None
420
+ else 8192,
421
+ }
422
+ if system_instruction:
423
+ config_kwargs["system_instruction"] = system_instruction
213
424
  if request.stop:
214
425
  config_kwargs["stop_sequences"] = request.stop
215
426
  if request.response_format:
216
- # Vertex AI uses response_mime_type and response_schema
217
427
  config_kwargs["response_mime_type"] = "application/json"
218
428
  if "schema" in request.response_format:
219
429
  config_kwargs["response_schema"] = self._clean_json_schema(
220
430
  request.response_format["schema"]
221
431
  )
222
432
 
223
- # Build config object
224
- config = (
225
- genai.types.GenerateContentConfig(**config_kwargs)
226
- if config_kwargs
227
- else None
228
- )
433
+ config = genai.types.GenerateContentConfig(**config_kwargs)
229
434
 
230
- # Add tools to config if present
231
435
  if request.tools:
232
- if config is None:
233
- config = genai.types.GenerateContentConfig()
234
- config.tools = self._convert_tools(request.tools)
235
-
236
- # Add system instruction to config if present
237
- if system_instruction:
238
- if config is None:
239
- config = genai.types.GenerateContentConfig()
240
- config.system_instruction = system_instruction
241
-
242
- response = await self.client.aio.models.generate_content(
243
- model=self._model_name,
244
- contents=messages,
245
- config=config,
246
- )
247
- # Extract content
248
- content = None
249
- if response.text:
250
- content = response.text
251
-
252
- # Extract tool calls
253
- tool_calls = None
254
- if response.candidates and response.candidates[0].content.parts:
255
- function_calls = []
256
- for part in response.candidates[0].content.parts:
257
- if not hasattr(part, "function_call") or not part.function_call:
258
- continue
259
- fc = part.function_call
260
- args_dict = dict(fc.args) if fc.args else {}
261
- function_calls.append(
262
- ToolCall(
263
- id=fc.name,
264
- type="function",
265
- function=FunctionCall(
266
- name=fc.name,
267
- arguments=json.dumps(args_dict),
268
- ),
436
+ function_declarations: list[FunctionDeclaration] = []
437
+ for t in request.tools:
438
+ schema_obj = self._clean_json_schema(t.function.parameters or {})
439
+ function_declarations.append(
440
+ FunctionDeclaration(
441
+ name=t.function.name,
442
+ description=t.function.description or "",
443
+ parameters=schema_obj,
269
444
  )
270
445
  )
271
- if function_calls:
272
- tool_calls = function_calls
273
-
274
- # Extract finish reason
275
- finish_reason = None
276
- if response.candidates:
277
- finish_reason = str(response.candidates[0].finish_reason)
278
-
279
- # Extract usage
280
- usage = None
281
- if response.usage_metadata:
282
- usage = {
283
- "prompt_tokens": response.usage_metadata.prompt_token_count,
284
- "completion_tokens": response.usage_metadata.candidates_token_count,
285
- "total_tokens": response.usage_metadata.total_token_count,
286
- }
287
-
288
- return GenerateResponse(
289
- content=content,
290
- tool_calls=tool_calls,
291
- finish_reason=finish_reason,
292
- usage=usage,
293
- )
446
+ gen_tools = [GeminiTool(function_declarations=function_declarations)]
447
+ config.tools = gen_tools
448
+
449
+ try:
450
+ response = await self.client.aio.models.generate_content(
451
+ model=self._model_name,
452
+ contents=contents,
453
+ config=config,
454
+ )
455
+ text, tool_calls = self._parse_response(response)
456
+
457
+ # If no text and no tool calls, check for errors in response
458
+ if not text and not tool_calls:
459
+ try:
460
+ # Check for blocking reasons or errors
461
+ if hasattr(response, "candidates") and response.candidates:
462
+ cand = response.candidates[0]
463
+ if hasattr(cand, "finish_reason") and cand.finish_reason:
464
+ finish_reason = cand.finish_reason
465
+ if finish_reason not in ("STOP", None):
466
+ error_msg = (
467
+ f"Model finished with reason: {finish_reason}"
468
+ )
469
+ return GenerateResponse(content=f"Warning: {error_msg}")
470
+ # Check for safety ratings that might block content
471
+ if hasattr(response, "candidates") and response.candidates:
472
+ cand = response.candidates[0]
473
+ if hasattr(cand, "safety_ratings"):
474
+ blocked = any(
475
+ getattr(r, "blocked", False)
476
+ for r in getattr(cand, "safety_ratings", [])
477
+ )
478
+ if blocked:
479
+ error_msg = "Response was blocked by safety filters"
480
+ return GenerateResponse(content=f"Warning: {error_msg}")
481
+ except Exception:
482
+ pass # If we can't check, just return empty
483
+
484
+ # Extract finish reason
485
+ finish_reason = None
486
+ if response.candidates:
487
+ finish_reason = str(response.candidates[0].finish_reason)
488
+
489
+ # Extract usage
490
+ usage = None
491
+ if response.usage_metadata:
492
+ usage = {
493
+ "prompt_tokens": response.usage_metadata.prompt_token_count,
494
+ "completion_tokens": response.usage_metadata.candidates_token_count,
495
+ "total_tokens": response.usage_metadata.total_token_count,
496
+ }
497
+
498
+ return GenerateResponse(
499
+ content=text,
500
+ tool_calls=tool_calls,
501
+ finish_reason=finish_reason,
502
+ usage=usage,
503
+ )
504
+ except Exception as e:
505
+ error_msg = str(e)
506
+ # Return error message instead of empty response
507
+ return GenerateResponse(content=f"Error: {error_msg}")
294
508
 
295
509
  async def generate_stream(
296
510
  self, request: GenerateRequest
@@ -298,22 +512,89 @@ class VertexAIModel(LLMModelAbstract):
298
512
  """Generate a streaming response using Vertex AI."""
299
513
  await self.validate_request(request)
300
514
 
301
- # Separate system message from conversation
302
- system_instruction = None
303
- messages = []
304
- for msg in request.messages:
305
- if msg.role == "system":
306
- system_instruction = msg.content if isinstance(msg.content, str) else ""
515
+ def _safe_text(text: str) -> str:
516
+ try:
517
+ return text.encode("utf-8", errors="replace").decode(
518
+ "utf-8", errors="replace"
519
+ )
520
+ except Exception:
521
+ return ""
522
+
523
+ contents: list[Content] = []
524
+ system_instruction = ""
525
+
526
+ # Convert messages to genai format (same logic as generate())
527
+ i = 0
528
+ while i < len(request.messages):
529
+ m = request.messages[i]
530
+
531
+ if m.role == "tool":
532
+ # Collect all consecutive tool messages
533
+ tool_parts = []
534
+ while i < len(request.messages) and request.messages[i].role == "tool":
535
+ tool_msg = request.messages[i]
536
+ content_str = (
537
+ tool_msg.content
538
+ if isinstance(tool_msg.content, str)
539
+ else str(tool_msg.content)
540
+ )
541
+ part = Part.from_function_response(
542
+ name=getattr(tool_msg, "name", "") or "",
543
+ response={"result": _safe_text(content_str)},
544
+ )
545
+ tool_parts.append(part)
546
+ i += 1
547
+ if tool_parts:
548
+ contents.append(Content(role="function", parts=tool_parts))
549
+ continue
550
+ elif m.role == "system":
551
+ content_str = (
552
+ m.content if isinstance(m.content, str) else str(m.content)
553
+ )
554
+ system_instruction += _safe_text(content_str).strip()
555
+ i += 1
556
+ elif m.role == "assistant":
557
+ if hasattr(m, "tool_calls") and m.tool_calls:
558
+ parts_list = []
559
+ for tc in m.tool_calls:
560
+ if not tc.function.name:
561
+ continue
562
+ args = (
563
+ json.loads(tc.function.arguments)
564
+ if isinstance(tc.function.arguments, str)
565
+ else tc.function.arguments
566
+ )
567
+ if not isinstance(args, dict):
568
+ args = {}
569
+ part = Part.from_function_call(name=tc.function.name, args=args)
570
+ parts_list.append(part)
571
+ if parts_list:
572
+ contents.append(Content(role="model", parts=parts_list))
573
+ else:
574
+ content_str = (
575
+ m.content if isinstance(m.content, str) else str(m.content)
576
+ )
577
+ if content_str:
578
+ part = Part(text=_safe_text(content_str))
579
+ contents.append(Content(role="model", parts=[part]))
580
+ i += 1
307
581
  else:
308
- messages.append(self._convert_message(msg))
309
-
310
- config_kwargs = {}
311
- if request.temperature is not None:
312
- config_kwargs["temperature"] = request.temperature
313
- if request.max_tokens is not None:
314
- config_kwargs["max_output_tokens"] = request.max_tokens
315
- if request.top_p is not None:
316
- config_kwargs["top_p"] = request.top_p
582
+ # User message - use _convert_message to handle multimodal content
583
+ user_content = self._convert_message(m)
584
+ contents.append(user_content)
585
+ i += 1
586
+
587
+ config_kwargs: dict[str, Any] = {
588
+ "temperature": request.temperature
589
+ if request.temperature is not None
590
+ else 0.2,
591
+ "top_p": request.top_p if request.top_p is not None else 0.95,
592
+ "max_output_tokens": request.max_tokens
593
+ if request.max_tokens is not None
594
+ else 8192,
595
+ }
596
+ if system_instruction:
597
+ config_kwargs["system_instruction"] = system_instruction
317
598
  if request.stop:
318
599
  config_kwargs["stop_sequences"] = request.stop
319
600
  if request.response_format:
@@ -322,70 +603,49 @@ class VertexAIModel(LLMModelAbstract):
322
603
  config_kwargs["response_schema"] = self._clean_json_schema(
323
604
  request.response_format["schema"]
324
605
  )
325
-
326
- # Build config object
327
- config = (
328
- genai.types.GenerateContentConfig(**config_kwargs)
329
- if config_kwargs
330
- else None
606
+ config_kwargs["automatic_function_calling"] = (
607
+ genai.types.AutomaticFunctionCallingConfig(maximum_remote_calls=100)
331
608
  )
332
609
 
333
- # Add tools to config if present
610
+ config = genai.types.GenerateContentConfig(**config_kwargs)
611
+
334
612
  if request.tools:
335
- if config is None:
336
- config = genai.types.GenerateContentConfig()
337
- config.tools = self._convert_tools(request.tools)
613
+ function_declarations: list[FunctionDeclaration] = []
614
+ for t in request.tools:
615
+ schema_obj = self._clean_json_schema(t.function.parameters or {})
616
+ function_declarations.append(
617
+ FunctionDeclaration(
618
+ name=t.function.name,
619
+ description=t.function.description or "",
620
+ parameters=schema_obj,
621
+ )
622
+ )
623
+ gen_tools = [GeminiTool(function_declarations=function_declarations)]
624
+ config.tools = gen_tools
625
+
626
+ try:
627
+ # Use generate_content_stream for streaming
628
+ stream = await self.client.aio.models.generate_content_stream(
629
+ model=self._model_name,
630
+ contents=contents,
631
+ config=config,
632
+ )
338
633
 
339
- # Add system instruction to config if present
340
- if system_instruction:
341
- if config is None:
342
- config = genai.types.GenerateContentConfig()
343
- config.system_instruction = system_instruction
344
-
345
- model_name = self._model_name
346
- stream = await self.client.aio.models.generate_content_stream(
347
- model=model_name,
348
- contents=messages,
349
- config=config,
350
- )
634
+ async for chunk in stream:
635
+ text, tool_calls = self._parse_response(chunk)
351
636
 
352
- async for chunk in stream:
353
- content = None
354
- if chunk.text:
355
- content = chunk.text
356
-
357
- # Extract tool calls from chunk
358
- tool_calls = None
359
- if chunk.candidates and chunk.candidates[0].content.parts:
360
- function_calls = []
361
- for part in chunk.candidates[0].content.parts:
362
- if not hasattr(part, "function_call") or not part.function_call:
363
- continue
364
- fc = part.function_call
365
- args_dict = dict(fc.args) if fc.args else {}
366
- function_calls.append(
367
- ToolCall(
368
- id=fc.name,
369
- type="function",
370
- function=FunctionCall(
371
- name=fc.name,
372
- arguments=json.dumps(args_dict),
373
- ),
374
- )
375
- )
376
- if function_calls:
377
- tool_calls = function_calls
637
+ # Yield text chunks as they come
638
+ if text:
639
+ yield StreamChunk(content=text, tool_calls=None)
378
640
 
379
- finish_reason = None
380
- if chunk.candidates:
381
- finish_reason = str(chunk.candidates[0].finish_reason)
382
-
383
- if content or tool_calls or finish_reason:
384
- yield StreamChunk(
385
- content=content,
386
- tool_calls=tool_calls,
387
- finish_reason=finish_reason,
388
- )
641
+ # Tool calls come in final chunk - yield them separately
642
+ if tool_calls:
643
+ yield StreamChunk(content=None, tool_calls=tool_calls)
644
+
645
+ except Exception as e:
646
+ error_msg = str(e)
647
+ # Yield error message instead of empty response
648
+ yield StreamChunk(content=f"Error: {error_msg}")
389
649
 
390
650
 
391
651
  class VertexEmbeddingModel(LLMModelAbstract):
@@ -428,6 +688,10 @@ class VertexEmbeddingModel(LLMModelAbstract):
428
688
  def model_name(self) -> str:
429
689
  return self._model_name
430
690
 
691
+ @model_name.setter
692
+ def model_name(self, model_name: str) -> None:
693
+ self._model_name = model_name
694
+
431
695
  @property
432
696
  def capabilities(self) -> ModelCapability:
433
697
  return ModelCapability.EMBEDDINGS
@@ -475,4 +739,8 @@ class VertexEmbeddingModel(LLMModelAbstract):
475
739
  return EmbeddingResponse(
476
740
  embeddings=all_embeddings,
477
741
  usage=None,
742
+ metadata={
743
+ "dimensions": len(all_embeddings[0]) if len(all_embeddings) > 0 else 0,
744
+ "batch_size": self._batch_size,
745
+ },
478
746
  )