mail-swarms 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mail/__init__.py +35 -0
- mail/api.py +1964 -0
- mail/cli.py +432 -0
- mail/client.py +1657 -0
- mail/config/__init__.py +8 -0
- mail/config/client.py +87 -0
- mail/config/server.py +165 -0
- mail/core/__init__.py +72 -0
- mail/core/actions.py +69 -0
- mail/core/agents.py +73 -0
- mail/core/message.py +366 -0
- mail/core/runtime.py +3537 -0
- mail/core/tasks.py +311 -0
- mail/core/tools.py +1206 -0
- mail/db/__init__.py +0 -0
- mail/db/init.py +182 -0
- mail/db/types.py +65 -0
- mail/db/utils.py +523 -0
- mail/examples/__init__.py +27 -0
- mail/examples/analyst_dummy/__init__.py +15 -0
- mail/examples/analyst_dummy/agent.py +136 -0
- mail/examples/analyst_dummy/prompts.py +44 -0
- mail/examples/consultant_dummy/__init__.py +15 -0
- mail/examples/consultant_dummy/agent.py +136 -0
- mail/examples/consultant_dummy/prompts.py +42 -0
- mail/examples/data_analysis/__init__.py +40 -0
- mail/examples/data_analysis/analyst/__init__.py +9 -0
- mail/examples/data_analysis/analyst/agent.py +67 -0
- mail/examples/data_analysis/analyst/prompts.py +53 -0
- mail/examples/data_analysis/processor/__init__.py +13 -0
- mail/examples/data_analysis/processor/actions.py +293 -0
- mail/examples/data_analysis/processor/agent.py +67 -0
- mail/examples/data_analysis/processor/prompts.py +48 -0
- mail/examples/data_analysis/reporter/__init__.py +10 -0
- mail/examples/data_analysis/reporter/actions.py +187 -0
- mail/examples/data_analysis/reporter/agent.py +67 -0
- mail/examples/data_analysis/reporter/prompts.py +49 -0
- mail/examples/data_analysis/statistics/__init__.py +18 -0
- mail/examples/data_analysis/statistics/actions.py +343 -0
- mail/examples/data_analysis/statistics/agent.py +67 -0
- mail/examples/data_analysis/statistics/prompts.py +60 -0
- mail/examples/mafia/__init__.py +0 -0
- mail/examples/mafia/game.py +1537 -0
- mail/examples/mafia/narrator_tools.py +396 -0
- mail/examples/mafia/personas.py +240 -0
- mail/examples/mafia/prompts.py +489 -0
- mail/examples/mafia/roles.py +147 -0
- mail/examples/mafia/spec.md +350 -0
- mail/examples/math_dummy/__init__.py +23 -0
- mail/examples/math_dummy/actions.py +252 -0
- mail/examples/math_dummy/agent.py +136 -0
- mail/examples/math_dummy/prompts.py +46 -0
- mail/examples/math_dummy/types.py +5 -0
- mail/examples/research/__init__.py +39 -0
- mail/examples/research/researcher/__init__.py +9 -0
- mail/examples/research/researcher/agent.py +67 -0
- mail/examples/research/researcher/prompts.py +54 -0
- mail/examples/research/searcher/__init__.py +10 -0
- mail/examples/research/searcher/actions.py +324 -0
- mail/examples/research/searcher/agent.py +67 -0
- mail/examples/research/searcher/prompts.py +53 -0
- mail/examples/research/summarizer/__init__.py +18 -0
- mail/examples/research/summarizer/actions.py +255 -0
- mail/examples/research/summarizer/agent.py +67 -0
- mail/examples/research/summarizer/prompts.py +55 -0
- mail/examples/research/verifier/__init__.py +10 -0
- mail/examples/research/verifier/actions.py +337 -0
- mail/examples/research/verifier/agent.py +67 -0
- mail/examples/research/verifier/prompts.py +52 -0
- mail/examples/supervisor/__init__.py +11 -0
- mail/examples/supervisor/agent.py +4 -0
- mail/examples/supervisor/prompts.py +93 -0
- mail/examples/support/__init__.py +33 -0
- mail/examples/support/classifier/__init__.py +10 -0
- mail/examples/support/classifier/actions.py +307 -0
- mail/examples/support/classifier/agent.py +68 -0
- mail/examples/support/classifier/prompts.py +56 -0
- mail/examples/support/coordinator/__init__.py +9 -0
- mail/examples/support/coordinator/agent.py +67 -0
- mail/examples/support/coordinator/prompts.py +48 -0
- mail/examples/support/faq/__init__.py +10 -0
- mail/examples/support/faq/actions.py +182 -0
- mail/examples/support/faq/agent.py +67 -0
- mail/examples/support/faq/prompts.py +42 -0
- mail/examples/support/sentiment/__init__.py +15 -0
- mail/examples/support/sentiment/actions.py +341 -0
- mail/examples/support/sentiment/agent.py +67 -0
- mail/examples/support/sentiment/prompts.py +54 -0
- mail/examples/weather_dummy/__init__.py +23 -0
- mail/examples/weather_dummy/actions.py +75 -0
- mail/examples/weather_dummy/agent.py +136 -0
- mail/examples/weather_dummy/prompts.py +35 -0
- mail/examples/weather_dummy/types.py +5 -0
- mail/factories/__init__.py +27 -0
- mail/factories/action.py +223 -0
- mail/factories/base.py +1531 -0
- mail/factories/supervisor.py +241 -0
- mail/net/__init__.py +7 -0
- mail/net/registry.py +712 -0
- mail/net/router.py +728 -0
- mail/net/server_utils.py +114 -0
- mail/net/types.py +247 -0
- mail/server.py +1605 -0
- mail/stdlib/__init__.py +0 -0
- mail/stdlib/anthropic/__init__.py +0 -0
- mail/stdlib/fs/__init__.py +15 -0
- mail/stdlib/fs/actions.py +209 -0
- mail/stdlib/http/__init__.py +19 -0
- mail/stdlib/http/actions.py +333 -0
- mail/stdlib/interswarm/__init__.py +11 -0
- mail/stdlib/interswarm/actions.py +208 -0
- mail/stdlib/mcp/__init__.py +19 -0
- mail/stdlib/mcp/actions.py +294 -0
- mail/stdlib/openai/__init__.py +13 -0
- mail/stdlib/openai/agents.py +451 -0
- mail/summarizer.py +234 -0
- mail/swarms_json/__init__.py +27 -0
- mail/swarms_json/types.py +87 -0
- mail/swarms_json/utils.py +255 -0
- mail/url_scheme.py +51 -0
- mail/utils/__init__.py +53 -0
- mail/utils/auth.py +194 -0
- mail/utils/context.py +17 -0
- mail/utils/logger.py +73 -0
- mail/utils/openai.py +212 -0
- mail/utils/parsing.py +89 -0
- mail/utils/serialize.py +292 -0
- mail/utils/store.py +49 -0
- mail/utils/string_builder.py +119 -0
- mail/utils/version.py +20 -0
- mail_swarms-1.3.2.dist-info/METADATA +237 -0
- mail_swarms-1.3.2.dist-info/RECORD +137 -0
- mail_swarms-1.3.2.dist-info/WHEEL +4 -0
- mail_swarms-1.3.2.dist-info/entry_points.txt +2 -0
- mail_swarms-1.3.2.dist-info/licenses/LICENSE +202 -0
- mail_swarms-1.3.2.dist-info/licenses/NOTICE +10 -0
- mail_swarms-1.3.2.dist-info/licenses/THIRD_PARTY_NOTICES.md +12334 -0
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright (c) 2025 Addison Kline
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
from collections.abc import Awaitable
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
import openai
|
|
11
|
+
from openai.types.chat import (
|
|
12
|
+
ChatCompletionMessageParam,
|
|
13
|
+
ChatCompletionMessageToolCallUnion,
|
|
14
|
+
ChatCompletionToolUnionParam,
|
|
15
|
+
)
|
|
16
|
+
from openai.types.responses import ResponseInputParam, ToolParam
|
|
17
|
+
|
|
18
|
+
from mail.core.agents import AgentOutput
|
|
19
|
+
from mail.core.tools import AgentToolCall
|
|
20
|
+
from mail.factories.base import MAILAgentFunction
|
|
21
|
+
from mail.factories.supervisor import SupervisorFunction
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenAIChatCompletionsAgentFunction(MAILAgentFunction):
|
|
25
|
+
"""
|
|
26
|
+
A MAIL agent function that uses the OpenAI API to generate chat completions.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
name: str,
|
|
32
|
+
comm_targets: list[str],
|
|
33
|
+
model: str,
|
|
34
|
+
tools: list[dict[str, Any]],
|
|
35
|
+
enable_entrypoint: bool = False,
|
|
36
|
+
enable_interswarm: bool = False,
|
|
37
|
+
can_complete_tasks: bool = False,
|
|
38
|
+
tool_format: Literal[
|
|
39
|
+
"completions", "responses"
|
|
40
|
+
] = "completions", # kept for compatibility
|
|
41
|
+
exclude_tools: list[str] = [],
|
|
42
|
+
**kwargs: Any,
|
|
43
|
+
) -> None:
|
|
44
|
+
super().__init__(
|
|
45
|
+
name=name,
|
|
46
|
+
comm_targets=comm_targets,
|
|
47
|
+
tools=tools,
|
|
48
|
+
enable_entrypoint=enable_entrypoint,
|
|
49
|
+
enable_interswarm=enable_interswarm,
|
|
50
|
+
can_complete_tasks=can_complete_tasks,
|
|
51
|
+
tool_format="completions",
|
|
52
|
+
exclude_tools=exclude_tools,
|
|
53
|
+
**kwargs,
|
|
54
|
+
)
|
|
55
|
+
self.model = model
|
|
56
|
+
self.client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
57
|
+
|
|
58
|
+
def __call__(
|
|
59
|
+
self,
|
|
60
|
+
messages: list[dict[str, Any]],
|
|
61
|
+
tool_choice: str | dict[str, str] = "required",
|
|
62
|
+
) -> Awaitable[AgentOutput]:
|
|
63
|
+
"""
|
|
64
|
+
Generate a chat completion using the OpenAI API.
|
|
65
|
+
"""
|
|
66
|
+
return self._run_chat_completion(messages, tool_choice)
|
|
67
|
+
|
|
68
|
+
async def _run_chat_completion(
|
|
69
|
+
self,
|
|
70
|
+
messages: list[dict[str, Any]],
|
|
71
|
+
tool_choice: str | dict[str, str] = "required",
|
|
72
|
+
) -> AgentOutput:
|
|
73
|
+
"""
|
|
74
|
+
Run a chat completion using the OpenAI API.
|
|
75
|
+
"""
|
|
76
|
+
response = await self.client.chat.completions.create( # type: ignore
|
|
77
|
+
model=self.model,
|
|
78
|
+
messages=self._preprocess_messages(messages),
|
|
79
|
+
tool_choice=tool_choice,
|
|
80
|
+
tools=self._preprocess_tools(),
|
|
81
|
+
)
|
|
82
|
+
choice = response.choices[0]
|
|
83
|
+
message = choice.message
|
|
84
|
+
tool_calls = self._postprocess_tool_calls(message)
|
|
85
|
+
return message.content or None, tool_calls
|
|
86
|
+
|
|
87
|
+
def _preprocess_messages(
|
|
88
|
+
self,
|
|
89
|
+
messages: list[dict[str, Any]],
|
|
90
|
+
) -> list[ChatCompletionMessageParam]:
|
|
91
|
+
"""
|
|
92
|
+
Preprocess the messages for the OpenAI API.
|
|
93
|
+
"""
|
|
94
|
+
normalized: list[ChatCompletionMessageParam] = []
|
|
95
|
+
for message in messages:
|
|
96
|
+
entry: dict[str, Any] = {
|
|
97
|
+
"role": message.get("role"),
|
|
98
|
+
"content": message.get("content"),
|
|
99
|
+
}
|
|
100
|
+
if "name" in message:
|
|
101
|
+
entry["name"] = message["name"]
|
|
102
|
+
if "tool_calls" in message:
|
|
103
|
+
entry["tool_calls"] = message["tool_calls"]
|
|
104
|
+
if "tool_call_id" in message:
|
|
105
|
+
entry["tool_call_id"] = message["tool_call_id"]
|
|
106
|
+
normalized.append(entry) # type: ignore[arg-type]
|
|
107
|
+
return normalized
|
|
108
|
+
|
|
109
|
+
def _preprocess_tools(self) -> list[ChatCompletionToolUnionParam]:
|
|
110
|
+
"""
|
|
111
|
+
Preprocess the tools for the OpenAI API.
|
|
112
|
+
"""
|
|
113
|
+
return [
|
|
114
|
+
{
|
|
115
|
+
"type": "function",
|
|
116
|
+
"function": {
|
|
117
|
+
"name": tool["name"],
|
|
118
|
+
"description": tool["description"],
|
|
119
|
+
"parameters": tool["parameters"],
|
|
120
|
+
},
|
|
121
|
+
}
|
|
122
|
+
for tool in self.tools
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
def _postprocess_tool_calls(
|
|
126
|
+
self,
|
|
127
|
+
message: Any,
|
|
128
|
+
) -> list[AgentToolCall]:
|
|
129
|
+
"""
|
|
130
|
+
Postprocess the tool calls from the OpenAI API response.
|
|
131
|
+
"""
|
|
132
|
+
tool_calls: list[ChatCompletionMessageToolCallUnion] = list(
|
|
133
|
+
getattr(message, "tool_calls", []) or []
|
|
134
|
+
)
|
|
135
|
+
if not tool_calls:
|
|
136
|
+
return []
|
|
137
|
+
|
|
138
|
+
message_dict = message.model_dump(exclude_none=False)
|
|
139
|
+
call_records: list[tuple[str, str, str, dict[str, Any]]] = []
|
|
140
|
+
|
|
141
|
+
for tool_call in tool_calls:
|
|
142
|
+
call_id = getattr(tool_call, "id", None) or f"call_{uuid4()}"
|
|
143
|
+
function_call = getattr(tool_call, "function", None)
|
|
144
|
+
custom_call = getattr(tool_call, "custom", None)
|
|
145
|
+
name = None
|
|
146
|
+
raw_args = "{}"
|
|
147
|
+
if function_call is not None:
|
|
148
|
+
name = getattr(function_call, "name", None)
|
|
149
|
+
raw_args = getattr(function_call, "arguments", "{}") or "{}"
|
|
150
|
+
elif custom_call is not None:
|
|
151
|
+
name = getattr(custom_call, "name", None)
|
|
152
|
+
raw_args = getattr(custom_call, "input", "{}") or "{}"
|
|
153
|
+
|
|
154
|
+
if not name:
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
parsed_args = json.loads(raw_args)
|
|
159
|
+
except json.JSONDecodeError:
|
|
160
|
+
parsed_args = {"raw": raw_args}
|
|
161
|
+
|
|
162
|
+
call_records.append((call_id, name, raw_args, parsed_args))
|
|
163
|
+
|
|
164
|
+
patched_calls: list[dict[str, Any]] = []
|
|
165
|
+
for call_id, name, raw_args, parsed_args in call_records:
|
|
166
|
+
patched_calls.append(
|
|
167
|
+
{
|
|
168
|
+
"id": call_id,
|
|
169
|
+
"type": "function",
|
|
170
|
+
"function": {
|
|
171
|
+
"name": name,
|
|
172
|
+
"arguments": raw_args,
|
|
173
|
+
},
|
|
174
|
+
}
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
message_dict["tool_calls"] = patched_calls
|
|
178
|
+
|
|
179
|
+
agent_calls: list[AgentToolCall] = []
|
|
180
|
+
for call_id, name, raw_args, parsed_args in call_records:
|
|
181
|
+
agent_calls.append(
|
|
182
|
+
AgentToolCall(
|
|
183
|
+
tool_name=name,
|
|
184
|
+
tool_args=parsed_args,
|
|
185
|
+
tool_call_id=call_id,
|
|
186
|
+
completion=message_dict,
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return agent_calls
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class OpenAIChatCompletionsSupervisorFunction(SupervisorFunction):
|
|
194
|
+
"""
|
|
195
|
+
A MAIL supervisor function that uses the OpenAI API to generate chat completions.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def __init__(
|
|
199
|
+
self,
|
|
200
|
+
name: str,
|
|
201
|
+
comm_targets: list[str],
|
|
202
|
+
model: str,
|
|
203
|
+
tools: list[dict[str, Any]],
|
|
204
|
+
can_complete_tasks: bool = True,
|
|
205
|
+
enable_entrypoint: bool = False,
|
|
206
|
+
enable_interswarm: bool = False,
|
|
207
|
+
tool_format: Literal["completions", "responses"] = "responses",
|
|
208
|
+
exclude_tools: list[str] = [],
|
|
209
|
+
**kwargs: Any,
|
|
210
|
+
) -> None:
|
|
211
|
+
super().__init__(
|
|
212
|
+
name=name,
|
|
213
|
+
comm_targets=comm_targets,
|
|
214
|
+
tools=tools,
|
|
215
|
+
can_complete_tasks=True, # supervisor can always complete tasks; param kept for compatibility
|
|
216
|
+
enable_entrypoint=enable_entrypoint,
|
|
217
|
+
enable_interswarm=enable_interswarm,
|
|
218
|
+
tool_format="completions",
|
|
219
|
+
exclude_tools=exclude_tools,
|
|
220
|
+
**kwargs,
|
|
221
|
+
)
|
|
222
|
+
self.model = model
|
|
223
|
+
self.client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
224
|
+
self.supervisor_fn = OpenAIChatCompletionsAgentFunction(
|
|
225
|
+
name=name,
|
|
226
|
+
comm_targets=comm_targets,
|
|
227
|
+
model=model,
|
|
228
|
+
tools=self.tools,
|
|
229
|
+
enable_entrypoint=enable_entrypoint,
|
|
230
|
+
enable_interswarm=enable_interswarm,
|
|
231
|
+
can_complete_tasks=True, # supervisor can always complete tasks; param kept for compatibility
|
|
232
|
+
tool_format="completions",
|
|
233
|
+
exclude_tools=exclude_tools,
|
|
234
|
+
**kwargs,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def __call__(
|
|
238
|
+
self,
|
|
239
|
+
messages: list[dict[str, Any]],
|
|
240
|
+
tool_choice: str | dict[str, str] = "required",
|
|
241
|
+
) -> Awaitable[AgentOutput]:
|
|
242
|
+
"""
|
|
243
|
+
Generate a chat completion using the OpenAI API.
|
|
244
|
+
"""
|
|
245
|
+
return self.supervisor_fn(messages, tool_choice)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class OpenAIResponsesAgentFunction(MAILAgentFunction):
|
|
249
|
+
"""
|
|
250
|
+
A MAIL agent function that uses the OpenAI API to generate responses.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
def __init__(
|
|
254
|
+
self,
|
|
255
|
+
name: str,
|
|
256
|
+
comm_targets: list[str],
|
|
257
|
+
model: str,
|
|
258
|
+
tools: list[dict[str, Any]],
|
|
259
|
+
enable_entrypoint: bool = False,
|
|
260
|
+
enable_interswarm: bool = False,
|
|
261
|
+
can_complete_tasks: bool = False,
|
|
262
|
+
tool_format: Literal[
|
|
263
|
+
"completions", "responses"
|
|
264
|
+
] = "responses", # kept for compatibility
|
|
265
|
+
exclude_tools: list[str] = [],
|
|
266
|
+
**kwargs: Any,
|
|
267
|
+
) -> None:
|
|
268
|
+
super().__init__(
|
|
269
|
+
name=name,
|
|
270
|
+
comm_targets=comm_targets,
|
|
271
|
+
tools=tools,
|
|
272
|
+
enable_entrypoint=enable_entrypoint,
|
|
273
|
+
enable_interswarm=enable_interswarm,
|
|
274
|
+
can_complete_tasks=can_complete_tasks,
|
|
275
|
+
tool_format="responses",
|
|
276
|
+
exclude_tools=exclude_tools,
|
|
277
|
+
**kwargs,
|
|
278
|
+
)
|
|
279
|
+
self.model = model
|
|
280
|
+
self.client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
281
|
+
|
|
282
|
+
def __call__(
|
|
283
|
+
self,
|
|
284
|
+
messages: list[dict[str, Any]],
|
|
285
|
+
tool_choice: str | dict[str, str] = "required",
|
|
286
|
+
) -> Awaitable[AgentOutput]:
|
|
287
|
+
"""
|
|
288
|
+
Generate a response using the OpenAI API.
|
|
289
|
+
"""
|
|
290
|
+
return self._run_response(messages, tool_choice)
|
|
291
|
+
|
|
292
|
+
async def _run_response(
|
|
293
|
+
self,
|
|
294
|
+
messages: list[dict[str, Any]],
|
|
295
|
+
tool_choice: str | dict[str, str] = "required",
|
|
296
|
+
) -> AgentOutput:
|
|
297
|
+
"""
|
|
298
|
+
Run a response using the OpenAI API.
|
|
299
|
+
"""
|
|
300
|
+
response = await self.client.responses.create( # type: ignore
|
|
301
|
+
model=self.model,
|
|
302
|
+
input=self._preprocess_messages(messages),
|
|
303
|
+
tool_choice=tool_choice,
|
|
304
|
+
tools=self._preprocess_tools(),
|
|
305
|
+
)
|
|
306
|
+
response_dict = response.model_dump()
|
|
307
|
+
outputs: list[dict[str, Any]] = response_dict.get("output", [])
|
|
308
|
+
# Deep-copy outputs so we can normalise in place without mutating original objects
|
|
309
|
+
normalized_outputs: list[dict[str, Any]] = json.loads(json.dumps(outputs))
|
|
310
|
+
text_segments: list[str] = []
|
|
311
|
+
call_records: list[tuple[str, str, dict[str, Any]]] = []
|
|
312
|
+
|
|
313
|
+
for block in outputs:
|
|
314
|
+
block_type = block.get("type")
|
|
315
|
+
if block_type == "message":
|
|
316
|
+
for content in block.get("content", []):
|
|
317
|
+
if content.get("type") in {"output_text", "text"}:
|
|
318
|
+
text_segments.append(content.get("text", ""))
|
|
319
|
+
elif block_type in {"custom_tool_call", "tool_call", "function_call"}:
|
|
320
|
+
name = (
|
|
321
|
+
block.get("name")
|
|
322
|
+
or block.get("tool", {}).get("name")
|
|
323
|
+
or block.get("function", {}).get("name")
|
|
324
|
+
)
|
|
325
|
+
if not name:
|
|
326
|
+
continue
|
|
327
|
+
call_id = block.get("call_id") or block.get("id") or f"call_{uuid4()}"
|
|
328
|
+
raw_input = (
|
|
329
|
+
block.get("input")
|
|
330
|
+
or block.get("tool_input")
|
|
331
|
+
or block.get("arguments")
|
|
332
|
+
or "{}"
|
|
333
|
+
)
|
|
334
|
+
if isinstance(raw_input, dict):
|
|
335
|
+
parsed_input = raw_input
|
|
336
|
+
else:
|
|
337
|
+
try:
|
|
338
|
+
parsed_input = json.loads(raw_input)
|
|
339
|
+
except json.JSONDecodeError:
|
|
340
|
+
parsed_input = {"raw": raw_input}
|
|
341
|
+
call_records.append((call_id, name, parsed_input))
|
|
342
|
+
|
|
343
|
+
for block in normalized_outputs:
|
|
344
|
+
if block.get("type") == "message":
|
|
345
|
+
for content in block.get("content", []):
|
|
346
|
+
if content.get("type") == "output_text":
|
|
347
|
+
content["type"] = "text"
|
|
348
|
+
|
|
349
|
+
agent_tool_calls = [
|
|
350
|
+
AgentToolCall(
|
|
351
|
+
tool_name=name,
|
|
352
|
+
tool_args=tool_args,
|
|
353
|
+
tool_call_id=call_id,
|
|
354
|
+
responses=normalized_outputs,
|
|
355
|
+
)
|
|
356
|
+
for call_id, name, tool_args in call_records
|
|
357
|
+
]
|
|
358
|
+
|
|
359
|
+
response_text = "\n".join(filter(None, text_segments)) or None
|
|
360
|
+
return response_text, agent_tool_calls
|
|
361
|
+
|
|
362
|
+
def _preprocess_messages(
|
|
363
|
+
self,
|
|
364
|
+
messages: list[dict[str, Any]],
|
|
365
|
+
) -> list[ResponseInputParam]:
|
|
366
|
+
"""
|
|
367
|
+
Preprocess the messages for the OpenAI API.
|
|
368
|
+
"""
|
|
369
|
+
normalized: list[ResponseInputParam] = []
|
|
370
|
+
for message in messages:
|
|
371
|
+
entry: dict[str, Any] = {
|
|
372
|
+
"role": message.get("role"),
|
|
373
|
+
"content": message.get("content"),
|
|
374
|
+
}
|
|
375
|
+
if "name" in message:
|
|
376
|
+
entry["name"] = message["name"]
|
|
377
|
+
if "tool_call_id" in message:
|
|
378
|
+
entry["tool_call_id"] = message["tool_call_id"]
|
|
379
|
+
normalized.append(entry) # type: ignore[arg-type]
|
|
380
|
+
return normalized
|
|
381
|
+
|
|
382
|
+
def _preprocess_tools(self) -> list[ToolParam]:
|
|
383
|
+
"""
|
|
384
|
+
Preprocess the tools for the OpenAI API.
|
|
385
|
+
"""
|
|
386
|
+
return [
|
|
387
|
+
{ # type: ignore
|
|
388
|
+
"type": "function",
|
|
389
|
+
"function": {
|
|
390
|
+
"name": tool["name"],
|
|
391
|
+
"description": tool.get("description"),
|
|
392
|
+
"parameters": tool.get("parameters"),
|
|
393
|
+
},
|
|
394
|
+
}
|
|
395
|
+
for tool in self.tools
|
|
396
|
+
]
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class OpenAIResponsesSupervisorFunction(SupervisorFunction):
|
|
400
|
+
"""
|
|
401
|
+
A MAIL supervisor function that uses the OpenAI API to generate responses.
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
def __init__(
|
|
405
|
+
self,
|
|
406
|
+
name: str,
|
|
407
|
+
comm_targets: list[str],
|
|
408
|
+
model: str,
|
|
409
|
+
tools: list[dict[str, Any]],
|
|
410
|
+
can_complete_tasks: bool = True,
|
|
411
|
+
enable_entrypoint: bool = False,
|
|
412
|
+
enable_interswarm: bool = False,
|
|
413
|
+
tool_format: Literal["completions", "responses"] = "responses",
|
|
414
|
+
exclude_tools: list[str] = [],
|
|
415
|
+
**kwargs: Any,
|
|
416
|
+
) -> None:
|
|
417
|
+
super().__init__(
|
|
418
|
+
name=name,
|
|
419
|
+
comm_targets=comm_targets,
|
|
420
|
+
tools=tools,
|
|
421
|
+
can_complete_tasks=True, # supervisor can always complete tasks; param kept for compatibility
|
|
422
|
+
enable_entrypoint=enable_entrypoint,
|
|
423
|
+
enable_interswarm=enable_interswarm,
|
|
424
|
+
tool_format="responses",
|
|
425
|
+
exclude_tools=exclude_tools,
|
|
426
|
+
**kwargs,
|
|
427
|
+
)
|
|
428
|
+
self.model = model
|
|
429
|
+
self.client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
430
|
+
self.supervisor_fn = OpenAIResponsesAgentFunction(
|
|
431
|
+
name=name,
|
|
432
|
+
comm_targets=comm_targets,
|
|
433
|
+
model=model,
|
|
434
|
+
tools=self.tools,
|
|
435
|
+
enable_entrypoint=enable_entrypoint,
|
|
436
|
+
enable_interswarm=enable_interswarm,
|
|
437
|
+
can_complete_tasks=True, # supervisor can always complete tasks; param kept for compatibility
|
|
438
|
+
tool_format="responses",
|
|
439
|
+
exclude_tools=exclude_tools,
|
|
440
|
+
**kwargs,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
def __call__(
|
|
444
|
+
self,
|
|
445
|
+
messages: list[dict[str, Any]],
|
|
446
|
+
tool_choice: str | dict[str, str] = "required",
|
|
447
|
+
) -> Awaitable[AgentOutput]:
|
|
448
|
+
"""
|
|
449
|
+
Generate a response using the OpenAI API.
|
|
450
|
+
"""
|
|
451
|
+
return self.supervisor_fn(messages, tool_choice)
|
mail/summarizer.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task Summarizer using MAIL.
|
|
3
|
+
|
|
4
|
+
Generates short titles (max 6 words) for conversation tasks using Haiku.
|
|
5
|
+
Uses the breakpoint tool pattern for structured output.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
14
|
+
|
|
15
|
+
from mail import MAILAction, MAILAgentTemplate, MAILSwarmTemplate
|
|
16
|
+
from mail.factories import LiteLLMSupervisorFunction
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from mail import MAILSwarm
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ============================================================================
|
|
23
|
+
# Tool Definition: submit_title
|
|
24
|
+
# ============================================================================
|
|
25
|
+
|
|
26
|
+
class SubmitTitleArgs(BaseModel):
|
|
27
|
+
"""Arguments for the submit_title tool."""
|
|
28
|
+
title: str = Field(
|
|
29
|
+
description="A concise title for the conversation (maximum 6 words)",
|
|
30
|
+
max_length=50,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def _submit_title_fn(args: dict) -> str:
|
|
35
|
+
"""
|
|
36
|
+
Submit a title. This is a breakpoint tool - execution pauses here
|
|
37
|
+
and the args are returned to the caller for processing.
|
|
38
|
+
"""
|
|
39
|
+
return json.dumps({"status": "submitted", "title": args.get("title", "")})
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
submit_title_action = MAILAction.from_pydantic_model(
|
|
43
|
+
model=SubmitTitleArgs,
|
|
44
|
+
function=_submit_title_fn,
|
|
45
|
+
name="submit_title",
|
|
46
|
+
description="Submit a short title (max 6 words) summarizing the conversation.",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ============================================================================
|
|
51
|
+
# System prompt
|
|
52
|
+
# ============================================================================
|
|
53
|
+
|
|
54
|
+
SYSTEM_PROMPT = """You are a title generator. Given a conversation between a user and an AI assistant, generate a concise title that captures the main topic or request.
|
|
55
|
+
|
|
56
|
+
Rules:
|
|
57
|
+
- Maximum 6 words
|
|
58
|
+
- Be specific and descriptive
|
|
59
|
+
- Use title case
|
|
60
|
+
- No quotes or punctuation at the end
|
|
61
|
+
- Focus on what the user wanted, not what the assistant did
|
|
62
|
+
|
|
63
|
+
Examples:
|
|
64
|
+
- "Weather Forecast for Tokyo"
|
|
65
|
+
- "Debug Python Import Error"
|
|
66
|
+
- "Explain Quantum Entanglement"
|
|
67
|
+
- "Create React Login Form"
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ============================================================================
|
|
72
|
+
# TaskSummarizer class
|
|
73
|
+
# ============================================================================
|
|
74
|
+
|
|
75
|
+
class TaskSummarizer:
|
|
76
|
+
"""
|
|
77
|
+
Generates short titles for conversation tasks.
|
|
78
|
+
|
|
79
|
+
Uses Haiku for fast, cheap summarization with structured output
|
|
80
|
+
via the breakpoint tool pattern.
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
summarizer = TaskSummarizer()
|
|
84
|
+
|
|
85
|
+
messages = [
|
|
86
|
+
{"role": "user", "content": "What's the weather in Tokyo?"},
|
|
87
|
+
{"role": "assistant", "content": "The weather in Tokyo is..."},
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
title = await summarizer.summarize(messages)
|
|
91
|
+
# Returns: "Tokyo Weather Forecast"
|
|
92
|
+
|
|
93
|
+
await summarizer.shutdown()
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(self, model: str = "anthropic/claude-haiku-4-5-20251001"):
|
|
97
|
+
self.model = model
|
|
98
|
+
self._swarm: MAILSwarm | None = None
|
|
99
|
+
self._template = self._create_template()
|
|
100
|
+
|
|
101
|
+
def _create_template(self) -> MAILSwarmTemplate:
|
|
102
|
+
"""Create the MAIL swarm template with submit_title as breakpoint tool."""
|
|
103
|
+
agent = MAILAgentTemplate(
|
|
104
|
+
name="summarizer",
|
|
105
|
+
factory=LiteLLMSupervisorFunction,
|
|
106
|
+
comm_targets=[],
|
|
107
|
+
actions=[submit_title_action],
|
|
108
|
+
agent_params={
|
|
109
|
+
"llm": self.model,
|
|
110
|
+
"system": SYSTEM_PROMPT,
|
|
111
|
+
"use_proxy": False,
|
|
112
|
+
},
|
|
113
|
+
enable_entrypoint=True,
|
|
114
|
+
can_complete_tasks=True,
|
|
115
|
+
tool_format="completions",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return MAILSwarmTemplate(
|
|
119
|
+
name="task_summarizer",
|
|
120
|
+
version="1.0.0",
|
|
121
|
+
agents=[agent],
|
|
122
|
+
actions=[submit_title_action],
|
|
123
|
+
entrypoint="summarizer",
|
|
124
|
+
breakpoint_tools=["submit_title"],
|
|
125
|
+
exclude_tools=["task_complete"], # Force use of submit_title breakpoint
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
async def _get_swarm(self) -> "MAILSwarm":
|
|
129
|
+
"""Get or create the swarm instance."""
|
|
130
|
+
if self._swarm is None:
|
|
131
|
+
self._swarm = self._template.instantiate(
|
|
132
|
+
instance_params={"user_token": "summarizer"},
|
|
133
|
+
user_id="summarizer_user",
|
|
134
|
+
)
|
|
135
|
+
return self._swarm
|
|
136
|
+
|
|
137
|
+
def _parse_title(self, response: dict) -> str | None:
|
|
138
|
+
"""Parse the title from the breakpoint tool call response."""
|
|
139
|
+
message = response.get("message", {})
|
|
140
|
+
subject = message.get("subject", "")
|
|
141
|
+
body = message.get("body", "")
|
|
142
|
+
|
|
143
|
+
if subject != "::breakpoint_tool_call::" or not body:
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
body_data = json.loads(body)
|
|
148
|
+
|
|
149
|
+
# Tool calls are standardized to OpenAI/LiteLLM format:
|
|
150
|
+
# [{"arguments": "{\"title\":\"...\"}", "name": "submit_title", "id": "..."}]
|
|
151
|
+
if isinstance(body_data, list):
|
|
152
|
+
for call in body_data:
|
|
153
|
+
if call.get("name") == "submit_title":
|
|
154
|
+
args = call.get("arguments", "{}")
|
|
155
|
+
if isinstance(args, str):
|
|
156
|
+
args = json.loads(args)
|
|
157
|
+
return args.get("title")
|
|
158
|
+
|
|
159
|
+
return None
|
|
160
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
async def summarize(
|
|
164
|
+
self,
|
|
165
|
+
messages: list[dict],
|
|
166
|
+
max_messages: int = 10,
|
|
167
|
+
) -> str | None:
|
|
168
|
+
"""
|
|
169
|
+
Generate a title for a conversation.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
messages: List of message dicts with 'role' and 'content' keys.
|
|
173
|
+
Only 'user' and 'assistant' roles are included.
|
|
174
|
+
max_messages: Maximum number of recent messages to include (default 10)
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
A short title string, or None if generation failed.
|
|
178
|
+
"""
|
|
179
|
+
# Filter to user/assistant messages and take last N
|
|
180
|
+
filtered = [
|
|
181
|
+
m for m in messages
|
|
182
|
+
if m.get("role") in ("user", "assistant")
|
|
183
|
+
][-max_messages:]
|
|
184
|
+
|
|
185
|
+
if not filtered:
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
# Format messages for the prompt
|
|
189
|
+
formatted = []
|
|
190
|
+
for msg in filtered:
|
|
191
|
+
role = msg["role"].upper()
|
|
192
|
+
content = msg.get("content", "")
|
|
193
|
+
# Truncate long messages
|
|
194
|
+
if len(content) > 500:
|
|
195
|
+
content = content[:500] + "..."
|
|
196
|
+
formatted.append(f"{role}: {content}")
|
|
197
|
+
|
|
198
|
+
prompt = "Generate a title for this conversation:\n\n" + "\n\n".join(formatted)
|
|
199
|
+
|
|
200
|
+
swarm = await self._get_swarm()
|
|
201
|
+
|
|
202
|
+
response, _ = await swarm.post_message_and_run(
|
|
203
|
+
body=prompt,
|
|
204
|
+
subject="Summarize",
|
|
205
|
+
show_events=False,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return self._parse_title(response) # type: ignore
|
|
209
|
+
|
|
210
|
+
async def shutdown(self):
|
|
211
|
+
"""Shutdown the swarm."""
|
|
212
|
+
if self._swarm is not None:
|
|
213
|
+
await self._swarm.shutdown()
|
|
214
|
+
self._swarm = None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
async def summarize_task(messages: list[dict], max_messages: int = 10) -> str | None:
|
|
218
|
+
"""
|
|
219
|
+
Generate a title for a conversation.
|
|
220
|
+
|
|
221
|
+
Creates a fresh swarm per request to avoid concurrency issues.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
messages: List of message dicts with 'role' and 'content' keys
|
|
225
|
+
max_messages: Maximum recent messages to include
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
A short title string, or None if generation failed
|
|
229
|
+
"""
|
|
230
|
+
summarizer = TaskSummarizer()
|
|
231
|
+
try:
|
|
232
|
+
return await summarizer.summarize(messages, max_messages)
|
|
233
|
+
finally:
|
|
234
|
+
await summarizer.shutdown()
|