letta-nightly 0.11.7.dev20250911104039__py3-none-any.whl → 0.11.7.dev20250913103940__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/adapters/letta_llm_stream_adapter.py +1 -1
- letta/agents/letta_agent_v2.py +46 -10
- letta/helpers/tpuf_client.py +41 -9
- letta/interfaces/openai_streaming_interface.py +11 -74
- letta/llm_api/anthropic_client.py +2 -2
- letta/llm_api/azure_client.py +5 -2
- letta/llm_api/google_vertex_client.py +158 -16
- letta/llm_api/openai_client.py +14 -11
- letta/orm/job.py +5 -1
- letta/orm/organization.py +2 -0
- letta/otel/sqlalchemy_instrumentation.py +6 -1
- letta/schemas/letta_stop_reason.py +2 -0
- letta/server/rest_api/app.py +61 -1
- letta/server/rest_api/redis_stream_manager.py +15 -2
- letta/server/rest_api/routers/v1/agents.py +53 -15
- letta/server/rest_api/routers/v1/tools.py +23 -39
- letta/services/job_manager.py +15 -3
- letta/services/mcp_manager.py +64 -3
- letta/services/tool_executor/files_tool_executor.py +2 -2
- {letta_nightly-0.11.7.dev20250911104039.dist-info → letta_nightly-0.11.7.dev20250913103940.dist-info}/METADATA +3 -3
- {letta_nightly-0.11.7.dev20250911104039.dist-info → letta_nightly-0.11.7.dev20250913103940.dist-info}/RECORD +24 -24
- {letta_nightly-0.11.7.dev20250911104039.dist-info → letta_nightly-0.11.7.dev20250913103940.dist-info}/WHEEL +0 -0
- {letta_nightly-0.11.7.dev20250911104039.dist-info → letta_nightly-0.11.7.dev20250913103940.dist-info}/entry_points.txt +0 -0
- {letta_nightly-0.11.7.dev20250911104039.dist-info → letta_nightly-0.11.7.dev20250913103940.dist-info}/licenses/LICENSE +0 -0
@@ -149,7 +149,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
|
149
149
|
request_json=self.request_data,
|
150
150
|
response_json={
|
151
151
|
"content": {
|
152
|
-
"tool_call": self.tool_call.model_dump_json(),
|
152
|
+
"tool_call": self.tool_call.model_dump_json() if self.tool_call else None,
|
153
153
|
"reasoning": [content.model_dump_json() for content in self.reasoning_content],
|
154
154
|
},
|
155
155
|
"id": self.interface.message_id,
|
letta/agents/letta_agent_v2.py
CHANGED
@@ -19,7 +19,7 @@ from letta.agents.helpers import (
|
|
19
19
|
generate_step_id,
|
20
20
|
)
|
21
21
|
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX
|
22
|
-
from letta.errors import ContextWindowExceededError
|
22
|
+
from letta.errors import ContextWindowExceededError, LLMError
|
23
23
|
from letta.helpers import ToolRulesSolver
|
24
24
|
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns, ns_to_ms
|
25
25
|
from letta.helpers.reasoning_helper import scrub_inner_thoughts_from_messages
|
@@ -213,8 +213,17 @@ class LettaAgentV2(BaseAgentV2):
|
|
213
213
|
|
214
214
|
if self.stop_reason is None:
|
215
215
|
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
216
|
-
|
217
|
-
|
216
|
+
|
217
|
+
result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
|
218
|
+
if run_id:
|
219
|
+
if self.job_update_metadata is None:
|
220
|
+
self.job_update_metadata = {}
|
221
|
+
self.job_update_metadata["result"] = result.model_dump(mode="json")
|
222
|
+
|
223
|
+
await self._request_checkpoint_finish(
|
224
|
+
request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns, run_id=run_id
|
225
|
+
)
|
226
|
+
return result
|
218
227
|
|
219
228
|
@trace_method
|
220
229
|
async def stream(
|
@@ -297,11 +306,24 @@ class LettaAgentV2(BaseAgentV2):
|
|
297
306
|
)
|
298
307
|
|
299
308
|
except:
|
300
|
-
if self.stop_reason:
|
309
|
+
if self.stop_reason and not first_chunk:
|
301
310
|
yield f"data: {self.stop_reason.model_dump_json()}\n\n"
|
302
311
|
raise
|
303
312
|
|
304
|
-
|
313
|
+
if run_id:
|
314
|
+
letta_messages = Message.to_letta_messages_from_list(
|
315
|
+
self.response_messages,
|
316
|
+
use_assistant_message=use_assistant_message,
|
317
|
+
reverse=False,
|
318
|
+
)
|
319
|
+
result = LettaResponse(messages=letta_messages, stop_reason=self.stop_reason, usage=self.usage)
|
320
|
+
if self.job_update_metadata is None:
|
321
|
+
self.job_update_metadata = {}
|
322
|
+
self.job_update_metadata["result"] = result.model_dump(mode="json")
|
323
|
+
|
324
|
+
await self._request_checkpoint_finish(
|
325
|
+
request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns, run_id=run_id
|
326
|
+
)
|
305
327
|
for finish_chunk in self.get_finish_chunks_for_stream(self.usage, self.stop_reason):
|
306
328
|
yield f"data: {finish_chunk}\n\n"
|
307
329
|
|
@@ -409,6 +431,9 @@ class LettaAgentV2(BaseAgentV2):
|
|
409
431
|
except ValueError as e:
|
410
432
|
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_llm_response.value)
|
411
433
|
raise e
|
434
|
+
except LLMError as e:
|
435
|
+
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.llm_api_error.value)
|
436
|
+
raise e
|
412
437
|
except Exception as e:
|
413
438
|
if isinstance(e, ContextWindowExceededError) and llm_request_attempt < summarizer_settings.max_summarizer_retries:
|
414
439
|
# Retry case
|
@@ -475,6 +500,17 @@ class LettaAgentV2(BaseAgentV2):
|
|
475
500
|
if include_return_message_types is None or message.message_type in include_return_message_types:
|
476
501
|
yield message
|
477
502
|
|
503
|
+
# Persist approval responses immediately to prevent agent from getting into a bad state
|
504
|
+
if (
|
505
|
+
len(input_messages_to_persist) == 1
|
506
|
+
and input_messages_to_persist[0].role == "approval"
|
507
|
+
and persisted_messages[0].role == "approval"
|
508
|
+
and persisted_messages[1].role == "tool"
|
509
|
+
):
|
510
|
+
self.agent_state.message_ids = self.agent_state.message_ids + [m.id for m in persisted_messages[:2]]
|
511
|
+
await self.agent_manager.update_message_ids_async(
|
512
|
+
agent_id=self.agent_state.id, message_ids=self.agent_state.message_ids, actor=self.actor
|
513
|
+
)
|
478
514
|
step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
|
479
515
|
except Exception as e:
|
480
516
|
self.logger.error(f"Error during step processing: {e}")
|
@@ -489,6 +525,7 @@ class LettaAgentV2(BaseAgentV2):
|
|
489
525
|
StopReasonType.no_tool_call,
|
490
526
|
StopReasonType.invalid_tool_call,
|
491
527
|
StopReasonType.invalid_llm_response,
|
528
|
+
StopReasonType.llm_api_error,
|
492
529
|
):
|
493
530
|
self.logger.error("Error occurred during step processing, with unexpected stop reason: %s", self.stop_reason.stop_reason)
|
494
531
|
raise e
|
@@ -736,11 +773,10 @@ class LettaAgentV2(BaseAgentV2):
|
|
736
773
|
return None
|
737
774
|
|
738
775
|
@trace_method
|
739
|
-
def _request_checkpoint_finish(
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
request_span.end()
|
776
|
+
async def _request_checkpoint_finish(
|
777
|
+
self, request_span: Span | None, request_start_timestamp_ns: int | None, run_id: str | None
|
778
|
+
) -> None:
|
779
|
+
await self._log_request(request_start_timestamp_ns, request_span, self.job_update_metadata, is_error=False, run_id=run_id)
|
744
780
|
return None
|
745
781
|
|
746
782
|
@trace_method
|
letta/helpers/tpuf_client.py
CHANGED
@@ -62,11 +62,18 @@ class TurbopufferClient:
|
|
62
62
|
"""
|
63
63
|
from letta.llm_api.llm_client import LLMClient
|
64
64
|
|
65
|
+
# filter out empty strings after stripping
|
66
|
+
filtered_texts = [text for text in texts if text.strip()]
|
67
|
+
|
68
|
+
# skip embedding if no valid texts
|
69
|
+
if not filtered_texts:
|
70
|
+
return []
|
71
|
+
|
65
72
|
embedding_client = LLMClient.create(
|
66
73
|
provider_type=self.default_embedding_config.embedding_endpoint_type,
|
67
74
|
actor=actor,
|
68
75
|
)
|
69
|
-
embeddings = await embedding_client.request_embeddings(
|
76
|
+
embeddings = await embedding_client.request_embeddings(filtered_texts, self.default_embedding_config)
|
70
77
|
return embeddings
|
71
78
|
|
72
79
|
@trace_method
|
@@ -119,8 +126,16 @@ class TurbopufferClient:
|
|
119
126
|
"""
|
120
127
|
from turbopuffer import AsyncTurbopuffer
|
121
128
|
|
129
|
+
# filter out empty text chunks
|
130
|
+
filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()]
|
131
|
+
|
132
|
+
if not filtered_chunks:
|
133
|
+
logger.warning("All text chunks were empty, skipping insertion")
|
134
|
+
return []
|
135
|
+
|
122
136
|
# generate embeddings using the default config
|
123
|
-
|
137
|
+
filtered_texts = [text for _, text in filtered_chunks]
|
138
|
+
embeddings = await self._generate_embeddings(filtered_texts, actor)
|
124
139
|
|
125
140
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
126
141
|
|
@@ -152,8 +167,8 @@ class TurbopufferClient:
|
|
152
167
|
tags_arrays = [] # Store tags as arrays
|
153
168
|
passages = []
|
154
169
|
|
155
|
-
for
|
156
|
-
passage_id = passage_ids[
|
170
|
+
for (original_idx, text), embedding in zip(filtered_chunks, embeddings):
|
171
|
+
passage_id = passage_ids[original_idx]
|
157
172
|
|
158
173
|
# append to columns
|
159
174
|
ids.append(passage_id)
|
@@ -240,8 +255,16 @@ class TurbopufferClient:
|
|
240
255
|
"""
|
241
256
|
from turbopuffer import AsyncTurbopuffer
|
242
257
|
|
258
|
+
# filter out empty message texts
|
259
|
+
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
|
260
|
+
|
261
|
+
if not filtered_messages:
|
262
|
+
logger.warning("All message texts were empty, skipping insertion")
|
263
|
+
return True
|
264
|
+
|
243
265
|
# generate embeddings using the default config
|
244
|
-
|
266
|
+
filtered_texts = [text for _, text in filtered_messages]
|
267
|
+
embeddings = await self._generate_embeddings(filtered_texts, actor)
|
245
268
|
|
246
269
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
247
270
|
|
@@ -266,8 +289,10 @@ class TurbopufferClient:
|
|
266
289
|
project_ids = []
|
267
290
|
template_ids = []
|
268
291
|
|
269
|
-
for
|
270
|
-
message_id = message_ids[
|
292
|
+
for (original_idx, text), embedding in zip(filtered_messages, embeddings):
|
293
|
+
message_id = message_ids[original_idx]
|
294
|
+
role = roles[original_idx]
|
295
|
+
created_at = created_ats[original_idx]
|
271
296
|
|
272
297
|
# ensure the provided timestamp is timezone-aware and in UTC
|
273
298
|
if created_at.tzinfo is None:
|
@@ -1162,8 +1187,15 @@ class TurbopufferClient:
|
|
1162
1187
|
if not text_chunks:
|
1163
1188
|
return []
|
1164
1189
|
|
1190
|
+
# filter out empty text chunks
|
1191
|
+
filtered_chunks = [text for text in text_chunks if text.strip()]
|
1192
|
+
|
1193
|
+
if not filtered_chunks:
|
1194
|
+
logger.warning("All text chunks were empty, skipping file passage insertion")
|
1195
|
+
return []
|
1196
|
+
|
1165
1197
|
# generate embeddings using the default config
|
1166
|
-
embeddings = await self._generate_embeddings(
|
1198
|
+
embeddings = await self._generate_embeddings(filtered_chunks, actor)
|
1167
1199
|
|
1168
1200
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
1169
1201
|
|
@@ -1189,7 +1221,7 @@ class TurbopufferClient:
|
|
1189
1221
|
created_ats = []
|
1190
1222
|
passages = []
|
1191
1223
|
|
1192
|
-
for
|
1224
|
+
for text, embedding in zip(filtered_chunks, embeddings):
|
1193
1225
|
passage = PydanticPassage(
|
1194
1226
|
text=text,
|
1195
1227
|
file_id=file_id,
|
@@ -24,7 +24,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
|
24
24
|
from letta.schemas.message import Message
|
25
25
|
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
26
26
|
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
27
|
-
from letta.streaming_utils import JSONInnerThoughtsExtractor
|
27
|
+
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
|
28
28
|
from letta.utils import count_tokens
|
29
29
|
|
30
30
|
logger = get_logger(__name__)
|
@@ -53,6 +53,8 @@ class OpenAIStreamingInterface:
|
|
53
53
|
|
54
54
|
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
55
55
|
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
|
56
|
+
# Reader that extracts only the assistant message value from send_message args
|
57
|
+
self.assistant_message_json_reader = FunctionArgumentsStreamHandler(json_key=self.assistant_message_tool_kwarg)
|
56
58
|
self.function_name_buffer = None
|
57
59
|
self.function_args_buffer = None
|
58
60
|
self.function_id_buffer = None
|
@@ -274,6 +276,10 @@ class OpenAIStreamingInterface:
|
|
274
276
|
# Store the ID of the tool call so allow skipping the corresponding response
|
275
277
|
if self.function_id_buffer:
|
276
278
|
self.prev_assistant_message_id = self.function_id_buffer
|
279
|
+
# Reset message reader at the start of a new send_message stream
|
280
|
+
self.assistant_message_json_reader.reset()
|
281
|
+
self.assistant_message_json_reader.in_message = True
|
282
|
+
self.assistant_message_json_reader.message_started = True
|
277
283
|
|
278
284
|
else:
|
279
285
|
if prev_message_type and prev_message_type != "tool_call_message":
|
@@ -328,39 +334,15 @@ class OpenAIStreamingInterface:
|
|
328
334
|
self.last_flushed_function_name is not None
|
329
335
|
and self.last_flushed_function_name == self.assistant_message_tool_name
|
330
336
|
):
|
331
|
-
#
|
332
|
-
|
333
|
-
|
334
|
-
self.function_args_buffer = None
|
335
|
-
|
336
|
-
# Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix
|
337
|
-
match_str = '{"' + self.assistant_message_tool_kwarg + '":"'
|
338
|
-
if updates_main_json == match_str:
|
339
|
-
updates_main_json = None
|
340
|
-
|
341
|
-
else:
|
342
|
-
# Some hardcoding to strip off the trailing "}"
|
343
|
-
if updates_main_json in ["}", '"}']:
|
344
|
-
updates_main_json = None
|
345
|
-
if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"':
|
346
|
-
updates_main_json = updates_main_json[:-1]
|
347
|
-
|
348
|
-
if not updates_main_json:
|
349
|
-
# early exit to turn into content mode
|
350
|
-
pass
|
351
|
-
|
352
|
-
# There may be a buffer from a previous chunk, for example
|
353
|
-
# if the previous chunk had arguments but we needed to flush name
|
354
|
-
if self.function_args_buffer:
|
355
|
-
# In this case, we should release the buffer + new data at once
|
356
|
-
combined_chunk = self.function_args_buffer + updates_main_json
|
357
|
-
|
337
|
+
# Minimal, robust extraction: only emit the value of "message"
|
338
|
+
extracted = self.assistant_message_json_reader.process_json_chunk(tool_call.function.arguments)
|
339
|
+
if extracted:
|
358
340
|
if prev_message_type and prev_message_type != "assistant_message":
|
359
341
|
message_index += 1
|
360
342
|
assistant_message = AssistantMessage(
|
361
343
|
id=self.letta_message_id,
|
362
344
|
date=datetime.now(timezone.utc),
|
363
|
-
content=
|
345
|
+
content=extracted,
|
364
346
|
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
365
347
|
)
|
366
348
|
prev_message_type = assistant_message.message_type
|
@@ -368,51 +350,6 @@ class OpenAIStreamingInterface:
|
|
368
350
|
# Store the ID of the tool call so allow skipping the corresponding response
|
369
351
|
if self.function_id_buffer:
|
370
352
|
self.prev_assistant_message_id = self.function_id_buffer
|
371
|
-
# clear buffer
|
372
|
-
self.function_args_buffer = None
|
373
|
-
self.function_id_buffer = None
|
374
|
-
|
375
|
-
else:
|
376
|
-
# If there's no buffer to clear, just output a new chunk with new data
|
377
|
-
# TODO: THIS IS HORRIBLE
|
378
|
-
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
|
379
|
-
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
|
380
|
-
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
|
381
|
-
|
382
|
-
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
|
383
|
-
self.assistant_message_tool_kwarg
|
384
|
-
) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
|
385
|
-
new_content = parsed_args.get(self.assistant_message_tool_kwarg)
|
386
|
-
prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "")
|
387
|
-
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
388
|
-
diff = new_content.replace(prev_content, "", 1)
|
389
|
-
|
390
|
-
# quick patch to mitigate double message streaming error
|
391
|
-
# TODO: root cause this issue and remove patch
|
392
|
-
if diff != "" and "\\n" not in new_content:
|
393
|
-
converted_new_content = new_content.replace("\n", "\\n")
|
394
|
-
converted_content_diff = converted_new_content.replace(prev_content, "", 1)
|
395
|
-
if converted_content_diff == "":
|
396
|
-
diff = converted_content_diff
|
397
|
-
|
398
|
-
self.current_json_parse_result = parsed_args
|
399
|
-
if prev_message_type and prev_message_type != "assistant_message":
|
400
|
-
message_index += 1
|
401
|
-
assistant_message = AssistantMessage(
|
402
|
-
id=self.letta_message_id,
|
403
|
-
date=datetime.now(timezone.utc),
|
404
|
-
content=diff,
|
405
|
-
# name=name,
|
406
|
-
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
407
|
-
)
|
408
|
-
prev_message_type = assistant_message.message_type
|
409
|
-
yield assistant_message
|
410
|
-
|
411
|
-
# Store the ID of the tool call so allow skipping the corresponding response
|
412
|
-
if self.function_id_buffer:
|
413
|
-
self.prev_assistant_message_id = self.function_id_buffer
|
414
|
-
# clear buffers
|
415
|
-
self.function_id_buffer = None
|
416
353
|
else:
|
417
354
|
# There may be a buffer from a previous chunk, for example
|
418
355
|
# if the previous chunk had arguments but we needed to flush name
|
@@ -497,7 +497,7 @@ class AnthropicClient(LLMClientBase):
|
|
497
497
|
try:
|
498
498
|
args_json = json.loads(arguments)
|
499
499
|
if not isinstance(args_json, dict):
|
500
|
-
raise
|
500
|
+
raise LLMServerError("Expected parseable json object for arguments")
|
501
501
|
except:
|
502
502
|
arguments = str(tool_input["function"]["arguments"])
|
503
503
|
else:
|
@@ -854,7 +854,7 @@ def remap_finish_reason(stop_reason: str) -> str:
|
|
854
854
|
elif stop_reason == "tool_use":
|
855
855
|
return "function_call"
|
856
856
|
else:
|
857
|
-
raise
|
857
|
+
raise LLMServerError(f"Unexpected stop_reason: {stop_reason}")
|
858
858
|
|
859
859
|
|
860
860
|
def strip_xml_tags(string: str, tag: Optional[str]) -> str:
|
letta/llm_api/azure_client.py
CHANGED
@@ -54,9 +54,12 @@ class AzureClient(OpenAIClient):
|
|
54
54
|
api_key = model_settings.azure_api_key or os.environ.get("AZURE_API_KEY")
|
55
55
|
base_url = model_settings.azure_base_url or os.environ.get("AZURE_BASE_URL")
|
56
56
|
api_version = model_settings.azure_api_version or os.environ.get("AZURE_API_VERSION")
|
57
|
+
try:
|
58
|
+
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
59
|
+
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
60
|
+
except Exception as e:
|
61
|
+
raise self.handle_llm_error(e)
|
57
62
|
|
58
|
-
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
59
|
-
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
60
63
|
return response.model_dump()
|
61
64
|
|
62
65
|
@trace_method
|
@@ -14,6 +14,19 @@ from google.genai.types import (
|
|
14
14
|
)
|
15
15
|
|
16
16
|
from letta.constants import NON_USER_MSG_PREFIX
|
17
|
+
from letta.errors import (
|
18
|
+
ContextWindowExceededError,
|
19
|
+
ErrorCode,
|
20
|
+
LLMAuthenticationError,
|
21
|
+
LLMBadRequestError,
|
22
|
+
LLMConnectionError,
|
23
|
+
LLMNotFoundError,
|
24
|
+
LLMPermissionDeniedError,
|
25
|
+
LLMRateLimitError,
|
26
|
+
LLMServerError,
|
27
|
+
LLMTimeoutError,
|
28
|
+
LLMUnprocessableEntityError,
|
29
|
+
)
|
17
30
|
from letta.helpers.datetime_helpers import get_utc_time_int
|
18
31
|
from letta.helpers.json_helpers import json_dumps, json_loads
|
19
32
|
from letta.llm_api.llm_client_base import LLMClientBase
|
@@ -48,13 +61,16 @@ class GoogleVertexClient(LLMClientBase):
|
|
48
61
|
"""
|
49
62
|
Performs underlying request to llm and returns raw response.
|
50
63
|
"""
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
64
|
+
try:
|
65
|
+
client = self._get_client()
|
66
|
+
response = client.models.generate_content(
|
67
|
+
model=llm_config.model,
|
68
|
+
contents=request_data["contents"],
|
69
|
+
config=request_data["config"],
|
70
|
+
)
|
71
|
+
return response.model_dump()
|
72
|
+
except Exception as e:
|
73
|
+
raise self.handle_llm_error(e)
|
58
74
|
|
59
75
|
@trace_method
|
60
76
|
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
@@ -67,6 +83,7 @@ class GoogleVertexClient(LLMClientBase):
|
|
67
83
|
# https://github.com/googleapis/python-aiplatform/issues/4472
|
68
84
|
retry_count = 1
|
69
85
|
should_retry = True
|
86
|
+
response_data = None
|
70
87
|
while should_retry and retry_count <= self.MAX_RETRIES:
|
71
88
|
try:
|
72
89
|
response = await client.aio.models.generate_content(
|
@@ -76,13 +93,15 @@ class GoogleVertexClient(LLMClientBase):
|
|
76
93
|
)
|
77
94
|
except errors.APIError as e:
|
78
95
|
# Retry on 503 and 500 errors as well, usually ephemeral from Gemini
|
79
|
-
if e.code == 503 or e.code == 500:
|
96
|
+
if e.code == 503 or e.code == 500 or e.code == 504:
|
80
97
|
logger.warning(f"Received {e}, retrying {retry_count}/{self.MAX_RETRIES}")
|
81
98
|
retry_count += 1
|
99
|
+
if retry_count > self.MAX_RETRIES:
|
100
|
+
raise self.handle_llm_error(e)
|
82
101
|
continue
|
83
|
-
raise e
|
102
|
+
raise self.handle_llm_error(e)
|
84
103
|
except Exception as e:
|
85
|
-
raise e
|
104
|
+
raise self.handle_llm_error(e)
|
86
105
|
response_data = response.model_dump()
|
87
106
|
is_malformed_function_call = self.is_malformed_function_call(response_data)
|
88
107
|
if is_malformed_function_call:
|
@@ -114,6 +133,8 @@ class GoogleVertexClient(LLMClientBase):
|
|
114
133
|
should_retry = is_malformed_function_call
|
115
134
|
retry_count += 1
|
116
135
|
|
136
|
+
if response_data is None:
|
137
|
+
raise RuntimeError("Failed to get response data after all retries")
|
117
138
|
return response_data
|
118
139
|
|
119
140
|
@staticmethod
|
@@ -358,11 +379,10 @@ class GoogleVertexClient(LLMClientBase):
|
|
358
379
|
|
359
380
|
if content is None or content.role is None or content.parts is None:
|
360
381
|
# This means the response is malformed like MALFORMED_FUNCTION_CALL
|
361
|
-
# NOTE: must be a ValueError to trigger a retry
|
362
382
|
if candidate.finish_reason == "MALFORMED_FUNCTION_CALL":
|
363
|
-
raise
|
383
|
+
raise LLMServerError(f"Malformed response from Google Vertex: {candidate.finish_reason}")
|
364
384
|
else:
|
365
|
-
raise
|
385
|
+
raise LLMServerError(f"Invalid response data from Google Vertex: {candidate.model_dump()}")
|
366
386
|
|
367
387
|
role = content.role
|
368
388
|
assert role == "model", f"Unknown role in response: {role}"
|
@@ -456,7 +476,7 @@ class GoogleVertexClient(LLMClientBase):
|
|
456
476
|
|
457
477
|
except json.decoder.JSONDecodeError:
|
458
478
|
if candidate.finish_reason == "MAX_TOKENS":
|
459
|
-
raise
|
479
|
+
raise LLMServerError("Could not parse response data from LLM: exceeded max token limit")
|
460
480
|
# Inner thoughts are the content by default
|
461
481
|
inner_thoughts = response_message.text
|
462
482
|
|
@@ -485,7 +505,7 @@ class GoogleVertexClient(LLMClientBase):
|
|
485
505
|
elif finish_reason == "RECITATION":
|
486
506
|
openai_finish_reason = "content_filter"
|
487
507
|
else:
|
488
|
-
raise
|
508
|
+
raise LLMServerError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
|
489
509
|
|
490
510
|
choices.append(
|
491
511
|
Choice(
|
@@ -576,5 +596,127 @@ class GoogleVertexClient(LLMClientBase):
|
|
576
596
|
|
577
597
|
@trace_method
|
578
598
|
def handle_llm_error(self, e: Exception) -> Exception:
|
579
|
-
#
|
599
|
+
# Handle Google GenAI specific errors
|
600
|
+
if isinstance(e, errors.ClientError):
|
601
|
+
logger.warning(f"[Google Vertex] Client error ({e.code}): {e}")
|
602
|
+
|
603
|
+
# Handle specific error codes
|
604
|
+
if e.code == 400:
|
605
|
+
error_str = str(e).lower()
|
606
|
+
if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str):
|
607
|
+
return ContextWindowExceededError(
|
608
|
+
message=f"Bad request to Google Vertex (context window exceeded): {str(e)}",
|
609
|
+
)
|
610
|
+
else:
|
611
|
+
return LLMBadRequestError(
|
612
|
+
message=f"Bad request to Google Vertex: {str(e)}",
|
613
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
614
|
+
)
|
615
|
+
elif e.code == 401:
|
616
|
+
return LLMAuthenticationError(
|
617
|
+
message=f"Authentication failed with Google Vertex: {str(e)}",
|
618
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
619
|
+
)
|
620
|
+
elif e.code == 403:
|
621
|
+
return LLMPermissionDeniedError(
|
622
|
+
message=f"Permission denied by Google Vertex: {str(e)}",
|
623
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
624
|
+
)
|
625
|
+
elif e.code == 404:
|
626
|
+
return LLMNotFoundError(
|
627
|
+
message=f"Resource not found in Google Vertex: {str(e)}",
|
628
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
629
|
+
)
|
630
|
+
elif e.code == 408:
|
631
|
+
return LLMTimeoutError(
|
632
|
+
message=f"Request to Google Vertex timed out: {str(e)}",
|
633
|
+
code=ErrorCode.TIMEOUT,
|
634
|
+
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
635
|
+
)
|
636
|
+
elif e.code == 422:
|
637
|
+
return LLMUnprocessableEntityError(
|
638
|
+
message=f"Invalid request content for Google Vertex: {str(e)}",
|
639
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
640
|
+
)
|
641
|
+
elif e.code == 429:
|
642
|
+
logger.warning("[Google Vertex] Rate limited (429). Consider backoff.")
|
643
|
+
return LLMRateLimitError(
|
644
|
+
message=f"Rate limited by Google Vertex: {str(e)}",
|
645
|
+
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
646
|
+
)
|
647
|
+
else:
|
648
|
+
return LLMServerError(
|
649
|
+
message=f"Google Vertex client error: {str(e)}",
|
650
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
651
|
+
details={
|
652
|
+
"status_code": e.code,
|
653
|
+
"response_json": getattr(e, "response_json", None),
|
654
|
+
},
|
655
|
+
)
|
656
|
+
|
657
|
+
if isinstance(e, errors.ServerError):
|
658
|
+
logger.warning(f"[Google Vertex] Server error ({e.code}): {e}")
|
659
|
+
|
660
|
+
# Handle specific server error codes
|
661
|
+
if e.code == 500:
|
662
|
+
return LLMServerError(
|
663
|
+
message=f"Google Vertex internal server error: {str(e)}",
|
664
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
665
|
+
details={
|
666
|
+
"status_code": e.code,
|
667
|
+
"response_json": getattr(e, "response_json", None),
|
668
|
+
},
|
669
|
+
)
|
670
|
+
elif e.code == 502:
|
671
|
+
return LLMConnectionError(
|
672
|
+
message=f"Bad gateway from Google Vertex: {str(e)}",
|
673
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
674
|
+
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
675
|
+
)
|
676
|
+
elif e.code == 503:
|
677
|
+
return LLMServerError(
|
678
|
+
message=f"Google Vertex service unavailable: {str(e)}",
|
679
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
680
|
+
details={
|
681
|
+
"status_code": e.code,
|
682
|
+
"response_json": getattr(e, "response_json", None),
|
683
|
+
},
|
684
|
+
)
|
685
|
+
elif e.code == 504:
|
686
|
+
return LLMTimeoutError(
|
687
|
+
message=f"Gateway timeout from Google Vertex: {str(e)}",
|
688
|
+
code=ErrorCode.TIMEOUT,
|
689
|
+
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
690
|
+
)
|
691
|
+
else:
|
692
|
+
return LLMServerError(
|
693
|
+
message=f"Google Vertex server error: {str(e)}",
|
694
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
695
|
+
details={
|
696
|
+
"status_code": e.code,
|
697
|
+
"response_json": getattr(e, "response_json", None),
|
698
|
+
},
|
699
|
+
)
|
700
|
+
|
701
|
+
if isinstance(e, errors.APIError):
|
702
|
+
logger.warning(f"[Google Vertex] API error ({e.code}): {e}")
|
703
|
+
return LLMServerError(
|
704
|
+
message=f"Google Vertex API error: {str(e)}",
|
705
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
706
|
+
details={
|
707
|
+
"status_code": e.code,
|
708
|
+
"response_json": getattr(e, "response_json", None),
|
709
|
+
},
|
710
|
+
)
|
711
|
+
|
712
|
+
# Handle connection-related errors
|
713
|
+
if "connection" in str(e).lower() or "timeout" in str(e).lower():
|
714
|
+
logger.warning(f"[Google Vertex] Connection/timeout error: {e}")
|
715
|
+
return LLMConnectionError(
|
716
|
+
message=f"Failed to connect to Google Vertex: {str(e)}",
|
717
|
+
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
718
|
+
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
719
|
+
)
|
720
|
+
|
721
|
+
# Fallback to base implementation for other errors
|
580
722
|
return super().handle_llm_error(e)
|