grasp_agents 0.4.7__tar.gz → 0.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/PKG-INFO +7 -6
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/README.md +5 -5
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/pyproject.toml +2 -1
- grasp_agents-0.5.1/src/grasp_agents/cloud_llm.py +365 -0
- grasp_agents-0.5.1/src/grasp_agents/comm_processor.py +214 -0
- grasp_agents-0.5.1/src/grasp_agents/errors.py +94 -0
- grasp_agents-0.5.1/src/grasp_agents/litellm/__init__.py +106 -0
- grasp_agents-0.5.1/src/grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents-0.5.1/src/grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents-0.5.1/src/grasp_agents/litellm/converters.py +138 -0
- grasp_agents-0.5.1/src/grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents-0.5.1/src/grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents-0.5.1/src/grasp_agents/llm.py +185 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/llm_agent.py +136 -120
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/llm_agent_memory.py +3 -3
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/llm_policy_executor.py +167 -174
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/memory.py +23 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/openai/__init__.py +24 -9
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/openai/completion_chunk_converters.py +6 -6
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/openai/completion_converters.py +12 -14
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/openai/content_converters.py +1 -3
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/openai/converters.py +6 -8
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/openai/message_converters.py +21 -3
- grasp_agents-0.5.1/src/grasp_agents/openai/openai_llm.py +283 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/openai/tool_converters.py +4 -6
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/packet.py +5 -2
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/packet_pool.py +14 -13
- grasp_agents-0.5.1/src/grasp_agents/printer.py +330 -0
- grasp_agents-0.5.1/src/grasp_agents/processor.py +467 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/prompt_builder.py +2 -2
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/run_context.py +11 -20
- grasp_agents-0.5.1/src/grasp_agents/runner.py +42 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/typing/completion.py +16 -9
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents-0.5.1/src/grasp_agents/typing/events.py +162 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/typing/message.py +25 -1
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/typing/tool.py +2 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/usage_tracker.py +31 -37
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/utils.py +95 -84
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/workflow/looped_workflow.py +60 -11
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/workflow/sequential_workflow.py +43 -11
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/workflow/workflow_processor.py +25 -24
- grasp_agents-0.4.7/src/grasp_agents/cloud_llm.py +0 -398
- grasp_agents-0.4.7/src/grasp_agents/comm_processor.py +0 -213
- grasp_agents-0.4.7/src/grasp_agents/errors.py +0 -34
- grasp_agents-0.4.7/src/grasp_agents/llm.py +0 -150
- grasp_agents-0.4.7/src/grasp_agents/openai/openai_llm.py +0 -231
- grasp_agents-0.4.7/src/grasp_agents/printer.py +0 -170
- grasp_agents-0.4.7/src/grasp_agents/processor.py +0 -329
- grasp_agents-0.4.7/src/grasp_agents/typing/events.py +0 -86
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/.gitignore +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/LICENSE.md +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/__init__.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/costs_dict.yaml +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/generics_utils.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/grasp_logging.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/http_client.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/rate_limiting/__init__.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/rate_limiting/types.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/rate_limiting/utils.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/typing/__init__.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/typing/content.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/typing/converters.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/typing/io.py +0 -0
- {grasp_agents-0.4.7 → grasp_agents-0.5.1}/src/grasp_agents/workflow/__init__.py +0 -0
@@ -1,11 +1,12 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: grasp_agents
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.5.1
|
4
4
|
Summary: Grasp Agents Library
|
5
5
|
License-File: LICENSE.md
|
6
6
|
Requires-Python: <4,>=3.11.4
|
7
7
|
Requires-Dist: dotenv>=0.9.9
|
8
8
|
Requires-Dist: httpx<1,>=0.27.0
|
9
|
+
Requires-Dist: litellm>=1.74.4
|
9
10
|
Requires-Dist: openai<2,>=1.68.2
|
10
11
|
Requires-Dist: pydantic>=2
|
11
12
|
Requires-Dist: pyyaml>=6.0.2
|
@@ -103,7 +104,7 @@ Ensure you have a `.env` file with your OpenAI and Google AI Studio API keys set
|
|
103
104
|
|
104
105
|
```
|
105
106
|
OPENAI_API_KEY=your_openai_api_key
|
106
|
-
|
107
|
+
GEMINI_API_KEY=your_gemini_api_key
|
107
108
|
```
|
108
109
|
|
109
110
|
Create a script, e.g., `problem_recommender.py`:
|
@@ -117,7 +118,7 @@ from dotenv import load_dotenv
|
|
117
118
|
from pydantic import BaseModel, Field
|
118
119
|
|
119
120
|
from grasp_agents.grasp_logging import setup_logging
|
120
|
-
from grasp_agents.
|
121
|
+
from grasp_agents.litellm import LiteLLM, LiteLLMSettings
|
121
122
|
from grasp_agents import LLMAgent, BaseTool, RunContext
|
122
123
|
|
123
124
|
load_dotenv()
|
@@ -178,9 +179,9 @@ class Problem(BaseModel):
|
|
178
179
|
|
179
180
|
teacher = LLMAgent[None, Problem, None](
|
180
181
|
name="teacher",
|
181
|
-
llm=
|
182
|
-
model_name="
|
183
|
-
llm_settings=
|
182
|
+
llm=LiteLLM(
|
183
|
+
model_name="gpt-4.1",
|
184
|
+
llm_settings=LiteLLMSettings(temperature=0.5),
|
184
185
|
),
|
185
186
|
tools=[AskStudentTool()],
|
186
187
|
react_mode=True,
|
@@ -87,7 +87,7 @@ Ensure you have a `.env` file with your OpenAI and Google AI Studio API keys set
|
|
87
87
|
|
88
88
|
```
|
89
89
|
OPENAI_API_KEY=your_openai_api_key
|
90
|
-
|
90
|
+
GEMINI_API_KEY=your_gemini_api_key
|
91
91
|
```
|
92
92
|
|
93
93
|
Create a script, e.g., `problem_recommender.py`:
|
@@ -101,7 +101,7 @@ from dotenv import load_dotenv
|
|
101
101
|
from pydantic import BaseModel, Field
|
102
102
|
|
103
103
|
from grasp_agents.grasp_logging import setup_logging
|
104
|
-
from grasp_agents.
|
104
|
+
from grasp_agents.litellm import LiteLLM, LiteLLMSettings
|
105
105
|
from grasp_agents import LLMAgent, BaseTool, RunContext
|
106
106
|
|
107
107
|
load_dotenv()
|
@@ -162,9 +162,9 @@ class Problem(BaseModel):
|
|
162
162
|
|
163
163
|
teacher = LLMAgent[None, Problem, None](
|
164
164
|
name="teacher",
|
165
|
-
llm=
|
166
|
-
model_name="
|
167
|
-
llm_settings=
|
165
|
+
llm=LiteLLM(
|
166
|
+
model_name="gpt-4.1",
|
167
|
+
llm_settings=LiteLLMSettings(temperature=0.5),
|
168
168
|
),
|
169
169
|
tools=[AskStudentTool()],
|
170
170
|
react_mode=True,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "grasp_agents"
|
3
|
-
version = "0.
|
3
|
+
version = "0.5.1"
|
4
4
|
description = "Grasp Agents Library"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.11.4,<4"
|
@@ -13,6 +13,7 @@ dependencies = [
|
|
13
13
|
"dotenv>=0.9.9",
|
14
14
|
"pyyaml>=6.0.2",
|
15
15
|
"pydantic>=2",
|
16
|
+
"litellm>=1.74.4",
|
16
17
|
]
|
17
18
|
|
18
19
|
[dependency-groups]
|
@@ -0,0 +1,365 @@
|
|
1
|
+
import logging
|
2
|
+
from abc import abstractmethod
|
3
|
+
from collections.abc import AsyncIterator, Mapping, Sequence
|
4
|
+
from copy import deepcopy
|
5
|
+
from typing import Any, Generic, Required, cast
|
6
|
+
|
7
|
+
import httpx
|
8
|
+
from pydantic import BaseModel
|
9
|
+
from typing_extensions import TypedDict
|
10
|
+
|
11
|
+
from .errors import LLMResponseValidationError, LLMToolCallValidationError
|
12
|
+
from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
|
13
|
+
from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
|
14
|
+
from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
|
15
|
+
from .typing.completion import Completion
|
16
|
+
from .typing.completion_chunk import CompletionChoice
|
17
|
+
from .typing.events import (
|
18
|
+
CompletionChunkEvent,
|
19
|
+
CompletionEvent,
|
20
|
+
LLMStreamingErrorData,
|
21
|
+
LLMStreamingErrorEvent,
|
22
|
+
)
|
23
|
+
from .typing.message import AssistantMessage, Messages
|
24
|
+
from .typing.tool import BaseTool, ToolChoice
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
class APIProvider(TypedDict, total=False):
|
30
|
+
name: Required[str]
|
31
|
+
base_url: str | None
|
32
|
+
api_key: str | None
|
33
|
+
# Wildcard patterns for model names that support response schema validation:
|
34
|
+
response_schema_support: tuple[str, ...] | None
|
35
|
+
|
36
|
+
|
37
|
+
def make_refusal_completion(model_name: str, err: BaseException) -> Completion:
|
38
|
+
failed_message = AssistantMessage(content=None, refusal=str(err))
|
39
|
+
|
40
|
+
return Completion(
|
41
|
+
model=model_name,
|
42
|
+
choices=[CompletionChoice(message=failed_message, finish_reason=None, index=0)],
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class CloudLLMSettings(LLMSettings, total=False):
|
47
|
+
extra_headers: dict[str, Any] | None
|
48
|
+
extra_body: object | None
|
49
|
+
extra_query: dict[str, Any] | None
|
50
|
+
|
51
|
+
|
52
|
+
LLMRateLimiter = RateLimiterC[
|
53
|
+
Messages,
|
54
|
+
AssistantMessage
|
55
|
+
| AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent],
|
56
|
+
]
|
57
|
+
|
58
|
+
|
59
|
+
class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
|
60
|
+
def __init__(
|
61
|
+
self,
|
62
|
+
# Base LLM args
|
63
|
+
model_name: str,
|
64
|
+
api_provider: APIProvider,
|
65
|
+
converters: ConvertT_co,
|
66
|
+
llm_settings: SettingsT_co | None = None,
|
67
|
+
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
68
|
+
response_schema: Any | None = None,
|
69
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
70
|
+
apply_response_schema_via_provider: bool = True,
|
71
|
+
model_id: str | None = None,
|
72
|
+
# Connection settings
|
73
|
+
async_http_client: httpx.AsyncClient | None = None,
|
74
|
+
async_http_client_params: (
|
75
|
+
dict[str, Any] | AsyncHTTPClientParams | None
|
76
|
+
) = None,
|
77
|
+
max_client_retries: int = 2,
|
78
|
+
# Rate limiting
|
79
|
+
rate_limiter: LLMRateLimiter | None = None,
|
80
|
+
# LLM response retries: try to regenerate to pass validation
|
81
|
+
max_response_retries: int = 0,
|
82
|
+
**kwargs: Any,
|
83
|
+
) -> None:
|
84
|
+
self.llm_settings: CloudLLMSettings | None
|
85
|
+
|
86
|
+
super().__init__(
|
87
|
+
model_name=model_name,
|
88
|
+
llm_settings=llm_settings,
|
89
|
+
converters=converters,
|
90
|
+
model_id=model_id,
|
91
|
+
tools=tools,
|
92
|
+
response_schema=response_schema,
|
93
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
94
|
+
**kwargs,
|
95
|
+
)
|
96
|
+
|
97
|
+
self._model_name = model_name
|
98
|
+
self._api_provider = api_provider
|
99
|
+
self._apply_response_schema_via_provider = apply_response_schema_via_provider
|
100
|
+
|
101
|
+
if (
|
102
|
+
apply_response_schema_via_provider
|
103
|
+
and response_schema_by_xml_tag is not None
|
104
|
+
):
|
105
|
+
raise ValueError(
|
106
|
+
"Response schema by XML tag is not supported "
|
107
|
+
"when apply_response_schema_via_provider is True."
|
108
|
+
)
|
109
|
+
|
110
|
+
self._rate_limiter: LLMRateLimiter | None = None
|
111
|
+
if rate_limiter is not None:
|
112
|
+
self._rate_limiter = rate_limiter
|
113
|
+
logger.info(
|
114
|
+
f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
|
115
|
+
)
|
116
|
+
|
117
|
+
self._async_http_client: httpx.AsyncClient | None = None
|
118
|
+
if async_http_client is not None:
|
119
|
+
self._async_http_client = async_http_client
|
120
|
+
elif async_http_client_params is not None:
|
121
|
+
self._async_http_client = create_simple_async_httpx_client(
|
122
|
+
async_http_client_params
|
123
|
+
)
|
124
|
+
|
125
|
+
self.max_client_retries = max_client_retries
|
126
|
+
self.max_response_retries = max_response_retries
|
127
|
+
|
128
|
+
@property
|
129
|
+
def api_provider(self) -> APIProvider:
|
130
|
+
return self._api_provider
|
131
|
+
|
132
|
+
@property
|
133
|
+
def rate_limiter(self) -> LLMRateLimiter | None:
|
134
|
+
return self._rate_limiter
|
135
|
+
|
136
|
+
@property
|
137
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
|
138
|
+
return self._tools
|
139
|
+
|
140
|
+
@tools.setter
|
141
|
+
def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
|
142
|
+
if not tools:
|
143
|
+
self._tools = None
|
144
|
+
return
|
145
|
+
strict_value = True if self._apply_response_schema_via_provider else None
|
146
|
+
for t in tools:
|
147
|
+
t.strict = strict_value
|
148
|
+
self._tools = {t.name: t for t in tools}
|
149
|
+
|
150
|
+
def _make_completion_kwargs(
|
151
|
+
self,
|
152
|
+
conversation: Messages,
|
153
|
+
tool_choice: ToolChoice | None = None,
|
154
|
+
n_choices: int | None = None,
|
155
|
+
) -> dict[str, Any]:
|
156
|
+
api_messages = [self._converters.to_message(m) for m in conversation]
|
157
|
+
|
158
|
+
api_tools = None
|
159
|
+
api_tool_choice = None
|
160
|
+
if self.tools:
|
161
|
+
api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
|
162
|
+
if tool_choice is not None:
|
163
|
+
api_tool_choice = self._converters.to_tool_choice(tool_choice)
|
164
|
+
|
165
|
+
api_llm_settings = deepcopy(self.llm_settings or {})
|
166
|
+
|
167
|
+
return dict(
|
168
|
+
api_messages=api_messages,
|
169
|
+
api_tools=api_tools,
|
170
|
+
api_tool_choice=api_tool_choice,
|
171
|
+
api_response_schema=self._response_schema,
|
172
|
+
n_choices=n_choices,
|
173
|
+
**api_llm_settings,
|
174
|
+
)
|
175
|
+
|
176
|
+
@abstractmethod
|
177
|
+
async def _get_completion(
|
178
|
+
self,
|
179
|
+
api_messages: list[Any],
|
180
|
+
*,
|
181
|
+
api_tools: list[Any] | None = None,
|
182
|
+
api_tool_choice: Any | None = None,
|
183
|
+
api_response_schema: type | None = None,
|
184
|
+
n_choices: int | None = None,
|
185
|
+
**api_llm_settings: Any,
|
186
|
+
) -> Any:
|
187
|
+
pass
|
188
|
+
|
189
|
+
@abstractmethod
|
190
|
+
async def _get_completion_stream(
|
191
|
+
self,
|
192
|
+
api_messages: list[Any],
|
193
|
+
*,
|
194
|
+
api_tools: list[Any] | None = None,
|
195
|
+
api_tool_choice: Any | None = None,
|
196
|
+
api_response_schema: type | None = None,
|
197
|
+
n_choices: int | None = None,
|
198
|
+
**api_llm_settings: Any,
|
199
|
+
) -> AsyncIterator[Any]:
|
200
|
+
pass
|
201
|
+
|
202
|
+
@limit_rate
|
203
|
+
async def _generate_completion_once(
|
204
|
+
self,
|
205
|
+
conversation: Messages,
|
206
|
+
*,
|
207
|
+
tool_choice: ToolChoice | None = None,
|
208
|
+
n_choices: int | None = None,
|
209
|
+
) -> Completion:
|
210
|
+
completion_kwargs = self._make_completion_kwargs(
|
211
|
+
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
212
|
+
)
|
213
|
+
|
214
|
+
if not self._apply_response_schema_via_provider:
|
215
|
+
completion_kwargs.pop("api_response_schema", None)
|
216
|
+
api_completion = await self._get_completion(**completion_kwargs)
|
217
|
+
|
218
|
+
completion = self._converters.from_completion(
|
219
|
+
api_completion, name=self.model_id
|
220
|
+
)
|
221
|
+
|
222
|
+
if not self._apply_response_schema_via_provider:
|
223
|
+
self._validate_response(completion)
|
224
|
+
self._validate_tool_calls(completion)
|
225
|
+
|
226
|
+
return completion
|
227
|
+
|
228
|
+
async def generate_completion(
|
229
|
+
self,
|
230
|
+
conversation: Messages,
|
231
|
+
*,
|
232
|
+
tool_choice: ToolChoice | None = None,
|
233
|
+
n_choices: int | None = None,
|
234
|
+
proc_name: str | None = None,
|
235
|
+
call_id: str | None = None,
|
236
|
+
) -> Completion:
|
237
|
+
n_attempt = 0
|
238
|
+
while n_attempt <= self.max_response_retries:
|
239
|
+
try:
|
240
|
+
return await self._generate_completion_once(
|
241
|
+
conversation, # type: ignore[return]
|
242
|
+
tool_choice=tool_choice,
|
243
|
+
n_choices=n_choices,
|
244
|
+
)
|
245
|
+
except (LLMResponseValidationError, LLMToolCallValidationError) as err:
|
246
|
+
n_attempt += 1
|
247
|
+
|
248
|
+
if n_attempt > self.max_response_retries:
|
249
|
+
if n_attempt == 1:
|
250
|
+
logger.warning(f"\nCloudLLM completion request failed:\n{err}")
|
251
|
+
if n_attempt > 1:
|
252
|
+
logger.warning(
|
253
|
+
f"\nCloudLLM completion request failed after retrying:\n{err}"
|
254
|
+
)
|
255
|
+
raise err
|
256
|
+
# return make_refusal_completion(self._model_name, err)
|
257
|
+
|
258
|
+
logger.warning(
|
259
|
+
f"\nCloudLLM completion request failed (retry attempt {n_attempt}):"
|
260
|
+
f"\n{err}"
|
261
|
+
)
|
262
|
+
|
263
|
+
return make_refusal_completion(
|
264
|
+
self._model_name,
|
265
|
+
Exception("Unexpected error: retry loop exited without returning"),
|
266
|
+
)
|
267
|
+
|
268
|
+
@limit_rate
|
269
|
+
async def _generate_completion_stream_once(
|
270
|
+
self,
|
271
|
+
conversation: Messages,
|
272
|
+
*,
|
273
|
+
tool_choice: ToolChoice | None = None,
|
274
|
+
n_choices: int | None = None,
|
275
|
+
proc_name: str | None = None,
|
276
|
+
call_id: str | None = None,
|
277
|
+
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
278
|
+
completion_kwargs = self._make_completion_kwargs(
|
279
|
+
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
280
|
+
)
|
281
|
+
if not self._apply_response_schema_via_provider:
|
282
|
+
completion_kwargs.pop("api_response_schema", None)
|
283
|
+
|
284
|
+
api_stream = self._get_completion_stream(**completion_kwargs)
|
285
|
+
api_stream = cast("AsyncIterator[Any]", api_stream)
|
286
|
+
|
287
|
+
async def iterator() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
288
|
+
api_completion_chunks: list[Any] = []
|
289
|
+
|
290
|
+
async for api_completion_chunk in api_stream:
|
291
|
+
api_completion_chunks.append(api_completion_chunk)
|
292
|
+
completion_chunk = self._converters.from_completion_chunk(
|
293
|
+
api_completion_chunk, name=self.model_id
|
294
|
+
)
|
295
|
+
|
296
|
+
yield CompletionChunkEvent(
|
297
|
+
data=completion_chunk, proc_name=proc_name, call_id=call_id
|
298
|
+
)
|
299
|
+
|
300
|
+
api_completion = self.combine_completion_chunks(api_completion_chunks)
|
301
|
+
completion = self._converters.from_completion(
|
302
|
+
api_completion, name=self.model_id
|
303
|
+
)
|
304
|
+
|
305
|
+
yield CompletionEvent(data=completion, proc_name=proc_name, call_id=call_id)
|
306
|
+
|
307
|
+
if not self._apply_response_schema_via_provider:
|
308
|
+
self._validate_response(completion)
|
309
|
+
self._validate_tool_calls(completion)
|
310
|
+
|
311
|
+
return iterator()
|
312
|
+
|
313
|
+
async def generate_completion_stream( # type: ignore[override]
|
314
|
+
self,
|
315
|
+
conversation: Messages,
|
316
|
+
*,
|
317
|
+
tool_choice: ToolChoice | None = None,
|
318
|
+
n_choices: int | None = None,
|
319
|
+
proc_name: str | None = None,
|
320
|
+
call_id: str | None = None,
|
321
|
+
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
|
322
|
+
n_attempt = 0
|
323
|
+
while n_attempt <= self.max_response_retries:
|
324
|
+
try:
|
325
|
+
async for event in await self._generate_completion_stream_once( # type: ignore[return]
|
326
|
+
conversation, # type: ignore[arg-type]
|
327
|
+
tool_choice=tool_choice,
|
328
|
+
n_choices=n_choices,
|
329
|
+
proc_name=proc_name,
|
330
|
+
call_id=call_id,
|
331
|
+
):
|
332
|
+
yield event
|
333
|
+
return
|
334
|
+
except (LLMResponseValidationError, LLMToolCallValidationError) as err:
|
335
|
+
err_data = LLMStreamingErrorData(
|
336
|
+
error=err, model_name=self._model_name, model_id=self.model_id
|
337
|
+
)
|
338
|
+
yield LLMStreamingErrorEvent(
|
339
|
+
data=err_data, proc_name=proc_name, call_id=call_id
|
340
|
+
)
|
341
|
+
|
342
|
+
n_attempt += 1
|
343
|
+
if n_attempt > self.max_response_retries:
|
344
|
+
if n_attempt == 1:
|
345
|
+
logger.warning(f"\nCloudLLM completion request failed:\n{err}")
|
346
|
+
if n_attempt > 1:
|
347
|
+
logger.warning(
|
348
|
+
"\nCloudLLM completion request failed after "
|
349
|
+
f"retrying:\n{err}"
|
350
|
+
)
|
351
|
+
refusal_completion = make_refusal_completion(
|
352
|
+
self._model_name, err
|
353
|
+
)
|
354
|
+
yield CompletionEvent(
|
355
|
+
data=refusal_completion,
|
356
|
+
proc_name=proc_name,
|
357
|
+
call_id=call_id,
|
358
|
+
)
|
359
|
+
raise err
|
360
|
+
# return
|
361
|
+
|
362
|
+
logger.warning(
|
363
|
+
"\nCloudLLM completion request failed "
|
364
|
+
f"(retry attempt {n_attempt}):\n{err}"
|
365
|
+
)
|
@@ -0,0 +1,214 @@
|
|
1
|
+
import logging
|
2
|
+
from collections.abc import AsyncIterator, Sequence
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast
|
4
|
+
|
5
|
+
from .errors import PacketRoutingError
|
6
|
+
from .memory import MemT
|
7
|
+
from .packet import Packet
|
8
|
+
from .packet_pool import PacketPool
|
9
|
+
from .processor import Processor
|
10
|
+
from .run_context import CtxT, RunContext
|
11
|
+
from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
|
12
|
+
from .typing.io import InT, OutT_co, ProcName
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
18
|
+
|
19
|
+
|
20
|
+
class ExitCommunicationHandler(Protocol[_OutT_contra, CtxT]):
|
21
|
+
def __call__(
|
22
|
+
self,
|
23
|
+
out_packet: Packet[_OutT_contra],
|
24
|
+
ctx: RunContext[CtxT],
|
25
|
+
) -> bool: ...
|
26
|
+
|
27
|
+
|
28
|
+
class SetRecipientsHandler(Protocol[_OutT_contra, CtxT]):
|
29
|
+
def __call__(
|
30
|
+
self, out_packet: Packet[_OutT_contra], ctx: RunContext[CtxT]
|
31
|
+
) -> None: ...
|
32
|
+
|
33
|
+
|
34
|
+
class CommProcessor(
|
35
|
+
Processor[InT, OutT_co, MemT, CtxT],
|
36
|
+
Generic[InT, OutT_co, MemT, CtxT],
|
37
|
+
):
|
38
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
39
|
+
0: "_in_type",
|
40
|
+
1: "_out_type",
|
41
|
+
}
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
name: ProcName,
|
46
|
+
*,
|
47
|
+
recipients: Sequence[ProcName] | None = None,
|
48
|
+
packet_pool: PacketPool[CtxT] | None = None,
|
49
|
+
max_retries: int = 0,
|
50
|
+
) -> None:
|
51
|
+
super().__init__(name=name, max_retries=max_retries)
|
52
|
+
|
53
|
+
self.recipients = recipients or []
|
54
|
+
self._packet_pool = packet_pool
|
55
|
+
self._is_listening = False
|
56
|
+
|
57
|
+
self._exit_communication_impl: (
|
58
|
+
ExitCommunicationHandler[OutT_co, CtxT] | None
|
59
|
+
) = None
|
60
|
+
self._set_recipients_impl: SetRecipientsHandler[OutT_co, CtxT] | None = None
|
61
|
+
|
62
|
+
@property
|
63
|
+
def packet_pool(self) -> PacketPool[CtxT] | None:
|
64
|
+
return self._packet_pool
|
65
|
+
|
66
|
+
@property
|
67
|
+
def is_listening(self) -> bool:
|
68
|
+
return self._is_listening
|
69
|
+
|
70
|
+
def _set_recipients(
|
71
|
+
self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT]
|
72
|
+
) -> None:
|
73
|
+
if self._set_recipients_impl:
|
74
|
+
self._set_recipients_impl(out_packet=out_packet, ctx=ctx)
|
75
|
+
return
|
76
|
+
|
77
|
+
out_packet.recipients = self.recipients
|
78
|
+
|
79
|
+
def _validate_routing(self, recipients: Sequence[ProcName]) -> Sequence[ProcName]:
|
80
|
+
for r in recipients:
|
81
|
+
if r not in self.recipients:
|
82
|
+
raise PacketRoutingError(
|
83
|
+
selected_recipient=r,
|
84
|
+
allowed_recipients=cast("list[str]", self.recipients),
|
85
|
+
)
|
86
|
+
|
87
|
+
return self.recipients
|
88
|
+
|
89
|
+
async def run(
|
90
|
+
self,
|
91
|
+
chat_inputs: Any | None = None,
|
92
|
+
*,
|
93
|
+
in_packet: Packet[InT] | None = None,
|
94
|
+
in_args: InT | Sequence[InT] | None = None,
|
95
|
+
forgetful: bool = False,
|
96
|
+
call_id: str | None = None,
|
97
|
+
ctx: RunContext[CtxT] | None = None,
|
98
|
+
) -> Packet[OutT_co]:
|
99
|
+
call_id = self._generate_call_id(call_id)
|
100
|
+
|
101
|
+
out_packet = await super().run(
|
102
|
+
chat_inputs=chat_inputs,
|
103
|
+
in_packet=in_packet,
|
104
|
+
in_args=in_args,
|
105
|
+
forgetful=forgetful,
|
106
|
+
call_id=call_id,
|
107
|
+
ctx=ctx,
|
108
|
+
)
|
109
|
+
|
110
|
+
if self._packet_pool is not None:
|
111
|
+
if ctx is None:
|
112
|
+
raise ValueError("RunContext must be provided when using PacketPool")
|
113
|
+
if self._exit_communication(out_packet=out_packet, ctx=ctx):
|
114
|
+
ctx.result = out_packet
|
115
|
+
await self._packet_pool.stop_all()
|
116
|
+
return out_packet
|
117
|
+
|
118
|
+
self._set_recipients(out_packet=out_packet, ctx=ctx)
|
119
|
+
out_packet.recipients = self._validate_routing(out_packet.recipients)
|
120
|
+
|
121
|
+
await self._packet_pool.post(out_packet)
|
122
|
+
|
123
|
+
return out_packet
|
124
|
+
|
125
|
+
async def run_stream(
|
126
|
+
self,
|
127
|
+
chat_inputs: Any | None = None,
|
128
|
+
*,
|
129
|
+
in_packet: Packet[InT] | None = None,
|
130
|
+
in_args: InT | Sequence[InT] | None = None,
|
131
|
+
forgetful: bool = False,
|
132
|
+
call_id: str | None = None,
|
133
|
+
ctx: RunContext[CtxT] | None = None,
|
134
|
+
) -> AsyncIterator[Event[Any]]:
|
135
|
+
call_id = self._generate_call_id(call_id)
|
136
|
+
|
137
|
+
out_packet: Packet[OutT_co] | None = None
|
138
|
+
async for event in super().run_stream(
|
139
|
+
chat_inputs=chat_inputs,
|
140
|
+
in_packet=in_packet,
|
141
|
+
in_args=in_args,
|
142
|
+
forgetful=forgetful,
|
143
|
+
call_id=call_id,
|
144
|
+
ctx=ctx,
|
145
|
+
):
|
146
|
+
if isinstance(event, ProcPacketOutputEvent):
|
147
|
+
out_packet = event.data
|
148
|
+
else:
|
149
|
+
yield event
|
150
|
+
|
151
|
+
if out_packet is None:
|
152
|
+
return
|
153
|
+
|
154
|
+
if self._packet_pool is not None:
|
155
|
+
if ctx is None:
|
156
|
+
raise ValueError("RunContext must be provided when using PacketPool")
|
157
|
+
if self._exit_communication(out_packet=out_packet, ctx=ctx):
|
158
|
+
ctx.result = out_packet
|
159
|
+
yield RunResultEvent(
|
160
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
161
|
+
)
|
162
|
+
await self._packet_pool.stop_all()
|
163
|
+
return
|
164
|
+
|
165
|
+
self._set_recipients(out_packet=out_packet, ctx=ctx)
|
166
|
+
out_packet.recipients = self._validate_routing(out_packet.recipients)
|
167
|
+
|
168
|
+
await self._packet_pool.post(out_packet)
|
169
|
+
|
170
|
+
yield ProcPacketOutputEvent(
|
171
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
172
|
+
)
|
173
|
+
|
174
|
+
def start_listening(self, ctx: RunContext[CtxT], **run_kwargs: Any) -> None:
|
175
|
+
if self._packet_pool is None:
|
176
|
+
raise RuntimeError("Packet pool must be initialized before listening")
|
177
|
+
|
178
|
+
if self._is_listening:
|
179
|
+
return
|
180
|
+
self._is_listening = True
|
181
|
+
|
182
|
+
self._packet_pool.register_packet_handler(
|
183
|
+
processor_name=self.name,
|
184
|
+
handler=self.run_stream if ctx.is_streaming else self.run, # type: ignore[call-arg]
|
185
|
+
ctx=ctx,
|
186
|
+
**run_kwargs,
|
187
|
+
)
|
188
|
+
|
189
|
+
def _exit_communication(
|
190
|
+
self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT]
|
191
|
+
) -> bool:
|
192
|
+
if self._exit_communication_impl:
|
193
|
+
return self._exit_communication_impl(out_packet=out_packet, ctx=ctx)
|
194
|
+
|
195
|
+
return False
|
196
|
+
|
197
|
+
def exit_communication(
|
198
|
+
self, func: ExitCommunicationHandler[OutT_co, CtxT]
|
199
|
+
) -> ExitCommunicationHandler[OutT_co, CtxT]:
|
200
|
+
self._exit_communication_impl = func
|
201
|
+
|
202
|
+
return func
|
203
|
+
|
204
|
+
def set_recipients(
|
205
|
+
self, func: SetRecipientsHandler[OutT_co, CtxT]
|
206
|
+
) -> SetRecipientsHandler[OutT_co, CtxT]:
|
207
|
+
self._select_recipients_impl = func
|
208
|
+
|
209
|
+
return func
|
210
|
+
|
211
|
+
# async def stop_listening(self) -> None:
|
212
|
+
# assert self._packet_pool is not None
|
213
|
+
# self._is_listening = False
|
214
|
+
# await self._packet_pool.unregister_packet_handler(self.name)
|