prompture 0.0.36.dev1__py3-none-any.whl → 0.0.37.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. prompture/__init__.py +120 -2
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +925 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +879 -0
  6. prompture/async_conversation.py +199 -17
  7. prompture/async_driver.py +24 -0
  8. prompture/async_groups.py +551 -0
  9. prompture/conversation.py +213 -18
  10. prompture/core.py +30 -12
  11. prompture/discovery.py +24 -1
  12. prompture/driver.py +38 -0
  13. prompture/drivers/__init__.py +5 -1
  14. prompture/drivers/async_azure_driver.py +7 -1
  15. prompture/drivers/async_claude_driver.py +7 -1
  16. prompture/drivers/async_google_driver.py +24 -4
  17. prompture/drivers/async_grok_driver.py +7 -1
  18. prompture/drivers/async_groq_driver.py +7 -1
  19. prompture/drivers/async_lmstudio_driver.py +59 -3
  20. prompture/drivers/async_ollama_driver.py +7 -0
  21. prompture/drivers/async_openai_driver.py +7 -1
  22. prompture/drivers/async_openrouter_driver.py +7 -1
  23. prompture/drivers/async_registry.py +5 -1
  24. prompture/drivers/azure_driver.py +7 -1
  25. prompture/drivers/claude_driver.py +7 -1
  26. prompture/drivers/google_driver.py +24 -4
  27. prompture/drivers/grok_driver.py +7 -1
  28. prompture/drivers/groq_driver.py +7 -1
  29. prompture/drivers/lmstudio_driver.py +58 -6
  30. prompture/drivers/ollama_driver.py +7 -0
  31. prompture/drivers/openai_driver.py +7 -1
  32. prompture/drivers/openrouter_driver.py +7 -1
  33. prompture/drivers/vision_helpers.py +153 -0
  34. prompture/group_types.py +147 -0
  35. prompture/groups.py +530 -0
  36. prompture/image.py +180 -0
  37. prompture/persistence.py +254 -0
  38. prompture/persona.py +482 -0
  39. prompture/serialization.py +218 -0
  40. prompture/settings.py +1 -0
  41. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/METADATA +1 -1
  42. prompture-0.0.37.dev1.dist-info/RECORD +77 -0
  43. prompture-0.0.36.dev1.dist-info/RECORD +0 -66
  44. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/WHEEL +0 -0
  45. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/entry_points.txt +0 -0
  46. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/licenses/LICENSE +0 -0
  47. {prompture-0.0.36.dev1.dist-info → prompture-0.0.37.dev1.dist-info}/top_level.txt +0 -0
