letta-nightly 0.11.7.dev20250911104039__py3-none-any.whl → 0.11.7.dev20250913103940__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -149,7 +149,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
149
149
  request_json=self.request_data,
150
150
  response_json={
151
151
  "content": {
152
- "tool_call": self.tool_call.model_dump_json(),
152
+ "tool_call": self.tool_call.model_dump_json() if self.tool_call else None,
153
153
  "reasoning": [content.model_dump_json() for content in self.reasoning_content],
154
154
  },
155
155
  "id": self.interface.message_id,
@@ -19,7 +19,7 @@ from letta.agents.helpers import (
19
19
  generate_step_id,
20
20
  )
21
21
  from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX
22
- from letta.errors import ContextWindowExceededError
22
+ from letta.errors import ContextWindowExceededError, LLMError
23
23
  from letta.helpers import ToolRulesSolver
24
24
  from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms
25
25
  from letta.helpers.reasoning_helper import scrub_inner_thoughts_from_messages
@@ -213,8 +213,17 @@ class LettaAgentV2(BaseAgentV2):
213
213
 
214
214
  if self.stop_reason is None:
215
215
  self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
216
- self._request_checkpoint_finish(request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns)
217
- return LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
216
+
217
+ result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
218
+ if run_id:
219
+ if self.job_update_metadata is None:
220
+ self.job_update_metadata = {}
221
+ self.job_update_metadata["result"] = result.model_dump(mode="json")
222
+
223
+ await self._request_checkpoint_finish(
224
+ request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns, run_id=run_id
225
+ )
226
+ return result
218
227
 
219
228
  @trace_method
220
229
  async def stream(
@@ -297,11 +306,24 @@ class LettaAgentV2(BaseAgentV2):
297
306
  )
298
307
 
299
308
  except:
300
- if self.stop_reason:
309
+ if self.stop_reason and not first_chunk:
301
310
  yield f"data: {self.stop_reason.model_dump_json()}\n\n"
302
311
  raise
303
312
 
304
- self._request_checkpoint_finish(request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns)
313
+ if run_id:
314
+ letta_messages = Message.to_letta_messages_from_list(
315
+ self.response_messages,
316
+ use_assistant_message=use_assistant_message,
317
+ reverse=False,
318
+ )
319
+ result = LettaResponse(messages=letta_messages, stop_reason=self.stop_reason, usage=self.usage)
320
+ if self.job_update_metadata is None:
321
+ self.job_update_metadata = {}
322
+ self.job_update_metadata["result"] = result.model_dump(mode="json")
323
+
324
+ await self._request_checkpoint_finish(
325
+ request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns, run_id=run_id
326
+ )
305
327
  for finish_chunk in self.get_finish_chunks_for_stream(self.usage, self.stop_reason):
306
328
  yield f"data: {finish_chunk}\n\n"
307
329
 
@@ -409,6 +431,9 @@ class LettaAgentV2(BaseAgentV2):
409
431
  except ValueError as e:
410
432
  self.stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value)
411
433
  raise e
434
+ except LLMError as e:
435
+ self.stop_reason = LettaStopReason(stop_reason=StopReasonType.llm_api_error.value)
436
+ raise e
412
437
  except Exception as e:
413
438
  if isinstance(e, ContextWindowExceededError) and llm_request_attempt < summarizer_settings.max_summarizer_retries:
414
439
  # Retry case
@@ -475,6 +500,17 @@ class LettaAgentV2(BaseAgentV2):
475
500
  if include_return_message_types is None or message.message_type in include_return_message_types:
476
501
  yield message
477
502
 
503
+ # Persist approval responses immediately to prevent agent from getting into a bad state
504
+ if (
505
+ len(input_messages_to_persist) == 1
506
+ and input_messages_to_persist[0].role == "approval"
507
+ and persisted_messages[0].role == "approval"
508
+ and persisted_messages[1].role == "tool"
509
+ ):
510
+ self.agent_state.message_ids = self.agent_state.message_ids + [m.id for m in persisted_messages[:2]]
511
+ await self.agent_manager.update_message_ids_async(
512
+ agent_id=self.agent_state.id, message_ids=self.agent_state.message_ids, actor=self.actor
513
+ )
478
514
  step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
479
515
  except Exception as e:
480
516
  self.logger.error(f"Error during step processing: {e}")
@@ -489,6 +525,7 @@ class LettaAgentV2(BaseAgentV2):
489
525
  StopReasonType.no_tool_call,
490
526
  StopReasonType.invalid_tool_call,
