mail-swarms 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mail/__init__.py +35 -0
- mail/api.py +1964 -0
- mail/cli.py +432 -0
- mail/client.py +1657 -0
- mail/config/__init__.py +8 -0
- mail/config/client.py +87 -0
- mail/config/server.py +165 -0
- mail/core/__init__.py +72 -0
- mail/core/actions.py +69 -0
- mail/core/agents.py +73 -0
- mail/core/message.py +366 -0
- mail/core/runtime.py +3537 -0
- mail/core/tasks.py +311 -0
- mail/core/tools.py +1206 -0
- mail/db/__init__.py +0 -0
- mail/db/init.py +182 -0
- mail/db/types.py +65 -0
- mail/db/utils.py +523 -0
- mail/examples/__init__.py +27 -0
- mail/examples/analyst_dummy/__init__.py +15 -0
- mail/examples/analyst_dummy/agent.py +136 -0
- mail/examples/analyst_dummy/prompts.py +44 -0
- mail/examples/consultant_dummy/__init__.py +15 -0
- mail/examples/consultant_dummy/agent.py +136 -0
- mail/examples/consultant_dummy/prompts.py +42 -0
- mail/examples/data_analysis/__init__.py +40 -0
- mail/examples/data_analysis/analyst/__init__.py +9 -0
- mail/examples/data_analysis/analyst/agent.py +67 -0
- mail/examples/data_analysis/analyst/prompts.py +53 -0
- mail/examples/data_analysis/processor/__init__.py +13 -0
- mail/examples/data_analysis/processor/actions.py +293 -0
- mail/examples/data_analysis/processor/agent.py +67 -0
- mail/examples/data_analysis/processor/prompts.py +48 -0
- mail/examples/data_analysis/reporter/__init__.py +10 -0
- mail/examples/data_analysis/reporter/actions.py +187 -0
- mail/examples/data_analysis/reporter/agent.py +67 -0
- mail/examples/data_analysis/reporter/prompts.py +49 -0
- mail/examples/data_analysis/statistics/__init__.py +18 -0
- mail/examples/data_analysis/statistics/actions.py +343 -0
- mail/examples/data_analysis/statistics/agent.py +67 -0
- mail/examples/data_analysis/statistics/prompts.py +60 -0
- mail/examples/mafia/__init__.py +0 -0
- mail/examples/mafia/game.py +1537 -0
- mail/examples/mafia/narrator_tools.py +396 -0
- mail/examples/mafia/personas.py +240 -0
- mail/examples/mafia/prompts.py +489 -0
- mail/examples/mafia/roles.py +147 -0
- mail/examples/mafia/spec.md +350 -0
- mail/examples/math_dummy/__init__.py +23 -0
- mail/examples/math_dummy/actions.py +252 -0
- mail/examples/math_dummy/agent.py +136 -0
- mail/examples/math_dummy/prompts.py +46 -0
- mail/examples/math_dummy/types.py +5 -0
- mail/examples/research/__init__.py +39 -0
- mail/examples/research/researcher/__init__.py +9 -0
- mail/examples/research/researcher/agent.py +67 -0
- mail/examples/research/researcher/prompts.py +54 -0
- mail/examples/research/searcher/__init__.py +10 -0
- mail/examples/research/searcher/actions.py +324 -0
- mail/examples/research/searcher/agent.py +67 -0
- mail/examples/research/searcher/prompts.py +53 -0
- mail/examples/research/summarizer/__init__.py +18 -0
- mail/examples/research/summarizer/actions.py +255 -0
- mail/examples/research/summarizer/agent.py +67 -0
- mail/examples/research/summarizer/prompts.py +55 -0
- mail/examples/research/verifier/__init__.py +10 -0
- mail/examples/research/verifier/actions.py +337 -0
- mail/examples/research/verifier/agent.py +67 -0
- mail/examples/research/verifier/prompts.py +52 -0
- mail/examples/supervisor/__init__.py +11 -0
- mail/examples/supervisor/agent.py +4 -0
- mail/examples/supervisor/prompts.py +93 -0
- mail/examples/support/__init__.py +33 -0
- mail/examples/support/classifier/__init__.py +10 -0
- mail/examples/support/classifier/actions.py +307 -0
- mail/examples/support/classifier/agent.py +68 -0
- mail/examples/support/classifier/prompts.py +56 -0
- mail/examples/support/coordinator/__init__.py +9 -0
- mail/examples/support/coordinator/agent.py +67 -0
- mail/examples/support/coordinator/prompts.py +48 -0
- mail/examples/support/faq/__init__.py +10 -0
- mail/examples/support/faq/actions.py +182 -0
- mail/examples/support/faq/agent.py +67 -0
- mail/examples/support/faq/prompts.py +42 -0
- mail/examples/support/sentiment/__init__.py +15 -0
- mail/examples/support/sentiment/actions.py +341 -0
- mail/examples/support/sentiment/agent.py +67 -0
- mail/examples/support/sentiment/prompts.py +54 -0
- mail/examples/weather_dummy/__init__.py +23 -0
- mail/examples/weather_dummy/actions.py +75 -0
- mail/examples/weather_dummy/agent.py +136 -0
- mail/examples/weather_dummy/prompts.py +35 -0
- mail/examples/weather_dummy/types.py +5 -0
- mail/factories/__init__.py +27 -0
- mail/factories/action.py +223 -0
- mail/factories/base.py +1531 -0
- mail/factories/supervisor.py +241 -0
- mail/net/__init__.py +7 -0
- mail/net/registry.py +712 -0
- mail/net/router.py +728 -0
- mail/net/server_utils.py +114 -0
- mail/net/types.py +247 -0
- mail/server.py +1605 -0
- mail/stdlib/__init__.py +0 -0
- mail/stdlib/anthropic/__init__.py +0 -0
- mail/stdlib/fs/__init__.py +15 -0
- mail/stdlib/fs/actions.py +209 -0
- mail/stdlib/http/__init__.py +19 -0
- mail/stdlib/http/actions.py +333 -0
- mail/stdlib/interswarm/__init__.py +11 -0
- mail/stdlib/interswarm/actions.py +208 -0
- mail/stdlib/mcp/__init__.py +19 -0
- mail/stdlib/mcp/actions.py +294 -0
- mail/stdlib/openai/__init__.py +13 -0
- mail/stdlib/openai/agents.py +451 -0
- mail/summarizer.py +234 -0
- mail/swarms_json/__init__.py +27 -0
- mail/swarms_json/types.py +87 -0
- mail/swarms_json/utils.py +255 -0
- mail/url_scheme.py +51 -0
- mail/utils/__init__.py +53 -0
- mail/utils/auth.py +194 -0
- mail/utils/context.py +17 -0
- mail/utils/logger.py +73 -0
- mail/utils/openai.py +212 -0
- mail/utils/parsing.py +89 -0
- mail/utils/serialize.py +292 -0
- mail/utils/store.py +49 -0
- mail/utils/string_builder.py +119 -0
- mail/utils/version.py +20 -0
- mail_swarms-1.3.2.dist-info/METADATA +237 -0
- mail_swarms-1.3.2.dist-info/RECORD +137 -0
- mail_swarms-1.3.2.dist-info/WHEEL +4 -0
- mail_swarms-1.3.2.dist-info/entry_points.txt +2 -0
- mail_swarms-1.3.2.dist-info/licenses/LICENSE +202 -0
- mail_swarms-1.3.2.dist-info/licenses/NOTICE +10 -0
- mail_swarms-1.3.2.dist-info/licenses/THIRD_PARTY_NOTICES.md +12334 -0
mail/db/utils.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright (c) 2025 Addison Kline
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
import asyncpg
|
|
10
|
+
import dotenv
|
|
11
|
+
|
|
12
|
+
from mail.db.types import AgentHistoriesDB
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger("mail.db")
|
|
15
|
+
|
|
16
|
+
# global connection pool
|
|
17
|
+
_pool: asyncpg.Pool | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def get_pool() -> asyncpg.Pool:
|
|
21
|
+
"""
|
|
22
|
+
Get or create the global connection pool.
|
|
23
|
+
"""
|
|
24
|
+
global _pool
|
|
25
|
+
|
|
26
|
+
if _pool is None:
|
|
27
|
+
dotenv.load_dotenv()
|
|
28
|
+
database_url = os.getenv("DATABASE_URL")
|
|
29
|
+
if database_url is None:
|
|
30
|
+
raise ValueError("DATABASE_URL is not set")
|
|
31
|
+
|
|
32
|
+
logger.info(f"creating new connection pool to {database_url}")
|
|
33
|
+
_pool = await asyncpg.create_pool(
|
|
34
|
+
database_url,
|
|
35
|
+
min_size=5,
|
|
36
|
+
max_size=20,
|
|
37
|
+
command_timeout=60,
|
|
38
|
+
server_settings={"application_name": "mail-server"},
|
|
39
|
+
)
|
|
40
|
+
logger.info("connection pool created")
|
|
41
|
+
|
|
42
|
+
return _pool
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
async def close_pool() -> None:
|
|
46
|
+
"""
|
|
47
|
+
Close the global connection pool.
|
|
48
|
+
"""
|
|
49
|
+
global _pool
|
|
50
|
+
|
|
51
|
+
if _pool is not None:
|
|
52
|
+
logger.info("closing connection pool")
|
|
53
|
+
await _pool.close()
|
|
54
|
+
_pool = None
|
|
55
|
+
logger.info("connection pool closed")
|
|
56
|
+
else:
|
|
57
|
+
logger.info("connection pool already closed")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
async def _db_execute(
|
|
61
|
+
query: str,
|
|
62
|
+
max_retries: int = 3,
|
|
63
|
+
retry_delay: float = 1.0,
|
|
64
|
+
) -> Any:
|
|
65
|
+
"""
|
|
66
|
+
Execute a database query and return the result, with retry logic for transient errors.
|
|
67
|
+
"""
|
|
68
|
+
pool = await get_pool()
|
|
69
|
+
|
|
70
|
+
for attempt in range(max_retries):
|
|
71
|
+
try:
|
|
72
|
+
async with pool.acquire() as connection:
|
|
73
|
+
result = await connection.fetch(query)
|
|
74
|
+
return result
|
|
75
|
+
except asyncpg.ConnectionDoesNotExistError:
|
|
76
|
+
# connection was closed, try to recreate pool
|
|
77
|
+
if attempt < max_retries - 1:
|
|
78
|
+
logger.warning(
|
|
79
|
+
f"database connection lost, retrying... ({attempt + 1}/{max_retries})"
|
|
80
|
+
)
|
|
81
|
+
global _pool
|
|
82
|
+
_pool = None
|
|
83
|
+
await asyncio.sleep(retry_delay)
|
|
84
|
+
pool = await get_pool()
|
|
85
|
+
else:
|
|
86
|
+
logger.error(
|
|
87
|
+
f"failed to reconnect to database after {max_retries} attempts"
|
|
88
|
+
)
|
|
89
|
+
raise
|
|
90
|
+
except asyncpg.ConnectionFailureError as e:
|
|
91
|
+
if attempt < max_retries - 1:
|
|
92
|
+
logger.warning(
|
|
93
|
+
f"database connection failure (attempt {attempt + 1}/{max_retries}): {e}"
|
|
94
|
+
)
|
|
95
|
+
await asyncio.sleep(retry_delay)
|
|
96
|
+
else:
|
|
97
|
+
logger.error(
|
|
98
|
+
f"failed to reconnect to database after {max_retries} attempts: {e}"
|
|
99
|
+
)
|
|
100
|
+
raise
|
|
101
|
+
except Exception as e:
|
|
102
|
+
if attempt < max_retries - 1:
|
|
103
|
+
logger.warning(
|
|
104
|
+
f"database query failed (attempt {attempt + 1}/{max_retries}): {e}"
|
|
105
|
+
)
|
|
106
|
+
await asyncio.sleep(retry_delay)
|
|
107
|
+
else:
|
|
108
|
+
logger.error(
|
|
109
|
+
f"failed to execute query after {max_retries} attempts: {e}"
|
|
110
|
+
)
|
|
111
|
+
raise
|
|
112
|
+
|
|
113
|
+
raise RuntimeError(f"failed to execute query after {max_retries} attempts")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async def create_agent_history(
|
|
117
|
+
swarm_name: str,
|
|
118
|
+
caller_role: Literal["admin", "agent", "user"],
|
|
119
|
+
caller_id: str,
|
|
120
|
+
tool_format: Literal["completions", "responses"],
|
|
121
|
+
task_id: str,
|
|
122
|
+
agent_name: str,
|
|
123
|
+
history: list[dict[str, Any]],
|
|
124
|
+
) -> None:
|
|
125
|
+
"""
|
|
126
|
+
Create a new agent history record in the database.
|
|
127
|
+
"""
|
|
128
|
+
import json
|
|
129
|
+
|
|
130
|
+
pool = await get_pool()
|
|
131
|
+
query = """
|
|
132
|
+
INSERT INTO agent_histories (swarm_name, caller_role, caller_id, tool_format, task_id, agent_name, history)
|
|
133
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
async with pool.acquire() as connection:
|
|
137
|
+
await connection.execute(
|
|
138
|
+
query,
|
|
139
|
+
swarm_name,
|
|
140
|
+
caller_role,
|
|
141
|
+
caller_id,
|
|
142
|
+
tool_format,
|
|
143
|
+
task_id,
|
|
144
|
+
agent_name,
|
|
145
|
+
json.dumps(history),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
async def load_agent_histories(
|
|
150
|
+
swarm_name: str,
|
|
151
|
+
caller_role: Literal["admin", "agent", "user"],
|
|
152
|
+
caller_id: str,
|
|
153
|
+
) -> dict[str, list[dict[str, Any]]]:
|
|
154
|
+
"""
|
|
155
|
+
Load all agent histories for a given swarm, caller role, and caller ID.
|
|
156
|
+
Returns a dictionary keyed by "{task_id}::{agent_name}" containing the history lists.
|
|
157
|
+
"""
|
|
158
|
+
import json
|
|
159
|
+
|
|
160
|
+
pool = await get_pool()
|
|
161
|
+
query = """
|
|
162
|
+
SELECT task_id, agent_name, history
|
|
163
|
+
FROM agent_histories
|
|
164
|
+
WHERE swarm_name = $1 AND caller_role = $2 AND caller_id = $3
|
|
165
|
+
ORDER BY created_at ASC
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
histories: dict[str, list[dict[str, Any]]] = {}
|
|
169
|
+
|
|
170
|
+
async with pool.acquire() as connection:
|
|
171
|
+
rows = await connection.fetch(query, swarm_name, caller_role, caller_id)
|
|
172
|
+
for row in rows:
|
|
173
|
+
task_id = row["task_id"]
|
|
174
|
+
agent_name = row["agent_name"]
|
|
175
|
+
history_data = row["history"]
|
|
176
|
+
|
|
177
|
+
# Parse JSON if it's a string, otherwise use as-is (asyncpg may auto-parse JSONB)
|
|
178
|
+
if isinstance(history_data, str):
|
|
179
|
+
history_list = json.loads(history_data)
|
|
180
|
+
else:
|
|
181
|
+
history_list = history_data
|
|
182
|
+
|
|
183
|
+
key = f"{task_id}::{agent_name}"
|
|
184
|
+
# If multiple records exist for the same key, extend the history
|
|
185
|
+
if key in histories:
|
|
186
|
+
histories[key].extend(history_list)
|
|
187
|
+
else:
|
|
188
|
+
histories[key] = history_list
|
|
189
|
+
|
|
190
|
+
logger.info(
|
|
191
|
+
f"loaded {len(histories)} agent history entries for {caller_role}:{caller_id}@{swarm_name}"
|
|
192
|
+
)
|
|
193
|
+
return histories
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
async def create_agent_histories_table() -> Any:
|
|
197
|
+
"""
|
|
198
|
+
Create the agent history table in the database.
|
|
199
|
+
"""
|
|
200
|
+
query = """
|
|
201
|
+
CREATE TABLE IF NOT EXISTS agent_histories (
|
|
202
|
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
|
203
|
+
swarm_name TEXT NOT NULL,
|
|
204
|
+
caller_role TEXT NOT NULL,
|
|
205
|
+
caller_id TEXT NOT NULL,
|
|
206
|
+
tool_format TEXT NOT NULL,
|
|
207
|
+
task_id TEXT NOT NULL,
|
|
208
|
+
agent_name TEXT NOT NULL,
|
|
209
|
+
history JSONB NOT NULL,
|
|
210
|
+
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
|
211
|
+
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
|
|
212
|
+
)
|
|
213
|
+
"""
|
|
214
|
+
result = await _db_execute(query)
|
|
215
|
+
return result
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# =============================================================================
|
|
219
|
+
# Task Persistence Functions
|
|
220
|
+
# =============================================================================
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
async def create_task(
|
|
224
|
+
task_id: str,
|
|
225
|
+
swarm_name: str,
|
|
226
|
+
caller_role: Literal["admin", "agent", "user"],
|
|
227
|
+
caller_id: str,
|
|
228
|
+
task_owner: str,
|
|
229
|
+
task_contributors: list[str],
|
|
230
|
+
remote_swarms: list[str],
|
|
231
|
+
start_time: str,
|
|
232
|
+
is_running: bool = True,
|
|
233
|
+
completed: bool = False,
|
|
234
|
+
title: str | None = None,
|
|
235
|
+
) -> None:
|
|
236
|
+
"""
|
|
237
|
+
Create a new task record in the database.
|
|
238
|
+
Uses INSERT ON CONFLICT to handle duplicate task_ids gracefully.
|
|
239
|
+
"""
|
|
240
|
+
import datetime
|
|
241
|
+
import json
|
|
242
|
+
|
|
243
|
+
pool = await get_pool()
|
|
244
|
+
query = """
|
|
245
|
+
INSERT INTO tasks (task_id, swarm_name, caller_role, caller_id, task_owner,
|
|
246
|
+
task_contributors, remote_swarms, start_time, is_running, completed, title)
|
|
247
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
|
248
|
+
ON CONFLICT (task_id, swarm_name, caller_role, caller_id) DO NOTHING
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
# Convert ISO string to datetime if needed
|
|
252
|
+
if isinstance(start_time, str):
|
|
253
|
+
start_time_dt = datetime.datetime.fromisoformat(start_time)
|
|
254
|
+
else:
|
|
255
|
+
start_time_dt = start_time
|
|
256
|
+
|
|
257
|
+
async with pool.acquire() as connection:
|
|
258
|
+
await connection.execute(
|
|
259
|
+
query,
|
|
260
|
+
task_id,
|
|
261
|
+
swarm_name,
|
|
262
|
+
caller_role,
|
|
263
|
+
caller_id,
|
|
264
|
+
task_owner,
|
|
265
|
+
json.dumps(task_contributors),
|
|
266
|
+
json.dumps(remote_swarms),
|
|
267
|
+
start_time_dt,
|
|
268
|
+
is_running,
|
|
269
|
+
completed,
|
|
270
|
+
title,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
async def update_task(
|
|
275
|
+
task_id: str,
|
|
276
|
+
swarm_name: str,
|
|
277
|
+
caller_role: Literal["admin", "agent", "user"],
|
|
278
|
+
caller_id: str,
|
|
279
|
+
is_running: bool | None = None,
|
|
280
|
+
completed: bool | None = None,
|
|
281
|
+
task_contributors: list[str] | None = None,
|
|
282
|
+
remote_swarms: list[str] | None = None,
|
|
283
|
+
title: str | None = None,
|
|
284
|
+
) -> None:
|
|
285
|
+
"""
|
|
286
|
+
Update an existing task record in the database.
|
|
287
|
+
Only updates fields that are provided (not None).
|
|
288
|
+
"""
|
|
289
|
+
import json
|
|
290
|
+
|
|
291
|
+
pool = await get_pool()
|
|
292
|
+
|
|
293
|
+
# Build dynamic update query
|
|
294
|
+
updates = []
|
|
295
|
+
params = []
|
|
296
|
+
param_idx = 1
|
|
297
|
+
|
|
298
|
+
if is_running is not None:
|
|
299
|
+
updates.append(f"is_running = ${param_idx}")
|
|
300
|
+
params.append(is_running)
|
|
301
|
+
param_idx += 1
|
|
302
|
+
|
|
303
|
+
if completed is not None:
|
|
304
|
+
updates.append(f"completed = ${param_idx}")
|
|
305
|
+
params.append(completed)
|
|
306
|
+
param_idx += 1
|
|
307
|
+
|
|
308
|
+
if task_contributors is not None:
|
|
309
|
+
updates.append(f"task_contributors = ${param_idx}")
|
|
310
|
+
params.append(json.dumps(task_contributors)) # type: ignore
|
|
311
|
+
param_idx += 1
|
|
312
|
+
|
|
313
|
+
if remote_swarms is not None:
|
|
314
|
+
updates.append(f"remote_swarms = ${param_idx}")
|
|
315
|
+
params.append(json.dumps(remote_swarms)) # type: ignore
|
|
316
|
+
param_idx += 1
|
|
317
|
+
|
|
318
|
+
if title is not None:
|
|
319
|
+
updates.append(f"title = ${param_idx}")
|
|
320
|
+
params.append(title) # type: ignore
|
|
321
|
+
param_idx += 1
|
|
322
|
+
|
|
323
|
+
if not updates:
|
|
324
|
+
return # Nothing to update
|
|
325
|
+
|
|
326
|
+
updates.append(f"updated_at = NOW()")
|
|
327
|
+
|
|
328
|
+
query = f"""
|
|
329
|
+
UPDATE tasks
|
|
330
|
+
SET {", ".join(updates)}
|
|
331
|
+
WHERE task_id = ${param_idx} AND swarm_name = ${param_idx + 1}
|
|
332
|
+
AND caller_role = ${param_idx + 2} AND caller_id = ${param_idx + 3}
|
|
333
|
+
"""
|
|
334
|
+
params.extend([task_id, swarm_name, caller_role, caller_id]) # type: ignore
|
|
335
|
+
|
|
336
|
+
async with pool.acquire() as connection:
|
|
337
|
+
await connection.execute(query, *params)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
async def load_tasks(
|
|
341
|
+
swarm_name: str,
|
|
342
|
+
caller_role: Literal["admin", "agent", "user"],
|
|
343
|
+
caller_id: str,
|
|
344
|
+
) -> list[dict[str, Any]]:
|
|
345
|
+
"""
|
|
346
|
+
Load all tasks for a given swarm, caller role, and caller ID.
|
|
347
|
+
Returns a list of task records.
|
|
348
|
+
"""
|
|
349
|
+
import json
|
|
350
|
+
|
|
351
|
+
pool = await get_pool()
|
|
352
|
+
query = """
|
|
353
|
+
SELECT task_id, task_owner, task_contributors, remote_swarms,
|
|
354
|
+
is_running, completed, start_time, title
|
|
355
|
+
FROM tasks
|
|
356
|
+
WHERE swarm_name = $1 AND caller_role = $2 AND caller_id = $3
|
|
357
|
+
ORDER BY start_time ASC
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
tasks = []
|
|
361
|
+
async with pool.acquire() as connection:
|
|
362
|
+
rows = await connection.fetch(query, swarm_name, caller_role, caller_id)
|
|
363
|
+
for row in rows:
|
|
364
|
+
task_contributors = row["task_contributors"]
|
|
365
|
+
remote_swarms = row["remote_swarms"]
|
|
366
|
+
|
|
367
|
+
# Parse JSON if needed
|
|
368
|
+
if isinstance(task_contributors, str):
|
|
369
|
+
task_contributors = json.loads(task_contributors)
|
|
370
|
+
if isinstance(remote_swarms, str):
|
|
371
|
+
remote_swarms = json.loads(remote_swarms)
|
|
372
|
+
|
|
373
|
+
tasks.append({
|
|
374
|
+
"task_id": row["task_id"],
|
|
375
|
+
"task_owner": row["task_owner"],
|
|
376
|
+
"task_contributors": task_contributors,
|
|
377
|
+
"remote_swarms": remote_swarms,
|
|
378
|
+
"is_running": row["is_running"],
|
|
379
|
+
"completed": row["completed"],
|
|
380
|
+
"start_time": row["start_time"].isoformat() if row["start_time"] else None,
|
|
381
|
+
"title": row["title"],
|
|
382
|
+
})
|
|
383
|
+
|
|
384
|
+
logger.info(
|
|
385
|
+
f"loaded {len(tasks)} tasks for {caller_role}:{caller_id}@{swarm_name}"
|
|
386
|
+
)
|
|
387
|
+
return tasks
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
async def create_task_event(
|
|
391
|
+
task_id: str,
|
|
392
|
+
swarm_name: str,
|
|
393
|
+
caller_role: Literal["admin", "agent", "user"],
|
|
394
|
+
caller_id: str,
|
|
395
|
+
event_type: str | None,
|
|
396
|
+
event_data: str | None,
|
|
397
|
+
event_id: str | None,
|
|
398
|
+
) -> None:
|
|
399
|
+
"""
|
|
400
|
+
Create a new task event record in the database.
|
|
401
|
+
"""
|
|
402
|
+
pool = await get_pool()
|
|
403
|
+
query = """
|
|
404
|
+
INSERT INTO task_events (task_id, swarm_name, caller_role, caller_id,
|
|
405
|
+
event_type, event_data, event_id)
|
|
406
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
async with pool.acquire() as connection:
|
|
410
|
+
await connection.execute(
|
|
411
|
+
query,
|
|
412
|
+
task_id,
|
|
413
|
+
swarm_name,
|
|
414
|
+
caller_role,
|
|
415
|
+
caller_id,
|
|
416
|
+
event_type,
|
|
417
|
+
event_data,
|
|
418
|
+
event_id,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
async def load_task_events(
|
|
423
|
+
task_id: str,
|
|
424
|
+
swarm_name: str,
|
|
425
|
+
caller_role: Literal["admin", "agent", "user"],
|
|
426
|
+
caller_id: str,
|
|
427
|
+
) -> list[dict[str, Any]]:
|
|
428
|
+
"""
|
|
429
|
+
Load all events for a specific task.
|
|
430
|
+
Returns a list of event records in chronological order.
|
|
431
|
+
"""
|
|
432
|
+
pool = await get_pool()
|
|
433
|
+
query = """
|
|
434
|
+
SELECT event_type, event_data, event_id
|
|
435
|
+
FROM task_events
|
|
436
|
+
WHERE task_id = $1 AND swarm_name = $2 AND caller_role = $3 AND caller_id = $4
|
|
437
|
+
ORDER BY created_at ASC
|
|
438
|
+
"""
|
|
439
|
+
|
|
440
|
+
events = []
|
|
441
|
+
async with pool.acquire() as connection:
|
|
442
|
+
rows = await connection.fetch(
|
|
443
|
+
query, task_id, swarm_name, caller_role, caller_id
|
|
444
|
+
)
|
|
445
|
+
for row in rows:
|
|
446
|
+
events.append(
|
|
447
|
+
{
|
|
448
|
+
"event": row["event_type"],
|
|
449
|
+
"data": row["event_data"],
|
|
450
|
+
"id": row["event_id"],
|
|
451
|
+
}
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
return events
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
async def create_task_response(
|
|
458
|
+
task_id: str,
|
|
459
|
+
swarm_name: str,
|
|
460
|
+
caller_role: Literal["admin", "agent", "user"],
|
|
461
|
+
caller_id: str,
|
|
462
|
+
response: dict[str, Any],
|
|
463
|
+
) -> None:
|
|
464
|
+
"""
|
|
465
|
+
Create or update a task response record in the database.
|
|
466
|
+
Uses UPSERT since each task can only have one final response.
|
|
467
|
+
"""
|
|
468
|
+
import json
|
|
469
|
+
|
|
470
|
+
pool = await get_pool()
|
|
471
|
+
query = """
|
|
472
|
+
INSERT INTO task_responses (task_id, swarm_name, caller_role, caller_id, response)
|
|
473
|
+
VALUES ($1, $2, $3, $4, $5)
|
|
474
|
+
ON CONFLICT (task_id, swarm_name, caller_role, caller_id)
|
|
475
|
+
DO UPDATE SET response = $5
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
async with pool.acquire() as connection:
|
|
479
|
+
await connection.execute(
|
|
480
|
+
query,
|
|
481
|
+
task_id,
|
|
482
|
+
swarm_name,
|
|
483
|
+
caller_role,
|
|
484
|
+
caller_id,
|
|
485
|
+
json.dumps(response),
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
async def load_task_responses(
|
|
490
|
+
swarm_name: str,
|
|
491
|
+
caller_role: Literal["admin", "agent", "user"],
|
|
492
|
+
caller_id: str,
|
|
493
|
+
) -> dict[str, dict[str, Any]]:
|
|
494
|
+
"""
|
|
495
|
+
Load all task responses for a given swarm, caller role, and caller ID.
|
|
496
|
+
Returns a dictionary keyed by task_id.
|
|
497
|
+
"""
|
|
498
|
+
import json
|
|
499
|
+
|
|
500
|
+
pool = await get_pool()
|
|
501
|
+
query = """
|
|
502
|
+
SELECT task_id, response
|
|
503
|
+
FROM task_responses
|
|
504
|
+
WHERE swarm_name = $1 AND caller_role = $2 AND caller_id = $3
|
|
505
|
+
"""
|
|
506
|
+
|
|
507
|
+
responses: dict[str, dict[str, Any]] = {}
|
|
508
|
+
async with pool.acquire() as connection:
|
|
509
|
+
rows = await connection.fetch(query, swarm_name, caller_role, caller_id)
|
|
510
|
+
for row in rows:
|
|
511
|
+
task_id = row["task_id"]
|
|
512
|
+
response_data = row["response"]
|
|
513
|
+
|
|
514
|
+
# Parse JSON if needed
|
|
515
|
+
if isinstance(response_data, str):
|
|
516
|
+
response_data = json.loads(response_data)
|
|
517
|
+
|
|
518
|
+
responses[task_id] = response_data
|
|
519
|
+
|
|
520
|
+
logger.info(
|
|
521
|
+
f"loaded {len(responses)} task responses for {caller_role}:{caller_id}@{swarm_name}"
|
|
522
|
+
)
|
|
523
|
+
return responses
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from .analyst_dummy import (
|
|
2
|
+
LiteLLMAnalystFunction,
|
|
3
|
+
factory_analyst_dummy,
|
|
4
|
+
)
|
|
5
|
+
from .consultant_dummy import (
|
|
6
|
+
LiteLLMConsultantFunction,
|
|
7
|
+
factory_consultant_dummy,
|
|
8
|
+
)
|
|
9
|
+
from .math_dummy import (
|
|
10
|
+
LiteLLMMathFunction,
|
|
11
|
+
factory_math_dummy,
|
|
12
|
+
)
|
|
13
|
+
from .weather_dummy import (
|
|
14
|
+
LiteLLMWeatherFunction,
|
|
15
|
+
factory_weather_dummy,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"factory_analyst_dummy",
|
|
20
|
+
"factory_consultant_dummy",
|
|
21
|
+
"factory_math_dummy",
|
|
22
|
+
"factory_weather_dummy",
|
|
23
|
+
"LiteLLMAnalystFunction",
|
|
24
|
+
"LiteLLMConsultantFunction",
|
|
25
|
+
"LiteLLMMathFunction",
|
|
26
|
+
"LiteLLMWeatherFunction",
|
|
27
|
+
]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .agent import (
|
|
2
|
+
LiteLLMAnalystFunction,
|
|
3
|
+
analyst_agent_params,
|
|
4
|
+
factory_analyst_dummy,
|
|
5
|
+
)
|
|
6
|
+
from .prompts import (
|
|
7
|
+
SYSPROMPT as ANALYST_SYSPROMPT,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"factory_analyst_dummy",
|
|
12
|
+
"LiteLLMAnalystFunction",
|
|
13
|
+
"ANALYST_SYSPROMPT",
|
|
14
|
+
"analyst_agent_params",
|
|
15
|
+
]
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright (c) 2025 Addison Kline
|
|
3
|
+
|
|
4
|
+
import warnings
|
|
5
|
+
from collections.abc import Awaitable
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
from mail.core.agents import AgentOutput
|
|
9
|
+
from mail.factories import AgentFunction
|
|
10
|
+
from mail.factories.base import LiteLLMAgentFunction
|
|
11
|
+
|
|
12
|
+
analyst_agent_params = {
|
|
13
|
+
"llm": "openai/gpt-5-mini",
|
|
14
|
+
"system": "mail.examples.analyst_dummy.prompts:SYSPROMPT",
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def factory_analyst_dummy(
|
|
19
|
+
# REQUIRED
|
|
20
|
+
# top-level params
|
|
21
|
+
comm_targets: list[str],
|
|
22
|
+
tools: list[dict[str, Any]],
|
|
23
|
+
# instance params
|
|
24
|
+
user_token: str,
|
|
25
|
+
# internal params
|
|
26
|
+
llm: str,
|
|
27
|
+
system: str,
|
|
28
|
+
# OPTIONAL
|
|
29
|
+
# top-level params
|
|
30
|
+
name: str = "analyst",
|
|
31
|
+
enable_entrypoint: bool = False,
|
|
32
|
+
enable_interswarm: bool = False,
|
|
33
|
+
can_complete_tasks: bool = False,
|
|
34
|
+
tool_format: Literal["completions", "responses"] = "responses",
|
|
35
|
+
exclude_tools: list[str] = [],
|
|
36
|
+
# instance params
|
|
37
|
+
# ...
|
|
38
|
+
# internal params
|
|
39
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
|
40
|
+
thinking_budget: int | None = None,
|
|
41
|
+
max_tokens: int | None = None,
|
|
42
|
+
memory: bool = True,
|
|
43
|
+
use_proxy: bool = True,
|
|
44
|
+
) -> AgentFunction:
|
|
45
|
+
warnings.warn(
|
|
46
|
+
"`mail.examples.analyst_dummy:factory_analyst_dummy` is deprecated and will be removed in a future version. "
|
|
47
|
+
"Use `mail.examples.analyst_dummy:LiteLLMAnalystFunction` instead.",
|
|
48
|
+
DeprecationWarning,
|
|
49
|
+
stacklevel=2,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
litellm_analyst = LiteLLMAnalystFunction(
|
|
53
|
+
name=name,
|
|
54
|
+
comm_targets=comm_targets,
|
|
55
|
+
tools=tools,
|
|
56
|
+
llm=llm,
|
|
57
|
+
system=system,
|
|
58
|
+
user_token=user_token,
|
|
59
|
+
enable_entrypoint=enable_entrypoint,
|
|
60
|
+
enable_interswarm=enable_interswarm,
|
|
61
|
+
reasoning_effort=reasoning_effort,
|
|
62
|
+
thinking_budget=thinking_budget,
|
|
63
|
+
max_tokens=max_tokens,
|
|
64
|
+
memory=memory,
|
|
65
|
+
use_proxy=use_proxy,
|
|
66
|
+
can_complete_tasks=can_complete_tasks,
|
|
67
|
+
tool_format=tool_format,
|
|
68
|
+
exclude_tools=exclude_tools,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
async def run(
|
|
72
|
+
messages: list[dict[str, Any]],
|
|
73
|
+
tool_choice: str | dict[str, str] = "required",
|
|
74
|
+
) -> AgentOutput:
|
|
75
|
+
"""
|
|
76
|
+
Execute the LiteLLM-based analyst agent function.
|
|
77
|
+
"""
|
|
78
|
+
return await litellm_analyst(messages, tool_choice)
|
|
79
|
+
|
|
80
|
+
return run
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class LiteLLMAnalystFunction(LiteLLMAgentFunction):
|
|
84
|
+
"""
|
|
85
|
+
Class that represents a LiteLLM-based analyst agent function.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
name: str,
|
|
91
|
+
comm_targets: list[str],
|
|
92
|
+
tools: list[dict[str, Any]],
|
|
93
|
+
llm: str,
|
|
94
|
+
system: str,
|
|
95
|
+
user_token: str = "",
|
|
96
|
+
enable_entrypoint: bool = False,
|
|
97
|
+
enable_interswarm: bool = False,
|
|
98
|
+
can_complete_tasks: bool = False,
|
|
99
|
+
tool_format: Literal["completions", "responses"] = "responses",
|
|
100
|
+
exclude_tools: list[str] = [],
|
|
101
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
102
|
+
thinking_budget: int | None = None,
|
|
103
|
+
max_tokens: int | None = None,
|
|
104
|
+
memory: bool = True,
|
|
105
|
+
use_proxy: bool = True,
|
|
106
|
+
_debug_include_mail_tools: bool = True,
|
|
107
|
+
) -> None:
|
|
108
|
+
super().__init__(
|
|
109
|
+
name=name,
|
|
110
|
+
comm_targets=comm_targets,
|
|
111
|
+
tools=tools,
|
|
112
|
+
llm=llm,
|
|
113
|
+
system=system,
|
|
114
|
+
user_token=user_token,
|
|
115
|
+
enable_entrypoint=enable_entrypoint,
|
|
116
|
+
enable_interswarm=enable_interswarm,
|
|
117
|
+
can_complete_tasks=can_complete_tasks,
|
|
118
|
+
tool_format=tool_format,
|
|
119
|
+
exclude_tools=exclude_tools,
|
|
120
|
+
reasoning_effort=reasoning_effort,
|
|
121
|
+
thinking_budget=thinking_budget,
|
|
122
|
+
max_tokens=max_tokens,
|
|
123
|
+
memory=memory,
|
|
124
|
+
use_proxy=use_proxy,
|
|
125
|
+
_debug_include_mail_tools=_debug_include_mail_tools,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def __call__(
|
|
129
|
+
self,
|
|
130
|
+
messages: list[dict[str, Any]],
|
|
131
|
+
tool_choice: str | dict[str, str] = "required",
|
|
132
|
+
) -> Awaitable[AgentOutput]:
|
|
133
|
+
"""
|
|
134
|
+
Execute the LiteLLM-based analyst agent function.
|
|
135
|
+
"""
|
|
136
|
+
return super().__call__(messages, tool_choice)
|