grasp_agents 0.4.7__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/PKG-INFO +7 -6
  2. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/README.md +5 -5
  3. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/pyproject.toml +2 -1
  4. grasp_agents-0.5.0/src/grasp_agents/cloud_llm.py +365 -0
  5. grasp_agents-0.5.0/src/grasp_agents/comm_processor.py +214 -0
  6. grasp_agents-0.5.0/src/grasp_agents/errors.py +94 -0
  7. grasp_agents-0.5.0/src/grasp_agents/litellm/__init__.py +106 -0
  8. grasp_agents-0.5.0/src/grasp_agents/litellm/completion_chunk_converters.py +68 -0
  9. grasp_agents-0.5.0/src/grasp_agents/litellm/completion_converters.py +72 -0
  10. grasp_agents-0.5.0/src/grasp_agents/litellm/converters.py +138 -0
  11. grasp_agents-0.5.0/src/grasp_agents/litellm/lite_llm.py +210 -0
  12. grasp_agents-0.5.0/src/grasp_agents/litellm/message_converters.py +66 -0
  13. grasp_agents-0.5.0/src/grasp_agents/llm.py +185 -0
  14. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/llm_agent.py +136 -120
  15. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/llm_agent_memory.py +3 -3
  16. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/llm_policy_executor.py +167 -174
  17. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/memory.py +4 -0
  18. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/openai/__init__.py +24 -9
  19. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/openai/completion_chunk_converters.py +6 -6
  20. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/openai/completion_converters.py +12 -14
  21. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/openai/content_converters.py +1 -3
  22. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/openai/converters.py +6 -8
  23. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/openai/message_converters.py +21 -3
  24. grasp_agents-0.5.0/src/grasp_agents/openai/openai_llm.py +283 -0
  25. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/openai/tool_converters.py +4 -6
  26. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/packet.py +5 -2
  27. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/packet_pool.py +14 -13
  28. grasp_agents-0.5.0/src/grasp_agents/printer.py +332 -0
  29. grasp_agents-0.5.0/src/grasp_agents/processor.py +469 -0
  30. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/prompt_builder.py +2 -2
  31. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/run_context.py +11 -20
  32. grasp_agents-0.5.0/src/grasp_agents/runner.py +42 -0
  33. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/typing/completion.py +16 -9
  34. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/typing/completion_chunk.py +51 -22
  35. grasp_agents-0.5.0/src/grasp_agents/typing/events.py +162 -0
  36. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/typing/message.py +25 -1
  37. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/typing/tool.py +2 -0
  38. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/usage_tracker.py +31 -37
  39. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/utils.py +95 -84
  40. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/workflow/looped_workflow.py +60 -11
  41. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/workflow/sequential_workflow.py +43 -11
  42. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/workflow/workflow_processor.py +25 -24
  43. grasp_agents-0.4.7/src/grasp_agents/cloud_llm.py +0 -398
  44. grasp_agents-0.4.7/src/grasp_agents/comm_processor.py +0 -213
  45. grasp_agents-0.4.7/src/grasp_agents/errors.py +0 -34
  46. grasp_agents-0.4.7/src/grasp_agents/llm.py +0 -150
  47. grasp_agents-0.4.7/src/grasp_agents/openai/openai_llm.py +0 -231
  48. grasp_agents-0.4.7/src/grasp_agents/printer.py +0 -170
  49. grasp_agents-0.4.7/src/grasp_agents/processor.py +0 -329
  50. grasp_agents-0.4.7/src/grasp_agents/typing/events.py +0 -86
  51. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/.gitignore +0 -0
  52. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/LICENSE.md +0 -0
  53. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/__init__.py +0 -0
  54. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/costs_dict.yaml +0 -0
  55. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/generics_utils.py +0 -0
  56. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/grasp_logging.py +0 -0
  57. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/http_client.py +0 -0
  58. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/rate_limiting/__init__.py +0 -0
  59. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
  60. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/rate_limiting/types.py +0 -0
  61. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/rate_limiting/utils.py +0 -0
  62. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/typing/__init__.py +0 -0
  63. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/typing/content.py +0 -0
  64. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/typing/converters.py +0 -0
  65. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/typing/io.py +0 -0
  66. {grasp_agents-0.4.7 → grasp_agents-0.5.0}/src/grasp_agents/workflow/__init__.py +0 -0