491
527
  StopReasonType.invalid_llm_response,
528
+ StopReasonType.llm_api_error,
492
529
  ):
493
530
  self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", self.stop_reason.stop_reason)
494
531
  raise e
@@ -736,11 +773,10 @@ class LettaAgentV2(BaseAgentV2):
736
773
  return None
737
774
 
738
775
  @trace_method
739
- def _request_checkpoint_finish(self, request_span: Span | None, request_start_timestamp_ns: int | None) -> None:
740
- if request_span is not None:
741
- duration_ns = get_utc_timestamp_ns() - request_start_timestamp_ns
742
- request_span.add_event(name="letta_request_ms", attributes={"duration_ms": ns_to_ms(duration_ns)})
743
- request_span.end()
776
+ async def _request_checkpoint_finish(
777
+ self, request_span: Span | None, request_start_timestamp_ns: int | None, run_id: str | None
778
+ ) -> None:
779
+ await self._log_request(request_start_timestamp_ns, request_span, self.job_update_metadata, is_error=False, run_id=run_id)
744
780
  return None
745
781
 
746
782
  @trace_method
@@ -62,11 +62,18 @@ class TurbopufferClient:
62
62
  """
63
63
  from letta.llm_api.llm_client import LLMClient
64
64
 
65
+ # filter out empty strings after stripping
66
+ filtered_texts = [text for text in texts if text.strip()]
67
+
68
+ # skip embedding if no valid texts
69
+ if not filtered_texts:
70
+ return []
71
+
65
72
  embedding_client = LLMClient.create(
66
73
  provider_type=self.default_embedding_config.embedding_endpoint_type,
67
74
  actor=actor,
68
75
  )
69
- embeddings = await embedding_client.request_embeddings(texts, self.default_embedding_config)
76
+ embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config)
70
77
  return embeddings
71
78
 
72
79
  @trace_method
@@ -119,8 +126,16 @@ class TurbopufferClient:
119
126
  """
120
127
  from turbopuffer import AsyncTurbopuffer
121
128
 
129
+ # filter out empty text chunks
130
+ filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()]
131
+
132
+ if not filtered_chunks:
133
+ logger.warning("All text chunks were empty, skipping insertion")
134
+ return []
135
+
122
136
  # generate embeddings using the default config
123
- embeddings = await self._generate_embeddings(text_chunks, actor)
137
+ filtered_texts = [text for _, text in filtered_chunks]
138
+ embeddings = await self._generate_embeddings(filtered_texts, actor)
124
139
 
125
140
  namespace_name = await self._get_archive_namespace_name(archive_id)
126
141
 
@@ -152,8 +167,8 @@ class TurbopufferClient:
152
167
  tags_arrays = [] # Store tags as arrays
153
168
  passages = []
154
169
 
155
- for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
156
- passage_id = passage_ids[idx]
170
+ for (original_idx, text), embedding in zip(filtered_chunks, embeddings):
171
+ passage_id = passage_ids[original_idx]
157
172
 
158
173
  # append to columns
159
174
  ids.append(passage_id)
