grasp_agents 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/cloud_llm.py +191 -218
  2. grasp_agents/comm_processor.py +101 -100
  3. grasp_agents/errors.py +69 -9
  4. grasp_agents/litellm/__init__.py +106 -0
  5. grasp_agents/litellm/completion_chunk_converters.py +68 -0
  6. grasp_agents/litellm/completion_converters.py +72 -0
  7. grasp_agents/litellm/converters.py +138 -0
  8. grasp_agents/litellm/lite_llm.py +210 -0
  9. grasp_agents/litellm/message_converters.py +66 -0
  10. grasp_agents/llm.py +84 -49
  11. grasp_agents/llm_agent.py +136 -120
  12. grasp_agents/llm_agent_memory.py +3 -3
  13. grasp_agents/llm_policy_executor.py +167 -174
  14. grasp_agents/memory.py +4 -0
  15. grasp_agents/openai/__init__.py +24 -9
  16. grasp_agents/openai/completion_chunk_converters.py +6 -6
  17. grasp_agents/openai/completion_converters.py +12 -14
  18. grasp_agents/openai/content_converters.py +1 -3
  19. grasp_agents/openai/converters.py +6 -8
  20. grasp_agents/openai/message_converters.py +21 -3
  21. grasp_agents/openai/openai_llm.py +155 -103
  22. grasp_agents/openai/tool_converters.py +4 -6
  23. grasp_agents/packet.py +5 -2
  24. grasp_agents/packet_pool.py +14 -13
  25. grasp_agents/printer.py +234 -72
  26. grasp_agents/processor.py +228 -88
  27. grasp_agents/prompt_builder.py +2 -2
  28. grasp_agents/run_context.py +11 -20
  29. grasp_agents/runner.py +42 -0
  30. grasp_agents/typing/completion.py +16 -9
  31. grasp_agents/typing/completion_chunk.py +51 -22
  32. grasp_agents/typing/events.py +95 -19
  33. grasp_agents/typing/message.py +25 -1
  34. grasp_agents/typing/tool.py +2 -0
  35. grasp_agents/usage_tracker.py +31 -37
  36. grasp_agents/utils.py +95 -84
  37. grasp_agents/workflow/looped_workflow.py +60 -11
  38. grasp_agents/workflow/sequential_workflow.py +43 -11
  39. grasp_agents/workflow/workflow_processor.py +25 -24
  40. {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
  41. grasp_agents-0.5.0.dist-info/RECORD +57 -0
  42. grasp_agents-0.4.6.dist-info/RECORD +0 -50
  43. {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
  44. {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/cloud_llm.py CHANGED
@@ -1,101 +1,59 @@
1
- import fnmatch
2
1
  import logging
3
- import os
4
2
  from abc import abstractmethod
5
- from collections.abc import AsyncIterator, Mapping
3
+ from collections.abc import AsyncIterator, Mapping, Sequence
6
4
  from copy import deepcopy
7
- from typing import Any, Generic, Literal, NotRequired
5
+ from typing import Any, Generic, Required, cast
8
6
 
9
7
  import httpx
10
8
  from pydantic import BaseModel
11
- from tenacity import (
12
- RetryCallState,
13
- retry,
14
- stop_after_attempt,
15
- wait_random_exponential,
16
- )
17
9
  from typing_extensions import TypedDict
18
10
 
11
+ from .errors import LLMResponseValidationError, LLMToolCallValidationError
19
12
  from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
20
13
  from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
21
14
  from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
22
15
  from .typing.completion import Completion
23
- from .typing.completion_chunk import (
24
- CompletionChoice,
25
- CompletionChunk,
26
- combine_completion_chunks,
16
+ from .typing.completion_chunk import CompletionChoice
17
+ from .typing.events import (
18
+ CompletionChunkEvent,
19
+ CompletionEvent,
20
+ LLMStreamingErrorData,
21
+ LLMStreamingErrorEvent,
27
22
  )
28
- from .typing.events import CompletionChunkEvent, CompletionEvent
29
23
  from .typing.message import AssistantMessage, Messages
30
24
  from .typing.tool import BaseTool, ToolChoice
31
25
 
32
26
  logger = logging.getLogger(__name__)
33
27
 
34
28
 
35
- APIProviderName = Literal["openai", "openrouter", "google_ai_studio"]
36
-
37
-
38
- class APIProvider(TypedDict):
39
- name: APIProviderName
40
- base_url: str
41
- api_key: NotRequired[str | None]
42
- struct_outputs_support: NotRequired[tuple[str, ...]]
43
-
44
-
45
- API_PROVIDERS: dict[APIProviderName, APIProvider] = {
46
- "openai": APIProvider(
47
- name="openai",
48
- base_url="https://api.openai.com/v1",
49
- api_key=os.getenv("OPENAI_API_KEY"),
50
- struct_outputs_support=("*",),
51
- ),
52
- "openrouter": APIProvider(
53
- name="openrouter",
54
- base_url="https://openrouter.ai/api/v1",
55
- api_key=os.getenv("OPENROUTER_API_KEY"),
56
- struct_outputs_support=(),
57
- ),
58
- "google_ai_studio": APIProvider(
59
- name="google_ai_studio",
60
- base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
61
- api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
62
- struct_outputs_support=("*",),
63
- ),
64
- }
65
-
66
-
67
- def retry_error_callback(retry_state: RetryCallState) -> Completion:
68
- exception = retry_state.outcome.exception() if retry_state.outcome else None
69
- if exception:
70
- if retry_state.attempt_number == 1:
71
- logger.warning(
72
- f"\nCloudLLM completion request failed:\n{exception}",
73
- # exc_info=exception,
74
- )
75
- if retry_state.attempt_number > 1:
76
- logger.warning(
77
- f"\nCloudLLM completion request failed after retrying:\n{exception}",
78
- # exc_info=exception,
79
- )
80
- failed_message = AssistantMessage(content=None, refusal=str(exception))
29
+ class APIProvider(TypedDict, total=False):
30
+ name: Required[str]
31
+ base_url: str | None
32
+ api_key: str | None
33
+ # Wildcard patterns for model names that support response schema validation:
34
+ response_schema_support: tuple[str, ...] | None
35
+
36
+
37
+ def make_refusal_completion(model_name: str, err: BaseException) -> Completion:
38
+ failed_message = AssistantMessage(content=None, refusal=str(err))
81
39
 
82
40
  return Completion(
83
- model="",
41
+ model=model_name,
84
42
  choices=[CompletionChoice(message=failed_message, finish_reason=None, index=0)],
85
43
  )
86
44
 
87
45
 
88
- def retry_before_sleep_callback(retry_state: RetryCallState) -> None:
89
- exception = retry_state.outcome.exception() if retry_state.outcome else None
90
- if exception:
91
- logger.info(
92
- "\nRetrying CloudLLM completion request "
93
- f"(attempt {retry_state.attempt_number}):\n{exception}"
94
- )
46
+ class CloudLLMSettings(LLMSettings, total=False):
47
+ extra_headers: dict[str, Any] | None
48
+ extra_body: object | None
49
+ extra_query: dict[str, Any] | None
95
50
 
96
51
 
97
- class CloudLLMSettings(LLMSettings, total=False):
98
- use_struct_outputs: bool
52
+ LLMRateLimiter = RateLimiterC[
53
+ Messages,
54
+ AssistantMessage
55
+ | AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent],
56
+ ]
99
57
 
100
58
 
101
59
  class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
@@ -103,25 +61,24 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
103
61
  self,
104
62
  # Base LLM args
105
63
  model_name: str,
64
+ api_provider: APIProvider,
106
65
  converters: ConvertT_co,
107
66
  llm_settings: SettingsT_co | None = None,
108
67
  tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
109
- response_format: type | Mapping[str, type] | None = None,
68
+ response_schema: Any | None = None,
69
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
70
+ apply_response_schema_via_provider: bool = True,
110
71
  model_id: str | None = None,
111
- # Custom LLM provider
112
- api_provider: APIProvider | None = None,
113
72
  # Connection settings
114
73
  async_http_client: httpx.AsyncClient | None = None,
115
74
  async_http_client_params: (
116
75
  dict[str, Any] | AsyncHTTPClientParams | None
117
76
  ) = None,
77
+ max_client_retries: int = 2,
118
78
  # Rate limiting
119
- rate_limiter: (RateLimiterC[Messages, AssistantMessage] | None) = None,
120
- rate_limiter_rpm: float | None = None,
121
- rate_limiter_chunk_size: int = 1000,
122
- rate_limiter_max_concurrency: int = 300,
123
- # Retries
124
- num_generation_retries: int = 0,
79
+ rate_limiter: LLMRateLimiter | None = None,
80
+ # LLM response retries: try to regenerate to pass validation
81
+ max_response_retries: int = 0,
125
82
  **kwargs: Any,
126
83
  ) -> None:
127
84
  self.llm_settings: CloudLLMSettings | None
@@ -132,57 +89,30 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
132
89
  converters=converters,
133
90
  model_id=model_id,
134
91
  tools=tools,
135
- response_format=response_format,
92
+ response_schema=response_schema,
93
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
136
94
  **kwargs,
137
95
  )
138
96
 
139
97
  self._model_name = model_name
140
- model_name_parts = model_name.split(":", 1)
141
-
142
- if len(model_name_parts) == 2:
143
- api_provider_name, api_model_name = model_name_parts
144
- self._api_model_name: str = api_model_name
145
- if api_provider_name not in API_PROVIDERS:
146
- raise ValueError(
147
- f"API provider '{api_provider_name}' is not supported. "
148
- f"Supported providers are: {', '.join(API_PROVIDERS.keys())}"
149
- )
150
- _api_provider = API_PROVIDERS[api_provider_name]
151
- elif api_provider is not None:
152
- self._api_model_name: str = model_name
153
- _api_provider = api_provider
154
- else:
155
- raise ValueError(
156
- "API provider must be specified either in the model name "
157
- "or as a separate argument."
158
- )
159
-
160
- self._api_provider_name: APIProviderName = _api_provider["name"]
161
- self._base_url: str | None = _api_provider.get("base_url")
162
- self._api_key: str | None = _api_provider.get("api_key")
163
- self._struct_outputs_support: bool = any(
164
- fnmatch.fnmatch(self._model_name, pat)
165
- for pat in _api_provider.get("struct_outputs_support", ())
166
- )
98
+ self._api_provider = api_provider
99
+ self._apply_response_schema_via_provider = apply_response_schema_via_provider
167
100
 
168
101
  if (
169
- self._llm_settings.get("use_struct_outputs")
170
- and not self._struct_outputs_support
102
+ apply_response_schema_via_provider
103
+ and response_schema_by_xml_tag is not None
171
104
  ):
172
105
  raise ValueError(
173
- f"Model {self._model_name} does not support structured outputs."
106
+ "Response schema by XML tag is not supported "
107
+ "when apply_response_schema_via_provider is True."
174
108
  )
175
109
 
176
- self._tool_call_settings: dict[str, Any] = {}
177
-
178
- self._rate_limiter: RateLimiterC[Messages, AssistantMessage] | None = (
179
- self._get_rate_limiter(
180
- rate_limiter=rate_limiter,
181
- rpm=rate_limiter_rpm,
182
- chunk_size=rate_limiter_chunk_size,
183
- max_concurrency=rate_limiter_max_concurrency,
110
+ self._rate_limiter: LLMRateLimiter | None = None
111
+ if rate_limiter is not None:
112
+ self._rate_limiter = rate_limiter
113
+ logger.info(
114
+ f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
184
115
  )
185
- )
186
116
 
187
117
  self._async_http_client: httpx.AsyncClient | None = None
188
118
  if async_http_client is not None:
@@ -192,18 +122,31 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
192
122
  async_http_client_params
193
123
  )
194
124
 
195
- self.num_generation_retries = num_generation_retries
125
+ self.max_client_retries = max_client_retries
126
+ self.max_response_retries = max_response_retries
196
127
 
197
128
  @property
198
- def api_provider_name(self) -> APIProviderName | None:
199
- return self._api_provider_name
129
+ def api_provider(self) -> APIProvider:
130
+ return self._api_provider
200
131
 
201
132
  @property
202
- def rate_limiter(
203
- self,
204
- ) -> RateLimiterC[Messages, AssistantMessage] | None:
133
+ def rate_limiter(self) -> LLMRateLimiter | None:
205
134
  return self._rate_limiter
206
135
 
136
+ @property
137
+ def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
138
+ return self._tools
139
+
140
+ @tools.setter
141
+ def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
142
+ if not tools:
143
+ self._tools = None
144
+ return
145
+ strict_value = True if self._apply_response_schema_via_provider else None
146
+ for t in tools:
147
+ t.strict = strict_value
148
+ self._tools = {t.name: t for t in tools}
149
+
207
150
  def _make_completion_kwargs(
208
151
  self,
209
152
  conversation: Messages,
@@ -215,21 +158,17 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
215
158
  api_tools = None
216
159
  api_tool_choice = None
217
160
  if self.tools:
218
- api_tools = [
219
- self._converters.to_tool(t, **self._tool_call_settings)
220
- for t in self.tools.values()
221
- ]
161
+ api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
222
162
  if tool_choice is not None:
223
163
  api_tool_choice = self._converters.to_tool_choice(tool_choice)
224
164
 
225
165
  api_llm_settings = deepcopy(self.llm_settings or {})
226
- api_llm_settings.pop("use_struct_outputs", None)
227
166
 
228
167
  return dict(
229
168
  api_messages=api_messages,
230
169
  api_tools=api_tools,
231
170
  api_tool_choice=api_tool_choice,
232
- api_response_format=self._response_format,
171
+ api_response_schema=self._response_schema,
233
172
  n_choices=n_choices,
234
173
  **api_llm_settings,
235
174
  )
@@ -241,19 +180,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
241
180
  *,
242
181
  api_tools: list[Any] | None = None,
243
182
  api_tool_choice: Any | None = None,
244
- n_choices: int | None = None,
245
- **api_llm_settings: Any,
246
- ) -> Any:
247
- pass
248
-
249
- @abstractmethod
250
- async def _get_parsed_completion(
251
- self,
252
- api_messages: list[Any],
253
- *,
254
- api_tools: list[Any] | None = None,
255
- api_tool_choice: Any | None = None,
256
- api_response_format: type | None = None,
183
+ api_response_schema: type | None = None,
257
184
  n_choices: int | None = None,
258
185
  **api_llm_settings: Any,
259
186
  ) -> Any:
@@ -266,25 +193,14 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
266
193
  *,
267
194
  api_tools: list[Any] | None = None,
268
195
  api_tool_choice: Any | None = None,
196
+ api_response_schema: type | None = None,
269
197
  n_choices: int | None = None,
270
198
  **api_llm_settings: Any,
271
199
  ) -> AsyncIterator[Any]:
