grasp_agents 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/agent_message.py +28 -0
- grasp_agents/agent_message_pool.py +94 -0
- grasp_agents/base_agent.py +72 -0
- grasp_agents/cloud_llm.py +353 -0
- grasp_agents/comm_agent.py +230 -0
- grasp_agents/costs_dict.yaml +122 -0
- grasp_agents/data_retrieval/__init__.py +7 -0
- grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
- grasp_agents/data_retrieval/types.py +57 -0
- grasp_agents/data_retrieval/utils.py +57 -0
- grasp_agents/grasp_logging.py +36 -0
- grasp_agents/http_client.py +24 -0
- grasp_agents/llm.py +106 -0
- grasp_agents/llm_agent.py +361 -0
- grasp_agents/llm_agent_state.py +73 -0
- grasp_agents/memory.py +150 -0
- grasp_agents/openai/__init__.py +83 -0
- grasp_agents/openai/completion_converters.py +49 -0
- grasp_agents/openai/content_converters.py +80 -0
- grasp_agents/openai/converters.py +170 -0
- grasp_agents/openai/message_converters.py +155 -0
- grasp_agents/openai/openai_llm.py +179 -0
- grasp_agents/openai/tool_converters.py +37 -0
- grasp_agents/printer.py +156 -0
- grasp_agents/prompt_builder.py +204 -0
- grasp_agents/run_context.py +90 -0
- grasp_agents/tool_orchestrator.py +181 -0
- grasp_agents/typing/__init__.py +0 -0
- grasp_agents/typing/completion.py +30 -0
- grasp_agents/typing/content.py +116 -0
- grasp_agents/typing/converters.py +118 -0
- grasp_agents/typing/io.py +32 -0
- grasp_agents/typing/message.py +130 -0
- grasp_agents/typing/tool.py +52 -0
- grasp_agents/usage_tracker.py +99 -0
- grasp_agents/utils.py +151 -0
- grasp_agents/workflow/__init__.py +0 -0
- grasp_agents/workflow/looped_agent.py +113 -0
- grasp_agents/workflow/sequential_agent.py +57 -0
- grasp_agents/workflow/workflow_agent.py +69 -0
- grasp_agents-0.1.5.dist-info/METADATA +14 -0
- grasp_agents-0.1.5.dist-info/RECORD +44 -0
- grasp_agents-0.1.5.dist-info/WHEEL +4 -0
- grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,179 @@
|
|
1
|
+
import logging
|
2
|
+
from collections.abc import Iterable
|
3
|
+
from copy import deepcopy
|
4
|
+
from typing import Any, Literal
|
5
|
+
|
6
|
+
from openai import AsyncOpenAI
|
7
|
+
from openai._types import NOT_GIVEN # noqa: PLC2701 # type: ignore[import]
|
8
|
+
from pydantic import BaseModel
|
9
|
+
|
10
|
+
from ..data_retrieval.rate_limiter_chunked import RateLimiterC
|
11
|
+
|
12
|
+
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
13
|
+
from ..http_client import AsyncHTTPClientParams
|
14
|
+
from ..typing.message import AssistantMessage, Conversation
|
15
|
+
from ..typing.tool import BaseTool
|
16
|
+
from . import (
|
17
|
+
ChatCompletion,
|
18
|
+
ChatCompletionAsyncStream, # type: ignore[import]
|
19
|
+
ChatCompletionChunk,
|
20
|
+
ChatCompletionMessageParam,
|
21
|
+
ChatCompletionPredictionContentParam,
|
22
|
+
ChatCompletionStreamOptionsParam,
|
23
|
+
ChatCompletionToolChoiceOptionParam,
|
24
|
+
ChatCompletionToolParam,
|
25
|
+
ParsedChatCompletion,
|
26
|
+
# ResponseFormatJSONObject,
|
27
|
+
# ResponseFormatJSONSchema,
|
28
|
+
# ResponseFormatText,
|
29
|
+
)
|
30
|
+
from .converters import OpenAIConverters
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
class OpenAILLMSettings(CloudLLMSettings, total=False):
|
36
|
+
reasoning_effort: Literal["low", "medium", "high"] | None
|
37
|
+
|
38
|
+
parallel_tool_calls: bool
|
39
|
+
|
40
|
+
# response_format: (
|
41
|
+
# ResponseFormatText | ResponseFormatJSONSchema | ResponseFormatJSONObject
|
42
|
+
# )
|
43
|
+
|
44
|
+
modalities: list[Literal["text", "audio"]] | None
|
45
|
+
|
46
|
+
frequency_penalty: float | None
|
47
|
+
presence_penalty: float | None
|
48
|
+
logit_bias: dict[str, int] | None
|
49
|
+
stop: str | list[str] | None
|
50
|
+
logprobs: bool | None
|
51
|
+
top_logprobs: int | None
|
52
|
+
n: int | None
|
53
|
+
|
54
|
+
prediction: ChatCompletionPredictionContentParam | None
|
55
|
+
|
56
|
+
stream_options: ChatCompletionStreamOptionsParam | None
|
57
|
+
|
58
|
+
metadata: dict[str, str] | None
|
59
|
+
store: bool | None
|
60
|
+
user: str
|
61
|
+
|
62
|
+
|
63
|
+
class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
# Base LLM args
|
67
|
+
model_name: str,
|
68
|
+
model_id: str | None = None,
|
69
|
+
llm_settings: OpenAILLMSettings | None = None,
|
70
|
+
tools: list[BaseTool[BaseModel, BaseModel, Any]] | None = None,
|
71
|
+
response_format: type | None = None,
|
72
|
+
# Connection settings
|
73
|
+
api_provider: APIProvider = "openai",
|
74
|
+
async_http_client_params: (
|
75
|
+
dict[str, Any] | AsyncHTTPClientParams | None
|
76
|
+
) = None,
|
77
|
+
async_openai_client_params: dict[str, Any] | None = None,
|
78
|
+
# Rate limiting
|
79
|
+
rate_limiter: (RateLimiterC[Conversation, AssistantMessage] | None) = None,
|
80
|
+
rate_limiter_rpm: float | None = None,
|
81
|
+
rate_limiter_chunk_size: int = 1000,
|
82
|
+
rate_limiter_max_concurrency: int = 300,
|
83
|
+
# Retries
|
84
|
+
num_generation_retries: int = 0,
|
85
|
+
# Disable tqdm for batch processing
|
86
|
+
no_tqdm: bool = True,
|
87
|
+
**kwargs: Any,
|
88
|
+
) -> None:
|
89
|
+
super().__init__(
|
90
|
+
model_name=model_name,
|
91
|
+
model_id=model_id,
|
92
|
+
llm_settings=llm_settings,
|
93
|
+
converters=OpenAIConverters(),
|
94
|
+
tools=tools,
|
95
|
+
response_format=response_format,
|
96
|
+
api_provider=api_provider,
|
97
|
+
async_http_client_params=async_http_client_params,
|
98
|
+
rate_limiter=rate_limiter,
|
99
|
+
rate_limiter_rpm=rate_limiter_rpm,
|
100
|
+
rate_limiter_chunk_size=rate_limiter_chunk_size,
|
101
|
+
rate_limiter_max_concurrency=rate_limiter_max_concurrency,
|
102
|
+
num_generation_retries=num_generation_retries,
|
103
|
+
no_tqdm=no_tqdm,
|
104
|
+
**kwargs,
|
105
|
+
)
|
106
|
+
|
107
|
+
async_openai_client_params_ = deepcopy(async_openai_client_params or {})
|
108
|
+
if self._async_http_client is not None:
|
109
|
+
async_openai_client_params_["http_client"] = self._async_http_client
|
110
|
+
|
111
|
+
# TODO: context manager for async client
|
112
|
+
self._client: AsyncOpenAI = AsyncOpenAI(
|
113
|
+
base_url=self._base_url,
|
114
|
+
api_key=self._api_key,
|
115
|
+
**async_openai_client_params_,
|
116
|
+
# timeout=10.0,
|
117
|
+
# max_retries=3,
|
118
|
+
)
|
119
|
+
|
120
|
+
async def _get_completion(
|
121
|
+
self,
|
122
|
+
api_messages: Iterable[ChatCompletionMessageParam],
|
123
|
+
api_tools: list[ChatCompletionToolParam] | None = None,
|
124
|
+
api_tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
|
125
|
+
**api_llm_settings: Any,
|
126
|
+
) -> ChatCompletion:
|
127
|
+
tools = api_tools or NOT_GIVEN
|
128
|
+
tool_choice = api_tool_choice or NOT_GIVEN
|
129
|
+
|
130
|
+
return await self._client.chat.completions.create(
|
131
|
+
model=self._model_name,
|
132
|
+
messages=api_messages,
|
133
|
+
tools=tools,
|
134
|
+
tool_choice=tool_choice,
|
135
|
+
stream=False,
|
136
|
+
**api_llm_settings,
|
137
|
+
)
|
138
|
+
|
139
|
+
async def _get_parsed_completion(
|
140
|
+
self,
|
141
|
+
api_messages: Iterable[ChatCompletionMessageParam],
|
142
|
+
api_tools: list[ChatCompletionToolParam] | None = None,
|
143
|
+
api_tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
|
144
|
+
api_response_format: type | None = None,
|
145
|
+
**api_llm_settings: Any,
|
146
|
+
) -> ParsedChatCompletion[Any]:
|
147
|
+
tools = api_tools or NOT_GIVEN
|
148
|
+
tool_choice = api_tool_choice or NOT_GIVEN
|
149
|
+
response_format = api_response_format or NOT_GIVEN
|
150
|
+
|
151
|
+
return await self._client.beta.chat.completions.parse(
|
152
|
+
model=self._model_name,
|
153
|
+
messages=api_messages,
|
154
|
+
tools=tools,
|
155
|
+
tool_choice=tool_choice,
|
156
|
+
response_format=response_format, # type: ignore[arg-type]
|
157
|
+
**api_llm_settings,
|
158
|
+
)
|
159
|
+
|
160
|
+
async def _get_completion_stream(
|
161
|
+
self,
|
162
|
+
api_messages: Iterable[ChatCompletionMessageParam],
|
163
|
+
api_tools: list[ChatCompletionToolParam] | None = None,
|
164
|
+
api_tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
|
165
|
+
**api_llm_settings: Any,
|
166
|
+
) -> ChatCompletionAsyncStream[ChatCompletionChunk]:
|
167
|
+
assert not api_tools, "Tool use is not supported in streaming mode"
|
168
|
+
|
169
|
+
tools = api_tools or NOT_GIVEN
|
170
|
+
tool_choice = api_tool_choice or NOT_GIVEN
|
171
|
+
|
172
|
+
return await self._client.chat.completions.create(
|
173
|
+
model=self._model_name,
|
174
|
+
messages=api_messages,
|
175
|
+
tools=tools,
|
176
|
+
tool_choice=tool_choice,
|
177
|
+
stream=True,
|
178
|
+
**api_llm_settings,
|
179
|
+
)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
from ..typing.tool import BaseTool, ToolChoice
|
6
|
+
from . import (
|
7
|
+
ChatCompletionFunctionDefinition,
|
8
|
+
ChatCompletionNamedToolChoiceFunction,
|
9
|
+
ChatCompletionNamedToolChoiceParam,
|
10
|
+
ChatCompletionToolChoiceOptionParam,
|
11
|
+
ChatCompletionToolParam,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
def to_api_tool(
|
16
|
+
tool: BaseTool[BaseModel, BaseModel, Any],
|
17
|
+
) -> ChatCompletionToolParam:
|
18
|
+
return ChatCompletionToolParam(
|
19
|
+
type="function",
|
20
|
+
function=ChatCompletionFunctionDefinition(
|
21
|
+
name=tool.name,
|
22
|
+
description=tool.description,
|
23
|
+
parameters=tool.in_schema.model_json_schema(),
|
24
|
+
strict=tool.strict,
|
25
|
+
),
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
def to_api_tool_choice(
|
30
|
+
tool_choice: ToolChoice,
|
31
|
+
) -> ChatCompletionToolChoiceOptionParam:
|
32
|
+
if isinstance(tool_choice, BaseTool):
|
33
|
+
return ChatCompletionNamedToolChoiceParam(
|
34
|
+
type="function",
|
35
|
+
function=ChatCompletionNamedToolChoiceFunction(name=tool_choice.name),
|
36
|
+
)
|
37
|
+
return tool_choice
|
grasp_agents/printer.py
ADDED
@@ -0,0 +1,156 @@
|
|
1
|
+
import hashlib
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
from collections.abc import Mapping, Sequence
|
5
|
+
from typing import Literal, TypeAlias
|
6
|
+
|
7
|
+
from termcolor._types import Color # type: ignore[import]
|
8
|
+
|
9
|
+
from .typing.content import Content, ContentPartText
|
10
|
+
from .typing.message import AssistantMessage, Message, Role, ToolMessage
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
ColoringMode: TypeAlias = Literal["agent_id", "role"]
|
16
|
+
|
17
|
+
ROLE_TO_COLOR: Mapping[Role, Color] = {
|
18
|
+
Role.SYSTEM: "magenta",
|
19
|
+
Role.USER: "green",
|
20
|
+
Role.ASSISTANT: "light_blue",
|
21
|
+
Role.TOOL: "light_cyan",
|
22
|
+
}
|
23
|
+
|
24
|
+
AVAILABLE_COLORS: list[Color] = [
|
25
|
+
"magenta",
|
26
|
+
"green",
|
27
|
+
"light_blue",
|
28
|
+
"light_cyan",
|
29
|
+
"yellow",
|
30
|
+
"blue",
|
31
|
+
"red",
|
32
|
+
]
|
33
|
+
|
34
|
+
|
35
|
+
class Printer:
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
source_id: str,
|
39
|
+
color_by: ColoringMode = "role",
|
40
|
+
msg_trunc_len: int = 20000,
|
41
|
+
print_messages: bool = False,
|
42
|
+
) -> None:
|
43
|
+
self.source_id = source_id
|
44
|
+
self.color_by = color_by
|
45
|
+
self.msg_trunc_len = msg_trunc_len
|
46
|
+
self.print_messages = print_messages
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def get_role_color(role: Role) -> Color:
|
50
|
+
return ROLE_TO_COLOR[role]
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def get_agent_color(agent_id: str) -> Color:
|
54
|
+
idx = int(
|
55
|
+
hashlib.md5(agent_id.encode()).hexdigest(), # noqa :S324
|
56
|
+
16,
|
57
|
+
) % len(AVAILABLE_COLORS)
|
58
|
+
|
59
|
+
return AVAILABLE_COLORS[idx]
|
60
|
+
|
61
|
+
@staticmethod
|
62
|
+
def content_to_str(content: Content | str, role: Role) -> str:
|
63
|
+
if role == Role.USER and isinstance(content, Content):
|
64
|
+
content_str_parts: list[str] = []
|
65
|
+
for content_part in content.parts:
|
66
|
+
if isinstance(content_part, ContentPartText):
|
67
|
+
content_str_parts.append(content_part.data.strip(" \n"))
|
68
|
+
elif content_part.data.type == "url":
|
69
|
+
content_str_parts.append(str(content_part.data.url))
|
70
|
+
elif content_part.data.type == "base64":
|
71
|
+
content_str_parts.append("<ENCODED_IMAGE>")
|
72
|
+
return "\n".join(content_str_parts)
|
73
|
+
|
74
|
+
assert isinstance(content, str)
|
75
|
+
|
76
|
+
return content.strip(" \n")
|
77
|
+
|
78
|
+
@staticmethod
|
79
|
+
def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
|
80
|
+
if len(content_str) > trunc_len:
|
81
|
+
return content_str[:trunc_len] + "[...]"
|
82
|
+
|
83
|
+
return content_str
|
84
|
+
|
85
|
+
def print_llm_message(self, message: Message, agent_id: str) -> None:
|
86
|
+
if not self.print_messages:
|
87
|
+
return
|
88
|
+
|
89
|
+
role = message.role
|
90
|
+
usage = message.usage if isinstance(message, AssistantMessage) else None
|
91
|
+
content_str = self.content_to_str(message.content, message.role)
|
92
|
+
|
93
|
+
if self.color_by == "agent_id":
|
94
|
+
color = self.get_agent_color(agent_id)
|
95
|
+
elif self.color_by == "role":
|
96
|
+
color = self.get_role_color(role)
|
97
|
+
|
98
|
+
log_kwargs = {"extra": {"color": color}} # type: ignore
|
99
|
+
|
100
|
+
# Print message title
|
101
|
+
|
102
|
+
out = f"\n<{agent_id}>"
|
103
|
+
out += "[" + role.value.upper() + "]"
|
104
|
+
|
105
|
+
if isinstance(message, ToolMessage):
|
106
|
+
out += f"\nTool call ID: {message.tool_call_id}"
|
107
|
+
|
108
|
+
# Print message content
|
109
|
+
|
110
|
+
if content_str:
|
111
|
+
try:
|
112
|
+
content_str = json.dumps(json.loads(content_str), indent=2)
|
113
|
+
except Exception:
|
114
|
+
pass
|
115
|
+
content_str_truncated = self.truncate_content_str(
|
116
|
+
content_str, trunc_len=self.msg_trunc_len
|
117
|
+
)
|
118
|
+
out += f"\n{content_str_truncated}"
|
119
|
+
|
120
|
+
logger.debug(out, **log_kwargs) # type: ignore
|
121
|
+
|
122
|
+
# Print tool calls
|
123
|
+
|
124
|
+
if isinstance(message, AssistantMessage) and message.tool_calls is not None:
|
125
|
+
for tool_call in message.tool_calls:
|
126
|
+
if self.color_by == "agent_id":
|
127
|
+
tool_color = self.get_agent_color(agent_id=agent_id)
|
128
|
+
elif self.color_by == "role":
|
129
|
+
tool_color = self.get_role_color(role=Role.TOOL)
|
130
|
+
logger.debug(
|
131
|
+
f"\n[TOOL_CALL]<{agent_id}>\n{tool_call.tool_name} "
|
132
|
+
f"| {tool_call.id}\n{tool_call.tool_arguments}",
|
133
|
+
extra={"color": tool_color}, # type: ignore
|
134
|
+
)
|
135
|
+
|
136
|
+
# Print usage
|
137
|
+
|
138
|
+
if usage is not None:
|
139
|
+
usage_str = (
|
140
|
+
f"I/O/(R)/(C) tokens: {usage.input_tokens}/{usage.output_tokens}"
|
141
|
+
)
|
142
|
+
if usage.reasoning_tokens is not None:
|
143
|
+
usage_str += f"/{usage.reasoning_tokens}"
|
144
|
+
if usage.cached_tokens is not None:
|
145
|
+
usage_str += f"/{usage.cached_tokens}"
|
146
|
+
logger.debug(
|
147
|
+
f"\n------------------------------------\n{usage_str}",
|
148
|
+
**log_kwargs, # type: ignore
|
149
|
+
)
|
150
|
+
|
151
|
+
def print_llm_messages(self, messages: Sequence[Message], agent_id: str) -> None:
|
152
|
+
if not self.print_messages:
|
153
|
+
return
|
154
|
+
|
155
|
+
for message in messages:
|
156
|
+
self.print_llm_message(message, agent_id)
|
@@ -0,0 +1,204 @@
|
|
1
|
+
from collections.abc import Sequence
|
2
|
+
from copy import deepcopy
|
3
|
+
from typing import Generic, Protocol
|
4
|
+
|
5
|
+
from .run_context import CtxT, RunContextWrapper, UserRunArgs
|
6
|
+
from .typing.content import ImageData
|
7
|
+
from .typing.io import (
|
8
|
+
AgentPayload,
|
9
|
+
InT,
|
10
|
+
LLMFormattedArgs,
|
11
|
+
LLMFormattedSystemArgs,
|
12
|
+
LLMPrompt,
|
13
|
+
LLMPromptArgs,
|
14
|
+
)
|
15
|
+
from .typing.message import UserMessage
|
16
|
+
|
17
|
+
|
18
|
+
class FormatSystemArgsHandler(Protocol[CtxT]):
|
19
|
+
def __call__(
|
20
|
+
self,
|
21
|
+
sys_args: LLMPromptArgs,
|
22
|
+
*,
|
23
|
+
ctx: RunContextWrapper[CtxT] | None,
|
24
|
+
) -> LLMFormattedSystemArgs: ...
|
25
|
+
|
26
|
+
|
27
|
+
class FormatInputArgsHandler(Protocol[InT, CtxT]):
|
28
|
+
def __call__(
|
29
|
+
self,
|
30
|
+
usr_args: LLMPromptArgs,
|
31
|
+
rcv_args: InT,
|
32
|
+
ctx: RunContextWrapper[CtxT] | None,
|
33
|
+
) -> LLMFormattedArgs: ...
|
34
|
+
|
35
|
+
|
36
|
+
class PromptBuilder(Generic[InT, CtxT]):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
agent_id: str,
|
40
|
+
sys_prompt: LLMPrompt | None,
|
41
|
+
inp_prompt: LLMPrompt | None,
|
42
|
+
sys_args_schema: type[LLMPromptArgs],
|
43
|
+
usr_args_schema: type[LLMPromptArgs],
|
44
|
+
rcv_args_schema: type[InT],
|
45
|
+
):
|
46
|
+
self._agent_id = agent_id
|
47
|
+
self.sys_prompt = sys_prompt
|
48
|
+
self.inp_prompt = inp_prompt
|
49
|
+
self.sys_args_schema = sys_args_schema
|
50
|
+
self.usr_args_schema = usr_args_schema
|
51
|
+
self.rcv_args_schema = rcv_args_schema
|
52
|
+
self.format_sys_args_impl: FormatSystemArgsHandler[CtxT] | None = None
|
53
|
+
self.format_inp_args_impl: FormatInputArgsHandler[InT, CtxT] | None = None
|
54
|
+
|
55
|
+
def _format_sys_args(
|
56
|
+
self,
|
57
|
+
sys_args: LLMPromptArgs,
|
58
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
59
|
+
) -> LLMFormattedSystemArgs:
|
60
|
+
if self.format_sys_args_impl:
|
61
|
+
return self.format_sys_args_impl(sys_args=sys_args, ctx=ctx)
|
62
|
+
|
63
|
+
return sys_args.model_dump(exclude_unset=True)
|
64
|
+
|
65
|
+
def _format_inp_args(
|
66
|
+
self,
|
67
|
+
usr_args: LLMPromptArgs,
|
68
|
+
rcv_args: InT,
|
69
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
70
|
+
) -> LLMFormattedArgs:
|
71
|
+
if self.format_inp_args_impl:
|
72
|
+
return self.format_inp_args_impl(
|
73
|
+
usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
|
74
|
+
)
|
75
|
+
|
76
|
+
return usr_args.model_dump(exclude_unset=True) | rcv_args.model_dump(
|
77
|
+
exclude_unset=True, exclude={"selected_recipient_ids"}
|
78
|
+
)
|
79
|
+
|
80
|
+
def make_sys_prompt(
|
81
|
+
self,
|
82
|
+
sys_args: LLMPromptArgs,
|
83
|
+
ctx: RunContextWrapper[CtxT] | None,
|
84
|
+
) -> LLMPrompt | None:
|
85
|
+
if self.sys_prompt is None:
|
86
|
+
return None
|
87
|
+
val_sys_args = self.sys_args_schema.model_validate(sys_args)
|
88
|
+
fmt_sys_args = self._format_sys_args(val_sys_args, ctx=ctx)
|
89
|
+
|
90
|
+
return self.sys_prompt.format(**fmt_sys_args)
|
91
|
+
|
92
|
+
def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
|
93
|
+
return [UserMessage.from_text(text, model_id=self._agent_id)]
|
94
|
+
|
95
|
+
def _usr_messages_from_content_parts(
|
96
|
+
self, content_parts: list[str | ImageData]
|
97
|
+
) -> list[UserMessage]:
|
98
|
+
return [UserMessage.from_content_parts(content_parts, model_id=self._agent_id)]
|
99
|
+
|
100
|
+
def _usr_messages_from_rcv_args(
|
101
|
+
self, rcv_args_batch: Sequence[InT]
|
102
|
+
) -> list[UserMessage]:
|
103
|
+
val_rcv_args_batch = [
|
104
|
+
self.rcv_args_schema.model_validate(rcv) for rcv in rcv_args_batch
|
105
|
+
]
|
106
|
+
|
107
|
+
return [
|
108
|
+
UserMessage.from_text(
|
109
|
+
rcv.model_dump_json(
|
110
|
+
exclude_unset=True,
|
111
|
+
indent=2,
|
112
|
+
exclude={"selected_recipient_ids"},
|
113
|
+
),
|
114
|
+
model_id=self._agent_id,
|
115
|
+
)
|
116
|
+
for rcv in val_rcv_args_batch
|
117
|
+
]
|
118
|
+
|
119
|
+
def _usr_messages_from_prompt_template(
|
120
|
+
self,
|
121
|
+
inp_prompt: LLMPrompt,
|
122
|
+
usr_args: UserRunArgs | None = None,
|
123
|
+
rcv_args_batch: Sequence[InT] | None = None,
|
124
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
125
|
+
) -> Sequence[UserMessage]:
|
126
|
+
usr_args_batch, rcv_args_batch = self._make_batched(usr_args, rcv_args_batch)
|
127
|
+
val_usr_args_batch = [
|
128
|
+
self.usr_args_schema.model_validate(u) for u in usr_args_batch
|
129
|
+
]
|
130
|
+
val_rcv_args_batch = [
|
131
|
+
self.rcv_args_schema.model_validate(r) for r in rcv_args_batch
|
132
|
+
]
|
133
|
+
formatted_inp_args_batch = [
|
134
|
+
self._format_inp_args(usr_args=val_usr_args, rcv_args=val_rcv_args, ctx=ctx)
|
135
|
+
for val_usr_args, val_rcv_args in zip(
|
136
|
+
val_usr_args_batch, val_rcv_args_batch, strict=False
|
137
|
+
)
|
138
|
+
]
|
139
|
+
|
140
|
+
return [
|
141
|
+
UserMessage.from_formatted_prompt(
|
142
|
+
prompt_template=inp_prompt, prompt_args=inp_args
|
143
|
+
)
|
144
|
+
for inp_args in formatted_inp_args_batch
|
145
|
+
]
|
146
|
+
|
147
|
+
def make_user_messages(
|
148
|
+
self,
|
149
|
+
inp_items: LLMPrompt | list[str | ImageData] | None = None,
|
150
|
+
usr_args: UserRunArgs | None = None,
|
151
|
+
rcv_args_batch: Sequence[InT] | None = None,
|
152
|
+
entry_point: bool = False,
|
153
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
154
|
+
) -> Sequence[UserMessage]:
|
155
|
+
# 1) Direct user input (e.g. chat input)
|
156
|
+
if inp_items is not None or entry_point:
|
157
|
+
"""
|
158
|
+
* If user inputs are provided, use them instead of the predefined
|
159
|
+
input prompt template
|
160
|
+
* In a multi-agent system, the predefined input prompt is used to
|
161
|
+
construct agent inputs using the combination of received
|
162
|
+
and user arguments.
|
163
|
+
However, the first agent run (entry point) has no received
|
164
|
+
messages, so we use the user inputs directly, if provided.
|
165
|
+
"""
|
166
|
+
if isinstance(inp_items, LLMPrompt):
|
167
|
+
return self._usr_messages_from_text(inp_items)
|
168
|
+
if isinstance(inp_items, list) and inp_items:
|
169
|
+
return self._usr_messages_from_content_parts(inp_items)
|
170
|
+
return []
|
171
|
+
|
172
|
+
# 2) No input prompt template + received args → raw JSON messages
|
173
|
+
if self.inp_prompt is None and rcv_args_batch:
|
174
|
+
return self._usr_messages_from_rcv_args(rcv_args_batch)
|
175
|
+
|
176
|
+
# 3) Input prompt template + any args → batch & format
|
177
|
+
if self.inp_prompt is not None:
|
178
|
+
return self._usr_messages_from_prompt_template(
|
179
|
+
inp_prompt=self.inp_prompt,
|
180
|
+
usr_args=usr_args,
|
181
|
+
rcv_args_batch=rcv_args_batch,
|
182
|
+
ctx=ctx,
|
183
|
+
)
|
184
|
+
return []
|
185
|
+
|
186
|
+
def _make_batched(
|
187
|
+
self,
|
188
|
+
usr_args: UserRunArgs | None = None,
|
189
|
+
rcv_args_batch: Sequence[InT] | None = None,
|
190
|
+
) -> tuple[Sequence[LLMPromptArgs], Sequence[InT]]:
|
191
|
+
usr_args_batch_ = (
|
192
|
+
usr_args if isinstance(usr_args, list) else [usr_args or LLMPromptArgs()]
|
193
|
+
)
|
194
|
+
rcv_args_batch_ = rcv_args_batch or [AgentPayload()]
|
195
|
+
|
196
|
+
# Broadcast singleton → match lengths
|
197
|
+
if len(usr_args_batch_) == 1 and len(rcv_args_batch_) > 1:
|
198
|
+
usr_args_batch_ = [deepcopy(usr_args_batch_[0]) for _ in rcv_args_batch_]
|
199
|
+
if len(rcv_args_batch_) == 1 and len(usr_args_batch_) > 1:
|
200
|
+
rcv_args_batch_ = [deepcopy(rcv_args_batch_[0]) for _ in usr_args_batch_]
|
201
|
+
if len(usr_args_batch_) != len(rcv_args_batch_):
|
202
|
+
raise ValueError("User args and received args must have the same length")
|
203
|
+
|
204
|
+
return usr_args_batch_, rcv_args_batch_ # type: ignore
|
@@ -0,0 +1,90 @@
|
|
1
|
+
from collections.abc import Sequence
|
2
|
+
from typing import Any, Generic, TypeAlias, TypeVar
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
6
|
+
|
7
|
+
from .printer import Printer
|
8
|
+
from .typing.content import ImageData
|
9
|
+
from .typing.io import (
|
10
|
+
AgentID,
|
11
|
+
AgentPayload,
|
12
|
+
AgentState,
|
13
|
+
InT,
|
14
|
+
LLMPrompt,
|
15
|
+
LLMPromptArgs,
|
16
|
+
OutT,
|
17
|
+
StateT,
|
18
|
+
)
|
19
|
+
from .usage_tracker import UsageTracker
|
20
|
+
|
21
|
+
SystemRunArgs: TypeAlias = LLMPromptArgs
|
22
|
+
UserRunArgs: TypeAlias = LLMPromptArgs | list[LLMPromptArgs]
|
23
|
+
|
24
|
+
|
25
|
+
class RunArgs(BaseModel):
|
26
|
+
sys: SystemRunArgs = Field(default_factory=LLMPromptArgs)
|
27
|
+
usr: UserRunArgs = Field(default_factory=LLMPromptArgs)
|
28
|
+
|
29
|
+
model_config = ConfigDict(extra="forbid")
|
30
|
+
|
31
|
+
|
32
|
+
class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
|
33
|
+
source_id: str
|
34
|
+
recipient_ids: Sequence[AgentID]
|
35
|
+
state: StateT
|
36
|
+
inp_items: LLMPrompt | list[str | ImageData] | None = None
|
37
|
+
sys_prompt: LLMPrompt | None = None
|
38
|
+
inp_prompt: LLMPrompt | None = None
|
39
|
+
sys_args: SystemRunArgs | None = None
|
40
|
+
usr_args: UserRunArgs | None = None
|
41
|
+
rcv_args: Sequence[InT] | None = None
|
42
|
+
outputs: Sequence[OutT]
|
43
|
+
|
44
|
+
model_config = ConfigDict(extra="forbid", frozen=True)
|
45
|
+
|
46
|
+
|
47
|
+
InteractionHistory: TypeAlias = list[
|
48
|
+
InteractionRecord[AgentPayload, AgentPayload, AgentState]
|
49
|
+
]
|
50
|
+
|
51
|
+
|
52
|
+
CtxT = TypeVar("CtxT")
|
53
|
+
|
54
|
+
|
55
|
+
class RunContextWrapper(BaseModel, Generic[CtxT]):
|
56
|
+
context: CtxT | None = None
|
57
|
+
run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
|
58
|
+
run_args: dict[AgentID, RunArgs] = Field(default_factory=dict)
|
59
|
+
interaction_history: InteractionHistory = Field(default_factory=list)
|
60
|
+
|
61
|
+
print_messages: bool = False
|
62
|
+
|
63
|
+
_usage_tracker: UsageTracker = PrivateAttr()
|
64
|
+
_printer: Printer = PrivateAttr()
|
65
|
+
|
66
|
+
# usage_tracker: Optional[UsageTracker] = None
|
67
|
+
# printer: Optional[Printer] = None
|
68
|
+
|
69
|
+
# @model_validator(mode="after")
|
70
|
+
# def set_usage_tracker_and_printer(self) -> "RunContextWrapper":
|
71
|
+
# self.usage_tracker = UsageTracker(source_id=self.run_id)
|
72
|
+
# self.printer = Printer(source_id=self.run_id)
|
73
|
+
|
74
|
+
# return self
|
75
|
+
|
76
|
+
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
77
|
+
self._usage_tracker = UsageTracker(source_id=self.run_id)
|
78
|
+
self._printer = Printer(
|
79
|
+
source_id=self.run_id, print_messages=self.print_messages
|
80
|
+
)
|
81
|
+
|
82
|
+
@property
|
83
|
+
def usage_tracker(self) -> UsageTracker:
|
84
|
+
return self._usage_tracker
|
85
|
+
|
86
|
+
@property
|
87
|
+
def printer(self) -> Printer:
|
88
|
+
return self._printer
|
89
|
+
|
90
|
+
model_config = ConfigDict(extra="forbid")
|