grasp_agents 0.4.7__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/cloud_llm.py +191 -224
  2. grasp_agents/comm_processor.py +101 -100
  3. grasp_agents/errors.py +69 -9
  4. grasp_agents/litellm/__init__.py +106 -0
  5. grasp_agents/litellm/completion_chunk_converters.py +68 -0
  6. grasp_agents/litellm/completion_converters.py +72 -0
  7. grasp_agents/litellm/converters.py +138 -0
  8. grasp_agents/litellm/lite_llm.py +210 -0
  9. grasp_agents/litellm/message_converters.py +66 -0
  10. grasp_agents/llm.py +84 -49
  11. grasp_agents/llm_agent.py +136 -120
  12. grasp_agents/llm_agent_memory.py +3 -3
  13. grasp_agents/llm_policy_executor.py +167 -174
  14. grasp_agents/memory.py +4 -0
  15. grasp_agents/openai/__init__.py +24 -9
  16. grasp_agents/openai/completion_chunk_converters.py +6 -6
  17. grasp_agents/openai/completion_converters.py +12 -14
  18. grasp_agents/openai/content_converters.py +1 -3
  19. grasp_agents/openai/converters.py +6 -8
  20. grasp_agents/openai/message_converters.py +21 -3
  21. grasp_agents/openai/openai_llm.py +155 -103
  22. grasp_agents/openai/tool_converters.py +4 -6
  23. grasp_agents/packet.py +5 -2
  24. grasp_agents/packet_pool.py +14 -13
  25. grasp_agents/printer.py +234 -72
  26. grasp_agents/processor.py +228 -88
  27. grasp_agents/prompt_builder.py +2 -2
  28. grasp_agents/run_context.py +11 -20
  29. grasp_agents/runner.py +42 -0
  30. grasp_agents/typing/completion.py +16 -9
  31. grasp_agents/typing/completion_chunk.py +51 -22
  32. grasp_agents/typing/events.py +95 -19
  33. grasp_agents/typing/message.py +25 -1
  34. grasp_agents/typing/tool.py +2 -0
  35. grasp_agents/usage_tracker.py +31 -37
  36. grasp_agents/utils.py +95 -84
  37. grasp_agents/workflow/looped_workflow.py +60 -11
  38. grasp_agents/workflow/sequential_workflow.py +43 -11
  39. grasp_agents/workflow/workflow_processor.py +25 -24
  40. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
  41. grasp_agents-0.5.0.dist-info/RECORD +57 -0
  42. grasp_agents-0.4.7.dist-info/RECORD +0 -50
  43. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
  44. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/cloud_llm.py CHANGED
@@ -1,103 +1,59 @@
1
- import fnmatch
2
1
  import logging
3
- import os
4
2
  from abc import abstractmethod
5
- from collections.abc import AsyncIterator, Mapping
3
+ from collections.abc import AsyncIterator, Mapping, Sequence
6
4
  from copy import deepcopy
7
- from typing import Any, Generic, Literal, NotRequired
5
+ from typing import Any, Generic, Required, cast
8
6
 
9
7
  import httpx
10
8
  from pydantic import BaseModel
11
- from tenacity import (
12
- RetryCallState,
13
- retry,
14
- stop_after_attempt,
15
- wait_random_exponential,
16
- )
17
9
  from typing_extensions import TypedDict
18
10
 
11
+ from .errors import LLMResponseValidationError, LLMToolCallValidationError
19
12
  from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
20
13
  from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
21
14
  from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
22
15
  from .typing.completion import Completion
