grasp_agents 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/agent_message.py +28 -0
  2. grasp_agents/agent_message_pool.py +94 -0
  3. grasp_agents/base_agent.py +72 -0
  4. grasp_agents/cloud_llm.py +353 -0
  5. grasp_agents/comm_agent.py +230 -0
  6. grasp_agents/costs_dict.yaml +122 -0
  7. grasp_agents/data_retrieval/__init__.py +7 -0
  8. grasp_agents/data_retrieval/rate_limiter_chunked.py +195 -0
  9. grasp_agents/data_retrieval/types.py +57 -0
  10. grasp_agents/data_retrieval/utils.py +57 -0
  11. grasp_agents/grasp_logging.py +36 -0
  12. grasp_agents/http_client.py +24 -0
  13. grasp_agents/llm.py +106 -0
  14. grasp_agents/llm_agent.py +361 -0
  15. grasp_agents/llm_agent_state.py +73 -0
  16. grasp_agents/memory.py +150 -0
  17. grasp_agents/openai/__init__.py +83 -0
  18. grasp_agents/openai/completion_converters.py +49 -0
  19. grasp_agents/openai/content_converters.py +80 -0
  20. grasp_agents/openai/converters.py +170 -0
  21. grasp_agents/openai/message_converters.py +155 -0
  22. grasp_agents/openai/openai_llm.py +179 -0
  23. grasp_agents/openai/tool_converters.py +37 -0
  24. grasp_agents/printer.py +156 -0
  25. grasp_agents/prompt_builder.py +204 -0
  26. grasp_agents/run_context.py +90 -0
  27. grasp_agents/tool_orchestrator.py +181 -0
  28. grasp_agents/typing/__init__.py +0 -0
  29. grasp_agents/typing/completion.py +30 -0
  30. grasp_agents/typing/content.py +116 -0
  31. grasp_agents/typing/converters.py +118 -0
  32. grasp_agents/typing/io.py +32 -0
  33. grasp_agents/typing/message.py +130 -0
  34. grasp_agents/typing/tool.py +52 -0
  35. grasp_agents/usage_tracker.py +99 -0
  36. grasp_agents/utils.py +151 -0
  37. grasp_agents/workflow/__init__.py +0 -0
  38. grasp_agents/workflow/looped_agent.py +113 -0
  39. grasp_agents/workflow/sequential_agent.py +57 -0
  40. grasp_agents/workflow/workflow_agent.py +69 -0
  41. grasp_agents-0.1.5.dist-info/METADATA +14 -0
  42. grasp_agents-0.1.5.dist-info/RECORD +44 -0
  43. grasp_agents-0.1.5.dist-info/WHEEL +4 -0
  44. grasp_agents-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,179 @@