@@ -240,8 +255,16 @@ class TurbopufferClient:
240
255
  """
241
256
  from turbopuffer import AsyncTurbopuffer
242
257
 
258
+ # filter out empty message texts
259
+ filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
260
+
261
+ if not filtered_messages:
262
+ logger.warning("All message texts were empty, skipping insertion")
263
+ return True
264
+
243
265
  # generate embeddings using the default config
244
- embeddings = await self._generate_embeddings(message_texts, actor)
266
+ filtered_texts = [text for _, text in filtered_messages]
267
+ embeddings = await self._generate_embeddings(filtered_texts, actor)
245
268
 
246
269
  namespace_name = await self._get_message_namespace_name(organization_id)
247
270
 
@@ -266,8 +289,10 @@ class TurbopufferClient:
266
289
  project_ids = []
267
290
  template_ids = []
268
291
 
269
- for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
270
- message_id = message_ids[idx]
292
+ for (original_idx, text), embedding in zip(filtered_messages, embeddings):
293
+ message_id = message_ids[original_idx]
294
+ role = roles[original_idx]
295
+ created_at = created_ats[original_idx]
271
296
 
272
297
  # ensure the provided timestamp is timezone-aware and in UTC
273
298
  if created_at.tzinfo is None:
@@ -1162,8 +1187,15 @@ class TurbopufferClient:
1162
1187
  if not text_chunks:
1163
1188
  return []
1164
1189
 
1190
+ # filter out empty text chunks
1191
+ filtered_chunks = [text for text in text_chunks if text.strip()]
1192
+
1193
+ if not filtered_chunks:
1194
+ logger.warning("All text chunks were empty, skipping file passage insertion")
1195
+ return []
1196
+
1165
1197
  # generate embeddings using the default config
1166
- embeddings = await self._generate_embeddings(text_chunks, actor)
1198
+ embeddings = await self._generate_embeddings(filtered_chunks, actor)
1167
1199
 
1168
1200
  namespace_name = await self._get_file_passages_namespace_name(organization_id)
1169
1201
 
@@ -1189,7 +1221,7 @@ class TurbopufferClient:
1189
1221
  created_ats = []
1190
1222
  passages = []
1191
1223
 
1192
- for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
1224
+ for text, embedding in zip(filtered_chunks, embeddings):
1193
1225
  passage = PydanticPassage(
1194
1226
  text=text,
1195
1227
  file_id=file_id,
@@ -24,7 +24,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
24
24
  from letta.schemas.message import Message
25
25
  from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
26
26
  from letta.server.rest_api.json_parser import OptimisticJSONParser
27
- from letta.streaming_utils import JSONInnerThoughtsExtractor
27
+ from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
28
28
  from letta.utils import count_tokens
29
29
 
30
30
  logger = get_logger(__name__)
@@ -53,6 +53,8 @@ class OpenAIStreamingInterface:
53
53
 
54
54
  self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
55
55
  self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
56
+ # Reader that extracts only the assistant message value from send_message args
57
+ self.assistant_message_json_reader = FunctionArgumentsStreamHandler(json_key=self.assistant_message_tool_kwarg)
56
58
  self.function_name_buffer = None
57
59
  self.function_args_buffer = None
58
60
  self.function_id_buffer = None
@@ -274,6 +276,10 @@ class OpenAIStreamingInterface:
274
276
  # Store the ID of the tool call so allow skipping the corresponding response
275
277
  if self.function_id_buffer:
276
278
  self.prev_assistant_message_id = self.function_id_buffer
279
+ # Reset message reader at the start of a new send_message stream
280
+ self.assistant_message_json_reader.reset()
281
+ self.assistant_message_json_reader.in_message = True
282
+ self.assistant_message_json_reader.message_started = True
277
283
 
278
284
  else:
279
285
  if prev_message_type and prev_message_type != "tool_call_message":
@@ -328,39 +334,15 @@ class OpenAIStreamingInterface:
328
334
  self.last_flushed_function_name is not None
329
335
  and self.last_flushed_function_name == self.assistant_message_tool_name
330
336
  ):
331
- # do an additional parse on the updates_main_json
332
- if self.function_args_buffer:
333
- updates_main_json = self.function_args_buffer + updates_main_json
334
- self.function_args_buffer = None
335
-
336
- # Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix
337
- match_str = '{"' + self.assistant_message_tool_kwarg + '":"'
338
- if updates_main_json == match_str:
339
- updates_main_json = None
340
-
341
- else:
342
- # Some hardcoding to strip off the trailing "}"
343
- if updates_main_json in ["}", '"}']:
344
- updates_main_json = None
345
- if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"':
346
- updates_main_json = updates_main_json[:-1]
347
-
348
- if not updates_main_json:
349
- # early exit to turn into content mode
350
- pass
351
-
352
- # There may be a buffer from a previous chunk, for example
353
- # if the previous chunk had arguments but we needed to flush name
354
- if self.function_args_buffer:
355
- # In this case, we should release the buffer + new data at once
356
- combined_chunk = self.function_args_buffer + updates_main_json
357
-
337
+ # Minimal, robust extraction: only emit the value of "message"
338
+ extracted = self.assistant_message_json_reader.process_json_chunk(tool_call.function.arguments)
339
+ if extracted:
358
340
  if prev_message_type and prev_message_type != "assistant_message":
359
341
  message_index += 1
360
342
  assistant_message = AssistantMessage(
361
343
  id=self.letta_message_id,
362
344
  date=datetime.now(timezone.utc),
363
- content=combined_chunk,
345
+ content=extracted,
364
346
  otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
365
347
  )
366
348
  prev_message_type = assistant_message.message_type
@@ -368,51 +350,6 @@ class OpenAIStreamingInterface:
368
350
  # Store the ID of the tool call so allow skipping the corresponding response
369
351
  if self.function_id_buffer:
370
352
  self.prev_assistant_message_id = self.function_id_buffer
371
- # clear buffer
372
- self.function_args_buffer = None
373
- self.function_id_buffer = None
374
-
375
- else:
376
- # If there's no buffer to clear, just output a new chunk with new data
377
- # TODO: THIS IS HORRIBLE
378
- # TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
379
- # TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
380
- parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
381
-
382
- if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
383
- self.assistant_message_tool_kwarg
384
- ) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
385
- new_content = parsed_args.get(self.assistant_message_tool_kwarg)
386
- prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "")
387
- # TODO: Assumes consistent state and that prev_content is subset of new_content
388
- diff = new_content.replace(prev_content, "", 1)
389
-
390
- # quick patch to mitigate double message streaming error
391
- # TODO: root cause this issue and remove patch
392
- if diff != "" and "\\n" not in new_content:
393
- converted_new_content = new_content.replace("\n", "\\n")
394
- converted_content_diff = converted_new_content.replace(prev_content, "", 1)
395
- if converted_content_diff == "":
396
- diff = converted_content_diff
397
-
398
- self.current_json_parse_result = parsed_args
399
- if prev_message_type and prev_message_type != "assistant_message":
400
- message_index += 1
401
- assistant_message = AssistantMessage(
402
- id=self.letta_message_id,
403
- date=datetime.now(timezone.utc),
404
- content=diff,
405
- # name=name,
406
- otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
407
- )
408
- prev_message_type = assistant_message.message_type
409
- yield assistant_message
410
-
411
- # Store the ID of the tool call so allow skipping the corresponding response
412
- if self.function_id_buffer:
413
- self.prev_assistant_message_id = self.function_id_buffer
414
- # clear buffers
415
- self.function_id_buffer = None
416
353
  else:
417
354
  # There may be a buffer from a previous chunk, for example
418
355
  # if the previous chunk had arguments but we needed to flush name
@@ -497,7 +497,7 @@ class AnthropicClient(LLMClientBase):
497
497
  try:
498
498
  args_json = json.loads(arguments)
499
499
  if not isinstance(args_json, dict):
500
- raise ValueError("Expected parseable json object for arguments")
500
+ raise LLMServerError("Expected parseable json object for arguments")
501
501
  except:
502
502
  arguments = str(tool_input["function"]["arguments"])
503
503
  else:
@@ -854,7 +854,7 @@ def remap_finish_reason(stop_reason: str) -> str:
854
854
  elif stop_reason == "tool_use":
855
855
  return "function_call"
856
856
  else:
857
- raise ValueError(f"Unexpected stop_reason: {stop_reason}")
857
+ raise LLMServerError(f"Unexpected stop_reason: {stop_reason}")
858
858
 
859
859
 
860
860
  def strip_xml_tags(string: str, tag: Optional[str]) -> str:
@@ -54,9 +54,12 @@ class AzureClient(OpenAIClient):
54
54
  api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
55
55
  base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
56
56
  api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
57
+ try:
58
+ client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
59
+ response: ChatCompletion = await client.chat.completions.create(**request_data)
60
+ except Exception as e:
61
+ raise self.handle_llm_error(e)
57
62
 
58
- client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
59
- response: ChatCompletion = await client.chat.completions.create(**request_data)
60
63
  return response.model_dump()
61
64
 
62
65
  @trace_method
@@ -14,6 +14,19 @@ from google.genai.types import (
14
14
  )
15
15
 
16
16
  from letta.constants import NON_USER_MSG_PREFIX
17
+ from letta.errors import (
18
+ ContextWindowExceededError,
19
+ ErrorCode,
20
+ LLMAuthenticationError,
21
+ LLMBadRequestError,
22
+ LLMConnectionError,
23
+ LLMNotFoundError,
24
+ LLMPermissionDeniedError,
25
+ LLMRateLimitError,
26
+ LLMServerError,
27
+ LLMTimeoutError,
28
+ LLMUnprocessableEntityError,
29
+ )
17
30
  from letta.helpers.datetime_helpers import get_utc_time_int
18
31
  from letta.helpers.json_helpers import json_dumps, json_loads
19
32
  from letta.llm_api.llm_client_base import LLMClientBase
@@ -48,13 +61,16 @@ class GoogleVertexClient(LLMClientBase):
48
61
  """