23
- from .typing.completion_chunk import (
24
- CompletionChoice,
25
- CompletionChunk,
26
- combine_completion_chunks,
16
+ from .typing.completion_chunk import CompletionChoice
17
+ from .typing.events import (
18
+ CompletionChunkEvent,
19
+ CompletionEvent,
20
+ LLMStreamingErrorData,
21
+ LLMStreamingErrorEvent,
27
22
  )
28
- from .typing.events import CompletionChunkEvent, CompletionEvent
29
23
  from .typing.message import AssistantMessage, Messages
30
24
  from .typing.tool import BaseTool, ToolChoice
31
25
 
32
26
  logger = logging.getLogger(__name__)
33
27
 
34
28
 
35
- APIProviderName = Literal["openai", "openrouter", "google_ai_studio"]
36
-
37
-
38
- class APIProvider(TypedDict):
39
- name: APIProviderName
40
- base_url: str
41
- api_key: NotRequired[str | None]
42
- struct_outputs_support: NotRequired[tuple[str, ...]]
43
-
44
-
45
- def get_api_providers() -> dict[APIProviderName, APIProvider]:
46
- """Returns a dictionary of available API providers."""
47
- return {
48
- "openai": APIProvider(
49
- name="openai",
50
- base_url="https://api.openai.com/v1",
51
- api_key=os.getenv("OPENAI_API_KEY"),
52
- struct_outputs_support=("*",),
53
- ),
54
- "openrouter": APIProvider(
55
- name="openrouter",
56
- base_url="https://openrouter.ai/api/v1",
57
- api_key=os.getenv("OPENROUTER_API_KEY"),
58
- struct_outputs_support=(),
59
- ),
60
- "google_ai_studio": APIProvider(
61
- name="google_ai_studio",
62
- base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
63
- api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
64
- struct_outputs_support=("*",),
65
- ),
66
- }
67
-
68
-
69
- def retry_error_callback(retry_state: RetryCallState) -> Completion:
70
- exception = retry_state.outcome.exception() if retry_state.outcome else None
71
- if exception:
72
- if retry_state.attempt_number == 1:
73
- logger.warning(
74
- f"\nCloudLLM completion request failed:\n{exception}",
75
- # exc_info=exception,
76
- )
77
- if retry_state.attempt_number > 1:
78
- logger.warning(
79
- f"\nCloudLLM completion request failed after retrying:\n{exception}",
80
- # exc_info=exception,
81
- )
82
- failed_message = AssistantMessage(content=None, refusal=str(exception))
29
+ class APIProvider(TypedDict, total=False):
30
+ name: Required[str]
31
+ base_url: str | None
32
+ api_key: str | None
33
+ # Wildcard patterns for model names that support response schema validation:
34
+ response_schema_support: tuple[str, ...] | None
35
+
36
+
37
+ def make_refusal_completion(model_name: str, err: BaseException) -> Completion:
38
+ failed_message = AssistantMessage(content=None, refusal=str(err))
83
39
 
84
40
  return Completion(
85
- model="",
41
+ model=model_name,
86
42
  choices=[CompletionChoice(message=failed_message, finish_reason=None, index=0)],
87
43
  )
88
44
 
89
45
 
90
- def retry_before_sleep_callback(retry_state: RetryCallState) -> None:
91
- exception = retry_state.outcome.exception() if retry_state.outcome else None
92
- if exception:
93
- logger.info(
94
- "\nRetrying CloudLLM completion request "
95
- f"(attempt {retry_state.attempt_number}):\n{exception}"
96
- )
46
+ class CloudLLMSettings(LLMSettings, total=False):
47
+ extra_headers: dict[str, Any] | None
48
+ extra_body: object | None
49
+ extra_query: dict[str, Any] | None
97
50
 
98
51
 
99
- class CloudLLMSettings(LLMSettings, total=False):
100
- use_struct_outputs: bool
52
+ LLMRateLimiter = RateLimiterC[
53
+ Messages,
54
+ AssistantMessage
55
+ | AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent],
56
+ ]
101
57
 
102
58
 
103
59
  class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
@@ -105,25 +61,24 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
105
61
  self,
106
62
  # Base LLM args
107
63
  model_name: str,
64
+ api_provider: APIProvider,
108
65
  converters: ConvertT_co,
109
66
  llm_settings: SettingsT_co | None = None,
110
67
  tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
111
- response_format: type | Mapping[str, type] | None = None,
68
+ response_schema: Any | None = None,
69
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
70
+ apply_response_schema_via_provider: bool = True,
112
71
  model_id: str | None = None,
113
- # Custom LLM provider
114
- api_provider: APIProvider | None = None,
115
72
  # Connection settings
