grasp_agents 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +28 -0
- grasp_agents/agent_message_pool.py +94 -0
- grasp_agents/base_agent.py +72 -0
- grasp_agents/cloud_llm.py +353 -0
- grasp_agents/comm_agent.py +230 -0
- grasp_agents/costs_dict.yaml +122 -0
- grasp_agents/data_retrieval/__init__.py +7 -0
- grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
- grasp_agents/data_retrieval/types.py +57 -0
- grasp_agents/data_retrieval/utils.py +57 -0
- grasp_agents/grasp_logging.py +36 -0
- grasp_agents/http_client.py +24 -0
- grasp_agents/llm.py +106 -0
- grasp_agents/llm_agent.py +361 -0
- grasp_agents/llm_agent_state.py +73 -0
- grasp_agents/memory.py +150 -0
- grasp_agents/openai/__init__.py +83 -0
- grasp_agents/openai/completion_converters.py +49 -0
- grasp_agents/openai/content_converters.py +80 -0
- grasp_agents/openai/converters.py +170 -0
- grasp_agents/openai/message_converters.py +155 -0
- grasp_agents/openai/openai_llm.py +179 -0
- grasp_agents/openai/tool_converters.py +37 -0
- grasp_agents/printer.py +156 -0
- grasp_agents/prompt_builder.py +204 -0
- grasp_agents/run_context.py +90 -0
- grasp_agents/tool_orchestrator.py +181 -0
- grasp_agents/typing/__init__.py +0 -0
- grasp_agents/typing/completion.py +30 -0
- grasp_agents/typing/content.py +116 -0
- grasp_agents/typing/converters.py +118 -0
- grasp_agents/typing/io.py +32 -0
- grasp_agents/typing/message.py +130 -0
- grasp_agents/typing/tool.py +52 -0
- grasp_agents/usage_tracker.py +99 -0
- grasp_agents/utils.py +151 -0
- grasp_agents/workflow/__init__.py +0 -0
- grasp_agents/workflow/looped_agent.py +113 -0
- grasp_agents/workflow/sequential_agent.py +57 -0
- grasp_agents/workflow/workflow_agent.py +69 -0
- grasp_agents-0.1.5.dist-info/METADATA +14 -0
- grasp_agents-0.1.5.dist-info/RECORD +44 -0
- grasp_agents-0.1.5.dist-info/WHEEL +4 -0
- grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,24 @@
|
|
1
|
+
import httpx
|
2
|
+
from pydantic import BaseModel, NonNegativeFloat, PositiveInt
|
3
|
+
|
4
|
+
|
5
|
+
class AsyncHTTPClientParams(BaseModel):
|
6
|
+
timeout: NonNegativeFloat = 10
|
7
|
+
max_connections: PositiveInt = 2000
|
8
|
+
max_keepalive_connections: PositiveInt = 500
|
9
|
+
keepalive_expiry: float | None = 5
|
10
|
+
|
11
|
+
|
12
|
+
def create_async_http_client(
|
13
|
+
client_params: AsyncHTTPClientParams,
|
14
|
+
) -> httpx.AsyncClient:
|
15
|
+
http_client = httpx.AsyncClient(
|
16
|
+
timeout=httpx.Timeout(client_params.timeout),
|
17
|
+
limits=httpx.Limits(
|
18
|
+
max_connections=client_params.max_connections,
|
19
|
+
max_keepalive_connections=client_params.max_keepalive_connections,
|
20
|
+
keepalive_expiry=client_params.keepalive_expiry,
|
21
|
+
),
|
22
|
+
)
|
23
|
+
|
24
|
+
return http_client
|
grasp_agents/llm.py
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
import logging
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from collections.abc import AsyncIterator, Sequence
|
4
|
+
from typing import Any, Generic, TypeVar
|
5
|
+
from uuid import uuid4
|
6
|
+
|
7
|
+
from pydantic import BaseModel
|
8
|
+
from typing_extensions import TypedDict
|
9
|
+
|
10
|
+
from .memory import MessageHistory
|
11
|
+
from .typing.completion import Completion, CompletionChunk
|
12
|
+
from .typing.converters import Converters
|
13
|
+
from .typing.message import AssistantMessage, Conversation
|
14
|
+
from .typing.tool import BaseTool, ToolChoice
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class LLMSettings(TypedDict):
|
20
|
+
pass
|
21
|
+
|
22
|
+
|
23
|
+
SettingsT = TypeVar("SettingsT", bound=LLMSettings, covariant=True) # noqa: PLC0105
|
24
|
+
ConvertT = TypeVar("ConvertT", bound=Converters, covariant=True) # noqa: PLC0105
|
25
|
+
|
26
|
+
|
27
|
+
class LLM(ABC, Generic[SettingsT, ConvertT]):
|
28
|
+
@abstractmethod
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
converters: ConvertT,
|
32
|
+
model_name: str | None = None,
|
33
|
+
model_id: str | None = None,
|
34
|
+
llm_settings: SettingsT | None = None,
|
35
|
+
tools: list[BaseTool[BaseModel, BaseModel, Any]] | None = None,
|
36
|
+
response_format: type | None = None,
|
37
|
+
**kwargs: Any,
|
38
|
+
) -> None:
|
39
|
+
super().__init__()
|
40
|
+
|
41
|
+
self._converters = converters
|
42
|
+
self._model_id = model_id or str(uuid4())[:8]
|
43
|
+
self._model_name = model_name
|
44
|
+
self._llm_settings = llm_settings
|
45
|
+
self._tools = {t.name: t for t in tools} if tools else None
|
46
|
+
self._response_format = response_format
|
47
|
+
|
48
|
+
@property
|
49
|
+
def model_id(self) -> str:
|
50
|
+
return self._model_id
|
51
|
+
|
52
|
+
@property
|
53
|
+
def model_name(self) -> str | None:
|
54
|
+
return self._model_name
|
55
|
+
|
56
|
+
@property
|
57
|
+
def llm_settings(self) -> SettingsT | None:
|
58
|
+
return self._llm_settings
|
59
|
+
|
60
|
+
@property
|
61
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, BaseModel, Any]] | None:
|
62
|
+
return self._tools
|
63
|
+
|
64
|
+
@property
|
65
|
+
def response_format(self) -> type | None:
|
66
|
+
return self._response_format
|
67
|
+
|
68
|
+
@tools.setter
|
69
|
+
def tools(self, tools: list[BaseTool[BaseModel, BaseModel, Any]] | None) -> None:
|
70
|
+
self._tools = {t.name: t for t in tools} if tools else None
|
71
|
+
|
72
|
+
def __repr__(self) -> str:
|
73
|
+
return (
|
74
|
+
f"{type(self).__name__}(model_id={self.model_id}; "
|
75
|
+
f"model_name={self._model_name})"
|
76
|
+
)
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
async def generate_completion(
|
80
|
+
self,
|
81
|
+
conversation: Conversation,
|
82
|
+
*,
|
83
|
+
tool_choice: ToolChoice | None = None,
|
84
|
+
**kwargs: Any,
|
85
|
+
) -> Completion:
|
86
|
+
pass
|
87
|
+
|
88
|
+
@abstractmethod
|
89
|
+
async def generate_completion_stream(
|
90
|
+
self,
|
91
|
+
conversation: Conversation,
|
92
|
+
*,
|
93
|
+
tool_choice: ToolChoice | None = None,
|
94
|
+
**kwargs: Any,
|
95
|
+
) -> AsyncIterator[CompletionChunk]:
|
96
|
+
pass
|
97
|
+
|
98
|
+
@abstractmethod
|
99
|
+
async def generate_message_batch(
|
100
|
+
self,
|
101
|
+
message_history: MessageHistory,
|
102
|
+
*,
|
103
|
+
tool_choice: ToolChoice | None = None,
|
104
|
+
**kwargs: Any,
|
105
|
+
) -> Sequence[AssistantMessage]:
|
106
|
+
pass
|
@@ -0,0 +1,361 @@
|
|
1
|
+
from collections.abc import Sequence
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Any, Generic, cast, final
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from .agent_message import AgentMessage
|
8
|
+
from .agent_message_pool import AgentMessagePool
|
9
|
+
from .comm_agent import CommunicatingAgent
|
10
|
+
from .llm import LLM, LLMSettings
|
11
|
+
from .llm_agent_state import (
|
12
|
+
LLMAgentState,
|
13
|
+
MakeCustomAgentState,
|
14
|
+
SetAgentStateStrategy,
|
15
|
+
)
|
16
|
+
from .prompt_builder import (
|
17
|
+
FormatInputArgsHandler,
|
18
|
+
FormatSystemArgsHandler,
|
19
|
+
PromptBuilder,
|
20
|
+
)
|
21
|
+
from .run_context import (
|
22
|
+
CtxT,
|
23
|
+
InteractionRecord,
|
24
|
+
RunContextWrapper,
|
25
|
+
SystemRunArgs,
|
26
|
+
UserRunArgs,
|
27
|
+
)
|
28
|
+
from .tool_orchestrator import ToolCallLoopExitHandler, ToolOrchestrator
|
29
|
+
from .typing.content import ImageData
|
30
|
+
from .typing.converters import Converters
|
31
|
+
from .typing.io import (
|
32
|
+
AgentID,
|
33
|
+
AgentPayload,
|
34
|
+
AgentState,
|
35
|
+
InT,
|
36
|
+
LLMPrompt,
|
37
|
+
LLMPromptArgs,
|
38
|
+
OutT,
|
39
|
+
)
|
40
|
+
from .typing.message import Conversation, Message, SystemMessage
|
41
|
+
from .typing.tool import BaseTool
|
42
|
+
from .utils import get_prompt
|
43
|
+
|
44
|
+
|
45
|
+
class LLMAgent(
|
46
|
+
CommunicatingAgent[InT, OutT, LLMAgentState, CtxT],
|
47
|
+
Generic[InT, OutT, CtxT],
|
48
|
+
):
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
agent_id: AgentID,
|
52
|
+
*,
|
53
|
+
# LLM
|
54
|
+
llm: LLM[LLMSettings, Converters],
|
55
|
+
# Input prompt template (combines user and received arguments)
|
56
|
+
inp_prompt: LLMPrompt | None = None,
|
57
|
+
inp_prompt_path: str | Path | None = None,
|
58
|
+
# System prompt template
|
59
|
+
sys_prompt: LLMPrompt | None = None,
|
60
|
+
sys_prompt_path: str | Path | None = None,
|
61
|
+
# System args (static args provided via RunContextWrapper)
|
62
|
+
sys_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
63
|
+
# User args (static args provided via RunContextWrapper)
|
64
|
+
usr_args_schema: type[LLMPromptArgs] = LLMPromptArgs,
|
65
|
+
# Received args (args from another agent)
|
66
|
+
rcv_args_schema: type[InT] = cast("type[InT]", AgentPayload),
|
67
|
+
# Output schema
|
68
|
+
out_schema: type[OutT] = cast("type[OutT]", AgentPayload),
|
69
|
+
# Tools
|
70
|
+
tools: list[BaseTool[BaseModel, BaseModel, CtxT]] | None = None,
|
71
|
+
max_turns: int = 1000,
|
72
|
+
react_mode: bool = False,
|
73
|
+
# Agent state management
|
74
|
+
set_state_strategy: SetAgentStateStrategy = "keep",
|
75
|
+
# Multi-agent routing
|
76
|
+
message_pool: AgentMessagePool[CtxT] | None = None,
|
77
|
+
recipient_ids: list[AgentID] | None = None,
|
78
|
+
dynamic_routing: bool = False,
|
79
|
+
) -> None:
|
80
|
+
super().__init__(
|
81
|
+
agent_id=agent_id,
|
82
|
+
out_schema=out_schema,
|
83
|
+
rcv_args_schema=rcv_args_schema,
|
84
|
+
message_pool=message_pool,
|
85
|
+
recipient_ids=recipient_ids,
|
86
|
+
dynamic_routing=dynamic_routing,
|
87
|
+
)
|
88
|
+
|
89
|
+
# Agent state
|
90
|
+
self._state: LLMAgentState = LLMAgentState()
|
91
|
+
self.set_state_strategy: SetAgentStateStrategy = set_state_strategy
|
92
|
+
self._make_custom_agent_state_impl: MakeCustomAgentState | None = None
|
93
|
+
|
94
|
+
# Tool orchestrator
|
95
|
+
self._tool_orchestrator: ToolOrchestrator[CtxT] = ToolOrchestrator[CtxT](
|
96
|
+
agent_id=self.agent_id,
|
97
|
+
llm=llm,
|
98
|
+
tools=tools,
|
99
|
+
max_turns=max_turns,
|
100
|
+
react_mode=react_mode,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Prompt builder
|
104
|
+
sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
|
105
|
+
inp_prompt = get_prompt(prompt_text=inp_prompt, prompt_path=inp_prompt_path)
|
106
|
+
self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[InT, CtxT](
|
107
|
+
agent_id=self._agent_id,
|
108
|
+
sys_prompt=sys_prompt,
|
109
|
+
inp_prompt=inp_prompt,
|
110
|
+
sys_args_schema=sys_args_schema,
|
111
|
+
usr_args_schema=usr_args_schema,
|
112
|
+
rcv_args_schema=rcv_args_schema,
|
113
|
+
)
|
114
|
+
|
115
|
+
self.no_tqdm = getattr(llm, "no_tqdm", False)
|
116
|
+
|
117
|
+
@property
|
118
|
+
def llm(self) -> LLM[LLMSettings, Converters]:
|
119
|
+
return self._tool_orchestrator.llm
|
120
|
+
|
121
|
+
@property
|
122
|
+
def tools(self) -> dict[str, BaseTool[BaseModel, BaseModel, CtxT]]:
|
123
|
+
return self._tool_orchestrator.tools
|
124
|
+
|
125
|
+
@property
|
126
|
+
def max_turns(self) -> int:
|
127
|
+
return self._tool_orchestrator.max_turns
|
128
|
+
|
129
|
+
@property
|
130
|
+
def sys_args_schema(self) -> type[LLMPromptArgs]:
|
131
|
+
return self._prompt_builder.sys_args_schema
|
132
|
+
|
133
|
+
@property
|
134
|
+
def usr_args_schema(self) -> type[LLMPromptArgs]:
|
135
|
+
return self._prompt_builder.usr_args_schema
|
136
|
+
|
137
|
+
@property
|
138
|
+
def sys_prompt(self) -> LLMPrompt | None:
|
139
|
+
return self._prompt_builder.sys_prompt
|
140
|
+
|
141
|
+
@property
|
142
|
+
def inp_prompt(self) -> LLMPrompt | None:
|
143
|
+
return self._prompt_builder.inp_prompt
|
144
|
+
|
145
|
+
def format_sys_args_handler(
|
146
|
+
self, func: FormatSystemArgsHandler[CtxT]
|
147
|
+
) -> FormatSystemArgsHandler[CtxT]:
|
148
|
+
self._prompt_builder.format_sys_args_impl = func
|
149
|
+
|
150
|
+
return func
|
151
|
+
|
152
|
+
def format_inp_args_handler(
|
153
|
+
self, func: FormatInputArgsHandler[InT, CtxT]
|
154
|
+
) -> FormatInputArgsHandler[InT, CtxT]:
|
155
|
+
self._prompt_builder.format_inp_args_impl = func
|
156
|
+
|
157
|
+
return func
|
158
|
+
|
159
|
+
def make_custom_agent_state_handler(
|
160
|
+
self, func: MakeCustomAgentState
|
161
|
+
) -> MakeCustomAgentState:
|
162
|
+
self._make_custom_agent_state_impl = func
|
163
|
+
|
164
|
+
return func
|
165
|
+
|
166
|
+
def tool_call_loop_exit_handler(
|
167
|
+
self, func: ToolCallLoopExitHandler[CtxT]
|
168
|
+
) -> ToolCallLoopExitHandler[CtxT]:
|
169
|
+
self._tool_orchestrator.tool_call_loop_exit_impl = func
|
170
|
+
|
171
|
+
return func
|
172
|
+
|
173
|
+
def _parse_output(
|
174
|
+
self,
|
175
|
+
conversation: Conversation,
|
176
|
+
rcv_args: InT | None = None,
|
177
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
178
|
+
**kwargs: Any,
|
179
|
+
) -> OutT:
|
180
|
+
if self._parse_output_impl:
|
181
|
+
return self._parse_output_impl(
|
182
|
+
conversation=conversation, rcv_args=rcv_args, ctx=ctx, **kwargs
|
183
|
+
)
|
184
|
+
try:
|
185
|
+
return self._out_schema.model_validate_json(str(conversation[-1].content))
|
186
|
+
except Exception:
|
187
|
+
return self._out_schema()
|
188
|
+
|
189
|
+
@final
|
190
|
+
async def run(
|
191
|
+
self,
|
192
|
+
inp_items: LLMPrompt | list[str | ImageData] | None = None,
|
193
|
+
*,
|
194
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
195
|
+
rcv_message: AgentMessage[InT, LLMAgentState] | None = None,
|
196
|
+
entry_point: bool = False,
|
197
|
+
forbid_state_change: bool = False,
|
198
|
+
**gen_kwargs: Any, # noqa: ARG002
|
199
|
+
) -> AgentMessage[OutT, LLMAgentState]:
|
200
|
+
# Get run arguments
|
201
|
+
sys_args: SystemRunArgs = LLMPromptArgs()
|
202
|
+
usr_args: UserRunArgs = LLMPromptArgs()
|
203
|
+
if ctx is not None:
|
204
|
+
run_args = ctx.run_args.get(self.agent_id)
|
205
|
+
if run_args is not None:
|
206
|
+
sys_args = run_args.sys
|
207
|
+
usr_args = run_args.usr
|
208
|
+
|
209
|
+
if entry_point:
|
210
|
+
assert rcv_message is None, (
|
211
|
+
"Entry point run should not have a received message"
|
212
|
+
)
|
213
|
+
if inp_items:
|
214
|
+
assert rcv_message is None, (
|
215
|
+
"There must be no received message with user inputs"
|
216
|
+
)
|
217
|
+
|
218
|
+
cur_state = self.state.model_copy(deep=True)
|
219
|
+
|
220
|
+
# 1. Make system prompt (can be None)
|
221
|
+
formatted_sys_prompt = self._prompt_builder.make_sys_prompt(
|
222
|
+
sys_args=sys_args, ctx=ctx
|
223
|
+
)
|
224
|
+
|
225
|
+
# 2. Set agent state
|
226
|
+
|
227
|
+
rcv_state = rcv_message.sender_state if rcv_message else None
|
228
|
+
prev_mh_len = len(cur_state.message_history)
|
229
|
+
|
230
|
+
state = LLMAgentState.from_cur_and_rcv_states(
|
231
|
+
cur_state=cur_state,
|
232
|
+
rcv_state=rcv_state,
|
233
|
+
sys_prompt=formatted_sys_prompt,
|
234
|
+
strategy=self.set_state_strategy,
|
235
|
+
make_custom_state_impl=self._make_custom_agent_state_impl,
|
236
|
+
ctx=ctx,
|
237
|
+
)
|
238
|
+
|
239
|
+
self._print_sys_msg(state=state, prev_mh_len=prev_mh_len, ctx=ctx)
|
240
|
+
|
241
|
+
# 3. Make and add user messages (can be empty)
|
242
|
+
user_message_batch = self._prompt_builder.make_user_messages(
|
243
|
+
inp_items=inp_items,
|
244
|
+
usr_args=usr_args,
|
245
|
+
rcv_args_batch=rcv_message.payloads if rcv_message else None,
|
246
|
+
entry_point=entry_point,
|
247
|
+
ctx=ctx,
|
248
|
+
)
|
249
|
+
if user_message_batch:
|
250
|
+
state.message_history.add_message_batch(user_message_batch)
|
251
|
+
self._print_msgs(user_message_batch, ctx=ctx)
|
252
|
+
|
253
|
+
if not self.tools:
|
254
|
+
# 4. Generate messages without tools
|
255
|
+
await self._tool_orchestrator.generate_once(agent_state=state, ctx=ctx)
|
256
|
+
else:
|
257
|
+
# 4. Run tool call loop (new messages are added to the message
|
258
|
+
# history inside the loop)
|
259
|
+
await self._tool_orchestrator.run_loop(agent_state=state, ctx=ctx)
|
260
|
+
|
261
|
+
# 5. Parse outputs
|
262
|
+
batch_size = state.message_history.batch_size
|
263
|
+
rcv_args_batch = rcv_message.payloads if rcv_message else batch_size * [None]
|
264
|
+
val_output_batch = [
|
265
|
+
self.out_schema.model_validate(
|
266
|
+
self._parse_output(conversation=conv, rcv_args=rcv_args, ctx=ctx)
|
267
|
+
)
|
268
|
+
for conv, rcv_args in zip(
|
269
|
+
state.message_history.batched_conversations,
|
270
|
+
rcv_args_batch,
|
271
|
+
strict=False,
|
272
|
+
)
|
273
|
+
]
|
274
|
+
|
275
|
+
# 6. Write interaction history to context
|
276
|
+
|
277
|
+
if self.dynamic_routing:
|
278
|
+
recipient_ids = self._validate_dynamic_routing(val_output_batch)
|
279
|
+
else:
|
280
|
+
recipient_ids = self._validate_static_routing(val_output_batch)
|
281
|
+
|
282
|
+
if ctx:
|
283
|
+
interaction_record = InteractionRecord(
|
284
|
+
source_id=self.agent_id,
|
285
|
+
recipient_ids=recipient_ids,
|
286
|
+
inp_items=inp_items,
|
287
|
+
sys_prompt=self.sys_prompt,
|
288
|
+
inp_prompt=self.inp_prompt,
|
289
|
+
sys_args=sys_args,
|
290
|
+
usr_args=usr_args,
|
291
|
+
rcv_args=(rcv_message.payloads if rcv_message is not None else None),
|
292
|
+
outputs=val_output_batch,
|
293
|
+
state=state,
|
294
|
+
)
|
295
|
+
ctx.interaction_history.append(
|
296
|
+
cast(
|
297
|
+
"InteractionRecord[AgentPayload, AgentPayload, AgentState]",
|
298
|
+
interaction_record,
|
299
|
+
)
|
300
|
+
)
|
301
|
+
|
302
|
+
agent_message = AgentMessage(
|
303
|
+
payloads=val_output_batch,
|
304
|
+
sender_id=self.agent_id,
|
305
|
+
sender_state=state,
|
306
|
+
recipient_ids=recipient_ids,
|
307
|
+
)
|
308
|
+
|
309
|
+
if not forbid_state_change:
|
310
|
+
self._state = state
|
311
|
+
|
312
|
+
return agent_message
|
313
|
+
|
314
|
+
def _print_msgs(
|
315
|
+
self,
|
316
|
+
messages: Sequence[Message],
|
317
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
318
|
+
) -> None:
|
319
|
+
if ctx:
|
320
|
+
ctx.printer.print_llm_messages(messages, agent_id=self.agent_id)
|
321
|
+
|
322
|
+
def _print_sys_msg(
|
323
|
+
self,
|
324
|
+
state: LLMAgentState,
|
325
|
+
prev_mh_len: int,
|
326
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
327
|
+
) -> None:
|
328
|
+
if (
|
329
|
+
len(state.message_history) == 1
|
330
|
+
and prev_mh_len == 0
|
331
|
+
and isinstance(state.message_history[0][0], SystemMessage)
|
332
|
+
):
|
333
|
+
self._print_msgs([state.message_history[0][0]], ctx=ctx)
|
334
|
+
|
335
|
+
# def _format_sys_args(
|
336
|
+
# self,
|
337
|
+
# sys_args: LLMPromptArgs,
|
338
|
+
# ctx: RunContextWrapper[CtxT] | None = None,
|
339
|
+
# ) -> LLMFormattedSystemArgs:
|
340
|
+
# return self._prompt_builder.format_sys_args(sys_args=sys_args, ctx=ctx)
|
341
|
+
|
342
|
+
# def _format_inp_args(
|
343
|
+
# self,
|
344
|
+
# usr_args: LLMPromptArgs,
|
345
|
+
# rcv_args: InT,
|
346
|
+
# ctx: RunContextWrapper[CtxT] | None = None,
|
347
|
+
# ) -> LLMFormattedArgs:
|
348
|
+
# return self._prompt_builder.format_inp_args(
|
349
|
+
# usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
|
350
|
+
# )
|
351
|
+
|
352
|
+
# def _tool_call_loop_exit(
|
353
|
+
# self,
|
354
|
+
# conversation: Conversation,
|
355
|
+
# *,
|
356
|
+
# ctx: RunContextWrapper[CtxT] | None = None,
|
357
|
+
# **kwargs: Any,
|
358
|
+
# ) -> bool:
|
359
|
+
# return self._tool_orchestrator.tool_call_loop_exit(
|
360
|
+
# conversation=conversation, ctx=ctx, **kwargs
|
361
|
+
# )
|
@@ -0,0 +1,73 @@
|
|
1
|
+
from copy import deepcopy
|
2
|
+
from typing import Any, Literal, Optional, Protocol
|
3
|
+
|
4
|
+
from pydantic import ConfigDict, Field
|
5
|
+
|
6
|
+
from .memory import MessageHistory
|
7
|
+
from .run_context import RunContextWrapper
|
8
|
+
from .typing.io import AgentState, LLMPrompt
|
9
|
+
|
10
|
+
SetAgentStateStrategy = Literal["keep", "reset", "from_sender", "custom"]
|
11
|
+
|
12
|
+
|
13
|
+
class MakeCustomAgentState(Protocol):
|
14
|
+
def __call__(
|
15
|
+
self,
|
16
|
+
cur_state: Optional["LLMAgentState"],
|
17
|
+
rec_state: Optional["LLMAgentState"],
|
18
|
+
sys_prompt: LLMPrompt | None,
|
19
|
+
ctx: RunContextWrapper[Any] | None,
|
20
|
+
) -> "LLMAgentState": ...
|
21
|
+
|
22
|
+
|
23
|
+
class LLMAgentState(AgentState):
|
24
|
+
message_history: MessageHistory = Field(default_factory=MessageHistory)
|
25
|
+
|
26
|
+
@property
|
27
|
+
def batch_size(self) -> int:
|
28
|
+
return self.message_history.batch_size
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def from_cur_and_rcv_states(
|
32
|
+
cls,
|
33
|
+
cur_state: Optional["LLMAgentState"] = None,
|
34
|
+
rcv_state: Optional["LLMAgentState"] = None,
|
35
|
+
sys_prompt: LLMPrompt | None = None,
|
36
|
+
strategy: SetAgentStateStrategy = "from_sender",
|
37
|
+
make_custom_state_impl: MakeCustomAgentState | None = None,
|
38
|
+
ctx: RunContextWrapper[Any] | None = None,
|
39
|
+
) -> "LLMAgentState":
|
40
|
+
upd_mh = cur_state.message_history if cur_state else None
|
41
|
+
if upd_mh is None or len(upd_mh) == 0:
|
42
|
+
upd_mh = MessageHistory(sys_prompt=sys_prompt)
|
43
|
+
|
44
|
+
if strategy == "keep":
|
45
|
+
pass
|
46
|
+
|
47
|
+
elif strategy == "reset":
|
48
|
+
upd_mh.reset(sys_prompt)
|
49
|
+
|
50
|
+
elif strategy == "from_sender":
|
51
|
+
rcv_mh = rcv_state.message_history if rcv_state else None
|
52
|
+
if rcv_mh:
|
53
|
+
upd_mh = deepcopy(rcv_mh)
|
54
|
+
else:
|
55
|
+
upd_mh.reset(sys_prompt)
|
56
|
+
|
57
|
+
elif strategy == "custom":
|
58
|
+
assert make_custom_state_impl is not None, (
|
59
|
+
"Custom message history handler implementation is not provided."
|
60
|
+
)
|
61
|
+
return make_custom_state_impl(
|
62
|
+
cur_state=cur_state,
|
63
|
+
rec_state=rcv_state,
|
64
|
+
sys_prompt=sys_prompt,
|
65
|
+
ctx=ctx,
|
66
|
+
)
|
67
|
+
|
68
|
+
return cls(message_history=upd_mh)
|
69
|
+
|
70
|
+
def __repr__(self) -> str:
|
71
|
+
return f"Message History: {len(self.message_history)}"
|
72
|
+
|
73
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|