prompture 0.0.35__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. prompture/__init__.py +120 -2
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +880 -0
  6. prompture/async_conversation.py +199 -17
  7. prompture/async_driver.py +24 -0
  8. prompture/async_groups.py +551 -0
  9. prompture/conversation.py +213 -18
  10. prompture/core.py +30 -12
  11. prompture/discovery.py +24 -1
  12. prompture/driver.py +38 -0
  13. prompture/drivers/__init__.py +5 -1
  14. prompture/drivers/async_azure_driver.py +7 -1
  15. prompture/drivers/async_claude_driver.py +7 -1
  16. prompture/drivers/async_google_driver.py +212 -28
  17. prompture/drivers/async_grok_driver.py +7 -1
  18. prompture/drivers/async_groq_driver.py +7 -1
  19. prompture/drivers/async_lmstudio_driver.py +74 -5
  20. prompture/drivers/async_ollama_driver.py +13 -3
  21. prompture/drivers/async_openai_driver.py +7 -1
  22. prompture/drivers/async_openrouter_driver.py +7 -1
  23. prompture/drivers/async_registry.py +5 -1
  24. prompture/drivers/azure_driver.py +7 -1
  25. prompture/drivers/claude_driver.py +7 -1
  26. prompture/drivers/google_driver.py +217 -33
  27. prompture/drivers/grok_driver.py +7 -1
  28. prompture/drivers/groq_driver.py +7 -1
  29. prompture/drivers/lmstudio_driver.py +73 -8
  30. prompture/drivers/ollama_driver.py +16 -5
  31. prompture/drivers/openai_driver.py +7 -1
  32. prompture/drivers/openrouter_driver.py +7 -1
  33. prompture/drivers/vision_helpers.py +153 -0
  34. prompture/group_types.py +147 -0
  35. prompture/groups.py +530 -0
  36. prompture/image.py +180 -0
  37. prompture/persistence.py +254 -0
  38. prompture/persona.py +482 -0
  39. prompture/serialization.py +218 -0
  40. prompture/settings.py +1 -0
  41. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  42. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  43. prompture-0.0.35.dist-info/METADATA +0 -464
  44. prompture-0.0.35.dist-info/RECORD +0 -66
  45. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +0 -0
  46. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  47. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  48. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,551 @@