116
73
  async_http_client: httpx.AsyncClient | None = None,
117
74
  async_http_client_params: (
118
75
  dict[str, Any] | AsyncHTTPClientParams | None
119
76
  ) = None,
77
+ max_client_retries: int = 2,
120
78
  # Rate limiting
121
- rate_limiter: (RateLimiterC[Messages, AssistantMessage] | None) = None,
122
- rate_limiter_rpm: float | None = None,
123
- rate_limiter_chunk_size: int = 1000,
124
- rate_limiter_max_concurrency: int = 300,
125
- # Retries
126
- num_generation_retries: int = 0,
79
+ rate_limiter: LLMRateLimiter | None = None,
80
+ # LLM response retries: try to regenerate to pass validation
81
+ max_response_retries: int = 0,
127
82
  **kwargs: Any,
128
83
  ) -> None:
129
84
  self.llm_settings: CloudLLMSettings | None
@@ -134,61 +89,30 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
134
89
  converters=converters,
135
90
  model_id=model_id,
136
91
  tools=tools,
137
- response_format=response_format,
92
+ response_schema=response_schema,
93
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
138
94
  **kwargs,
139
95
  )
140
96
 
141
97
  self._model_name = model_name
142
- model_name_parts = model_name.split(":", 1)
143
-
144
- if len(model_name_parts) == 2:
145
- api_provider_name, api_model_name = model_name_parts
146
- self._api_model_name: str = api_model_name
147
-
148
- api_providers = get_api_providers()
149
-
150
- if api_provider_name not in api_providers:
151
- raise ValueError(
152
- f"API provider '{api_provider_name}' is not supported. "
153
- f"Supported providers are: {', '.join(api_providers.keys())}"
154
- )
155
-
156
- _api_provider = api_providers[api_provider_name]
157
- elif api_provider is not None:
158
- self._api_model_name: str = model_name
159
- _api_provider = api_provider
160
- else:
161
- raise ValueError(
162
- "API provider must be specified either in the model name "
163
- "or as a separate argument."
164
- )
165
-
166
- self._api_provider_name: APIProviderName = _api_provider["name"]
167
- self._base_url: str | None = _api_provider.get("base_url")
168
- self._api_key: str | None = _api_provider.get("api_key")
169
- self._struct_outputs_support: bool = any(
170
- fnmatch.fnmatch(self._model_name, pat)
171
- for pat in _api_provider.get("struct_outputs_support", ())
172
- )
98
+ self._api_provider = api_provider
99
+ self._apply_response_schema_via_provider = apply_response_schema_via_provider
173
100
 
174
101
  if (
175
- self._llm_settings.get("use_struct_outputs")
176
- and not self._struct_outputs_support
102
+ apply_response_schema_via_provider
103
+ and response_schema_by_xml_tag is not None
177
104
  ):
178
105
  raise ValueError(
179
- f"Model {self._model_name} does not support structured outputs."
106
+ "Response schema by XML tag is not supported "
107
+ "when apply_response_schema_via_provider is True."
180
108
  )
181
109
 
182
- self._tool_call_settings: dict[str, Any] = {}
183
-
184
- self._rate_limiter: RateLimiterC[Messages, AssistantMessage] | None = (
185
- self._get_rate_limiter(
186
- rate_limiter=rate_limiter,
187
- rpm=rate_limiter_rpm,
188
- chunk_size=rate_limiter_chunk_size,
189
- max_concurrency=rate_limiter_max_concurrency,
110
+ self._rate_limiter: LLMRateLimiter | None = None
111
+ if rate_limiter is not None:
112
+ self._rate_limiter = rate_limiter
113
+ logger.info(
114
+ f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
190
115
  )
191
- )
192
116
 
193
117
  self._async_http_client: httpx.AsyncClient | None = None
194
118
  if async_http_client is not None:
@@ -198,18 +122,31 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
198
122
  async_http_client_params
199
123
  )
200
124
 
201
- self.num_generation_retries = num_generation_retries
125
+ self.max_client_retries = max_client_retries
126
+ self.max_response_retries = max_response_retries
202
127
 