49
62
  Performs underlying request to llm and returns raw response.
50
63
  """
51
- client = self._get_client()
52
- response = client.models.generate_content(
53
- model=llm_config.model,
54
- contents=request_data["contents"],
55
- config=request_data["config"],
56
- )
57
- return response.model_dump()
64
+ try:
65
+ client = self._get_client()
66
+ response = client.models.generate_content(
67
+ model=llm_config.model,
68
+ contents=request_data["contents"],
69
+ config=request_data["config"],
70
+ )
71
+ return response.model_dump()
72
+ except Exception as e:
73
+ raise self.handle_llm_error(e)
58
74
 
59
75
  @trace_method
60
76
  async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
@@ -67,6 +83,7 @@ class GoogleVertexClient(LLMClientBase):
67
83
  # https://github.com/googleapis/python-aiplatform/issues/4472
68
84
  retry_count = 1
69
85
  should_retry = True
86
+ response_data = None
70
87
  while should_retry and retry_count <= self.MAX_RETRIES:
71
88
  try:
72
89
  response = await client.aio.models.generate_content(
@@ -76,13 +93,15 @@ class GoogleVertexClient(LLMClientBase):
76
93
  )
77
94
  except errors.APIError as e:
78
95
  # Retry on 503 and 500 errors as well, usually ephemeral from Gemini
79
- if e.code == 503 or e.code == 500:
96
+ if e.code == 503 or e.code == 500 or e.code == 504:
80
97
  logger.warning(f"Received {e}, retrying {retry_count}/{self.MAX_RETRIES}")
81
98
  retry_count += 1
99
+ if retry_count > self.MAX_RETRIES:
100
+ raise self.handle_llm_error(e)
82
101
  continue
83
- raise e
102
+ raise self.handle_llm_error(e)
84
103
  except Exception as e:
85
- raise e
104
+ raise self.handle_llm_error(e)
86
105
  response_data = response.model_dump()
87
106
  is_malformed_function_call = self.is_malformed_function_call(response_data)
88
107
  if is_malformed_function_call:
@@ -114,6 +133,8 @@ class GoogleVertexClient(LLMClientBase):
114
133
  should_retry = is_malformed_function_call
115
134
  retry_count += 1
116
135
 
136
+ if response_data is None:
137
+ raise RuntimeError("Failed to get response data after all retries")
117
138
  return response_data
118
139
 
119
140
  @staticmethod
@@ -358,11 +379,10 @@ class GoogleVertexClient(LLMClientBase):
358
379
 
359
380
  if content is None or content.role is None or content.parts is None:
360
381
  # This means the response is malformed like MALFORMED_FUNCTION_CALL
361
- # NOTE: must be a ValueError to trigger a retry
362
382
  if candidate.finish_reason == "MALFORMED_FUNCTION_CALL":
363
- raise ValueError(f"Error in response data from LLM: {candidate.finish_reason}")
383
+ raise LLMServerError(f"Malformed response from Google Vertex: {candidate.finish_reason}")
364
384
  else:
365
- raise ValueError(f"Error in response data from LLM: {candidate.model_dump()}")
385
+ raise LLMServerError(f"Invalid response data from Google Vertex: {candidate.model_dump()}")
366
386
 
367
387
  role = content.role
368
388
  assert role == "model", f"Unknown role in response: {role}"
@@ -456,7 +476,7 @@ class GoogleVertexClient(LLMClientBase):
456
476
 
457
477
  except json.decoder.JSONDecodeError:
458
478
  if candidate.finish_reason == "MAX_TOKENS":
459
- raise ValueError("Could not parse response data from LLM: exceeded max token limit")
479
+ raise LLMServerError("Could not parse response data from LLM: exceeded max token limit")
460
480
  # Inner thoughts are the content by default
461
481
  inner_thoughts = response_message.text
462
482
 
@@ -485,7 +505,7 @@ class GoogleVertexClient(LLMClientBase):
485
505
  elif finish_reason == "RECITATION":
486
506
  openai_finish_reason = "content_filter"
487
507
  else:
488
- raise ValueError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
508
+ raise LLMServerError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
489
509
 
490
510
  choices.append(
491
511
  Choice(
@@ -576,5 +596,127 @@ class GoogleVertexClient(LLMClientBase):
576
596
 
577
597
  @trace_method
578
598
  def handle_llm_error(self, e: Exception) -> Exception:
579
- # Fallback to base implementation
599
+ # Handle Google GenAI specific errors
600
+ if isinstance(e, errors.ClientError):
601
+ logger.warning(f"[Google Vertex] Client error ({e.code}): {e}")
602
+
603
+ # Handle specific error codes
604
+ if e.code == 400:
605
+ error_str = str(e).lower()
606
+ if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str):
607
+ return ContextWindowExceededError(
608
+ message=f"Bad request to Google Vertex (context window exceeded): {str(e)}",
609
+ )
610
+ else:
611
+ return LLMBadRequestError(
612
+ message=f"Bad request to Google Vertex: {str(e)}",
613
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
614
+ )
615
+ elif e.code == 401:
616
+ return LLMAuthenticationError(
617
+ message=f"Authentication failed with Google Vertex: {str(e)}",
618
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
619
+ )
620
+ elif e.code == 403:
621
+ return LLMPermissionDeniedError(
622
+ message=f"Permission denied by Google Vertex: {str(e)}",
623
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
624
+ )
625
+ elif e.code == 404:
626
+ return LLMNotFoundError(
627
+ message=f"Resource not found in Google Vertex: {str(e)}",
628
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
629
+ )
630
+ elif e.code == 408:
631
+ return LLMTimeoutError(
632
+ message=f"Request to Google Vertex timed out: {str(e)}",
633
+ code=ErrorCode.TIMEOUT,
634
+ details={"cause": str(e.__cause__) if e.__cause__ else None},
635
+ )
636
+ elif e.code == 422:
637
+ return LLMUnprocessableEntityError(
638
+ message=f"Invalid request content for Google Vertex: {str(e)}",
639
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
640
+ )
641
+ elif e.code == 429:
642
+ logger.warning("[Google Vertex] Rate limited (429). Consider backoff.")
643
+ return LLMRateLimitError(
644
+ message=f"Rate limited by Google Vertex: {str(e)}",
645
+ code=ErrorCode.RATE_LIMIT_EXCEEDED,
646
+ )
647
+ else:
648
+ return LLMServerError(
649
+ message=f"Google Vertex client error: {str(e)}",
650
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
651
+ details={
652
+ "status_code": e.code,
653
+ "response_json": getattr(e, "response_json", None),
654
+ },
655
+ )
656
+
657
+ if isinstance(e, errors.ServerError):
658
+ logger.warning(f"[Google Vertex] Server error ({e.code}): {e}")
659
+
660
+ # Handle specific server error codes
661
+ if e.code == 500:
662
+ return LLMServerError(
663
+ message=f"Google Vertex internal server error: {str(e)}",
664
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
665
+ details={
666
+ "status_code": e.code,
667
+ "response_json": getattr(e, "response_json", None),
668
+ },
669
+ )
670
+ elif e.code == 502:
671
+ return LLMConnectionError(
672
+ message=f"Bad gateway from Google Vertex: {str(e)}",
673
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
674
+ details={"cause": str(e.__cause__) if e.__cause__ else None},
675
+ )
676
+ elif e.code == 503:
677
+ return LLMServerError(
678
+ message=f"Google Vertex service unavailable: {str(e)}",
679
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
680
+ details={
681
+ "status_code": e.code,
682
+ "response_json": getattr(e, "response_json", None),
683
+ },
684
+ )
685
+ elif e.code == 504:
686
+ return LLMTimeoutError(
687
+ message=f"Gateway timeout from Google Vertex: {str(e)}",
688
+ code=ErrorCode.TIMEOUT,
689
+ details={"cause": str(e.__cause__) if e.__cause__ else None},
690
+ )
691
+ else:
692
+ return LLMServerError(
693
+ message=f"Google Vertex server error: {str(e)}",
694
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
695
+ details={
696
+ "status_code": e.code,
697
+ "response_json": getattr(e, "response_json", None),
698
+ },
699
+ )
700
+
701
+ if isinstance(e, errors.APIError):
702
+ logger.warning(f"[Google Vertex] API error ({e.code}): {e}")
703
+ return LLMServerError(
704
+ message=f"Google Vertex API error: {str(e)}",
705
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
706
+ details={
707
+ "status_code": e.code,
708
+ "response_json": getattr(e, "response_json", None),
709
+ },
710
+ )
711
+
712
+ # Handle connection-related errors
713
+ if "connection" in str(e).lower() or "timeout" in str(e).lower():
714
+ logger.warning(f"[Google Vertex] Connection/timeout error: {e}")
715
+ return LLMConnectionError(
716
+ message=f"Failed to connect to Google Vertex: {str(e)}",
717
+ code=ErrorCode.INTERNAL_SERVER_ERROR,
718
+ details={"cause": str(e.__cause__) if e.__cause__ else None},
719
+ )
720
+
721
+ # Fallback to base implementation for other errors
580
722
  return super().handle_llm_error(e)