grasp_agents 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +191 -218
- grasp_agents/comm_processor.py +101 -100
- grasp_agents/errors.py +69 -9
- grasp_agents/litellm/__init__.py +106 -0
- grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents/litellm/converters.py +138 -0
- grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents/llm.py +84 -49
- grasp_agents/llm_agent.py +136 -120
- grasp_agents/llm_agent_memory.py +3 -3
- grasp_agents/llm_policy_executor.py +167 -174
- grasp_agents/memory.py +4 -0
- grasp_agents/openai/__init__.py +24 -9
- grasp_agents/openai/completion_chunk_converters.py +6 -6
- grasp_agents/openai/completion_converters.py +12 -14
- grasp_agents/openai/content_converters.py +1 -3
- grasp_agents/openai/converters.py +6 -8
- grasp_agents/openai/message_converters.py +21 -3
- grasp_agents/openai/openai_llm.py +155 -103
- grasp_agents/openai/tool_converters.py +4 -6
- grasp_agents/packet.py +5 -2
- grasp_agents/packet_pool.py +14 -13
- grasp_agents/printer.py +234 -72
- grasp_agents/processor.py +228 -88
- grasp_agents/prompt_builder.py +2 -2
- grasp_agents/run_context.py +11 -20
- grasp_agents/runner.py +42 -0
- grasp_agents/typing/completion.py +16 -9
- grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents/typing/events.py +95 -19
- grasp_agents/typing/message.py +25 -1
- grasp_agents/typing/tool.py +2 -0
- grasp_agents/usage_tracker.py +31 -37
- grasp_agents/utils.py +95 -84
- grasp_agents/workflow/looped_workflow.py +60 -11
- grasp_agents/workflow/sequential_workflow.py +43 -11
- grasp_agents/workflow/workflow_processor.py +25 -24
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
- grasp_agents-0.5.0.dist-info/RECORD +57 -0
- grasp_agents-0.4.6.dist-info/RECORD +0 -50
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/cloud_llm.py
CHANGED
@@ -1,101 +1,59 @@
|
|
1
|
-
import fnmatch
|
2
1
|
import logging
|
3
|
-
import os
|
4
2
|
from abc import abstractmethod
|
5
|
-
from collections.abc import AsyncIterator, Mapping
|
3
|
+
from collections.abc import AsyncIterator, Mapping, Sequence
|
6
4
|
from copy import deepcopy
|
7
|
-
from typing import Any, Generic,
|
5
|
+
from typing import Any, Generic, Required, cast
|
8
6
|
|
9
7
|
import httpx
|
10
8
|
from pydantic import BaseModel
|
11
|
-
from tenacity import (
|
12
|
-
RetryCallState,
|
13
|
-
retry,
|
14
|
-
stop_after_attempt,
|
15
|
-
wait_random_exponential,
|
16
|
-
)
|
17
9
|
from typing_extensions import TypedDict
|
18
10
|
|
11
|
+
from .errors import LLMResponseValidationError, LLMToolCallValidationError
|
19
12
|
from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
|
20
13
|
from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
|
21
14
|
from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
|
22
15
|
from .typing.completion import Completion
|
23
|
-
from .typing.completion_chunk import
|
24
|
-
|
25
|
-
|
26
|
-
|
16
|
+
from .typing.completion_chunk import CompletionChoice
|
17
|
+
from .typing.events import (
|
18
|
+
CompletionChunkEvent,
|
19
|
+
CompletionEvent,
|
20
|
+
LLMStreamingErrorData,
|
21
|
+
LLMStreamingErrorEvent,
|
27
22
|
)
|
28
|
-
from .typing.events import CompletionChunkEvent, CompletionEvent
|
29
23
|
from .typing.message import AssistantMessage, Messages
|
30
24
|
from .typing.tool import BaseTool, ToolChoice
|
31
25
|
|
32
26
|
logger = logging.getLogger(__name__)
|
33
27
|
|
34
28
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
API_PROVIDERS: dict[APIProviderName, APIProvider] = {
|
46
|
-
"openai": APIProvider(
|
47
|
-
name="openai",
|
48
|
-
base_url="https://api.openai.com/v1",
|
49
|
-
api_key=os.getenv("OPENAI_API_KEY"),
|
50
|
-
struct_outputs_support=("*",),
|
51
|
-
),
|
52
|
-
"openrouter": APIProvider(
|
53
|
-
name="openrouter",
|
54
|
-
base_url="https://openrouter.ai/api/v1",
|
55
|
-
api_key=os.getenv("OPENROUTER_API_KEY"),
|
56
|
-
struct_outputs_support=(),
|
57
|
-
),
|
58
|
-
"google_ai_studio": APIProvider(
|
59
|
-
name="google_ai_studio",
|
60
|
-
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
61
|
-
api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
|
62
|
-
struct_outputs_support=("*",),
|
63
|
-
),
|
64
|
-
}
|
65
|
-
|
66
|
-
|
67
|
-
def retry_error_callback(retry_state: RetryCallState) -> Completion:
|
68
|
-
exception = retry_state.outcome.exception() if retry_state.outcome else None
|
69
|
-
if exception:
|
70
|
-
if retry_state.attempt_number == 1:
|
71
|
-
logger.warning(
|
72
|
-
f"\nCloudLLM completion request failed:\n{exception}",
|
73
|
-
# exc_info=exception,
|
74
|
-
)
|
75
|
-
if retry_state.attempt_number > 1:
|
76
|
-
logger.warning(
|
77
|
-
f"\nCloudLLM completion request failed after retrying:\n{exception}",
|
78
|
-
# exc_info=exception,
|
79
|
-
)
|
80
|
-
failed_message = AssistantMessage(content=None, refusal=str(exception))
|
29
|
+
class APIProvider(TypedDict, total=False):
|
30
|
+
name: Required[str]
|
31
|
+
base_url: str | None
|
32
|
+
api_key: str | None
|
33
|
+
# Wildcard patterns for model names that support response schema validation:
|
34
|
+
response_schema_support: tuple[str, ...] | None
|
35
|
+
|
36
|
+
|
37
|
+
def make_refusal_completion(model_name: str, err: BaseException) -> Completion:
|
38
|
+
failed_message = AssistantMessage(content=None, refusal=str(err))
|
81
39
|
|
82
40
|
return Completion(
|
83
|
-
model=
|
41
|
+
model=model_name,
|
84
42
|
choices=[CompletionChoice(message=failed_message, finish_reason=None, index=0)],
|
85
43
|
)
|
86
44
|
|
87
45
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
"\nRetrying CloudLLM completion request "
|
93
|
-
f"(attempt {retry_state.attempt_number}):\n{exception}"
|
94
|
-
)
|
46
|
+
class CloudLLMSettings(LLMSettings, total=False):
|
47
|
+
extra_headers: dict[str, Any] | None
|
48
|
+
extra_body: object | None
|
49
|
+
extra_query: dict[str, Any] | None
|
95
50
|
|
96
51
|
|
97
|
-
|
98
|
-
|
52
|
+
LLMRateLimiter = RateLimiterC[
|
53
|
+
Messages,
|
54
|
+
AssistantMessage
|
55
|
+
| AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent],
|
56
|
+
]
|
99
57
|
|
100
58
|
|
101
59
|
class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
|
@@ -103,25 +61,24 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
103
61
|
self,
|
104
62
|
# Base LLM args
|
105
63
|
model_name: str,
|
64
|
+
api_provider: APIProvider,
|
106
65
|
converters: ConvertT_co,
|
107
66
|
llm_settings: SettingsT_co | None = None,
|
108
67
|
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
109
|
-
|
68
|
+
response_schema: Any | None = None,
|
69
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
70
|
+
apply_response_schema_via_provider: bool = True,
|
110
71
|
model_id: str | None = None,
|
111
|
-
# Custom LLM provider
|
112
|
-
api_provider: APIProvider | None = None,
|
113
72
|
# Connection settings
|
114
73
|
async_http_client: httpx.AsyncClient | None = None,
|
115
74
|
async_http_client_params: (
|
116
75
|
dict[str, Any] | AsyncHTTPClientParams | None
|
117
76
|
) = None,
|
77
|
+
max_client_retries: int = 2,
|
118
78
|
# Rate limiting
|
119
|
-
rate_limiter:
|
120
|
-
|
121
|
-
|
122
|
-
rate_limiter_max_concurrency: int = 300,
|
123
|
-
# Retries
|
124
|
-
num_generation_retries: int = 0,
|
79
|
+
rate_limiter: LLMRateLimiter | None = None,
|
80
|
+
# LLM response retries: try to regenerate to pass validation
|
81
|
+
max_response_retries: int = 0,
|
125
82
|
**kwargs: Any,
|
126
83
|
) -> None:
|
127
84
|
self.llm_settings: CloudLLMSettings | None
|
@@ -132,57 +89,30 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
132
89
|
converters=converters,
|
133
90
|
model_id=model_id,
|
134
91
|
tools=tools,
|
135
|
-
|
92
|
+
response_schema=response_schema,
|
93
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
136
94
|
**kwargs,
|
137
95
|
)
|
138
96
|
|
139
97
|
self._model_name = model_name
|
140
|
-
|
141
|
-
|
142
|
-
if len(model_name_parts) == 2:
|
143
|
-
api_provider_name, api_model_name = model_name_parts
|
144
|
-
self._api_model_name: str = api_model_name
|
145
|
-
if api_provider_name not in API_PROVIDERS:
|
146
|
-
raise ValueError(
|
147
|
-
f"API provider '{api_provider_name}' is not supported. "
|
148
|
-
f"Supported providers are: {', '.join(API_PROVIDERS.keys())}"
|
149
|
-
)
|
150
|
-
_api_provider = API_PROVIDERS[api_provider_name]
|
151
|
-
elif api_provider is not None:
|
152
|
-
self._api_model_name: str = model_name
|
153
|
-
_api_provider = api_provider
|
154
|
-
else:
|
155
|
-
raise ValueError(
|
156
|
-
"API provider must be specified either in the model name "
|
157
|
-
"or as a separate argument."
|
158
|
-
)
|
159
|
-
|
160
|
-
self._api_provider_name: APIProviderName = _api_provider["name"]
|
161
|
-
self._base_url: str | None = _api_provider.get("base_url")
|
162
|
-
self._api_key: str | None = _api_provider.get("api_key")
|
163
|
-
self._struct_outputs_support: bool = any(
|
164
|
-
fnmatch.fnmatch(self._model_name, pat)
|
165
|
-
for pat in _api_provider.get("struct_outputs_support", ())
|
166
|
-
)
|
98
|
+
self._api_provider = api_provider
|
99
|
+
self._apply_response_schema_via_provider = apply_response_schema_via_provider
|
167
100
|
|
168
101
|
if (
|
169
|
-
|
170
|
-
and not
|
102
|
+
apply_response_schema_via_provider
|
103
|
+
and response_schema_by_xml_tag is not None
|
171
104
|
):
|
172
105
|
raise ValueError(
|
173
|
-
|
106
|
+
"Response schema by XML tag is not supported "
|
107
|
+
"when apply_response_schema_via_provider is True."
|
174
108
|
)
|
175
109
|
|
176
|
-
self.
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
rate_limiter
|
181
|
-
rpm=rate_limiter_rpm,
|
182
|
-
chunk_size=rate_limiter_chunk_size,
|
183
|
-
max_concurrency=rate_limiter_max_concurrency,
|
110
|
+
self._rate_limiter: LLMRateLimiter | None = None
|
111
|
+
if rate_limiter is not None:
|
112
|
+
self._rate_limiter = rate_limiter
|
113
|
+
logger.info(
|
114
|
+
f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
|
184
115
|
)
|
185
|
-
)
|
186
116
|
|
187
117
|
self._async_http_client: httpx.AsyncClient | None = None
|
188
118
|
if async_http_client is not None:
|
@@ -192,18 +122,31 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
192
122
|
async_http_client_params
|
193
123
|
)
|
194
124
|
|
195
|
-
self.
|
125
|
+
self.max_client_retries = max_client_retries
|
126
|
+
self.max_response_retries = max_response_retries
|
196
127
|
|
197
128
|
@property
|
198
|
-
def
|
199
|
-
return self.
|
129
|
+
def api_provider(self) -> APIProvider:
|
130
|
+
return self._api_provider
|
200
131
|
|
201
132
|
@property
|
202
|
-
def rate_limiter(
|
203
|
-
self,
|
204
|
-
) -> RateLimiterC[Messages, AssistantMessage] | None:
|
133
|
+
def rate_limiter(self) -> LLMRateLimiter | None:
|
205
134
|
return self._rate_limiter
|
206
135
|
|
136
|
+
@property
|
137
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
|
138
|
+
return self._tools
|
139
|
+
|
140
|
+
@tools.setter
|
141
|
+
def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
142
|
+
if not tools:
|
143
|
+
self._tools = None
|
144
|
+
return
|
145
|
+
strict_value = True if self._apply_response_schema_via_provider else None
|
146
|
+
for t in tools:
|
147
|
+
t.strict = strict_value
|
148
|
+
self._tools = {t.name: t for t in tools}
|
149
|
+
|
207
150
|
def _make_completion_kwargs(
|
208
151
|
self,
|
209
152
|
conversation: Messages,
|
@@ -215,21 +158,17 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
215
158
|
api_tools = None
|
216
159
|
api_tool_choice = None
|
217
160
|
if self.tools:
|
218
|
-
api_tools = [
|
219
|
-
self._converters.to_tool(t, **self._tool_call_settings)
|
220
|
-
for t in self.tools.values()
|
221
|
-
]
|
161
|
+
api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
|
222
162
|
if tool_choice is not None:
|
223
163
|
api_tool_choice = self._converters.to_tool_choice(tool_choice)
|
224
164
|
|
225
165
|
api_llm_settings = deepcopy(self.llm_settings or {})
|
226
|
-
api_llm_settings.pop("use_struct_outputs", None)
|
227
166
|
|
228
167
|
return dict(
|
229
168
|
api_messages=api_messages,
|
230
169
|
api_tools=api_tools,
|
231
170
|
api_tool_choice=api_tool_choice,
|
232
|
-
|
171
|
+
api_response_schema=self._response_schema,
|
233
172
|
n_choices=n_choices,
|
234
173
|
**api_llm_settings,
|
235
174
|
)
|
@@ -241,19 +180,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
241
180
|
*,
|
242
181
|
api_tools: list[Any] | None = None,
|
243
182
|
api_tool_choice: Any | None = None,
|
244
|
-
|
245
|
-
**api_llm_settings: Any,
|
246
|
-
) -> Any:
|
247
|
-
pass
|
248
|
-
|
249
|
-
@abstractmethod
|
250
|
-
async def _get_parsed_completion(
|
251
|
-
self,
|
252
|
-
api_messages: list[Any],
|
253
|
-
*,
|
254
|
-
api_tools: list[Any] | None = None,
|
255
|
-
api_tool_choice: Any | None = None,
|
256
|
-
api_response_format: type | None = None,
|
183
|
+
api_response_schema: type | None = None,
|
257
184
|
n_choices: int | None = None,
|
258
185
|
**api_llm_settings: Any,
|
259
186
|
) -> Any:
|
@@ -266,25 +193,14 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
266
193
|
*,
|
267
194
|
api_tools: list[Any] | None = None,
|
268
195
|
api_tool_choice: Any | None = None,
|
196
|
+
api_response_schema: type | None = None,
|
269
197
|
n_choices: int | None = None,
|
270
198
|
**api_llm_settings: Any,
|
271
199
|
) -> AsyncIterator[Any]:
|
272
200
|
pass
|
273
201
|
|
274
|
-
@
|
275
|
-
async def
|
276
|
-
self,
|
277
|
-
api_messages: list[Any],
|
278
|
-
*,
|
279
|
-
api_tools: list[Any] | None = None,
|
280
|
-
api_tool_choice: Any | None = None,
|
281
|
-
api_response_format: type | None = None,
|
282
|
-
n_choices: int | None = None,
|
283
|
-
**api_llm_settings: Any,
|
284
|
-
) -> AsyncIterator[Any]:
|
285
|
-
pass
|
286
|
-
|
287
|
-
async def generate_completion_no_retry(
|
202
|
+
@limit_rate
|
203
|
+
async def _generate_completion_once(
|
288
204
|
self,
|
289
205
|
conversation: Messages,
|
290
206
|
*,
|
@@ -295,98 +211,155 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
295
211
|
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
296
212
|
)
|
297
213
|
|
298
|
-
if not self.
|
299
|
-
completion_kwargs.pop("
|
300
|
-
|
301
|
-
else:
|
302
|
-
api_completion = await self._get_parsed_completion(**completion_kwargs)
|
214
|
+
if not self._apply_response_schema_via_provider:
|
215
|
+
completion_kwargs.pop("api_response_schema", None)
|
216
|
+
api_completion = await self._get_completion(**completion_kwargs)
|
303
217
|
|
304
218
|
completion = self._converters.from_completion(
|
305
219
|
api_completion, name=self.model_id
|
306
220
|
)
|
307
221
|
|
308
|
-
if not self.
|
309
|
-
|
310
|
-
# of the LLM provider
|
311
|
-
self._validate_completion(completion)
|
222
|
+
if not self._apply_response_schema_via_provider:
|
223
|
+
self._validate_response(completion)
|
312
224
|
self._validate_tool_calls(completion)
|
313
225
|
|
314
226
|
return completion
|
315
227
|
|
316
|
-
async def
|
228
|
+
async def generate_completion(
|
229
|
+
self,
|
230
|
+
conversation: Messages,
|
231
|
+
*,
|
232
|
+
tool_choice: ToolChoice | None = None,
|
233
|
+
n_choices: int | None = None,
|
234
|
+
proc_name: str | None = None,
|
235
|
+
call_id: str | None = None,
|
236
|
+
) -> Completion:
|
237
|
+
n_attempt = 0
|
238
|
+
while n_attempt <= self.max_response_retries:
|
239
|
+
try:
|
240
|
+
return await self._generate_completion_once(
|
241
|
+
conversation, # type: ignore[return]
|
242
|
+
tool_choice=tool_choice,
|
243
|
+
n_choices=n_choices,
|
244
|
+
)
|
245
|
+
except (LLMResponseValidationError, LLMToolCallValidationError) as err:
|
246
|
+
n_attempt += 1
|
247
|
+
|
248
|
+
if n_attempt > self.max_response_retries:
|
249
|
+
if n_attempt == 1:
|
250
|
+
logger.warning(f"\nCloudLLM completion request failed:\n{err}")
|
251
|
+
if n_attempt > 1:
|
252
|
+
logger.warning(
|
253
|
+
f"\nCloudLLM completion request failed after retrying:\n{err}"
|
254
|
+
)
|
255
|
+
raise err
|
256
|
+
# return make_refusal_completion(self._model_name, err)
|
257
|
+
|
258
|
+
logger.warning(
|
259
|
+
f"\nCloudLLM completion request failed (retry attempt {n_attempt}):"
|
260
|
+
f"\n{err}"
|
261
|
+
)
|
262
|
+
|
263
|
+
return make_refusal_completion(
|
264
|
+
self._model_name,
|
265
|
+
Exception("Unexpected error: retry loop exited without returning"),
|
266
|
+
)
|
267
|
+
|
268
|
+
@limit_rate
|
269
|
+
async def _generate_completion_stream_once(
|
317
270
|
self,
|
318
271
|
conversation: Messages,
|
319
272
|
*,
|
320
273
|
tool_choice: ToolChoice | None = None,
|
321
274
|
n_choices: int | None = None,
|
275
|
+
proc_name: str | None = None,
|
276
|
+
call_id: str | None = None,
|
322
277
|
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
323
278
|
completion_kwargs = self._make_completion_kwargs(
|
324
279
|
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
325
280
|
)
|
281
|
+
if not self._apply_response_schema_via_provider:
|
282
|
+
completion_kwargs.pop("api_response_schema", None)
|
283
|
+
|
284
|
+
api_stream = self._get_completion_stream(**completion_kwargs)
|
285
|
+
api_stream = cast("AsyncIterator[Any]", api_stream)
|
326
286
|
|
327
|
-
|
328
|
-
|
329
|
-
api_stream = await self._get_completion_stream(**completion_kwargs)
|
330
|
-
else:
|
331
|
-
api_stream = await self._get_parsed_completion_stream(**completion_kwargs)
|
287
|
+
async def iterator() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
288
|
+
api_completion_chunks: list[Any] = []
|
332
289
|
|
333
|
-
async def iterate() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
334
|
-
completion_chunks: list[CompletionChunk] = []
|
335
290
|
async for api_completion_chunk in api_stream:
|
291
|
+
api_completion_chunks.append(api_completion_chunk)
|
336
292
|
completion_chunk = self._converters.from_completion_chunk(
|
337
293
|
api_completion_chunk, name=self.model_id
|
338
294
|
)
|
339
|
-
completion_chunks.append(completion_chunk)
|
340
|
-
yield CompletionChunkEvent(data=completion_chunk, name=self.model_id)
|
341
295
|
|
342
|
-
|
343
|
-
|
296
|
+
yield CompletionChunkEvent(
|
297
|
+
data=completion_chunk, proc_name=proc_name, call_id=call_id
|
298
|
+
)
|
299
|
+
|
300
|
+
api_completion = self.combine_completion_chunks(api_completion_chunks)
|
301
|
+
completion = self._converters.from_completion(
|
302
|
+
api_completion, name=self.model_id
|
303
|
+
)
|
344
304
|
|
345
|
-
yield CompletionEvent(data=completion,
|
305
|
+
yield CompletionEvent(data=completion, proc_name=proc_name, call_id=call_id)
|
346
306
|
|
347
|
-
if not self.
|
348
|
-
|
349
|
-
# of the LLM provider
|
350
|
-
self._validate_completion(completion)
|
307
|
+
if not self._apply_response_schema_via_provider:
|
308
|
+
self._validate_response(completion)
|
351
309
|
self._validate_tool_calls(completion)
|
352
310
|
|
353
|
-
return
|
311
|
+
return iterator()
|
354
312
|
|
355
|
-
|
356
|
-
async def generate_completion( # type: ignore[override]
|
313
|
+
async def generate_completion_stream( # type: ignore[override]
|
357
314
|
self,
|
358
315
|
conversation: Messages,
|
359
316
|
*,
|
360
317
|
tool_choice: ToolChoice | None = None,
|
361
318
|
n_choices: int | None = None,
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
)
|
385
|
-
return rate_limiter
|
386
|
-
if rpm is not None:
|
387
|
-
logger.info(f"[{self.__class__.__name__}] Set rate limit to {rpm} RPM")
|
388
|
-
return RateLimiterC(
|
389
|
-
rpm=rpm, chunk_size=chunk_size, max_concurrency=max_concurrency
|
390
|
-
)
|
319
|
+
proc_name: str | None = None,
|
320
|
+
call_id: str | None = None,
|
321
|
+
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
|
322
|
+
n_attempt = 0
|
323
|
+
while n_attempt <= self.max_response_retries:
|
324
|
+
try:
|
325
|
+
async for event in await self._generate_completion_stream_once( # type: ignore[return]
|
326
|
+
conversation, # type: ignore[arg-type]
|
327
|
+
tool_choice=tool_choice,
|
328
|
+
n_choices=n_choices,
|
329
|
+
proc_name=proc_name,
|
330
|
+
call_id=call_id,
|
331
|
+
):
|
332
|
+
yield event
|
333
|
+
return
|
334
|
+
except (LLMResponseValidationError, LLMToolCallValidationError) as err:
|
335
|
+
err_data = LLMStreamingErrorData(
|
336
|
+
error=err, model_name=self._model_name, model_id=self.model_id
|
337
|
+
)
|
338
|
+
yield LLMStreamingErrorEvent(
|
339
|
+
data=err_data, proc_name=proc_name, call_id=call_id
|
340
|
+
)
|
391
341
|
|
392
|
-
|
342
|
+
n_attempt += 1
|
343
|
+
if n_attempt > self.max_response_retries:
|
344
|
+
if n_attempt == 1:
|
345
|
+
logger.warning(f"\nCloudLLM completion request failed:\n{err}")
|
346
|
+
if n_attempt > 1:
|
347
|
+
logger.warning(
|
348
|
+
"\nCloudLLM completion request failed after "
|
349
|
+
f"retrying:\n{err}"
|
350
|
+
)
|
351
|
+
refusal_completion = make_refusal_completion(
|
352
|
+
self._model_name, err
|
353
|
+
)
|
354
|
+
yield CompletionEvent(
|
355
|
+
data=refusal_completion,
|
356
|
+
proc_name=proc_name,
|
357
|
+
call_id=call_id,
|
358
|
+
)
|
359
|
+
raise err
|
360
|
+
# return
|
361
|
+
|
362
|
+
logger.warning(
|
363
|
+
"\nCloudLLM completion request failed "
|
364
|
+
f"(retry attempt {n_attempt}):\n{err}"
|
365
|
+
)
|