solana-agent 20.1.2__py3-none-any.whl → 31.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- solana_agent/__init__.py +10 -5
- solana_agent/adapters/ffmpeg_transcoder.py +375 -0
- solana_agent/adapters/mongodb_adapter.py +15 -2
- solana_agent/adapters/openai_adapter.py +679 -0
- solana_agent/adapters/openai_realtime_ws.py +1813 -0
- solana_agent/adapters/pinecone_adapter.py +543 -0
- solana_agent/cli.py +128 -0
- solana_agent/client/solana_agent.py +180 -20
- solana_agent/domains/agent.py +13 -13
- solana_agent/domains/routing.py +18 -8
- solana_agent/factories/agent_factory.py +239 -38
- solana_agent/guardrails/pii.py +107 -0
- solana_agent/interfaces/client/client.py +95 -12
- solana_agent/interfaces/guardrails/guardrails.py +26 -0
- solana_agent/interfaces/plugins/plugins.py +2 -1
- solana_agent/interfaces/providers/__init__.py +0 -0
- solana_agent/interfaces/providers/audio.py +40 -0
- solana_agent/interfaces/providers/data_storage.py +9 -2
- solana_agent/interfaces/providers/llm.py +86 -9
- solana_agent/interfaces/providers/memory.py +13 -1
- solana_agent/interfaces/providers/realtime.py +212 -0
- solana_agent/interfaces/providers/vector_storage.py +53 -0
- solana_agent/interfaces/services/agent.py +27 -12
- solana_agent/interfaces/services/knowledge_base.py +59 -0
- solana_agent/interfaces/services/query.py +41 -8
- solana_agent/interfaces/services/routing.py +0 -1
- solana_agent/plugins/manager.py +37 -16
- solana_agent/plugins/registry.py +34 -19
- solana_agent/plugins/tools/__init__.py +0 -5
- solana_agent/plugins/tools/auto_tool.py +1 -0
- solana_agent/repositories/memory.py +332 -111
- solana_agent/services/__init__.py +1 -1
- solana_agent/services/agent.py +390 -241
- solana_agent/services/knowledge_base.py +768 -0
- solana_agent/services/query.py +1858 -153
- solana_agent/services/realtime.py +626 -0
- solana_agent/services/routing.py +104 -51
- solana_agent-31.4.0.dist-info/METADATA +1070 -0
- solana_agent-31.4.0.dist-info/RECORD +49 -0
- {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info}/WHEEL +1 -1
- solana_agent-31.4.0.dist-info/entry_points.txt +3 -0
- solana_agent/adapters/llm_adapter.py +0 -160
- solana_agent-20.1.2.dist-info/METADATA +0 -464
- solana_agent-20.1.2.dist-info/RECORD +0 -35
- {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,679 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM provider adapters for the Solana Agent system.
|
|
3
|
+
|
|
4
|
+
These adapters implement the LLMProvider interface for different LLM services.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import base64
|
|
9
|
+
import io
|
|
10
|
+
import math
|
|
11
|
+
from typing import (
|
|
12
|
+
AsyncGenerator,
|
|
13
|
+
List,
|
|
14
|
+
Literal,
|
|
15
|
+
Optional,
|
|
16
|
+
Type,
|
|
17
|
+
TypeVar,
|
|
18
|
+
Dict,
|
|
19
|
+
Any,
|
|
20
|
+
Union,
|
|
21
|
+
)
|
|
22
|
+
from PIL import Image
|
|
23
|
+
from openai import AsyncOpenAI, OpenAIError
|
|
24
|
+
from pydantic import BaseModel
|
|
25
|
+
import instructor
|
|
26
|
+
from instructor import Mode
|
|
27
|
+
import logfire
|
|
28
|
+
|
|
29
|
+
from solana_agent.interfaces.providers.llm import LLMProvider
|
|
30
|
+
|
|
31
|
+
# Setup logger for this module
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
T = TypeVar("T", bound=BaseModel)
|
|
35
|
+
|
|
36
|
+
DEFAULT_CHAT_MODEL = "gpt-4.1"
|
|
37
|
+
DEFAULT_VISION_MODEL = "gpt-4.1"
|
|
38
|
+
DEFAULT_PARSE_MODEL = "gpt-4.1"
|
|
39
|
+
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large"
|
|
40
|
+
DEFAULT_EMBEDDING_DIMENSIONS = 3072
|
|
41
|
+
DEFAULT_TRANSCRIPTION_MODEL = "gpt-4o-mini-transcribe"
|
|
42
|
+
DEFAULT_TTS_MODEL = "tts-1"
|
|
43
|
+
|
|
44
|
+
# Image constants
|
|
45
|
+
SUPPORTED_IMAGE_FORMATS = {"PNG", "JPEG", "WEBP", "GIF"}
|
|
46
|
+
MAX_IMAGE_SIZE_MB = 20
|
|
47
|
+
MAX_TOTAL_IMAGE_SIZE_MB = 50
|
|
48
|
+
MAX_IMAGE_COUNT = 500
|
|
49
|
+
GPT41_PATCH_SIZE = 32
|
|
50
|
+
GPT41_MAX_PATCHES = 1536
|
|
51
|
+
GPT41_MINI_MULTIPLIER = 1.62
|
|
52
|
+
GPT41_NANO_MULTIPLIER = 2.46
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class OpenAIAdapter(LLMProvider):
|
|
56
|
+
"""OpenAI implementation of LLMProvider with web search capabilities."""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
api_key: str,
|
|
61
|
+
base_url: Optional[str] = None,
|
|
62
|
+
model: Optional[str] = None,
|
|
63
|
+
logfire_api_key: Optional[str] = None,
|
|
64
|
+
):
|
|
65
|
+
self.api_key = api_key
|
|
66
|
+
self.base_url = base_url
|
|
67
|
+
|
|
68
|
+
# Create client with base_url if provided (for Grok support)
|
|
69
|
+
if base_url:
|
|
70
|
+
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
71
|
+
else:
|
|
72
|
+
self.client = AsyncOpenAI(api_key=api_key)
|
|
73
|
+
|
|
74
|
+
self.logfire = False
|
|
75
|
+
if logfire_api_key:
|
|
76
|
+
try:
|
|
77
|
+
logfire.configure(token=logfire_api_key)
|
|
78
|
+
self.logfire = True
|
|
79
|
+
# Instrument the main client immediately after configuring logfire
|
|
80
|
+
logfire.instrument_openai(self.client)
|
|
81
|
+
logger.info(
|
|
82
|
+
"Logfire configured and OpenAI client instrumented successfully."
|
|
83
|
+
)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
logger.error(f"Failed to configure Logfire: {e}")
|
|
86
|
+
self.logfire = False
|
|
87
|
+
|
|
88
|
+
# Use provided model or defaults (for Grok or OpenAI)
|
|
89
|
+
if model:
|
|
90
|
+
# Custom model provided (e.g., from Grok config)
|
|
91
|
+
self.parse_model = model
|
|
92
|
+
self.text_model = model
|
|
93
|
+
self.vision_model = model
|
|
94
|
+
else:
|
|
95
|
+
# Use OpenAI defaults
|
|
96
|
+
self.parse_model = DEFAULT_PARSE_MODEL
|
|
97
|
+
self.text_model = DEFAULT_CHAT_MODEL
|
|
98
|
+
self.vision_model = DEFAULT_VISION_MODEL
|
|
99
|
+
|
|
100
|
+
# These remain OpenAI-specific
|
|
101
|
+
self.transcription_model = DEFAULT_TRANSCRIPTION_MODEL
|
|
102
|
+
self.tts_model = DEFAULT_TTS_MODEL
|
|
103
|
+
self.embedding_model = DEFAULT_EMBEDDING_MODEL
|
|
104
|
+
self.embedding_dimensions = DEFAULT_EMBEDDING_DIMENSIONS
|
|
105
|
+
|
|
106
|
+
def get_api_key(self) -> Optional[str]: # pragma: no cover
|
|
107
|
+
"""Return the API key used to configure the OpenAI client."""
|
|
108
|
+
return getattr(self, "api_key", None)
|
|
109
|
+
|
|
110
|
+
async def tts(
|
|
111
|
+
self,
|
|
112
|
+
text: str,
|
|
113
|
+
instructions: str = "You speak in a friendly and helpful manner.",
|
|
114
|
+
voice: Literal[
|
|
115
|
+
"alloy",
|
|
116
|
+
"ash",
|
|
117
|
+
"ballad",
|
|
118
|
+
"coral",
|
|
119
|
+
"echo",
|
|
120
|
+
"fable",
|
|
121
|
+
"onyx",
|
|
122
|
+
"nova",
|
|
123
|
+
"sage",
|
|
124
|
+
"shimmer",
|
|
125
|
+
] = "nova",
|
|
126
|
+
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = "aac",
|
|
127
|
+
) -> AsyncGenerator[bytes, None]: # pragma: no cover
|
|
128
|
+
"""Stream text-to-speech audio from OpenAI models.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
text: Text to convert to speech
|
|
132
|
+
instructions: Not used in this implementation
|
|
133
|
+
voice: Voice to use for synthesis
|
|
134
|
+
response_format: Audio format
|
|
135
|
+
|
|
136
|
+
Yields:
|
|
137
|
+
Audio bytes as they become available
|
|
138
|
+
"""
|
|
139
|
+
try:
|
|
140
|
+
if self.logfire: # Instrument only if logfire is enabled
|
|
141
|
+
logfire.instrument_openai(self.client)
|
|
142
|
+
async with self.client.audio.speech.with_streaming_response.create(
|
|
143
|
+
model=self.tts_model,
|
|
144
|
+
voice=voice,
|
|
145
|
+
input=text,
|
|
146
|
+
response_format=response_format,
|
|
147
|
+
) as stream:
|
|
148
|
+
# Stream the bytes in 16KB chunks
|
|
149
|
+
async for chunk in stream.iter_bytes(chunk_size=1024 * 16):
|
|
150
|
+
yield chunk
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
# Log the exception with traceback
|
|
154
|
+
logger.exception(f"Error in text_to_speech: {e}")
|
|
155
|
+
yield b"" # Return empty bytes on error
|
|
156
|
+
|
|
157
|
+
async def transcribe_audio(
|
|
158
|
+
self,
|
|
159
|
+
audio_bytes: bytes,
|
|
160
|
+
input_format: Literal[
|
|
161
|
+
"flac", "mp3", "mp4", "mpeg", "mpga", "m4a", "ogg", "wav", "webm"
|
|
162
|
+
] = "mp4",
|
|
163
|
+
) -> AsyncGenerator[str, None]: # pragma: no cover
|
|
164
|
+
"""Stream transcription of an audio file.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
audio_bytes: Audio file bytes
|
|
168
|
+
input_format: Format of the input audio file
|
|
169
|
+
|
|
170
|
+
Yields:
|
|
171
|
+
Transcript text chunks as they become available
|
|
172
|
+
"""
|
|
173
|
+
try:
|
|
174
|
+
if self.logfire: # Instrument only if logfire is enabled
|
|
175
|
+
logfire.instrument_openai(self.client)
|
|
176
|
+
async with self.client.audio.transcriptions.with_streaming_response.create(
|
|
177
|
+
model=self.transcription_model,
|
|
178
|
+
file=(f"file.{input_format}", audio_bytes),
|
|
179
|
+
response_format="text",
|
|
180
|
+
) as stream:
|
|
181
|
+
# Stream the text in 16KB chunks
|
|
182
|
+
async for chunk in stream.iter_text(chunk_size=1024 * 16):
|
|
183
|
+
yield chunk
|
|
184
|
+
|
|
185
|
+
except Exception as e:
|
|
186
|
+
# Log the exception with traceback
|
|
187
|
+
logger.exception(f"Error in transcribe_audio: {e}")
|
|
188
|
+
yield f"I apologize, but I encountered an error transcribing the audio: {str(e)}"
|
|
189
|
+
|
|
190
|
+
async def generate_text(
|
|
191
|
+
self,
|
|
192
|
+
prompt: str,
|
|
193
|
+
system_prompt: str = "",
|
|
194
|
+
api_key: Optional[str] = None,
|
|
195
|
+
base_url: Optional[str] = None,
|
|
196
|
+
model: Optional[str] = None,
|
|
197
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
198
|
+
) -> str: # pragma: no cover
|
|
199
|
+
"""Generate text or function call from OpenAI models."""
|
|
200
|
+
messages = []
|
|
201
|
+
if system_prompt:
|
|
202
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
203
|
+
messages.append({"role": "user", "content": prompt})
|
|
204
|
+
|
|
205
|
+
request_params = {
|
|
206
|
+
"messages": messages,
|
|
207
|
+
"model": model or self.text_model,
|
|
208
|
+
}
|
|
209
|
+
if tools:
|
|
210
|
+
request_params["tools"] = tools
|
|
211
|
+
|
|
212
|
+
if api_key and base_url:
|
|
213
|
+
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
214
|
+
else:
|
|
215
|
+
client = self.client
|
|
216
|
+
|
|
217
|
+
if self.logfire:
|
|
218
|
+
logfire.instrument_openai(client)
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
response = await client.chat.completions.create(**request_params)
|
|
222
|
+
return response
|
|
223
|
+
except OpenAIError as e:
|
|
224
|
+
logger.error(f"OpenAI API error during text generation: {e}")
|
|
225
|
+
return None
|
|
226
|
+
except Exception as e:
|
|
227
|
+
logger.exception(f"Error in generate_text: {e}")
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
def _calculate_gpt41_image_cost(self, width: int, height: int, model: str) -> int:
|
|
231
|
+
"""Calculates the token cost for an image with GPT-4.1 models."""
|
|
232
|
+
patches_wide = math.ceil(width / GPT41_PATCH_SIZE)
|
|
233
|
+
patches_high = math.ceil(height / GPT41_PATCH_SIZE)
|
|
234
|
+
total_patches_needed = patches_wide * patches_high
|
|
235
|
+
|
|
236
|
+
if total_patches_needed > GPT41_MAX_PATCHES:
|
|
237
|
+
scale_factor = math.sqrt(GPT41_MAX_PATCHES / total_patches_needed)
|
|
238
|
+
new_width = math.floor(width * scale_factor)
|
|
239
|
+
new_height = math.floor(height * scale_factor)
|
|
240
|
+
|
|
241
|
+
final_patches_wide_scaled = math.ceil(new_width / GPT41_PATCH_SIZE)
|
|
242
|
+
final_patches_high_scaled = math.ceil(new_height / GPT41_PATCH_SIZE)
|
|
243
|
+
image_tokens = final_patches_wide_scaled * final_patches_high_scaled
|
|
244
|
+
|
|
245
|
+
# Ensure it doesn't exceed the cap due to ceiling operations after scaling
|
|
246
|
+
image_tokens = min(image_tokens, GPT41_MAX_PATCHES)
|
|
247
|
+
|
|
248
|
+
logger.debug(
|
|
249
|
+
f"Image scaled down. Original patches: {total_patches_needed}, New dims: ~{new_width}x{new_height}, Final patches: {image_tokens}"
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
else:
|
|
253
|
+
image_tokens = total_patches_needed
|
|
254
|
+
logger.debug(f"Image fits within patch limit. Patches: {image_tokens}")
|
|
255
|
+
|
|
256
|
+
# Apply model-specific multiplier
|
|
257
|
+
if "mini" in model:
|
|
258
|
+
total_tokens = math.ceil(image_tokens * GPT41_MINI_MULTIPLIER)
|
|
259
|
+
elif "nano" in model:
|
|
260
|
+
total_tokens = math.ceil(image_tokens * GPT41_NANO_MULTIPLIER)
|
|
261
|
+
else: # Assume base gpt-4.1
|
|
262
|
+
total_tokens = image_tokens
|
|
263
|
+
|
|
264
|
+
logger.info(
|
|
265
|
+
f"Calculated token cost for image ({width}x{height}) with model '{model}': {total_tokens} tokens (base image tokens: {image_tokens})"
|
|
266
|
+
)
|
|
267
|
+
return total_tokens
|
|
268
|
+
|
|
269
|
+
async def generate_text_with_images(
|
|
270
|
+
self,
|
|
271
|
+
prompt: str,
|
|
272
|
+
images: List[Union[str, bytes]],
|
|
273
|
+
system_prompt: str = "",
|
|
274
|
+
detail: Literal["low", "high", "auto"] = "auto",
|
|
275
|
+
) -> str: # pragma: no cover
|
|
276
|
+
"""Generate text from OpenAI models using text and image inputs."""
|
|
277
|
+
if not images:
|
|
278
|
+
logger.warning(
|
|
279
|
+
"generate_text_with_images called with no images. Falling back to generate_text."
|
|
280
|
+
)
|
|
281
|
+
return await self.generate_text(prompt, system_prompt)
|
|
282
|
+
|
|
283
|
+
target_model = self.vision_model
|
|
284
|
+
if "gpt-4.1" not in target_model: # Basic check for vision model
|
|
285
|
+
logger.warning(
|
|
286
|
+
f"Model '{target_model}' might not support vision. Using it anyway."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
content_list: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
|
|
290
|
+
total_image_bytes = 0
|
|
291
|
+
total_image_tokens = 0
|
|
292
|
+
|
|
293
|
+
if len(images) > MAX_IMAGE_COUNT:
|
|
294
|
+
logger.error(
|
|
295
|
+
f"Too many images provided ({len(images)}). Maximum is {MAX_IMAGE_COUNT}."
|
|
296
|
+
)
|
|
297
|
+
return f"Error: Too many images provided ({len(images)}). Maximum is {MAX_IMAGE_COUNT}."
|
|
298
|
+
|
|
299
|
+
for i, image_input in enumerate(images):
|
|
300
|
+
image_url_data: Dict[str, Any] = {"detail": detail}
|
|
301
|
+
image_bytes: Optional[bytes] = None
|
|
302
|
+
image_format: Optional[str] = None
|
|
303
|
+
width: Optional[int] = None
|
|
304
|
+
height: Optional[int] = None
|
|
305
|
+
|
|
306
|
+
try:
|
|
307
|
+
if isinstance(image_input, str): # It's a URL
|
|
308
|
+
logger.debug(f"Processing image URL: {image_input[:50]}...")
|
|
309
|
+
image_url_data["url"] = image_input
|
|
310
|
+
# Cannot easily validate size/format/dimensions or calculate cost for URLs
|
|
311
|
+
logger.warning(
|
|
312
|
+
"Cannot validate size/format or calculate token cost for image URLs."
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
elif isinstance(image_input, bytes): # It's image bytes
|
|
316
|
+
logger.debug(
|
|
317
|
+
f"Processing image bytes (size: {len(image_input)})..."
|
|
318
|
+
)
|
|
319
|
+
image_bytes = image_input
|
|
320
|
+
size_mb = len(image_bytes) / (1024 * 1024)
|
|
321
|
+
if size_mb > MAX_IMAGE_SIZE_MB:
|
|
322
|
+
logger.error(
|
|
323
|
+
f"Image {i + 1} size ({size_mb:.2f}MB) exceeds limit ({MAX_IMAGE_SIZE_MB}MB)."
|
|
324
|
+
)
|
|
325
|
+
return f"Error: Image {i + 1} size ({size_mb:.2f}MB) exceeds limit ({MAX_IMAGE_SIZE_MB}MB)."
|
|
326
|
+
total_image_bytes += len(image_bytes)
|
|
327
|
+
|
|
328
|
+
# Use Pillow to validate format and get dimensions
|
|
329
|
+
try:
|
|
330
|
+
img = Image.open(io.BytesIO(image_bytes))
|
|
331
|
+
image_format = img.format
|
|
332
|
+
width, height = img.size
|
|
333
|
+
img.verify() # Verify integrity
|
|
334
|
+
# Re-open after verify
|
|
335
|
+
img = Image.open(io.BytesIO(image_bytes))
|
|
336
|
+
width, height = img.size # Get dimensions again
|
|
337
|
+
|
|
338
|
+
if image_format not in SUPPORTED_IMAGE_FORMATS:
|
|
339
|
+
logger.error(
|
|
340
|
+
f"Unsupported image format '{image_format}' for image {i + 1}."
|
|
341
|
+
)
|
|
342
|
+
return f"Error: Unsupported image format '{image_format}'. Supported formats: {SUPPORTED_IMAGE_FORMATS}."
|
|
343
|
+
|
|
344
|
+
logger.debug(
|
|
345
|
+
f"Image {i + 1}: Format={image_format}, Dimensions={width}x{height}"
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Calculate cost only if dimensions are available
|
|
349
|
+
if width and height and "gpt-4.1" in target_model:
|
|
350
|
+
total_image_tokens += self._calculate_gpt41_image_cost(
|
|
351
|
+
width, height, target_model
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
except (IOError, SyntaxError) as img_err:
|
|
355
|
+
logger.error(
|
|
356
|
+
f"Invalid or corrupted image data for image {i + 1}: {img_err}"
|
|
357
|
+
)
|
|
358
|
+
return f"Error: Invalid or corrupted image data provided for image {i + 1}."
|
|
359
|
+
except Exception as pillow_err:
|
|
360
|
+
logger.error(
|
|
361
|
+
f"Pillow error processing image {i + 1}: {pillow_err}"
|
|
362
|
+
)
|
|
363
|
+
return f"Error: Could not process image data for image {i + 1}."
|
|
364
|
+
|
|
365
|
+
# Encode to Base64 Data URL
|
|
366
|
+
mime_type = Image.MIME.get(image_format)
|
|
367
|
+
if not mime_type:
|
|
368
|
+
logger.warning(
|
|
369
|
+
f"Could not determine MIME type for format {image_format}. Defaulting to image/jpeg."
|
|
370
|
+
)
|
|
371
|
+
mime_type = "image/jpeg"
|
|
372
|
+
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
|
373
|
+
image_url_data["url"] = f"data:{mime_type};base64,{base64_image}"
|
|
374
|
+
|
|
375
|
+
else:
|
|
376
|
+
logger.error(
|
|
377
|
+
f"Invalid image input type for image {i + 1}: {type(image_input)}"
|
|
378
|
+
)
|
|
379
|
+
return f"Error: Invalid image input type for image {i + 1}. Must be URL (str) or bytes."
|
|
380
|
+
|
|
381
|
+
content_list.append({"type": "image_url", "image_url": image_url_data})
|
|
382
|
+
|
|
383
|
+
except Exception as proc_err:
|
|
384
|
+
logger.error(
|
|
385
|
+
f"Error processing image {i + 1}: {proc_err}", exc_info=True
|
|
386
|
+
)
|
|
387
|
+
return f"Error: Failed to process image {i + 1}."
|
|
388
|
+
|
|
389
|
+
total_size_mb = total_image_bytes / (1024 * 1024)
|
|
390
|
+
if total_size_mb > MAX_TOTAL_IMAGE_SIZE_MB:
|
|
391
|
+
logger.error(
|
|
392
|
+
f"Total image size ({total_size_mb:.2f}MB) exceeds limit ({MAX_TOTAL_IMAGE_SIZE_MB}MB)."
|
|
393
|
+
)
|
|
394
|
+
return f"Error: Total image size ({total_size_mb:.2f}MB) exceeds limit ({MAX_TOTAL_IMAGE_SIZE_MB}MB)."
|
|
395
|
+
|
|
396
|
+
messages: List[Dict[str, Any]] = []
|
|
397
|
+
if system_prompt:
|
|
398
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
399
|
+
messages.append({"role": "user", "content": content_list})
|
|
400
|
+
|
|
401
|
+
request_params = {
|
|
402
|
+
"messages": messages,
|
|
403
|
+
"model": target_model,
|
|
404
|
+
# "max_tokens": 300 # Optional: Add max_tokens if needed
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
if self.logfire:
|
|
408
|
+
logfire.instrument_openai(self.client)
|
|
409
|
+
|
|
410
|
+
logger.info(
|
|
411
|
+
f"Sending request to '{target_model}' with {len(images)} images. Total calculated image tokens (approx): {total_image_tokens}"
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
try:
|
|
415
|
+
response = await self.client.chat.completions.create(**request_params)
|
|
416
|
+
if response.choices and response.choices[0].message.content:
|
|
417
|
+
# Log actual usage if available
|
|
418
|
+
if response.usage:
|
|
419
|
+
logger.info(
|
|
420
|
+
f"OpenAI API Usage: Prompt={response.usage.prompt_tokens}, Completion={response.usage.completion_tokens}, Total={response.usage.total_tokens}"
|
|
421
|
+
)
|
|
422
|
+
return response.choices[0].message.content
|
|
423
|
+
else:
|
|
424
|
+
logger.warning("Received vision response with no content.")
|
|
425
|
+
return ""
|
|
426
|
+
except OpenAIError as e: # Catch specific OpenAI errors
|
|
427
|
+
logger.error(f"OpenAI API error during vision request: {e}")
|
|
428
|
+
return f"I apologize, but I encountered an API error: {e}"
|
|
429
|
+
except Exception as e:
|
|
430
|
+
logger.exception(f"Error in generate_text_with_images: {e}")
|
|
431
|
+
return f"I apologize, but I encountered an unexpected error: {e}"
|
|
432
|
+
|
|
433
|
+
async def chat_stream(
|
|
434
|
+
self,
|
|
435
|
+
messages: List[Dict[str, Any]],
|
|
436
|
+
model: Optional[str] = None,
|
|
437
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
438
|
+
api_key: Optional[str] = None,
|
|
439
|
+
base_url: Optional[str] = None,
|
|
440
|
+
) -> AsyncGenerator[Dict[str, Any], None]: # pragma: no cover
|
|
441
|
+
"""Stream chat completions with optional tool calls, yielding normalized events."""
|
|
442
|
+
try:
|
|
443
|
+
request_params: Dict[str, Any] = {
|
|
444
|
+
"messages": messages,
|
|
445
|
+
"model": model or self.text_model,
|
|
446
|
+
"stream": True,
|
|
447
|
+
}
|
|
448
|
+
if tools:
|
|
449
|
+
request_params["tools"] = tools
|
|
450
|
+
|
|
451
|
+
# Use custom client if api_key and base_url provided, otherwise use default
|
|
452
|
+
if api_key and base_url:
|
|
453
|
+
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
454
|
+
else:
|
|
455
|
+
client = self.client
|
|
456
|
+
|
|
457
|
+
if self.logfire:
|
|
458
|
+
logfire.instrument_openai(client)
|
|
459
|
+
|
|
460
|
+
stream = await client.chat.completions.create(**request_params)
|
|
461
|
+
async for chunk in stream:
|
|
462
|
+
try:
|
|
463
|
+
if not chunk or not getattr(chunk, "choices", None):
|
|
464
|
+
continue
|
|
465
|
+
ch = chunk.choices[0]
|
|
466
|
+
delta = getattr(ch, "delta", None)
|
|
467
|
+
if delta is None:
|
|
468
|
+
# Some SDKs use 'message' instead of 'delta'
|
|
469
|
+
delta = getattr(ch, "message", None)
|
|
470
|
+
if delta is None:
|
|
471
|
+
# Finish event
|
|
472
|
+
finish = getattr(ch, "finish_reason", None)
|
|
473
|
+
if finish:
|
|
474
|
+
yield {"type": "message_end", "finish_reason": finish}
|
|
475
|
+
continue
|
|
476
|
+
|
|
477
|
+
# Content delta
|
|
478
|
+
content_piece = getattr(delta, "content", None)
|
|
479
|
+
if content_piece:
|
|
480
|
+
yield {"type": "content", "delta": content_piece}
|
|
481
|
+
|
|
482
|
+
# Tool call deltas
|
|
483
|
+
tool_calls = getattr(delta, "tool_calls", None)
|
|
484
|
+
if tool_calls:
|
|
485
|
+
for idx, tc in enumerate(tool_calls):
|
|
486
|
+
try:
|
|
487
|
+
tc_id = getattr(tc, "id", None)
|
|
488
|
+
func = getattr(tc, "function", None)
|
|
489
|
+
name = getattr(func, "name", None) if func else None
|
|
490
|
+
args_piece = (
|
|
491
|
+
getattr(func, "arguments", "") if func else ""
|
|
492
|
+
)
|
|
493
|
+
yield {
|
|
494
|
+
"type": "tool_call_delta",
|
|
495
|
+
"id": tc_id,
|
|
496
|
+
"index": getattr(tc, "index", idx),
|
|
497
|
+
"name": name,
|
|
498
|
+
"arguments_delta": args_piece or "",
|
|
499
|
+
}
|
|
500
|
+
except Exception:
|
|
501
|
+
continue
|
|
502
|
+
except Exception as parse_err:
|
|
503
|
+
logger.debug(f"Error parsing stream chunk: {parse_err}")
|
|
504
|
+
continue
|
|
505
|
+
# End of stream (SDK may not emit finish event in all cases)
|
|
506
|
+
yield {"type": "message_end", "finish_reason": "end_of_stream"}
|
|
507
|
+
except Exception as e:
|
|
508
|
+
logger.exception(f"Error in chat_stream: {e}")
|
|
509
|
+
yield {"type": "error", "error": str(e)}
|
|
510
|
+
|
|
511
|
+
async def parse_structured_output(
|
|
512
|
+
self,
|
|
513
|
+
prompt: str,
|
|
514
|
+
system_prompt: str,
|
|
515
|
+
model_class: Type[T],
|
|
516
|
+
api_key: Optional[str] = None,
|
|
517
|
+
base_url: Optional[str] = None,
|
|
518
|
+
model: Optional[str] = None,
|
|
519
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
520
|
+
) -> T: # pragma: no cover
|
|
521
|
+
"""Generate structured output using Pydantic model parsing with Instructor."""
|
|
522
|
+
|
|
523
|
+
messages = []
|
|
524
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
525
|
+
messages.append({"role": "user", "content": prompt})
|
|
526
|
+
|
|
527
|
+
try:
|
|
528
|
+
if api_key and base_url:
|
|
529
|
+
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
530
|
+
else:
|
|
531
|
+
client = self.client
|
|
532
|
+
|
|
533
|
+
if self.logfire:
|
|
534
|
+
logfire.instrument_openai(client)
|
|
535
|
+
|
|
536
|
+
# Use the provided model or the default parse model
|
|
537
|
+
current_parse_model = model or self.parse_model
|
|
538
|
+
|
|
539
|
+
patched_client = instructor.from_openai(client, mode=Mode.TOOLS_STRICT)
|
|
540
|
+
|
|
541
|
+
create_args = {
|
|
542
|
+
"model": current_parse_model,
|
|
543
|
+
"messages": messages,
|
|
544
|
+
"response_model": model_class,
|
|
545
|
+
"max_retries": 2, # Automatically retry on validation errors
|
|
546
|
+
}
|
|
547
|
+
if tools:
|
|
548
|
+
create_args["tools"] = tools
|
|
549
|
+
|
|
550
|
+
response = await patched_client.chat.completions.create(**create_args)
|
|
551
|
+
return response
|
|
552
|
+
except Exception as e:
|
|
553
|
+
logger.warning(
|
|
554
|
+
f"Instructor parsing (TOOLS_STRICT mode) failed: {e}"
|
|
555
|
+
) # Log warning
|
|
556
|
+
|
|
557
|
+
try:
|
|
558
|
+
# Determine client again for fallback
|
|
559
|
+
if api_key and base_url:
|
|
560
|
+
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
561
|
+
else:
|
|
562
|
+
client = self.client
|
|
563
|
+
|
|
564
|
+
if self.logfire: # Instrument again if needed
|
|
565
|
+
logfire.instrument_openai(client)
|
|
566
|
+
|
|
567
|
+
# Use the provided model or the default parse model
|
|
568
|
+
current_parse_model = model or self.parse_model
|
|
569
|
+
|
|
570
|
+
# First fallback: Try regular JSON mode
|
|
571
|
+
logger.info("Falling back to instructor JSON mode.") # Log info
|
|
572
|
+
patched_client = instructor.from_openai(client, mode=Mode.JSON)
|
|
573
|
+
response = await patched_client.chat.completions.create(
|
|
574
|
+
model=current_parse_model, # Use the determined model
|
|
575
|
+
messages=messages,
|
|
576
|
+
response_model=model_class,
|
|
577
|
+
max_retries=1,
|
|
578
|
+
)
|
|
579
|
+
return response
|
|
580
|
+
except Exception as json_error:
|
|
581
|
+
logger.warning(
|
|
582
|
+
f"Instructor JSON mode fallback also failed: {json_error}"
|
|
583
|
+
) # Log warning
|
|
584
|
+
|
|
585
|
+
try:
|
|
586
|
+
# Determine client again for final fallback
|
|
587
|
+
if api_key and base_url:
|
|
588
|
+
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
589
|
+
else:
|
|
590
|
+
client = self.client
|
|
591
|
+
|
|
592
|
+
if self.logfire: # Instrument again if needed
|
|
593
|
+
logfire.instrument_openai(client)
|
|
594
|
+
|
|
595
|
+
# Use the provided model or the default parse model
|
|
596
|
+
current_parse_model = model or self.parse_model
|
|
597
|
+
|
|
598
|
+
# Final fallback: Manual extraction with a detailed prompt
|
|
599
|
+
logger.info("Falling back to manual JSON extraction.") # Log info
|
|
600
|
+
fallback_system_prompt = f"""
|
|
601
|
+
{system_prompt}
|
|
602
|
+
|
|
603
|
+
You must respond with valid JSON that can be parsed as the following Pydantic model:
|
|
604
|
+
{model_class.model_json_schema()}
|
|
605
|
+
|
|
606
|
+
Ensure the response contains ONLY the JSON object and nothing else.
|
|
607
|
+
"""
|
|
608
|
+
|
|
609
|
+
# Regular completion without instructor
|
|
610
|
+
completion = await client.chat.completions.create(
|
|
611
|
+
model=current_parse_model, # Use the determined model
|
|
612
|
+
messages=[
|
|
613
|
+
{"role": "system", "content": fallback_system_prompt},
|
|
614
|
+
{"role": "user", "content": prompt},
|
|
615
|
+
],
|
|
616
|
+
response_format={"type": "json_object"},
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Extract and parse the JSON response
|
|
620
|
+
json_str = completion.choices[0].message.content
|
|
621
|
+
|
|
622
|
+
# Use Pydantic to parse and validate
|
|
623
|
+
return model_class.model_validate_json(json_str)
|
|
624
|
+
|
|
625
|
+
except Exception as fallback_error:
|
|
626
|
+
# Log the final exception with traceback
|
|
627
|
+
logger.exception(
|
|
628
|
+
f"All structured output fallback methods failed: {fallback_error}"
|
|
629
|
+
)
|
|
630
|
+
raise ValueError(
|
|
631
|
+
f"Failed to generate structured output: {e}. All fallbacks failed."
|
|
632
|
+
) from e
|
|
633
|
+
|
|
634
|
+
async def embed_text(
|
|
635
|
+
self, text: str, model: Optional[str] = None, dimensions: Optional[int] = None
|
|
636
|
+
) -> List[float]: # pragma: no cover
|
|
637
|
+
"""Generate an embedding for the given text using OpenAI.
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
text: The text to embed.
|
|
641
|
+
model: The embedding model to use (defaults to text-embedding-3-large).
|
|
642
|
+
dimensions: Desired output dimensions for the embedding.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
A list of floats representing the embedding vector.
|
|
646
|
+
"""
|
|
647
|
+
if not text:
|
|
648
|
+
# Log error instead of raising immediately, let caller handle empty input if needed
|
|
649
|
+
logger.error("Attempted to embed empty text.")
|
|
650
|
+
raise ValueError("Text cannot be empty")
|
|
651
|
+
|
|
652
|
+
try:
|
|
653
|
+
# Use provided model/dimensions or fall back to defaults
|
|
654
|
+
embedding_model = model or self.embedding_model
|
|
655
|
+
embedding_dimensions = dimensions or self.embedding_dimensions
|
|
656
|
+
|
|
657
|
+
# Replace newlines with spaces as recommended by OpenAI
|
|
658
|
+
text = text.replace("\n", " ")
|
|
659
|
+
|
|
660
|
+
if self.logfire: # Instrument only if logfire is enabled
|
|
661
|
+
logfire.instrument_openai(self.client)
|
|
662
|
+
|
|
663
|
+
response = await self.client.embeddings.create(
|
|
664
|
+
input=[text], model=embedding_model, dimensions=embedding_dimensions
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
if response.data and response.data[0].embedding:
|
|
668
|
+
return response.data[0].embedding
|
|
669
|
+
else:
|
|
670
|
+
# Log warning about unexpected response structure
|
|
671
|
+
logger.warning(
|
|
672
|
+
"Failed to retrieve embedding from OpenAI response structure."
|
|
673
|
+
)
|
|
674
|
+
raise ValueError("Failed to retrieve embedding from OpenAI response")
|
|
675
|
+
|
|
676
|
+
except Exception as e:
|
|
677
|
+
# Log the exception with traceback before raising
|
|
678
|
+
logger.exception(f"Error generating embedding: {e}")
|
|
679
|
+
raise # Re-raise the original exception
|