@@ -1,11 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.4.7
3
+ Version: 0.5.0
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
7
7
  Requires-Dist: dotenv>=0.9.9
8
8
  Requires-Dist: httpx<1,>=0.27.0
9
+ Requires-Dist: litellm>=1.74.4
9
10
  Requires-Dist: openai<2,>=1.68.2
10
11
  Requires-Dist: pydantic>=2
11
12
  Requires-Dist: pyyaml>=6.0.2
@@ -103,7 +104,7 @@ Ensure you have a `.env` file with your OpenAI and Google AI Studio API keys set
103
104
 
104
105
  ```
105
106
  OPENAI_API_KEY=your_openai_api_key
106
- GOOGLE_AI_STUDIO_API_KEY=your_google_ai_studio_api_key
107
+ GEMINI_API_KEY=your_gemini_api_key
107
108
  ```
108
109
 
109
110
  Create a script, e.g., `problem_recommender.py`:
@@ -117,7 +118,7 @@ from dotenv import load_dotenv
117
118
  from pydantic import BaseModel, Field
118
119
 
119
120
  from grasp_agents.grasp_logging import setup_logging
120
- from grasp_agents.openai import OpenAILLM, OpenAILLMSettings
121
+ from grasp_agents.litellm import LiteLLM, LiteLLMSettings
121
122
  from grasp_agents import LLMAgent, BaseTool, RunContext
122
123
 
123
124
  load_dotenv()
@@ -178,9 +179,9 @@ class Problem(BaseModel):
178
179
 