1
+ import logging
2
+ from collections.abc import Iterable
3
+ from copy import deepcopy
4
+ from typing import Any, Literal
5
+
6
+ from openai import AsyncOpenAI
7
+ from openai._types import NOT_GIVEN # noqa: PLC2701 # type: ignore[import]
8
+ from pydantic import BaseModel
9
+
10
+ from ..data_retrieval.rate_limiter_chunked import RateLimiterC
11
+
12
+ from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
13
+ from ..http_client import AsyncHTTPClientParams
14
+ from ..typing.message import AssistantMessage, Conversation
15
+ from ..typing.tool import BaseTool
16
+ from . import (
17
+ ChatCompletion,
18
+ ChatCompletionAsyncStream, # type: ignore[import]
19
+ ChatCompletionChunk,
20
+ ChatCompletionMessageParam,
21
+ ChatCompletionPredictionContentParam,
22
+ ChatCompletionStreamOptionsParam,
23
+ ChatCompletionToolChoiceOptionParam,
24
+ ChatCompletionToolParam,
25
+ ParsedChatCompletion,
26
+ # ResponseFormatJSONObject,
27
+ # ResponseFormatJSONSchema,
28
+ # ResponseFormatText,
29
+ )
30
+ from .converters import OpenAIConverters
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class OpenAILLMSettings(CloudLLMSettings, total=False):
36
+ reasoning_effort: Literal["low", "medium", "high"] | None
37
+
38
+ parallel_tool_calls: bool
39
+
40
+ # response_format: (
41
+ # ResponseFormatText | ResponseFormatJSONSchema | ResponseFormatJSONObject
42
+ # )
43
+
44
+ modalities: list[Literal["text", "audio"]] | None
45
+
46
+ frequency_penalty: float | None
47
+ presence_penalty: float | None
48
+ logit_bias: dict[str, int] | None
49
+ stop: str | list[str] | None
50
+ logprobs: bool | None
51
+ top_logprobs: int | None
52
+ n: int | None
53
+
54
+ prediction: ChatCompletionPredictionContentParam | None
55
+
56
+ stream_options: ChatCompletionStreamOptionsParam | None
57
+
58
+ metadata: dict[str, str] | None
59
+ store: bool | None
60
+ user: str
61
+
62
+
63
+ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
64
+ def __init__(
65
+ self,
66
+ # Base LLM args
67
+ model_name: str,
68
+ model_id: str | None = None,
69
+ llm_settings: OpenAILLMSettings | None = None,
70
+ tools: list[BaseTool[BaseModel, BaseModel, Any]] | None = None,
71
+ response_format: type | None = None,
72
+ # Connection settings
73
+ api_provider: APIProvider = "openai",
74
+ async_http_client_params: (
75
+ dict[str, Any] | AsyncHTTPClientParams | None
76
+ ) = None,
77
+ async_openai_client_params: dict[str, Any] | None = None,
78
+ # Rate limiting
79
+ rate_limiter: (RateLimiterC[Conversation, AssistantMessage] | None) = None,
80
+ rate_limiter_rpm: float | None = None,
81
+ rate_limiter_chunk_size: int = 1000,
82
+ rate_limiter_max_concurrency: int = 300,
83
+ # Retries
84
+ num_generation_retries: int = 0,
85
+ # Disable tqdm for batch processing
86
+ no_tqdm: bool = True,
87
+ **kwargs: Any,
88
+ ) -> None:
89
+ super().__init__(
90
+ model_name=model_name,
91
+ model_id=model_id,
92
+ llm_settings=llm_settings,
93
+ converters=OpenAIConverters(),
94
+ tools=tools,
95
+ response_format=response_format,
96
+ api_provider=api_provider,
97
+ async_http_client_params=async_http_client_params,
98
+ rate_limiter=rate_limiter,
99
+ rate_limiter_rpm=rate_limiter_rpm,
100
+ rate_limiter_chunk_size=rate_limiter_chunk_size,
101
+ rate_limiter_max_concurrency=rate_limiter_max_concurrency,
102
+ num_generation_retries=num_generation_retries,
103
+ no_tqdm=no_tqdm,
104
+ **kwargs,
105
+ )
106
+
107
+ async_openai_client_params_ = deepcopy(async_openai_client_params or {})
108
+ if self._async_http_client is not None:
109
+ async_openai_client_params_["http_client"] = self._async_http_client
110
+
111
+ # TODO: context manager for async client
112
+ self._client: AsyncOpenAI = AsyncOpenAI(
113
+ base_url=self._base_url,
114
+ api_key=self._api_key,
115
+ **async_openai_client_params_,
116
+ # timeout=10.0,
117
+ # max_retries=3,
118
+ )
119
+
120
+ async def _get_completion(
121
+ self,
122
+ api_messages: Iterable[ChatCompletionMessageParam],
123
+ api_tools: list[ChatCompletionToolParam] | None = None,
124
+ api_tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
125
+ **api_llm_settings: Any,
126
+ ) -> ChatCompletion:
127
+ tools = api_tools or NOT_GIVEN
128
+ tool_choice = api_tool_choice or NOT_GIVEN
129
+
130
+ return await self._client.chat.completions.create(
131
+ model=self._model_name,
132
+ messages=api_messages,
133
+ tools=tools,
134
+ tool_choice=tool_choice,
135
+ stream=False,
136
+ **api_llm_settings,
137
+ )
138
+
139
+ async def _get_parsed_completion(
140
+ self,
141
+ api_messages: Iterable[ChatCompletionMessageParam],
142
+ api_tools: list[ChatCompletionToolParam] | None = None,
143
+ api_tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
144
+ api_response_format: type | None = None,
145
+ **api_llm_settings: Any,
146
+ ) -> ParsedChatCompletion[Any]:
147
+ tools = api_tools or NOT_GIVEN
148
+ tool_choice = api_tool_choice or NOT_GIVEN
149
+ response_format = api_response_format or NOT_GIVEN
150
+
151
+ return await self._client.beta.chat.completions.parse(
152
+ model=self._model_name,
153
+ messages=api_messages,
154
+ tools=tools,
155
+ tool_choice=tool_choice,
156
+ response_format=response_format, # type: ignore[arg-type]
157
+ **api_llm_settings,
158
+ )
159
+
160
+ async def _get_completion_stream(
161
+ self,
162
+ api_messages: Iterable[ChatCompletionMessageParam],
163
+ api_tools: list[ChatCompletionToolParam] | None = None,
164
+ api_tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
165
+ **api_llm_settings: Any,
166
+ ) -> ChatCompletionAsyncStream[ChatCompletionChunk]:
167
+ assert not api_tools, "Tool use is not supported in streaming mode"
168
+
169
+ tools = api_tools or NOT_GIVEN
170
+ tool_choice = api_tool_choice or NOT_GIVEN
171
+
172
+ return await self._client.chat.completions.create(
173
+ model=self._model_name,
174
+ messages=api_messages,
175
+ tools=tools,
176
+ tool_choice=tool_choice,
177
+ stream=True,
178
+ **api_llm_settings,
179
+ )
@@ -0,0 +1,37 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from ..typing.tool import BaseTool, ToolChoice
6
+ from . import (
7
+ ChatCompletionFunctionDefinition,
8
+ ChatCompletionNamedToolChoiceFunction,
9
+ ChatCompletionNamedToolChoiceParam,
10
+ ChatCompletionToolChoiceOptionParam,
11
+ ChatCompletionToolParam,
12
+ )
13
+
14
+
15
+ def to_api_tool(
16
+ tool: BaseTool[BaseModel, BaseModel, Any],
17
+ ) -> ChatCompletionToolParam:
18
+ return ChatCompletionToolParam(
19
+ type="function",
20
+ function=ChatCompletionFunctionDefinition(
21
+ name=tool.name,
22
+ description=tool.description,
23
+ parameters=tool.in_schema.model_json_schema(),
24
+ strict=tool.strict,
25
+ ),
26
+ )
27
+
28
+
29
+ def to_api_tool_choice(
30
+ tool_choice: ToolChoice,
31
+ ) -> ChatCompletionToolChoiceOptionParam:
32
+ if isinstance(tool_choice, BaseTool):
33
+ return ChatCompletionNamedToolChoiceParam(
34
+ type="function",
35
+ function=ChatCompletionNamedToolChoiceFunction(name=tool_choice.name),
36
+ )
37
+ return tool_choice
@@ -0,0 +1,156 @@
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ from collections.abc import Mapping, Sequence
5
+ from typing import Literal, TypeAlias
6
+
7
+ from termcolor._types import Color # type: ignore[import]
8
+
9
+ from .typing.content import Content, ContentPartText
10
+ from .typing.message import AssistantMessage, Message, Role, ToolMessage
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ ColoringMode: TypeAlias = Literal["agent_id", "role"]
16
+
17
+ ROLE_TO_COLOR: Mapping[Role, Color] = {
18
+ Role.SYSTEM: "magenta",
19
+ Role.USER: "green",
20
+ Role.ASSISTANT: "light_blue",
21
+ Role.TOOL: "light_cyan",
22
+ }
23
+
24
+ AVAILABLE_COLORS: list[Color] = [
25
+ "magenta",
26
+ "green",
27
+ "light_blue",
28
+ "light_cyan",
29
+ "yellow",
30
+ "blue",
31
+ "red",
32
+ ]
33
+
34
+
35
+ class Printer:
36
+ def __init__(
37
+ self,
38
+ source_id: str,
39
+ color_by: ColoringMode = "role",
40
+ msg_trunc_len: int = 20000,
41
+ print_messages: bool = False,
42
+ ) -> None:
43
+ self.source_id = source_id
44
+ self.color_by = color_by
45
+ self.msg_trunc_len = msg_trunc_len
46
+ self.print_messages = print_messages
47
+
48
+ @staticmethod
49
+ def get_role_color(role: Role) -> Color:
50
+ return ROLE_TO_COLOR[role]
51
+
52
+ @staticmethod
53
+ def get_agent_color(agent_id: str) -> Color:
54
+ idx = int(
55
+ hashlib.md5(agent_id.encode()).hexdigest(), # noqa :S324
56
+ 16,
57
+ ) % len(AVAILABLE_COLORS)
58
+
59
+ return AVAILABLE_COLORS[idx]
60
+
61
+ @staticmethod
62
+ def content_to_str(content: Content | str, role: Role) -> str:
63
+ if role == Role.USER and isinstance(content, Content):
64
+ content_str_parts: list[str] = []
65
+ for content_part in content.parts:
66
+ if isinstance(content_part, ContentPartText):
67
+ content_str_parts.append(content_part.data.strip(" \n"))
68
+ elif content_part.data.type == "url":
69
+ content_str_parts.append(str(content_part.data.url))
70
+ elif content_part.data.type == "base64":
71
+ content_str_parts.append("<ENCODED_IMAGE>")
72
+ return "\n".join(content_str_parts)
73
+
74
+ assert isinstance(content, str)
75
+
76
+ return content.strip(" \n")
77
+
78
+ @staticmethod
79
+ def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
80
+ if len(content_str) > trunc_len:
81
+ return content_str[:trunc_len] + "[...]"
82
+
83
+ return content_str
84
+
85
+ def print_llm_message(self, message: Message, agent_id: str) -> None:
86
+ if not self.print_messages:
87
+ return
88
+
89
+ role = message.role
90
+ usage = message.usage if isinstance(message, AssistantMessage) else None
91
+ content_str = self.content_to_str(message.content, message.role)
92
+
93
+ if self.color_by == "agent_id":
94
+ color = self.get_agent_color(agent_id)
95
+ elif self.color_by == "role":
96
+ color = self.get_role_color(role)
97
+
98
+ log_kwargs = {"extra": {"color": color}} # type: ignore
99
+
100
+ # Print message title
101
+
102
+ out = f"\n<{agent_id}>"
103
+ out += "[" + role.value.upper() + "]"
104
+
105
+ if isinstance(message, ToolMessage):
106
+ out += f"\nTool call ID: {message.tool_call_id}"
107
+
108
+ # Print message content
109
+
110
+ if content_str:
111
+ try:
112
+ content_str = json.dumps(json.loads(content_str), indent=2)
113
+ except Exception:
114
+ pass
115
+ content_str_truncated = self.truncate_content_str(
116
+ content_str, trunc_len=self.msg_trunc_len
117
+ )
118
+ out += f"\n{content_str_truncated}"
119
+
120
+ logger.debug(out, **log_kwargs) # type: ignore
121
+
122
+ # Print tool calls
123
+
124
+ if isinstance(message, AssistantMessage) and message.tool_calls is not None:
125
+ for tool_call in message.tool_calls:
126
+ if self.color_by == "agent_id":
127
+ tool_color = self.get_agent_color(agent_id=agent_id)
128
+ elif self.color_by == "role":
129
+ tool_color = self.get_role_color(role=Role.TOOL)
130
+ logger.debug(
131
+ f"\n[TOOL_CALL]<{agent_id}>\n{tool_call.tool_name} "
132
+ f"| {tool_call.id}\n{tool_call.tool_arguments}",
133
+ extra={"color": tool_color}, # type: ignore
134
+ )
135
+
136
+ # Print usage
137
+
138
+ if usage is not None:
139
+ usage_str = (
140
+ f"I/O/(R)/(C) tokens: {usage.input_tokens}/{usage.output_tokens}"
141
+ )
142
+ if usage.reasoning_tokens is not None:
143
+ usage_str += f"/{usage.reasoning_tokens}"
144
+ if usage.cached_tokens is not None:
145
+ usage_str += f"/{usage.cached_tokens}"
146
+ logger.debug(
147
+ f"\n------------------------------------\n{usage_str}",
148
+ **log_kwargs, # type: ignore
149
+ )
150
+
151
+ def print_llm_messages(self, messages: Sequence[Message], agent_id: str) -> None:
152
+ if not self.print_messages:
153
+ return
154
+
155
+ for message in messages:
156
+ self.print_llm_message(message, agent_id)
@@ -0,0 +1,204 @@
1
+ from collections.abc import Sequence
2
+ from copy import deepcopy
3
+ from typing import Generic, Protocol
4
+
5
+ from .run_context import CtxT, RunContextWrapper, UserRunArgs
6
+ from .typing.content import ImageData
7
+ from .typing.io import (
8
+ AgentPayload,
9
+ InT,
10
+ LLMFormattedArgs,
11
+ LLMFormattedSystemArgs,
12
+ LLMPrompt,
13
+ LLMPromptArgs,
14
+ )
15
+ from .typing.message import UserMessage
16
+
17
+
18
+ class FormatSystemArgsHandler(Protocol[CtxT]):
19
+ def __call__(
20
+ self,
21
+ sys_args: LLMPromptArgs,
22
+ *,
23
+ ctx: RunContextWrapper[CtxT] | None,
24
+ ) -> LLMFormattedSystemArgs: ...
25
+
26
+
27
+ class FormatInputArgsHandler(Protocol[InT, CtxT]):
28
+ def __call__(
29
+ self,
30
+ usr_args: LLMPromptArgs,
31
+ rcv_args: InT,
32
+ ctx: RunContextWrapper[CtxT] | None,
33
+ ) -> LLMFormattedArgs: ...
34
+
35
+
36
+ class PromptBuilder(Generic[InT, CtxT]):
37
+ def __init__(
38
+ self,
39
+ agent_id: str,
40
+ sys_prompt: LLMPrompt | None,
41
+ inp_prompt: LLMPrompt | None,
42
+ sys_args_schema: type[LLMPromptArgs],
43
+ usr_args_schema: type[LLMPromptArgs],
44
+ rcv_args_schema: type[InT],
45
+ ):
46
+ self._agent_id = agent_id
47
+ self.sys_prompt = sys_prompt
48
+ self.inp_prompt = inp_prompt
49
+ self.sys_args_schema = sys_args_schema
50
+ self.usr_args_schema = usr_args_schema
51
+ self.rcv_args_schema = rcv_args_schema
52
+ self.format_sys_args_impl: FormatSystemArgsHandler[CtxT] | None = None
53
+ self.format_inp_args_impl: FormatInputArgsHandler[InT, CtxT] | None = None
54
+
55
+ def _format_sys_args(
56
+ self,
57
+ sys_args: LLMPromptArgs,
58
+ ctx: RunContextWrapper[CtxT] | None = None,
59
+ ) -> LLMFormattedSystemArgs:
60
+ if self.format_sys_args_impl:
61
+ return self.format_sys_args_impl(sys_args=sys_args, ctx=ctx)
62
+
63
+ return sys_args.model_dump(exclude_unset=True)
64
+
65
+ def _format_inp_args(
66
+ self,
67
+ usr_args: LLMPromptArgs,
68
+ rcv_args: InT,
69
+ ctx: RunContextWrapper[CtxT] | None = None,
70
+ ) -> LLMFormattedArgs:
71
+ if self.format_inp_args_impl:
72
+ return self.format_inp_args_impl(
73
+ usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
74
+ )
75
+
76
+ return usr_args.model_dump(exclude_unset=True) | rcv_args.model_dump(
77
+ exclude_unset=True, exclude={"selected_recipient_ids"}
78
+ )
79
+
80
+ def make_sys_prompt(
81
+ self,
82
+ sys_args: LLMPromptArgs,
83
+ ctx: RunContextWrapper[CtxT] | None,
84
+ ) -> LLMPrompt | None:
85
+ if self.sys_prompt is None:
86
+ return None
87
+ val_sys_args = self.sys_args_schema.model_validate(sys_args)
88
+ fmt_sys_args = self._format_sys_args(val_sys_args, ctx=ctx)
89
+
90
+ return self.sys_prompt.format(**fmt_sys_args)
91
+
92
+ def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
93
+ return [UserMessage.from_text(text, model_id=self._agent_id)]
94
+
95
+ def _usr_messages_from_content_parts(
96
+ self, content_parts: list[str | ImageData]
97
+ ) -> list[UserMessage]:
98
+ return [UserMessage.from_content_parts(content_parts, model_id=self._agent_id)]
99
+
100
+ def _usr_messages_from_rcv_args(
101
+ self, rcv_args_batch: Sequence[InT]
102
+ ) -> list[UserMessage]:
103
+ val_rcv_args_batch = [
104
+ self.rcv_args_schema.model_validate(rcv) for rcv in rcv_args_batch
105
+ ]
106
+
107
+ return [
108
+ UserMessage.from_text(
109
+ rcv.model_dump_json(
110
+ exclude_unset=True,
111
+ indent=2,
112
+ exclude={"selected_recipient_ids"},
113
+ ),
114
+ model_id=self._agent_id,
115
+ )
116
+ for rcv in val_rcv_args_batch
117
+ ]
118
+
119
+ def _usr_messages_from_prompt_template(
120
+ self,
121
+ inp_prompt: LLMPrompt,
122
+ usr_args: UserRunArgs | None = None,
123
+ rcv_args_batch: Sequence[InT] | None = None,
124
+ ctx: RunContextWrapper[CtxT] | None = None,
125
+ ) -> Sequence[UserMessage]:
126
+ usr_args_batch, rcv_args_batch = self._make_batched(usr_args, rcv_args_batch)
127
+ val_usr_args_batch = [
128
+ self.usr_args_schema.model_validate(u) for u in usr_args_batch
129
+ ]
130
+ val_rcv_args_batch = [
131
+ self.rcv_args_schema.model_validate(r) for r in rcv_args_batch
132
+ ]
133
+ formatted_inp_args_batch = [
134
+ self._format_inp_args(usr_args=val_usr_args, rcv_args=val_rcv_args, ctx=ctx)
135
+ for val_usr_args, val_rcv_args in zip(
136
+ val_usr_args_batch, val_rcv_args_batch, strict=False
137
+ )
138
+ ]
139
+
140
+ return [
141
+ UserMessage.from_formatted_prompt(
142
+ prompt_template=inp_prompt, prompt_args=inp_args
143
+ )
144
+ for inp_args in formatted_inp_args_batch
145
+ ]
146
+
147
+ def make_user_messages(
148
+ self,
149
+ inp_items: LLMPrompt | list[str | ImageData] | None = None,
150
+ usr_args: UserRunArgs | None = None,
151
+ rcv_args_batch: Sequence[InT] | None = None,
152
+ entry_point: bool = False,
153
+ ctx: RunContextWrapper[CtxT] | None = None,
154
+ ) -> Sequence[UserMessage]:
155
+ # 1) Direct user input (e.g. chat input)
156
+ if inp_items is not None or entry_point:
157
+ """
158
+ * If user inputs are provided, use them instead of the predefined
159
+ input prompt template
160
+ * In a multi-agent system, the predefined input prompt is used to
161
+ construct agent inputs using the combination of received
162
+ and user arguments.
163
+ However, the first agent run (entry point) has no received
164
+ messages, so we use the user inputs directly, if provided.
165
+ """
166
+ if isinstance(inp_items, LLMPrompt):
167
+ return self._usr_messages_from_text(inp_items)
168
+ if isinstance(inp_items, list) and inp_items:
169
+ return self._usr_messages_from_content_parts(inp_items)
170
+ return []
171
+
172
+ # 2) No input prompt template + received args → raw JSON messages
173
+ if self.inp_prompt is None and rcv_args_batch:
174
+ return self._usr_messages_from_rcv_args(rcv_args_batch)
175
+
176
+ # 3) Input prompt template + any args → batch & format
177
+ if self.inp_prompt is not None:
178
+ return self._usr_messages_from_prompt_template(
179
+ inp_prompt=self.inp_prompt,
180
+ usr_args=usr_args,
181
+ rcv_args_batch=rcv_args_batch,
182
+ ctx=ctx,
183
+ )
184
+ return []
185
+
186
+ def _make_batched(
187
+ self,
188
+ usr_args: UserRunArgs | None = None,
189
+ rcv_args_batch: Sequence[InT] | None = None,
190
+ ) -> tuple[Sequence[LLMPromptArgs], Sequence[InT]]:
191
+ usr_args_batch_ = (
192
+ usr_args if isinstance(usr_args, list) else [usr_args or LLMPromptArgs()]
193
+ )
194
+ rcv_args_batch_ = rcv_args_batch or [AgentPayload()]
195
+
196
+ # Broadcast singleton → match lengths
197
+ if len(usr_args_batch_) == 1 and len(rcv_args_batch_) > 1:
198
+ usr_args_batch_ = [deepcopy(usr_args_batch_[0]) for _ in rcv_args_batch_]
199
+ if len(rcv_args_batch_) == 1 and len(usr_args_batch_) > 1:
200
+ rcv_args_batch_ = [deepcopy(rcv_args_batch_[0]) for _ in usr_args_batch_]
201
+ if len(usr_args_batch_) != len(rcv_args_batch_):
202
+ raise ValueError("User args and received args must have the same length")
203
+
204
+ return usr_args_batch_, rcv_args_batch_ # type: ignore
@@ -0,0 +1,90 @@
1
+ from collections.abc import Sequence
2
+ from typing import Any, Generic, TypeAlias, TypeVar
3
+ from uuid import uuid4
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
6
+
7
+ from .printer import Printer
8
+ from .typing.content import ImageData
9
+ from .typing.io import (
10
+ AgentID,
11
+ AgentPayload,
12
+ AgentState,
13
+ InT,
14
+ LLMPrompt,
15
+ LLMPromptArgs,
16
+ OutT,
17
+ StateT,
18
+ )
19
+ from .usage_tracker import UsageTracker
20
+
21
+ SystemRunArgs: TypeAlias = LLMPromptArgs
22
+ UserRunArgs: TypeAlias = LLMPromptArgs | list[LLMPromptArgs]
23
+
24
+
25
+ class RunArgs(BaseModel):
26
+ sys: SystemRunArgs = Field(default_factory=LLMPromptArgs)
27
+ usr: UserRunArgs = Field(default_factory=LLMPromptArgs)
28
+
29
+ model_config = ConfigDict(extra="forbid")
30
+
31
+
32
+ class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
33
+ source_id: str
34
+ recipient_ids: Sequence[AgentID]
35
+ state: StateT
36
+ inp_items: LLMPrompt | list[str | ImageData] | None = None
37
+ sys_prompt: LLMPrompt | None = None
38
+ inp_prompt: LLMPrompt | None = None
39
+ sys_args: SystemRunArgs | None = None
40
+ usr_args: UserRunArgs | None = None
41
+ rcv_args: Sequence[InT] | None = None
42
+ outputs: Sequence[OutT]
43
+
44
+ model_config = ConfigDict(extra="forbid", frozen=True)
45
+
46
+
47
+ InteractionHistory: TypeAlias = list[
48
+ InteractionRecord[AgentPayload, AgentPayload, AgentState]
49
+ ]
50
+
51
+
52
+ CtxT = TypeVar("CtxT")
53
+
54
+
55
+ class RunContextWrapper(BaseModel, Generic[CtxT]):
56
+ context: CtxT | None = None
57
+ run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
58
+ run_args: dict[AgentID, RunArgs] = Field(default_factory=dict)
59
+ interaction_history: InteractionHistory = Field(default_factory=list)
60
+
61
+ print_messages: bool = False
62
+
63
+ _usage_tracker: UsageTracker = PrivateAttr()
64
+ _printer: Printer = PrivateAttr()
65
+
66
+ # usage_tracker: Optional[UsageTracker] = None
67
+ # printer: Optional[Printer] = None
68
+
69
+ # @model_validator(mode="after")
70
+ # def set_usage_tracker_and_printer(self) -> "RunContextWrapper":
71
+ # self.usage_tracker = UsageTracker(source_id=self.run_id)
72
+ # self.printer = Printer(source_id=self.run_id)
73
+
74
+ # return self
75
+
76
+ def model_post_init(self, context: Any) -> None: # noqa: ARG002
77
+ self._usage_tracker = UsageTracker(source_id=self.run_id)
78
+ self._printer = Printer(
79
+ source_id=self.run_id, print_messages=self.print_messages
80
+ )
81
+
82
+ @property
83
+ def usage_tracker(self) -> UsageTracker:
84
+ return self._usage_tracker
85
+
86
+ @property
87
+ def printer(self) -> Printer:
88
+ return self._printer
89
+
90
+ model_config = ConfigDict(extra="forbid")