1
+ """Async multi-agent group coordination.
2
+
3
+ Provides :class:`ParallelGroup`, :class:`AsyncSequentialGroup`,
4
+ :class:`AsyncLoopGroup`, and :class:`AsyncRouterAgent` for composing
5
+ multiple async agents into deterministic workflows.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import logging
12
+ import time
13
+ from collections.abc import Callable
14
+ from typing import Any
15
+
16
+ from .agent_types import AgentResult, AgentState
17
+ from .group_types import (
18
+ AgentError,
19
+ ErrorPolicy,
20
+ GroupCallbacks,
21
+ GroupResult,
22
+ GroupStep,
23
+ _aggregate_usage,
24
+ )
25
+ from .groups import _agent_name, _inject_state, _normalise_agents
26
+
27
+ logger = logging.getLogger("prompture.async_groups")
28
+
29
+
30
+ # ------------------------------------------------------------------
31
+ # ParallelGroup
32
+ # ------------------------------------------------------------------
33
+
34
+
35
+ class ParallelGroup:
36
+ """Execute agents concurrently and collect results.
37
+
38
+ Agents read from a frozen snapshot of the shared state taken at
39
+ the start of the run. Output key writes are applied after all
40
+ agents complete, in agent index order.
41
+
42
+ Args:
43
+ agents: List of async agents or ``(agent, prompt_template)`` tuples.
44
+ state: Initial shared state dict.
45
+ error_policy: How to handle agent failures.
46
+ timeout_ms: Per-agent timeout in milliseconds.
47
+ callbacks: Observability hooks.
48
+ max_total_cost: Budget cap in USD.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ agents: list[Any],
54
+ *,
55
+ state: dict[str, Any] | None = None,
56
+ error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
57
+ timeout_ms: float | None = None,
58
+ callbacks: GroupCallbacks | None = None,
59
+ max_total_cost: float | None = None,
60
+ ) -> None:
61
+ self._agents = _normalise_agents(agents)
62
+ self._state: dict[str, Any] = dict(state) if state else {}
63
+ self._error_policy = error_policy
64
+ self._timeout_ms = timeout_ms
65
+ self._callbacks = callbacks or GroupCallbacks()
66
+ self._max_total_cost = max_total_cost
67
+ self._stop_requested = False
68
+
69
+ def stop(self) -> None:
70
+ """Request graceful shutdown."""
71
+ self._stop_requested = True
72
+
73
+ async def run_async(self, prompt: str = "") -> GroupResult:
74
+ """Execute all agents concurrently."""
75
+ self._stop_requested = False
76
+ t0 = time.perf_counter()
77
+
78
+ # Frozen state snapshot for all agents
79
+ frozen_state = dict(self._state)
80
+
81
+ async def _run_one(
82
+ idx: int, agent: Any, custom_prompt: str | None
83
+ ) -> tuple[int, str, AgentResult | None, AgentError | None, GroupStep]:
84
+ name = _agent_name(agent, idx)
85
+
86
+ if custom_prompt is not None:
87
+ effective = _inject_state(custom_prompt, frozen_state)
88
+ elif prompt:
89
+ effective = _inject_state(prompt, frozen_state)
90
+ else:
91
+ effective = ""
92
+
93
+ if self._callbacks.on_agent_start:
94
+ self._callbacks.on_agent_start(name, effective)
95
+
96
+ step_t0 = time.perf_counter()
97
+ try:
98
+ coro = agent.run(effective)
99
+ if self._timeout_ms is not None:
100
+ result = await asyncio.wait_for(coro, timeout=self._timeout_ms / 1000)
101
+ else:
102
+ result = await coro
103
+
104
+ duration_ms = (time.perf_counter() - step_t0) * 1000
105
+ step = GroupStep(
106
+ agent_name=name,
107
+ step_type="agent_run",
108
+ timestamp=step_t0,
109
+ duration_ms=duration_ms,
110
+ usage_delta=getattr(result, "run_usage", {}),
111
+ )
112
+ if self._callbacks.on_agent_complete:
113
+ self._callbacks.on_agent_complete(name, result)
114
+ return idx, name, result, None, step
115
+
116
+ except Exception as exc:
117
+ duration_ms = (time.perf_counter() - step_t0) * 1000
118
+ err = AgentError(
119
+ agent_name=name,
120
+ error=exc,
121
+ output_key=getattr(agent, "output_key", None),
122
+ )
123
+ step = GroupStep(
124
+ agent_name=name,
125
+ step_type="agent_error",
126
+ timestamp=step_t0,
127
+ duration_ms=duration_ms,
128
+ error=str(exc),
129
+ )
130
+ if self._callbacks.on_agent_error:
131
+ self._callbacks.on_agent_error(name, exc)
132
+ return idx, name, None, err, step
133
+
134
+ # Launch all agents concurrently
135
+ tasks = [_run_one(idx, agent, custom_prompt) for idx, (agent, custom_prompt) in enumerate(self._agents)]
136
+ completed = await asyncio.gather(*tasks, return_exceptions=False)
137
+
138
+ # Sort by original index to maintain deterministic ordering
139
+ completed_sorted = sorted(completed, key=lambda x: x[0])
140
+
141
+ agent_results: dict[str, Any] = {}
142
+ errors: list[AgentError] = []
143
+ timeline: list[GroupStep] = []
144
+ usage_summaries: list[dict[str, Any]] = []
145
+
146
+ for idx, name, result, err, step in completed_sorted:
147
+ timeline.append(step)
148
+ if err is not None:
149
+ errors.append(err)
150
+ elif result is not None:
151
+ agent_results[name] = result
152
+ usage_summaries.append(getattr(result, "run_usage", {}))
153
+
154
+ # Apply output_key writes in order
155
+ agent_obj = self._agents[idx][0]
156
+ output_key = getattr(agent_obj, "output_key", None)
157
+ if output_key:
158
+ self._state[output_key] = result.output_text
159
+ if self._callbacks.on_state_update:
160
+ self._callbacks.on_state_update(output_key, result.output_text)
161
+
162
+ elapsed_ms = (time.perf_counter() - t0) * 1000
163
+ return GroupResult(
164
+ agent_results=agent_results,
165
+ aggregate_usage=_aggregate_usage(*usage_summaries),
166
+ shared_state=dict(self._state),
167
+ elapsed_ms=elapsed_ms,
168
+ timeline=timeline,
169
+ errors=errors,
170
+ success=len(errors) == 0,
171
+ )
172
+
173
+ def run(self, prompt: str = "") -> GroupResult:
174
+ """Sync wrapper around :meth:`run_async`."""
175
+ return asyncio.run(self.run_async(prompt))
176
+
177
+
178
+ # ------------------------------------------------------------------
179
+ # AsyncSequentialGroup
180
+ # ------------------------------------------------------------------
181
+
182
+
183
+ class AsyncSequentialGroup:
184
+ """Async version of :class:`~prompture.groups.SequentialGroup`.
185
+
186
+ Args:
187
+ agents: List of async agents or ``(agent, prompt_template)`` tuples.
188
+ state: Initial shared state dict.
189
+ error_policy: How to handle agent failures.
190
+ max_total_turns: Limit on total agent runs.
191
+ callbacks: Observability hooks.
192
+ max_total_cost: Budget cap in USD.
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ agents: list[Any],
198
+ *,
199
+ state: dict[str, Any] | None = None,
200
+ error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
201
+ max_total_turns: int | None = None,
202
+ callbacks: GroupCallbacks | None = None,
203
+ max_total_cost: float | None = None,
204
+ ) -> None:
205
+ self._agents = _normalise_agents(agents)
206
+ self._state: dict[str, Any] = dict(state) if state else {}
207
+ self._error_policy = error_policy
208
+ self._max_total_turns = max_total_turns
209
+ self._callbacks = callbacks or GroupCallbacks()
210
+ self._max_total_cost = max_total_cost
211
+ self._stop_requested = False
212
+
213
+ def stop(self) -> None:
214
+ self._stop_requested = True
215
+
216
+ async def run(self, prompt: str = "") -> GroupResult:
217
+ """Execute all agents in sequence (async)."""
218
+ self._stop_requested = False
219
+ t0 = time.perf_counter()
220
+ timeline: list[GroupStep] = []
221
+ agent_results: dict[str, Any] = {}
222
+ errors: list[AgentError] = []
223
+ usage_summaries: list[dict[str, Any]] = []
224
+ turns = 0
225
+
226
+ for idx, (agent, custom_prompt) in enumerate(self._agents):
227
+ if self._stop_requested:
228
+ break
229
+
230
+ name = _agent_name(agent, idx)
231
+
232
+ if custom_prompt is not None:
233
+ effective = _inject_state(custom_prompt, self._state)
234
+ elif prompt:
235
+ effective = _inject_state(prompt, self._state)
236
+ else:
237
+ effective = ""
238
+
239
+ if self._max_total_cost is not None:
240
+ total_so_far = sum(s.get("total_cost", 0.0) for s in usage_summaries)
241
+ if total_so_far >= self._max_total_cost:
242
+ break
243
+
244
+ if self._max_total_turns is not None and turns >= self._max_total_turns:
245
+ break
246
+
247
+ if self._callbacks.on_agent_start:
248
+ self._callbacks.on_agent_start(name, effective)
249
+
250
+ step_t0 = time.perf_counter()
251
+ try:
252
+ result = await agent.run(effective)
253
+ duration_ms = (time.perf_counter() - step_t0) * 1000
254
+ turns += 1
255
+
256
+ agent_results[name] = result
257
+ usage = getattr(result, "run_usage", {})
258
+ usage_summaries.append(usage)
259
+
260
+ output_key = getattr(agent, "output_key", None)
261
+ if output_key:
262
+ self._state[output_key] = result.output_text
263
+ if self._callbacks.on_state_update:
264
+ self._callbacks.on_state_update(output_key, result.output_text)
265
+
266
+ timeline.append(
267
+ GroupStep(
268
+ agent_name=name,
269
+ step_type="agent_run",
270
+ timestamp=step_t0,
271
+ duration_ms=duration_ms,
272
+ usage_delta=usage,
273
+ )
274
+ )
275
+
276
+ if self._callbacks.on_agent_complete:
277
+ self._callbacks.on_agent_complete(name, result)
278
+
279
+ except Exception as exc:
280
+ duration_ms = (time.perf_counter() - step_t0) * 1000
281
+ turns += 1
282
+ err = AgentError(
283
+ agent_name=name,
284
+ error=exc,
285
+ output_key=getattr(agent, "output_key", None),
286
+ )
287
+ errors.append(err)
288
+ timeline.append(
289
+ GroupStep(
290
+ agent_name=name,
291
+ step_type="agent_error",
292
+ timestamp=step_t0,
293
+ duration_ms=duration_ms,
294
+ error=str(exc),
295
+ )
296
+ )
297
+
298
+ if self._callbacks.on_agent_error:
299
+ self._callbacks.on_agent_error(name, exc)
300
+
301
+ if self._error_policy == ErrorPolicy.fail_fast:
302
+ break
303
+
304
+ elapsed_ms = (time.perf_counter() - t0) * 1000
305
+ return GroupResult(
306
+ agent_results=agent_results,
307
+ aggregate_usage=_aggregate_usage(*usage_summaries),
308
+ shared_state=dict(self._state),
309
+ elapsed_ms=elapsed_ms,
310
+ timeline=timeline,
311
+ errors=errors,
312
+ success=len(errors) == 0,
313
+ )
314
+
315
+
316
+ # ------------------------------------------------------------------
317
+ # AsyncLoopGroup
318
+ # ------------------------------------------------------------------
319
+
320
+
321
+ class AsyncLoopGroup:
322
+ """Async version of :class:`~prompture.groups.LoopGroup`.
323
+
324
+ Args:
325
+ agents: List of async agents or ``(agent, prompt_template)`` tuples.
326
+ exit_condition: Callable ``(state, iteration) -> bool``.
327
+ max_iterations: Hard cap on loop iterations.
328
+ state: Initial shared state dict.
329
+ error_policy: How to handle agent failures.
330
+ callbacks: Observability hooks.
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ agents: list[Any],
336
+ *,
337
+ exit_condition: Callable[[dict[str, Any], int], bool],
338
+ max_iterations: int = 10,
339
+ state: dict[str, Any] | None = None,
340
+ error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
341
+ callbacks: GroupCallbacks | None = None,
342
+ ) -> None:
343
+ self._agents = _normalise_agents(agents)
344
+ self._exit_condition = exit_condition
345
+ self._max_iterations = max_iterations
346
+ self._state: dict[str, Any] = dict(state) if state else {}
347
+ self._error_policy = error_policy
348
+ self._callbacks = callbacks or GroupCallbacks()
349
+ self._stop_requested = False
350
+
351
+ def stop(self) -> None:
352
+ self._stop_requested = True
353
+
354
+ async def run(self, prompt: str = "") -> GroupResult:
355
+ """Execute the loop (async)."""
356
+ self._stop_requested = False
357
+ t0 = time.perf_counter()
358
+ timeline: list[GroupStep] = []
359
+ agent_results: dict[str, Any] = {}
360
+ errors: list[AgentError] = []
361
+ usage_summaries: list[dict[str, Any]] = []
362
+
363
+ for iteration in range(self._max_iterations):
364
+ if self._stop_requested:
365
+ break
366
+ if self._exit_condition(self._state, iteration):
367
+ break
368
+
369
+ for idx, (agent, custom_prompt) in enumerate(self._agents):
370
+ if self._stop_requested:
371
+ break
372
+
373
+ name = _agent_name(agent, idx)
374
+ result_key = f"{name}_iter{iteration}"
375
+
376
+ if custom_prompt is not None:
377
+ effective = _inject_state(custom_prompt, self._state)
378
+ elif prompt:
379
+ effective = _inject_state(prompt, self._state)
380
+ else:
381
+ effective = ""
382
+
383
+ if self._callbacks.on_agent_start:
384
+ self._callbacks.on_agent_start(name, effective)
385
+
386
+ step_t0 = time.perf_counter()
387
+ try:
388
+ result = await agent.run(effective)
389
+ duration_ms = (time.perf_counter() - step_t0) * 1000
390
+
391
+ agent_results[result_key] = result
392
+ usage = getattr(result, "run_usage", {})
393
+ usage_summaries.append(usage)
394
+
395
+ output_key = getattr(agent, "output_key", None)
396
+ if output_key:
397
+ self._state[output_key] = result.output_text
398
+ if self._callbacks.on_state_update:
399
+ self._callbacks.on_state_update(output_key, result.output_text)
400
+
401
+ timeline.append(
402
+ GroupStep(
403
+ agent_name=name,
404
+ step_type="agent_run",
405
+ timestamp=step_t0,
406
+ duration_ms=duration_ms,
407
+ usage_delta=usage,
408
+ )
409
+ )
410
+
411
+ if self._callbacks.on_agent_complete:
412
+ self._callbacks.on_agent_complete(name, result)
413
+
414
+ except Exception as exc:
415
+ duration_ms = (time.perf_counter() - step_t0) * 1000
416
+ err = AgentError(
417
+ agent_name=name,
418
+ error=exc,
419
+ output_key=getattr(agent, "output_key", None),
420
+ )
421
+ errors.append(err)
422
+ timeline.append(
423
+ GroupStep(
424
+ agent_name=name,
425
+ step_type="agent_error",
426
+ timestamp=step_t0,
427
+ duration_ms=duration_ms,
428
+ error=str(exc),
429
+ )
430
+ )
431
+
432
+ if self._callbacks.on_agent_error:
433
+ self._callbacks.on_agent_error(name, exc)
434
+
435
+ if self._error_policy == ErrorPolicy.fail_fast:
436
+ break
437
+
438
+ if errors and self._error_policy == ErrorPolicy.fail_fast:
439
+ break
440
+
441
+ elapsed_ms = (time.perf_counter() - t0) * 1000
442
+ return GroupResult(
443
+ agent_results=agent_results,
444
+ aggregate_usage=_aggregate_usage(*usage_summaries),
445
+ shared_state=dict(self._state),
446
+ elapsed_ms=elapsed_ms,
447
+ timeline=timeline,
448
+ errors=errors,
449
+ success=len(errors) == 0,
450
+ )
451
+
452
+
453
+ # ------------------------------------------------------------------
454
+ # AsyncRouterAgent
455
+ # ------------------------------------------------------------------
456
+
457
+ _DEFAULT_ROUTING_PROMPT = """Given these specialists:
458
+ {agent_list}
459
+
460
+ Which should handle this? Reply with ONLY the name.
461
+
462
+ Request: {prompt}"""
463
+
464
+
465
+ class AsyncRouterAgent:
466
+ """Async LLM-driven router that delegates to the best-matching agent.
467
+
468
+ Args:
469
+ model: Model string for the routing LLM call.
470
+ agents: List of async agents to route between.
471
+ routing_prompt: Custom prompt template.
472
+ fallback: Agent to use when routing fails.
473
+ driver: Pre-built async driver instance.
474
+ """
475
+
476
+ def __init__(
477
+ self,
478
+ model: str = "",
479
+ *,
480
+ agents: list[Any],
481
+ routing_prompt: str | None = None,
482
+ fallback: Any | None = None,
483
+ driver: Any | None = None,
484
+ name: str = "",
485
+ description: str = "",
486
+ output_key: str | None = None,
487
+ ) -> None:
488
+ self._model = model
489
+ self._driver = driver
490
+ self._agents = {_agent_name(a, i): a for i, a in enumerate(agents)}
491
+ self._routing_prompt = routing_prompt or _DEFAULT_ROUTING_PROMPT
492
+ self._fallback = fallback
493
+ self.name = name
494
+ self.description = description
495
+ self.output_key = output_key
496
+
497
+ async def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
498
+ """Route the prompt to the best agent (async)."""
499
+ from .async_conversation import AsyncConversation
500
+
501
+ agent_lines = []
502
+ for name, agent in self._agents.items():
503
+ desc = getattr(agent, "description", "") or ""
504
+ agent_lines.append(f"- {name}: {desc}" if desc else f"- {name}")
505
+ agent_list = "\n".join(agent_lines)
506
+
507
+ routing_text = self._routing_prompt.replace("{agent_list}", agent_list).replace("{prompt}", prompt)
508
+
509
+ kwargs: dict[str, Any] = {}
510
+ if self._driver is not None:
511
+ kwargs["driver"] = self._driver
512
+ else:
513
+ kwargs["model_name"] = self._model
514
+
515
+ conv = AsyncConversation(**kwargs)
516
+ route_response = await conv.ask(routing_text)
517
+
518
+ selected = self._fuzzy_match(route_response.strip())
519
+
520
+ if selected is not None:
521
+ return await selected.run(prompt, deps=deps) if deps is not None else await selected.run(prompt)
522
+ elif self._fallback is not None:
523
+ return await self._fallback.run(prompt, deps=deps) if deps is not None else await self._fallback.run(prompt)
524
+ else:
525
+ return AgentResult(
526
+ output=route_response,
527
+ output_text=route_response,
528
+ messages=conv.messages,
529
+ usage=conv.usage,
530
+ state=AgentState.idle,
531
+ )
532
+
533
+ def _fuzzy_match(self, response: str) -> Any | None:
534
+ """Find the best matching agent name in the LLM response."""
535
+ response_lower = response.lower().strip()
536
+
537
+ for name, agent in self._agents.items():
538
+ if name.lower() == response_lower:
539
+ return agent
540
+
541
+ for name, agent in self._agents.items():
542
+ if name.lower() in response_lower:
543
+ return agent
544
+
545
+ response_words = set(response_lower.split())
546
+ for name, agent in self._agents.items():
547
+ name_words = set(name.lower().replace("_", " ").split())
548
+ if name_words & response_words:
549
+ return agent
550
+
551
+ return None