179
180
  teacher = LLMAgent[None, Problem, None](
180
181
  name="teacher",
181
- llm=OpenAILLM(
182
- model_name="openai:gpt-4.1",
183
- llm_settings=OpenAILLMSettings(temperature=0.5),
182
+ llm=LiteLLM(
183
+ model_name="gpt-4.1",
184
+ llm_settings=LiteLLMSettings(temperature=0.5),
184
185
  ),
185
186
  tools=[AskStudentTool()],
186
187
  react_mode=True,
@@ -87,7 +87,7 @@ Ensure you have a `.env` file with your OpenAI and Google AI Studio API keys set
87
87
 
88
88
  ```
89
89
  OPENAI_API_KEY=your_openai_api_key
90
- GOOGLE_AI_STUDIO_API_KEY=your_google_ai_studio_api_key
90
+ GEMINI_API_KEY=your_gemini_api_key
91
91
  ```
92
92
 
93
93
  Create a script, e.g., `problem_recommender.py`:
@@ -101,7 +101,7 @@ from dotenv import load_dotenv
101
101
  from pydantic import BaseModel, Field
102
102
 
103
103
  from grasp_agents.grasp_logging import setup_logging
104
- from grasp_agents.openai import OpenAILLM, OpenAILLMSettings
104
+ from grasp_agents.litellm import LiteLLM, LiteLLMSettings
105
105
  from grasp_agents import LLMAgent, BaseTool, RunContext
106
106
 
107
107
  load_dotenv()
@@ -162,9 +162,9 @@ class Problem(BaseModel):
162
162
 
163
163
  teacher = LLMAgent[None, Problem, None](
164
164
  name="teacher",
165
- llm=OpenAILLM(
166
- model_name="openai:gpt-4.1",
167
- llm_settings=OpenAILLMSettings(temperature=0.5),
165
+ llm=LiteLLM(
166
+ model_name="gpt-4.1",
167
+ llm_settings=LiteLLMSettings(temperature=0.5),
168
168
  ),
169
169
  tools=[AskStudentTool()],
170
170
  react_mode=True,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "grasp_agents"
3
- version = "0.4.7"
3
+ version = "0.5.0"
4
4
  description = "Grasp Agents Library"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11.4,<4"
@@ -13,6 +13,7 @@ dependencies = [
13
13
  "dotenv>=0.9.9",
14
14
  "pyyaml>=6.0.2",
15
15
  "pydantic>=2",
16
+ "litellm>=1.74.4",
16
17
  ]
17
18
 
18
19
  [dependency-groups]
@@ -0,0 +1,365 @@
1
+ import logging
2
+ from abc import abstractmethod
3
+ from collections.abc import AsyncIterator, Mapping, Sequence
4
+ from copy import deepcopy
5
+ from typing import Any, Generic, Required, cast
6
+
7
+ import httpx
8
+ from pydantic import BaseModel
9
+ from typing_extensions import TypedDict
10
+
11
+ from .errors import LLMResponseValidationError, LLMToolCallValidationError
12
+ from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
13
+ from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
14
+ from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
15
+ from .typing.completion import Completion
16
+ from .typing.completion_chunk import CompletionChoice
17
+ from .typing.events import (
18
+ CompletionChunkEvent,
19
+ CompletionEvent,
20
+ LLMStreamingErrorData,
21
+ LLMStreamingErrorEvent,
22
+ )
23
+ from .typing.message import AssistantMessage, Messages
24
+ from .typing.tool import BaseTool, ToolChoice
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class APIProvider(TypedDict, total=False):
30
+ name: Required[str]
31
+ base_url: str | None
32
+ api_key: str | None
33
+ # Wildcard patterns for model names that support response schema validation:
34
+ response_schema_support: tuple[str, ...] | None
35
+
36
+
37
+ def make_refusal_completion(model_name: str, err: BaseException) -> Completion:
38
+ failed_message = AssistantMessage(content=None, refusal=str(err))
39
+
40
+ return Completion(
41
+ model=model_name,
42
+ choices=[CompletionChoice(message=failed_message, finish_reason=None, index=0)],
43
+ )
44
+
45
+
46
+ class CloudLLMSettings(LLMSettings, total=False):
47
+ extra_headers: dict[str, Any] | None
48
+ extra_body: object | None
49
+ extra_query: dict[str, Any] | None
50
+
51
+
52
+ LLMRateLimiter = RateLimiterC[
53
+ Messages,
54
+ AssistantMessage
55
+ | AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent],
56
+ ]
57
+
58
+
59
+ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
60
+ def __init__(
61
+ self,
62
+ # Base LLM args
63
+ model_name: str,
64
+ api_provider: APIProvider,
65
+ converters: ConvertT_co,
66
+ llm_settings: SettingsT_co | None = None,
67
+ tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
68
+ response_schema: Any | None = None,
69
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
70
+ apply_response_schema_via_provider: bool = True,
71
+ model_id: str | None = None,
72
+ # Connection settings
73
+ async_http_client: httpx.AsyncClient | None = None,
74
+ async_http_client_params: (
75
+ dict[str, Any] | AsyncHTTPClientParams | None
76
+ ) = None,
77
+ max_client_retries: int = 2,
78
+ # Rate limiting
79
+ rate_limiter: LLMRateLimiter | None = None,
80
+ # LLM response retries: try to regenerate to pass validation
81
+ max_response_retries: int = 0,
82
+ **kwargs: Any,
83
+ ) -> None:
84
+ self.llm_settings: CloudLLMSettings | None
85
+
86
+ super().__init__(
87
+ model_name=model_name,
88
+ llm_settings=llm_settings,
89
+ converters=converters,
90
+ model_id=model_id,
91
+ tools=tools,
92
+ response_schema=response_schema,
93
+ response_schema_by_xml_tag=response_schema_by_xml_tag,
94
+ **kwargs,
95
+ )
96
+
97
+ self._model_name = model_name
98
+ self._api_provider = api_provider
99
+ self._apply_response_schema_via_provider = apply_response_schema_via_provider
100
+
101
+ if (
102
+ apply_response_schema_via_provider
103
+ and response_schema_by_xml_tag is not None
104
+ ):
105
+ raise ValueError(
106
+ "Response schema by XML tag is not supported "
107
+ "when apply_response_schema_via_provider is True."
108
+ )
109
+
110
+ self._rate_limiter: LLMRateLimiter | None = None
111
+ if rate_limiter is not None:
112
+ self._rate_limiter = rate_limiter
113
+ logger.info(
114
+ f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
115
+ )
116
+
117
+ self._async_http_client: httpx.AsyncClient | None = None
118
+ if async_http_client is not None:
119
+ self._async_http_client = async_http_client
120
+ elif async_http_client_params is not None:
121
+ self._async_http_client = create_simple_async_httpx_client(
122
+ async_http_client_params
123
+ )
124
+
125
+ self.max_client_retries = max_client_retries
126
+ self.max_response_retries = max_response_retries
127
+
128
+ @property
129
+ def api_provider(self) -> APIProvider:
130
+ return self._api_provider
131
+
132
+ @property
133
+ def rate_limiter(self) -> LLMRateLimiter | None:
134
+ return self._rate_limiter
135
+
136
+ @property
137
+ def tools(self) -> dict[str, BaseTool[BaseModel, Any, Any]] | None:
138
+ return self._tools
139
+
140
+ @tools.setter
141
+ def tools(self, tools: Sequence[BaseTool[BaseModel, Any, Any]] | None) -> None:
142
+ if not tools:
143
+ self._tools = None
144
+ return
145
+ strict_value = True if self._apply_response_schema_via_provider else None
146
+ for t in tools:
147
+ t.strict = strict_value
148
+ self._tools = {t.name: t for t in tools}
149
+
150
+ def _make_completion_kwargs(
151
+ self,
152
+ conversation: Messages,
153
+ tool_choice: ToolChoice | None = None,
154
+ n_choices: int | None = None,
155
+ ) -> dict[str, Any]:
156
+ api_messages = [self._converters.to_message(m) for m in conversation]
157
+
158
+ api_tools = None
159
+ api_tool_choice = None
160
+ if self.tools:
161
+ api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
162
+ if tool_choice is not None:
163
+ api_tool_choice = self._converters.to_tool_choice(tool_choice)
164
+
165
+ api_llm_settings = deepcopy(self.llm_settings or {})
166
+
167
+ return dict(
168
+ api_messages=api_messages,
169
+ api_tools=api_tools,
170
+ api_tool_choice=api_tool_choice,
171
+ api_response_schema=self._response_schema,
172
+ n_choices=n_choices,
173
+ **api_llm_settings,
174
+ )
175
+
176
+ @abstractmethod
177
+ async def _get_completion(
178
+ self,
179
+ api_messages: list[Any],
180
+ *,
181
+ api_tools: list[Any] | None = None,
182
+ api_tool_choice: Any | None = None,
183
+ api_response_schema: type | None = None,
184
+ n_choices: int | None = None,
185
+ **api_llm_settings: Any,
186
+ ) -> Any:
187
+ pass
188
+
189
+ @abstractmethod
190
+ async def _get_completion_stream(
191
+ self,
192
+ api_messages: list[Any],
193
+ *,
194
+ api_tools: list[Any] | None = None,
195
+ api_tool_choice: Any | None = None,
196
+ api_response_schema: type | None = None,
197
+ n_choices: int | None = None,
198
+ **api_llm_settings: Any,
199
+ ) -> AsyncIterator[Any]:
200
+ pass
201
+
202
+ @limit_rate
203
+ async def _generate_completion_once(
204
+ self,
205
+ conversation: Messages,
206
+ *,
207
+ tool_choice: ToolChoice | None = None,
208
+ n_choices: int | None = None,
209
+ ) -> Completion:
210
+ completion_kwargs = self._make_completion_kwargs(
211
+ conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
212
+ )
213
+
214
+ if not self._apply_response_schema_via_provider:
215
+ completion_kwargs.pop("api_response_schema", None)
216
+ api_completion = await self._get_completion(**completion_kwargs)
217
+
218
+ completion = self._converters.from_completion(
219
+ api_completion, name=self.model_id
220
+ )
221
+
222
+ if not self._apply_response_schema_via_provider:
223
+ self._validate_response(completion)
224
+ self._validate_tool_calls(completion)
225
+
226
+ return completion
227
+
228
+ async def generate_completion(
229
+ self,
230
+ conversation: Messages,
231
+ *,
232
+ tool_choice: ToolChoice | None = None,
233
+ n_choices: int | None = None,
234
+ proc_name: str | None = None,
235
+ call_id: str | None = None,
236
+ ) -> Completion:
237
+ n_attempt = 0
238
+ while n_attempt <= self.max_response_retries:
239
+ try:
240
+ return await self._generate_completion_once(
241
+ conversation, # type: ignore[return]
242
+ tool_choice=tool_choice,
243
+ n_choices=n_choices,
244
+ )
245
+ except (LLMResponseValidationError, LLMToolCallValidationError) as err:
246
+ n_attempt += 1
247
+
248
+ if n_attempt > self.max_response_retries:
249
+ if n_attempt == 1:
250
+ logger.warning(f"\nCloudLLM completion request failed:\n{err}")
251
+ if n_attempt > 1:
252
+ logger.warning(
253
+ f"\nCloudLLM completion request failed after retrying:\n{err}"
254
+ )
255
+ raise err
256
+ # return make_refusal_completion(self._model_name, err)
257
+
258
+ logger.warning(
259
+ f"\nCloudLLM completion request failed (retry attempt {n_attempt}):"
260
+ f"\n{err}"
261
+ )
262
+
263
+ return make_refusal_completion(
264
+ self._model_name,
265
+ Exception("Unexpected error: retry loop exited without returning"),
266
+ )
267
+
268
+ @limit_rate
269
+ async def _generate_completion_stream_once(
270
+ self,
271
+ conversation: Messages,
272
+ *,
273
+ tool_choice: ToolChoice | None = None,
274
+ n_choices: int | None = None,
275
+ proc_name: str | None = None,
276
+ call_id: str | None = None,
277
+ ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
278
+ completion_kwargs = self._make_completion_kwargs(
279
+ conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
280
+ )
281
+ if not self._apply_response_schema_via_provider:
282
+ completion_kwargs.pop("api_response_schema", None)
283
+
284
+ api_stream = self._get_completion_stream(**completion_kwargs)
285
+ api_stream = cast("AsyncIterator[Any]", api_stream)
286
+
287
+ async def iterator() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
288
+ api_completion_chunks: list[Any] = []
289
+
290
+ async for api_completion_chunk in api_stream:
291
+ api_completion_chunks.append(api_completion_chunk)
292
+ completion_chunk = self._converters.from_completion_chunk(
293
+ api_completion_chunk, name=self.model_id
294
+ )
295
+
296
+ yield CompletionChunkEvent(
297
+ data=completion_chunk, proc_name=proc_name, call_id=call_id
298
+ )
299
+
300
+ api_completion = self.combine_completion_chunks(api_completion_chunks)
301
+ completion = self._converters.from_completion(
302
+ api_completion, name=self.model_id
303
+ )
304
+
305
+ yield CompletionEvent(data=completion, proc_name=proc_name, call_id=call_id)
306
+
307
+ if not self._apply_response_schema_via_provider:
308
+ self._validate_response(completion)
309
+ self._validate_tool_calls(completion)
310
+
311
+ return iterator()
312
+
313
+ async def generate_completion_stream( # type: ignore[override]
314
+ self,
315
+ conversation: Messages,
316
+ *,
317
+ tool_choice: ToolChoice | None = None,
318
+ n_choices: int | None = None,
319
+ proc_name: str | None = None,
320
+ call_id: str | None = None,
321
+ ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
322
+ n_attempt = 0
323
+ while n_attempt <= self.max_response_retries:
324
+ try:
325
+ async for event in await self._generate_completion_stream_once( # type: ignore[return]
326
+ conversation, # type: ignore[arg-type]
327
+ tool_choice=tool_choice,
328
+ n_choices=n_choices,
329
+ proc_name=proc_name,
330
+ call_id=call_id,
331
+ ):
332
+ yield event
333
+ return
334
+ except (LLMResponseValidationError, LLMToolCallValidationError) as err:
335
+ err_data = LLMStreamingErrorData(
336
+ error=err, model_name=self._model_name, model_id=self.model_id
337
+ )
338
+ yield LLMStreamingErrorEvent(
339
+ data=err_data, proc_name=proc_name, call_id=call_id
340
+ )
341
+
342
+ n_attempt += 1
343
+ if n_attempt > self.max_response_retries:
344
+ if n_attempt == 1:
345
+ logger.warning(f"\nCloudLLM completion request failed:\n{err}")
346
+ if n_attempt > 1:
347
+ logger.warning(
348
+ "\nCloudLLM completion request failed after "
349
+ f"retrying:\n{err}"
350
+ )
351
+ refusal_completion = make_refusal_completion(
352
+ self._model_name, err
353
+ )
354
+ yield CompletionEvent(
355
+ data=refusal_completion,
356
+ proc_name=proc_name,
357
+ call_id=call_id,
358
+ )
359
+ raise err
360
+ # return
361
+
362
+ logger.warning(
363
+ "\nCloudLLM completion request failed "
364
+ f"(retry attempt {n_attempt}):\n{err}"
365
+ )
@@ -0,0 +1,214 @@
1
+ import logging
2
+ from collections.abc import AsyncIterator, Sequence
3
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast
4
+
5
+ from .errors import PacketRoutingError
6
+ from .memory import MemT
7
+ from .packet import Packet
8
+ from .packet_pool import PacketPool
9
+ from .processor import Processor
10
+ from .run_context import CtxT, RunContext
11
+ from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
12
+ from .typing.io import InT, OutT_co, ProcName
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
18
+
19
+
20
+ class ExitCommunicationHandler(Protocol[_OutT_contra, CtxT]):
21
+ def __call__(
22
+ self,
23
+ out_packet: Packet[_OutT_contra],
24
+ ctx: RunContext[CtxT],
25
+ ) -> bool: ...
26
+
27
+
28
+ class SetRecipientsHandler(Protocol[_OutT_contra, CtxT]):
29
+ def __call__(
30
+ self, out_packet: Packet[_OutT_contra], ctx: RunContext[CtxT]
31
+ ) -> None: ...
32
+
33
+
34
+ class CommProcessor(
35
+ Processor[InT, OutT_co, MemT, CtxT],
36
+ Generic[InT, OutT_co, MemT, CtxT],
37
+ ):
38
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
39
+ 0: "_in_type",
40
+ 1: "_out_type",
41
+ }
42
+
43
+ def __init__(
44
+ self,
45
+ name: ProcName,
46
+ *,
47
+ recipients: Sequence[ProcName] | None = None,
48
+ packet_pool: PacketPool[CtxT] | None = None,
49
+ max_retries: int = 0,
50
+ ) -> None:
51
+ super().__init__(name=name, max_retries=max_retries)
52
+
53
+ self.recipients = recipients or []
54
+ self._packet_pool = packet_pool
55
+ self._is_listening = False
56
+
57
+ self._exit_communication_impl: (
58
+ ExitCommunicationHandler[OutT_co, CtxT] | None
59
+ ) = None
60
+ self._set_recipients_impl: SetRecipientsHandler[OutT_co, CtxT] | None = None
61
+
62
+ @property
63
+ def packet_pool(self) -> PacketPool[CtxT] | None:
64
+ return self._packet_pool
65
+
66
+ @property
67
+ def is_listening(self) -> bool:
68
+ return self._is_listening
69
+
70
+ def _set_recipients(
71
+ self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT]
72
+ ) -> None:
73
+ if self._set_recipients_impl:
74
+ self._set_recipients_impl(out_packet=out_packet, ctx=ctx)
75
+ return
76
+
77
+ out_packet.recipients = self.recipients
78
+
79
+ def _validate_routing(self, recipients: Sequence[ProcName]) -> Sequence[ProcName]:
80
+ for r in recipients:
81
+ if r not in self.recipients:
82
+ raise PacketRoutingError(
83
+ selected_recipient=r,
84
+ allowed_recipients=cast("list[str]", self.recipients),
85
+ )
86
+
87
+ return self.recipients
88
+
89
+ async def run(
90
+ self,
91
+ chat_inputs: Any | None = None,
92
+ *,
93
+ in_packet: Packet[InT] | None = None,
94
+ in_args: InT | Sequence[InT] | None = None,
95
+ forgetful: bool = False,
96
+ call_id: str | None = None,
97
+ ctx: RunContext[CtxT] | None = None,
98
+ ) -> Packet[OutT_co]:
99
+ call_id = self._generate_call_id(call_id)
100
+
101
+ out_packet = await super().run(
102
+ chat_inputs=chat_inputs,
103
+ in_packet=in_packet,
104
+ in_args=in_args,
105
+ forgetful=forgetful,
106
+ call_id=call_id,
107
+ ctx=ctx,
108
+ )
109
+
110
+ if self._packet_pool is not None:
111
+ if ctx is None:
112
+ raise ValueError("RunContext must be provided when using PacketPool")
113
+ if self._exit_communication(out_packet=out_packet, ctx=ctx):
114
+ ctx.result = out_packet
115
+ await self._packet_pool.stop_all()
116
+ return out_packet
117
+
118
+ self._set_recipients(out_packet=out_packet, ctx=ctx)
119
+ out_packet.recipients = self._validate_routing(out_packet.recipients)
120
+
121
+ await self._packet_pool.post(out_packet)
122
+
123
+ return out_packet
124
+
125
+ async def run_stream(
126
+ self,
127
+ chat_inputs: Any | None = None,
128
+ *,
129
+ in_packet: Packet[InT] | None = None,
130
+ in_args: InT | Sequence[InT] | None = None,
131
+ forgetful: bool = False,
132
+ call_id: str | None = None,
133
+ ctx: RunContext[CtxT] | None = None,
134
+ ) -> AsyncIterator[Event[Any]]:
135
+ call_id = self._generate_call_id(call_id)
136
+
137
+ out_packet: Packet[OutT_co] | None = None
138
+ async for event in super().run_stream(
139
+ chat_inputs=chat_inputs,
140
+ in_packet=in_packet,
141
+ in_args=in_args,
142
+ forgetful=forgetful,
143
+ call_id=call_id,
144
+ ctx=ctx,
145
+ ):
146
+ if isinstance(event, ProcPacketOutputEvent):
147
+ out_packet = event.data
148
+ else:
149
+ yield event
150
+
151
+ if out_packet is None:
152
+ return
153
+
154
+ if self._packet_pool is not None:
155
+ if ctx is None:
156
+ raise ValueError("RunContext must be provided when using PacketPool")
157
+ if self._exit_communication(out_packet=out_packet, ctx=ctx):
158
+ ctx.result = out_packet
159
+ yield RunResultEvent(
160
+ data=out_packet, proc_name=self.name, call_id=call_id
161
+ )
162
+ await self._packet_pool.stop_all()
163
+ return
164
+
165
+ self._set_recipients(out_packet=out_packet, ctx=ctx)
166
+ out_packet.recipients = self._validate_routing(out_packet.recipients)
167
+
168
+ await self._packet_pool.post(out_packet)
169
+
170
+ yield ProcPacketOutputEvent(
171
+ data=out_packet, proc_name=self.name, call_id=call_id
172
+ )
173
+
174
+ def start_listening(self, ctx: RunContext[CtxT], **run_kwargs: Any) -> None:
175
+ if self._packet_pool is None:
176
+ raise RuntimeError("Packet pool must be initialized before listening")
177
+
178
+ if self._is_listening:
179
+ return
180
+ self._is_listening = True
181
+
182
+ self._packet_pool.register_packet_handler(
183
+ processor_name=self.name,
184
+ handler=self.run_stream if ctx.is_streaming else self.run, # type: ignore[call-arg]
185
+ ctx=ctx,
186
+ **run_kwargs,
187
+ )
188
+
189
+ def _exit_communication(
190
+ self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT]
191
+ ) -> bool:
192
+ if self._exit_communication_impl:
193
+ return self._exit_communication_impl(out_packet=out_packet, ctx=ctx)
194
+
195
+ return False
196
+
197
+ def exit_communication(
198
+ self, func: ExitCommunicationHandler[OutT_co, CtxT]
199
+ ) -> ExitCommunicationHandler[OutT_co, CtxT]:
200
+ self._exit_communication_impl = func
201
+
202
+ return func
203
+
204
+ def set_recipients(
205
+ self, func: SetRecipientsHandler[OutT_co, CtxT]
206
+ ) -> SetRecipientsHandler[OutT_co, CtxT]:
207
+ self._select_recipients_impl = func
208
+
209
+ return func
210
+
211
+ # async def stop_listening(self) -> None:
212
+ # assert self._packet_pool is not None
213
+ # self._is_listening = False
214
+ # await self._packet_pool.unregister_packet_handler(self.name)