grasp_agents 0.4.7__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +191 -224
- grasp_agents/comm_processor.py +101 -100
- grasp_agents/errors.py +69 -9
- grasp_agents/litellm/__init__.py +106 -0
- grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents/litellm/converters.py +138 -0
- grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents/llm.py +84 -49
- grasp_agents/llm_agent.py +136 -120
- grasp_agents/llm_agent_memory.py +3 -3
- grasp_agents/llm_policy_executor.py +167 -174
- grasp_agents/memory.py +23 -0
- grasp_agents/openai/__init__.py +24 -9
- grasp_agents/openai/completion_chunk_converters.py +6 -6
- grasp_agents/openai/completion_converters.py +12 -14
- grasp_agents/openai/content_converters.py +1 -3
- grasp_agents/openai/converters.py +6 -8
- grasp_agents/openai/message_converters.py +21 -3
- grasp_agents/openai/openai_llm.py +155 -103
- grasp_agents/openai/tool_converters.py +4 -6
- grasp_agents/packet.py +5 -2
- grasp_agents/packet_pool.py +14 -13
- grasp_agents/printer.py +233 -73
- grasp_agents/processor.py +229 -91
- grasp_agents/prompt_builder.py +2 -2
- grasp_agents/run_context.py +11 -20
- grasp_agents/runner.py +42 -0
- grasp_agents/typing/completion.py +16 -9
- grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents/typing/events.py +95 -19
- grasp_agents/typing/message.py +25 -1
- grasp_agents/typing/tool.py +2 -0
- grasp_agents/usage_tracker.py +31 -37
- grasp_agents/utils.py +95 -84
- grasp_agents/workflow/looped_workflow.py +60 -11
- grasp_agents/workflow/sequential_workflow.py +43 -11
- grasp_agents/workflow/workflow_processor.py +25 -24
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/METADATA +7 -6
- grasp_agents-0.5.1.dist-info/RECORD +57 -0
- grasp_agents-0.4.7.dist-info/RECORD +0 -50
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,21 +1,22 @@
|
|
1
|
+
import fnmatch
|
1
2
|
import logging
|
3
|
+
import os
|
2
4
|
from collections.abc import AsyncIterator, Iterable, Mapping
|
3
5
|
from copy import deepcopy
|
4
|
-
from typing import Any, Literal
|
6
|
+
from typing import Any, Literal
|
5
7
|
|
6
8
|
import httpx
|
7
|
-
from openai import AsyncOpenAI
|
9
|
+
from openai import AsyncOpenAI, AsyncStream
|
8
10
|
from openai._types import NOT_GIVEN # type: ignore[import]
|
9
11
|
from openai.lib.streaming.chat import (
|
10
12
|
AsyncChatCompletionStreamManager as OpenAIAsyncChatCompletionStreamManager,
|
11
13
|
)
|
14
|
+
from openai.lib.streaming.chat import ChatCompletionStreamState
|
12
15
|
from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
|
13
16
|
from pydantic import BaseModel
|
14
17
|
|
15
|
-
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
18
|
+
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings, LLMRateLimiter
|
16
19
|
from ..http_client import AsyncHTTPClientParams
|
17
|
-
from ..rate_limiting.rate_limiter_chunked import RateLimiterC
|
18
|
-
from ..typing.message import AssistantMessage, Messages
|
19
20
|
from ..typing.tool import BaseTool
|
20
21
|
from . import (
|
21
22
|
OpenAICompletion,
|
@@ -23,17 +24,40 @@ from . import (
|
|
23
24
|
OpenAIMessageParam,
|
24
25
|
OpenAIParsedCompletion,
|
25
26
|
OpenAIPredictionContentParam,
|
27
|
+
OpenAIResponseFormatJSONObject,
|
28
|
+
OpenAIResponseFormatText,
|
26
29
|
OpenAIStreamOptionsParam,
|
27
30
|
OpenAIToolChoiceOptionParam,
|
28
31
|
OpenAIToolParam,
|
32
|
+
OpenAIWebSearchOptions,
|
29
33
|
)
|
30
34
|
from .converters import OpenAIConverters
|
31
35
|
|
32
36
|
logger = logging.getLogger(__name__)
|
33
37
|
|
34
38
|
|
35
|
-
|
36
|
-
|
39
|
+
def get_openai_compatible_providers() -> list[APIProvider]:
|
40
|
+
"""Returns a dictionary of available OpenAI-compatible API providers."""
|
41
|
+
return [
|
42
|
+
APIProvider(
|
43
|
+
name="openai",
|
44
|
+
base_url="https://api.openai.com/v1",
|
45
|
+
api_key=os.getenv("OPENAI_API_KEY"),
|
46
|
+
response_schema_support=("*",),
|
47
|
+
),
|
48
|
+
APIProvider(
|
49
|
+
name="gemini_openai",
|
50
|
+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
51
|
+
api_key=os.getenv("GEMINI_API_KEY"),
|
52
|
+
response_schema_support=("*",),
|
53
|
+
),
|
54
|
+
APIProvider(
|
55
|
+
name="openrouter",
|
56
|
+
base_url="https://openrouter.ai/api/v1",
|
57
|
+
api_key=os.getenv("OPENROUTER_API_KEY"),
|
58
|
+
response_schema_support=(),
|
59
|
+
),
|
60
|
+
]
|
37
61
|
|
38
62
|
|
39
63
|
class OpenAILLMSettings(CloudLLMSettings, total=False):
|
@@ -41,7 +65,7 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
|
|
41
65
|
|
42
66
|
parallel_tool_calls: bool
|
43
67
|
|
44
|
-
modalities: list[Literal["text"
|
68
|
+
modalities: list[Literal["text"]] | None
|
45
69
|
|
46
70
|
frequency_penalty: float | None
|
47
71
|
presence_penalty: float | None
|
@@ -50,19 +74,20 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
|
|
50
74
|
logprobs: bool | None
|
51
75
|
top_logprobs: int | None
|
52
76
|
|
77
|
+
stream_options: OpenAIStreamOptionsParam | None
|
78
|
+
|
53
79
|
prediction: OpenAIPredictionContentParam | None
|
54
80
|
|
55
|
-
|
81
|
+
web_search_options: OpenAIWebSearchOptions | None
|
56
82
|
|
57
83
|
metadata: dict[str, str] | None
|
58
84
|
store: bool | None
|
59
85
|
user: str
|
60
86
|
|
61
|
-
#
|
62
|
-
|
63
|
-
|
64
|
-
#
|
65
|
-
# )
|
87
|
+
# To support the old JSON mode without respose schemas
|
88
|
+
response_format: OpenAIResponseFormatJSONObject | OpenAIResponseFormatText
|
89
|
+
|
90
|
+
# TODO: support audio
|
66
91
|
|
67
92
|
|
68
93
|
class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
@@ -72,52 +97,87 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
72
97
|
model_name: str,
|
73
98
|
llm_settings: OpenAILLMSettings | None = None,
|
74
99
|
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
75
|
-
|
100
|
+
response_schema: Any | None = None,
|
101
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
102
|
+
apply_response_schema_via_provider: bool = False,
|
76
103
|
model_id: str | None = None,
|
77
104
|
# Custom LLM provider
|
78
105
|
api_provider: APIProvider | None = None,
|
79
106
|
# Connection settings
|
107
|
+
max_client_retries: int = 2,
|
80
108
|
async_http_client: httpx.AsyncClient | None = None,
|
81
109
|
async_http_client_params: (
|
82
110
|
dict[str, Any] | AsyncHTTPClientParams | None
|
83
111
|
) = None,
|
84
112
|
async_openai_client_params: dict[str, Any] | None = None,
|
85
113
|
# Rate limiting
|
86
|
-
rate_limiter:
|
87
|
-
|
88
|
-
|
89
|
-
rate_limiter_max_concurrency: int = 300,
|
90
|
-
# Retries
|
91
|
-
num_generation_retries: int = 0,
|
114
|
+
rate_limiter: LLMRateLimiter | None = None,
|
115
|
+
# LLM response retries: try to regenerate to pass validation
|
116
|
+
max_response_retries: int = 1,
|
92
117
|
) -> None:
|
118
|
+
openai_compatible_providers = get_openai_compatible_providers()
|
119
|
+
|
120
|
+
model_name_parts = model_name.split("/", 1)
|
121
|
+
if api_provider is not None:
|
122
|
+
provider_model_name = model_name
|
123
|
+
elif len(model_name_parts) == 2:
|
124
|
+
compat_providers_map = {
|
125
|
+
provider["name"]: provider for provider in openai_compatible_providers
|
126
|
+
}
|
127
|
+
provider_name, provider_model_name = model_name_parts
|
128
|
+
if provider_name not in compat_providers_map:
|
129
|
+
raise ValueError(
|
130
|
+
f"OpenAI compatible API provider '{provider_name}' "
|
131
|
+
"is not supported. Supported providers are: "
|
132
|
+
f"{', '.join(compat_providers_map.keys())}"
|
133
|
+
)
|
134
|
+
api_provider = compat_providers_map[provider_name]
|
135
|
+
else:
|
136
|
+
raise ValueError(
|
137
|
+
"Model name must be in the format 'provider/model_name' or "
|
138
|
+
"you must provide an 'api_provider' argument."
|
139
|
+
)
|
140
|
+
|
93
141
|
super().__init__(
|
94
|
-
model_name=
|
142
|
+
model_name=provider_model_name,
|
95
143
|
model_id=model_id,
|
96
144
|
llm_settings=llm_settings,
|
97
145
|
converters=OpenAIConverters(),
|
98
146
|
tools=tools,
|
99
|
-
|
147
|
+
response_schema=response_schema,
|
148
|
+
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
149
|
+
apply_response_schema_via_provider=apply_response_schema_via_provider,
|
100
150
|
api_provider=api_provider,
|
101
151
|
async_http_client=async_http_client,
|
102
152
|
async_http_client_params=async_http_client_params,
|
103
153
|
rate_limiter=rate_limiter,
|
104
|
-
|
105
|
-
|
106
|
-
rate_limiter_max_concurrency=rate_limiter_max_concurrency,
|
107
|
-
num_generation_retries=num_generation_retries,
|
154
|
+
max_client_retries=max_client_retries,
|
155
|
+
max_response_retries=max_response_retries,
|
108
156
|
)
|
109
157
|
|
110
|
-
|
111
|
-
|
112
|
-
|
158
|
+
response_schema_support: bool = any(
|
159
|
+
fnmatch.fnmatch(self._model_name, pat)
|
160
|
+
for pat in api_provider.get("response_schema_support") or []
|
161
|
+
)
|
162
|
+
if apply_response_schema_via_provider:
|
163
|
+
if self._tools:
|
164
|
+
for tool in self._tools.values():
|
165
|
+
tool.strict = True
|
166
|
+
if not response_schema_support:
|
167
|
+
raise ValueError(
|
168
|
+
"Native response schema validation is not supported for model "
|
169
|
+
f"'{self._model_name}' by the API provider. Please set "
|
170
|
+
"apply_response_schema_via_provider=False."
|
171
|
+
)
|
113
172
|
|
114
173
|
_async_openai_client_params = deepcopy(async_openai_client_params or {})
|
115
174
|
if self._async_http_client is not None:
|
116
175
|
_async_openai_client_params["http_client"] = self._async_http_client
|
117
176
|
|
118
177
|
self._client: AsyncOpenAI = AsyncOpenAI(
|
119
|
-
base_url=self.
|
120
|
-
api_key=self.
|
178
|
+
base_url=self.api_provider.get("base_url"),
|
179
|
+
api_key=self.api_provider.get("api_key"),
|
180
|
+
max_retries=max_client_retries,
|
121
181
|
**_async_openai_client_params,
|
122
182
|
)
|
123
183
|
|
@@ -126,15 +186,28 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
126
186
|
api_messages: Iterable[OpenAIMessageParam],
|
127
187
|
api_tools: list[OpenAIToolParam] | None = None,
|
128
188
|
api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
|
189
|
+
api_response_schema: type[Any] | None = None,
|
129
190
|
n_choices: int | None = None,
|
130
191
|
**api_llm_settings: Any,
|
131
|
-
) -> OpenAICompletion:
|
192
|
+
) -> OpenAICompletion | OpenAIParsedCompletion[Any]:
|
132
193
|
tools = api_tools or NOT_GIVEN
|
133
194
|
tool_choice = api_tool_choice or NOT_GIVEN
|
195
|
+
response_format = api_response_schema or NOT_GIVEN
|
134
196
|
n = n_choices or NOT_GIVEN
|
135
197
|
|
198
|
+
if self._apply_response_schema_via_provider:
|
199
|
+
return await self._client.beta.chat.completions.parse(
|
200
|
+
model=self._model_name,
|
201
|
+
messages=api_messages,
|
202
|
+
tools=tools,
|
203
|
+
tool_choice=tool_choice,
|
204
|
+
response_format=response_format,
|
205
|
+
n=n,
|
206
|
+
**api_llm_settings,
|
207
|
+
)
|
208
|
+
|
136
209
|
return await self._client.chat.completions.create(
|
137
|
-
model=self.
|
210
|
+
model=self._model_name,
|
138
211
|
messages=api_messages,
|
139
212
|
tools=tools,
|
140
213
|
tool_choice=tool_choice,
|
@@ -143,89 +216,68 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
143
216
|
**api_llm_settings,
|
144
217
|
)
|
145
218
|
|
146
|
-
async def
|
147
|
-
self,
|
148
|
-
api_messages: Iterable[OpenAIMessageParam],
|
149
|
-
api_tools: list[OpenAIToolParam] | None = None,
|
150
|
-
api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
|
151
|
-
api_response_format: type | None = None,
|
152
|
-
n_choices: int | None = None,
|
153
|
-
**api_llm_settings: Any,
|
154
|
-
) -> OpenAIParsedCompletion[Any]:
|
155
|
-
tools = api_tools or NOT_GIVEN
|
156
|
-
tool_choice = api_tool_choice or NOT_GIVEN
|
157
|
-
n = n_choices or NOT_GIVEN
|
158
|
-
response_format = api_response_format or NOT_GIVEN
|
159
|
-
|
160
|
-
return await self._client.beta.chat.completions.parse(
|
161
|
-
model=self._api_model_name,
|
162
|
-
messages=api_messages,
|
163
|
-
tools=tools,
|
164
|
-
tool_choice=tool_choice,
|
165
|
-
response_format=response_format,
|
166
|
-
n=n,
|
167
|
-
**api_llm_settings,
|
168
|
-
)
|
169
|
-
|
170
|
-
async def _get_completion_stream(
|
219
|
+
async def _get_completion_stream( # type: ignore[override]
|
171
220
|
self,
|
172
221
|
api_messages: Iterable[OpenAIMessageParam],
|
173
222
|
api_tools: list[OpenAIToolParam] | None = None,
|
174
223
|
api_tool_choice: OpenAIToolChoiceOptionParam | None = None,
|
224
|
+
api_response_schema: type[Any] | None = None,
|
175
225
|
n_choices: int | None = None,
|
176
226
|
**api_llm_settings: Any,
|
177
227
|
) -> AsyncIterator[OpenAICompletionChunk]:
|
178
228
|
tools = api_tools or NOT_GIVEN
|
179
229
|
tool_choice = api_tool_choice or NOT_GIVEN
|
230
|
+
response_format = api_response_schema or NOT_GIVEN
|
180
231
|
n = n_choices or NOT_GIVEN
|
181
232
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
233
|
+
if self._apply_response_schema_via_provider:
|
234
|
+
stream_manager: OpenAIAsyncChatCompletionStreamManager[Any] = (
|
235
|
+
self._client.beta.chat.completions.stream(
|
236
|
+
model=self._model_name,
|
237
|
+
messages=api_messages,
|
238
|
+
tools=tools,
|
239
|
+
tool_choice=tool_choice,
|
240
|
+
response_format=response_format,
|
241
|
+
n=n,
|
242
|
+
**api_llm_settings,
|
243
|
+
)
|
244
|
+
)
|
245
|
+
async with stream_manager as stream:
|
246
|
+
async for chunk_event in stream:
|
247
|
+
if isinstance(chunk_event, OpenAIChunkEvent):
|
248
|
+
yield chunk_event.chunk
|
249
|
+
else:
|
250
|
+
stream_generator: AsyncStream[
|
251
|
+
OpenAICompletionChunk
|
252
|
+
] = await self._client.chat.completions.create(
|
253
|
+
model=self._model_name,
|
254
|
+
messages=api_messages,
|
255
|
+
tools=tools,
|
256
|
+
tool_choice=tool_choice,
|
257
|
+
stream=True,
|
258
|
+
n=n,
|
259
|
+
**api_llm_settings,
|
260
|
+
)
|
193
261
|
async with stream_generator as stream:
|
194
262
|
async for completion_chunk in stream:
|
195
263
|
yield completion_chunk
|
196
264
|
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
n = n_choices or NOT_GIVEN
|
212
|
-
|
213
|
-
stream_manager: OpenAIAsyncChatCompletionStreamManager[
|
214
|
-
OpenAICompletionChunk
|
215
|
-
] = self._client.beta.chat.completions.stream(
|
216
|
-
model=self._api_model_name,
|
217
|
-
messages=api_messages,
|
218
|
-
tools=tools,
|
219
|
-
tool_choice=tool_choice,
|
220
|
-
response_format=response_format,
|
221
|
-
n=n,
|
222
|
-
**api_llm_settings,
|
265
|
+
def combine_completion_chunks(
|
266
|
+
self, completion_chunks: list[OpenAICompletionChunk]
|
267
|
+
) -> OpenAICompletion:
|
268
|
+
response_format = NOT_GIVEN
|
269
|
+
input_tools = NOT_GIVEN
|
270
|
+
if self._apply_response_schema_via_provider:
|
271
|
+
if self._response_schema:
|
272
|
+
response_format = self._response_schema
|
273
|
+
if self._tools:
|
274
|
+
input_tools = [
|
275
|
+
self._converters.to_tool(tool) for tool in self._tools.values()
|
276
|
+
]
|
277
|
+
state = ChatCompletionStreamState[Any](
|
278
|
+
input_tools=input_tools, response_format=response_format
|
223
279
|
)
|
280
|
+
for chunk in completion_chunks:
|
281
|
+
state.handle_chunk(chunk)
|
224
282
|
|
225
|
-
|
226
|
-
async with stream_manager as stream:
|
227
|
-
async for chunk_event in stream:
|
228
|
-
if isinstance(chunk_event, OpenAIChunkEvent):
|
229
|
-
yield chunk_event.chunk
|
230
|
-
|
231
|
-
return iterate()
|
283
|
+
return state.get_final_completion()
|
@@ -13,10 +13,8 @@ from . import (
|
|
13
13
|
)
|
14
14
|
|
15
15
|
|
16
|
-
def to_api_tool(
|
17
|
-
tool
|
18
|
-
) -> OpenAIToolParam:
|
19
|
-
if strict:
|
16
|
+
def to_api_tool(tool: BaseTool[BaseModel, Any, Any]) -> OpenAIToolParam:
|
17
|
+
if tool.strict:
|
20
18
|
return pydantic_function_tool(
|
21
19
|
model=tool.in_type, name=tool.name, description=tool.description
|
22
20
|
)
|
@@ -25,9 +23,9 @@ def to_api_tool(
|
|
25
23
|
name=tool.name,
|
26
24
|
description=tool.description,
|
27
25
|
parameters=tool.in_type.model_json_schema(),
|
28
|
-
strict=strict,
|
26
|
+
strict=tool.strict,
|
29
27
|
)
|
30
|
-
if strict is None:
|
28
|
+
if tool.strict is None:
|
31
29
|
function.pop("strict")
|
32
30
|
|
33
31
|
return OpenAIToolParam(type="function", function=function)
|
grasp_agents/packet.py
CHANGED
@@ -15,10 +15,13 @@ class Packet(BaseModel, Generic[_PayloadT_co]):
|
|
15
15
|
sender: ProcName
|
16
16
|
recipients: Sequence[ProcName] = Field(default_factory=list)
|
17
17
|
|
18
|
-
model_config = ConfigDict(extra="forbid"
|
18
|
+
model_config = ConfigDict(extra="forbid")
|
19
19
|
|
20
20
|
def __repr__(self) -> str:
|
21
21
|
return (
|
22
|
-
f"
|
22
|
+
f"{self.__class__.__name__}:\n"
|
23
|
+
f"ID: {self.id}\n"
|
24
|
+
f"From: {self.sender}\n"
|
25
|
+
f"To: {', '.join(self.recipients)}\n"
|
23
26
|
f"Payloads: {len(self.payloads)}"
|
24
27
|
)
|
grasp_agents/packet_pool.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
import asyncio
|
2
2
|
import logging
|
3
|
+
from collections.abc import AsyncIterator
|
3
4
|
from typing import Any, Generic, Protocol, TypeVar
|
4
5
|
|
5
6
|
from .packet import Packet
|
6
7
|
from .run_context import CtxT, RunContext
|
8
|
+
from .typing.events import Event
|
7
9
|
from .typing.io import ProcName
|
8
10
|
|
9
11
|
logger = logging.getLogger(__name__)
|
@@ -16,16 +18,16 @@ class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
|
|
16
18
|
async def __call__(
|
17
19
|
self,
|
18
20
|
packet: Packet[_PayloadT_contra],
|
19
|
-
ctx: RunContext[CtxT]
|
21
|
+
ctx: RunContext[CtxT],
|
20
22
|
**kwargs: Any,
|
21
|
-
) -> None: ...
|
23
|
+
) -> AsyncIterator[Event[Any]] | None: ...
|
22
24
|
|
23
25
|
|
24
26
|
class PacketPool(Generic[CtxT]):
|
25
27
|
def __init__(self) -> None:
|
26
28
|
self._queues: dict[ProcName, asyncio.Queue[Packet[Any]]] = {}
|
27
29
|
self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
|
28
|
-
self._tasks: dict[ProcName, asyncio.Task[None]] = {}
|
30
|
+
self._tasks: dict[ProcName, asyncio.Task[AsyncIterator[Event[Any]] | None]] = {}
|
29
31
|
|
30
32
|
async def post(self, packet: Packet[Any]) -> None:
|
31
33
|
for recipient_id in packet.recipients:
|
@@ -36,7 +38,7 @@ class PacketPool(Generic[CtxT]):
|
|
36
38
|
self,
|
37
39
|
processor_name: ProcName,
|
38
40
|
handler: PacketHandler[Any, CtxT],
|
39
|
-
ctx: RunContext[CtxT]
|
41
|
+
ctx: RunContext[CtxT],
|
40
42
|
**run_kwargs: Any,
|
41
43
|
) -> None:
|
42
44
|
self._packet_handlers[processor_name] = handler
|
@@ -47,11 +49,8 @@ class PacketPool(Generic[CtxT]):
|
|
47
49
|
)
|
48
50
|
|
49
51
|
async def _handle_packets(
|
50
|
-
self,
|
51
|
-
|
52
|
-
ctx: RunContext[CtxT] | None = None,
|
53
|
-
**run_kwargs: Any,
|
54
|
-
) -> None:
|
52
|
+
self, processor_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
|
53
|
+
) -> AsyncIterator[Event[Any]] | None:
|
55
54
|
queue = self._queues[processor_name]
|
56
55
|
while True:
|
57
56
|
try:
|
@@ -59,11 +58,13 @@ class PacketPool(Generic[CtxT]):
|
|
59
58
|
handler = self._packet_handlers.get(processor_name)
|
60
59
|
if handler is None:
|
61
60
|
break
|
62
|
-
|
63
61
|
try:
|
64
|
-
|
65
|
-
packet, ctx=ctx, **run_kwargs
|
66
|
-
|
62
|
+
if ctx.is_streaming:
|
63
|
+
async for event in handler(packet, ctx=ctx, **run_kwargs): # type: ignore[return-value]
|
64
|
+
yield event
|
65
|
+
else:
|
66
|
+
await handler(packet, ctx=ctx, **run_kwargs)
|
67
|
+
|
67
68
|
except Exception:
|
68
69
|
logger.exception(f"Error handling packet for {processor_name}")
|
69
70
|
|