prompture 0.0.35__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. prompture/__init__.py +120 -2
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +880 -0
  6. prompture/async_conversation.py +199 -17
  7. prompture/async_driver.py +24 -0
  8. prompture/async_groups.py +551 -0
  9. prompture/conversation.py +213 -18
  10. prompture/core.py +30 -12
  11. prompture/discovery.py +24 -1
  12. prompture/driver.py +38 -0
  13. prompture/drivers/__init__.py +5 -1
  14. prompture/drivers/async_azure_driver.py +7 -1
  15. prompture/drivers/async_claude_driver.py +7 -1
  16. prompture/drivers/async_google_driver.py +212 -28
  17. prompture/drivers/async_grok_driver.py +7 -1
  18. prompture/drivers/async_groq_driver.py +7 -1
  19. prompture/drivers/async_lmstudio_driver.py +74 -5
  20. prompture/drivers/async_ollama_driver.py +13 -3
  21. prompture/drivers/async_openai_driver.py +7 -1
  22. prompture/drivers/async_openrouter_driver.py +7 -1
  23. prompture/drivers/async_registry.py +5 -1
  24. prompture/drivers/azure_driver.py +7 -1
  25. prompture/drivers/claude_driver.py +7 -1
  26. prompture/drivers/google_driver.py +217 -33
  27. prompture/drivers/grok_driver.py +7 -1
  28. prompture/drivers/groq_driver.py +7 -1
  29. prompture/drivers/lmstudio_driver.py +73 -8
  30. prompture/drivers/ollama_driver.py +16 -5
  31. prompture/drivers/openai_driver.py +7 -1
  32. prompture/drivers/openrouter_driver.py +7 -1
  33. prompture/drivers/vision_helpers.py +153 -0
  34. prompture/group_types.py +147 -0
  35. prompture/groups.py +530 -0
  36. prompture/image.py +180 -0
  37. prompture/persistence.py +254 -0
  38. prompture/persona.py +482 -0
  39. prompture/serialization.py +218 -0
  40. prompture/settings.py +1 -0
  41. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  42. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  43. prompture-0.0.35.dist-info/METADATA +0 -464
  44. prompture-0.0.35.dist-info/RECORD +0 -66
  45. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +0 -0
  46. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  47. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  48. {prompture-0.0.35.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ """Shared helpers for converting universal vision message blocks to provider-specific formats."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+
8
+ def _prepare_openai_vision_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
9
+ """Convert universal image blocks to OpenAI-compatible vision format.
10
+
11
+ Works for OpenAI, Azure, Groq, Grok, LM Studio, and OpenRouter.
12
+
13
+ Universal format::
14
+
15
+ {"type": "image", "source": ImageContent(...)}
16
+
17
+ OpenAI format::
18
+
19
+ {"type": "image_url", "image_url": {"url": "data:mime;base64,..."}}
20
+ """
21
+ out: list[dict[str, Any]] = []
22
+ for msg in messages:
23
+ content = msg.get("content")
24
+ if not isinstance(content, list):
25
+ out.append(msg)
26
+ continue
27
+ new_blocks: list[dict[str, Any]] = []
28
+ for block in content:
29
+ if isinstance(block, dict) and block.get("type") == "image":
30
+ source = block["source"]
31
+ if source.source_type == "url" and source.url:
32
+ url = source.url
33
+ else:
34
+ url = f"data:{source.media_type};base64,{source.data}"
35
+ new_blocks.append(
36
+ {
37
+ "type": "image_url",
38
+ "image_url": {"url": url},
39
+ }
40
+ )
41
+ else:
42
+ new_blocks.append(block)
43
+ out.append({**msg, "content": new_blocks})
44
+ return out
45
+
46
+
47
+ def _prepare_claude_vision_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
48
+ """Convert universal image blocks to Anthropic Claude format.
49
+
50
+ Claude format::
51
+
52
+ {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
53
+ """
54
+ out: list[dict[str, Any]] = []
55
+ for msg in messages:
56
+ content = msg.get("content")
57
+ if not isinstance(content, list):
58
+ out.append(msg)
59
+ continue
60
+ new_blocks: list[dict[str, Any]] = []
61
+ for block in content:
62
+ if isinstance(block, dict) and block.get("type") == "image":
63
+ source = block["source"]
64
+ if source.source_type == "url" and source.url:
65
+ new_blocks.append(
66
+ {
67
+ "type": "image",
68
+ "source": {
69
+ "type": "url",
70
+ "url": source.url,
71
+ },
72
+ }
73
+ )
74
+ else:
75
+ new_blocks.append(
76
+ {
77
+ "type": "image",
78
+ "source": {
79
+ "type": "base64",
80
+ "media_type": source.media_type,
81
+ "data": source.data,
82
+ },
83
+ }
84
+ )
85
+ else:
86
+ new_blocks.append(block)
87
+ out.append({**msg, "content": new_blocks})
88
+ return out
89
+
90
+
91
+ def _prepare_google_vision_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
92
+ """Convert universal image blocks to Google Gemini format.
93
+
94
+ Gemini expects ``parts`` arrays containing text and inline_data dicts::
95
+
96
+ {"role": "user", "parts": [
97
+ "text prompt",
98
+ {"inline_data": {"mime_type": "image/png", "data": "base64..."}},
99
+ ]}
100
+ """
101
+ out: list[dict[str, Any]] = []
102
+ for msg in messages:
103
+ content = msg.get("content")
104
+ if not isinstance(content, list):
105
+ out.append(msg)
106
+ continue
107
+ # Convert content blocks to Gemini parts
108
+ parts: list[Any] = []
109
+ for block in content:
110
+ if isinstance(block, dict) and block.get("type") == "text":
111
+ parts.append(block["text"])
112
+ elif isinstance(block, dict) and block.get("type") == "image":
113
+ source = block["source"]
114
+ parts.append(
115
+ {
116
+ "inline_data": {
117
+ "mime_type": source.media_type,
118
+ "data": source.data,
119
+ }
120
+ }
121
+ )
122
+ else:
123
+ parts.append(block)
124
+ out.append({**msg, "content": parts, "_vision_parts": True})
125
+ return out
126
+
127
+
128
+ def _prepare_ollama_vision_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
129
+ """Convert universal image blocks to Ollama format.
130
+
131
+ Ollama expects images as a separate field::
132
+
133
+ {"role": "user", "content": "text", "images": ["base64..."]}
134
+ """
135
+ out: list[dict[str, Any]] = []
136
+ for msg in messages:
137
+ content = msg.get("content")
138
+ if not isinstance(content, list):
139
+ out.append(msg)
140
+ continue
141
+ text_parts: list[str] = []
142
+ images: list[str] = []
143
+ for block in content:
144
+ if isinstance(block, dict) and block.get("type") == "text":
145
+ text_parts.append(block["text"])
146
+ elif isinstance(block, dict) and block.get("type") == "image":
147
+ source = block["source"]
148
+ images.append(source.data)
149
+ new_msg = {**msg, "content": " ".join(text_parts)}
150
+ if images:
151
+ new_msg["images"] = images
152
+ out.append(new_msg)
153
+ return out
@@ -0,0 +1,147 @@
1
+ """Shared types for multi-agent group coordination.
2
+
3
+ Defines enums, dataclasses, and callbacks used by
4
+ :class:`~prompture.groups.SequentialGroup`,
5
+ :class:`~prompture.async_groups.ParallelGroup`, and related classes.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import enum
11
+ import json
12
+ import time
13
+ from collections.abc import Callable
14
+ from dataclasses import dataclass, field
15
+ from typing import Any
16
+
17
+
18
+ class ErrorPolicy(enum.Enum):
19
+ """How a group handles agent failures."""
20
+
21
+ fail_fast = "fail_fast"
22
+ continue_on_error = "continue_on_error"
23
+ retry_failed = "retry_failed"
24
+
25
+
26
+ @dataclass
27
+ class GroupStep:
28
+ """Record of a single agent execution within a group run."""
29
+
30
+ agent_name: str
31
+ step_type: str = "agent_run"
32
+ timestamp: float = 0.0
33
+ duration_ms: float = 0.0
34
+ usage_delta: dict[str, Any] = field(default_factory=dict)
35
+ error: str | None = None
36
+
37
+
38
+ @dataclass
39
+ class AgentError:
40
+ """Captures a failed agent execution within a group."""
41
+
42
+ agent_name: str
43
+ error: Exception
44
+ error_message: str = ""
45
+ output_key: str | None = None
46
+
47
+ def __post_init__(self) -> None:
48
+ if not self.error_message:
49
+ self.error_message = str(self.error)
50
+
51
+
52
+ @dataclass
53
+ class GroupResult:
54
+ """Outcome of a group execution.
55
+
56
+ Attributes:
57
+ agent_results: Mapping of agent name/key to their :class:`AgentResult`.
58
+ aggregate_usage: Combined token/cost totals across all agent runs.
59
+ shared_state: Final state dict after all agents have written outputs.
60
+ elapsed_ms: Wall-clock duration of the group run.
61
+ timeline: Ordered list of :class:`GroupStep` records.
62
+ errors: List of :class:`AgentError` for any failed agents.
63
+ success: ``True`` if no errors occurred.
64
+ """
65
+
66
+ agent_results: dict[str, Any] = field(default_factory=dict)
67
+ aggregate_usage: dict[str, Any] = field(default_factory=dict)
68
+ shared_state: dict[str, Any] = field(default_factory=dict)
69
+ elapsed_ms: float = 0.0
70
+ timeline: list[GroupStep] = field(default_factory=list)
71
+ errors: list[AgentError] = field(default_factory=list)
72
+ success: bool = True
73
+
74
+ def export(self) -> dict[str, Any]:
75
+ """Return a JSON-serializable dict representation."""
76
+ return {
77
+ "agent_results": {
78
+ k: {
79
+ "output_text": getattr(v, "output_text", str(v)),
80
+ "usage": getattr(v, "run_usage", {}),
81
+ }
82
+ for k, v in self.agent_results.items()
83
+ },
84
+ "aggregate_usage": self.aggregate_usage,
85
+ "shared_state": self.shared_state,
86
+ "elapsed_ms": self.elapsed_ms,
87
+ "timeline": [
88
+ {
89
+ "agent_name": s.agent_name,
90
+ "step_type": s.step_type,
91
+ "timestamp": s.timestamp,
92
+ "duration_ms": s.duration_ms,
93
+ "usage_delta": s.usage_delta,
94
+ "error": s.error,
95
+ }
96
+ for s in self.timeline
97
+ ],
98
+ "errors": [
99
+ {
100
+ "agent_name": e.agent_name,
101
+ "error_message": e.error_message,
102
+ "output_key": e.output_key,
103
+ }
104
+ for e in self.errors
105
+ ],
106
+ "success": self.success,
107
+ }
108
+
109
+ def save(self, path: str) -> None:
110
+ """Write the exported dict to a JSON file."""
111
+ with open(path, "w", encoding="utf-8") as f:
112
+ json.dump(self.export(), f, indent=2, default=str)
113
+
114
+
115
+ @dataclass
116
+ class GroupCallbacks:
117
+ """Observability callbacks for group execution."""
118
+
119
+ on_agent_start: Callable[[str, str], None] | None = None
120
+ on_agent_complete: Callable[[str, Any], None] | None = None
121
+ on_agent_error: Callable[[str, Exception], None] | None = None
122
+ on_state_update: Callable[[str, Any], None] | None = None
123
+
124
+
125
+ def _aggregate_usage(*sessions: dict[str, Any]) -> dict[str, Any]:
126
+ """Merge multiple usage summary dicts into one aggregate."""
127
+ agg: dict[str, Any] = {
128
+ "prompt_tokens": 0,
129
+ "completion_tokens": 0,
130
+ "total_tokens": 0,
131
+ "total_cost": 0.0,
132
+ "call_count": 0,
133
+ "errors": 0,
134
+ }
135
+ for s in sessions:
136
+ agg["prompt_tokens"] += s.get("prompt_tokens", 0)
137
+ agg["completion_tokens"] += s.get("completion_tokens", 0)
138
+ agg["total_tokens"] += s.get("total_tokens", 0)
139
+ agg["total_cost"] += s.get("total_cost", 0.0)
140
+ agg["call_count"] += s.get("call_count", 0)
141
+ agg["errors"] += s.get("errors", 0)
142
+ return agg
143
+
144
+
145
+ def _now_ms() -> float:
146
+ """Current time in milliseconds (perf_counter-based)."""
147
+ return time.perf_counter() * 1000