grasp_agents 0.5.3__tar.gz → 0.5.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/PKG-INFO +1 -1
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/pyproject.toml +1 -1
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/__init__.py +3 -4
- grasp_agents-0.5.4/src/grasp_agents/errors.py +156 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/llm_agent.py +36 -58
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/llm_policy_executor.py +2 -2
- grasp_agents-0.5.4/src/grasp_agents/packet.py +46 -0
- grasp_agents-0.5.4/src/grasp_agents/packet_pool.py +159 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/printer.py +8 -4
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/processor.py +208 -163
- grasp_agents-0.5.4/src/grasp_agents/prompt_builder.py +193 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/run_context.py +1 -9
- grasp_agents-0.5.4/src/grasp_agents/runner.py +131 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/typing/events.py +8 -4
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/typing/io.py +1 -1
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/workflow/looped_workflow.py +13 -19
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/workflow/sequential_workflow.py +6 -10
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/workflow/workflow_processor.py +7 -15
- grasp_agents-0.5.3/src/grasp_agents/comm_processor.py +0 -214
- grasp_agents-0.5.3/src/grasp_agents/errors.py +0 -94
- grasp_agents-0.5.3/src/grasp_agents/packet.py +0 -27
- grasp_agents-0.5.3/src/grasp_agents/packet_pool.py +0 -92
- grasp_agents-0.5.3/src/grasp_agents/prompt_builder.py +0 -218
- grasp_agents-0.5.3/src/grasp_agents/runner.py +0 -42
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/.gitignore +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/LICENSE.md +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/README.md +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/cloud_llm.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/costs_dict.yaml +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/generics_utils.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/grasp_logging.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/http_client.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/litellm/__init__.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/litellm/completion_chunk_converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/litellm/completion_converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/litellm/converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/litellm/lite_llm.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/litellm/message_converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/llm.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/llm_agent_memory.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/memory.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/openai/__init__.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/openai/completion_chunk_converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/openai/completion_converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/openai/content_converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/openai/converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/openai/message_converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/openai/openai_llm.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/openai/tool_converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/rate_limiting/__init__.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/rate_limiting/types.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/rate_limiting/utils.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/typing/__init__.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/typing/completion.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/typing/completion_chunk.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/typing/content.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/typing/converters.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/typing/message.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/typing/tool.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/usage_tracker.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/utils.py +0 -0
- {grasp_agents-0.5.3 → grasp_agents-0.5.4}/src/grasp_agents/workflow/__init__.py +0 -0
@@ -1,14 +1,13 @@
|
|
1
1
|
# pyright: reportUnusedImport=false
|
2
2
|
|
3
3
|
|
4
|
-
from .comm_processor import CommProcessor
|
5
4
|
from .llm import LLM, LLMSettings
|
6
5
|
from .llm_agent import LLMAgent
|
7
6
|
from .llm_agent_memory import LLMAgentMemory
|
8
7
|
from .memory import Memory
|
9
8
|
from .packet import Packet
|
10
9
|
from .processor import Processor
|
11
|
-
from .run_context import
|
10
|
+
from .run_context import RunContext
|
12
11
|
from .typing.completion import Completion
|
13
12
|
from .typing.content import Content, ImageData
|
14
13
|
from .typing.io import LLMPrompt, LLMPromptArgs, ProcName
|
@@ -19,20 +18,20 @@ __all__ = [
|
|
19
18
|
"LLM",
|
20
19
|
"AssistantMessage",
|
21
20
|
"BaseTool",
|
22
|
-
"CommProcessor",
|
23
21
|
"Completion",
|
24
22
|
"Content",
|
25
23
|
"ImageData",
|
26
24
|
"LLMAgent",
|
25
|
+
"LLMAgentMemory",
|
27
26
|
"LLMPrompt",
|
28
27
|
"LLMPromptArgs",
|
29
28
|
"LLMSettings",
|
29
|
+
"Memory",
|
30
30
|
"Messages",
|
31
31
|
"Packet",
|
32
32
|
"Packet",
|
33
33
|
"ProcName",
|
34
34
|
"Processor",
|
35
|
-
"RunArgs",
|
36
35
|
"RunContext",
|
37
36
|
"SystemMessage",
|
38
37
|
"UserMessage",
|
@@ -0,0 +1,156 @@
|
|
1
|
+
# from openai import APIResponseValidationError
|
2
|
+
|
3
|
+
|
4
|
+
class ProcRunError(Exception):
|
5
|
+
def __init__(
|
6
|
+
self, proc_name: str, call_id: str, message: str | None = None
|
7
|
+
) -> None:
|
8
|
+
super().__init__(
|
9
|
+
message
|
10
|
+
or f"Processor run failed [proc_name: {proc_name}; call_id: {call_id}]."
|
11
|
+
)
|
12
|
+
self.proc_name = proc_name
|
13
|
+
self.call_id = call_id
|
14
|
+
|
15
|
+
|
16
|
+
class ProcInputValidationError(ProcRunError):
|
17
|
+
pass
|
18
|
+
|
19
|
+
|
20
|
+
class ProcOutputValidationError(ProcRunError):
|
21
|
+
def __init__(
|
22
|
+
self, schema: object, proc_name: str, call_id: str, message: str | None = None
|
23
|
+
):
|
24
|
+
super().__init__(
|
25
|
+
proc_name=proc_name,
|
26
|
+
call_id=call_id,
|
27
|
+
message=message
|
28
|
+
or (
|
29
|
+
"Processor output validation failed "
|
30
|
+
f"[proc_name: {proc_name}; call_id: {call_id}]. "
|
31
|
+
f"Expected type:\n{schema}"
|
32
|
+
),
|
33
|
+
)
|
34
|
+
|
35
|
+
|
36
|
+
class AgentFinalAnswerError(ProcRunError):
|
37
|
+
def __init__(
|
38
|
+
self, proc_name: str, call_id: str, message: str | None = None
|
39
|
+
) -> None:
|
40
|
+
super().__init__(
|
41
|
+
proc_name=proc_name,
|
42
|
+
call_id=call_id,
|
43
|
+
message=message
|
44
|
+
or "Final answer tool call did not return a final answer message "
|
45
|
+
f"[proc_name={proc_name}; call_id={call_id}]",
|
46
|
+
)
|
47
|
+
self.message = message
|
48
|
+
|
49
|
+
|
50
|
+
class WorkflowConstructionError(Exception):
|
51
|
+
pass
|
52
|
+
|
53
|
+
|
54
|
+
class PacketRoutingError(ProcRunError):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
proc_name: str,
|
58
|
+
call_id: str,
|
59
|
+
selected_recipient: str | None = None,
|
60
|
+
allowed_recipients: list[str] | None = None,
|
61
|
+
message: str | None = None,
|
62
|
+
) -> None:
|
63
|
+
default_message = (
|
64
|
+
f"Selected recipient '{selected_recipient}' is not in the allowed "
|
65
|
+
f"recipients: {allowed_recipients} "
|
66
|
+
f"[proc_name={proc_name}; call_id={call_id}]"
|
67
|
+
)
|
68
|
+
super().__init__(
|
69
|
+
proc_name=proc_name, call_id=call_id, message=message or default_message
|
70
|
+
)
|
71
|
+
self.selected_recipient = selected_recipient
|
72
|
+
self.allowed_recipients = allowed_recipients
|
73
|
+
|
74
|
+
|
75
|
+
class RunnerError(Exception):
|
76
|
+
pass
|
77
|
+
|
78
|
+
|
79
|
+
class PromptBuilderError(Exception):
|
80
|
+
def __init__(self, proc_name: str, message: str | None = None) -> None:
|
81
|
+
super().__init__(message or f"Prompt builder failed [proc_name={proc_name}]")
|
82
|
+
self.proc_name = proc_name
|
83
|
+
self.message = message
|
84
|
+
|
85
|
+
|
86
|
+
class SystemPromptBuilderError(PromptBuilderError):
|
87
|
+
def __init__(self, proc_name: str, message: str | None = None) -> None:
|
88
|
+
super().__init__(
|
89
|
+
proc_name=proc_name,
|
90
|
+
message=message
|
91
|
+
or "System prompt builder failed to make system prompt "
|
92
|
+
f"[proc_name={proc_name}]",
|
93
|
+
)
|
94
|
+
self.message = message
|
95
|
+
|
96
|
+
|
97
|
+
class InputPromptBuilderError(PromptBuilderError):
|
98
|
+
def __init__(self, proc_name: str, message: str | None = None) -> None:
|
99
|
+
super().__init__(
|
100
|
+
proc_name=proc_name,
|
101
|
+
message=message
|
102
|
+
or "Input prompt builder failed to make input content "
|
103
|
+
f"[proc_name={proc_name}]",
|
104
|
+
)
|
105
|
+
self.message = message
|
106
|
+
|
107
|
+
|
108
|
+
class PyJSONStringParsingError(Exception):
|
109
|
+
def __init__(self, s: str, message: str | None = None) -> None:
|
110
|
+
super().__init__(
|
111
|
+
message
|
112
|
+
or "Both ast.literal_eval and json.loads failed to parse the following "
|
113
|
+
f"JSON/Python string:\n{s}"
|
114
|
+
)
|
115
|
+
self.s = s
|
116
|
+
|
117
|
+
|
118
|
+
class JSONSchemaValidationError(Exception):
|
119
|
+
def __init__(self, s: str, schema: object, message: str | None = None) -> None:
|
120
|
+
super().__init__(
|
121
|
+
message
|
122
|
+
or f"JSON schema validation failed for:\n{s}\nExpected type: {schema}"
|
123
|
+
)
|
124
|
+
self.s = s
|
125
|
+
self.schema = schema
|
126
|
+
|
127
|
+
|
128
|
+
class CompletionError(Exception):
|
129
|
+
pass
|
130
|
+
|
131
|
+
|
132
|
+
class CombineCompletionChunksError(Exception):
|
133
|
+
pass
|
134
|
+
|
135
|
+
|
136
|
+
class LLMToolCallValidationError(Exception):
|
137
|
+
def __init__(
|
138
|
+
self, tool_name: str, tool_args: str, message: str | None = None
|
139
|
+
) -> None:
|
140
|
+
super().__init__(
|
141
|
+
message
|
142
|
+
or f"Failed to validate tool call '{tool_name}' with arguments:"
|
143
|
+
f"\n{tool_args}."
|
144
|
+
)
|
145
|
+
self.tool_name = tool_name
|
146
|
+
self.tool_args = tool_args
|
147
|
+
|
148
|
+
|
149
|
+
class LLMResponseValidationError(JSONSchemaValidationError):
|
150
|
+
def __init__(self, s: str, schema: object, message: str | None = None) -> None:
|
151
|
+
super().__init__(
|
152
|
+
s,
|
153
|
+
schema,
|
154
|
+
message
|
155
|
+
or f"Failed to validate LLM response:\n{s}\nExpected type: {schema}",
|
156
|
+
)
|
@@ -4,7 +4,6 @@ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast
|
|
4
4
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
7
|
-
from .comm_processor import CommProcessor
|
8
7
|
from .llm import LLM, LLMSettings
|
9
8
|
from .llm_agent_memory import LLMAgentMemory, PrepareMemoryHandler
|
10
9
|
from .llm_policy_executor import (
|
@@ -12,7 +11,7 @@ from .llm_policy_executor import (
|
|
12
11
|
LLMPolicyExecutor,
|
13
12
|
ManageMemoryHandler,
|
14
13
|
)
|
15
|
-
from .
|
14
|
+
from .processor import Processor
|
16
15
|
from .prompt_builder import (
|
17
16
|
MakeInputContentHandler,
|
18
17
|
MakeSystemPromptHandler,
|
@@ -27,7 +26,7 @@ from .typing.events import (
|
|
27
26
|
SystemMessageEvent,
|
28
27
|
UserMessageEvent,
|
29
28
|
)
|
30
|
-
from .typing.io import InT, LLMPrompt, LLMPromptArgs,
|
29
|
+
from .typing.io import InT, LLMPrompt, LLMPromptArgs, OutT, ProcName
|
31
30
|
from .typing.message import Message, Messages, SystemMessage, UserMessage
|
32
31
|
from .typing.tool import BaseTool
|
33
32
|
from .utils import get_prompt, validate_obj_from_json_or_py_string
|
@@ -47,8 +46,8 @@ class ParseOutputHandler(Protocol[_InT_contra, _OutT_co, CtxT]):
|
|
47
46
|
|
48
47
|
|
49
48
|
class LLMAgent(
|
50
|
-
|
51
|
-
Generic[InT,
|
49
|
+
Processor[InT, OutT, LLMAgentMemory, CtxT],
|
50
|
+
Generic[InT, OutT, CtxT],
|
52
51
|
):
|
53
52
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
54
53
|
0: "_in_type",
|
@@ -71,8 +70,6 @@ class LLMAgent(
|
|
71
70
|
sys_prompt_path: str | Path | None = None,
|
72
71
|
# System args (static args provided via RunContext)
|
73
72
|
sys_args_schema: type[LLMPromptArgs] | None = None,
|
74
|
-
# User args (static args provided via RunContext)
|
75
|
-
usr_args_schema: type[LLMPromptArgs] | None = None,
|
76
73
|
# Agent loop settings
|
77
74
|
max_turns: int = 100,
|
78
75
|
react_mode: bool = False,
|
@@ -82,15 +79,9 @@ class LLMAgent(
|
|
82
79
|
# Retries
|
83
80
|
max_retries: int = 0,
|
84
81
|
# Multi-agent routing
|
85
|
-
packet_pool: PacketPool[CtxT] | None = None,
|
86
82
|
recipients: list[ProcName] | None = None,
|
87
83
|
) -> None:
|
88
|
-
super().__init__(
|
89
|
-
name=name,
|
90
|
-
packet_pool=packet_pool,
|
91
|
-
recipients=recipients,
|
92
|
-
max_retries=max_retries,
|
93
|
-
)
|
84
|
+
super().__init__(name=name, recipients=recipients, max_retries=max_retries)
|
94
85
|
|
95
86
|
# Agent memory
|
96
87
|
|
@@ -132,14 +123,13 @@ class LLMAgent(
|
|
132
123
|
self.in_type, CtxT
|
133
124
|
](
|
134
125
|
agent_name=self._name,
|
135
|
-
|
136
|
-
|
126
|
+
sys_prompt=sys_prompt,
|
127
|
+
in_prompt=in_prompt,
|
137
128
|
sys_args_schema=sys_args_schema,
|
138
|
-
usr_args_schema=usr_args_schema,
|
139
129
|
)
|
140
130
|
|
141
131
|
self._prepare_memory_impl: PrepareMemoryHandler | None = None
|
142
|
-
self._parse_output_impl: ParseOutputHandler[InT,
|
132
|
+
self._parse_output_impl: ParseOutputHandler[InT, OutT, CtxT] | None = None
|
143
133
|
self._register_overridden_handlers()
|
144
134
|
|
145
135
|
@property
|
@@ -158,17 +148,13 @@ class LLMAgent(
|
|
158
148
|
def sys_args_schema(self) -> type[LLMPromptArgs] | None:
|
159
149
|
return self._prompt_builder.sys_args_schema
|
160
150
|
|
161
|
-
@property
|
162
|
-
def usr_args_schema(self) -> type[LLMPromptArgs] | None:
|
163
|
-
return self._prompt_builder.usr_args_schema
|
164
|
-
|
165
151
|
@property
|
166
152
|
def sys_prompt(self) -> LLMPrompt | None:
|
167
|
-
return self._prompt_builder.
|
153
|
+
return self._prompt_builder.sys_prompt
|
168
154
|
|
169
155
|
@property
|
170
156
|
def in_prompt(self) -> LLMPrompt | None:
|
171
|
-
return self._prompt_builder.
|
157
|
+
return self._prompt_builder.in_prompt
|
172
158
|
|
173
159
|
def _prepare_memory(
|
174
160
|
self,
|
@@ -184,20 +170,13 @@ class LLMAgent(
|
|
184
170
|
|
185
171
|
def _memorize_inputs(
|
186
172
|
self,
|
173
|
+
memory: LLMAgentMemory,
|
187
174
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
188
|
-
*,
|
189
175
|
in_args: InT | None = None,
|
190
|
-
memory: LLMAgentMemory,
|
191
176
|
ctx: RunContext[CtxT] | None = None,
|
192
177
|
) -> tuple[SystemMessage | None, UserMessage | None]:
|
193
|
-
# 1. Get
|
194
|
-
sys_args
|
195
|
-
usr_args: LLMPromptArgs | None = None
|
196
|
-
if ctx is not None:
|
197
|
-
run_args = ctx.run_args.get(self.name)
|
198
|
-
if run_args is not None:
|
199
|
-
sys_args = run_args.sys
|
200
|
-
usr_args = run_args.usr
|
178
|
+
# 1. Get system arguments
|
179
|
+
sys_args = ctx.sys_args.get(self.name) if ctx and ctx.sys_args else None
|
201
180
|
|
202
181
|
# 2. Make system prompt (can be None)
|
203
182
|
|
@@ -214,16 +193,13 @@ class LLMAgent(
|
|
214
193
|
system_message = cast("SystemMessage", memory.message_history[0])
|
215
194
|
else:
|
216
195
|
self._prepare_memory(
|
217
|
-
memory=memory,
|
218
|
-
in_args=in_args,
|
219
|
-
sys_prompt=formatted_sys_prompt,
|
220
|
-
ctx=ctx,
|
196
|
+
memory=memory, in_args=in_args, sys_prompt=formatted_sys_prompt, ctx=ctx
|
221
197
|
)
|
222
198
|
|
223
199
|
# 3. Make and add input messages
|
224
200
|
|
225
201
|
input_message = self._prompt_builder.make_input_message(
|
226
|
-
chat_inputs=chat_inputs, in_args=in_args,
|
202
|
+
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
227
203
|
)
|
228
204
|
if input_message:
|
229
205
|
memory.update([input_message])
|
@@ -236,7 +212,7 @@ class LLMAgent(
|
|
236
212
|
*,
|
237
213
|
in_args: InT | None = None,
|
238
214
|
ctx: RunContext[CtxT] | None = None,
|
239
|
-
) ->
|
215
|
+
) -> OutT:
|
240
216
|
if self._parse_output_impl:
|
241
217
|
return self._parse_output_impl(
|
242
218
|
conversation=conversation, in_args=in_args, ctx=ctx
|
@@ -257,9 +233,12 @@ class LLMAgent(
|
|
257
233
|
memory: LLMAgentMemory,
|
258
234
|
call_id: str,
|
259
235
|
ctx: RunContext[CtxT] | None = None,
|
260
|
-
) ->
|
236
|
+
) -> OutT:
|
261
237
|
system_message, input_message = self._memorize_inputs(
|
262
|
-
|
238
|
+
memory=memory,
|
239
|
+
chat_inputs=chat_inputs,
|
240
|
+
in_args=in_args,
|
241
|
+
ctx=ctx,
|
263
242
|
)
|
264
243
|
if system_message:
|
265
244
|
self._print_messages([system_message], call_id=call_id, ctx=ctx)
|
@@ -268,11 +247,9 @@ class LLMAgent(
|
|
268
247
|
|
269
248
|
await self._policy_executor.execute(memory, call_id=call_id, ctx=ctx)
|
270
249
|
|
271
|
-
return
|
272
|
-
|
273
|
-
|
274
|
-
)
|
275
|
-
]
|
250
|
+
return self._parse_output(
|
251
|
+
conversation=memory.message_history, in_args=in_args, ctx=ctx
|
252
|
+
)
|
276
253
|
|
277
254
|
async def _process_stream(
|
278
255
|
self,
|
@@ -284,7 +261,10 @@ class LLMAgent(
|
|
284
261
|
ctx: RunContext[CtxT] | None = None,
|
285
262
|
) -> AsyncIterator[Event[Any]]:
|
286
263
|
system_message, input_message = self._memorize_inputs(
|
287
|
-
|
264
|
+
memory=memory,
|
265
|
+
chat_inputs=chat_inputs,
|
266
|
+
in_args=in_args,
|
267
|
+
ctx=ctx,
|
288
268
|
)
|
289
269
|
if system_message:
|
290
270
|
self._print_messages([system_message], call_id=call_id, ctx=ctx)
|
@@ -322,6 +302,10 @@ class LLMAgent(
|
|
322
302
|
cur_cls = type(self)
|
323
303
|
base_cls = LLMAgent[Any, Any, Any]
|
324
304
|
|
305
|
+
# Packet routing
|
306
|
+
if cur_cls._select_recipients is not base_cls._select_recipients: # noqa: SLF001
|
307
|
+
self.select_recipients_impl = self._select_recipients
|
308
|
+
|
325
309
|
# Prompt builder
|
326
310
|
|
327
311
|
if cur_cls._make_system_prompt is not base_cls._make_system_prompt: # noqa: SLF001
|
@@ -354,15 +338,9 @@ class LLMAgent(
|
|
354
338
|
return self._prompt_builder.make_system_prompt(sys_args=sys_args, ctx=ctx)
|
355
339
|
|
356
340
|
def _make_input_content(
|
357
|
-
self,
|
358
|
-
*,
|
359
|
-
in_args: InT | None = None,
|
360
|
-
usr_args: LLMPromptArgs | None = None,
|
361
|
-
ctx: RunContext[CtxT] | None = None,
|
341
|
+
self, in_args: InT | None = None, *, ctx: RunContext[CtxT] | None = None
|
362
342
|
) -> Content:
|
363
|
-
return self._prompt_builder.make_input_content(
|
364
|
-
in_args=in_args, usr_args=usr_args, ctx=ctx
|
365
|
-
)
|
343
|
+
return self._prompt_builder.make_input_content(in_args=in_args, ctx=ctx)
|
366
344
|
|
367
345
|
def _exit_tool_call_loop(
|
368
346
|
self,
|
@@ -403,8 +381,8 @@ class LLMAgent(
|
|
403
381
|
return func
|
404
382
|
|
405
383
|
def parse_output(
|
406
|
-
self, func: ParseOutputHandler[InT,
|
407
|
-
) -> ParseOutputHandler[InT,
|
384
|
+
self, func: ParseOutputHandler[InT, OutT, CtxT]
|
385
|
+
) -> ParseOutputHandler[InT, OutT, CtxT]:
|
408
386
|
if self._used_default_llm_response_schema:
|
409
387
|
self._policy_executor.llm.response_schema = None
|
410
388
|
self._parse_output_impl = func
|
@@ -255,7 +255,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
255
255
|
|
256
256
|
final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
|
257
257
|
if final_answer_message is None:
|
258
|
-
raise AgentFinalAnswerError
|
258
|
+
raise AgentFinalAnswerError(proc_name=self.agent_name, call_id=call_id)
|
259
259
|
|
260
260
|
return final_answer_message
|
261
261
|
|
@@ -282,7 +282,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
282
282
|
|
283
283
|
final_answer_message = self._extract_final_answer_from_tool_calls(memory)
|
284
284
|
if final_answer_message is None:
|
285
|
-
raise AgentFinalAnswerError
|
285
|
+
raise AgentFinalAnswerError(proc_name=self.agent_name, call_id=call_id)
|
286
286
|
yield GenMessageEvent(
|
287
287
|
proc_name=self.agent_name, call_id=call_id, data=final_answer_message
|
288
288
|
)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from collections.abc import Sequence
|
2
|
+
from typing import Annotated, Any, Generic, Literal, TypeVar
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
|
6
|
+
|
7
|
+
from .typing.io import ProcName
|
8
|
+
|
9
|
+
START_PROC_NAME: Literal["*START*"] = "*START*"
|
10
|
+
|
11
|
+
_PayloadT_co = TypeVar("_PayloadT_co", covariant=True)
|
12
|
+
|
13
|
+
|
14
|
+
class Packet(BaseModel, Generic[_PayloadT_co]):
|
15
|
+
id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
16
|
+
payloads: Sequence[_PayloadT_co]
|
17
|
+
sender: ProcName
|
18
|
+
recipients: Sequence[ProcName] | None = None
|
19
|
+
|
20
|
+
model_config = ConfigDict(extra="forbid")
|
21
|
+
|
22
|
+
def __repr__(self) -> str:
|
23
|
+
_to = ", ".join(self.recipients) if self.recipients else "None"
|
24
|
+
return (
|
25
|
+
f"{self.__class__.__name__}:\n"
|
26
|
+
f"ID: {self.id}\n"
|
27
|
+
f"From: {self.sender}\n"
|
28
|
+
f"To: {_to}\n"
|
29
|
+
f"Payloads: {len(self.payloads)}"
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def _check_recipients_length(v: Sequence[ProcName] | None) -> Sequence[ProcName] | None:
|
34
|
+
if v is not None and len(v) != 1:
|
35
|
+
raise ValueError("recipients must contain exactly one item")
|
36
|
+
return v
|
37
|
+
|
38
|
+
|
39
|
+
class StartPacket(Packet[_PayloadT_co]):
|
40
|
+
chat_inputs: Any | None = "start"
|
41
|
+
sender: ProcName = Field(default=START_PROC_NAME, frozen=True)
|
42
|
+
payloads: Sequence[_PayloadT_co] = Field(default=(), frozen=True)
|
43
|
+
recipients: Annotated[
|
44
|
+
Sequence[ProcName] | None,
|
45
|
+
AfterValidator(_check_recipients_length),
|
46
|
+
] = None
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from collections.abc import AsyncIterator
|
4
|
+
from types import TracebackType
|
5
|
+
from typing import Any, Generic, Literal, Protocol, TypeVar
|
6
|
+
|
7
|
+
from .packet import Packet
|
8
|
+
from .run_context import CtxT, RunContext
|
9
|
+
from .typing.events import Event
|
10
|
+
from .typing.io import ProcName
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
END_PROC_NAME: Literal["*END*"] = "*END*"
|
16
|
+
|
17
|
+
|
18
|
+
_PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
|
19
|
+
|
20
|
+
|
21
|
+
class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
|
22
|
+
async def __call__(
|
23
|
+
self,
|
24
|
+
packet: Packet[_PayloadT_contra],
|
25
|
+
ctx: RunContext[CtxT],
|
26
|
+
**kwargs: Any,
|
27
|
+
) -> None: ...
|
28
|
+
|
29
|
+
|
30
|
+
class PacketPool(Generic[CtxT]):
|
31
|
+
def __init__(self) -> None:
|
32
|
+
self._packet_queues: dict[ProcName, asyncio.Queue[Packet[Any] | None]] = {}
|
33
|
+
self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
|
34
|
+
self._task_group: asyncio.TaskGroup | None = None
|
35
|
+
|
36
|
+
self._event_queue: asyncio.Queue[Event[Any] | None] = asyncio.Queue()
|
37
|
+
|
38
|
+
self._final_result_fut: asyncio.Future[Packet[Any]] | None = None
|
39
|
+
|
40
|
+
self._stopping = False
|
41
|
+
self._stopped_evt = asyncio.Event()
|
42
|
+
|
43
|
+
self._errors: list[Exception] = []
|
44
|
+
|
45
|
+
async def post(self, packet: Packet[Any]) -> None:
|
46
|
+
if packet.recipients == [END_PROC_NAME]:
|
47
|
+
fut = self._ensure_final_future()
|
48
|
+
if not fut.done():
|
49
|
+
fut.set_result(packet)
|
50
|
+
await self.shutdown()
|
51
|
+
return
|
52
|
+
|
53
|
+
for recipient_id in packet.recipients or []:
|
54
|
+
queue = self._packet_queues.setdefault(recipient_id, asyncio.Queue())
|
55
|
+
await queue.put(packet)
|
56
|
+
|
57
|
+
def _ensure_final_future(self) -> asyncio.Future[Packet[Any]]:
|
58
|
+
fut = self._final_result_fut
|
59
|
+
if fut is None:
|
60
|
+
fut = asyncio.get_running_loop().create_future()
|
61
|
+
self._final_result_fut = fut
|
62
|
+
return fut
|
63
|
+
|
64
|
+
async def final_result(self) -> Packet[Any]:
|
65
|
+
fut = self._ensure_final_future()
|
66
|
+
try:
|
67
|
+
return await fut
|
68
|
+
finally:
|
69
|
+
await self.shutdown()
|
70
|
+
|
71
|
+
def register_packet_handler(
|
72
|
+
self,
|
73
|
+
proc_name: ProcName,
|
74
|
+
handler: PacketHandler[Any, CtxT],
|
75
|
+
ctx: RunContext[CtxT],
|
76
|
+
**run_kwargs: Any,
|
77
|
+
) -> None:
|
78
|
+
if self._stopping:
|
79
|
+
raise RuntimeError("PacketPool is stopping/stopped")
|
80
|
+
|
81
|
+
self._packet_handlers[proc_name] = handler
|
82
|
+
self._packet_queues.setdefault(proc_name, asyncio.Queue())
|
83
|
+
|
84
|
+
if self._task_group is not None:
|
85
|
+
self._task_group.create_task(
|
86
|
+
self._handle_packets(proc_name, ctx=ctx, **run_kwargs),
|
87
|
+
name=f"packet-handler:{proc_name}",
|
88
|
+
)
|
89
|
+
|
90
|
+
async def push_event(self, event: Event[Any]) -> None:
|
91
|
+
await self._event_queue.put(event)
|
92
|
+
|
93
|
+
async def __aenter__(self) -> "PacketPool[CtxT]":
|
94
|
+
self._task_group = asyncio.TaskGroup()
|
95
|
+
await self._task_group.__aenter__()
|
96
|
+
|
97
|
+
return self
|
98
|
+
|
99
|
+
async def __aexit__(
|
100
|
+
self,
|
101
|
+
exc_type: type[BaseException] | None,
|
102
|
+
exc: BaseException | None,
|
103
|
+
tb: TracebackType | None,
|
104
|
+
) -> bool | None:
|
105
|
+
await self.shutdown()
|
106
|
+
|
107
|
+
if self._task_group is not None:
|
108
|
+
try:
|
109
|
+
return await self._task_group.__aexit__(exc_type, exc, tb)
|
110
|
+
finally:
|
111
|
+
self._task_group = None
|
112
|
+
|
113
|
+
if self._errors:
|
114
|
+
raise ExceptionGroup("PacketPool worker errors", self._errors)
|
115
|
+
|
116
|
+
return False
|
117
|
+
|
118
|
+
async def _handle_packets(
|
119
|
+
self, proc_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
|
120
|
+
) -> None:
|
121
|
+
queue = self._packet_queues[proc_name]
|
122
|
+
handler = self._packet_handlers[proc_name]
|
123
|
+
|
124
|
+
while True:
|
125
|
+
packet = await queue.get()
|
126
|
+
if packet is None:
|
127
|
+
break
|
128
|
+
try:
|
129
|
+
await handler(packet, ctx=ctx, **run_kwargs)
|
130
|
+
except asyncio.CancelledError:
|
131
|
+
raise
|
132
|
+
except Exception as err:
|
133
|
+
logger.exception("Error handling packet for %s", proc_name)
|
134
|
+
self._errors.append(err)
|
135
|
+
fut = self._final_result_fut
|
136
|
+
if fut and not fut.done():
|
137
|
+
fut.set_exception(err)
|
138
|
+
await self.shutdown()
|
139
|
+
raise
|
140
|
+
|
141
|
+
async def stream_events(self) -> AsyncIterator[Event[Any]]:
|
142
|
+
while True:
|
143
|
+
event = await self._event_queue.get()
|
144
|
+
if event is None:
|
145
|
+
break
|
146
|
+
yield event
|
147
|
+
|
148
|
+
async def shutdown(self) -> None:
|
149
|
+
if self._stopping:
|
150
|
+
await self._stopped_evt.wait()
|
151
|
+
return
|
152
|
+
self._stopping = True
|
153
|
+
try:
|
154
|
+
await self._event_queue.put(None)
|
155
|
+
for queue in self._packet_queues.values():
|
156
|
+
await queue.put(None)
|
157
|
+
|
158
|
+
finally:
|
159
|
+
self._stopped_evt.set()
|
@@ -219,12 +219,12 @@ async def print_event_stream(
|
|
219
219
|
) -> None:
|
220
220
|
color = _get_color(event, Role.ASSISTANT)
|
221
221
|
|
222
|
-
if isinstance(event,
|
223
|
-
src = "processor"
|
224
|
-
elif isinstance(event, WorkflowResultEvent):
|
222
|
+
if isinstance(event, WorkflowResultEvent):
|
225
223
|
src = "workflow"
|
226
|
-
|
224
|
+
elif isinstance(event, RunResultEvent):
|
227
225
|
src = "run"
|
226
|
+
else:
|
227
|
+
src = "processor"
|
228
228
|
|
229
229
|
text = f"\n<{event.proc_name}> [{event.call_id}]\n"
|
230
230
|
|
@@ -232,6 +232,10 @@ async def print_event_stream(
|
|
232
232
|
text += f"<{src} output>\n"
|
233
233
|
for p in event.data.payloads:
|
234
234
|
if isinstance(p, BaseModel):
|
235
|
+
for field_info in type(p).model_fields.values():
|
236
|
+
if field_info.exclude:
|
237
|
+
field_info.exclude = False
|
238
|
+
type(p).model_rebuild(force=True)
|
235
239
|
p_str = p.model_dump_json(indent=2)
|
236
240
|
else:
|
237
241
|
try:
|