272
200
  pass
273
201
 
274
- @abstractmethod
275
- async def _get_parsed_completion_stream(
276
- self,
277
- api_messages: list[Any],
278
- *,
279
- api_tools: list[Any] | None = None,
280
- api_tool_choice: Any | None = None,
281
- api_response_format: type | None = None,
282
- n_choices: int | None = None,
283
- **api_llm_settings: Any,
284
- ) -> AsyncIterator[Any]:
285
- pass
286
-
287
- async def generate_completion_no_retry(
202
+ @limit_rate
203
+ async def _generate_completion_once(
288
204
  self,
289
205
  conversation: Messages,
290
206
  *,
@@ -295,98 +211,155 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
295
211
  conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
296
212
  )
297
213
 
298
- if not self._llm_settings.get("use_struct_outputs"):
299
- completion_kwargs.pop("api_response_format", None)
300
- api_completion = await self._get_completion(**completion_kwargs)
301
- else:
302
- api_completion = await self._get_parsed_completion(**completion_kwargs)
214
+ if not self._apply_response_schema_via_provider:
215
+ completion_kwargs.pop("api_response_schema", None)
216
+ api_completion = await self._get_completion(**completion_kwargs)
303
217
 
304
218
  completion = self._converters.from_completion(
305
219
  api_completion, name=self.model_id
306
220
  )
307
221
 
308
- if not self._llm_settings.get("use_struct_outputs"):
309
- # If validation is not handled by the structured output functionality
310
- # of the LLM provider
311
- self._validate_completion(completion)
222
+ if not self._apply_response_schema_via_provider:
223
+ self._validate_response(completion)
312
224
  self._validate_tool_calls(completion)
313
225
 
314
226
  return completion
315
227
 
316
- async def generate_completion_stream(
228
+ async def generate_completion(
229
+ self,
230
+ conversation: Messages,
231
+ *,
232
+ tool_choice: ToolChoice | None = None,
233
+ n_choices: int | None = None,
234
+ proc_name: str | None = None,
235
+ call_id: str | None = None,
236
+ ) -> Completion:
237
+ n_attempt = 0
238
+ while n_attempt <= self.max_response_retries:
239
+ try:
240
+ return await self._generate_completion_once(
241
+ conversation, # type: ignore[return]
242
+ tool_choice=tool_choice,
243
+ n_choices=n_choices,
244
+ )
245
+ except (LLMResponseValidationError, LLMToolCallValidationError) as err:
246
+ n_attempt += 1
247
+
248
+ if n_attempt > self.max_response_retries:
249
+ if n_attempt == 1:
250
+ logger.warning(f"\nCloudLLM completion request failed:\n{err}")
251
+ if n_attempt > 1:
252
+ logger.warning(
253
+ f"\nCloudLLM completion request failed after retrying:\n{err}"
254
+ )
255
+ raise err
256
+ # return make_refusal_completion(self._model_name, err)
257
+
258
+ logger.warning(
259
+ f"\nCloudLLM completion request failed (retry attempt {n_attempt}):"
260
+ f"\n{err}"
261
+ )
262
+
263
+ return make_refusal_completion(
264
+ self._model_name,
265
+ Exception("Unexpected error: retry loop exited without returning"),
266
+ )
267
+
268
+ @limit_rate
269
+ async def _generate_completion_stream_once(
317
270
  self,
318
271
  conversation: Messages,
319
272
  *,
320
273
  tool_choice: ToolChoice | None = None,
321
274
  n_choices: int | None = None,
275
+ proc_name: str | None = None,
276
+ call_id: str | None = None,
322
277
  ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
323
278
  completion_kwargs = self._make_completion_kwargs(
324
279
  conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
325
280
  )
281
+ if not self._apply_response_schema_via_provider:
282
+ completion_kwargs.pop("api_response_schema", None)
283
+
284
+ api_stream = self._get_completion_stream(**completion_kwargs)
285
+ api_stream = cast("AsyncIterator[Any]", api_stream)
326
286
 
327
- if not self._llm_settings.get("use_struct_outputs"):
328
- completion_kwargs.pop("api_response_format", None)
329
- api_stream = await self._get_completion_stream(**completion_kwargs)
330
- else:
331
- api_stream = await self._get_parsed_completion_stream(**completion_kwargs)
287
+ async def iterator() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
288
+ api_completion_chunks: list[Any] = []
332
289
 
333
- async def iterate() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
334
- completion_chunks: list[CompletionChunk] = []
335
290
  async for api_completion_chunk in api_stream:
291
+ api_completion_chunks.append(api_completion_chunk)
336
292
  completion_chunk = self._converters.from_completion_chunk(
337
293
  api_completion_chunk, name=self.model_id
338
294
  )
339
- completion_chunks.append(completion_chunk)
340
- yield CompletionChunkEvent(data=completion_chunk, name=self.model_id)
341
295
 
342
- # TODO: can be done using the OpenAI final_completion_chunk
343
- completion = combine_completion_chunks(completion_chunks)
296
+ yield CompletionChunkEvent(
297
+ data=completion_chunk, proc_name=proc_name, call_id=call_id
298
+ )
299
+
300
+ api_completion = self.combine_completion_chunks(api_completion_chunks)
301
+ completion = self._converters.from_completion(
302
+ api_completion, name=self.model_id
303
+ )
344
304
 
345
- yield CompletionEvent(data=completion, name=self.model_id)
305
+ yield CompletionEvent(data=completion, proc_name=proc_name, call_id=call_id)
346
306
 
347
- if not self._llm_settings.get("use_struct_outputs"):
348
- # If validation is not handled by the structured outputs functionality
349
- # of the LLM provider
350
- self._validate_completion(completion)
307
+ if not self._apply_response_schema_via_provider:
308
+ self._validate_response(completion)
351
309
  self._validate_tool_calls(completion)
352
310
 
353
- return iterate()
311
+ return iterator()
354
312
 
355
- @limit_rate
356
- async def generate_completion( # type: ignore[override]
313
+ async def generate_completion_stream( # type: ignore[override]
357
314
  self,
358
315
  conversation: Messages,
359
316
  *,
360
317
  tool_choice: ToolChoice | None = None,
361
318
  n_choices: int | None = None,
362
- ) -> Completion:
363
- wrapped_func = retry(
364
- wait=wait_random_exponential(min=1, max=8),
365
- stop=stop_after_attempt(self.num_generation_retries + 1),
366
- before_sleep=retry_before_sleep_callback,
367
- retry_error_callback=retry_error_callback,
368
- )(self.__class__.generate_completion_no_retry)
369
-
370
- return await wrapped_func(
371
- self, conversation, tool_choice=tool_choice, n_choices=n_choices
372
- )
373
-
374
- def _get_rate_limiter(
375
- self,
376
- rate_limiter: RateLimiterC[Messages, AssistantMessage] | None = None,
377
- rpm: float | None = None,
378
- chunk_size: int = 1000,
379
- max_concurrency: int = 300,
380
- ) -> RateLimiterC[Messages, AssistantMessage] | None:
381
- if rate_limiter is not None:
382
- logger.info(
383
- f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
384
- )
385
- return rate_limiter
386
- if rpm is not None:
387
- logger.info(f"[{self.__class__.__name__}] Set rate limit to {rpm} RPM")
388
- return RateLimiterC(
389
- rpm=rpm, chunk_size=chunk_size, max_concurrency=max_concurrency
390
- )
319
+ proc_name: str | None = None,
320
+ call_id: str | None = None,
321
+ ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
322
+ n_attempt = 0
323
+ while n_attempt <= self.max_response_retries:
324
+ try:
325
+ async for event in await self._generate_completion_stream_once( # type: ignore[return]
326
+ conversation, # type: ignore[arg-type]
327
+ tool_choice=tool_choice,
328
+ n_choices=n_choices,
329
+ proc_name=proc_name,
330
+ call_id=call_id,
331
+ ):
332
+ yield event
333
+ return
334
+ except (LLMResponseValidationError, LLMToolCallValidationError) as err:
335
+ err_data = LLMStreamingErrorData(
336
+ error=err, model_name=self._model_name, model_id=self.model_id
337
+ )
338
+ yield LLMStreamingErrorEvent(
339
+ data=err_data, proc_name=proc_name, call_id=call_id
340
+ )
391
341
 
392
- return None
342
+ n_attempt += 1
343
+ if n_attempt > self.max_response_retries:
344
+ if n_attempt == 1:
345
+ logger.warning(f"\nCloudLLM completion request failed:\n{err}")
346
+ if n_attempt > 1:
347
+ logger.warning(
348
+ "\nCloudLLM completion request failed after "
349
+ f"retrying:\n{err}"
350
+ )
351
+ refusal_completion = make_refusal_completion(
352
+ self._model_name, err
353
+ )
354
+ yield CompletionEvent(
355
+ data=refusal_completion,
356
+ proc_name=proc_name,
357
+ call_id=call_id,
358
+ )
359
+ raise err
360
+ # return
361
+
362
+ logger.warning(
363
+ "\nCloudLLM completion request failed "
364
+ f"(retry attempt {n_attempt}):\n{err}"
365
+ )