solana-agent 20.1.2__py3-none-any.whl → 31.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. solana_agent/__init__.py +10 -5
  2. solana_agent/adapters/ffmpeg_transcoder.py +375 -0
  3. solana_agent/adapters/mongodb_adapter.py +15 -2
  4. solana_agent/adapters/openai_adapter.py +679 -0
  5. solana_agent/adapters/openai_realtime_ws.py +1813 -0
  6. solana_agent/adapters/pinecone_adapter.py +543 -0
  7. solana_agent/cli.py +128 -0
  8. solana_agent/client/solana_agent.py +180 -20
  9. solana_agent/domains/agent.py +13 -13
  10. solana_agent/domains/routing.py +18 -8
  11. solana_agent/factories/agent_factory.py +239 -38
  12. solana_agent/guardrails/pii.py +107 -0
  13. solana_agent/interfaces/client/client.py +95 -12
  14. solana_agent/interfaces/guardrails/guardrails.py +26 -0
  15. solana_agent/interfaces/plugins/plugins.py +2 -1
  16. solana_agent/interfaces/providers/__init__.py +0 -0
  17. solana_agent/interfaces/providers/audio.py +40 -0
  18. solana_agent/interfaces/providers/data_storage.py +9 -2
  19. solana_agent/interfaces/providers/llm.py +86 -9
  20. solana_agent/interfaces/providers/memory.py +13 -1
  21. solana_agent/interfaces/providers/realtime.py +212 -0
  22. solana_agent/interfaces/providers/vector_storage.py +53 -0
  23. solana_agent/interfaces/services/agent.py +27 -12
  24. solana_agent/interfaces/services/knowledge_base.py +59 -0
  25. solana_agent/interfaces/services/query.py +41 -8
  26. solana_agent/interfaces/services/routing.py +0 -1
  27. solana_agent/plugins/manager.py +37 -16
  28. solana_agent/plugins/registry.py +34 -19
  29. solana_agent/plugins/tools/__init__.py +0 -5
  30. solana_agent/plugins/tools/auto_tool.py +1 -0
  31. solana_agent/repositories/memory.py +332 -111
  32. solana_agent/services/__init__.py +1 -1
  33. solana_agent/services/agent.py +390 -241
  34. solana_agent/services/knowledge_base.py +768 -0
  35. solana_agent/services/query.py +1858 -153
  36. solana_agent/services/realtime.py +626 -0
  37. solana_agent/services/routing.py +104 -51
  38. solana_agent-31.4.0.dist-info/METADATA +1070 -0
  39. solana_agent-31.4.0.dist-info/RECORD +49 -0
  40. {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info}/WHEEL +1 -1
  41. solana_agent-31.4.0.dist-info/entry_points.txt +3 -0
  42. solana_agent/adapters/llm_adapter.py +0 -160
  43. solana_agent-20.1.2.dist-info/METADATA +0 -464
  44. solana_agent-20.1.2.dist-info/RECORD +0 -35
  45. {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,679 @@
1
+ """
2
+ LLM provider adapters for the Solana Agent system.
3
+
4
+ These adapters implement the LLMProvider interface for different LLM services.
5
+ """
6
+
7
+ import logging
8
+ import base64
9
+ import io
10
+ import math
11
+ from typing import (
12
+ AsyncGenerator,
13
+ List,
14
+ Literal,
15
+ Optional,
16
+ Type,
17
+ TypeVar,
18
+ Dict,
19
+ Any,
20
+ Union,
21
+ )
22
+ from PIL import Image
23
+ from openai import AsyncOpenAI, OpenAIError
24
+ from pydantic import BaseModel
25
+ import instructor
26
+ from instructor import Mode
27
+ import logfire
28
+
29
+ from solana_agent.interfaces.providers.llm import LLMProvider
30
+
31
+ # Setup logger for this module
32
+ logger = logging.getLogger(__name__)
33
+
34
+ T = TypeVar("T", bound=BaseModel)
35
+
36
+ DEFAULT_CHAT_MODEL = "gpt-4.1"
37
+ DEFAULT_VISION_MODEL = "gpt-4.1"
38
+ DEFAULT_PARSE_MODEL = "gpt-4.1"
39
+ DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large"
40
+ DEFAULT_EMBEDDING_DIMENSIONS = 3072
41
+ DEFAULT_TRANSCRIPTION_MODEL = "gpt-4o-mini-transcribe"
42
+ DEFAULT_TTS_MODEL = "tts-1"
43
+
44
+ # Image constants
45
+ SUPPORTED_IMAGE_FORMATS = {"PNG", "JPEG", "WEBP", "GIF"}
46
+ MAX_IMAGE_SIZE_MB = 20
47
+ MAX_TOTAL_IMAGE_SIZE_MB = 50
48
+ MAX_IMAGE_COUNT = 500
49
+ GPT41_PATCH_SIZE = 32
50
+ GPT41_MAX_PATCHES = 1536
51
+ GPT41_MINI_MULTIPLIER = 1.62
52
+ GPT41_NANO_MULTIPLIER = 2.46
53
+
54
+
55
+ class OpenAIAdapter(LLMProvider):
56
+ """OpenAI implementation of LLMProvider with web search capabilities."""
57
+
58
+ def __init__(
59
+ self,
60
+ api_key: str,
61
+ base_url: Optional[str] = None,
62
+ model: Optional[str] = None,
63
+ logfire_api_key: Optional[str] = None,
64
+ ):
65
+ self.api_key = api_key
66
+ self.base_url = base_url
67
+
68
+ # Create client with base_url if provided (for Grok support)
69
+ if base_url:
70
+ self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
71
+ else:
72
+ self.client = AsyncOpenAI(api_key=api_key)
73
+
74
+ self.logfire = False
75
+ if logfire_api_key:
76
+ try:
77
+ logfire.configure(token=logfire_api_key)
78
+ self.logfire = True
79
+ # Instrument the main client immediately after configuring logfire
80
+ logfire.instrument_openai(self.client)
81
+ logger.info(
82
+ "Logfire configured and OpenAI client instrumented successfully."
83
+ )
84
+ except Exception as e:
85
+ logger.error(f"Failed to configure Logfire: {e}")
86
+ self.logfire = False
87
+
88
+ # Use provided model or defaults (for Grok or OpenAI)
89
+ if model:
90
+ # Custom model provided (e.g., from Grok config)
91
+ self.parse_model = model
92
+ self.text_model = model
93
+ self.vision_model = model
94
+ else:
95
+ # Use OpenAI defaults
96
+ self.parse_model = DEFAULT_PARSE_MODEL
97
+ self.text_model = DEFAULT_CHAT_MODEL
98
+ self.vision_model = DEFAULT_VISION_MODEL
99
+
100
+ # These remain OpenAI-specific
101
+ self.transcription_model = DEFAULT_TRANSCRIPTION_MODEL
102
+ self.tts_model = DEFAULT_TTS_MODEL
103
+ self.embedding_model = DEFAULT_EMBEDDING_MODEL
104
+ self.embedding_dimensions = DEFAULT_EMBEDDING_DIMENSIONS
105
+
106
+ def get_api_key(self) -> Optional[str]: # pragma: no cover
107
+ """Return the API key used to configure the OpenAI client."""
108
+ return getattr(self, "api_key", None)
109
+
110
+ async def tts(
111
+ self,
112
+ text: str,
113
+ instructions: str = "You speak in a friendly and helpful manner.",
114
+ voice: Literal[
115
+ "alloy",
116
+ "ash",
117
+ "ballad",
118
+ "coral",
119
+ "echo",
120
+ "fable",
121
+ "onyx",
122
+ "nova",
123
+ "sage",
124
+ "shimmer",
125
+ ] = "nova",
126
+ response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = "aac",
127
+ ) -> AsyncGenerator[bytes, None]: # pragma: no cover
128
+ """Stream text-to-speech audio from OpenAI models.
129
+
130
+ Args:
131
+ text: Text to convert to speech
132
+ instructions: Not used in this implementation
133
+ voice: Voice to use for synthesis
134
+ response_format: Audio format
135
+
136
+ Yields:
137
+ Audio bytes as they become available
138
+ """
139
+ try:
140
+ if self.logfire: # Instrument only if logfire is enabled
141
+ logfire.instrument_openai(self.client)
142
+ async with self.client.audio.speech.with_streaming_response.create(
143
+ model=self.tts_model,
144
+ voice=voice,
145
+ input=text,
146
+ response_format=response_format,
147
+ ) as stream:
148
+ # Stream the bytes in 16KB chunks
149
+ async for chunk in stream.iter_bytes(chunk_size=1024 * 16):
150
+ yield chunk
151
+
152
+ except Exception as e:
153
+ # Log the exception with traceback
154
+ logger.exception(f"Error in text_to_speech: {e}")
155
+ yield b"" # Return empty bytes on error
156
+
157
+ async def transcribe_audio(
158
+ self,
159
+ audio_bytes: bytes,
160
+ input_format: Literal[
161
+ "flac", "mp3", "mp4", "mpeg", "mpga", "m4a", "ogg", "wav", "webm"
162
+ ] = "mp4",
163
+ ) -> AsyncGenerator[str, None]: # pragma: no cover
164
+ """Stream transcription of an audio file.
165
+
166
+ Args:
167
+ audio_bytes: Audio file bytes
168
+ input_format: Format of the input audio file
169
+
170
+ Yields:
171
+ Transcript text chunks as they become available
172
+ """
173
+ try:
174
+ if self.logfire: # Instrument only if logfire is enabled
175
+ logfire.instrument_openai(self.client)
176
+ async with self.client.audio.transcriptions.with_streaming_response.create(
177
+ model=self.transcription_model,
178
+ file=(f"file.{input_format}", audio_bytes),
179
+ response_format="text",
180
+ ) as stream:
181
+ # Stream the text in 16KB chunks
182
+ async for chunk in stream.iter_text(chunk_size=1024 * 16):
183
+ yield chunk
184
+
185
+ except Exception as e:
186
+ # Log the exception with traceback
187
+ logger.exception(f"Error in transcribe_audio: {e}")
188
+ yield f"I apologize, but I encountered an error transcribing the audio: {str(e)}"
189
+
190
+ async def generate_text(
191
+ self,
192
+ prompt: str,
193
+ system_prompt: str = "",
194
+ api_key: Optional[str] = None,
195
+ base_url: Optional[str] = None,
196
+ model: Optional[str] = None,
197
+ tools: Optional[List[Dict[str, Any]]] = None,
198
+ ) -> str: # pragma: no cover
199
+ """Generate text or function call from OpenAI models."""
200
+ messages = []
201
+ if system_prompt:
202
+ messages.append({"role": "system", "content": system_prompt})
203
+ messages.append({"role": "user", "content": prompt})
204
+
205
+ request_params = {
206
+ "messages": messages,
207
+ "model": model or self.text_model,
208
+ }
209
+ if tools:
210
+ request_params["tools"] = tools
211
+
212
+ if api_key and base_url:
213
+ client = AsyncOpenAI(api_key=api_key, base_url=base_url)
214
+ else:
215
+ client = self.client
216
+
217
+ if self.logfire:
218
+ logfire.instrument_openai(client)
219
+
220
+ try:
221
+ response = await client.chat.completions.create(**request_params)
222
+ return response
223
+ except OpenAIError as e:
224
+ logger.error(f"OpenAI API error during text generation: {e}")
225
+ return None
226
+ except Exception as e:
227
+ logger.exception(f"Error in generate_text: {e}")
228
+ return None
229
+
230
+ def _calculate_gpt41_image_cost(self, width: int, height: int, model: str) -> int:
231
+ """Calculates the token cost for an image with GPT-4.1 models."""
232
+ patches_wide = math.ceil(width / GPT41_PATCH_SIZE)
233
+ patches_high = math.ceil(height / GPT41_PATCH_SIZE)
234
+ total_patches_needed = patches_wide * patches_high
235
+
236
+ if total_patches_needed > GPT41_MAX_PATCHES:
237
+ scale_factor = math.sqrt(GPT41_MAX_PATCHES / total_patches_needed)
238
+ new_width = math.floor(width * scale_factor)
239
+ new_height = math.floor(height * scale_factor)
240
+
241
+ final_patches_wide_scaled = math.ceil(new_width / GPT41_PATCH_SIZE)
242
+ final_patches_high_scaled = math.ceil(new_height / GPT41_PATCH_SIZE)
243
+ image_tokens = final_patches_wide_scaled * final_patches_high_scaled
244
+
245
+ # Ensure it doesn't exceed the cap due to ceiling operations after scaling
246
+ image_tokens = min(image_tokens, GPT41_MAX_PATCHES)
247
+
248
+ logger.debug(
249
+ f"Image scaled down. Original patches: {total_patches_needed}, New dims: ~{new_width}x{new_height}, Final patches: {image_tokens}"
250
+ )
251
+
252
+ else:
253
+ image_tokens = total_patches_needed
254
+ logger.debug(f"Image fits within patch limit. Patches: {image_tokens}")
255
+
256
+ # Apply model-specific multiplier
257
+ if "mini" in model:
258
+ total_tokens = math.ceil(image_tokens * GPT41_MINI_MULTIPLIER)
259
+ elif "nano" in model:
260
+ total_tokens = math.ceil(image_tokens * GPT41_NANO_MULTIPLIER)
261
+ else: # Assume base gpt-4.1
262
+ total_tokens = image_tokens
263
+
264
+ logger.info(
265
+ f"Calculated token cost for image ({width}x{height}) with model '{model}': {total_tokens} tokens (base image tokens: {image_tokens})"
266
+ )
267
+ return total_tokens
268
+
269
+ async def generate_text_with_images(
270
+ self,
271
+ prompt: str,
272
+ images: List[Union[str, bytes]],
273
+ system_prompt: str = "",
274
+ detail: Literal["low", "high", "auto"] = "auto",
275
+ ) -> str: # pragma: no cover
276
+ """Generate text from OpenAI models using text and image inputs."""
277
+ if not images:
278
+ logger.warning(
279
+ "generate_text_with_images called with no images. Falling back to generate_text."
280
+ )
281
+ return await self.generate_text(prompt, system_prompt)
282
+
283
+ target_model = self.vision_model
284
+ if "gpt-4.1" not in target_model: # Basic check for vision model
285
+ logger.warning(
286
+ f"Model '{target_model}' might not support vision. Using it anyway."
287
+ )
288
+
289
+ content_list: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
290
+ total_image_bytes = 0
291
+ total_image_tokens = 0
292
+
293
+ if len(images) > MAX_IMAGE_COUNT:
294
+ logger.error(
295
+ f"Too many images provided ({len(images)}). Maximum is {MAX_IMAGE_COUNT}."
296
+ )
297
+ return f"Error: Too many images provided ({len(images)}). Maximum is {MAX_IMAGE_COUNT}."
298
+
299
+ for i, image_input in enumerate(images):
300
+ image_url_data: Dict[str, Any] = {"detail": detail}
301
+ image_bytes: Optional[bytes] = None
302
+ image_format: Optional[str] = None
303
+ width: Optional[int] = None
304
+ height: Optional[int] = None
305
+
306
+ try:
307
+ if isinstance(image_input, str): # It's a URL
308
+ logger.debug(f"Processing image URL: {image_input[:50]}...")
309
+ image_url_data["url"] = image_input
310
+ # Cannot easily validate size/format/dimensions or calculate cost for URLs
311
+ logger.warning(
312
+ "Cannot validate size/format or calculate token cost for image URLs."
313
+ )
314
+
315
+ elif isinstance(image_input, bytes): # It's image bytes
316
+ logger.debug(
317
+ f"Processing image bytes (size: {len(image_input)})..."
318
+ )
319
+ image_bytes = image_input
320
+ size_mb = len(image_bytes) / (1024 * 1024)
321
+ if size_mb > MAX_IMAGE_SIZE_MB:
322
+ logger.error(
323
+ f"Image {i + 1} size ({size_mb:.2f}MB) exceeds limit ({MAX_IMAGE_SIZE_MB}MB)."
324
+ )
325
+ return f"Error: Image {i + 1} size ({size_mb:.2f}MB) exceeds limit ({MAX_IMAGE_SIZE_MB}MB)."
326
+ total_image_bytes += len(image_bytes)
327
+
328
+ # Use Pillow to validate format and get dimensions
329
+ try:
330
+ img = Image.open(io.BytesIO(image_bytes))
331
+ image_format = img.format
332
+ width, height = img.size
333
+ img.verify() # Verify integrity
334
+ # Re-open after verify
335
+ img = Image.open(io.BytesIO(image_bytes))
336
+ width, height = img.size # Get dimensions again
337
+
338
+ if image_format not in SUPPORTED_IMAGE_FORMATS:
339
+ logger.error(
340
+ f"Unsupported image format '{image_format}' for image {i + 1}."
341
+ )
342
+ return f"Error: Unsupported image format '{image_format}'. Supported formats: {SUPPORTED_IMAGE_FORMATS}."
343
+
344
+ logger.debug(
345
+ f"Image {i + 1}: Format={image_format}, Dimensions={width}x{height}"
346
+ )
347
+
348
+ # Calculate cost only if dimensions are available
349
+ if width and height and "gpt-4.1" in target_model:
350
+ total_image_tokens += self._calculate_gpt41_image_cost(
351
+ width, height, target_model
352
+ )
353
+
354
+ except (IOError, SyntaxError) as img_err:
355
+ logger.error(
356
+ f"Invalid or corrupted image data for image {i + 1}: {img_err}"
357
+ )
358
+ return f"Error: Invalid or corrupted image data provided for image {i + 1}."
359
+ except Exception as pillow_err:
360
+ logger.error(
361
+ f"Pillow error processing image {i + 1}: {pillow_err}"
362
+ )
363
+ return f"Error: Could not process image data for image {i + 1}."
364
+
365
+ # Encode to Base64 Data URL
366
+ mime_type = Image.MIME.get(image_format)
367
+ if not mime_type:
368
+ logger.warning(
369
+ f"Could not determine MIME type for format {image_format}. Defaulting to image/jpeg."
370
+ )
371
+ mime_type = "image/jpeg"
372
+ base64_image = base64.b64encode(image_bytes).decode("utf-8")
373
+ image_url_data["url"] = f"data:{mime_type};base64,{base64_image}"
374
+
375
+ else:
376
+ logger.error(
377
+ f"Invalid image input type for image {i + 1}: {type(image_input)}"
378
+ )
379
+ return f"Error: Invalid image input type for image {i + 1}. Must be URL (str) or bytes."
380
+
381
+ content_list.append({"type": "image_url", "image_url": image_url_data})
382
+
383
+ except Exception as proc_err:
384
+ logger.error(
385
+ f"Error processing image {i + 1}: {proc_err}", exc_info=True
386
+ )
387
+ return f"Error: Failed to process image {i + 1}."
388
+
389
+ total_size_mb = total_image_bytes / (1024 * 1024)
390
+ if total_size_mb > MAX_TOTAL_IMAGE_SIZE_MB:
391
+ logger.error(
392
+ f"Total image size ({total_size_mb:.2f}MB) exceeds limit ({MAX_TOTAL_IMAGE_SIZE_MB}MB)."
393
+ )
394
+ return f"Error: Total image size ({total_size_mb:.2f}MB) exceeds limit ({MAX_TOTAL_IMAGE_SIZE_MB}MB)."
395
+
396
+ messages: List[Dict[str, Any]] = []
397
+ if system_prompt:
398
+ messages.append({"role": "system", "content": system_prompt})
399
+ messages.append({"role": "user", "content": content_list})
400
+
401
+ request_params = {
402
+ "messages": messages,
403
+ "model": target_model,
404
+ # "max_tokens": 300 # Optional: Add max_tokens if needed
405
+ }
406
+
407
+ if self.logfire:
408
+ logfire.instrument_openai(self.client)
409
+
410
+ logger.info(
411
+ f"Sending request to '{target_model}' with {len(images)} images. Total calculated image tokens (approx): {total_image_tokens}"
412
+ )
413
+
414
+ try:
415
+ response = await self.client.chat.completions.create(**request_params)
416
+ if response.choices and response.choices[0].message.content:
417
+ # Log actual usage if available
418
+ if response.usage:
419
+ logger.info(
420
+ f"OpenAI API Usage: Prompt={response.usage.prompt_tokens}, Completion={response.usage.completion_tokens}, Total={response.usage.total_tokens}"
421
+ )
422
+ return response.choices[0].message.content
423
+ else:
424
+ logger.warning("Received vision response with no content.")
425
+ return ""
426
+ except OpenAIError as e: # Catch specific OpenAI errors
427
+ logger.error(f"OpenAI API error during vision request: {e}")
428
+ return f"I apologize, but I encountered an API error: {e}"
429
+ except Exception as e:
430
+ logger.exception(f"Error in generate_text_with_images: {e}")
431
+ return f"I apologize, but I encountered an unexpected error: {e}"
432
+
433
+ async def chat_stream(
434
+ self,
435
+ messages: List[Dict[str, Any]],
436
+ model: Optional[str] = None,
437
+ tools: Optional[List[Dict[str, Any]]] = None,
438
+ api_key: Optional[str] = None,
439
+ base_url: Optional[str] = None,
440
+ ) -> AsyncGenerator[Dict[str, Any], None]: # pragma: no cover
441
+ """Stream chat completions with optional tool calls, yielding normalized events."""
442
+ try:
443
+ request_params: Dict[str, Any] = {
444
+ "messages": messages,
445
+ "model": model or self.text_model,
446
+ "stream": True,
447
+ }
448
+ if tools:
449
+ request_params["tools"] = tools
450
+
451
+ # Use custom client if api_key and base_url provided, otherwise use default
452
+ if api_key and base_url:
453
+ client = AsyncOpenAI(api_key=api_key, base_url=base_url)
454
+ else:
455
+ client = self.client
456
+
457
+ if self.logfire:
458
+ logfire.instrument_openai(client)
459
+
460
+ stream = await client.chat.completions.create(**request_params)
461
+ async for chunk in stream:
462
+ try:
463
+ if not chunk or not getattr(chunk, "choices", None):
464
+ continue
465
+ ch = chunk.choices[0]
466
+ delta = getattr(ch, "delta", None)
467
+ if delta is None:
468
+ # Some SDKs use 'message' instead of 'delta'
469
+ delta = getattr(ch, "message", None)
470
+ if delta is None:
471
+ # Finish event
472
+ finish = getattr(ch, "finish_reason", None)
473
+ if finish:
474
+ yield {"type": "message_end", "finish_reason": finish}
475
+ continue
476
+
477
+ # Content delta
478
+ content_piece = getattr(delta, "content", None)
479
+ if content_piece:
480
+ yield {"type": "content", "delta": content_piece}
481
+
482
+ # Tool call deltas
483
+ tool_calls = getattr(delta, "tool_calls", None)
484
+ if tool_calls:
485
+ for idx, tc in enumerate(tool_calls):
486
+ try:
487
+ tc_id = getattr(tc, "id", None)
488
+ func = getattr(tc, "function", None)
489
+ name = getattr(func, "name", None) if func else None
490
+ args_piece = (
491
+ getattr(func, "arguments", "") if func else ""
492
+ )
493
+ yield {
494
+ "type": "tool_call_delta",
495
+ "id": tc_id,
496
+ "index": getattr(tc, "index", idx),
497
+ "name": name,
498
+ "arguments_delta": args_piece or "",
499
+ }
500
+ except Exception:
501
+ continue
502
+ except Exception as parse_err:
503
+ logger.debug(f"Error parsing stream chunk: {parse_err}")
504
+ continue
505
+ # End of stream (SDK may not emit finish event in all cases)
506
+ yield {"type": "message_end", "finish_reason": "end_of_stream"}
507
+ except Exception as e:
508
+ logger.exception(f"Error in chat_stream: {e}")
509
+ yield {"type": "error", "error": str(e)}
510
+
511
+ async def parse_structured_output(
512
+ self,
513
+ prompt: str,
514
+ system_prompt: str,
515
+ model_class: Type[T],
516
+ api_key: Optional[str] = None,
517
+ base_url: Optional[str] = None,
518
+ model: Optional[str] = None,
519
+ tools: Optional[List[Dict[str, Any]]] = None,
520
+ ) -> T: # pragma: no cover
521
+ """Generate structured output using Pydantic model parsing with Instructor."""
522
+
523
+ messages = []
524
+ messages.append({"role": "system", "content": system_prompt})
525
+ messages.append({"role": "user", "content": prompt})
526
+
527
+ try:
528
+ if api_key and base_url:
529
+ client = AsyncOpenAI(api_key=api_key, base_url=base_url)
530
+ else:
531
+ client = self.client
532
+
533
+ if self.logfire:
534
+ logfire.instrument_openai(client)
535
+
536
+ # Use the provided model or the default parse model
537
+ current_parse_model = model or self.parse_model
538
+
539
+ patched_client = instructor.from_openai(client, mode=Mode.TOOLS_STRICT)
540
+
541
+ create_args = {
542
+ "model": current_parse_model,
543
+ "messages": messages,
544
+ "response_model": model_class,
545
+ "max_retries": 2, # Automatically retry on validation errors
546
+ }
547
+ if tools:
548
+ create_args["tools"] = tools
549
+
550
+ response = await patched_client.chat.completions.create(**create_args)
551
+ return response
552
+ except Exception as e:
553
+ logger.warning(
554
+ f"Instructor parsing (TOOLS_STRICT mode) failed: {e}"
555
+ ) # Log warning
556
+
557
+ try:
558
+ # Determine client again for fallback
559
+ if api_key and base_url:
560
+ client = AsyncOpenAI(api_key=api_key, base_url=base_url)
561
+ else:
562
+ client = self.client
563
+
564
+ if self.logfire: # Instrument again if needed
565
+ logfire.instrument_openai(client)
566
+
567
+ # Use the provided model or the default parse model
568
+ current_parse_model = model or self.parse_model
569
+
570
+ # First fallback: Try regular JSON mode
571
+ logger.info("Falling back to instructor JSON mode.") # Log info
572
+ patched_client = instructor.from_openai(client, mode=Mode.JSON)
573
+ response = await patched_client.chat.completions.create(
574
+ model=current_parse_model, # Use the determined model
575
+ messages=messages,
576
+ response_model=model_class,
577
+ max_retries=1,
578
+ )
579
+ return response
580
+ except Exception as json_error:
581
+ logger.warning(
582
+ f"Instructor JSON mode fallback also failed: {json_error}"
583
+ ) # Log warning
584
+
585
+ try:
586
+ # Determine client again for final fallback
587
+ if api_key and base_url:
588
+ client = AsyncOpenAI(api_key=api_key, base_url=base_url)
589
+ else:
590
+ client = self.client
591
+
592
+ if self.logfire: # Instrument again if needed
593
+ logfire.instrument_openai(client)
594
+
595
+ # Use the provided model or the default parse model
596
+ current_parse_model = model or self.parse_model
597
+
598
+ # Final fallback: Manual extraction with a detailed prompt
599
+ logger.info("Falling back to manual JSON extraction.") # Log info
600
+ fallback_system_prompt = f"""
601
+ {system_prompt}
602
+
603
+ You must respond with valid JSON that can be parsed as the following Pydantic model:
604
+ {model_class.model_json_schema()}
605
+
606
+ Ensure the response contains ONLY the JSON object and nothing else.
607
+ """
608
+
609
+ # Regular completion without instructor
610
+ completion = await client.chat.completions.create(
611
+ model=current_parse_model, # Use the determined model
612
+ messages=[
613
+ {"role": "system", "content": fallback_system_prompt},
614
+ {"role": "user", "content": prompt},
615
+ ],
616
+ response_format={"type": "json_object"},
617
+ )
618
+
619
+ # Extract and parse the JSON response
620
+ json_str = completion.choices[0].message.content
621
+
622
+ # Use Pydantic to parse and validate
623
+ return model_class.model_validate_json(json_str)
624
+
625
+ except Exception as fallback_error:
626
+ # Log the final exception with traceback
627
+ logger.exception(
628
+ f"All structured output fallback methods failed: {fallback_error}"
629
+ )
630
+ raise ValueError(
631
+ f"Failed to generate structured output: {e}. All fallbacks failed."
632
+ ) from e
633
+
634
+ async def embed_text(
635
+ self, text: str, model: Optional[str] = None, dimensions: Optional[int] = None
636
+ ) -> List[float]: # pragma: no cover
637
+ """Generate an embedding for the given text using OpenAI.
638
+
639
+ Args:
640
+ text: The text to embed.
641
+ model: The embedding model to use (defaults to text-embedding-3-large).
642
+ dimensions: Desired output dimensions for the embedding.
643
+
644
+ Returns:
645
+ A list of floats representing the embedding vector.
646
+ """
647
+ if not text:
648
+ # Log error instead of raising immediately, let caller handle empty input if needed
649
+ logger.error("Attempted to embed empty text.")
650
+ raise ValueError("Text cannot be empty")
651
+
652
+ try:
653
+ # Use provided model/dimensions or fall back to defaults
654
+ embedding_model = model or self.embedding_model
655
+ embedding_dimensions = dimensions or self.embedding_dimensions
656
+
657
+ # Replace newlines with spaces as recommended by OpenAI
658
+ text = text.replace("\n", " ")
659
+
660
+ if self.logfire: # Instrument only if logfire is enabled
661
+ logfire.instrument_openai(self.client)
662
+
663
+ response = await self.client.embeddings.create(
664
+ input=[text], model=embedding_model, dimensions=embedding_dimensions
665
+ )
666
+
667
+ if response.data and response.data[0].embedding:
668
+ return response.data[0].embedding
669
+ else:
670
+ # Log warning about unexpected response structure
671
+ logger.warning(
672
+ "Failed to retrieve embedding from OpenAI response structure."
673
+ )
674
+ raise ValueError("Failed to retrieve embedding from OpenAI response")
675
+
676
+ except Exception as e:
677
+ # Log the exception with traceback before raising
678
+ logger.exception(f"Error generating embedding: {e}")
679
+ raise # Re-raise the original exception