prompture/groups.py ADDED
@@ -0,0 +1,530 @@
1
+ """Synchronous multi-agent group coordination.
2
+
3
+ Provides :class:`SequentialGroup`, :class:`LoopGroup`,
4
+ :class:`RouterAgent`, and :class:`GroupAsAgent` for composing
5
+ multiple agents into deterministic workflows.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import re
12
+ import time
13
+ from collections.abc import Callable
14
+ from typing import Any
15
+
16
+ from .agent_types import AgentResult, AgentState
17
+ from .group_types import (
18
+ AgentError,
19
+ ErrorPolicy,
20
+ GroupCallbacks,
21
+ GroupResult,
22
+ GroupStep,
23
+ _aggregate_usage,
24
+ )
25
+
26
+ logger = logging.getLogger("prompture.groups")
27
+
28
+
29
+ # ------------------------------------------------------------------
30
+ # State injection helper
31
+ # ------------------------------------------------------------------
32
+
33
+
34
+ def _inject_state(template: str, state: dict[str, Any]) -> str:
35
+ """Replace ``{key}`` placeholders with state values.
36
+
37
+ Unknown keys pass through unchanged so downstream agents can
38
+ still see the literal placeholder.
39
+ """
40
+
41
+ def _replacer(m: re.Match[str]) -> str:
42
+ key = m.group(1)
43
+ if key in state:
44
+ return str(state[key])
45
+ return m.group(0) # leave unchanged
46
+
47
+ return re.sub(r"\{(\w+)\}", _replacer, template)
48
+
49
+
50
+ # ------------------------------------------------------------------
51
+ # Agent entry normalisation
52
+ # ------------------------------------------------------------------
53
+
54
+ AgentEntry = Any # Agent | tuple[Agent, str]
55
+
56
+
57
+ def _normalise_agents(agents: list[Any]) -> list[tuple[Any, str | None]]:
58
+ """Convert a mixed list of ``Agent`` or ``(Agent, prompt_template)`` to uniform tuples."""
59
+ result: list[tuple[Any, str | None]] = []
60
+ for item in agents:
61
+ if isinstance(item, tuple):
62
+ result.append((item[0], item[1]))
63
+ else:
64
+ result.append((item, None))
65
+ return result
66
+
67
+
68
+ def _agent_name(agent: Any, index: int) -> str:
69
+ """Determine a display name for an agent."""
70
+ name = getattr(agent, "name", "") or ""
71
+ return name if name else f"agent_{index}"
72
+
73
+
74
+ # ------------------------------------------------------------------
75
+ # SequentialGroup
76
+ # ------------------------------------------------------------------
77
+
78
+
79
+ class SequentialGroup:
80
+ """Execute agents in sequence, passing state between them.
81
+
82
+ Each agent's ``output_key`` (if set) writes its output text
83
+ into the shared state dict, making it available as ``{key}``
84
+ in subsequent agent prompts.
85
+
86
+ Args:
87
+ agents: List of agents or ``(agent, prompt_template)`` tuples.
88
+ state: Initial shared state dict.
89
+ error_policy: How to handle agent failures.
90
+ max_total_turns: Limit on total agent runs across the sequence.
91
+ callbacks: Observability hooks.
92
+ max_total_cost: Budget cap in USD.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ agents: list[Any],
98
+ *,
99
+ state: dict[str, Any] | None = None,
100
+ error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
101
+ max_total_turns: int | None = None,
102
+ callbacks: GroupCallbacks | None = None,
103
+ max_total_cost: float | None = None,
104
+ ) -> None:
105
+ self._agents = _normalise_agents(agents)
106
+ self._state: dict[str, Any] = dict(state) if state else {}
107
+ self._error_policy = error_policy
108
+ self._max_total_turns = max_total_turns
109
+ self._callbacks = callbacks or GroupCallbacks()
110
+ self._max_total_cost = max_total_cost
111
+ self._stop_requested = False
112
+
113
+ def stop(self) -> None:
114
+ """Request graceful shutdown after the current agent finishes."""
115
+ self._stop_requested = True
116
+
117
+ def save(self, path: str) -> None:
118
+ """Run and save result to file. Convenience wrapper."""
119
+ result = self.run()
120
+ result.save(path)
121
+
122
+ def run(self, prompt: str = "") -> GroupResult:
123
+ """Execute all agents in order."""
124
+ self._stop_requested = False
125
+ t0 = time.perf_counter()
126
+ timeline: list[GroupStep] = []
127
+ agent_results: dict[str, Any] = {}
128
+ errors: list[AgentError] = []
129
+ usage_summaries: list[dict[str, Any]] = []
130
+ turns = 0
131
+
132
+ for idx, (agent, custom_prompt) in enumerate(self._agents):
133
+ if self._stop_requested:
134
+ break
135
+
136
+ name = _agent_name(agent, idx)
137
+
138
+ # Build effective prompt
139
+ if custom_prompt is not None:
140
+ effective = _inject_state(custom_prompt, self._state)
141
+ elif prompt:
142
+ effective = _inject_state(prompt, self._state)
143
+ else:
144
+ effective = ""
145
+
146
+ # Check budget
147
+ if self._max_total_cost is not None:
148
+ total_so_far = sum(s.get("total_cost", 0.0) for s in usage_summaries)
149
+ if total_so_far >= self._max_total_cost:
150
+ logger.debug("Budget exceeded, stopping group")
151
+ break
152
+
153
+ # Check max turns
154
+ if self._max_total_turns is not None and turns >= self._max_total_turns:
155
+ logger.debug("Max total turns reached")
156
+ break
157
+
158
+ # Fire callback
159
+ if self._callbacks.on_agent_start:
160
+ self._callbacks.on_agent_start(name, effective)
161
+
162
+ step_t0 = time.perf_counter()
163
+ try:
164
+ result = agent.run(effective)
165
+ duration_ms = (time.perf_counter() - step_t0) * 1000
166
+ turns += 1
167
+
168
+ agent_results[name] = result
169
+ usage = getattr(result, "run_usage", {})
170
+ usage_summaries.append(usage)
171
+
172
+ # Write to shared state
173
+ output_key = getattr(agent, "output_key", None)
174
+ if output_key:
175
+ self._state[output_key] = result.output_text
176
+ if self._callbacks.on_state_update:
177
+ self._callbacks.on_state_update(output_key, result.output_text)
178
+
179
+ timeline.append(
180
+ GroupStep(
181
+ agent_name=name,
182
+ step_type="agent_run",
183
+ timestamp=step_t0,
184
+ duration_ms=duration_ms,
185
+ usage_delta=usage,
186
+ )
187
+ )
188
+
189
+ if self._callbacks.on_agent_complete:
190
+ self._callbacks.on_agent_complete(name, result)
191
+
192
+ except Exception as exc:
193
+ duration_ms = (time.perf_counter() - step_t0) * 1000
194
+ turns += 1
195
+ err = AgentError(
196
+ agent_name=name,
197
+ error=exc,
198
+ output_key=getattr(agent, "output_key", None),
199
+ )
200
+ errors.append(err)
201
+ timeline.append(
202
+ GroupStep(
203
+ agent_name=name,
204
+ step_type="agent_error",
205
+ timestamp=step_t0,
206
+ duration_ms=duration_ms,
207
+ error=str(exc),
208
+ )
209
+ )
210
+
211
+ if self._callbacks.on_agent_error:
212
+ self._callbacks.on_agent_error(name, exc)
213
+
214
+ if self._error_policy == ErrorPolicy.fail_fast:
215
+ break
216
+ # continue_on_error / retry_failed: continue to next agent
217
+
218
+ elapsed_ms = (time.perf_counter() - t0) * 1000
219
+ return GroupResult(
220
+ agent_results=agent_results,
221
+ aggregate_usage=_aggregate_usage(*usage_summaries),
222
+ shared_state=dict(self._state),
223
+ elapsed_ms=elapsed_ms,
224
+ timeline=timeline,
225
+ errors=errors,
226
+ success=len(errors) == 0,
227
+ )
228
+
229
+
230
+ # ------------------------------------------------------------------
231
+ # LoopGroup
232
+ # ------------------------------------------------------------------
233
+
234
+
235
+ class LoopGroup:
236
+ """Repeat a sequence of agents until an exit condition is met.
237
+
238
+ Args:
239
+ agents: List of agents or ``(agent, prompt_template)`` tuples.
240
+ exit_condition: Callable ``(state, iteration) -> bool``.
241
+ When it returns ``True`` the loop stops.
242
+ max_iterations: Hard cap on loop iterations.
243
+ state: Initial shared state dict.
244
+ error_policy: How to handle agent failures.
245
+ callbacks: Observability hooks.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ agents: list[Any],
251
+ *,
252
+ exit_condition: Callable[[dict[str, Any], int], bool],
253
+ max_iterations: int = 10,
254
+ state: dict[str, Any] | None = None,
255
+ error_policy: ErrorPolicy = ErrorPolicy.fail_fast,
256
+ callbacks: GroupCallbacks | None = None,
257
+ ) -> None:
258
+ self._agents = _normalise_agents(agents)
259
+ self._exit_condition = exit_condition
260
+ self._max_iterations = max_iterations
261
+ self._state: dict[str, Any] = dict(state) if state else {}
262
+ self._error_policy = error_policy
263
+ self._callbacks = callbacks or GroupCallbacks()
264
+ self._stop_requested = False
265
+
266
+ def stop(self) -> None:
267
+ """Request graceful shutdown."""
268
+ self._stop_requested = True
269
+
270
+ def run(self, prompt: str = "") -> GroupResult:
271
+ """Execute the loop."""
272
+ self._stop_requested = False
273
+ t0 = time.perf_counter()
274
+ timeline: list[GroupStep] = []
275
+ agent_results: dict[str, Any] = {}
276
+ errors: list[AgentError] = []
277
+ usage_summaries: list[dict[str, Any]] = []
278
+
279
+ for iteration in range(self._max_iterations):
280
+ if self._stop_requested:
281
+ break
282
+ if self._exit_condition(self._state, iteration):
283
+ break
284
+
285
+ for idx, (agent, custom_prompt) in enumerate(self._agents):
286
+ if self._stop_requested:
287
+ break
288
+
289
+ name = _agent_name(agent, idx)
290
+ result_key = f"{name}_iter{iteration}"
291
+
292
+ if custom_prompt is not None:
293
+ effective = _inject_state(custom_prompt, self._state)
294
+ elif prompt:
295
+ effective = _inject_state(prompt, self._state)
296
+ else:
297
+ effective = ""
298
+
299
+ if self._callbacks.on_agent_start:
300
+ self._callbacks.on_agent_start(name, effective)
301
+
302
+ step_t0 = time.perf_counter()
303
+ try:
304
+ result = agent.run(effective)
305
+ duration_ms = (time.perf_counter() - step_t0) * 1000
306
+
307
+ agent_results[result_key] = result
308
+ usage = getattr(result, "run_usage", {})
309
+ usage_summaries.append(usage)
310
+
311
+ output_key = getattr(agent, "output_key", None)
312
+ if output_key:
313
+ self._state[output_key] = result.output_text
314
+ if self._callbacks.on_state_update:
315
+ self._callbacks.on_state_update(output_key, result.output_text)
316
+
317
+ timeline.append(
318
+ GroupStep(
319
+ agent_name=name,
320
+ step_type="agent_run",
321
+ timestamp=step_t0,
322
+ duration_ms=duration_ms,
323
+ usage_delta=usage,
324
+ )
325
+ )
326
+
327
+ if self._callbacks.on_agent_complete:
328
+ self._callbacks.on_agent_complete(name, result)
329
+
330
+ except Exception as exc:
331
+ duration_ms = (time.perf_counter() - step_t0) * 1000
332
+ err = AgentError(
333
+ agent_name=name,
334
+ error=exc,
335
+ output_key=getattr(agent, "output_key", None),
336
+ )
337
+ errors.append(err)
338
+ timeline.append(
339
+ GroupStep(
340
+ agent_name=name,
341
+ step_type="agent_error",
342
+ timestamp=step_t0,
343
+ duration_ms=duration_ms,
344
+ error=str(exc),
345
+ )
346
+ )
347
+
348
+ if self._callbacks.on_agent_error:
349
+ self._callbacks.on_agent_error(name, exc)
350
+
351
+ if self._error_policy == ErrorPolicy.fail_fast:
352
+ break
353
+
354
+ # Check if error caused early exit
355
+ if errors and self._error_policy == ErrorPolicy.fail_fast:
356
+ break
357
+
358
+ elapsed_ms = (time.perf_counter() - t0) * 1000
359
+ return GroupResult(
360
+ agent_results=agent_results,
361
+ aggregate_usage=_aggregate_usage(*usage_summaries),
362
+ shared_state=dict(self._state),
363
+ elapsed_ms=elapsed_ms,
364
+ timeline=timeline,
365
+ errors=errors,
366
+ success=len(errors) == 0,
367
+ )
368
+
369
+
370
+ # ------------------------------------------------------------------
371
+ # RouterAgent
372
+ # ------------------------------------------------------------------
373
+
374
+ _DEFAULT_ROUTING_PROMPT = """Given these specialists:
375
+ {agent_list}
376
+
377
+ Which should handle this? Reply with ONLY the name.
378
+
379
+ Request: {prompt}"""
380
+
381
+
382
+ class RouterAgent:
383
+ """LLM-driven router that delegates to the best-matching agent.
384
+
385
+ Args:
386
+ model: Model string for the routing LLM call.
387
+ agents: List of agents to route between.
388
+ routing_prompt: Custom prompt template (must include ``{agent_list}``
389
+ and ``{prompt}`` placeholders).
390
+ fallback: Agent to use when routing fails.
391
+ driver: Pre-built driver instance for the routing call.
392
+ """
393
+
394
+ def __init__(
395
+ self,
396
+ model: str = "",
397
+ *,
398
+ agents: list[Any],
399
+ routing_prompt: str | None = None,
400
+ fallback: Any | None = None,
401
+ driver: Any | None = None,
402
+ name: str = "",
403
+ description: str = "",
404
+ output_key: str | None = None,
405
+ ) -> None:
406
+ self._model = model
407
+ self._driver = driver
408
+ self._agents = {_agent_name(a, i): a for i, a in enumerate(agents)}
409
+ self._routing_prompt = routing_prompt or _DEFAULT_ROUTING_PROMPT
410
+ self._fallback = fallback
411
+ self.name = name
412
+ self.description = description
413
+ self.output_key = output_key
414
+
415
+ def run(self, prompt: str, *, deps: Any = None) -> AgentResult:
416
+ """Route the prompt to the best agent and return its result."""
417
+ from .conversation import Conversation
418
+
419
+ # Build agent list for routing prompt
420
+ agent_lines = []
421
+ for name, agent in self._agents.items():
422
+ desc = getattr(agent, "description", "") or ""
423
+ agent_lines.append(f"- {name}: {desc}" if desc else f"- {name}")
424
+ agent_list = "\n".join(agent_lines)
425
+
426
+ routing_text = self._routing_prompt.replace("{agent_list}", agent_list).replace("{prompt}", prompt)
427
+
428
+ # Single LLM call for routing
429
+ kwargs: dict[str, Any] = {}
430
+ if self._driver is not None:
431
+ kwargs["driver"] = self._driver
432
+ else:
433
+ kwargs["model_name"] = self._model
434
+
435
+ conv = Conversation(**kwargs)
436
+ route_response = conv.ask(routing_text)
437
+
438
+ # Fuzzy match against known agent names
439
+ selected = self._fuzzy_match(route_response.strip())
440
+
441
+ if selected is not None:
442
+ return selected.run(prompt, deps=deps) if deps is not None else selected.run(prompt)
443
+ elif self._fallback is not None:
444
+ return self._fallback.run(prompt, deps=deps) if deps is not None else self._fallback.run(prompt)
445
+ else:
446
+ # Return routing response as fallback
447
+ return AgentResult(
448
+ output=route_response,
449
+ output_text=route_response,
450
+ messages=conv.messages,
451
+ usage=conv.usage,
452
+ state=AgentState.idle,
453
+ )
454
+
455
+ def _fuzzy_match(self, response: str) -> Any | None:
456
+ """Find the best matching agent name in the LLM response."""
457
+ response_lower = response.lower().strip()
458
+
459
+ # Exact match
460
+ for name, agent in self._agents.items():
461
+ if name.lower() == response_lower:
462
+ return agent
463
+
464
+ # Substring match
465
+ for name, agent in self._agents.items():
466
+ if name.lower() in response_lower:
467
+ return agent
468
+
469
+ # Word-level match
470
+ response_words = set(response_lower.split())
471
+ for name, agent in self._agents.items():
472
+ name_words = set(name.lower().replace("_", " ").split())
473
+ if name_words & response_words:
474
+ return agent
475
+
476
+ return None
477
+
478
+
479
+ # ------------------------------------------------------------------
480
+ # GroupAsAgent
481
+ # ------------------------------------------------------------------
482
+
483
+
484
+ class GroupAsAgent:
485
+ """Adapter that makes a group behave like an Agent for composability.
486
+
487
+ Allows nesting groups inside other groups by presenting the same
488
+ ``run(prompt) -> AgentResult`` interface.
489
+
490
+ Args:
491
+ group: The group to wrap (SequentialGroup, LoopGroup, etc.).
492
+ name: Agent identity name.
493
+ output_key: Shared state key for writing output.
494
+ """
495
+
496
+ def __init__(
497
+ self,
498
+ group: Any,
499
+ *,
500
+ name: str = "",
501
+ output_key: str | None = None,
502
+ ) -> None:
503
+ self._group = group
504
+ self.name = name
505
+ self.output_key = output_key
506
+ self.description = ""
507
+
508
+ def run(self, prompt: str, **kwargs: Any) -> AgentResult:
509
+ """Run the wrapped group and return an AgentResult."""
510
+ group_result = self._group.run(prompt)
511
+
512
+ # Use the last agent's output text, or the shared state
513
+ output_text = ""
514
+ if group_result.agent_results:
515
+ last_result = list(group_result.agent_results.values())[-1]
516
+ output_text = getattr(last_result, "output_text", str(last_result))
517
+
518
+ return AgentResult(
519
+ output=output_text,
520
+ output_text=output_text,
521
+ messages=[],
522
+ usage=group_result.aggregate_usage,
523
+ state=AgentState.idle,
524
+ run_usage=group_result.aggregate_usage,
525
+ )
526
+
527
+ def stop(self) -> None:
528
+ """Propagate stop to the wrapped group."""
529
+ if hasattr(self._group, "stop"):
530
+ self._group.stop()
prompture/image.py ADDED
@@ -0,0 +1,180 @@
1
+ """Image handling utilities for vision-capable LLM drivers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import mimetypes
7
+ import re
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Union
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class ImageContent:
15
+ """Normalized image representation for vision-capable drivers.
16
+
17
+ Attributes:
18
+ data: Base64-encoded image data.
19
+ media_type: MIME type (e.g. ``"image/png"``, ``"image/jpeg"``).
20
+ source_type: How the image is delivered — ``"base64"`` or ``"url"``.
21
+ url: Original URL when ``source_type`` is ``"url"``.
22
+ """
23
+
24
+ data: str
25
+ media_type: str
26
+ source_type: str = "base64"
27
+ url: str | None = None
28
+
29
+
30
+ # Public type alias accepted by all image-aware APIs.
31
+ ImageInput = Union[bytes, str, Path, ImageContent]
32
+
33
+ # Known data-URI prefix pattern
34
+ _DATA_URI_RE = re.compile(r"^data:(image/[a-zA-Z0-9.+-]+);base64,(.+)$", re.DOTALL)
35
+
36
+ # Base64 detection heuristic — must look like pure base64 of reasonable length
37
+ _BASE64_RE = re.compile(r"^[A-Za-z0-9+/\n\r]+=*$")
38
+
39
+ _MIME_FROM_EXT: dict[str, str] = {
40
+ ".jpg": "image/jpeg",
41
+ ".jpeg": "image/jpeg",
42
+ ".png": "image/png",
43
+ ".gif": "image/gif",
44
+ ".webp": "image/webp",
45
+ ".bmp": "image/bmp",
46
+ ".svg": "image/svg+xml",
47
+ ".tiff": "image/tiff",
48
+ ".tif": "image/tiff",
49
+ }
50
+
51
+ _MAGIC_BYTES: list[tuple[bytes, str]] = [
52
+ (b"\x89PNG", "image/png"),
53
+ (b"\xff\xd8\xff", "image/jpeg"),
54
+ (b"GIF87a", "image/gif"),
55
+ (b"GIF89a", "image/gif"),
56
+ (b"RIFF", "image/webp"), # WebP starts with RIFF...WEBP
57
+ (b"BM", "image/bmp"),
58
+ ]
59
+
60
+
61
+ def _guess_media_type_from_bytes(data: bytes) -> str:
62
+ """Guess MIME type from the first few bytes of image data."""
63
+ for magic, mime in _MAGIC_BYTES:
64
+ if data[: len(magic)] == magic:
65
+ return mime
66
+ return "image/png" # safe fallback
67
+
68
+
69
+ def _guess_media_type(path: str) -> str:
70
+ """Guess MIME type from a file path or URL."""
71
+ # Strip query strings for URLs
72
+ clean = path.split("?")[0].split("#")[0]
73
+ ext = Path(clean).suffix.lower()
74
+ if ext in _MIME_FROM_EXT:
75
+ return _MIME_FROM_EXT[ext]
76
+ guessed = mimetypes.guess_type(clean)[0]
77
+ return guessed or "image/png"
78
+
79
+
80
+ # ------------------------------------------------------------------
81
+ # Constructor functions
82
+ # ------------------------------------------------------------------
83
+
84
+
85
+ def image_from_bytes(data: bytes, media_type: str | None = None) -> ImageContent:
86
+ """Create an :class:`ImageContent` from raw bytes.
87
+
88
+ Args:
89
+ data: Raw image bytes.
90
+ media_type: MIME type. Auto-detected from magic bytes when *None*.
91
+ """
92
+ if not data:
93
+ raise ValueError("Image data cannot be empty")
94
+ b64 = base64.b64encode(data).decode("ascii")
95
+ mt = media_type or _guess_media_type_from_bytes(data)
96
+ return ImageContent(data=b64, media_type=mt)
97
+
98
+
99
+ def image_from_base64(b64: str, media_type: str = "image/png") -> ImageContent:
100
+ """Create an :class:`ImageContent` from a base64-encoded string.
101
+
102
+ Accepts both raw base64 and ``data:`` URIs.
103
+ """
104
+ m = _DATA_URI_RE.match(b64)
105
+ if m:
106
+ return ImageContent(data=m.group(2), media_type=m.group(1))
107
+ return ImageContent(data=b64, media_type=media_type)
108
+
109
+
110
+ def image_from_file(path: str | Path, media_type: str | None = None) -> ImageContent:
111
+ """Create an :class:`ImageContent` by reading a local file.
112
+
113
+ Args:
114
+ path: Path to an image file.
115
+ media_type: MIME type. Guessed from extension when *None*.
116
+ """
117
+ p = Path(path)
118
+ if not p.exists():
119
+ raise FileNotFoundError(f"Image file not found: {p}")
120
+ raw = p.read_bytes()
121
+ mt = media_type or _guess_media_type(str(p))
122
+ return image_from_bytes(raw, mt)
123
+
124
+
125
+ def image_from_url(url: str, media_type: str | None = None) -> ImageContent:
126
+ """Create an :class:`ImageContent` referencing a remote URL.
127
+
128
+ The image is **not** downloaded — the URL is stored directly so
129
+ drivers that accept URL-based images can pass it through. For
130
+ drivers that require base64, the URL is embedded as a data URI.
131
+
132
+ Args:
133
+ url: Publicly-accessible image URL.
134
+ media_type: MIME type. Guessed from the URL when *None*.
135
+ """
136
+ mt = media_type or _guess_media_type(url)
137
+ return ImageContent(data="", media_type=mt, source_type="url", url=url)
138
+
139
+
140
+ # ------------------------------------------------------------------
141
+ # Smart constructor
142
+ # ------------------------------------------------------------------
143
+
144
+
145
+ def make_image(source: ImageInput) -> ImageContent:
146
+ """Auto-detect the source type and return an :class:`ImageContent`.
147
+
148
+ Accepts:
149
+ - ``ImageContent`` — returned as-is.
150
+ - ``bytes`` — base64-encoded with auto-detected MIME.
151
+ - ``str`` — tries (in order): data URI, URL, file path, raw base64.
152
+ - ``pathlib.Path`` — read from disk.
153
+ """
154
+ if isinstance(source, ImageContent):
155
+ return source
156
+
157
+ if isinstance(source, bytes):
158
+ return image_from_bytes(source)
159
+
160
+ if isinstance(source, Path):
161
+ return image_from_file(source)
162
+
163
+ if isinstance(source, str):
164
+ # 1. data URI
165
+ if source.startswith("data:"):
166
+ return image_from_base64(source)
167
+
168
+ # 2. URL
169
+ if source.startswith(("http://", "https://")):
170
+ return image_from_url(source)
171
+
172
+ # 3. File path (if exists on disk)
173
+ p = Path(source)
174
+ if p.exists():
175
+ return image_from_file(p)
176
+
177
+ # 4. Assume raw base64
178
+ return image_from_base64(source)
179
+
180
+ raise TypeError(f"Unsupported image source type: {type(source).__name__}")