203
128
  @property
204
- def api_provider_name(self) -> APIProviderName | None:
205
- return self._api_provider_name
129
+ def api_provider(self) -> APIProvider:
130
+ return self._api_provider
206
131
 
207
132
  @property
208
- def rate_limiter(
209
- self,
210
- ) -> RateLimiterC[Messages, AssistantMessage] | None:
133
+ def rate_limiter(self) -> LLMRateLimiter | None:
211
134
  return self._rate_limiter
212
135
 
136
+ @property
137
+ def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
138
+ return self._tools
139
+
140
+ @tools.setter
141
+ def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
142
+ if not tools:
143
+ self._tools = None
144
+ return
145
+ strict_value = True if self._apply_response_schema_via_provider else None
146
+ for t in tools:
147
+ t.strict = strict_value
148
+ self._tools = {t.name: t for t in tools}
149
+
213
150
  def _make_completion_kwargs(
214
151
  self,
215
152
  conversation: Messages,
@@ -221,21 +158,17 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
221
158
  api_tools = None
222
159
  api_tool_choice = None
223
160
  if self.tools:
224
- api_tools = [
225
- self._converters.to_tool(t, **self._tool_call_settings)
226
- for t in self.tools.values()
227
- ]
161
+ api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
228
162
  if tool_choice is not None:
229
163
  api_tool_choice = self._converters.to_tool_choice(tool_choice)
230
164
 
231
165
  api_llm_settings = deepcopy(self.llm_settings or {})
232
- api_llm_settings.pop("use_struct_outputs", None)
233
166
 
234
167
  return dict(
235
168
  api_messages=api_messages,
236
169
  api_tools=api_tools,
237
170
  api_tool_choice=api_tool_choice,
238
- api_response_format=self._response_format,
171
+ api_response_schema=self._response_schema,
239
172
  n_choices=n_choices,
240
173
  **api_llm_settings,
241
174
  )
@@ -247,19 +180,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
247
180
  *,
248
181
  api_tools: list[Any] | None = None,
249
182
  api_tool_choice: Any | None = None,
250
- n_choices: int | None = None,
251
- **api_llm_settings: Any,
252
- ) -> Any:
253
- pass
254
-
255
- @abstractmethod
256
- async def _get_parsed_completion(
257
- self,
258
- api_messages: list[Any],
259
- *,
260
- api_tools: list[Any] | None = None,
261
- api_tool_choice: Any | None = None,
262
- api_response_format: type | None = None,
183
+ api_response_schema: type | None = None,
263
184
  n_choices: int | None = None,
264
185
  **api_llm_settings: Any,
265
186
  ) -> Any:
@@ -272,25 +193,14 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
272
193
  *,
273
194
  api_tools: list[Any] | None = None,
274
195
  api_tool_choice: Any | None = None,
196
+ api_response_schema: type | None = None,
275
197
  n_choices: int | None = None,
276
198
  **api_llm_settings: Any,
277
199
  ) -> AsyncIterator[Any]:
278
200
  pass
279
201
 
