mail-swarms 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. mail/__init__.py +35 -0
  2. mail/api.py +1964 -0
  3. mail/cli.py +432 -0
  4. mail/client.py +1657 -0
  5. mail/config/__init__.py +8 -0
  6. mail/config/client.py +87 -0
  7. mail/config/server.py +165 -0
  8. mail/core/__init__.py +72 -0
  9. mail/core/actions.py +69 -0
  10. mail/core/agents.py +73 -0
  11. mail/core/message.py +366 -0
  12. mail/core/runtime.py +3537 -0
  13. mail/core/tasks.py +311 -0
  14. mail/core/tools.py +1206 -0
  15. mail/db/__init__.py +0 -0
  16. mail/db/init.py +182 -0
  17. mail/db/types.py +65 -0
  18. mail/db/utils.py +523 -0
  19. mail/examples/__init__.py +27 -0
  20. mail/examples/analyst_dummy/__init__.py +15 -0
  21. mail/examples/analyst_dummy/agent.py +136 -0
  22. mail/examples/analyst_dummy/prompts.py +44 -0
  23. mail/examples/consultant_dummy/__init__.py +15 -0
  24. mail/examples/consultant_dummy/agent.py +136 -0
  25. mail/examples/consultant_dummy/prompts.py +42 -0
  26. mail/examples/data_analysis/__init__.py +40 -0
  27. mail/examples/data_analysis/analyst/__init__.py +9 -0
  28. mail/examples/data_analysis/analyst/agent.py +67 -0
  29. mail/examples/data_analysis/analyst/prompts.py +53 -0
  30. mail/examples/data_analysis/processor/__init__.py +13 -0
  31. mail/examples/data_analysis/processor/actions.py +293 -0
  32. mail/examples/data_analysis/processor/agent.py +67 -0
  33. mail/examples/data_analysis/processor/prompts.py +48 -0
  34. mail/examples/data_analysis/reporter/__init__.py +10 -0
  35. mail/examples/data_analysis/reporter/actions.py +187 -0
  36. mail/examples/data_analysis/reporter/agent.py +67 -0
  37. mail/examples/data_analysis/reporter/prompts.py +49 -0
  38. mail/examples/data_analysis/statistics/__init__.py +18 -0
  39. mail/examples/data_analysis/statistics/actions.py +343 -0
  40. mail/examples/data_analysis/statistics/agent.py +67 -0
  41. mail/examples/data_analysis/statistics/prompts.py +60 -0
  42. mail/examples/mafia/__init__.py +0 -0
  43. mail/examples/mafia/game.py +1537 -0
  44. mail/examples/mafia/narrator_tools.py +396 -0
  45. mail/examples/mafia/personas.py +240 -0
  46. mail/examples/mafia/prompts.py +489 -0
  47. mail/examples/mafia/roles.py +147 -0
  48. mail/examples/mafia/spec.md +350 -0
  49. mail/examples/math_dummy/__init__.py +23 -0
  50. mail/examples/math_dummy/actions.py +252 -0
  51. mail/examples/math_dummy/agent.py +136 -0
  52. mail/examples/math_dummy/prompts.py +46 -0
  53. mail/examples/math_dummy/types.py +5 -0
  54. mail/examples/research/__init__.py +39 -0
  55. mail/examples/research/researcher/__init__.py +9 -0
  56. mail/examples/research/researcher/agent.py +67 -0
  57. mail/examples/research/researcher/prompts.py +54 -0
  58. mail/examples/research/searcher/__init__.py +10 -0
  59. mail/examples/research/searcher/actions.py +324 -0
  60. mail/examples/research/searcher/agent.py +67 -0
  61. mail/examples/research/searcher/prompts.py +53 -0
  62. mail/examples/research/summarizer/__init__.py +18 -0
  63. mail/examples/research/summarizer/actions.py +255 -0
  64. mail/examples/research/summarizer/agent.py +67 -0
  65. mail/examples/research/summarizer/prompts.py +55 -0
  66. mail/examples/research/verifier/__init__.py +10 -0
  67. mail/examples/research/verifier/actions.py +337 -0
  68. mail/examples/research/verifier/agent.py +67 -0
  69. mail/examples/research/verifier/prompts.py +52 -0
  70. mail/examples/supervisor/__init__.py +11 -0
  71. mail/examples/supervisor/agent.py +4 -0
  72. mail/examples/supervisor/prompts.py +93 -0
  73. mail/examples/support/__init__.py +33 -0
  74. mail/examples/support/classifier/__init__.py +10 -0
  75. mail/examples/support/classifier/actions.py +307 -0
  76. mail/examples/support/classifier/agent.py +68 -0
  77. mail/examples/support/classifier/prompts.py +56 -0
  78. mail/examples/support/coordinator/__init__.py +9 -0
  79. mail/examples/support/coordinator/agent.py +67 -0
  80. mail/examples/support/coordinator/prompts.py +48 -0
  81. mail/examples/support/faq/__init__.py +10 -0
  82. mail/examples/support/faq/actions.py +182 -0
  83. mail/examples/support/faq/agent.py +67 -0
  84. mail/examples/support/faq/prompts.py +42 -0
  85. mail/examples/support/sentiment/__init__.py +15 -0
  86. mail/examples/support/sentiment/actions.py +341 -0
  87. mail/examples/support/sentiment/agent.py +67 -0
  88. mail/examples/support/sentiment/prompts.py +54 -0
  89. mail/examples/weather_dummy/__init__.py +23 -0
  90. mail/examples/weather_dummy/actions.py +75 -0
  91. mail/examples/weather_dummy/agent.py +136 -0
  92. mail/examples/weather_dummy/prompts.py +35 -0
  93. mail/examples/weather_dummy/types.py +5 -0
  94. mail/factories/__init__.py +27 -0
  95. mail/factories/action.py +223 -0
  96. mail/factories/base.py +1531 -0
  97. mail/factories/supervisor.py +241 -0
  98. mail/net/__init__.py +7 -0
  99. mail/net/registry.py +712 -0
  100. mail/net/router.py +728 -0
  101. mail/net/server_utils.py +114 -0
  102. mail/net/types.py +247 -0
  103. mail/server.py +1605 -0
  104. mail/stdlib/__init__.py +0 -0
  105. mail/stdlib/anthropic/__init__.py +0 -0
  106. mail/stdlib/fs/__init__.py +15 -0
  107. mail/stdlib/fs/actions.py +209 -0
  108. mail/stdlib/http/__init__.py +19 -0
  109. mail/stdlib/http/actions.py +333 -0
  110. mail/stdlib/interswarm/__init__.py +11 -0
  111. mail/stdlib/interswarm/actions.py +208 -0
  112. mail/stdlib/mcp/__init__.py +19 -0
  113. mail/stdlib/mcp/actions.py +294 -0
  114. mail/stdlib/openai/__init__.py +13 -0
  115. mail/stdlib/openai/agents.py +451 -0
  116. mail/summarizer.py +234 -0
  117. mail/swarms_json/__init__.py +27 -0
  118. mail/swarms_json/types.py +87 -0
  119. mail/swarms_json/utils.py +255 -0
  120. mail/url_scheme.py +51 -0
  121. mail/utils/__init__.py +53 -0
  122. mail/utils/auth.py +194 -0
  123. mail/utils/context.py +17 -0
  124. mail/utils/logger.py +73 -0
  125. mail/utils/openai.py +212 -0
  126. mail/utils/parsing.py +89 -0
  127. mail/utils/serialize.py +292 -0
  128. mail/utils/store.py +49 -0
  129. mail/utils/string_builder.py +119 -0
  130. mail/utils/version.py +20 -0
  131. mail_swarms-1.3.2.dist-info/METADATA +237 -0
  132. mail_swarms-1.3.2.dist-info/RECORD +137 -0
  133. mail_swarms-1.3.2.dist-info/WHEEL +4 -0
  134. mail_swarms-1.3.2.dist-info/entry_points.txt +2 -0
  135. mail_swarms-1.3.2.dist-info/licenses/LICENSE +202 -0
  136. mail_swarms-1.3.2.dist-info/licenses/NOTICE +10 -0
  137. mail_swarms-1.3.2.dist-info/licenses/THIRD_PARTY_NOTICES.md +12334 -0
