prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +264 -23
- prompture/_version.py +34 -0
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +789 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +193 -0
- prompture/async_groups.py +551 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +826 -0
- prompture/core.py +894 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +187 -0
- prompture/driver.py +206 -5
- prompture/drivers/__init__.py +175 -67
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +123 -0
- prompture/drivers/async_claude_driver.py +113 -0
- prompture/drivers/async_google_driver.py +316 -0
- prompture/drivers/async_grok_driver.py +97 -0
- prompture/drivers/async_groq_driver.py +90 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +148 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +135 -0
- prompture/drivers/async_openai_driver.py +102 -0
- prompture/drivers/async_openrouter_driver.py +102 -0
- prompture/drivers/async_registry.py +133 -0
- prompture/drivers/azure_driver.py +42 -9
- prompture/drivers/claude_driver.py +257 -34
- prompture/drivers/google_driver.py +295 -42
- prompture/drivers/grok_driver.py +35 -32
- prompture/drivers/groq_driver.py +33 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +97 -19
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +168 -23
- prompture/drivers/openai_driver.py +184 -9
- prompture/drivers/openrouter_driver.py +37 -25
- prompture/drivers/registry.py +306 -0
- prompture/drivers/vision_helpers.py +153 -0
- prompture/field_definitions.py +106 -96
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/serialization.py +218 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +19 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/METADATA +0 -368
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
prompture/cost_mixin.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Shared cost-calculation mixin for LLM drivers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CostMixin:
|
|
9
|
+
"""Mixin that provides ``_calculate_cost`` to sync and async drivers.
|
|
10
|
+
|
|
11
|
+
Drivers that charge per-token should inherit from this mixin alongside
|
|
12
|
+
their base class (``Driver`` or ``AsyncDriver``). Free/local drivers
|
|
13
|
+
(Ollama, LM Studio, LocalHTTP, HuggingFace, AirLLM) can skip it.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
# Subclasses must define MODEL_PRICING as a class attribute.
|
|
17
|
+
MODEL_PRICING: dict[str, dict[str, Any]] = {}
|
|
18
|
+
|
|
19
|
+
# Divisor for hardcoded MODEL_PRICING rates.
|
|
20
|
+
# Most drivers use per-1K-token pricing (1_000).
|
|
21
|
+
# Grok uses per-1M-token pricing (1_000_000).
|
|
22
|
+
# Google uses per-1M-character pricing (1_000_000).
|
|
23
|
+
_PRICING_UNIT: int = 1_000
|
|
24
|
+
|
|
25
|
+
def _calculate_cost(
|
|
26
|
+
self,
|
|
27
|
+
provider: str,
|
|
28
|
+
model: str,
|
|
29
|
+
prompt_tokens: int | float,
|
|
30
|
+
completion_tokens: int | float,
|
|
31
|
+
) -> float:
|
|
32
|
+
"""Calculate USD cost for a generation call.
|
|
33
|
+
|
|
34
|
+
Resolution order:
|
|
35
|
+
1. Live rates from ``model_rates.get_model_rates()`` (per 1M tokens).
|
|
36
|
+
2. Hardcoded ``MODEL_PRICING`` on the driver class (unit set by ``_PRICING_UNIT``).
|
|
37
|
+
3. Zero if neither source has data.
|
|
38
|
+
"""
|
|
39
|
+
from .model_rates import get_model_rates
|
|
40
|
+
|
|
41
|
+
live_rates = get_model_rates(provider, model)
|
|
42
|
+
if live_rates:
|
|
43
|
+
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
44
|
+
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
45
|
+
else:
|
|
46
|
+
unit = self._PRICING_UNIT
|
|
47
|
+
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
48
|
+
prompt_cost = (prompt_tokens / unit) * model_pricing["prompt"]
|
|
49
|
+
completion_cost = (completion_tokens / unit) * model_pricing["completion"]
|
|
50
|
+
|
|
51
|
+
return round(prompt_cost + completion_cost, 6)
|
prompture/discovery.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Discovery module for auto-detecting available models."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
|
|
8
|
+
from .drivers import (
|
|
9
|
+
AzureDriver,
|
|
10
|
+
ClaudeDriver,
|
|
11
|
+
GoogleDriver,
|
|
12
|
+
GrokDriver,
|
|
13
|
+
GroqDriver,
|
|
14
|
+
LMStudioDriver,
|
|
15
|
+
LocalHTTPDriver,
|
|
16
|
+
OllamaDriver,
|
|
17
|
+
OpenAIDriver,
|
|
18
|
+
OpenRouterDriver,
|
|
19
|
+
)
|
|
20
|
+
from .settings import settings
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_available_models() -> list[str]:
|
|
26
|
+
"""
|
|
27
|
+
Auto-detects all available models based on configured drivers and environment variables.
|
|
28
|
+
|
|
29
|
+
Iterates through supported providers and checks if they are configured (e.g. API key present).
|
|
30
|
+
For static drivers, returns models from their MODEL_PRICING keys.
|
|
31
|
+
For dynamic drivers (like Ollama), attempts to fetch available models from the endpoint.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
A list of unique model strings in the format "provider/model_id".
|
|
35
|
+
"""
|
|
36
|
+
available_models: set[str] = set()
|
|
37
|
+
configured_providers: set[str] = set()
|
|
38
|
+
|
|
39
|
+
# Map of provider name to driver class
|
|
40
|
+
# We need to map the registry keys to the actual classes to check MODEL_PRICING
|
|
41
|
+
# and instantiate for dynamic checks if needed.
|
|
42
|
+
provider_classes = {
|
|
43
|
+
"openai": OpenAIDriver,
|
|
44
|
+
"azure": AzureDriver,
|
|
45
|
+
"claude": ClaudeDriver,
|
|
46
|
+
"google": GoogleDriver,
|
|
47
|
+
"groq": GroqDriver,
|
|
48
|
+
"openrouter": OpenRouterDriver,
|
|
49
|
+
"grok": GrokDriver,
|
|
50
|
+
"ollama": OllamaDriver,
|
|
51
|
+
"lmstudio": LMStudioDriver,
|
|
52
|
+
"local_http": LocalHTTPDriver,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
for provider, driver_cls in provider_classes.items():
|
|
56
|
+
try:
|
|
57
|
+
# 1. Check if the provider is configured (has API key or endpoint)
|
|
58
|
+
# We can check this by looking at the settings or env vars that the driver uses.
|
|
59
|
+
# A simple way is to try to instantiate it with defaults, but that might fail if keys are missing.
|
|
60
|
+
# Instead, let's check the specific requirements for each known provider.
|
|
61
|
+
|
|
62
|
+
is_configured = False
|
|
63
|
+
|
|
64
|
+
if provider == "openai":
|
|
65
|
+
if settings.openai_api_key or os.getenv("OPENAI_API_KEY"):
|
|
66
|
+
is_configured = True
|
|
67
|
+
elif provider == "azure":
|
|
68
|
+
if (
|
|
69
|
+
(settings.azure_api_key or os.getenv("AZURE_API_KEY"))
|
|
70
|
+
and (settings.azure_api_endpoint or os.getenv("AZURE_API_ENDPOINT"))
|
|
71
|
+
and (settings.azure_deployment_id or os.getenv("AZURE_DEPLOYMENT_ID"))
|
|
72
|
+
):
|
|
73
|
+
is_configured = True
|
|
74
|
+
elif provider == "claude":
|
|
75
|
+
if settings.claude_api_key or os.getenv("CLAUDE_API_KEY"):
|
|
76
|
+
is_configured = True
|
|
77
|
+
elif provider == "google":
|
|
78
|
+
if settings.google_api_key or os.getenv("GOOGLE_API_KEY"):
|
|
79
|
+
is_configured = True
|
|
80
|
+
elif provider == "groq":
|
|
81
|
+
if settings.groq_api_key or os.getenv("GROQ_API_KEY"):
|
|
82
|
+
is_configured = True
|
|
83
|
+
elif provider == "openrouter":
|
|
84
|
+
if settings.openrouter_api_key or os.getenv("OPENROUTER_API_KEY"):
|
|
85
|
+
is_configured = True
|
|
86
|
+
elif provider == "grok":
|
|
87
|
+
if settings.grok_api_key or os.getenv("GROK_API_KEY"):
|
|
88
|
+
is_configured = True
|
|
89
|
+
elif provider == "ollama":
|
|
90
|
+
# Ollama is always considered "configured" as it defaults to localhost
|
|
91
|
+
# We will check connectivity later
|
|
92
|
+
is_configured = True
|
|
93
|
+
elif provider == "lmstudio":
|
|
94
|
+
# LM Studio is similar to Ollama, defaults to localhost
|
|
95
|
+
is_configured = True
|
|
96
|
+
elif provider == "local_http" and (settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT")):
|
|
97
|
+
is_configured = True
|
|
98
|
+
|
|
99
|
+
if not is_configured:
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
configured_providers.add(provider)
|
|
103
|
+
|
|
104
|
+
# 2. Static Detection: Get models from MODEL_PRICING
|
|
105
|
+
if hasattr(driver_cls, "MODEL_PRICING"):
|
|
106
|
+
pricing = driver_cls.MODEL_PRICING
|
|
107
|
+
for model_id in pricing:
|
|
108
|
+
# Skip "default" or generic keys if they exist
|
|
109
|
+
if model_id == "default":
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
# For Azure, the model_id in pricing is usually the base model name,
|
|
113
|
+
# but the user needs to use the deployment ID.
|
|
114
|
+
# However, our Azure driver implementation uses the deployment_id from init
|
|
115
|
+
# as the "model" for the request, but expects the user to pass a model name
|
|
116
|
+
# that maps to pricing?
|
|
117
|
+
# Looking at AzureDriver:
|
|
118
|
+
# kwargs = {"model": self.deployment_id, ...}
|
|
119
|
+
# model = options.get("model", self.model) -> used for pricing lookup
|
|
120
|
+
# So we should list the keys in MODEL_PRICING as available "models"
|
|
121
|
+
# even though for Azure specifically it's a bit weird because of deployment IDs.
|
|
122
|
+
# But for general discovery, listing supported models is correct.
|
|
123
|
+
|
|
124
|
+
available_models.add(f"{provider}/{model_id}")
|
|
125
|
+
|
|
126
|
+
# 3. Dynamic Detection: Specific logic for Ollama
|
|
127
|
+
if provider == "ollama":
|
|
128
|
+
try:
|
|
129
|
+
endpoint = settings.ollama_endpoint or os.getenv(
|
|
130
|
+
"OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
|
|
131
|
+
)
|
|
132
|
+
# We need the base URL for tags, usually http://localhost:11434/api/tags
|
|
133
|
+
# The configured endpoint might be .../api/generate or .../api/chat
|
|
134
|
+
base_url = endpoint.split("/api/")[0]
|
|
135
|
+
tags_url = f"{base_url}/api/tags"
|
|
136
|
+
|
|
137
|
+
resp = requests.get(tags_url, timeout=2)
|
|
138
|
+
if resp.status_code == 200:
|
|
139
|
+
data = resp.json()
|
|
140
|
+
models = data.get("models", [])
|
|
141
|
+
for model in models:
|
|
142
|
+
name = model.get("name")
|
|
143
|
+
if name:
|
|
144
|
+
# Ollama model names often include tags like "llama3:latest"
|
|
145
|
+
# We can keep them as is.
|
|
146
|
+
available_models.add(f"ollama/{name}")
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.debug(f"Failed to fetch Ollama models: {e}")
|
|
149
|
+
|
|
150
|
+
# Dynamic Detection: LM Studio loaded models
|
|
151
|
+
if provider == "lmstudio":
|
|
152
|
+
try:
|
|
153
|
+
endpoint = settings.lmstudio_endpoint or os.getenv(
|
|
154
|
+
"LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions"
|
|
155
|
+
)
|
|
156
|
+
base_url = endpoint.split("/v1/")[0]
|
|
157
|
+
models_url = f"{base_url}/v1/models"
|
|
158
|
+
|
|
159
|
+
headers: dict[str, str] = {}
|
|
160
|
+
api_key = settings.lmstudio_api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
161
|
+
if api_key:
|
|
162
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
163
|
+
|
|
164
|
+
resp = requests.get(models_url, headers=headers, timeout=2)
|
|
165
|
+
if resp.status_code == 200:
|
|
166
|
+
data = resp.json()
|
|
167
|
+
models = data.get("data", [])
|
|
168
|
+
for model in models:
|
|
169
|
+
model_id = model.get("id")
|
|
170
|
+
if model_id:
|
|
171
|
+
available_models.add(f"lmstudio/{model_id}")
|
|
172
|
+
except Exception as e:
|
|
173
|
+
logger.debug(f"Failed to fetch LM Studio models: {e}")
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.warning(f"Error detecting models for provider {provider}: {e}")
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
# Enrich with live model list from models.dev cache
|
|
180
|
+
from .model_rates import PROVIDER_MAP, get_all_provider_models
|
|
181
|
+
|
|
182
|
+
for prompture_name, api_name in PROVIDER_MAP.items():
|
|
183
|
+
if prompture_name in configured_providers:
|
|
184
|
+
for model_id in get_all_provider_models(api_name):
|
|
185
|
+
available_models.add(f"{prompture_name}/{model_id}")
|
|
186
|
+
|
|
187
|
+
return sorted(list(available_models))
|
prompture/driver.py
CHANGED
|
@@ -1,7 +1,16 @@
|
|
|
1
|
-
"""Driver base class for LLM adapters.
|
|
2
|
-
|
|
1
|
+
"""Driver base class for LLM adapters."""
|
|
2
|
+
|
|
3
3
|
from __future__ import annotations
|
|
4
|
-
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from collections.abc import Iterator
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from .callbacks import DriverCallbacks
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger("prompture.driver")
|
|
13
|
+
|
|
5
14
|
|
|
6
15
|
class Driver:
|
|
7
16
|
"""Adapter base. Implementar generate(prompt, options) -> {"text": ... , "meta": {...}}
|
|
@@ -20,5 +29,197 @@ class Driver:
|
|
|
20
29
|
additional provider-specific metadata while the core fields provide
|
|
21
30
|
standardized access to token usage and cost information.
|
|
22
31
|
"""
|
|
23
|
-
|
|
24
|
-
|
|
32
|
+
|
|
33
|
+
supports_json_mode: bool = False
|
|
34
|
+
supports_json_schema: bool = False
|
|
35
|
+
supports_messages: bool = False
|
|
36
|
+
supports_tool_use: bool = False
|
|
37
|
+
supports_streaming: bool = False
|
|
38
|
+
supports_vision: bool = False
|
|
39
|
+
|
|
40
|
+
callbacks: DriverCallbacks | None = None
|
|
41
|
+
|
|
42
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
46
|
+
"""Generate a response from a list of conversation messages.
|
|
47
|
+
|
|
48
|
+
Each message is a dict with ``"role"`` (``"system"``, ``"user"``, or
|
|
49
|
+
``"assistant"``) and ``"content"`` keys.
|
|
50
|
+
|
|
51
|
+
The default implementation flattens the messages into a single prompt
|
|
52
|
+
string and delegates to :meth:`generate`. Drivers that natively
|
|
53
|
+
support message arrays should override this method and set
|
|
54
|
+
``supports_messages = True``.
|
|
55
|
+
"""
|
|
56
|
+
self._check_vision_support(messages)
|
|
57
|
+
prompt = self._flatten_messages(messages)
|
|
58
|
+
return self.generate(prompt, options)
|
|
59
|
+
|
|
60
|
+
# ------------------------------------------------------------------
|
|
61
|
+
# Tool use
|
|
62
|
+
# ------------------------------------------------------------------
|
|
63
|
+
|
|
64
|
+
def generate_messages_with_tools(
|
|
65
|
+
self,
|
|
66
|
+
messages: list[dict[str, Any]],
|
|
67
|
+
tools: list[dict[str, Any]],
|
|
68
|
+
options: dict[str, Any],
|
|
69
|
+
) -> dict[str, Any]:
|
|
70
|
+
"""Generate a response that may include tool calls.
|
|
71
|
+
|
|
72
|
+
Returns a dict with keys: ``text``, ``meta``, ``tool_calls``, ``stop_reason``.
|
|
73
|
+
``tool_calls`` is a list of ``{"id": str, "name": str, "arguments": dict}``.
|
|
74
|
+
|
|
75
|
+
Drivers that support tool use should override this method and set
|
|
76
|
+
``supports_tool_use = True``.
|
|
77
|
+
"""
|
|
78
|
+
raise NotImplementedError(f"{self.__class__.__name__} does not support tool use")
|
|
79
|
+
|
|
80
|
+
# ------------------------------------------------------------------
|
|
81
|
+
# Streaming
|
|
82
|
+
# ------------------------------------------------------------------
|
|
83
|
+
|
|
84
|
+
def generate_messages_stream(
|
|
85
|
+
self,
|
|
86
|
+
messages: list[dict[str, Any]],
|
|
87
|
+
options: dict[str, Any],
|
|
88
|
+
) -> Iterator[dict[str, Any]]:
|
|
89
|
+
"""Yield response chunks incrementally.
|
|
90
|
+
|
|
91
|
+
Each chunk is a dict:
|
|
92
|
+
- ``{"type": "delta", "text": str}`` for content fragments
|
|
93
|
+
- ``{"type": "done", "text": str, "meta": dict}`` for the final summary
|
|
94
|
+
|
|
95
|
+
Drivers that support streaming should override this method and set
|
|
96
|
+
``supports_streaming = True``.
|
|
97
|
+
"""
|
|
98
|
+
raise NotImplementedError(f"{self.__class__.__name__} does not support streaming")
|
|
99
|
+
|
|
100
|
+
# ------------------------------------------------------------------
|
|
101
|
+
# Hook-aware wrappers
|
|
102
|
+
# ------------------------------------------------------------------
|
|
103
|
+
|
|
104
|
+
def generate_with_hooks(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
105
|
+
"""Wrap :meth:`generate` with on_request / on_response / on_error callbacks."""
|
|
106
|
+
driver_name = getattr(self, "model", self.__class__.__name__)
|
|
107
|
+
self._fire_callback(
|
|
108
|
+
"on_request",
|
|
109
|
+
{"prompt": prompt, "messages": None, "options": options, "driver": driver_name},
|
|
110
|
+
)
|
|
111
|
+
t0 = time.perf_counter()
|
|
112
|
+
try:
|
|
113
|
+
resp = self.generate(prompt, options)
|
|
114
|
+
except Exception as exc:
|
|
115
|
+
self._fire_callback(
|
|
116
|
+
"on_error",
|
|
117
|
+
{"error": exc, "prompt": prompt, "messages": None, "options": options, "driver": driver_name},
|
|
118
|
+
)
|
|
119
|
+
raise
|
|
120
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
121
|
+
self._fire_callback(
|
|
122
|
+
"on_response",
|
|
123
|
+
{
|
|
124
|
+
"text": resp.get("text", ""),
|
|
125
|
+
"meta": resp.get("meta", {}),
|
|
126
|
+
"driver": driver_name,
|
|
127
|
+
"elapsed_ms": elapsed_ms,
|
|
128
|
+
},
|
|
129
|
+
)
|
|
130
|
+
return resp
|
|
131
|
+
|
|
132
|
+
def generate_messages_with_hooks(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
133
|
+
"""Wrap :meth:`generate_messages` with callbacks."""
|
|
134
|
+
driver_name = getattr(self, "model", self.__class__.__name__)
|
|
135
|
+
self._fire_callback(
|
|
136
|
+
"on_request",
|
|
137
|
+
{"prompt": None, "messages": messages, "options": options, "driver": driver_name},
|
|
138
|
+
)
|
|
139
|
+
t0 = time.perf_counter()
|
|
140
|
+
try:
|
|
141
|
+
resp = self.generate_messages(messages, options)
|
|
142
|
+
except Exception as exc:
|
|
143
|
+
self._fire_callback(
|
|
144
|
+
"on_error",
|
|
145
|
+
{"error": exc, "prompt": None, "messages": messages, "options": options, "driver": driver_name},
|
|
146
|
+
)
|
|
147
|
+
raise
|
|
148
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
149
|
+
self._fire_callback(
|
|
150
|
+
"on_response",
|
|
151
|
+
{
|
|
152
|
+
"text": resp.get("text", ""),
|
|
153
|
+
"meta": resp.get("meta", {}),
|
|
154
|
+
"driver": driver_name,
|
|
155
|
+
"elapsed_ms": elapsed_ms,
|
|
156
|
+
},
|
|
157
|
+
)
|
|
158
|
+
return resp
|
|
159
|
+
|
|
160
|
+
# ------------------------------------------------------------------
|
|
161
|
+
# Internal helpers
|
|
162
|
+
# ------------------------------------------------------------------
|
|
163
|
+
|
|
164
|
+
def _fire_callback(self, event: str, payload: dict[str, Any]) -> None:
|
|
165
|
+
"""Invoke a single callback, swallowing and logging any exception."""
|
|
166
|
+
if self.callbacks is None:
|
|
167
|
+
return
|
|
168
|
+
cb = getattr(self.callbacks, event, None)
|
|
169
|
+
if cb is None:
|
|
170
|
+
return
|
|
171
|
+
try:
|
|
172
|
+
cb(payload)
|
|
173
|
+
except Exception:
|
|
174
|
+
logger.exception("Callback %s raised an exception", event)
|
|
175
|
+
|
|
176
|
+
def _check_vision_support(self, messages: list[dict[str, Any]]) -> None:
|
|
177
|
+
"""Raise if messages contain image blocks and the driver lacks vision support."""
|
|
178
|
+
if self.supports_vision:
|
|
179
|
+
return
|
|
180
|
+
for msg in messages:
|
|
181
|
+
content = msg.get("content")
|
|
182
|
+
if isinstance(content, list):
|
|
183
|
+
for block in content:
|
|
184
|
+
if isinstance(block, dict) and block.get("type") == "image":
|
|
185
|
+
raise NotImplementedError(
|
|
186
|
+
f"{self.__class__.__name__} does not support vision/image inputs. "
|
|
187
|
+
"Use a vision-capable model."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
191
|
+
"""Transform universal message format into provider-specific wire format.
|
|
192
|
+
|
|
193
|
+
Vision-capable drivers override this to convert the universal image
|
|
194
|
+
blocks into their provider-specific format. The base implementation
|
|
195
|
+
validates vision support and returns messages unchanged.
|
|
196
|
+
"""
|
|
197
|
+
self._check_vision_support(messages)
|
|
198
|
+
return messages
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def _flatten_messages(messages: list[dict[str, Any]]) -> str:
|
|
202
|
+
"""Join messages into a single prompt string with role prefixes."""
|
|
203
|
+
parts: list[str] = []
|
|
204
|
+
for msg in messages:
|
|
205
|
+
role = msg.get("role", "user")
|
|
206
|
+
content = msg.get("content", "")
|
|
207
|
+
# Handle content that is a list of blocks (vision messages)
|
|
208
|
+
if isinstance(content, list):
|
|
209
|
+
text_parts = []
|
|
210
|
+
for block in content:
|
|
211
|
+
if isinstance(block, dict):
|
|
212
|
+
if block.get("type") == "text":
|
|
213
|
+
text_parts.append(block.get("text", ""))
|
|
214
|
+
elif block.get("type") == "image":
|
|
215
|
+
text_parts.append("[image]")
|
|
216
|
+
elif isinstance(block, str):
|
|
217
|
+
text_parts.append(block)
|
|
218
|
+
content = " ".join(text_parts)
|
|
219
|
+
if role == "system":
|
|
220
|
+
parts.append(f"[System]: {content}")
|
|
221
|
+
elif role == "assistant":
|
|
222
|
+
parts.append(f"[Assistant]: {content}")
|
|
223
|
+
else:
|
|
224
|
+
parts.append(f"[User]: {content}")
|
|
225
|
+
return "\n\n".join(parts)
|