280
- @abstractmethod
281
- async def _get_parsed_completion_stream(
282
- self,
283
- api_messages: list[Any],
284
- *,
285
- api_tools: list[Any] | None = None,
286
- api_tool_choice: Any | None = None,
287
- api_response_format: type | None = None,
288
- n_choices: int | None = None,
289
- **api_llm_settings: Any,
290
- ) -> AsyncIterator[Any]:
291
- pass
292
-
293
- async def generate_completion_no_retry(
202
+ @limit_rate
203
+ async def _generate_completion_once(
294
204
  self,
295
205
  conversation: Messages,
296
206
  *,
@@ -301,98 +211,155 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
301
211
  conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
302
212
  )
303
213
 
304
- if not self._llm_settings.get("use_struct_outputs"):
305
- completion_kwargs.pop("api_response_format", None)
306
- api_completion = await self._get_completion(**completion_kwargs)
307
- else:
308
- api_completion = await self._get_parsed_completion(**completion_kwargs)
214
+ if not self._apply_response_schema_via_provider:
215
+ completion_kwargs.pop("api_response_schema", None)
216
+ api_completion = await self._get_completion(**completion_kwargs)
309
217
 
310
218
  completion = self._converters.from_completion(
311
219
  api_completion, name=self.model_id
312
220
  )
313
221
 
314
- if not self._llm_settings.get("use_struct_outputs"):
315
- # If validation is not handled by the structured output functionality
316
- # of the LLM provider
317
- self._validate_completion(completion)
222
+ if not self._apply_response_schema_via_provider:
223
+ self._validate_response(completion)
318
224
  self._validate_tool_calls(completion)
319
225
 
320
226
  return completion
321
227
 
322
- async def generate_completion_stream(
228
+ async def generate_completion(
229
+ self,
230
+ conversation: Messages,
231
+ *,
232
+ tool_choice: ToolChoice | None = None,
233
+ n_choices: int | None = None,
234
+ proc_name: str | None = None,
235
+ call_id: str | None = None,
236
+ ) -> Completion:
237
+ n_attempt = 0
238
+ while n_attempt <= self.max_response_retries:
239
+ try:
240
+ return await self._generate_completion_once(
241
+ conversation, # type: ignore[return]
242
+ tool_choice=tool_choice,
243
+ n_choices=n_choices,
244
+ )
245
+ except (LLMResponseValidationError, LLMToolCallValidationError) as err:
246
+ n_attempt += 1
247
+
248
+ if n_attempt > self.max_response_retries:
249
+ if n_attempt == 1:
250
+ logger.warning(f"\nCloudLLM completion request failed:\n{err}")
251
+ if n_attempt > 1:
252
+ logger.warning(
253
+ f"\nCloudLLM completion request failed after retrying:\n{err}"
254
+ )
255
+ raise err
256
+ # return make_refusal_completion(self._model_name, err)
257
+
258
+ logger.warning(
259
+ f"\nCloudLLM completion request failed (retry attempt {n_attempt}):"
260
+ f"\n{err}"
261
+ )
262
+
263
+ return make_refusal_completion(
264
+ self._model_name,
265
+ Exception("Unexpected error: retry loop exited without returning"),
266
+ )
267
+
268
+ @limit_rate
269
+ async def _generate_completion_stream_once(
323
270
  self,
324
271
  conversation: Messages,
325
272
  *,
326
273
  tool_choice: ToolChoice | None = None,
327
274
  n_choices: int | None = None,
275
+ proc_name: str | None = None,
276
+ call_id: str | None = None,
328
277
  ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
329
278
  completion_kwargs = self._make_completion_kwargs(
330
279
  conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
331
280
  )
281
+ if not self._apply_response_schema_via_provider:
282
+ completion_kwargs.pop("api_response_schema", None)
283
+
284
+ api_stream = self._get_completion_stream(**completion_kwargs)
285
+ api_stream = cast("AsyncIterator[Any]", api_stream)
332
286
 
333
- if not self._llm_settings.get("use_struct_outputs"):
334
- completion_kwargs.pop("api_response_format", None)
335
- api_stream = await self._get_completion_stream(**completion_kwargs)
336
- else:
337
- api_stream = await self._get_parsed_completion_stream(**completion_kwargs)
287
+ async def iterator() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
288
+ api_completion_chunks: list[Any] = []
338
289
 
339
- async def iterate() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
340
- completion_chunks: list[CompletionChunk] = []
341
290
  async for api_completion_chunk in api_stream:
291
+ api_completion_chunks.append(api_completion_chunk)
342
292
  completion_chunk = self._converters.from_completion_chunk(
343
293
  api_completion_chunk, name=self.model_id
344
294
  )
345
- completion_chunks.append(completion_chunk)
346
- yield CompletionChunkEvent(data=completion_chunk, name=self.model_id)
347
295
 
348
- # TODO: can be done using the OpenAI final_completion_chunk
349
- completion = combine_completion_chunks(completion_chunks)
296
+ yield CompletionChunkEvent(
297
+ data=completion_chunk, proc_name=proc_name, call_id=call_id
298
+ )
299
+
300
+ api_completion = self.combine_completion_chunks(api_completion_chunks)
301
+ completion = self._converters.from_completion(
302
+ api_completion, name=self.model_id
303
+ )
350
304
 
351
- yield CompletionEvent(data=completion, name=self.model_id)
305
+ yield CompletionEvent(data=completion, proc_name=proc_name, call_id=call_id)
352
306
 
353
- if not self._llm_settings.get("use_struct_outputs"):
354
- # If validation is not handled by the structured outputs functionality
355
- # of the LLM provider
356
- self._validate_completion(completion)
307
+ if not self._apply_response_schema_via_provider:
308
+ self._validate_response(completion)
357
309
  self._validate_tool_calls(completion)
358
310
 
359
- return iterate()
311
+ return iterator()
360
312
 
361
- @limit_rate
362
- async def generate_completion( # type: ignore[override]
313
+ async def generate_completion_stream( # type: ignore[override]
363
314
  self,
364
315
  conversation: Messages,
365
316
  *,
366
317
  tool_choice: ToolChoice | None = None,
367
318
  n_choices: int | None = None,
368
- ) -> Completion:
369
- wrapped_func = retry(
370
- wait=wait_random_exponential(min=1, max=8),
371
- stop=stop_after_attempt(self.num_generation_retries + 1),
372
- before_sleep=retry_before_sleep_callback,
373
- retry_error_callback=retry_error_callback,
374
- )(self.__class__.generate_completion_no_retry)
375
-
376
- return await wrapped_func(
377
- self, conversation, tool_choice=tool_choice, n_choices=n_choices
378
- )
379
-
380
- def _get_rate_limiter(
381
- self,
382
- rate_limiter: RateLimiterC[Messages, AssistantMessage] | None = None,
383
- rpm: float | None = None,
384
- chunk_size: int = 1000,
385
- max_concurrency: int = 300,
386
- ) -> RateLimiterC[Messages, AssistantMessage] | None:
387
- if rate_limiter is not None:
388
- logger.info(
389
- f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
390
- )
391
- return rate_limiter
392
- if rpm is not None:
393
- logger.info(f"[{self.__class__.__name__}] Set rate limit to {rpm} RPM")
394
- return RateLimiterC(
395
- rpm=rpm, chunk_size=chunk_size, max_concurrency=max_concurrency
396
- )
319
+ proc_name: str | None = None,
320
+ call_id: str | None = None,
321
+ ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
322
+ n_attempt = 0
323
+ while n_attempt <= self.max_response_retries:
324
+ try:
325
+ async for event in await self._generate_completion_stream_once( # type: ignore[return]
326
+ conversation, # type: ignore[arg-type]
327
+ tool_choice=tool_choice,
328
+ n_choices=n_choices,
329
+ proc_name=proc_name,
330
+ call_id=call_id,
331
+ ):
332
+ yield event
333
+ return
334
+ except (LLMResponseValidationError, LLMToolCallValidationError) as err:
335
+ err_data = LLMStreamingErrorData(
336
+ error=err, model_name=self._model_name, model_id=self.model_id
337
+ )
338
+ yield LLMStreamingErrorEvent(
339
+ data=err_data, proc_name=proc_name, call_id=call_id
340
+ )
397
341
 
398
- return None
342
+ n_attempt += 1
343
+ if n_attempt > self.max_response_retries:
344
+ if n_attempt == 1:
345
+ logger.warning(f"\nCloudLLM completion request failed:\n{err}")
346
+ if n_attempt > 1:
347
+ logger.warning(
348
+ "\nCloudLLM completion request failed after "
349
+ f"retrying:\n{err}"
350
+ )
351
+ refusal_completion = make_refusal_completion(
352
+ self._model_name, err
353
+ )
354
+ yield CompletionEvent(
355
+ data=refusal_completion,
356
+ proc_name=proc_name,
357
+ call_id=call_id,
358
+ )
359
+ raise err
360
+ # return
361
+
362
+ logger.warning(
363
+ "\nCloudLLM completion request failed "
364
+ f"(retry attempt {n_attempt}):\n{err}"
365
+ )