prompture 0.0.36.dev1__py3-none-any.whl → 0.0.37.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +120 -2
- prompture/_version.py +2 -2
- prompture/agent.py +925 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +879 -0
- prompture/async_conversation.py +199 -17
- prompture/async_driver.py +24 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +213 -18
- prompture/core.py +30 -12
- prompture/discovery.py +24 -1
- prompture/driver.py +38 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +7 -1
- prompture/drivers/async_claude_driver.py +7 -1
- prompture/drivers/async_google_driver.py +24 -4
- prompture/drivers/async_grok_driver.py +7 -1
- prompture/drivers/async_groq_driver.py +7 -1
- prompture/drivers/async_lmstudio_driver.py +59 -3
- prompture/drivers/async_ollama_driver.py +7 -0
- prompture/drivers/async_openai_driver.py +7 -1
- prompture/drivers/async_openrouter_driver.py +7 -1
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +7 -1
- prompture/drivers/claude_driver.py +7 -1
- prompture/drivers/google_driver.py +24 -4
- prompture/drivers/grok_driver.py +7 -1
- prompture/drivers/groq_driver.py +7 -1
- prompture/drivers/lmstudio_driver.py +58 -6
- prompture/drivers/ollama_driver.py +7 -0
- prompture/drivers/openai_driver.py +7 -1
- prompture/drivers/openrouter_driver.py +7 -1
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/METADATA +1 -1
- prompture-0.0.37.dev1.dist-info/RECORD +77 -0
- prompture-0.0.36.dev1.dist-info/RECORD +0 -66
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/top_level.txt +0 -0
prompture/groups.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
1
|
+
"""Synchronous multi-agent group coordination.
|
|
2
|
+
|
|
3
|
+
Provides :class:`SequentialGroup`, :class:`LoopGroup`,
|
|
4
|
+
:class:`RouterAgent`, and :class:`GroupAsAgent` for composing
|
|
5
|
+
multiple agents into deterministic workflows.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import re
|
|
12
|
+
import time
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from .agent_types import AgentResult, AgentState
|
|
17
|
+
from .group_types import (
|
|
18
|
+
AgentError,
|
|
19
|
+
ErrorPolicy,
|
|
20
|
+
GroupCallbacks,
|
|
21
|
+
GroupResult,
|
|
22
|
+
GroupStep,
|
|
23
|
+
_aggregate_usage,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger("prompture.groups")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# ------------------------------------------------------------------
|
|
30
|
+
# State injection helper
|
|
31
|
+
# ------------------------------------------------------------------
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _inject_state(template: str, state: dict[str, Any]) -> str:
|
|
35
|
+
"""Replace ``{key}`` placeholders with state values.
|
|
36
|
+
|
|
37
|
+
Unknown keys pass through unchanged so downstream agents can
|
|
38
|
+
still see the literal placeholder.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def _replacer(m: re.Match[str]) -> str:
|
|
42
|
+
key = m.group(1)
|
|
43
|
+
if key in state:
|
|
44
|
+
return str(state[key])
|
|
45
|
+
return m.group(0) # leave unchanged
|
|
46
|
+
|
|
47
|
+
return re.sub(r"\{(\w+)\}", _replacer, template)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ------------------------------------------------------------------
|
|
51
|
+
# Agent entry normalisation
|
|
52
|
+
# ------------------------------------------------------------------
|
|
53
|
+
|
|
54
|
+
AgentEntry = Any # Agent | tuple[Agent, str]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _normalise_agents(agents: list[Any]) -> list[tuple[Any, str | None]]:
|
|
58
|
+
"""Convert a mixed list of ``Agent`` or ``(Agent, prompt_template)`` to uniform tuples."""
|
|
59
|
+
result: list[tuple[Any, str | None]] = []
|
|
60
|
+
for item in agents:
|
|
61
|
+
if isinstance(item, tuple):
|
|
62
|
+
result.append((item[0], item[1]))
|
|
63
|
+
else:
|
|
64
|
+
result.append((item, None))
|
|
65
|
+
return result
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _agent_name(agent: Any, index: int) -> str:
|
|
69
|
+
"""Determine a display name for an agent."""
|
|
70
|
+
name = getattr(agent, "name", "") or ""
|
|
71
|
+
return name if name else f"agent_{index}"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ------------------------------------------------------------------
|
|
75
|
+
# SequentialGroup
|
|
76
|
+
# ------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class SequentialGroup:
|
|
80
|
+
"""Execute agents in sequence, passing state between them.
|
|
81
|
+
|
|
82
|
+
Each agent's ``output_key`` (if set) writes its output text
|
|
83
|
+
into the shared state dict, making it available as ``{key}``
|
|
84
|
+
in subsequent agent prompts.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
agents: List of agents or ``(agent, prompt_template)`` tuples.
|
|
88
|
+
state: Initial shared state dict.
|
|
89
|
+
error_policy: How to handle agent failures.
|
|
90
|
+
max_total_turns: Limit on total agent runs across the sequence.
|
|
91
|
+
callbacks: Observability hooks.
|
|
92
|
+
max_total_cost: Budget cap in USD.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
agents: list[Any],
|
|
98
|
+
*,
|
|
99
|
+
state: dict[str, Any] | None = None,
|
|
100
|
+
error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
|
|
101
|
+
max_total_turns: int | None = None,
|
|
102
|
+
callbacks: GroupCallbacks | None = None,
|
|
103
|
+
max_total_cost: float | None = None,
|
|
104
|
+
) -> None:
|
|
105
|
+
self._agents = _normalise_agents(agents)
|
|
106
|
+
self._state: dict[str, Any] = dict(state) if state else {}
|
|
107
|
+
self._error_policy = error_policy
|
|
108
|
+
self._max_total_turns = max_total_turns
|
|
109
|
+
self._callbacks = callbacks or GroupCallbacks()
|
|
110
|
+
self._max_total_cost = max_total_cost
|
|
111
|
+
self._stop_requested = False
|
|
112
|
+
|
|
113
|
+
def stop(self) -> None:
|
|
114
|
+
"""Request graceful shutdown after the current agent finishes."""
|
|
115
|
+
self._stop_requested = True
|
|
116
|
+
|
|
117
|
+
def save(self, path: str) -> None:
|
|
118
|
+
"""Run and save result to file. Convenience wrapper."""
|
|
119
|
+
result = self.run()
|
|
120
|
+
result.save(path)
|
|
121
|
+
|
|
122
|
+
def run(self, prompt: str = "") -> GroupResult:
|
|
123
|
+
"""Execute all agents in order."""
|
|
124
|
+
self._stop_requested = False
|
|
125
|
+
t0 = time.perf_counter()
|
|
126
|
+
timeline: list[GroupStep] = []
|
|
127
|
+
agent_results: dict[str, Any] = {}
|
|
128
|
+
errors: list[AgentError] = []
|
|
129
|
+
usage_summaries: list[dict[str, Any]] = []
|
|
130
|
+
turns = 0
|
|
131
|
+
|
|
132
|
+
for idx, (agent, custom_prompt) in enumerate(self._agents):
|
|
133
|
+
if self._stop_requested:
|
|
134
|
+
break
|
|
135
|
+
|
|
136
|
+
name = _agent_name(agent, idx)
|
|
137
|
+
|
|
138
|
+
# Build effective prompt
|
|
139
|
+
if custom_prompt is not None:
|
|
140
|
+
effective = _inject_state(custom_prompt, self._state)
|
|
141
|
+
elif prompt:
|
|
142
|
+
effective = _inject_state(prompt, self._state)
|
|
143
|
+
else:
|
|
144
|
+
effective = ""
|
|
145
|
+
|
|
146
|
+
# Check budget
|
|
147
|
+
if self._max_total_cost is not None:
|
|
148
|
+
total_so_far = sum(s.get("total_cost", 0.0) for s in usage_summaries)
|
|
149
|
+
if total_so_far >= self._max_total_cost:
|
|
150
|
+
logger.debug("Budget exceeded, stopping group")
|
|
151
|
+
break
|
|
152
|
+
|
|
153
|
+
# Check max turns
|
|
154
|
+
if self._max_total_turns is not None and turns >= self._max_total_turns:
|
|
155
|
+
logger.debug("Max total turns reached")
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
# Fire callback
|
|
159
|
+
if self._callbacks.on_agent_start:
|
|
160
|
+
self._callbacks.on_agent_start(name, effective)
|
|
161
|
+
|
|
162
|
+
step_t0 = time.perf_counter()
|
|
163
|
+
try:
|
|
164
|
+
result = agent.run(effective)
|
|
165
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
166
|
+
turns += 1
|
|
167
|
+
|
|
168
|
+
agent_results[name] = result
|
|
169
|
+
usage = getattr(result, "run_usage", {})
|
|
170
|
+
usage_summaries.append(usage)
|
|
171
|
+
|
|
172
|
+
# Write to shared state
|
|
173
|
+
output_key = getattr(agent, "output_key", None)
|
|
174
|
+
if output_key:
|
|
175
|
+
self._state[output_key] = result.output_text
|
|
176
|
+
if self._callbacks.on_state_update:
|
|
177
|
+
self._callbacks.on_state_update(output_key, result.output_text)
|
|
178
|
+
|
|
179
|
+
timeline.append(
|
|
180
|
+
GroupStep(
|
|
181
|
+
agent_name=name,
|
|
182
|
+
step_type="agent_run",
|
|
183
|
+
timestamp=step_t0,
|
|
184
|
+
duration_ms=duration_ms,
|
|
185
|
+
usage_delta=usage,
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if self._callbacks.on_agent_complete:
|
|
190
|
+
self._callbacks.on_agent_complete(name, result)
|
|
191
|
+
|
|
192
|
+
except Exception as exc:
|
|
193
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
194
|
+
turns += 1
|
|
195
|
+
err = AgentError(
|
|
196
|
+
agent_name=name,
|
|
197
|
+
error=exc,
|
|
198
|
+
output_key=getattr(agent, "output_key", None),
|
|
199
|
+
)
|
|
200
|
+
errors.append(err)
|
|
201
|
+
timeline.append(
|
|
202
|
+
GroupStep(
|
|
203
|
+
agent_name=name,
|
|
204
|
+
step_type="agent_error",
|
|
205
|
+
timestamp=step_t0,
|
|
206
|
+
duration_ms=duration_ms,
|
|
207
|
+
error=str(exc),
|
|
208
|
+
)
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if self._callbacks.on_agent_error:
|
|
212
|
+
self._callbacks.on_agent_error(name, exc)
|
|
213
|
+
|
|
214
|
+
if self._error_policy == ErrorPolicy.fail_fast:
|
|
215
|
+
break
|
|
216
|
+
# continue_on_error / retry_failed: continue to next agent
|
|
217
|
+
|
|
218
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
219
|
+
return GroupResult(
|
|
220
|
+
agent_results=agent_results,
|
|
221
|
+
aggregate_usage=_aggregate_usage(*usage_summaries),
|
|
222
|
+
shared_state=dict(self._state),
|
|
223
|
+
elapsed_ms=elapsed_ms,
|
|
224
|
+
timeline=timeline,
|
|
225
|
+
errors=errors,
|
|
226
|
+
success=len(errors) == 0,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# ------------------------------------------------------------------
|
|
231
|
+
# LoopGroup
|
|
232
|
+
# ------------------------------------------------------------------
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class LoopGroup:
|
|
236
|
+
"""Repeat a sequence of agents until an exit condition is met.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
agents: List of agents or ``(agent, prompt_template)`` tuples.
|
|
240
|
+
exit_condition: Callable ``(state, iteration) -> bool``.
|
|
241
|
+
When it returns ``True`` the loop stops.
|
|
242
|
+
max_iterations: Hard cap on loop iterations.
|
|
243
|
+
state: Initial shared state dict.
|
|
244
|
+
error_policy: How to handle agent failures.
|
|
245
|
+
callbacks: Observability hooks.
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
def __init__(
|
|
249
|
+
self,
|
|
250
|
+
agents: list[Any],
|
|
251
|
+
*,
|
|
252
|
+
exit_condition: Callable[[dict[str, Any], int], bool],
|
|
253
|
+
max_iterations: int = 10,
|
|
254
|
+
state: dict[str, Any] | None = None,
|
|
255
|
+
error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
|
|
256
|
+
callbacks: GroupCallbacks | None = None,
|
|
257
|
+
) -> None:
|
|
258
|
+
self._agents = _normalise_agents(agents)
|
|
259
|
+
self._exit_condition = exit_condition
|
|
260
|
+
self._max_iterations = max_iterations
|
|
261
|
+
self._state: dict[str, Any] = dict(state) if state else {}
|
|
262
|
+
self._error_policy = error_policy
|
|
263
|
+
self._callbacks = callbacks or GroupCallbacks()
|
|
264
|
+
self._stop_requested = False
|
|
265
|
+
|
|
266
|
+
def stop(self) -> None:
|
|
267
|
+
"""Request graceful shutdown."""
|
|
268
|
+
self._stop_requested = True
|
|
269
|
+
|
|
270
|
+
def run(self, prompt: str = "") -> GroupResult:
|
|
271
|
+
"""Execute the loop."""
|
|
272
|
+
self._stop_requested = False
|
|
273
|
+
t0 = time.perf_counter()
|
|
274
|
+
timeline: list[GroupStep] = []
|
|
275
|
+
agent_results: dict[str, Any] = {}
|
|
276
|
+
errors: list[AgentError] = []
|
|
277
|
+
usage_summaries: list[dict[str, Any]] = []
|
|
278
|
+
|
|
279
|
+
for iteration in range(self._max_iterations):
|
|
280
|
+
if self._stop_requested:
|
|
281
|
+
break
|
|
282
|
+
if self._exit_condition(self._state, iteration):
|
|
283
|
+
break
|
|
284
|
+
|
|
285
|
+
for idx, (agent, custom_prompt) in enumerate(self._agents):
|
|
286
|
+
if self._stop_requested:
|
|
287
|
+
break
|
|
288
|
+
|
|
289
|
+
name = _agent_name(agent, idx)
|
|
290
|
+
result_key = f"{name}_iter{iteration}"
|
|
291
|
+
|
|
292
|
+
if custom_prompt is not None:
|
|
293
|
+
effective = _inject_state(custom_prompt, self._state)
|
|
294
|
+
elif prompt:
|
|
295
|
+
effective = _inject_state(prompt, self._state)
|
|
296
|
+
else:
|
|
297
|
+
effective = ""
|
|
298
|
+
|
|
299
|
+
if self._callbacks.on_agent_start:
|
|
300
|
+
self._callbacks.on_agent_start(name, effective)
|
|
301
|
+
|
|
302
|
+
step_t0 = time.perf_counter()
|
|
303
|
+
try:
|
|
304
|
+
result = agent.run(effective)
|
|
305
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
306
|
+
|
|
307
|
+
agent_results[result_key] = result
|
|
308
|
+
usage = getattr(result, "run_usage", {})
|
|
309
|
+
usage_summaries.append(usage)
|
|
310
|
+
|
|
311
|
+
output_key = getattr(agent, "output_key", None)
|
|
312
|
+
if output_key:
|
|
313
|
+
self._state[output_key] = result.output_text
|
|
314
|
+
if self._callbacks.on_state_update:
|
|
315
|
+
self._callbacks.on_state_update(output_key, result.output_text)
|
|
316
|
+
|
|
317
|
+
timeline.append(
|
|
318
|
+
GroupStep(
|
|
319
|
+
agent_name=name,
|
|
320
|
+
step_type="agent_run",
|
|
321
|
+
timestamp=step_t0,
|
|
322
|
+
duration_ms=duration_ms,
|
|
323
|
+
usage_delta=usage,
|
|
324
|
+
)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if self._callbacks.on_agent_complete:
|
|
328
|
+
self._callbacks.on_agent_complete(name, result)
|
|
329
|
+
|
|
330
|
+
except Exception as exc:
|
|
331
|
+
duration_ms = (time.perf_counter() - step_t0) * 1000
|
|
332
|
+
err = AgentError(
|
|
333
|
+
agent_name=name,
|
|
334
|
+
error=exc,
|
|
335
|
+
output_key=getattr(agent, "output_key", None),
|
|
336
|
+
)
|
|
337
|
+
errors.append(err)
|
|
338
|
+
timeline.append(
|
|
339
|
+
GroupStep(
|
|
340
|
+
agent_name=name,
|
|
341
|
+
step_type="agent_error",
|
|
342
|
+
timestamp=step_t0,
|
|
343
|
+
duration_ms=duration_ms,
|
|
344
|
+
error=str(exc),
|
|
345
|
+
)
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
if self._callbacks.on_agent_error:
|
|
349
|
+
self._callbacks.on_agent_error(name, exc)
|
|
350
|
+
|
|
351
|
+
if self._error_policy == ErrorPolicy.fail_fast:
|
|
352
|
+
break
|
|
353
|
+
|
|
354
|
+
# Check if error caused early exit
|
|
355
|
+
if errors and self._error_policy == ErrorPolicy.fail_fast:
|
|
356
|
+
break
|
|
357
|
+
|
|
358
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
359
|
+
return GroupResult(
|
|
360
|
+
agent_results=agent_results,
|
|
361
|
+
aggregate_usage=_aggregate_usage(*usage_summaries),
|
|
362
|
+
shared_state=dict(self._state),
|
|
363
|
+
elapsed_ms=elapsed_ms,
|
|
364
|
+
timeline=timeline,
|
|
365
|
+
errors=errors,
|
|
366
|
+
success=len(errors) == 0,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
# ------------------------------------------------------------------
|
|
371
|
+
# RouterAgent
|
|
372
|
+
# ------------------------------------------------------------------
|
|
373
|
+
|
|
374
|
+
_DEFAULT_ROUTING_PROMPT = """Given these specialists:
|
|
375
|
+
{agent_list}
|
|
376
|
+
|
|
377
|
+
Which should handle this? Reply with ONLY the name.
|
|
378
|
+
|
|
379
|
+
Request: {prompt}"""
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class RouterAgent:
|
|
383
|
+
"""LLM-driven router that delegates to the best-matching agent.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
model: Model string for the routing LLM call.
|
|
387
|
+
agents: List of agents to route between.
|
|
388
|
+
routing_prompt: Custom prompt template (must include ``{agent_list}``
|
|
389
|
+
and ``{prompt}`` placeholders).
|
|
390
|
+
fallback: Agent to use when routing fails.
|
|
391
|
+
driver: Pre-built driver instance for the routing call.
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
def __init__(
|
|
395
|
+
self,
|
|
396
|
+
model: str = "",
|
|
397
|
+
*,
|
|
398
|
+
agents: list[Any],
|
|
399
|
+
routing_prompt: str | None = None,
|
|
400
|
+
fallback: Any | None = None,
|
|
401
|
+
driver: Any | None = None,
|
|
402
|
+
name: str = "",
|
|
403
|
+
description: str = "",
|
|
404
|
+
output_key: str | None = None,
|
|
405
|
+
) -> None:
|
|
406
|
+
self._model = model
|
|
407
|
+
self._driver = driver
|
|
408
|
+
self._agents = {_agent_name(a, i): a for i, a in enumerate(agents)}
|
|
409
|
+
self._routing_prompt = routing_prompt or _DEFAULT_ROUTING_PROMPT
|
|
410
|
+
self._fallback = fallback
|
|
411
|
+
self.name = name
|
|
412
|
+
self.description = description
|
|
413
|
+
self.output_key = output_key
|
|
414
|
+
|
|
415
|
+
def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
|
|
416
|
+
"""Route the prompt to the best agent and return its result."""
|
|
417
|
+
from .conversation import Conversation
|
|
418
|
+
|
|
419
|
+
# Build agent list for routing prompt
|
|
420
|
+
agent_lines = []
|
|
421
|
+
for name, agent in self._agents.items():
|
|
422
|
+
desc = getattr(agent, "description", "") or ""
|
|
423
|
+
agent_lines.append(f"- {name}: {desc}" if desc else f"- {name}")
|
|
424
|
+
agent_list = "\n".join(agent_lines)
|
|
425
|
+
|
|
426
|
+
routing_text = self._routing_prompt.replace("{agent_list}", agent_list).replace("{prompt}", prompt)
|
|
427
|
+
|
|
428
|
+
# Single LLM call for routing
|
|
429
|
+
kwargs: dict[str, Any] = {}
|
|
430
|
+
if self._driver is not None:
|
|
431
|
+
kwargs["driver"] = self._driver
|
|
432
|
+
else:
|
|
433
|
+
kwargs["model_name"] = self._model
|
|
434
|
+
|
|
435
|
+
conv = Conversation(**kwargs)
|
|
436
|
+
route_response = conv.ask(routing_text)
|
|
437
|
+
|
|
438
|
+
# Fuzzy match against known agent names
|
|
439
|
+
selected = self._fuzzy_match(route_response.strip())
|
|
440
|
+
|
|
441
|
+
if selected is not None:
|
|
442
|
+
return selected.run(prompt, deps=deps) if deps is not None else selected.run(prompt)
|
|
443
|
+
elif self._fallback is not None:
|
|
444
|
+
return self._fallback.run(prompt, deps=deps) if deps is not None else self._fallback.run(prompt)
|
|
445
|
+
else:
|
|
446
|
+
# Return routing response as fallback
|
|
447
|
+
return AgentResult(
|
|
448
|
+
output=route_response,
|
|
449
|
+
output_text=route_response,
|
|
450
|
+
messages=conv.messages,
|
|
451
|
+
usage=conv.usage,
|
|
452
|
+
state=AgentState.idle,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
def _fuzzy_match(self, response: str) -> Any | None:
|
|
456
|
+
"""Find the best matching agent name in the LLM response."""
|
|
457
|
+
response_lower = response.lower().strip()
|
|
458
|
+
|
|
459
|
+
# Exact match
|
|
460
|
+
for name, agent in self._agents.items():
|
|
461
|
+
if name.lower() == response_lower:
|
|
462
|
+
return agent
|
|
463
|
+
|
|
464
|
+
# Substring match
|
|
465
|
+
for name, agent in self._agents.items():
|
|
466
|
+
if name.lower() in response_lower:
|
|
467
|
+
return agent
|
|
468
|
+
|
|
469
|
+
# Word-level match
|
|
470
|
+
response_words = set(response_lower.split())
|
|
471
|
+
for name, agent in self._agents.items():
|
|
472
|
+
name_words = set(name.lower().replace("_", " ").split())
|
|
473
|
+
if name_words & response_words:
|
|
474
|
+
return agent
|
|
475
|
+
|
|
476
|
+
return None
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
# ------------------------------------------------------------------
|
|
480
|
+
# GroupAsAgent
|
|
481
|
+
# ------------------------------------------------------------------
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
class GroupAsAgent:
|
|
485
|
+
"""Adapter that makes a group behave like an Agent for composability.
|
|
486
|
+
|
|
487
|
+
Allows nesting groups inside other groups by presenting the same
|
|
488
|
+
``run(prompt) -> AgentResult`` interface.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
group: The group to wrap (SequentialGroup, LoopGroup, etc.).
|
|
492
|
+
name: Agent identity name.
|
|
493
|
+
output_key: Shared state key for writing output.
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
def __init__(
|
|
497
|
+
self,
|
|
498
|
+
group: Any,
|
|
499
|
+
*,
|
|
500
|
+
name: str = "",
|
|
501
|
+
output_key: str | None = None,
|
|
502
|
+
) -> None:
|
|
503
|
+
self._group = group
|
|
504
|
+
self.name = name
|
|
505
|
+
self.output_key = output_key
|
|
506
|
+
self.description = ""
|
|
507
|
+
|
|
508
|
+
def run(self, prompt: str, **kwargs: Any) -> AgentResult:
|
|
509
|
+
"""Run the wrapped group and return an AgentResult."""
|
|
510
|
+
group_result = self._group.run(prompt)
|
|
511
|
+
|
|
512
|
+
# Use the last agent's output text, or the shared state
|
|
513
|
+
output_text = ""
|
|
514
|
+
if group_result.agent_results:
|
|
515
|
+
last_result = list(group_result.agent_results.values())[-1]
|
|
516
|
+
output_text = getattr(last_result, "output_text", str(last_result))
|
|
517
|
+
|
|
518
|
+
return AgentResult(
|
|
519
|
+
output=output_text,
|
|
520
|
+
output_text=output_text,
|
|
521
|
+
messages=[],
|
|
522
|
+
usage=group_result.aggregate_usage,
|
|
523
|
+
state=AgentState.idle,
|
|
524
|
+
run_usage=group_result.aggregate_usage,
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
def stop(self) -> None:
|
|
528
|
+
"""Propagate stop to the wrapped group."""
|
|
529
|
+
if hasattr(self._group, "stop"):
|
|
530
|
+
self._group.stop()
|
prompture/image.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Image handling utilities for vision-capable LLM drivers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import mimetypes
|
|
7
|
+
import re
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Union
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class ImageContent:
|
|
15
|
+
"""Normalized image representation for vision-capable drivers.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
data: Base64-encoded image data.
|
|
19
|
+
media_type: MIME type (e.g. ``"image/png"``, ``"image/jpeg"``).
|
|
20
|
+
source_type: How the image is delivered — ``"base64"`` or ``"url"``.
|
|
21
|
+
url: Original URL when ``source_type`` is ``"url"``.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
data: str
|
|
25
|
+
media_type: str
|
|
26
|
+
source_type: str = "base64"
|
|
27
|
+
url: str | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Public type alias accepted by all image-aware APIs.
|
|
31
|
+
ImageInput = Union[bytes, str, Path, ImageContent]
|
|
32
|
+
|
|
33
|
+
# Known data-URI prefix pattern
|
|
34
|
+
_DATA_URI_RE = re.compile(r"^data:(image/[a-zA-Z0-9.+-]+);base64,(.+)$", re.DOTALL)
|
|
35
|
+
|
|
36
|
+
# Base64 detection heuristic — must look like pure base64 of reasonable length
|
|
37
|
+
_BASE64_RE = re.compile(r"^[A-Za-z0-9+/\n\r]+=*$")
|
|
38
|
+
|
|
39
|
+
_MIME_FROM_EXT: dict[str, str] = {
|
|
40
|
+
".jpg": "image/jpeg",
|
|
41
|
+
".jpeg": "image/jpeg",
|
|
42
|
+
".png": "image/png",
|
|
43
|
+
".gif": "image/gif",
|
|
44
|
+
".webp": "image/webp",
|
|
45
|
+
".bmp": "image/bmp",
|
|
46
|
+
".svg": "image/svg+xml",
|
|
47
|
+
".tiff": "image/tiff",
|
|
48
|
+
".tif": "image/tiff",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
_MAGIC_BYTES: list[tuple[bytes, str]] = [
|
|
52
|
+
(b"\x89PNG", "image/png"),
|
|
53
|
+
(b"\xff\xd8\xff", "image/jpeg"),
|
|
54
|
+
(b"GIF87a", "image/gif"),
|
|
55
|
+
(b"GIF89a", "image/gif"),
|
|
56
|
+
(b"RIFF", "image/webp"), # WebP starts with RIFF...WEBP
|
|
57
|
+
(b"BM", "image/bmp"),
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _guess_media_type_from_bytes(data: bytes) -> str:
|
|
62
|
+
"""Guess MIME type from the first few bytes of image data."""
|
|
63
|
+
for magic, mime in _MAGIC_BYTES:
|
|
64
|
+
if data[: len(magic)] == magic:
|
|
65
|
+
return mime
|
|
66
|
+
return "image/png" # safe fallback
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _guess_media_type(path: str) -> str:
|
|
70
|
+
"""Guess MIME type from a file path or URL."""
|
|
71
|
+
# Strip query strings for URLs
|
|
72
|
+
clean = path.split("?")[0].split("#")[0]
|
|
73
|
+
ext = Path(clean).suffix.lower()
|
|
74
|
+
if ext in _MIME_FROM_EXT:
|
|
75
|
+
return _MIME_FROM_EXT[ext]
|
|
76
|
+
guessed = mimetypes.guess_type(clean)[0]
|
|
77
|
+
return guessed or "image/png"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# ------------------------------------------------------------------
|
|
81
|
+
# Constructor functions
|
|
82
|
+
# ------------------------------------------------------------------
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def image_from_bytes(data: bytes, media_type: str | None = None) -> ImageContent:
|
|
86
|
+
"""Create an :class:`ImageContent` from raw bytes.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
data: Raw image bytes.
|
|
90
|
+
media_type: MIME type. Auto-detected from magic bytes when *None*.
|
|
91
|
+
"""
|
|
92
|
+
if not data:
|
|
93
|
+
raise ValueError("Image data cannot be empty")
|
|
94
|
+
b64 = base64.b64encode(data).decode("ascii")
|
|
95
|
+
mt = media_type or _guess_media_type_from_bytes(data)
|
|
96
|
+
return ImageContent(data=b64, media_type=mt)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def image_from_base64(b64: str, media_type: str = "image/png") -> ImageContent:
|
|
100
|
+
"""Create an :class:`ImageContent` from a base64-encoded string.
|
|
101
|
+
|
|
102
|
+
Accepts both raw base64 and ``data:`` URIs.
|
|
103
|
+
"""
|
|
104
|
+
m = _DATA_URI_RE.match(b64)
|
|
105
|
+
if m:
|
|
106
|
+
return ImageContent(data=m.group(2), media_type=m.group(1))
|
|
107
|
+
return ImageContent(data=b64, media_type=media_type)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def image_from_file(path: str | Path, media_type: str | None = None) -> ImageContent:
|
|
111
|
+
"""Create an :class:`ImageContent` by reading a local file.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
path: Path to an image file.
|
|
115
|
+
media_type: MIME type. Guessed from extension when *None*.
|
|
116
|
+
"""
|
|
117
|
+
p = Path(path)
|
|
118
|
+
if not p.exists():
|
|
119
|
+
raise FileNotFoundError(f"Image file not found: {p}")
|
|
120
|
+
raw = p.read_bytes()
|
|
121
|
+
mt = media_type or _guess_media_type(str(p))
|
|
122
|
+
return image_from_bytes(raw, mt)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def image_from_url(url: str, media_type: str | None = None) -> ImageContent:
|
|
126
|
+
"""Create an :class:`ImageContent` referencing a remote URL.
|
|
127
|
+
|
|
128
|
+
The image is **not** downloaded — the URL is stored directly so
|
|
129
|
+
drivers that accept URL-based images can pass it through. For
|
|
130
|
+
drivers that require base64, the URL is embedded as a data URI.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
url: Publicly-accessible image URL.
|
|
134
|
+
media_type: MIME type. Guessed from the URL when *None*.
|
|
135
|
+
"""
|
|
136
|
+
mt = media_type or _guess_media_type(url)
|
|
137
|
+
return ImageContent(data="", media_type=mt, source_type="url", url=url)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# ------------------------------------------------------------------
|
|
141
|
+
# Smart constructor
|
|
142
|
+
# ------------------------------------------------------------------
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def make_image(source: ImageInput) -> ImageContent:
|
|
146
|
+
"""Auto-detect the source type and return an :class:`ImageContent`.
|
|
147
|
+
|
|
148
|
+
Accepts:
|
|
149
|
+
- ``ImageContent`` — returned as-is.
|
|
150
|
+
- ``bytes`` — base64-encoded with auto-detected MIME.
|
|
151
|
+
- ``str`` — tries (in order): data URI, URL, file path, raw base64.
|
|
152
|
+
- ``pathlib.Path`` — read from disk.
|
|
153
|
+
"""
|
|
154
|
+
if isinstance(source, ImageContent):
|
|
155
|
+
return source
|
|
156
|
+
|
|
157
|
+
if isinstance(source, bytes):
|
|
158
|
+
return image_from_bytes(source)
|
|
159
|
+
|
|
160
|
+
if isinstance(source, Path):
|
|
161
|
+
return image_from_file(source)
|
|
162
|
+
|
|
163
|
+
if isinstance(source, str):
|
|
164
|
+
# 1. data URI
|
|
165
|
+
if source.startswith("data:"):
|
|
166
|
+
return image_from_base64(source)
|
|
167
|
+
|
|
168
|
+
# 2. URL
|
|
169
|
+
if source.startswith(("http://", "https://")):
|
|
170
|
+
return image_from_url(source)
|
|
171
|
+
|
|
172
|
+
# 3. File path (if exists on disk)
|
|
173
|
+
p = Path(source)
|
|
174
|
+
if p.exists():
|
|
175
|
+
return image_from_file(p)
|
|
176
|
+
|
|
177
|
+
# 4. Assume raw base64
|
|
178
|
+
return image_from_base64(source)
|
|
179
|
+
|
|
180
|
+
raise TypeError(f"Unsupported image source type: {type(source).__name__}")
|