grasp_agents 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +4 -6
- grasp_agents/errors.py +80 -18
- grasp_agents/llm_agent.py +106 -146
- grasp_agents/llm_agent_memory.py +1 -1
- grasp_agents/llm_policy_executor.py +17 -15
- grasp_agents/packet.py +23 -4
- grasp_agents/packet_pool.py +117 -50
- grasp_agents/printer.py +9 -5
- grasp_agents/processor.py +217 -166
- grasp_agents/prompt_builder.py +75 -138
- grasp_agents/run_context.py +3 -16
- grasp_agents/runner.py +110 -21
- grasp_agents/typing/events.py +8 -4
- grasp_agents/typing/io.py +1 -8
- grasp_agents/workflow/looped_workflow.py +13 -19
- grasp_agents/workflow/sequential_workflow.py +6 -10
- grasp_agents/workflow/workflow_processor.py +23 -16
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/RECORD +21 -22
- grasp_agents/comm_processor.py +0 -214
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.3.dist-info → grasp_agents-0.5.5.dist-info}/licenses/LICENSE.md +0 -0
@@ -3,7 +3,7 @@ import json
|
|
3
3
|
from collections.abc import AsyncIterator, Coroutine, Sequence
|
4
4
|
from itertools import starmap
|
5
5
|
from logging import getLogger
|
6
|
-
from typing import Any, Generic, Protocol
|
6
|
+
from typing import Any, Generic, Protocol, final
|
7
7
|
|
8
8
|
from pydantic import BaseModel
|
9
9
|
|
@@ -29,7 +29,7 @@ from .typing.tool import BaseTool, NamedToolChoice, ToolCall, ToolChoice
|
|
29
29
|
logger = getLogger(__name__)
|
30
30
|
|
31
31
|
|
32
|
-
class
|
32
|
+
class ToolCallLoopTerminator(Protocol[CtxT]):
|
33
33
|
def __call__(
|
34
34
|
self,
|
35
35
|
conversation: Messages,
|
@@ -39,7 +39,7 @@ class ExitToolCallLoopHandler(Protocol[CtxT]):
|
|
39
39
|
) -> bool: ...
|
40
40
|
|
41
41
|
|
42
|
-
class
|
42
|
+
class MemoryManager(Protocol[CtxT]):
|
43
43
|
def __call__(
|
44
44
|
self,
|
45
45
|
memory: LLMAgentMemory,
|
@@ -78,8 +78,8 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
78
78
|
self._max_turns = max_turns
|
79
79
|
self._react_mode = react_mode
|
80
80
|
|
81
|
-
self.
|
82
|
-
self.
|
81
|
+
self.tool_call_loop_terminator: ToolCallLoopTerminator[CtxT] | None = None
|
82
|
+
self.memory_manager: MemoryManager[CtxT] | None = None
|
83
83
|
|
84
84
|
@property
|
85
85
|
def agent_name(self) -> str:
|
@@ -97,18 +97,20 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
97
97
|
def max_turns(self) -> int:
|
98
98
|
return self._max_turns
|
99
99
|
|
100
|
-
|
100
|
+
@final
|
101
|
+
def _terminate_tool_call_loop(
|
101
102
|
self,
|
102
103
|
conversation: Messages,
|
103
104
|
*,
|
104
105
|
ctx: RunContext[CtxT] | None = None,
|
105
106
|
**kwargs: Any,
|
106
107
|
) -> bool:
|
107
|
-
if self.
|
108
|
-
return self.
|
108
|
+
if self.tool_call_loop_terminator:
|
109
|
+
return self.tool_call_loop_terminator(conversation, ctx=ctx, **kwargs)
|
109
110
|
|
110
111
|
return False
|
111
112
|
|
113
|
+
@final
|
112
114
|
def _manage_memory(
|
113
115
|
self,
|
114
116
|
memory: LLMAgentMemory,
|
@@ -116,8 +118,8 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
116
118
|
ctx: RunContext[CtxT] | None = None,
|
117
119
|
**kwargs: Any,
|
118
120
|
) -> None:
|
119
|
-
if self.
|
120
|
-
self.
|
121
|
+
if self.memory_manager:
|
122
|
+
self.memory_manager(memory=memory, ctx=ctx, **kwargs)
|
121
123
|
|
122
124
|
async def generate_message(
|
123
125
|
self,
|
@@ -255,7 +257,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
255
257
|
|
256
258
|
final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
|
257
259
|
if final_answer_message is None:
|
258
|
-
raise AgentFinalAnswerError
|
260
|
+
raise AgentFinalAnswerError(proc_name=self.agent_name, call_id=call_id)
|
259
261
|
|
260
262
|
return final_answer_message
|
261
263
|
|
@@ -282,7 +284,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
282
284
|
|
283
285
|
final_answer_message = self._extract_final_answer_from_tool_calls(memory)
|
284
286
|
if final_answer_message is None:
|
285
|
-
raise AgentFinalAnswerError
|
287
|
+
raise AgentFinalAnswerError(proc_name=self.agent_name, call_id=call_id)
|
286
288
|
yield GenMessageEvent(
|
287
289
|
proc_name=self.agent_name, call_id=call_id, data=final_answer_message
|
288
290
|
)
|
@@ -309,8 +311,8 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
309
311
|
# 2. Check if we should exit the tool call loop
|
310
312
|
|
311
313
|
# If a final answer is not provided via a tool call, we use
|
312
|
-
#
|
313
|
-
if not self._final_answer_as_tool_call and self.
|
314
|
+
# _terminate_tool_call_loop to determine whether to exit the loop.
|
315
|
+
if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
|
314
316
|
memory.message_history, ctx=ctx, num_turns=turns
|
315
317
|
):
|
316
318
|
return gen_message
|
@@ -390,7 +392,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
390
392
|
turns = 0
|
391
393
|
|
392
394
|
while True:
|
393
|
-
if not self._final_answer_as_tool_call and self.
|
395
|
+
if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
|
394
396
|
memory.message_history, ctx=ctx, num_turns=turns
|
395
397
|
):
|
396
398
|
return
|
grasp_agents/packet.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
from collections.abc import Sequence
|
2
|
-
from typing import Generic, TypeVar
|
2
|
+
from typing import Annotated, Any, Generic, Literal, TypeVar
|
3
3
|
from uuid import uuid4
|
4
4
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field
|
5
|
+
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
|
6
6
|
|
7
7
|
from .typing.io import ProcName
|
8
8
|
|
9
|
+
START_PROC_NAME: Literal["*START*"] = "*START*"
|
10
|
+
|
9
11
|
_PayloadT_co = TypeVar("_PayloadT_co", covariant=True)
|
10
12
|
|
11
13
|
|
@@ -13,15 +15,32 @@ class Packet(BaseModel, Generic[_PayloadT_co]):
|
|
13
15
|
id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
14
16
|
payloads: Sequence[_PayloadT_co]
|
15
17
|
sender: ProcName
|
16
|
-
recipients: Sequence[ProcName] =
|
18
|
+
recipients: Sequence[ProcName] | None = None
|
17
19
|
|
18
20
|
model_config = ConfigDict(extra="forbid")
|
19
21
|
|
20
22
|
def __repr__(self) -> str:
|
23
|
+
_to = ", ".join(self.recipients) if self.recipients else "None"
|
21
24
|
return (
|
22
25
|
f"{self.__class__.__name__}:\n"
|
23
26
|
f"ID: {self.id}\n"
|
24
27
|
f"From: {self.sender}\n"
|
25
|
-
f"To: {
|
28
|
+
f"To: {_to}\n"
|
26
29
|
f"Payloads: {len(self.payloads)}"
|
27
30
|
)
|
31
|
+
|
32
|
+
|
33
|
+
def _check_recipients_length(v: Sequence[ProcName] | None) -> Sequence[ProcName] | None:
|
34
|
+
if v is not None and len(v) != 1:
|
35
|
+
raise ValueError("recipients must contain exactly one item")
|
36
|
+
return v
|
37
|
+
|
38
|
+
|
39
|
+
class StartPacket(Packet[_PayloadT_co]):
|
40
|
+
chat_inputs: Any | None = "start"
|
41
|
+
sender: ProcName = Field(default=START_PROC_NAME, frozen=True)
|
42
|
+
payloads: Sequence[_PayloadT_co] = Field(default=(), frozen=True)
|
43
|
+
recipients: Annotated[
|
44
|
+
Sequence[ProcName] | None,
|
45
|
+
AfterValidator(_check_recipients_length),
|
46
|
+
] = None
|
grasp_agents/packet_pool.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import asyncio
|
2
2
|
import logging
|
3
3
|
from collections.abc import AsyncIterator
|
4
|
-
from
|
4
|
+
from types import TracebackType
|
5
|
+
from typing import Any, Generic, Literal, Protocol, TypeVar
|
5
6
|
|
6
7
|
from .packet import Packet
|
7
8
|
from .run_context import CtxT, RunContext
|
@@ -11,6 +12,9 @@ from .typing.io import ProcName
|
|
11
12
|
logger = logging.getLogger(__name__)
|
12
13
|
|
13
14
|
|
15
|
+
END_PROC_NAME: Literal["*END*"] = "*END*"
|
16
|
+
|
17
|
+
|
14
18
|
_PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
|
15
19
|
|
16
20
|
|
@@ -20,73 +24,136 @@ class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
|
|
20
24
|
packet: Packet[_PayloadT_contra],
|
21
25
|
ctx: RunContext[CtxT],
|
22
26
|
**kwargs: Any,
|
23
|
-
) ->
|
27
|
+
) -> None: ...
|
24
28
|
|
25
29
|
|
26
30
|
class PacketPool(Generic[CtxT]):
|
27
31
|
def __init__(self) -> None:
|
28
|
-
self.
|
32
|
+
self._packet_queues: dict[ProcName, asyncio.Queue[Packet[Any] | None]] = {}
|
29
33
|
self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
|
30
|
-
self.
|
34
|
+
self._task_group: asyncio.TaskGroup | None = None
|
35
|
+
|
36
|
+
self._event_queue: asyncio.Queue[Event[Any] | None] = asyncio.Queue()
|
37
|
+
|
38
|
+
self._final_result_fut: asyncio.Future[Packet[Any]] | None = None
|
39
|
+
|
40
|
+
self._stopping = False
|
41
|
+
self._stopped_evt = asyncio.Event()
|
42
|
+
|
43
|
+
self._errors: list[Exception] = []
|
31
44
|
|
32
45
|
async def post(self, packet: Packet[Any]) -> None:
|
33
|
-
|
34
|
-
|
46
|
+
if packet.recipients == [END_PROC_NAME]:
|
47
|
+
fut = self._ensure_final_future()
|
48
|
+
if not fut.done():
|
49
|
+
fut.set_result(packet)
|
50
|
+
await self.shutdown()
|
51
|
+
return
|
52
|
+
|
53
|
+
for recipient_id in packet.recipients or []:
|
54
|
+
queue = self._packet_queues.setdefault(recipient_id, asyncio.Queue())
|
35
55
|
await queue.put(packet)
|
36
56
|
|
57
|
+
def _ensure_final_future(self) -> asyncio.Future[Packet[Any]]:
|
58
|
+
fut = self._final_result_fut
|
59
|
+
if fut is None:
|
60
|
+
fut = asyncio.get_running_loop().create_future()
|
61
|
+
self._final_result_fut = fut
|
62
|
+
return fut
|
63
|
+
|
64
|
+
async def final_result(self) -> Packet[Any]:
|
65
|
+
fut = self._ensure_final_future()
|
66
|
+
try:
|
67
|
+
return await fut
|
68
|
+
finally:
|
69
|
+
await self.shutdown()
|
70
|
+
|
37
71
|
def register_packet_handler(
|
38
72
|
self,
|
39
|
-
|
73
|
+
proc_name: ProcName,
|
40
74
|
handler: PacketHandler[Any, CtxT],
|
41
75
|
ctx: RunContext[CtxT],
|
42
76
|
**run_kwargs: Any,
|
43
77
|
) -> None:
|
44
|
-
self.
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
78
|
+
if self._stopping:
|
79
|
+
raise RuntimeError("PacketPool is stopping/stopped")
|
80
|
+
|
81
|
+
self._packet_handlers[proc_name] = handler
|
82
|
+
self._packet_queues.setdefault(proc_name, asyncio.Queue())
|
83
|
+
|
84
|
+
if self._task_group is not None:
|
85
|
+
self._task_group.create_task(
|
86
|
+
self._handle_packets(proc_name, ctx=ctx, **run_kwargs),
|
87
|
+
name=f"packet-handler:{proc_name}",
|
49
88
|
)
|
50
89
|
|
90
|
+
async def push_event(self, event: Event[Any]) -> None:
|
91
|
+
await self._event_queue.put(event)
|
92
|
+
|
93
|
+
async def __aenter__(self) -> "PacketPool[CtxT]":
|
94
|
+
self._task_group = asyncio.TaskGroup()
|
95
|
+
await self._task_group.__aenter__()
|
96
|
+
|
97
|
+
return self
|
98
|
+
|
99
|
+
async def __aexit__(
|
100
|
+
self,
|
101
|
+
exc_type: type[BaseException] | None,
|
102
|
+
exc: BaseException | None,
|
103
|
+
tb: TracebackType | None,
|
104
|
+
) -> bool | None:
|
105
|
+
await self.shutdown()
|
106
|
+
|
107
|
+
if self._task_group is not None:
|
108
|
+
try:
|
109
|
+
return await self._task_group.__aexit__(exc_type, exc, tb)
|
110
|
+
finally:
|
111
|
+
self._task_group = None
|
112
|
+
|
113
|
+
if self._errors:
|
114
|
+
raise ExceptionGroup("PacketPool worker errors", self._errors)
|
115
|
+
|
116
|
+
return False
|
117
|
+
|
51
118
|
async def _handle_packets(
|
52
|
-
self,
|
53
|
-
) ->
|
54
|
-
queue = self.
|
119
|
+
self, proc_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
|
120
|
+
) -> None:
|
121
|
+
queue = self._packet_queues[proc_name]
|
122
|
+
handler = self._packet_handlers[proc_name]
|
123
|
+
|
55
124
|
while True:
|
125
|
+
packet = await queue.get()
|
126
|
+
if packet is None:
|
127
|
+
break
|
56
128
|
try:
|
57
|
-
packet =
|
58
|
-
handler = self._packet_handlers.get(processor_name)
|
59
|
-
if handler is None:
|
60
|
-
break
|
61
|
-
try:
|
62
|
-
if ctx.is_streaming:
|
63
|
-
async for event in handler(packet, ctx=ctx, **run_kwargs): # type: ignore[return-value]
|
64
|
-
yield event
|
65
|
-
else:
|
66
|
-
await handler(packet, ctx=ctx, **run_kwargs)
|
67
|
-
|
68
|
-
except Exception:
|
69
|
-
logger.exception(f"Error handling packet for {processor_name}")
|
70
|
-
|
71
|
-
queue.task_done()
|
72
|
-
|
73
|
-
except Exception:
|
74
|
-
logger.exception(
|
75
|
-
f"Unexpected error in processing loop for {processor_name}"
|
76
|
-
)
|
77
|
-
|
78
|
-
async def unregister_packet_handler(self, processor_name: ProcName) -> None:
|
79
|
-
if task := self._tasks.get(processor_name):
|
80
|
-
task.cancel()
|
81
|
-
try:
|
82
|
-
await task
|
129
|
+
await handler(packet, ctx=ctx, **run_kwargs)
|
83
130
|
except asyncio.CancelledError:
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
131
|
+
raise
|
132
|
+
except Exception as err:
|
133
|
+
logger.exception("Error handling packet for %s", proc_name)
|
134
|
+
self._errors.append(err)
|
135
|
+
fut = self._final_result_fut
|
136
|
+
if fut and not fut.done():
|
137
|
+
fut.set_exception(err)
|
138
|
+
await self.shutdown()
|
139
|
+
raise
|
140
|
+
|
141
|
+
async def stream_events(self) -> AsyncIterator[Event[Any]]:
|
142
|
+
while True:
|
143
|
+
event = await self._event_queue.get()
|
144
|
+
if event is None:
|
145
|
+
break
|
146
|
+
yield event
|
147
|
+
|
148
|
+
async def shutdown(self) -> None:
|
149
|
+
if self._stopping:
|
150
|
+
await self._stopped_evt.wait()
|
151
|
+
return
|
152
|
+
self._stopping = True
|
153
|
+
try:
|
154
|
+
await self._event_queue.put(None)
|
155
|
+
for queue in self._packet_queues.values():
|
156
|
+
await queue.put(None)
|
157
|
+
|
158
|
+
finally:
|
159
|
+
self._stopped_evt.set()
|
grasp_agents/printer.py
CHANGED
@@ -119,7 +119,7 @@ class Printer:
|
|
119
119
|
# Thinking
|
120
120
|
if isinstance(message, AssistantMessage) and message.reasoning_content:
|
121
121
|
thinking = message.reasoning_content.strip(" \n")
|
122
|
-
out += f"
|
122
|
+
out += f"<thinking>\n{thinking}\n</thinking>\n"
|
123
123
|
|
124
124
|
# Content
|
125
125
|
content = self.content_to_str(message.content or "", message.role)
|
@@ -219,12 +219,12 @@ async def print_event_stream(
|
|
219
219
|
) -> None:
|
220
220
|
color = _get_color(event, Role.ASSISTANT)
|
221
221
|
|
222
|
-
if isinstance(event,
|
223
|
-
src = "processor"
|
224
|
-
elif isinstance(event, WorkflowResultEvent):
|
222
|
+
if isinstance(event, WorkflowResultEvent):
|
225
223
|
src = "workflow"
|
226
|
-
|
224
|
+
elif isinstance(event, RunResultEvent):
|
227
225
|
src = "run"
|
226
|
+
else:
|
227
|
+
src = "processor"
|
228
228
|
|
229
229
|
text = f"\n<{event.proc_name}> [{event.call_id}]\n"
|
230
230
|
|
@@ -232,6 +232,10 @@ async def print_event_stream(
|
|
232
232
|
text += f"<{src} output>\n"
|
233
233
|
for p in event.data.payloads:
|
234
234
|
if isinstance(p, BaseModel):
|
235
|
+
for field_info in type(p).model_fields.values():
|
236
|
+
if field_info.exclude:
|
237
|
+
field_info.exclude = False
|
238
|
+
type(p).model_rebuild(force=True)
|
235
239
|
p_str = p.model_dump_json(indent=2)
|
236
240
|
else:
|
237
241
|
try:
|