@@ -0,0 +1,451 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copyright (c) 2025 Addison Kline
3
+
4
+ import json
5
+ import os
6
+ from collections.abc import Awaitable
7
+ from typing import Any, Literal
8
+ from uuid import uuid4
9
+
10
+ import openai
11
+ from openai.types.chat import (
12
+ ChatCompletionMessageParam,
13
+ ChatCompletionMessageToolCallUnion,
14
+ ChatCompletionToolUnionParam,
15
+ )
16
+ from openai.types.responses import ResponseInputParam, ToolParam
17
+
18
+ from mail.core.agents import AgentOutput
19
+ from mail.core.tools import AgentToolCall
20
+ from mail.factories.base import MAILAgentFunction
21
+ from mail.factories.supervisor import SupervisorFunction
22
+
23
+
24
+ class OpenAIChatCompletionsAgentFunction(MAILAgentFunction):
25
+ """
26
+ A MAIL agent function that uses the OpenAI API to generate chat completions.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ name: str,
32
+ comm_targets: list[str],
33
+ model: str,
34
+ tools: list[dict[str, Any]],
35
+ enable_entrypoint: bool = False,
36
+ enable_interswarm: bool = False,
37
+ can_complete_tasks: bool = False,
38
+ tool_format: Literal[
39
+ "completions", "responses"
40
+ ] = "completions", # kept for compatibility
41
+ exclude_tools: list[str] = [],
42
+ **kwargs: Any,
43
+ ) -> None:
44
+ super().__init__(
45
+ name=name,
46
+ comm_targets=comm_targets,
47
+ tools=tools,
48
+ enable_entrypoint=enable_entrypoint,
49
+ enable_interswarm=enable_interswarm,
50
+ can_complete_tasks=can_complete_tasks,
51
+ tool_format="completions",
52
+ exclude_tools=exclude_tools,
53
+ **kwargs,
54
+ )
55
+ self.model = model
56
+ self.client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
57
+
58
+ def __call__(
59
+ self,
60
+ messages: list[dict[str, Any]],
61
+ tool_choice: str | dict[str, str] = "required",
62
+ ) -> Awaitable[AgentOutput]:
63
+ """
64
+ Generate a chat completion using the OpenAI API.
65
+ """
66
+ return self._run_chat_completion(messages, tool_choice)
67
+
68
+ async def _run_chat_completion(
69
+ self,
70
+ messages: list[dict[str, Any]],
71
+ tool_choice: str | dict[str, str] = "required",
72
+ ) -> AgentOutput:
73
+ """
74
+ Run a chat completion using the OpenAI API.
75
+ """
76
+ response = await self.client.chat.completions.create( # type: ignore
77
+ model=self.model,
78
+ messages=self._preprocess_messages(messages),
79
+ tool_choice=tool_choice,
80
+ tools=self._preprocess_tools(),
81
+ )
82
+ choice = response.choices[0]
83
+ message = choice.message
84
+ tool_calls = self._postprocess_tool_calls(message)
85
+ return message.content or None, tool_calls
86
+
87
+ def _preprocess_messages(
88
+ self,
89
+ messages: list[dict[str, Any]],
90
+ ) -> list[ChatCompletionMessageParam]:
91
+ """
92
+ Preprocess the messages for the OpenAI API.
93
+ """
94
+ normalized: list[ChatCompletionMessageParam] = []
95
+ for message in messages:
96
+ entry: dict[str, Any] = {
97
+ "role": message.get("role"),
98
+ "content": message.get("content"),
99
+ }
100
+ if "name" in message:
101
+ entry["name"] = message["name"]
102
+ if "tool_calls" in message:
103
+ entry["tool_calls"] = message["tool_calls"]
104
+ if "tool_call_id" in message:
105
+ entry["tool_call_id"] = message["tool_call_id"]
106
+ normalized.append(entry) # type: ignore[arg-type]
107
+ return normalized
108
+
109
+ def _preprocess_tools(self) -> list[ChatCompletionToolUnionParam]:
110
+ """
111
+ Preprocess the tools for the OpenAI API.
112
+ """
113
+ return [
114
+ {
115
+ "type": "function",
116
+ "function": {
117
+ "name": tool["name"],
118
+ "description": tool["description"],
119
+ "parameters": tool["parameters"],
120
+ },
121
+ }
122
+ for tool in self.tools
123
+ ]
124
+
125
+ def _postprocess_tool_calls(
126
+ self,
127
+ message: Any,
128
+ ) -> list[AgentToolCall]:
129
+ """
130
+ Postprocess the tool calls from the OpenAI API response.
131
+ """
132
+ tool_calls: list[ChatCompletionMessageToolCallUnion] = list(
133
+ getattr(message, "tool_calls", []) or []
134
+ )
135
+ if not tool_calls:
136
+ return []
137
+
138
+ message_dict = message.model_dump(exclude_none=False)
139
+ call_records: list[tuple[str, str, str, dict[str, Any]]] = []
140
+
141
+ for tool_call in tool_calls:
142
+ call_id = getattr(tool_call, "id", None) or f"call_{uuid4()}"
143
+ function_call = getattr(tool_call, "function", None)
144
+ custom_call = getattr(tool_call, "custom", None)
145
+ name = None
146
+ raw_args = "{}"
147
+ if function_call is not None:
148
+ name = getattr(function_call, "name", None)
149
+ raw_args = getattr(function_call, "arguments", "{}") or "{}"
150
+ elif custom_call is not None:
151
+ name = getattr(custom_call, "name", None)
152
+ raw_args = getattr(custom_call, "input", "{}") or "{}"
153
+
154
+ if not name:
155
+ continue
156
+
157
+ try:
158
+ parsed_args = json.loads(raw_args)
159
+ except json.JSONDecodeError:
160
+ parsed_args = {"raw": raw_args}
161
+
162
+ call_records.append((call_id, name, raw_args, parsed_args))
163
+
164
+ patched_calls: list[dict[str, Any]] = []
165
+ for call_id, name, raw_args, parsed_args in call_records:
166
+ patched_calls.append(
167
+ {
168
+ "id": call_id,
169
+ "type": "function",
170
+ "function": {
171
+ "name": name,
172
+ "arguments": raw_args,
173
+ },
174
+ }
175
+ )
176
+
177
+ message_dict["tool_calls"] = patched_calls
178
+
179
+ agent_calls: list[AgentToolCall] = []
180
+ for call_id, name, raw_args, parsed_args in call_records:
181
+ agent_calls.append(
182
+ AgentToolCall(
183
+ tool_name=name,
184
+ tool_args=parsed_args,
185
+ tool_call_id=call_id,
186
+ completion=message_dict,
187
+ )
188
+ )
189
+
190
+ return agent_calls
191
+
192
+
193
+ class OpenAIChatCompletionsSupervisorFunction(SupervisorFunction):
194
+ """
195
+ A MAIL supervisor function that uses the OpenAI API to generate chat completions.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ name: str,
201
+ comm_targets: list[str],
202
+ model: str,
203
+ tools: list[dict[str, Any]],
204
+ can_complete_tasks: bool = True,
205
+ enable_entrypoint: bool = False,
206
+ enable_interswarm: bool = False,
207
+ tool_format: Literal["completions", "responses"] = "responses",
208
+ exclude_tools: list[str] = [],
209
+ **kwargs: Any,
210
+ ) -> None:
211
+ super().__init__(
212
+ name=name,
213
+ comm_targets=comm_targets,
214
+ tools=tools,
215
+ can_complete_tasks=True, # supervisor can always complete tasks; param kept for compatibility
216
+ enable_entrypoint=enable_entrypoint,
217
+ enable_interswarm=enable_interswarm,
218
+ tool_format="completions",
219
+ exclude_tools=exclude_tools,
220
+ **kwargs,
221
+ )
222
+ self.model = model
223
+ self.client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
224
+ self.supervisor_fn = OpenAIChatCompletionsAgentFunction(
225
+ name=name,
226
+ comm_targets=comm_targets,
227
+ model=model,
228
+ tools=self.tools,
229
+ enable_entrypoint=enable_entrypoint,
230
+ enable_interswarm=enable_interswarm,
231
+ can_complete_tasks=True, # supervisor can always complete tasks; param kept for compatibility
232
+ tool_format="completions",
233
+ exclude_tools=exclude_tools,
234
+ **kwargs,
235
+ )
236
+
237
+ def __call__(
238
+ self,
239
+ messages: list[dict[str, Any]],
240
+ tool_choice: str | dict[str, str] = "required",
241
+ ) -> Awaitable[AgentOutput]:
242
+ """
243
+ Generate a chat completion using the OpenAI API.
244
+ """
245
+ return self.supervisor_fn(messages, tool_choice)
246
+
247
+
248
+ class OpenAIResponsesAgentFunction(MAILAgentFunction):
249
+ """
250
+ A MAIL agent function that uses the OpenAI API to generate responses.
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ name: str,
256
+ comm_targets: list[str],
257
+ model: str,
258
+ tools: list[dict[str, Any]],
259
+ enable_entrypoint: bool = False,
260
+ enable_interswarm: bool = False,
261
+ can_complete_tasks: bool = False,
262
+ tool_format: Literal[
263
+ "completions", "responses"
264
+ ] = "responses", # kept for compatibility
265
+ exclude_tools: list[str] = [],
266
+ **kwargs: Any,
267
+ ) -> None:
268
+ super().__init__(
269
+ name=name,
270
+ comm_targets=comm_targets,
271
+ tools=tools,
272
+ enable_entrypoint=enable_entrypoint,
273
+ enable_interswarm=enable_interswarm,
274
+ can_complete_tasks=can_complete_tasks,
275
+ tool_format="responses",
276
+ exclude_tools=exclude_tools,
277
+ **kwargs,
278
+ )
279
+ self.model = model
280
+ self.client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
281
+
282
+ def __call__(
283
+ self,
284
+ messages: list[dict[str, Any]],
285
+ tool_choice: str | dict[str, str] = "required",
286
+ ) -> Awaitable[AgentOutput]:
287
+ """
288
+ Generate a response using the OpenAI API.
289
+ """
290
+ return self._run_response(messages, tool_choice)
291
+
292
+ async def _run_response(
293
+ self,
294
+ messages: list[dict[str, Any]],
295
+ tool_choice: str | dict[str, str] = "required",
296
+ ) -> AgentOutput:
297
+ """
298
+ Run a response using the OpenAI API.
299
+ """
300
+ response = await self.client.responses.create( # type: ignore
301
+ model=self.model,
302
+ input=self._preprocess_messages(messages),
303
+ tool_choice=tool_choice,
304
+ tools=self._preprocess_tools(),
305
+ )
306
+ response_dict = response.model_dump()
307
+ outputs: list[dict[str, Any]] = response_dict.get("output", [])
308
+ # Deep-copy outputs so we can normalise in place without mutating original objects
309
+ normalized_outputs: list[dict[str, Any]] = json.loads(json.dumps(outputs))
310
+ text_segments: list[str] = []
311
+ call_records: list[tuple[str, str, dict[str, Any]]] = []
312
+
313
+ for block in outputs:
314
+ block_type = block.get("type")
315
+ if block_type == "message":
316
+ for content in block.get("content", []):
317
+ if content.get("type") in {"output_text", "text"}:
318
+ text_segments.append(content.get("text", ""))
319
+ elif block_type in {"custom_tool_call", "tool_call", "function_call"}:
320
+ name = (
321
+ block.get("name")
322
+ or block.get("tool", {}).get("name")
323
+ or block.get("function", {}).get("name")
324
+ )
325
+ if not name:
326
+ continue
327
+ call_id = block.get("call_id") or block.get("id") or f"call_{uuid4()}"
328
+ raw_input = (
329
+ block.get("input")
330
+ or block.get("tool_input")
331
+ or block.get("arguments")
332
+ or "{}"
333
+ )
334
+ if isinstance(raw_input, dict):
335
+ parsed_input = raw_input
336
+ else:
337
+ try:
338
+ parsed_input = json.loads(raw_input)
339
+ except json.JSONDecodeError:
340
+ parsed_input = {"raw": raw_input}
341
+ call_records.append((call_id, name, parsed_input))
342
+
343
+ for block in normalized_outputs:
344
+ if block.get("type") == "message":
345
+ for content in block.get("content", []):
346
+ if content.get("type") == "output_text":
347
+ content["type"] = "text"
348
+
349
+ agent_tool_calls = [
350
+ AgentToolCall(
351
+ tool_name=name,
352
+ tool_args=tool_args,
353
+ tool_call_id=call_id,
354
+ responses=normalized_outputs,
355
+ )
356
+ for call_id, name, tool_args in call_records
357
+ ]
358
+
359
+ response_text = "\n".join(filter(None, text_segments)) or None
360
+ return response_text, agent_tool_calls
361
+
362
+ def _preprocess_messages(
363
+ self,
364
+ messages: list[dict[str, Any]],
365
+ ) -> list[ResponseInputParam]:
366
+ """
367
+ Preprocess the messages for the OpenAI API.
368
+ """
369
+ normalized: list[ResponseInputParam] = []
370
+ for message in messages:
371
+ entry: dict[str, Any] = {
372
+ "role": message.get("role"),
373
+ "content": message.get("content"),
374
+ }
375
+ if "name" in message:
376
+ entry["name"] = message["name"]
377
+ if "tool_call_id" in message:
378
+ entry["tool_call_id"] = message["tool_call_id"]
379
+ normalized.append(entry) # type: ignore[arg-type]
380
+ return normalized
381
+
382
+ def _preprocess_tools(self) -> list[ToolParam]:
383
+ """
384
+ Preprocess the tools for the OpenAI API.
385
+ """
386
+ return [
387
+ { # type: ignore
388
+ "type": "function",
389
+ "function": {
390
+ "name": tool["name"],
391
+ "description": tool.get("description"),
392
+ "parameters": tool.get("parameters"),
393
+ },
394
+ }
395
+ for tool in self.tools
396
+ ]
397
+
398
+
399
+ class OpenAIResponsesSupervisorFunction(SupervisorFunction):
400
+ """
401
+ A MAIL supervisor function that uses the OpenAI API to generate responses.
402
+ """
403
+
404
+ def __init__(
405
+ self,
406
+ name: str,
407
+ comm_targets: list[str],
408
+ model: str,
409
+ tools: list[dict[str, Any]],
410
+ can_complete_tasks: bool = True,
411
+ enable_entrypoint: bool = False,
412
+ enable_interswarm: bool = False,
413
+ tool_format: Literal["completions", "responses"] = "responses",
414
+ exclude_tools: list[str] = [],
415
+ **kwargs: Any,
416
+ ) -> None:
417
+ super().__init__(
418
+ name=name,
419
+ comm_targets=comm_targets,
420
+ tools=tools,
421
+ can_complete_tasks=True, # supervisor can always complete tasks; param kept for compatibility
422
+ enable_entrypoint=enable_entrypoint,
423
+ enable_interswarm=enable_interswarm,
424
+ tool_format="responses",
425
+ exclude_tools=exclude_tools,
426
+ **kwargs,
427
+ )
428
+ self.model = model
429
+ self.client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
430
+ self.supervisor_fn = OpenAIResponsesAgentFunction(
431
+ name=name,
432
+ comm_targets=comm_targets,
433
+ model=model,
434
+ tools=self.tools,
435
+ enable_entrypoint=enable_entrypoint,
436
+ enable_interswarm=enable_interswarm,
437
+ can_complete_tasks=True, # supervisor can always complete tasks; param kept for compatibility
438
+ tool_format="responses",
439
+ exclude_tools=exclude_tools,
440
+ **kwargs,
441
+ )
442
+
443
+ def __call__(
444
+ self,
445
+ messages: list[dict[str, Any]],
446
+ tool_choice: str | dict[str, str] = "required",
447
+ ) -> Awaitable[AgentOutput]:
448
+ """
449
+ Generate a response using the OpenAI API.
450
+ """
451
+ return self.supervisor_fn(messages, tool_choice)
mail/summarizer.py ADDED
@@ -0,0 +1,234 @@
1
+ """
2
+ Task Summarizer using MAIL.
3
+
4
+ Generates short titles (max 6 words) for conversation tasks using Haiku.
5
+ Uses the breakpoint tool pattern for structured output.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ from typing import TYPE_CHECKING
12
+
13
+ from pydantic import BaseModel, Field
14
+
15
+ from mail import MAILAction, MAILAgentTemplate, MAILSwarmTemplate
16
+ from mail.factories import LiteLLMSupervisorFunction
17
+
18
+ if TYPE_CHECKING:
19
+ from mail import MAILSwarm
20
+
21
+
22
+ # ============================================================================
23
+ # Tool Definition: submit_title
24
+ # ============================================================================
25
+
26
+ class SubmitTitleArgs(BaseModel):
27
+ """Arguments for the submit_title tool."""
28
+ title: str = Field(
29
+ description="A concise title for the conversation (maximum 6 words)",
30
+ max_length=50,
31
+ )
32
+
33
+
34
+ async def _submit_title_fn(args: dict) -> str:
35
+ """
36
+ Submit a title. This is a breakpoint tool - execution pauses here
37
+ and the args are returned to the caller for processing.
38
+ """
39
+ return json.dumps({"status": "submitted", "title": args.get("title", "")})
40
+
41
+
42
+ submit_title_action = MAILAction.from_pydantic_model(
43
+ model=SubmitTitleArgs,
44
+ function=_submit_title_fn,
45
+ name="submit_title",
46
+ description="Submit a short title (max 6 words) summarizing the conversation.",
47
+ )
48
+
49
+
50
+ # ============================================================================
51
+ # System prompt
52
+ # ============================================================================
53
+
54
+ SYSTEM_PROMPT = """You are a title generator. Given a conversation between a user and an AI assistant, generate a concise title that captures the main topic or request.
55
+
56
+ Rules:
57
+ - Maximum 6 words
58
+ - Be specific and descriptive
59
+ - Use title case
60
+ - No quotes or punctuation at the end
61
+ - Focus on what the user wanted, not what the assistant did
62
+
63
+ Examples:
64
+ - "Weather Forecast for Tokyo"
65
+ - "Debug Python Import Error"
66
+ - "Explain Quantum Entanglement"
67
+ - "Create React Login Form"
68
+ """
69
+
70
+
71
+ # ============================================================================
72
+ # TaskSummarizer class
73
+ # ============================================================================
74
+
75
+ class TaskSummarizer:
76
+ """
77
+ Generates short titles for conversation tasks.
78
+
79
+ Uses Haiku for fast, cheap summarization with structured output
80
+ via the breakpoint tool pattern.
81
+
82
+ Example:
83
+ summarizer = TaskSummarizer()
84
+
85
+ messages = [
86
+ {"role": "user", "content": "What's the weather in Tokyo?"},
87
+ {"role": "assistant", "content": "The weather in Tokyo is..."},
88
+ ]
89
+
90
+ title = await summarizer.summarize(messages)
91
+ # Returns: "Tokyo Weather Forecast"
92
+
93
+ await summarizer.shutdown()
94
+ """
95
+
96
+ def __init__(self, model: str = "anthropic/claude-haiku-4-5-20251001"):
97
+ self.model = model
98
+ self._swarm: MAILSwarm | None = None
99
+ self._template = self._create_template()
100
+
101
+ def _create_template(self) -> MAILSwarmTemplate:
102
+ """Create the MAIL swarm template with submit_title as breakpoint tool."""
103
+ agent = MAILAgentTemplate(
104
+ name="summarizer",
105
+ factory=LiteLLMSupervisorFunction,
106
+ comm_targets=[],
107
+ actions=[submit_title_action],
108
+ agent_params={
109
+ "llm": self.model,
110
+ "system": SYSTEM_PROMPT,
111
+ "use_proxy": False,
112
+ },
113
+ enable_entrypoint=True,
114
+ can_complete_tasks=True,
115
+ tool_format="completions",
116
+ )
117
+
118
+ return MAILSwarmTemplate(
119
+ name="task_summarizer",
120
+ version="1.0.0",
121
+ agents=[agent],
122
+ actions=[submit_title_action],
123
+ entrypoint="summarizer",
124
+ breakpoint_tools=["submit_title"],
125
+ exclude_tools=["task_complete"], # Force use of submit_title breakpoint
126
+ )
127
+
128
+ async def _get_swarm(self) -> "MAILSwarm":
129
+ """Get or create the swarm instance."""
130
+ if self._swarm is None:
131
+ self._swarm = self._template.instantiate(
132
+ instance_params={"user_token": "summarizer"},
133
+ user_id="summarizer_user",
134
+ )
135
+ return self._swarm
136
+
137
+ def _parse_title(self, response: dict) -> str | None:
138
+ """Parse the title from the breakpoint tool call response."""
139
+ message = response.get("message", {})
140
+ subject = message.get("subject", "")
141
+ body = message.get("body", "")
142
+
143
+ if subject != "::breakpoint_tool_call::" or not body:
144
+ return None
145
+
146
+ try:
147
+ body_data = json.loads(body)
148
+
149
+ # Tool calls are standardized to OpenAI/LiteLLM format:
150
+ # [{"arguments": "{\"title\":\"...\"}", "name": "submit_title", "id": "..."}]
151
+ if isinstance(body_data, list):
152
+ for call in body_data:
153
+ if call.get("name") == "submit_title":
154
+ args = call.get("arguments", "{}")
155
+ if isinstance(args, str):
156
+ args = json.loads(args)
157
+ return args.get("title")
158
+
159
+ return None
160
+ except (json.JSONDecodeError, KeyError, TypeError):
161
+ return None
162
+
163
+ async def summarize(
164
+ self,
165
+ messages: list[dict],
166
+ max_messages: int = 10,
167
+ ) -> str | None:
168
+ """
169
+ Generate a title for a conversation.
170
+
171
+ Args:
172
+ messages: List of message dicts with 'role' and 'content' keys.
173
+ Only 'user' and 'assistant' roles are included.
174
+ max_messages: Maximum number of recent messages to include (default 10)
175
+
176
+ Returns:
177
+ A short title string, or None if generation failed.
178
+ """
179
+ # Filter to user/assistant messages and take last N
180
+ filtered = [
181
+ m for m in messages
182
+ if m.get("role") in ("user", "assistant")
183
+ ][-max_messages:]
184
+
185
+ if not filtered:
186
+ return None
187
+
188
+ # Format messages for the prompt
189
+ formatted = []
190
+ for msg in filtered:
191
+ role = msg["role"].upper()
192
+ content = msg.get("content", "")
193
+ # Truncate long messages
194
+ if len(content) > 500:
195
+ content = content[:500] + "..."
196
+ formatted.append(f"{role}: {content}")
197
+
198
+ prompt = "Generate a title for this conversation:\n\n" + "\n\n".join(formatted)
199
+
200
+ swarm = await self._get_swarm()
201
+
202
+ response, _ = await swarm.post_message_and_run(
203
+ body=prompt,
204
+ subject="Summarize",
205
+ show_events=False,
206
+ )
207
+
208
+ return self._parse_title(response) # type: ignore
209
+
210
+ async def shutdown(self):
211
+ """Shutdown the swarm."""
212
+ if self._swarm is not None:
213
+ await self._swarm.shutdown()
214
+ self._swarm = None
215
+
216
+
217
+ async def summarize_task(messages: list[dict], max_messages: int = 10) -> str | None:
218
+ """
219
+ Generate a title for a conversation.
220
+
221
+ Creates a fresh swarm per request to avoid concurrency issues.
222
+
223
+ Args:
224
+ messages: List of message dicts with 'role' and 'content' keys
225
+ max_messages: Maximum recent messages to include
226
+
227
+ Returns:
228
+ A short title string, or None if generation failed
229
+ """
230
+ summarizer = TaskSummarizer()
231
+ try:
232
+ return await summarizer.summarize(messages, max_messages)
233
+ finally:
234
+ await summarizer.shutdown()