grasp_agents 0.2.10__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +15 -14
- grasp_agents/cloud_llm.py +118 -131
- grasp_agents/comm_processor.py +201 -0
- grasp_agents/generics_utils.py +15 -7
- grasp_agents/llm.py +60 -31
- grasp_agents/llm_agent.py +229 -278
- grasp_agents/llm_agent_memory.py +58 -0
- grasp_agents/llm_policy_executor.py +482 -0
- grasp_agents/memory.py +20 -134
- grasp_agents/message_history.py +140 -0
- grasp_agents/openai/__init__.py +54 -36
- grasp_agents/openai/completion_chunk_converters.py +78 -0
- grasp_agents/openai/completion_converters.py +53 -30
- grasp_agents/openai/content_converters.py +13 -14
- grasp_agents/openai/converters.py +44 -68
- grasp_agents/openai/message_converters.py +58 -72
- grasp_agents/openai/openai_llm.py +101 -42
- grasp_agents/openai/tool_converters.py +24 -19
- grasp_agents/packet.py +24 -0
- grasp_agents/packet_pool.py +91 -0
- grasp_agents/printer.py +29 -15
- grasp_agents/processor.py +194 -0
- grasp_agents/prompt_builder.py +173 -176
- grasp_agents/run_context.py +21 -41
- grasp_agents/typing/completion.py +58 -12
- grasp_agents/typing/completion_chunk.py +173 -0
- grasp_agents/typing/converters.py +8 -12
- grasp_agents/typing/events.py +86 -0
- grasp_agents/typing/io.py +4 -13
- grasp_agents/typing/message.py +12 -50
- grasp_agents/typing/tool.py +52 -26
- grasp_agents/usage_tracker.py +6 -6
- grasp_agents/utils.py +3 -3
- grasp_agents/workflow/looped_workflow.py +132 -0
- grasp_agents/workflow/parallel_processor.py +95 -0
- grasp_agents/workflow/sequential_workflow.py +66 -0
- grasp_agents/workflow/workflow_processor.py +78 -0
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
- grasp_agents-0.3.1.dist-info/RECORD +51 -0
- grasp_agents/agent_message.py +0 -27
- grasp_agents/agent_message_pool.py +0 -92
- grasp_agents/base_agent.py +0 -51
- grasp_agents/comm_agent.py +0 -217
- grasp_agents/llm_agent_state.py +0 -79
- grasp_agents/tool_orchestrator.py +0 -203
- grasp_agents/workflow/looped_agent.py +0 -120
- grasp_agents/workflow/sequential_agent.py +0 -63
- grasp_agents/workflow/workflow_agent.py +0 -73
- grasp_agents-0.2.10.dist-info/RECORD +0 -46
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/__init__.py
CHANGED
@@ -1,38 +1,39 @@
|
|
1
1
|
# pyright: reportUnusedImport=false
|
2
2
|
|
3
3
|
|
4
|
-
from .
|
5
|
-
from .base_agent import BaseAgent
|
6
|
-
from .comm_agent import CommunicatingAgent
|
4
|
+
from .comm_processor import CommProcessor
|
7
5
|
from .llm import LLM, LLMSettings
|
8
6
|
from .llm_agent import LLMAgent
|
9
|
-
from .
|
7
|
+
from .llm_agent_memory import LLMAgentMemory
|
8
|
+
from .memory import Memory
|
9
|
+
from .packet import Packet
|
10
|
+
from .processor import Processor
|
11
|
+
from .run_context import RunArgs, RunContext
|
10
12
|
from .typing.completion import Completion
|
11
13
|
from .typing.content import Content, ImageData
|
12
|
-
from .typing.io import
|
13
|
-
from .typing.message import AssistantMessage,
|
14
|
+
from .typing.io import LLMPrompt, LLMPromptArgs, ProcName
|
15
|
+
from .typing.message import AssistantMessage, Messages, SystemMessage, UserMessage
|
14
16
|
from .typing.tool import BaseTool
|
15
17
|
|
16
18
|
__all__ = [
|
17
19
|
"LLM",
|
18
|
-
"AgentID",
|
19
|
-
"AgentMessage",
|
20
|
-
"AgentState",
|
21
20
|
"AssistantMessage",
|
22
|
-
"BaseAgent",
|
23
21
|
"BaseTool",
|
24
|
-
"
|
22
|
+
"CommProcessor",
|
25
23
|
"Completion",
|
26
24
|
"Content",
|
27
|
-
"Conversation",
|
28
25
|
"ImageData",
|
29
26
|
"LLMAgent",
|
30
|
-
"LLMFormattedArgs",
|
31
27
|
"LLMPrompt",
|
32
28
|
"LLMPromptArgs",
|
33
29
|
"LLMSettings",
|
30
|
+
"Messages",
|
31
|
+
"Packet",
|
32
|
+
"Packet",
|
33
|
+
"ProcName",
|
34
|
+
"Processor",
|
34
35
|
"RunArgs",
|
35
|
-
"
|
36
|
+
"RunContext",
|
36
37
|
"SystemMessage",
|
37
38
|
"UserMessage",
|
38
39
|
]
|
grasp_agents/cloud_llm.py
CHANGED
@@ -17,16 +17,14 @@ from tenacity import (
|
|
17
17
|
from typing_extensions import TypedDict
|
18
18
|
|
19
19
|
from .http_client import AsyncHTTPClientParams, create_async_http_client
|
20
|
-
from .llm import LLM,
|
21
|
-
from .
|
22
|
-
from .rate_limiting.rate_limiter_chunked import
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
from .typing.
|
27
|
-
from .typing.message import AssistantMessage, Conversation
|
20
|
+
from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
|
21
|
+
from .message_history import MessageHistory
|
22
|
+
from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate_chunked
|
23
|
+
from .typing.completion import Completion
|
24
|
+
from .typing.completion_chunk import CompletionChunk, combine_completion_chunks
|
25
|
+
from .typing.events import CompletionChunkEvent, CompletionEvent
|
26
|
+
from .typing.message import AssistantMessage, Messages
|
28
27
|
from .typing.tool import BaseTool, ToolChoice
|
29
|
-
from .utils import validate_obj_from_json_or_py_string
|
30
28
|
|
31
29
|
logger = logging.getLogger(__name__)
|
32
30
|
|
@@ -38,7 +36,7 @@ class APIProviderInfo(TypedDict):
|
|
38
36
|
name: APIProvider
|
39
37
|
base_url: str
|
40
38
|
api_key: str | None
|
41
|
-
|
39
|
+
struct_outputs_support: tuple[str, ...]
|
42
40
|
|
43
41
|
|
44
42
|
PROVIDERS: dict[APIProvider, APIProviderInfo] = {
|
@@ -46,19 +44,19 @@ PROVIDERS: dict[APIProvider, APIProviderInfo] = {
|
|
46
44
|
name="openai",
|
47
45
|
base_url="https://api.openai.com/v1",
|
48
46
|
api_key=os.getenv("OPENAI_API_KEY"),
|
49
|
-
|
47
|
+
struct_outputs_support=("*",),
|
50
48
|
),
|
51
49
|
"openrouter": APIProviderInfo(
|
52
50
|
name="openrouter",
|
53
51
|
base_url="https://openrouter.ai/api/v1",
|
54
52
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
55
|
-
|
53
|
+
struct_outputs_support=(),
|
56
54
|
),
|
57
55
|
"google_ai_studio": APIProviderInfo(
|
58
56
|
name="google_ai_studio",
|
59
57
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
60
58
|
api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
|
61
|
-
|
59
|
+
struct_outputs_support=("*",),
|
62
60
|
),
|
63
61
|
}
|
64
62
|
|
@@ -88,20 +86,16 @@ def retry_before_callback(retry_state: RetryCallState) -> None:
|
|
88
86
|
|
89
87
|
|
90
88
|
class CloudLLMSettings(LLMSettings, total=False):
|
91
|
-
|
92
|
-
temperature: float | None
|
93
|
-
top_p: float | None
|
94
|
-
seed: int | None
|
95
|
-
use_structured_outputs: bool
|
89
|
+
use_struct_outputs: bool
|
96
90
|
|
97
91
|
|
98
|
-
class CloudLLM(LLM[
|
92
|
+
class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
|
99
93
|
def __init__(
|
100
94
|
self,
|
101
95
|
# Base LLM args
|
102
96
|
model_name: str,
|
103
|
-
converters:
|
104
|
-
llm_settings:
|
97
|
+
converters: ConvertT_co,
|
98
|
+
llm_settings: SettingsT_co | None = None,
|
105
99
|
model_id: str | None = None,
|
106
100
|
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
107
101
|
response_format: type | Mapping[str, type] | None = None,
|
@@ -110,7 +104,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
110
104
|
dict[str, Any] | AsyncHTTPClientParams | None
|
111
105
|
) = None,
|
112
106
|
# Rate limiting
|
113
|
-
rate_limiter: (RateLimiterC[
|
107
|
+
rate_limiter: (RateLimiterC[Messages, AssistantMessage] | None) = None,
|
114
108
|
rate_limiter_rpm: float | None = None,
|
115
109
|
rate_limiter_chunk_size: int = 1000,
|
116
110
|
rate_limiter_max_concurrency: int = 300,
|
@@ -144,24 +138,26 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
144
138
|
self._api_provider: APIProvider = api_provider
|
145
139
|
self._api_model_name: str = api_model_name
|
146
140
|
|
147
|
-
self.
|
141
|
+
self._struct_outputs_support: bool = any(
|
148
142
|
fnmatch.fnmatch(self._model_name, pat)
|
149
|
-
for pat in PROVIDERS[api_provider]["
|
143
|
+
for pat in PROVIDERS[api_provider]["struct_outputs_support"]
|
150
144
|
)
|
151
145
|
if (
|
152
|
-
self._llm_settings.get("
|
153
|
-
and not self.
|
146
|
+
self._llm_settings.get("use_struct_outputs")
|
147
|
+
and not self._struct_outputs_support
|
154
148
|
):
|
155
149
|
raise ValueError(
|
156
150
|
f"Model {self._model_name} does not support structured outputs."
|
157
151
|
)
|
158
152
|
|
159
|
-
self.
|
153
|
+
self._tool_call_settings: dict[str, Any] = {}
|
154
|
+
|
155
|
+
self._rate_limiter: RateLimiterC[Messages, AssistantMessage] | None = (
|
160
156
|
self._get_rate_limiter(
|
161
157
|
rate_limiter=rate_limiter,
|
162
|
-
|
163
|
-
|
164
|
-
|
158
|
+
rpm=rate_limiter_rpm,
|
159
|
+
chunk_size=rate_limiter_chunk_size,
|
160
|
+
max_concurrency=rate_limiter_max_concurrency,
|
165
161
|
)
|
166
162
|
)
|
167
163
|
self.no_tqdm = no_tqdm
|
@@ -188,29 +184,36 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
188
184
|
@property
|
189
185
|
def rate_limiter(
|
190
186
|
self,
|
191
|
-
) -> RateLimiterC[
|
187
|
+
) -> RateLimiterC[Messages, AssistantMessage] | None:
|
192
188
|
return self._rate_limiter
|
193
189
|
|
194
190
|
def _make_completion_kwargs(
|
195
|
-
self,
|
191
|
+
self,
|
192
|
+
conversation: Messages,
|
193
|
+
tool_choice: ToolChoice | None = None,
|
194
|
+
n_choices: int | None = None,
|
196
195
|
) -> dict[str, Any]:
|
197
196
|
api_messages = [self._converters.to_message(m) for m in conversation]
|
198
197
|
|
199
198
|
api_tools = None
|
200
199
|
api_tool_choice = None
|
201
200
|
if self.tools:
|
202
|
-
api_tools = [
|
201
|
+
api_tools = [
|
202
|
+
self._converters.to_tool(t, **self._tool_call_settings)
|
203
|
+
for t in self.tools.values()
|
204
|
+
]
|
203
205
|
if tool_choice is not None:
|
204
206
|
api_tool_choice = self._converters.to_tool_choice(tool_choice)
|
205
207
|
|
206
208
|
api_llm_settings = deepcopy(self.llm_settings or {})
|
207
|
-
api_llm_settings.pop("
|
209
|
+
api_llm_settings.pop("use_struct_outputs", None)
|
208
210
|
|
209
211
|
return dict(
|
210
212
|
api_messages=api_messages,
|
211
213
|
api_tools=api_tools,
|
212
214
|
api_tool_choice=api_tool_choice,
|
213
215
|
api_response_format=self._response_format,
|
216
|
+
n_choices=n_choices,
|
214
217
|
**api_llm_settings,
|
215
218
|
)
|
216
219
|
|
@@ -221,6 +224,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
221
224
|
*,
|
222
225
|
api_tools: list[Any] | None = None,
|
223
226
|
api_tool_choice: Any | None = None,
|
227
|
+
n_choices: int | None = None,
|
224
228
|
**api_llm_settings: Any,
|
225
229
|
) -> Any:
|
226
230
|
pass
|
@@ -233,6 +237,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
233
237
|
api_tools: list[Any] | None = None,
|
234
238
|
api_tool_choice: Any | None = None,
|
235
239
|
api_response_format: type | None = None,
|
240
|
+
n_choices: int | None = None,
|
236
241
|
**api_llm_settings: Any,
|
237
242
|
) -> Any:
|
238
243
|
pass
|
@@ -244,161 +249,143 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
244
249
|
*,
|
245
250
|
api_tools: list[Any] | None = None,
|
246
251
|
api_tool_choice: Any | None = None,
|
252
|
+
n_choices: int | None = None,
|
247
253
|
**api_llm_settings: Any,
|
248
254
|
) -> AsyncIterator[Any]:
|
249
255
|
pass
|
250
256
|
|
251
|
-
|
257
|
+
@abstractmethod
|
258
|
+
async def _get_parsed_completion_stream(
|
259
|
+
self,
|
260
|
+
api_messages: list[Any],
|
261
|
+
*,
|
262
|
+
api_tools: list[Any] | None = None,
|
263
|
+
api_tool_choice: Any | None = None,
|
264
|
+
api_response_format: type | None = None,
|
265
|
+
n_choices: int | None = None,
|
266
|
+
**api_llm_settings: Any,
|
267
|
+
) -> AsyncIterator[Any]:
|
268
|
+
pass
|
269
|
+
|
270
|
+
async def generate_completion_no_retry(
|
252
271
|
self,
|
253
|
-
conversation:
|
272
|
+
conversation: Messages,
|
254
273
|
*,
|
255
274
|
tool_choice: ToolChoice | None = None,
|
256
|
-
|
275
|
+
n_choices: int | None = None,
|
257
276
|
) -> Completion:
|
258
277
|
completion_kwargs = self._make_completion_kwargs(
|
259
|
-
conversation=conversation, tool_choice=tool_choice
|
278
|
+
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
260
279
|
)
|
261
280
|
|
262
|
-
if (
|
263
|
-
self._response_format is None
|
264
|
-
or (not self._struct_output_support)
|
265
|
-
or (not self._llm_settings.get("use_structured_outputs"))
|
266
|
-
):
|
281
|
+
if not self._llm_settings.get("use_struct_outputs"):
|
267
282
|
completion_kwargs.pop("api_response_format", None)
|
268
|
-
api_completion = await self._get_completion(**completion_kwargs
|
283
|
+
api_completion = await self._get_completion(**completion_kwargs)
|
269
284
|
else:
|
270
|
-
api_completion = await self._get_parsed_completion(
|
271
|
-
**completion_kwargs, **kwargs
|
272
|
-
)
|
285
|
+
api_completion = await self._get_parsed_completion(**completion_kwargs)
|
273
286
|
|
274
287
|
completion = self._converters.from_completion(
|
275
|
-
api_completion,
|
288
|
+
api_completion, name=self.model_id
|
276
289
|
)
|
277
290
|
|
278
|
-
self.
|
291
|
+
if not self._llm_settings.get("use_struct_outputs"):
|
292
|
+
# If validation is not handled by the structured output functionality
|
293
|
+
# of the LLM provider
|
294
|
+
self._validate_completion(completion)
|
295
|
+
self._validate_tool_calls(completion)
|
279
296
|
|
280
297
|
return completion
|
281
298
|
|
282
|
-
def _validate_completion(self, completion: Completion) -> None:
|
283
|
-
for choice in completion.choices:
|
284
|
-
message = choice.message
|
285
|
-
if (
|
286
|
-
self._response_format_pyd is not None
|
287
|
-
and not self._llm_settings.get("use_structured_outputs")
|
288
|
-
and not message.tool_calls
|
289
|
-
):
|
290
|
-
validate_obj_from_json_or_py_string(
|
291
|
-
message.content or "",
|
292
|
-
adapter=self._response_format_pyd,
|
293
|
-
from_substring=True,
|
294
|
-
)
|
295
|
-
|
296
299
|
async def generate_completion_stream(
|
297
300
|
self,
|
298
|
-
conversation:
|
301
|
+
conversation: Messages,
|
299
302
|
*,
|
300
303
|
tool_choice: ToolChoice | None = None,
|
301
|
-
|
302
|
-
) -> AsyncIterator[
|
304
|
+
n_choices: int | None = None,
|
305
|
+
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
303
306
|
completion_kwargs = self._make_completion_kwargs(
|
304
|
-
conversation=conversation, tool_choice=tool_choice
|
305
|
-
)
|
306
|
-
completion_kwargs.pop("api_response_format", None)
|
307
|
-
api_completion_chunk_iterator = await self._get_completion_stream(
|
308
|
-
**completion_kwargs, **kwargs
|
307
|
+
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
309
308
|
)
|
310
309
|
|
311
|
-
|
312
|
-
|
313
|
-
|
310
|
+
if not self._llm_settings.get("use_struct_outputs"):
|
311
|
+
completion_kwargs.pop("api_response_format", None)
|
312
|
+
api_stream = await self._get_completion_stream(**completion_kwargs)
|
313
|
+
else:
|
314
|
+
api_stream = await self._get_parsed_completion_stream(**completion_kwargs)
|
315
|
+
|
316
|
+
async def iterate() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
317
|
+
completion_chunks: list[CompletionChunk] = []
|
318
|
+
async for api_completion_chunk in api_stream:
|
319
|
+
completion_chunk = self._converters.from_completion_chunk(
|
320
|
+
api_completion_chunk, name=self.model_id
|
321
|
+
)
|
322
|
+
completion_chunks.append(completion_chunk)
|
323
|
+
yield CompletionChunkEvent(data=completion_chunk, name=self.model_id)
|
324
|
+
|
325
|
+
# TODO: can be done using the OpenAI final_completion_chunk
|
326
|
+
completion = combine_completion_chunks(completion_chunks)
|
327
|
+
|
328
|
+
yield CompletionEvent(data=completion, name=self.model_id)
|
314
329
|
|
315
|
-
|
330
|
+
if not self._llm_settings.get("use_struct_outputs"):
|
331
|
+
# If validation is not handled by the structured outputs functionality
|
332
|
+
# of the LLM provider
|
333
|
+
self._validate_completion(completion)
|
334
|
+
self._validate_tool_calls(completion)
|
335
|
+
|
336
|
+
return iterate()
|
337
|
+
|
338
|
+
async def generate_completion(
|
316
339
|
self,
|
317
|
-
conversation:
|
340
|
+
conversation: Messages,
|
318
341
|
*,
|
319
342
|
tool_choice: ToolChoice | None = None,
|
320
|
-
|
343
|
+
n_choices: int | None = None,
|
321
344
|
) -> Completion:
|
322
345
|
wrapped_func = retry(
|
323
346
|
wait=wait_random_exponential(min=1, max=8),
|
324
347
|
stop=stop_after_attempt(self.num_generation_retries + 1),
|
325
348
|
before=retry_before_callback,
|
326
349
|
retry_error_callback=retry_error_callback,
|
327
|
-
)(self.__class__.
|
350
|
+
)(self.__class__.generate_completion_no_retry)
|
328
351
|
|
329
|
-
return await wrapped_func(
|
352
|
+
return await wrapped_func(
|
353
|
+
self, conversation, tool_choice=tool_choice, n_choices=n_choices
|
354
|
+
)
|
330
355
|
|
331
356
|
@limit_rate_chunked # type: ignore
|
332
|
-
async def
|
357
|
+
async def _generate_completion_batch(
|
333
358
|
self,
|
334
|
-
conversation:
|
359
|
+
conversation: Messages,
|
335
360
|
*,
|
336
361
|
tool_choice: ToolChoice | None = None,
|
337
|
-
**kwargs: Any,
|
338
362
|
) -> Completion:
|
339
|
-
return await self.
|
340
|
-
conversation, tool_choice=tool_choice, **kwargs
|
341
|
-
)
|
363
|
+
return await self.generate_completion(conversation, tool_choice=tool_choice)
|
342
364
|
|
343
365
|
async def generate_completion_batch(
|
344
|
-
self,
|
345
|
-
message_history: MessageHistory,
|
346
|
-
*,
|
347
|
-
tool_choice: ToolChoice | None = None,
|
348
|
-
**kwargs: Any,
|
366
|
+
self, message_history: MessageHistory, *, tool_choice: ToolChoice | None = None
|
349
367
|
) -> Sequence[Completion]:
|
350
|
-
return await self.
|
351
|
-
list(message_history.
|
368
|
+
return await self._generate_completion_batch(
|
369
|
+
list(message_history.conversations), # type: ignore
|
352
370
|
tool_choice=tool_choice,
|
353
|
-
**kwargs,
|
354
|
-
)
|
355
|
-
|
356
|
-
async def generate_message(
|
357
|
-
self,
|
358
|
-
conversation: Conversation,
|
359
|
-
*,
|
360
|
-
tool_choice: ToolChoice | None = None,
|
361
|
-
**kwargs: Any,
|
362
|
-
) -> AssistantMessage:
|
363
|
-
completion = await self.generate_completion(
|
364
|
-
conversation, tool_choice=tool_choice, **kwargs
|
365
371
|
)
|
366
372
|
|
367
|
-
return completion.choices[0].message
|
368
|
-
|
369
|
-
async def generate_message_batch(
|
370
|
-
self,
|
371
|
-
message_history: MessageHistory,
|
372
|
-
*,
|
373
|
-
tool_choice: ToolChoice | None = None,
|
374
|
-
**kwargs: Any,
|
375
|
-
) -> Sequence[AssistantMessage]:
|
376
|
-
completion_batch = await self.generate_completion_batch(
|
377
|
-
message_history, tool_choice=tool_choice, **kwargs
|
378
|
-
)
|
379
|
-
|
380
|
-
return [completion.choices[0].message for completion in completion_batch]
|
381
|
-
|
382
373
|
def _get_rate_limiter(
|
383
374
|
self,
|
384
|
-
rate_limiter: RateLimiterC[
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
) -> RateLimiterC[
|
375
|
+
rate_limiter: RateLimiterC[Messages, AssistantMessage] | None = None,
|
376
|
+
rpm: float | None = None,
|
377
|
+
chunk_size: int = 1000,
|
378
|
+
max_concurrency: int = 300,
|
379
|
+
) -> RateLimiterC[Messages, AssistantMessage] | None:
|
389
380
|
if rate_limiter is not None:
|
390
381
|
logger.info(
|
391
382
|
f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
|
392
383
|
)
|
393
384
|
return rate_limiter
|
394
|
-
if
|
395
|
-
logger.info(
|
396
|
-
f"[{self.__class__.__name__}] Set rate limit to {rate_limiter_rpm} RPM"
|
397
|
-
)
|
385
|
+
if rpm is not None:
|
386
|
+
logger.info(f"[{self.__class__.__name__}] Set rate limit to {rpm} RPM")
|
398
387
|
return RateLimiterC(
|
399
|
-
rpm=
|
400
|
-
chunk_size=rate_limiter_chunk_size,
|
401
|
-
max_concurrency=rate_limiter_max_concurrency,
|
388
|
+
rpm=rpm, chunk_size=chunk_size, max_concurrency=max_concurrency
|
402
389
|
)
|
403
390
|
|
404
391
|
return None
|
@@ -0,0 +1,201 @@
|
|
1
|
+
import logging
|
2
|
+
from collections.abc import AsyncIterator, Sequence
|
3
|
+
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
from pydantic.json_schema import SkipJsonSchema
|
7
|
+
|
8
|
+
from .packet import Packet
|
9
|
+
from .packet_pool import PacketPool
|
10
|
+
from .processor import Processor
|
11
|
+
from .run_context import CtxT, RunContext
|
12
|
+
from .typing.events import Event, PacketEvent
|
13
|
+
from .typing.io import InT_contra, MemT_co, OutT_co, ProcName
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class DynCommPayload(BaseModel):
|
19
|
+
selected_recipients: SkipJsonSchema[Sequence[ProcName]]
|
20
|
+
|
21
|
+
|
22
|
+
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
23
|
+
|
24
|
+
|
25
|
+
class ExitCommunicationHandler(Protocol[_OutT_contra, CtxT]):
|
26
|
+
def __call__(
|
27
|
+
self,
|
28
|
+
out_packet: Packet[_OutT_contra],
|
29
|
+
ctx: RunContext[CtxT] | None,
|
30
|
+
) -> bool: ...
|
31
|
+
|
32
|
+
|
33
|
+
class CommProcessor(
|
34
|
+
Processor[InT_contra, OutT_co, MemT_co, CtxT],
|
35
|
+
Generic[InT_contra, OutT_co, MemT_co, CtxT],
|
36
|
+
):
|
37
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
38
|
+
0: "_in_type",
|
39
|
+
1: "_out_type",
|
40
|
+
}
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
name: ProcName,
|
45
|
+
*,
|
46
|
+
recipients: Sequence[ProcName] | None = None,
|
47
|
+
packet_pool: PacketPool[CtxT] | None = None,
|
48
|
+
) -> None:
|
49
|
+
super().__init__(name=name)
|
50
|
+
|
51
|
+
self.recipients = recipients or []
|
52
|
+
|
53
|
+
self._packet_pool = packet_pool
|
54
|
+
self._is_listening = False
|
55
|
+
self._exit_communication_impl: (
|
56
|
+
ExitCommunicationHandler[OutT_co, CtxT] | None
|
57
|
+
) = None
|
58
|
+
|
59
|
+
def _validate_routing(self, payloads: Sequence[OutT_co]) -> Sequence[ProcName]:
|
60
|
+
if all(isinstance(p, DynCommPayload) for p in payloads):
|
61
|
+
payloads_ = cast("Sequence[DynCommPayload]", payloads)
|
62
|
+
selected_recipients_per_payload = [
|
63
|
+
set(p.selected_recipients or []) for p in payloads_
|
64
|
+
]
|
65
|
+
assert all(
|
66
|
+
x == selected_recipients_per_payload[0]
|
67
|
+
for x in selected_recipients_per_payload
|
68
|
+
), "All payloads must have the same recipient IDs for dynamic routing"
|
69
|
+
|
70
|
+
assert payloads_[0].selected_recipients is not None
|
71
|
+
selected_recipients = payloads_[0].selected_recipients
|
72
|
+
|
73
|
+
assert all(rid in self.recipients for rid in selected_recipients), (
|
74
|
+
"Dynamic routing is enabled, but recipient IDs are not in "
|
75
|
+
"the allowed agent's recipient IDs"
|
76
|
+
)
|
77
|
+
|
78
|
+
return selected_recipients
|
79
|
+
|
80
|
+
if all((not isinstance(p, DynCommPayload)) for p in payloads):
|
81
|
+
return self.recipients
|
82
|
+
|
83
|
+
raise ValueError(
|
84
|
+
"All payloads must be either DCommAgentPayload or not DCommAgentPayload"
|
85
|
+
)
|
86
|
+
|
87
|
+
async def run(
|
88
|
+
self,
|
89
|
+
chat_inputs: Any | None = None,
|
90
|
+
*,
|
91
|
+
in_packet: Packet[InT_contra] | None = None,
|
92
|
+
in_args: InT_contra | Sequence[InT_contra] | None = None,
|
93
|
+
forgetful: bool = True,
|
94
|
+
ctx: RunContext[CtxT] | None = None,
|
95
|
+
) -> Packet[OutT_co]:
|
96
|
+
out_packet = await super().run(
|
97
|
+
chat_inputs=chat_inputs,
|
98
|
+
in_packet=in_packet,
|
99
|
+
in_args=in_args,
|
100
|
+
ctx=ctx,
|
101
|
+
)
|
102
|
+
recipients = self._validate_routing(out_packet.payloads)
|
103
|
+
routed_out_packet = Packet(
|
104
|
+
payloads=out_packet.payloads, sender=self.name, recipients=recipients
|
105
|
+
)
|
106
|
+
if self._packet_pool is not None and in_packet is None and in_args is None:
|
107
|
+
# If no input packet or args, we assume this is the first run.
|
108
|
+
await self._packet_pool.post(routed_out_packet)
|
109
|
+
|
110
|
+
return routed_out_packet
|
111
|
+
|
112
|
+
async def run_stream(
|
113
|
+
self,
|
114
|
+
chat_inputs: Any | None = None,
|
115
|
+
*,
|
116
|
+
in_packet: Packet[InT_contra] | None = None,
|
117
|
+
in_args: InT_contra | Sequence[InT_contra] | None = None,
|
118
|
+
forgetful: bool = True,
|
119
|
+
ctx: RunContext[CtxT] | None = None,
|
120
|
+
) -> AsyncIterator[Event[Any]]:
|
121
|
+
out_packet: Packet[OutT_co] | None = None
|
122
|
+
async for event in super().run_stream(
|
123
|
+
chat_inputs=chat_inputs,
|
124
|
+
in_packet=in_packet,
|
125
|
+
in_args=in_args,
|
126
|
+
ctx=ctx,
|
127
|
+
):
|
128
|
+
if isinstance(event, PacketEvent):
|
129
|
+
out_packet = event.data
|
130
|
+
else:
|
131
|
+
yield event
|
132
|
+
|
133
|
+
if out_packet is None:
|
134
|
+
raise RuntimeError("No output packet generated during stream run")
|
135
|
+
|
136
|
+
recipients = self._validate_routing(out_packet.payloads)
|
137
|
+
routed_out_packet = Packet(
|
138
|
+
payloads=out_packet.payloads, sender=self.name, recipients=recipients
|
139
|
+
)
|
140
|
+
if self._packet_pool is not None and in_packet is None and in_args is None:
|
141
|
+
# If no input packet or args, we assume this is the first run.
|
142
|
+
await self._packet_pool.post(routed_out_packet)
|
143
|
+
|
144
|
+
yield PacketEvent(data=routed_out_packet, name=self.name)
|
145
|
+
|
146
|
+
def exit_communication(
|
147
|
+
self, func: ExitCommunicationHandler[OutT_co, CtxT]
|
148
|
+
) -> ExitCommunicationHandler[OutT_co, CtxT]:
|
149
|
+
self._exit_communication_impl = func
|
150
|
+
|
151
|
+
return func
|
152
|
+
|
153
|
+
def _exit_communication_fn(
|
154
|
+
self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT] | None
|
155
|
+
) -> bool:
|
156
|
+
if self._exit_communication_impl:
|
157
|
+
return self._exit_communication_impl(out_packet=out_packet, ctx=ctx)
|
158
|
+
|
159
|
+
return False
|
160
|
+
|
161
|
+
async def _packet_handler(
|
162
|
+
self,
|
163
|
+
packet: Packet[InT_contra],
|
164
|
+
ctx: RunContext[CtxT] | None = None,
|
165
|
+
**run_kwargs: Any,
|
166
|
+
) -> None:
|
167
|
+
assert self._packet_pool is not None, "Packet pool must be initialized"
|
168
|
+
|
169
|
+
out_packet = await self.run(ctx=ctx, in_packet=packet, **run_kwargs)
|
170
|
+
|
171
|
+
if self._exit_communication_fn(out_packet=out_packet, ctx=ctx):
|
172
|
+
await self._packet_pool.stop_all()
|
173
|
+
return
|
174
|
+
|
175
|
+
await self._packet_pool.post(out_packet)
|
176
|
+
|
177
|
+
@property
|
178
|
+
def is_listening(self) -> bool:
|
179
|
+
return self._is_listening
|
180
|
+
|
181
|
+
async def start_listening(
|
182
|
+
self, ctx: RunContext[CtxT] | None = None, **run_kwargs: Any
|
183
|
+
) -> None:
|
184
|
+
assert self._packet_pool is not None, "Packet pool must be initialized"
|
185
|
+
|
186
|
+
if self._is_listening:
|
187
|
+
return
|
188
|
+
|
189
|
+
self._is_listening = True
|
190
|
+
self._packet_pool.register_packet_handler(
|
191
|
+
processor_name=self.name,
|
192
|
+
handler=self._packet_handler,
|
193
|
+
ctx=ctx,
|
194
|
+
**run_kwargs,
|
195
|
+
)
|
196
|
+
|
197
|
+
async def stop_listening(self) -> None:
|
198
|
+
assert self._packet_pool is not None, "Packet pool must be initialized"
|
199
|
+
|
200
|
+
self._is_listening = False
|
201
|
+
await self._packet_pool.unregister_packet_handler(self.name)
|