grasp_agents 0.4.7__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +191 -224
- grasp_agents/comm_processor.py +101 -100
- grasp_agents/errors.py +69 -9
- grasp_agents/litellm/__init__.py +106 -0
- grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents/litellm/converters.py +138 -0
- grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents/llm.py +84 -49
- grasp_agents/llm_agent.py +136 -120
- grasp_agents/llm_agent_memory.py +3 -3
- grasp_agents/llm_policy_executor.py +167 -174
- grasp_agents/memory.py +4 -0
- grasp_agents/openai/__init__.py +24 -9
- grasp_agents/openai/completion_chunk_converters.py +6 -6
- grasp_agents/openai/completion_converters.py +12 -14
- grasp_agents/openai/content_converters.py +1 -3
- grasp_agents/openai/converters.py +6 -8
- grasp_agents/openai/message_converters.py +21 -3
- grasp_agents/openai/openai_llm.py +155 -103
- grasp_agents/openai/tool_converters.py +4 -6
- grasp_agents/packet.py +5 -2
- grasp_agents/packet_pool.py +14 -13
- grasp_agents/printer.py +234 -72
- grasp_agents/processor.py +228 -88
- grasp_agents/prompt_builder.py +2 -2
- grasp_agents/run_context.py +11 -20
- grasp_agents/runner.py +42 -0
- grasp_agents/typing/completion.py +16 -9
- grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents/typing/events.py +95 -19
- grasp_agents/typing/message.py +25 -1
- grasp_agents/typing/tool.py +2 -0
- grasp_agents/usage_tracker.py +31 -37
- grasp_agents/utils.py +95 -84
- grasp_agents/workflow/looped_workflow.py +60 -11
- grasp_agents/workflow/sequential_workflow.py +43 -11
- grasp_agents/workflow/workflow_processor.py +25 -24
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
- grasp_agents-0.5.0.dist-info/RECORD +57 -0
- grasp_agents-0.4.7.dist-info/RECORD +0 -50
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/cloud_llm.py
CHANGED
@@ -1,103 +1,59 @@
|
|
1
|
-
import fnmatch
|
2
1
|
import logging
|
3
|
-
import os
|
4
2
|
from abc import abstractmethod
|
5
|
-
from collections.abc import AsyncIterator, Mapping
|
3
|
+
from collections.abc import AsyncIterator, Mapping, Sequence
|
6
4
|
from copy import deepcopy
|
7
|
-
from typing import Any, Generic,
|
5
|
+
from typing import Any, Generic, Required, cast
|
8
6
|
|
9
7
|
import httpx
|
10
8
|
from pydantic import BaseModel
|
11
|
-
from tenacity import (
|
12
|
-
RetryCallState,
|
13
|
-
retry,
|
14
|
-
stop_after_attempt,
|
15
|
-
wait_random_exponential,
|
16
|
-
)
|
17
9
|
from typing_extensions import TypedDict
|
18
10
|
|
11
|
+
from .errors import LLMResponseValidationError, LLMToolCallValidationError
|
19
12
|
from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
|
20
13
|
from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
|
21
14
|
from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
|
22
15
|
from .typing.completion import Completion
|
23
|
-
from .typing.completion_chunk import
|
24
|
-
|
25
|
-
|
26
|
-
|
16
|
+
from .typing.completion_chunk import CompletionChoice
|
17
|
+
from .typing.events import (
|
18
|
+
CompletionChunkEvent,
|
19
|
+
CompletionEvent,
|
20
|
+
LLMStreamingErrorData,
|
21
|
+
LLMStreamingErrorEvent,
|
27
22
|
)
|
28
|
-
from .typing.events import CompletionChunkEvent, CompletionEvent
|
29
23
|
from .typing.message import AssistantMessage, Messages
|
30
24
|
from .typing.tool import BaseTool, ToolChoice
|
31
25
|
|
32
26
|
logger = logging.getLogger(__name__)
|
33
27
|
|
34
28
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
def get_api_providers() -> dict[APIProviderName, APIProvider]:
|
46
|
-
"""Returns a dictionary of available API providers."""
|
47
|
-
return {
|
48
|
-
"openai": APIProvider(
|
49
|
-
name="openai",
|
50
|
-
base_url="https://api.openai.com/v1",
|
51
|
-
api_key=os.getenv("OPENAI_API_KEY"),
|
52
|
-
struct_outputs_support=("*",),
|
53
|
-
),
|
54
|
-
"openrouter": APIProvider(
|
55
|
-
name="openrouter",
|
56
|
-
base_url="https://openrouter.ai/api/v1",
|
57
|
-
api_key=os.getenv("OPENROUTER_API_KEY"),
|
58
|
-
struct_outputs_support=(),
|
59
|
-
),
|
60
|
-
"google_ai_studio": APIProvider(
|
61
|
-
name="google_ai_studio",
|
62
|
-
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
63
|
-
api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
|
64
|
-
struct_outputs_support=("*",),
|
65
|
-
),
|
66
|
-
}
|
67
|
-
|
68
|
-
|
69
|
-
def retry_error_callback(retry_state: RetryCallState) -> Completion:
|
70
|
-
exception = retry_state.outcome.exception() if retry_state.outcome else None
|
71
|
-
if exception:
|
72
|
-
if retry_state.attempt_number == 1:
|
73
|
-
logger.warning(
|
74
|
-
f"\nCloudLLM completion request failed:\n{exception}",
|
75
|
-
# exc_info=exception,
|
76
|
-
)
|
77
|
-
if retry_state.attempt_number > 1:
|
78
|
-
logger.warning(
|
79
|
-
f"\nCloudLLM completion request failed after retrying:\n{exception}",
|
80
|
-
# exc_info=exception,
|
81
|
-
)
|
82
|
-
failed_message = AssistantMessage(content=None, refusal=str(exception))
|
29
|
+
class APIProvider(TypedDict, total=False):
|
30
|
+
name: Required[str]
|
31
|
+
base_url: str | None
|
32
|
+
api_key: str | None
|
33
|
+
# Wildcard patterns for model names that support response schema validation:
|
34
|
+
response_schema_support: tuple[str, ...] | None
|
35
|
+
|
36
|
+
|
37
|
+
def make_refusal_completion(model_name: str, err: BaseException) -> Completion:
|
38
|
+
failed_message = AssistantMessage(content=None, refusal=str(err))
|
83
39
|
|
84
40
|
return Completion(
|
85
|
-
model=
|
41
|
+
model=model_name,
|
86
42
|
choices=[CompletionChoice(message=failed_message, finish_reason=None, index=0)],
|
87
43
|
)
|
88
44
|
|
89
45
|
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
"\nRetrying CloudLLM completion request "
|
95
|
-
f"(attempt {retry_state.attempt_number}):\n{exception}"
|
96
|
-
)
|
46
|
+
class CloudLLMSettings(LLMSettings, total=False):
|
47
|
+
extra_headers: dict[str, Any] | None
|
48
|
+
extra_body: object | None
|
49
|
+
extra_query: dict[str, Any] | None
|
97
50
|
|
98
51
|
|
99
|
-
|
100
|
-
|
52
|
+
LLMRateLimiter = RateLimiterC[
|
53
|
+
Messages,
|
54
|
+
AssistantMessage
|
55
|
+
| AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent],
|
56
|
+
]
|
101
57
|
|
102
58
|
|
103
59
|
class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
|
@@ -105,25 +61,24 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
105
61
|
self,
|
106
62
|
# Base LLM args
|
107
63
|
model_name: str,
|
64
|
+
api_provider: APIProvider,
|
108
65
|
converters: ConvertT_co,
|
109
66
|
llm_settings: SettingsT_co | None = None,
|
110
67
|
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
111
|
-
|
68
|
+
response_schema: Any | None = None,
|
69
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
70
|
+
apply_response_schema_via_provider: bool = True,
|
112
71
|
model_id: str | None = None,
|
113
|
-
# Custom LLM provider
|
114
|
-
api_provider: APIProvider | None = None,
|
115
72
|
# Connection settings
|
116
73
|
async_http_client: httpx.AsyncClient | None = None,
|
117
74
|
async_http_client_params: (
|
118
75
|
dict[str, Any] | AsyncHTTPClientParams | None
|
119
76
|
) = None,
|
77
|
+
max_client_retries: int = 2,
|
120
78
|
# Rate limiting
|
121
|
-
rate_limiter:
|
122
|
-
|
123
|
-
|
124
|
-
rate_limiter_max_concurrency: int = 300,
|
125
|
-
# Retries
|
126
|
-
num_generation_retries: int = 0,
|
79
|
+
rate_limiter: LLMRateLimiter | None = None,
|
80
|
+
# LLM response retries: try to regenerate to pass validation
|
81
|
+
max_response_retries: int = 0,
|
127
82
|
**kwargs: Any,
|
128
83
|
) -> None:
|
129
84
|
self.llm_settings: CloudLLMSettings | None
|
@@ -134,61 +89,30 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
134
89
|
converters=converters,
|
135
90
|
model_id=model_id,
|
136
91
|
tools=tools,
|
137
|
-
|
92
|
+
response_schema=response_schema,
|
93
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
138
94
|
**kwargs,
|
139
95
|
)
|
140
96
|
|
141
97
|
self._model_name = model_name
|
142
|
-
|
143
|
-
|
144
|
-
if len(model_name_parts) == 2:
|
145
|
-
api_provider_name, api_model_name = model_name_parts
|
146
|
-
self._api_model_name: str = api_model_name
|
147
|
-
|
148
|
-
api_providers = get_api_providers()
|
149
|
-
|
150
|
-
if api_provider_name not in api_providers:
|
151
|
-
raise ValueError(
|
152
|
-
f"API provider '{api_provider_name}' is not supported. "
|
153
|
-
f"Supported providers are: {', '.join(api_providers.keys())}"
|
154
|
-
)
|
155
|
-
|
156
|
-
_api_provider = api_providers[api_provider_name]
|
157
|
-
elif api_provider is not None:
|
158
|
-
self._api_model_name: str = model_name
|
159
|
-
_api_provider = api_provider
|
160
|
-
else:
|
161
|
-
raise ValueError(
|
162
|
-
"API provider must be specified either in the model name "
|
163
|
-
"or as a separate argument."
|
164
|
-
)
|
165
|
-
|
166
|
-
self._api_provider_name: APIProviderName = _api_provider["name"]
|
167
|
-
self._base_url: str | None = _api_provider.get("base_url")
|
168
|
-
self._api_key: str | None = _api_provider.get("api_key")
|
169
|
-
self._struct_outputs_support: bool = any(
|
170
|
-
fnmatch.fnmatch(self._model_name, pat)
|
171
|
-
for pat in _api_provider.get("struct_outputs_support", ())
|
172
|
-
)
|
98
|
+
self._api_provider = api_provider
|
99
|
+
self._apply_response_schema_via_provider = apply_response_schema_via_provider
|
173
100
|
|
174
101
|
if (
|
175
|
-
|
176
|
-
and not
|
102
|
+
apply_response_schema_via_provider
|
103
|
+
and response_schema_by_xml_tag is not None
|
177
104
|
):
|
178
105
|
raise ValueError(
|
179
|
-
|
106
|
+
"Response schema by XML tag is not supported "
|
107
|
+
"when apply_response_schema_via_provider is True."
|
180
108
|
)
|
181
109
|
|
182
|
-
self.
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
rate_limiter
|
187
|
-
rpm=rate_limiter_rpm,
|
188
|
-
chunk_size=rate_limiter_chunk_size,
|
189
|
-
max_concurrency=rate_limiter_max_concurrency,
|
110
|
+
self._rate_limiter: LLMRateLimiter | None = None
|
111
|
+
if rate_limiter is not None:
|
112
|
+
self._rate_limiter = rate_limiter
|
113
|
+
logger.info(
|
114
|
+
f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
|
190
115
|
)
|
191
|
-
)
|
192
116
|
|
193
117
|
self._async_http_client: httpx.AsyncClient | None = None
|
194
118
|
if async_http_client is not None:
|
@@ -198,18 +122,31 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
198
122
|
async_http_client_params
|
199
123
|
)
|
200
124
|
|
201
|
-
self.
|
125
|
+
self.max_client_retries = max_client_retries
|
126
|
+
self.max_response_retries = max_response_retries
|
202
127
|
|
203
128
|
@property
|
204
|
-
def
|
205
|
-
return self.
|
129
|
+
def api_provider(self) -> APIProvider:
|
130
|
+
return self._api_provider
|
206
131
|
|
207
132
|
@property
|
208
|
-
def rate_limiter(
|
209
|
-
self,
|
210
|
-
) -> RateLimiterC[Messages, AssistantMessage] | None:
|
133
|
+
def rate_limiter(self) -> LLMRateLimiter | None:
|
211
134
|
return self._rate_limiter
|
212
135
|
|
136
|
+
@property
|
137
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
|
138
|
+
return self._tools
|
139
|
+
|
140
|
+
@tools.setter
|
141
|
+
def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
142
|
+
if not tools:
|
143
|
+
self._tools = None
|
144
|
+
return
|
145
|
+
strict_value = True if self._apply_response_schema_via_provider else None
|
146
|
+
for t in tools:
|
147
|
+
t.strict = strict_value
|
148
|
+
self._tools = {t.name: t for t in tools}
|
149
|
+
|
213
150
|
def _make_completion_kwargs(
|
214
151
|
self,
|
215
152
|
conversation: Messages,
|
@@ -221,21 +158,17 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
221
158
|
api_tools = None
|
222
159
|
api_tool_choice = None
|
223
160
|
if self.tools:
|
224
|
-
api_tools = [
|
225
|
-
self._converters.to_tool(t, **self._tool_call_settings)
|
226
|
-
for t in self.tools.values()
|
227
|
-
]
|
161
|
+
api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
|
228
162
|
if tool_choice is not None:
|
229
163
|
api_tool_choice = self._converters.to_tool_choice(tool_choice)
|
230
164
|
|
231
165
|
api_llm_settings = deepcopy(self.llm_settings or {})
|
232
|
-
api_llm_settings.pop("use_struct_outputs", None)
|
233
166
|
|
234
167
|
return dict(
|
235
168
|
api_messages=api_messages,
|
236
169
|
api_tools=api_tools,
|
237
170
|
api_tool_choice=api_tool_choice,
|
238
|
-
|
171
|
+
api_response_schema=self._response_schema,
|
239
172
|
n_choices=n_choices,
|
240
173
|
**api_llm_settings,
|
241
174
|
)
|
@@ -247,19 +180,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
247
180
|
*,
|
248
181
|
api_tools: list[Any] | None = None,
|
249
182
|
api_tool_choice: Any | None = None,
|
250
|
-
|
251
|
-
**api_llm_settings: Any,
|
252
|
-
) -> Any:
|
253
|
-
pass
|
254
|
-
|
255
|
-
@abstractmethod
|
256
|
-
async def _get_parsed_completion(
|
257
|
-
self,
|
258
|
-
api_messages: list[Any],
|
259
|
-
*,
|
260
|
-
api_tools: list[Any] | None = None,
|
261
|
-
api_tool_choice: Any | None = None,
|
262
|
-
api_response_format: type | None = None,
|
183
|
+
api_response_schema: type | None = None,
|
263
184
|
n_choices: int | None = None,
|
264
185
|
**api_llm_settings: Any,
|
265
186
|
) -> Any:
|
@@ -272,25 +193,14 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
272
193
|
*,
|
273
194
|
api_tools: list[Any] | None = None,
|
274
195
|
api_tool_choice: Any | None = None,
|
196
|
+
api_response_schema: type | None = None,
|
275
197
|
n_choices: int | None = None,
|
276
198
|
**api_llm_settings: Any,
|
277
199
|
) -> AsyncIterator[Any]:
|
278
200
|
pass
|
279
201
|
|
280
|
-
@
|
281
|
-
async def
|
282
|
-
self,
|
283
|
-
api_messages: list[Any],
|
284
|
-
*,
|
285
|
-
api_tools: list[Any] | None = None,
|
286
|
-
api_tool_choice: Any | None = None,
|
287
|
-
api_response_format: type | None = None,
|
288
|
-
n_choices: int | None = None,
|
289
|
-
**api_llm_settings: Any,
|
290
|
-
) -> AsyncIterator[Any]:
|
291
|
-
pass
|
292
|
-
|
293
|
-
async def generate_completion_no_retry(
|
202
|
+
@limit_rate
|
203
|
+
async def _generate_completion_once(
|
294
204
|
self,
|
295
205
|
conversation: Messages,
|
296
206
|
*,
|
@@ -301,98 +211,155 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
301
211
|
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
302
212
|
)
|
303
213
|
|
304
|
-
if not self.
|
305
|
-
completion_kwargs.pop("
|
306
|
-
|
307
|
-
else:
|
308
|
-
api_completion = await self._get_parsed_completion(**completion_kwargs)
|
214
|
+
if not self._apply_response_schema_via_provider:
|
215
|
+
completion_kwargs.pop("api_response_schema", None)
|
216
|
+
api_completion = await self._get_completion(**completion_kwargs)
|
309
217
|
|
310
218
|
completion = self._converters.from_completion(
|
311
219
|
api_completion, name=self.model_id
|
312
220
|
)
|
313
221
|
|
314
|
-
if not self.
|
315
|
-
|
316
|
-
# of the LLM provider
|
317
|
-
self._validate_completion(completion)
|
222
|
+
if not self._apply_response_schema_via_provider:
|
223
|
+
self._validate_response(completion)
|
318
224
|
self._validate_tool_calls(completion)
|
319
225
|
|
320
226
|
return completion
|
321
227
|
|
322
|
-
async def
|
228
|
+
async def generate_completion(
|
229
|
+
self,
|
230
|
+
conversation: Messages,
|
231
|
+
*,
|
232
|
+
tool_choice: ToolChoice | None = None,
|
233
|
+
n_choices: int | None = None,
|
234
|
+
proc_name: str | None = None,
|
235
|
+
call_id: str | None = None,
|
236
|
+
) -> Completion:
|
237
|
+
n_attempt = 0
|
238
|
+
while n_attempt <= self.max_response_retries:
|
239
|
+
try:
|
240
|
+
return await self._generate_completion_once(
|
241
|
+
conversation, # type: ignore[return]
|
242
|
+
tool_choice=tool_choice,
|
243
|
+
n_choices=n_choices,
|
244
|
+
)
|
245
|
+
except (LLMResponseValidationError, LLMToolCallValidationError) as err:
|
246
|
+
n_attempt += 1
|
247
|
+
|
248
|
+
if n_attempt > self.max_response_retries:
|
249
|
+
if n_attempt == 1:
|
250
|
+
logger.warning(f"\nCloudLLM completion request failed:\n{err}")
|
251
|
+
if n_attempt > 1:
|
252
|
+
logger.warning(
|
253
|
+
f"\nCloudLLM completion request failed after retrying:\n{err}"
|
254
|
+
)
|
255
|
+
raise err
|
256
|
+
# return make_refusal_completion(self._model_name, err)
|
257
|
+
|
258
|
+
logger.warning(
|
259
|
+
f"\nCloudLLM completion request failed (retry attempt {n_attempt}):"
|
260
|
+
f"\n{err}"
|
261
|
+
)
|
262
|
+
|
263
|
+
return make_refusal_completion(
|
264
|
+
self._model_name,
|
265
|
+
Exception("Unexpected error: retry loop exited without returning"),
|
266
|
+
)
|
267
|
+
|
268
|
+
@limit_rate
|
269
|
+
async def _generate_completion_stream_once(
|
323
270
|
self,
|
324
271
|
conversation: Messages,
|
325
272
|
*,
|
326
273
|
tool_choice: ToolChoice | None = None,
|
327
274
|
n_choices: int | None = None,
|
275
|
+
proc_name: str | None = None,
|
276
|
+
call_id: str | None = None,
|
328
277
|
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
329
278
|
completion_kwargs = self._make_completion_kwargs(
|
330
279
|
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
331
280
|
)
|
281
|
+
if not self._apply_response_schema_via_provider:
|
282
|
+
completion_kwargs.pop("api_response_schema", None)
|
283
|
+
|
284
|
+
api_stream = self._get_completion_stream(**completion_kwargs)
|
285
|
+
api_stream = cast("AsyncIterator[Any]", api_stream)
|
332
286
|
|
333
|
-
|
334
|
-
|
335
|
-
api_stream = await self._get_completion_stream(**completion_kwargs)
|
336
|
-
else:
|
337
|
-
api_stream = await self._get_parsed_completion_stream(**completion_kwargs)
|
287
|
+
async def iterator() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
288
|
+
api_completion_chunks: list[Any] = []
|
338
289
|
|
339
|
-
async def iterate() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
340
|
-
completion_chunks: list[CompletionChunk] = []
|
341
290
|
async for api_completion_chunk in api_stream:
|
291
|
+
api_completion_chunks.append(api_completion_chunk)
|
342
292
|
completion_chunk = self._converters.from_completion_chunk(
|
343
293
|
api_completion_chunk, name=self.model_id
|
344
294
|
)
|
345
|
-
completion_chunks.append(completion_chunk)
|
346
|
-
yield CompletionChunkEvent(data=completion_chunk, name=self.model_id)
|
347
295
|
|
348
|
-
|
349
|
-
|
296
|
+
yield CompletionChunkEvent(
|
297
|
+
data=completion_chunk, proc_name=proc_name, call_id=call_id
|
298
|
+
)
|
299
|
+
|
300
|
+
api_completion = self.combine_completion_chunks(api_completion_chunks)
|
301
|
+
completion = self._converters.from_completion(
|
302
|
+
api_completion, name=self.model_id
|
303
|
+
)
|
350
304
|
|
351
|
-
yield CompletionEvent(data=completion,
|
305
|
+
yield CompletionEvent(data=completion, proc_name=proc_name, call_id=call_id)
|
352
306
|
|
353
|
-
if not self.
|
354
|
-
|
355
|
-
# of the LLM provider
|
356
|
-
self._validate_completion(completion)
|
307
|
+
if not self._apply_response_schema_via_provider:
|
308
|
+
self._validate_response(completion)
|
357
309
|
self._validate_tool_calls(completion)
|
358
310
|
|
359
|
-
return
|
311
|
+
return iterator()
|
360
312
|
|
361
|
-
|
362
|
-
async def generate_completion( # type: ignore[override]
|
313
|
+
async def generate_completion_stream( # type: ignore[override]
|
363
314
|
self,
|
364
315
|
conversation: Messages,
|
365
316
|
*,
|
366
317
|
tool_choice: ToolChoice | None = None,
|
367
318
|
n_choices: int | None = None,
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
)
|
391
|
-
return rate_limiter
|
392
|
-
if rpm is not None:
|
393
|
-
logger.info(f"[{self.__class__.__name__}] Set rate limit to {rpm} RPM")
|
394
|
-
return RateLimiterC(
|
395
|
-
rpm=rpm, chunk_size=chunk_size, max_concurrency=max_concurrency
|
396
|
-
)
|
319
|
+
proc_name: str | None = None,
|
320
|
+
call_id: str | None = None,
|
321
|
+
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
|
322
|
+
n_attempt = 0
|
323
|
+
while n_attempt <= self.max_response_retries:
|
324
|
+
try:
|
325
|
+
async for event in await self._generate_completion_stream_once( # type: ignore[return]
|
326
|
+
conversation, # type: ignore[arg-type]
|
327
|
+
tool_choice=tool_choice,
|
328
|
+
n_choices=n_choices,
|
329
|
+
proc_name=proc_name,
|
330
|
+
call_id=call_id,
|
331
|
+
):
|
332
|
+
yield event
|
333
|
+
return
|
334
|
+
except (LLMResponseValidationError, LLMToolCallValidationError) as err:
|
335
|
+
err_data = LLMStreamingErrorData(
|
336
|
+
error=err, model_name=self._model_name, model_id=self.model_id
|
337
|
+
)
|
338
|
+
yield LLMStreamingErrorEvent(
|
339
|
+
data=err_data, proc_name=proc_name, call_id=call_id
|
340
|
+
)
|
397
341
|
|
398
|
-
|
342
|
+
n_attempt += 1
|
343
|
+
if n_attempt > self.max_response_retries:
|
344
|
+
if n_attempt == 1:
|
345
|
+
logger.warning(f"\nCloudLLM completion request failed:\n{err}")
|
346
|
+
if n_attempt > 1:
|
347
|
+
logger.warning(
|
348
|
+
"\nCloudLLM completion request failed after "
|
349
|
+
f"retrying:\n{err}"
|
350
|
+
)
|
351
|
+
refusal_completion = make_refusal_completion(
|
352
|
+
self._model_name, err
|
353
|
+
)
|
354
|
+
yield CompletionEvent(
|
355
|
+
data=refusal_completion,
|
356
|
+
proc_name=proc_name,
|
357
|
+
call_id=call_id,
|
358
|
+
)
|
359
|
+
raise err
|
360
|
+
# return
|
361
|
+
|
362
|
+
logger.warning(
|
363
|
+
"\nCloudLLM completion request failed "
|
364
|
+
f"(retry attempt {n_attempt}):\n{err}"
|
365
|
+
)
|