grasp_agents 0.2.11__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +15 -14
- grasp_agents/cloud_llm.py +118 -131
- grasp_agents/comm_processor.py +201 -0
- grasp_agents/generics_utils.py +15 -7
- grasp_agents/llm.py +60 -31
- grasp_agents/llm_agent.py +229 -273
- grasp_agents/llm_agent_memory.py +58 -0
- grasp_agents/llm_policy_executor.py +482 -0
- grasp_agents/memory.py +20 -134
- grasp_agents/message_history.py +140 -0
- grasp_agents/openai/__init__.py +54 -36
- grasp_agents/openai/completion_chunk_converters.py +78 -0
- grasp_agents/openai/completion_converters.py +53 -30
- grasp_agents/openai/content_converters.py +13 -14
- grasp_agents/openai/converters.py +44 -68
- grasp_agents/openai/message_converters.py +58 -72
- grasp_agents/openai/openai_llm.py +101 -42
- grasp_agents/openai/tool_converters.py +24 -19
- grasp_agents/packet.py +24 -0
- grasp_agents/packet_pool.py +91 -0
- grasp_agents/printer.py +29 -15
- grasp_agents/processor.py +193 -0
- grasp_agents/prompt_builder.py +175 -192
- grasp_agents/run_context.py +20 -37
- grasp_agents/typing/completion.py +58 -12
- grasp_agents/typing/completion_chunk.py +173 -0
- grasp_agents/typing/converters.py +8 -12
- grasp_agents/typing/events.py +86 -0
- grasp_agents/typing/io.py +4 -13
- grasp_agents/typing/message.py +12 -50
- grasp_agents/typing/tool.py +52 -26
- grasp_agents/usage_tracker.py +6 -6
- grasp_agents/utils.py +3 -3
- grasp_agents/workflow/looped_workflow.py +132 -0
- grasp_agents/workflow/parallel_processor.py +95 -0
- grasp_agents/workflow/sequential_workflow.py +66 -0
- grasp_agents/workflow/workflow_processor.py +78 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/METADATA +41 -50
- grasp_agents-0.3.2.dist-info/RECORD +51 -0
- grasp_agents/agent_message.py +0 -27
- grasp_agents/agent_message_pool.py +0 -92
- grasp_agents/base_agent.py +0 -51
- grasp_agents/comm_agent.py +0 -217
- grasp_agents/llm_agent_state.py +0 -79
- grasp_agents/tool_orchestrator.py +0 -203
- grasp_agents/workflow/looped_agent.py +0 -134
- grasp_agents/workflow/sequential_agent.py +0 -72
- grasp_agents/workflow/workflow_agent.py +0 -88
- grasp_agents-0.2.11.dist-info/RECORD +0 -46
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/WHEEL +0 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,38 +1,43 @@
|
|
1
1
|
from typing import Any
|
2
2
|
|
3
|
+
from openai import pydantic_function_tool
|
3
4
|
from pydantic import BaseModel
|
4
5
|
|
5
|
-
from ..typing.tool import BaseTool, ToolChoice
|
6
|
+
from ..typing.tool import BaseTool, NamedToolChoice, ToolChoice
|
6
7
|
from . import (
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
8
|
+
OpenAIFunctionDefinition,
|
9
|
+
OpenAINamedToolChoiceFunction,
|
10
|
+
OpenAINamedToolChoiceParam,
|
11
|
+
OpenAIToolChoiceOptionParam,
|
12
|
+
OpenAIToolParam,
|
12
13
|
)
|
13
14
|
|
14
15
|
|
15
16
|
def to_api_tool(
|
16
|
-
tool: BaseTool[BaseModel, Any, Any],
|
17
|
-
) ->
|
18
|
-
|
17
|
+
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
|
18
|
+
) -> OpenAIToolParam:
|
19
|
+
if strict:
|
20
|
+
# Enforce strict mode for pydantic models
|
21
|
+
return pydantic_function_tool(
|
22
|
+
model=tool.in_type, name=tool.name, description=tool.description
|
23
|
+
)
|
24
|
+
|
25
|
+
function = OpenAIFunctionDefinition(
|
19
26
|
name=tool.name,
|
20
27
|
description=tool.description,
|
21
|
-
parameters=tool.
|
22
|
-
strict=
|
28
|
+
parameters=tool.in_type.model_json_schema(),
|
29
|
+
strict=strict,
|
23
30
|
)
|
24
|
-
if
|
31
|
+
if strict is None:
|
25
32
|
function.pop("strict")
|
26
33
|
|
27
|
-
return
|
34
|
+
return OpenAIToolParam(type="function", function=function)
|
28
35
|
|
29
36
|
|
30
|
-
def to_api_tool_choice(
|
31
|
-
tool_choice:
|
32
|
-
|
33
|
-
if isinstance(tool_choice, BaseTool):
|
34
|
-
return ChatCompletionNamedToolChoiceParam(
|
37
|
+
def to_api_tool_choice(tool_choice: ToolChoice) -> OpenAIToolChoiceOptionParam:
|
38
|
+
if isinstance(tool_choice, NamedToolChoice):
|
39
|
+
return OpenAINamedToolChoiceParam(
|
35
40
|
type="function",
|
36
|
-
function=
|
41
|
+
function=OpenAINamedToolChoiceFunction(name=tool_choice.name),
|
37
42
|
)
|
38
43
|
return tool_choice
|
grasp_agents/packet.py
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
from collections.abc import Sequence
|
2
|
+
from typing import Generic, TypeVar
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field
|
6
|
+
|
7
|
+
from .typing.io import ProcName
|
8
|
+
|
9
|
+
_PayloadT_co = TypeVar("_PayloadT_co", covariant=True)
|
10
|
+
|
11
|
+
|
12
|
+
class Packet(BaseModel, Generic[_PayloadT_co]):
|
13
|
+
id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
14
|
+
payloads: Sequence[_PayloadT_co]
|
15
|
+
sender: ProcName
|
16
|
+
recipients: Sequence[ProcName] = Field(default_factory=list)
|
17
|
+
|
18
|
+
model_config = ConfigDict(extra="forbid", frozen=True)
|
19
|
+
|
20
|
+
def __repr__(self) -> str:
|
21
|
+
return (
|
22
|
+
f"From: {self.sender}, To: {', '.join(self.recipients)}, "
|
23
|
+
f"Payloads: {len(self.payloads)}"
|
24
|
+
)
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from typing import Any, Generic, Protocol, TypeVar
|
4
|
+
|
5
|
+
from .packet import Packet
|
6
|
+
from .run_context import CtxT, RunContext
|
7
|
+
from .typing.io import ProcName
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
_PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
|
13
|
+
|
14
|
+
|
15
|
+
class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
|
16
|
+
async def __call__(
|
17
|
+
self,
|
18
|
+
packet: Packet[_PayloadT_contra],
|
19
|
+
ctx: RunContext[CtxT] | None,
|
20
|
+
**kwargs: Any,
|
21
|
+
) -> None: ...
|
22
|
+
|
23
|
+
|
24
|
+
class PacketPool(Generic[CtxT]):
|
25
|
+
def __init__(self) -> None:
|
26
|
+
self._queues: dict[ProcName, asyncio.Queue[Packet[Any]]] = {}
|
27
|
+
self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
|
28
|
+
self._tasks: dict[ProcName, asyncio.Task[None]] = {}
|
29
|
+
|
30
|
+
async def post(self, packet: Packet[Any]) -> None:
|
31
|
+
for recipient_id in packet.recipients:
|
32
|
+
queue = self._queues.setdefault(recipient_id, asyncio.Queue())
|
33
|
+
await queue.put(packet)
|
34
|
+
|
35
|
+
def register_packet_handler(
|
36
|
+
self,
|
37
|
+
processor_name: ProcName,
|
38
|
+
handler: PacketHandler[Any, CtxT],
|
39
|
+
ctx: RunContext[CtxT] | None = None,
|
40
|
+
**run_kwargs: Any,
|
41
|
+
) -> None:
|
42
|
+
self._packet_handlers[processor_name] = handler
|
43
|
+
self._queues.setdefault(processor_name, asyncio.Queue())
|
44
|
+
if processor_name not in self._tasks:
|
45
|
+
self._tasks[processor_name] = asyncio.create_task(
|
46
|
+
self._handle_packets(processor_name, ctx=ctx, **run_kwargs)
|
47
|
+
)
|
48
|
+
|
49
|
+
async def _handle_packets(
|
50
|
+
self,
|
51
|
+
processor_name: ProcName,
|
52
|
+
ctx: RunContext[CtxT] | None = None,
|
53
|
+
**run_kwargs: Any,
|
54
|
+
) -> None:
|
55
|
+
queue = self._queues[processor_name]
|
56
|
+
while True:
|
57
|
+
try:
|
58
|
+
packet = await queue.get()
|
59
|
+
handler = self._packet_handlers.get(processor_name)
|
60
|
+
if handler is None:
|
61
|
+
break
|
62
|
+
|
63
|
+
try:
|
64
|
+
await self._packet_handlers[processor_name](
|
65
|
+
packet, ctx=ctx, **run_kwargs
|
66
|
+
)
|
67
|
+
except Exception:
|
68
|
+
logger.exception(f"Error handling packet for {processor_name}")
|
69
|
+
|
70
|
+
queue.task_done()
|
71
|
+
|
72
|
+
except Exception:
|
73
|
+
logger.exception(
|
74
|
+
f"Unexpected error in processing loop for {processor_name}"
|
75
|
+
)
|
76
|
+
|
77
|
+
async def unregister_packet_handler(self, processor_name: ProcName) -> None:
|
78
|
+
if task := self._tasks.get(processor_name):
|
79
|
+
task.cancel()
|
80
|
+
try:
|
81
|
+
await task
|
82
|
+
except asyncio.CancelledError:
|
83
|
+
logger.debug(f"{processor_name} exited")
|
84
|
+
|
85
|
+
self._tasks.pop(processor_name, None)
|
86
|
+
self._queues.pop(processor_name, None)
|
87
|
+
self._packet_handlers.pop(processor_name, None)
|
88
|
+
|
89
|
+
async def stop_all(self) -> None:
|
90
|
+
for processor_name in list(self._tasks):
|
91
|
+
await self.unregister_packet_handler(processor_name)
|
grasp_agents/printer.py
CHANGED
@@ -6,13 +6,14 @@ from typing import Literal, TypeAlias
|
|
6
6
|
|
7
7
|
from termcolor._types import Color # type: ignore[import]
|
8
8
|
|
9
|
+
from .typing.completion import Usage
|
9
10
|
from .typing.content import Content, ContentPartText
|
10
11
|
from .typing.message import AssistantMessage, Message, Role, ToolMessage
|
11
12
|
|
12
13
|
logger = logging.getLogger(__name__)
|
13
14
|
|
14
15
|
|
15
|
-
ColoringMode: TypeAlias = Literal["
|
16
|
+
ColoringMode: TypeAlias = Literal["agent", "role"]
|
16
17
|
|
17
18
|
ROLE_TO_COLOR: Mapping[Role, Color] = {
|
18
19
|
Role.SYSTEM: "magenta",
|
@@ -50,9 +51,9 @@ class Printer:
|
|
50
51
|
return ROLE_TO_COLOR[role]
|
51
52
|
|
52
53
|
@staticmethod
|
53
|
-
def get_agent_color(
|
54
|
+
def get_agent_color(agent_name: str) -> Color:
|
54
55
|
idx = int(
|
55
|
-
hashlib.md5(
|
56
|
+
hashlib.md5(agent_name.encode()).hexdigest(), # noqa :S324
|
56
57
|
16,
|
57
58
|
) % len(AVAILABLE_COLORS)
|
58
59
|
|
@@ -82,16 +83,22 @@ class Printer:
|
|
82
83
|
|
83
84
|
return content_str
|
84
85
|
|
85
|
-
def print_llm_message(
|
86
|
+
def print_llm_message(
|
87
|
+
self, message: Message, agent_name: str, usage: Usage | None = None
|
88
|
+
) -> None:
|
86
89
|
if not self.print_messages:
|
87
90
|
return
|
88
91
|
|
92
|
+
if usage is not None and not isinstance(message, AssistantMessage):
|
93
|
+
raise ValueError(
|
94
|
+
"Usage information can only be printed for AssistantMessage"
|
95
|
+
)
|
96
|
+
|
89
97
|
role = message.role
|
90
|
-
usage = message.usage if isinstance(message, AssistantMessage) else None
|
91
98
|
content_str = self.content_to_str(message.content or "", message.role)
|
92
99
|
|
93
|
-
if self.color_by == "
|
94
|
-
color = self.get_agent_color(
|
100
|
+
if self.color_by == "agent":
|
101
|
+
color = self.get_agent_color(agent_name)
|
95
102
|
elif self.color_by == "role":
|
96
103
|
color = self.get_role_color(role)
|
97
104
|
|
@@ -99,11 +106,11 @@ class Printer:
|
|
99
106
|
|
100
107
|
# Print message title
|
101
108
|
|
102
|
-
out = f"\n<{
|
109
|
+
out = f"\n<{agent_name}>"
|
103
110
|
out += "[" + role.value.upper() + "]"
|
104
111
|
|
105
112
|
if isinstance(message, ToolMessage):
|
106
|
-
out += f"\
|
113
|
+
out += f"\n{message.name} | {message.tool_call_id}"
|
107
114
|
|
108
115
|
# Print message content
|
109
116
|
|
@@ -123,12 +130,12 @@ class Printer:
|
|
123
130
|
|
124
131
|
if isinstance(message, AssistantMessage) and message.tool_calls is not None:
|
125
132
|
for tool_call in message.tool_calls:
|
126
|
-
if self.color_by == "
|
127
|
-
tool_color = self.get_agent_color(
|
133
|
+
if self.color_by == "agent":
|
134
|
+
tool_color = self.get_agent_color(agent_name=agent_name)
|
128
135
|
elif self.color_by == "role":
|
129
136
|
tool_color = self.get_role_color(role=Role.TOOL)
|
130
137
|
logger.debug(
|
131
|
-
f"\n[TOOL_CALL]<{
|
138
|
+
f"\n[TOOL_CALL]<{agent_name}>\n{tool_call.tool_name} "
|
132
139
|
f"| {tool_call.id}\n{tool_call.tool_arguments}",
|
133
140
|
extra={"color": tool_color}, # type: ignore
|
134
141
|
)
|
@@ -148,9 +155,16 @@ class Printer:
|
|
148
155
|
**log_kwargs, # type: ignore
|
149
156
|
)
|
150
157
|
|
151
|
-
def print_llm_messages(
|
158
|
+
def print_llm_messages(
|
159
|
+
self,
|
160
|
+
messages: Sequence[Message],
|
161
|
+
agent_name: str,
|
162
|
+
usages: Sequence[Usage | None] | None = None,
|
163
|
+
) -> None:
|
152
164
|
if not self.print_messages:
|
153
165
|
return
|
154
166
|
|
155
|
-
|
156
|
-
|
167
|
+
_usages: Sequence[Usage | None] = usages or [None] * len(messages)
|
168
|
+
|
169
|
+
for _message, _usage in zip(messages, _usages, strict=False):
|
170
|
+
self.print_llm_message(_message, usage=_usage, agent_name=agent_name)
|
@@ -0,0 +1,193 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from collections.abc import AsyncIterator, Sequence
|
3
|
+
from typing import Any, ClassVar, Generic, cast, final
|
4
|
+
|
5
|
+
from pydantic import BaseModel, TypeAdapter
|
6
|
+
|
7
|
+
from .generics_utils import AutoInstanceAttributesMixin
|
8
|
+
from .packet import Packet
|
9
|
+
from .run_context import CtxT, RunContext
|
10
|
+
from .typing.events import Event, PacketEvent, ProcOutputEvent
|
11
|
+
from .typing.io import InT_contra, MemT_co, OutT_co, ProcName
|
12
|
+
from .typing.tool import BaseTool
|
13
|
+
|
14
|
+
|
15
|
+
class Processor(
|
16
|
+
AutoInstanceAttributesMixin, ABC, Generic[InT_contra, OutT_co, MemT_co, CtxT]
|
17
|
+
):
|
18
|
+
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
19
|
+
0: "_in_type",
|
20
|
+
1: "_out_type",
|
21
|
+
}
|
22
|
+
|
23
|
+
def __init__(self, name: ProcName, **kwargs: Any) -> None:
|
24
|
+
self._in_type: type[InT_contra]
|
25
|
+
self._out_type: type[OutT_co]
|
26
|
+
|
27
|
+
super().__init__()
|
28
|
+
|
29
|
+
self._in_type_adapter: TypeAdapter[InT_contra] = TypeAdapter(self._in_type)
|
30
|
+
self._out_type_adapter: TypeAdapter[OutT_co] = TypeAdapter(self._out_type)
|
31
|
+
|
32
|
+
self._name: ProcName = name
|
33
|
+
self._memory: MemT_co
|
34
|
+
|
35
|
+
@property
|
36
|
+
def in_type(self) -> type[InT_contra]: # type: ignore[reportInvalidTypeVarUse]
|
37
|
+
# Exposing the type of a contravariant variable only, should be type safe
|
38
|
+
return self._in_type
|
39
|
+
|
40
|
+
@property
|
41
|
+
def out_type(self) -> type[OutT_co]:
|
42
|
+
return self._out_type
|
43
|
+
|
44
|
+
@property
|
45
|
+
def name(self) -> ProcName:
|
46
|
+
return self._name
|
47
|
+
|
48
|
+
@property
|
49
|
+
def memory(self) -> MemT_co:
|
50
|
+
return self._memory
|
51
|
+
|
52
|
+
def _validate_and_resolve_inputs(
|
53
|
+
self,
|
54
|
+
chat_inputs: Any | None = None,
|
55
|
+
in_packet: Packet[InT_contra] | None = None,
|
56
|
+
in_args: InT_contra | Sequence[InT_contra] | None = None,
|
57
|
+
) -> Sequence[InT_contra] | None:
|
58
|
+
multiple_inputs_err_message = (
|
59
|
+
"Only one of chat_inputs, in_args, or in_message must be provided."
|
60
|
+
)
|
61
|
+
if chat_inputs is not None and in_args is not None:
|
62
|
+
raise ValueError(multiple_inputs_err_message)
|
63
|
+
if chat_inputs is not None and in_packet is not None:
|
64
|
+
raise ValueError(multiple_inputs_err_message)
|
65
|
+
if in_args is not None and in_packet is not None:
|
66
|
+
raise ValueError(multiple_inputs_err_message)
|
67
|
+
|
68
|
+
resolved_in_args: Sequence[InT_contra] | None = None
|
69
|
+
if in_packet is not None:
|
70
|
+
resolved_in_args = in_packet.payloads
|
71
|
+
elif isinstance(in_args, self._in_type):
|
72
|
+
resolved_in_args = cast("Sequence[InT_contra]", [in_args])
|
73
|
+
elif in_args is None:
|
74
|
+
resolved_in_args = in_args
|
75
|
+
else:
|
76
|
+
resolved_in_args = cast("Sequence[InT_contra]", in_args)
|
77
|
+
|
78
|
+
return resolved_in_args
|
79
|
+
|
80
|
+
async def _process(
|
81
|
+
self,
|
82
|
+
chat_inputs: Any | None = None,
|
83
|
+
*,
|
84
|
+
in_args: Sequence[InT_contra] | None = None,
|
85
|
+
forgetful: bool = False,
|
86
|
+
ctx: RunContext[CtxT] | None = None,
|
87
|
+
) -> Sequence[OutT_co]:
|
88
|
+
assert in_args is not None, (
|
89
|
+
"Default implementation of _process requires in_args"
|
90
|
+
)
|
91
|
+
|
92
|
+
return cast("Sequence[OutT_co]", in_args)
|
93
|
+
|
94
|
+
async def _process_stream(
|
95
|
+
self,
|
96
|
+
chat_inputs: Any | None = None,
|
97
|
+
*,
|
98
|
+
in_args: Sequence[InT_contra] | None = None,
|
99
|
+
forgetful: bool = False,
|
100
|
+
ctx: RunContext[CtxT] | None = None,
|
101
|
+
) -> AsyncIterator[Event[Any]]:
|
102
|
+
assert in_args is not None, (
|
103
|
+
"Default implementation of _process requires in_args"
|
104
|
+
)
|
105
|
+
outputs = cast("Sequence[OutT_co]", in_args)
|
106
|
+
for out in outputs:
|
107
|
+
yield ProcOutputEvent(data=out, name=self.name)
|
108
|
+
|
109
|
+
def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
|
110
|
+
return [
|
111
|
+
self._out_type_adapter.validate_python(payload) for payload in out_payloads
|
112
|
+
]
|
113
|
+
|
114
|
+
async def run(
|
115
|
+
self,
|
116
|
+
chat_inputs: Any | None = None,
|
117
|
+
*,
|
118
|
+
in_packet: Packet[InT_contra] | None = None,
|
119
|
+
in_args: InT_contra | Sequence[InT_contra] | None = None,
|
120
|
+
forgetful: bool = False,
|
121
|
+
ctx: RunContext[CtxT] | None = None,
|
122
|
+
) -> Packet[OutT_co]:
|
123
|
+
resolved_in_args = self._validate_and_resolve_inputs(
|
124
|
+
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
125
|
+
)
|
126
|
+
outputs = await self._process(
|
127
|
+
chat_inputs=chat_inputs,
|
128
|
+
in_args=resolved_in_args,
|
129
|
+
forgetful=forgetful,
|
130
|
+
ctx=ctx,
|
131
|
+
)
|
132
|
+
val_outputs = self._validate_outputs(outputs)
|
133
|
+
|
134
|
+
return Packet(payloads=val_outputs, sender=self.name)
|
135
|
+
|
136
|
+
async def run_stream(
|
137
|
+
self,
|
138
|
+
chat_inputs: Any | None = None,
|
139
|
+
*,
|
140
|
+
in_packet: Packet[InT_contra] | None = None,
|
141
|
+
in_args: InT_contra | Sequence[InT_contra] | None = None,
|
142
|
+
forgetful: bool = False,
|
143
|
+
ctx: RunContext[CtxT] | None = None,
|
144
|
+
) -> AsyncIterator[Event[Any]]:
|
145
|
+
resolved_in_args = self._validate_and_resolve_inputs(
|
146
|
+
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
147
|
+
)
|
148
|
+
|
149
|
+
outputs: Sequence[OutT_co] = []
|
150
|
+
async for output_event in self._process_stream(
|
151
|
+
chat_inputs=chat_inputs,
|
152
|
+
in_args=resolved_in_args,
|
153
|
+
forgetful=forgetful,
|
154
|
+
ctx=ctx,
|
155
|
+
):
|
156
|
+
if isinstance(output_event, ProcOutputEvent):
|
157
|
+
outputs.append(output_event.data)
|
158
|
+
else:
|
159
|
+
yield output_event
|
160
|
+
|
161
|
+
val_outputs = self._validate_outputs(outputs)
|
162
|
+
out_packet = Packet[OutT_co](payloads=val_outputs, sender=self.name)
|
163
|
+
|
164
|
+
yield PacketEvent(data=out_packet, name=self.name)
|
165
|
+
|
166
|
+
@final
|
167
|
+
def as_tool(
|
168
|
+
self, tool_name: str, tool_description: str
|
169
|
+
) -> BaseTool[InT_contra, OutT_co, Any]: # type: ignore[override]
|
170
|
+
# TODO: stream tools
|
171
|
+
processor_instance = self
|
172
|
+
in_type = processor_instance.in_type
|
173
|
+
out_type = processor_instance.out_type
|
174
|
+
if not issubclass(in_type, BaseModel):
|
175
|
+
raise TypeError(
|
176
|
+
"Cannot create a tool from an agent with "
|
177
|
+
f"non-BaseModel input type: {in_type}"
|
178
|
+
)
|
179
|
+
|
180
|
+
class ProcessorTool(BaseTool[in_type, out_type, Any]):
|
181
|
+
name: str = tool_name
|
182
|
+
description: str = tool_description
|
183
|
+
|
184
|
+
async def run(
|
185
|
+
self, inp: InT_contra, ctx: RunContext[CtxT] | None = None
|
186
|
+
) -> OutT_co:
|
187
|
+
result = await processor_instance.run(
|
188
|
+
in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
|
189
|
+
)
|
190
|
+
|
191
|
+
return result.payloads[0]
|
192
|
+
|
193
|
+
return ProcessorTool() # type: ignore[return-value]
|