agent-runtime-core 0.7.1__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime_core/__init__.py +1 -1
- agent_runtime_core/files/__init__.py +88 -0
- agent_runtime_core/files/base.py +343 -0
- agent_runtime_core/files/ocr.py +406 -0
- agent_runtime_core/files/processors.py +508 -0
- agent_runtime_core/files/tools.py +317 -0
- agent_runtime_core/files/vision.py +360 -0
- {agent_runtime_core-0.7.1.dist-info → agent_runtime_core-0.8.0.dist-info}/METADATA +35 -1
- {agent_runtime_core-0.7.1.dist-info → agent_runtime_core-0.8.0.dist-info}/RECORD +11 -5
- {agent_runtime_core-0.7.1.dist-info → agent_runtime_core-0.8.0.dist-info}/WHEEL +0 -0
- {agent_runtime_core-0.7.1.dist-info → agent_runtime_core-0.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""
|
|
2
|
+
File read/write tools for agents.
|
|
3
|
+
|
|
4
|
+
Provides sandboxed file access tools that agents can use to read and write files.
|
|
5
|
+
All file operations are restricted to configured allowed directories.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from agent_runtime_core.files.tools import FileTools, FileToolsConfig
|
|
9
|
+
|
|
10
|
+
config = FileToolsConfig(
|
|
11
|
+
allowed_directories=["/app/uploads", "/app/outputs"],
|
|
12
|
+
max_file_size_bytes=50 * 1024 * 1024, # 50MB
|
|
13
|
+
)
|
|
14
|
+
tools = FileTools(config)
|
|
15
|
+
|
|
16
|
+
# Read a file
|
|
17
|
+
result = await tools.read_file("document.pdf")
|
|
18
|
+
|
|
19
|
+
# Write a file
|
|
20
|
+
await tools.write_file("output.txt", "Hello, world!")
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import os
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Any, Optional, Union
|
|
27
|
+
import base64
|
|
28
|
+
|
|
29
|
+
from .base import FileProcessorRegistry, ProcessingOptions, ProcessedFile, FileType
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class FileToolsConfig:
|
|
34
|
+
"""Configuration for file tools."""
|
|
35
|
+
# Sandboxing
|
|
36
|
+
allowed_directories: list[str] = field(default_factory=lambda: ["."])
|
|
37
|
+
|
|
38
|
+
# Size limits
|
|
39
|
+
max_file_size_bytes: int = 100 * 1024 * 1024 # 100MB default
|
|
40
|
+
max_write_size_bytes: int = 100 * 1024 * 1024 # 100MB default
|
|
41
|
+
|
|
42
|
+
# Processing options
|
|
43
|
+
use_ocr: bool = False
|
|
44
|
+
ocr_provider: Optional[str] = None
|
|
45
|
+
use_vision: bool = False
|
|
46
|
+
vision_provider: Optional[str] = None
|
|
47
|
+
|
|
48
|
+
# Write options
|
|
49
|
+
allow_overwrite: bool = False
|
|
50
|
+
create_directories: bool = True
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_file_read_schema() -> dict[str, Any]:
|
|
54
|
+
"""Get the tool schema for file_read in OpenAI format."""
|
|
55
|
+
from ..tools import ToolSchemaBuilder
|
|
56
|
+
|
|
57
|
+
return (
|
|
58
|
+
ToolSchemaBuilder("file_read")
|
|
59
|
+
.description(
|
|
60
|
+
"Read and process a file. Extracts text content from various file types "
|
|
61
|
+
"including PDF, DOCX, images (with optional OCR), spreadsheets, and text files. "
|
|
62
|
+
"Returns the extracted text and metadata."
|
|
63
|
+
)
|
|
64
|
+
.param("path", "string", "Path to the file to read", required=True)
|
|
65
|
+
.param("use_ocr", "boolean", "Use OCR for images/scanned documents", default=False)
|
|
66
|
+
.param("ocr_provider", "string", "OCR provider to use",
|
|
67
|
+
enum=["tesseract", "google", "aws", "azure"])
|
|
68
|
+
.param("use_vision", "boolean", "Use AI vision for image analysis", default=False)
|
|
69
|
+
.param("vision_provider", "string", "Vision AI provider to use",
|
|
70
|
+
enum=["openai", "anthropic", "gemini"])
|
|
71
|
+
.param("vision_prompt", "string", "Custom prompt for vision analysis")
|
|
72
|
+
.to_openai_format()
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_file_write_schema() -> dict[str, Any]:
|
|
77
|
+
"""Get the tool schema for file_write in OpenAI format."""
|
|
78
|
+
from ..tools import ToolSchemaBuilder
|
|
79
|
+
|
|
80
|
+
return (
|
|
81
|
+
ToolSchemaBuilder("file_write")
|
|
82
|
+
.description(
|
|
83
|
+
"Write content to a file. Can write text content or base64-encoded binary data. "
|
|
84
|
+
"The file path must be within allowed directories."
|
|
85
|
+
)
|
|
86
|
+
.param("path", "string", "Path where the file should be written", required=True)
|
|
87
|
+
.param("content", "string", "Content to write (text or base64 for binary)", required=True)
|
|
88
|
+
.param("encoding", "string", "Content encoding", enum=["text", "base64"], default="text")
|
|
89
|
+
.param("overwrite", "boolean", "Whether to overwrite existing files", default=False)
|
|
90
|
+
.to_openai_format()
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class FileTools:
|
|
95
|
+
"""
|
|
96
|
+
File read/write tools for agents with sandboxing.
|
|
97
|
+
|
|
98
|
+
All file operations are restricted to configured allowed directories.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
config: Optional[FileToolsConfig] = None,
|
|
104
|
+
registry: Optional[FileProcessorRegistry] = None,
|
|
105
|
+
):
|
|
106
|
+
self.config = config or FileToolsConfig()
|
|
107
|
+
self.registry = registry
|
|
108
|
+
if self.registry is None:
|
|
109
|
+
self.registry = FileProcessorRegistry()
|
|
110
|
+
self.registry.auto_register()
|
|
111
|
+
|
|
112
|
+
def _resolve_path(self, path: str) -> Path:
|
|
113
|
+
"""Resolve and validate a file path against allowed directories."""
|
|
114
|
+
# Resolve to absolute path
|
|
115
|
+
resolved = Path(path).resolve()
|
|
116
|
+
|
|
117
|
+
# Check if path is within any allowed directory
|
|
118
|
+
for allowed_dir in self.config.allowed_directories:
|
|
119
|
+
allowed_path = Path(allowed_dir).resolve()
|
|
120
|
+
try:
|
|
121
|
+
resolved.relative_to(allowed_path)
|
|
122
|
+
return resolved
|
|
123
|
+
except ValueError:
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
raise PermissionError(
|
|
127
|
+
f"Access denied: '{path}' is not within allowed directories. "
|
|
128
|
+
f"Allowed: {self.config.allowed_directories}"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
async def read_file(
|
|
132
|
+
self,
|
|
133
|
+
path: str,
|
|
134
|
+
use_ocr: bool = False,
|
|
135
|
+
ocr_provider: Optional[str] = None,
|
|
136
|
+
use_vision: bool = False,
|
|
137
|
+
vision_provider: Optional[str] = None,
|
|
138
|
+
vision_prompt: Optional[str] = None,
|
|
139
|
+
) -> dict[str, Any]:
|
|
140
|
+
"""
|
|
141
|
+
Read and process a file.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
path: Path to file
|
|
145
|
+
use_ocr: Whether to use OCR for images/scanned documents
|
|
146
|
+
ocr_provider: OCR provider (tesseract, google, aws, azure)
|
|
147
|
+
use_vision: Whether to use AI vision analysis
|
|
148
|
+
vision_provider: Vision provider (openai, anthropic, gemini)
|
|
149
|
+
vision_prompt: Custom prompt for vision analysis
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Dict with extracted text, metadata, and processing info
|
|
153
|
+
"""
|
|
154
|
+
resolved_path = self._resolve_path(path)
|
|
155
|
+
|
|
156
|
+
if not resolved_path.exists():
|
|
157
|
+
raise FileNotFoundError(f"File not found: {path}")
|
|
158
|
+
|
|
159
|
+
# Check file size
|
|
160
|
+
file_size = resolved_path.stat().st_size
|
|
161
|
+
if file_size > self.config.max_file_size_bytes:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f"File size ({file_size} bytes) exceeds limit "
|
|
164
|
+
f"({self.config.max_file_size_bytes} bytes)"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Read file content
|
|
168
|
+
content = resolved_path.read_bytes()
|
|
169
|
+
|
|
170
|
+
# Build processing options
|
|
171
|
+
options = ProcessingOptions(
|
|
172
|
+
max_size_bytes=self.config.max_file_size_bytes,
|
|
173
|
+
use_ocr=use_ocr or self.config.use_ocr,
|
|
174
|
+
ocr_provider=ocr_provider or self.config.ocr_provider,
|
|
175
|
+
use_vision=use_vision or self.config.use_vision,
|
|
176
|
+
vision_provider=vision_provider or self.config.vision_provider,
|
|
177
|
+
vision_prompt=vision_prompt,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Process file
|
|
181
|
+
result = await self.registry.process(resolved_path.name, content, options)
|
|
182
|
+
|
|
183
|
+
# Return as dict for tool response
|
|
184
|
+
return {
|
|
185
|
+
"filename": result.filename,
|
|
186
|
+
"file_type": result.file_type.value,
|
|
187
|
+
"mime_type": result.mime_type,
|
|
188
|
+
"size_bytes": result.size_bytes,
|
|
189
|
+
"text": result.text,
|
|
190
|
+
"metadata": result.metadata,
|
|
191
|
+
"ocr_text": result.ocr_text,
|
|
192
|
+
"vision_description": result.vision_description,
|
|
193
|
+
"warnings": result.warnings,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
async def write_file(
|
|
197
|
+
self,
|
|
198
|
+
path: str,
|
|
199
|
+
content: str,
|
|
200
|
+
encoding: str = "text",
|
|
201
|
+
overwrite: bool = False,
|
|
202
|
+
) -> dict[str, Any]:
|
|
203
|
+
"""
|
|
204
|
+
Write content to a file.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
path: Path where the file should be written
|
|
208
|
+
content: Content to write (text or base64 for binary)
|
|
209
|
+
encoding: Content encoding ("text" or "base64")
|
|
210
|
+
overwrite: Whether to overwrite existing files
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Dict with file info after writing
|
|
214
|
+
"""
|
|
215
|
+
resolved_path = self._resolve_path(path)
|
|
216
|
+
|
|
217
|
+
# Check if file exists and overwrite is allowed
|
|
218
|
+
if resolved_path.exists():
|
|
219
|
+
if not (overwrite or self.config.allow_overwrite):
|
|
220
|
+
raise FileExistsError(
|
|
221
|
+
f"File already exists: {path}. Set overwrite=True to replace."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Decode content
|
|
225
|
+
if encoding == "base64":
|
|
226
|
+
try:
|
|
227
|
+
data = base64.b64decode(content)
|
|
228
|
+
except Exception as e:
|
|
229
|
+
raise ValueError(f"Invalid base64 content: {e}")
|
|
230
|
+
else:
|
|
231
|
+
data = content.encode("utf-8")
|
|
232
|
+
|
|
233
|
+
# Check size limit
|
|
234
|
+
if len(data) > self.config.max_write_size_bytes:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Content size ({len(data)} bytes) exceeds write limit "
|
|
237
|
+
f"({self.config.max_write_size_bytes} bytes)"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Create parent directories if needed
|
|
241
|
+
if self.config.create_directories:
|
|
242
|
+
resolved_path.parent.mkdir(parents=True, exist_ok=True)
|
|
243
|
+
|
|
244
|
+
# Write file
|
|
245
|
+
resolved_path.write_bytes(data)
|
|
246
|
+
|
|
247
|
+
return {
|
|
248
|
+
"success": True,
|
|
249
|
+
"path": str(resolved_path),
|
|
250
|
+
"size_bytes": len(data),
|
|
251
|
+
"encoding": encoding,
|
|
252
|
+
"overwritten": resolved_path.exists() and (overwrite or self.config.allow_overwrite),
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
async def list_files(
|
|
256
|
+
self,
|
|
257
|
+
directory: str = ".",
|
|
258
|
+
pattern: str = "*",
|
|
259
|
+
recursive: bool = False,
|
|
260
|
+
) -> dict[str, Any]:
|
|
261
|
+
"""
|
|
262
|
+
List files in a directory.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
directory: Directory to list (must be within allowed directories)
|
|
266
|
+
pattern: Glob pattern to filter files
|
|
267
|
+
recursive: Whether to search recursively
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
Dict with list of files
|
|
271
|
+
"""
|
|
272
|
+
resolved_dir = self._resolve_path(directory)
|
|
273
|
+
|
|
274
|
+
if not resolved_dir.is_dir():
|
|
275
|
+
raise NotADirectoryError(f"Not a directory: {directory}")
|
|
276
|
+
|
|
277
|
+
if recursive:
|
|
278
|
+
files = list(resolved_dir.rglob(pattern))
|
|
279
|
+
else:
|
|
280
|
+
files = list(resolved_dir.glob(pattern))
|
|
281
|
+
|
|
282
|
+
# Filter to only files (not directories)
|
|
283
|
+
files = [f for f in files if f.is_file()]
|
|
284
|
+
|
|
285
|
+
return {
|
|
286
|
+
"directory": str(resolved_dir),
|
|
287
|
+
"pattern": pattern,
|
|
288
|
+
"recursive": recursive,
|
|
289
|
+
"count": len(files),
|
|
290
|
+
"files": [
|
|
291
|
+
{
|
|
292
|
+
"name": f.name,
|
|
293
|
+
"path": str(f),
|
|
294
|
+
"size_bytes": f.stat().st_size,
|
|
295
|
+
"modified": f.stat().st_mtime,
|
|
296
|
+
}
|
|
297
|
+
for f in files[:100] # Limit to 100 files
|
|
298
|
+
],
|
|
299
|
+
"truncated": len(files) > 100,
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def get_file_list_schema() -> dict[str, Any]:
|
|
304
|
+
"""Get the tool schema for file_list in OpenAI format."""
|
|
305
|
+
from ..tools import ToolSchemaBuilder
|
|
306
|
+
|
|
307
|
+
return (
|
|
308
|
+
ToolSchemaBuilder("file_list")
|
|
309
|
+
.description(
|
|
310
|
+
"List files in a directory. Returns file names, sizes, and modification times. "
|
|
311
|
+
"The directory must be within allowed directories."
|
|
312
|
+
)
|
|
313
|
+
.param("directory", "string", "Directory to list", default=".")
|
|
314
|
+
.param("pattern", "string", "Glob pattern to filter files", default="*")
|
|
315
|
+
.param("recursive", "boolean", "Search recursively", default=False)
|
|
316
|
+
.to_openai_format()
|
|
317
|
+
)
|
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AI Vision providers for image analysis.
|
|
3
|
+
|
|
4
|
+
Supports multiple AI vision providers:
|
|
5
|
+
- OpenAI GPT-4 Vision
|
|
6
|
+
- Anthropic Claude Vision
|
|
7
|
+
- Google Gemini Vision
|
|
8
|
+
|
|
9
|
+
All providers are optional - install the corresponding library to use.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import Any
|
|
15
|
+
import base64
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class VisionResult:
|
|
23
|
+
"""Result from AI vision analysis."""
|
|
24
|
+
|
|
25
|
+
description: str
|
|
26
|
+
"""Natural language description of the image."""
|
|
27
|
+
|
|
28
|
+
labels: list[str] = field(default_factory=list)
|
|
29
|
+
"""Detected labels/objects in the image."""
|
|
30
|
+
|
|
31
|
+
raw_response: Any = None
|
|
32
|
+
"""Raw response from the vision provider."""
|
|
33
|
+
|
|
34
|
+
model: str = ""
|
|
35
|
+
"""Model used for analysis."""
|
|
36
|
+
|
|
37
|
+
usage: dict[str, int] = field(default_factory=dict)
|
|
38
|
+
"""Token usage information."""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class VisionProvider(ABC):
|
|
42
|
+
"""Abstract base class for AI vision providers."""
|
|
43
|
+
|
|
44
|
+
name: str = "base"
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
async def analyze_image(
|
|
48
|
+
self,
|
|
49
|
+
image_data: bytes,
|
|
50
|
+
prompt: str = "Describe this image in detail.",
|
|
51
|
+
mime_type: str = "image/png",
|
|
52
|
+
**kwargs,
|
|
53
|
+
) -> VisionResult:
|
|
54
|
+
"""
|
|
55
|
+
Analyze an image using AI vision.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
image_data: Raw image bytes
|
|
59
|
+
prompt: Question or instruction for the vision model
|
|
60
|
+
mime_type: MIME type of the image
|
|
61
|
+
**kwargs: Provider-specific options
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
VisionResult with description and metadata
|
|
65
|
+
"""
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def is_available(cls) -> bool:
|
|
70
|
+
"""Check if this provider's dependencies are installed."""
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
def _encode_image(self, image_data: bytes) -> str:
|
|
74
|
+
"""Encode image data to base64."""
|
|
75
|
+
return base64.b64encode(image_data).decode("utf-8")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class OpenAIVision(VisionProvider):
|
|
79
|
+
"""OpenAI GPT-4 Vision provider."""
|
|
80
|
+
|
|
81
|
+
name = "openai"
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
api_key: str | None = None,
|
|
86
|
+
model: str = "gpt-4o",
|
|
87
|
+
max_tokens: int = 1024,
|
|
88
|
+
):
|
|
89
|
+
self.api_key = api_key
|
|
90
|
+
self.model = model
|
|
91
|
+
self.max_tokens = max_tokens
|
|
92
|
+
self._client = None
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def is_available(cls) -> bool:
|
|
96
|
+
try:
|
|
97
|
+
import openai # noqa: F401
|
|
98
|
+
return True
|
|
99
|
+
except ImportError:
|
|
100
|
+
return False
|
|
101
|
+
|
|
102
|
+
def _get_client(self):
|
|
103
|
+
if self._client is None:
|
|
104
|
+
try:
|
|
105
|
+
from openai import AsyncOpenAI
|
|
106
|
+
self._client = AsyncOpenAI(api_key=self.api_key)
|
|
107
|
+
except ImportError:
|
|
108
|
+
raise ImportError(
|
|
109
|
+
"OpenAI library not installed. Install with: pip install openai"
|
|
110
|
+
)
|
|
111
|
+
return self._client
|
|
112
|
+
|
|
113
|
+
async def analyze_image(
|
|
114
|
+
self,
|
|
115
|
+
image_data: bytes,
|
|
116
|
+
prompt: str = "Describe this image in detail.",
|
|
117
|
+
mime_type: str = "image/png",
|
|
118
|
+
**kwargs,
|
|
119
|
+
) -> VisionResult:
|
|
120
|
+
client = self._get_client()
|
|
121
|
+
|
|
122
|
+
base64_image = self._encode_image(image_data)
|
|
123
|
+
data_url = f"data:{mime_type};base64,{base64_image}"
|
|
124
|
+
|
|
125
|
+
response = await client.chat.completions.create(
|
|
126
|
+
model=kwargs.get("model", self.model),
|
|
127
|
+
max_tokens=kwargs.get("max_tokens", self.max_tokens),
|
|
128
|
+
messages=[
|
|
129
|
+
{
|
|
130
|
+
"role": "user",
|
|
131
|
+
"content": [
|
|
132
|
+
{"type": "text", "text": prompt},
|
|
133
|
+
{"type": "image_url", "image_url": {"url": data_url}},
|
|
134
|
+
],
|
|
135
|
+
}
|
|
136
|
+
],
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
description = response.choices[0].message.content or ""
|
|
140
|
+
|
|
141
|
+
return VisionResult(
|
|
142
|
+
description=description,
|
|
143
|
+
model=response.model,
|
|
144
|
+
raw_response=response.model_dump(),
|
|
145
|
+
usage={
|
|
146
|
+
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
|
147
|
+
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
|
148
|
+
},
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class AnthropicVision(VisionProvider):
|
|
153
|
+
"""Anthropic Claude Vision provider."""
|
|
154
|
+
|
|
155
|
+
name = "anthropic"
|
|
156
|
+
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
api_key: str | None = None,
|
|
160
|
+
model: str = "claude-sonnet-4-20250514",
|
|
161
|
+
max_tokens: int = 1024,
|
|
162
|
+
):
|
|
163
|
+
self.api_key = api_key
|
|
164
|
+
self.model = model
|
|
165
|
+
self.max_tokens = max_tokens
|
|
166
|
+
self._client = None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def is_available(cls) -> bool:
|
|
172
|
+
try:
|
|
173
|
+
import anthropic # noqa: F401
|
|
174
|
+
return True
|
|
175
|
+
except ImportError:
|
|
176
|
+
return False
|
|
177
|
+
|
|
178
|
+
def _get_client(self):
|
|
179
|
+
if self._client is None:
|
|
180
|
+
try:
|
|
181
|
+
from anthropic import AsyncAnthropic
|
|
182
|
+
self._client = AsyncAnthropic(api_key=self.api_key)
|
|
183
|
+
except ImportError:
|
|
184
|
+
raise ImportError(
|
|
185
|
+
"Anthropic library not installed. Install with: pip install anthropic"
|
|
186
|
+
)
|
|
187
|
+
return self._client
|
|
188
|
+
|
|
189
|
+
async def analyze_image(
|
|
190
|
+
self,
|
|
191
|
+
image_data: bytes,
|
|
192
|
+
prompt: str = "Describe this image in detail.",
|
|
193
|
+
mime_type: str = "image/png",
|
|
194
|
+
**kwargs,
|
|
195
|
+
) -> VisionResult:
|
|
196
|
+
client = self._get_client()
|
|
197
|
+
|
|
198
|
+
base64_image = self._encode_image(image_data)
|
|
199
|
+
|
|
200
|
+
response = await client.messages.create(
|
|
201
|
+
model=kwargs.get("model", self.model),
|
|
202
|
+
max_tokens=kwargs.get("max_tokens", self.max_tokens),
|
|
203
|
+
messages=[
|
|
204
|
+
{
|
|
205
|
+
"role": "user",
|
|
206
|
+
"content": [
|
|
207
|
+
{
|
|
208
|
+
"type": "image",
|
|
209
|
+
"source": {
|
|
210
|
+
"type": "base64",
|
|
211
|
+
"media_type": mime_type,
|
|
212
|
+
"data": base64_image,
|
|
213
|
+
},
|
|
214
|
+
},
|
|
215
|
+
{"type": "text", "text": prompt},
|
|
216
|
+
],
|
|
217
|
+
}
|
|
218
|
+
],
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
description = response.content[0].text if response.content else ""
|
|
222
|
+
|
|
223
|
+
return VisionResult(
|
|
224
|
+
description=description,
|
|
225
|
+
model=response.model,
|
|
226
|
+
raw_response=response.model_dump(),
|
|
227
|
+
usage={
|
|
228
|
+
"input_tokens": response.usage.input_tokens if response.usage else 0,
|
|
229
|
+
"output_tokens": response.usage.output_tokens if response.usage else 0,
|
|
230
|
+
},
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class GeminiVision(VisionProvider):
|
|
235
|
+
"""Google Gemini Vision provider."""
|
|
236
|
+
|
|
237
|
+
name = "gemini"
|
|
238
|
+
|
|
239
|
+
def __init__(
|
|
240
|
+
self,
|
|
241
|
+
api_key: str | None = None,
|
|
242
|
+
model: str = "gemini-2.0-flash",
|
|
243
|
+
):
|
|
244
|
+
self.api_key = api_key
|
|
245
|
+
self.model = model
|
|
246
|
+
self._client = None
|
|
247
|
+
|
|
248
|
+
@classmethod
|
|
249
|
+
def is_available(cls) -> bool:
|
|
250
|
+
try:
|
|
251
|
+
import google.generativeai # noqa: F401
|
|
252
|
+
return True
|
|
253
|
+
except ImportError:
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
def _get_client(self):
|
|
257
|
+
if self._client is None:
|
|
258
|
+
try:
|
|
259
|
+
import google.generativeai as genai
|
|
260
|
+
if self.api_key:
|
|
261
|
+
genai.configure(api_key=self.api_key)
|
|
262
|
+
self._client = genai.GenerativeModel(self.model)
|
|
263
|
+
except ImportError:
|
|
264
|
+
raise ImportError(
|
|
265
|
+
"Google Generative AI library not installed. "
|
|
266
|
+
"Install with: pip install google-generativeai"
|
|
267
|
+
)
|
|
268
|
+
return self._client
|
|
269
|
+
|
|
270
|
+
async def analyze_image(
|
|
271
|
+
self,
|
|
272
|
+
image_data: bytes,
|
|
273
|
+
prompt: str = "Describe this image in detail.",
|
|
274
|
+
mime_type: str = "image/png",
|
|
275
|
+
**kwargs,
|
|
276
|
+
) -> VisionResult:
|
|
277
|
+
client = self._get_client()
|
|
278
|
+
|
|
279
|
+
# Gemini uses PIL Image or inline data
|
|
280
|
+
image_part = {
|
|
281
|
+
"mime_type": mime_type,
|
|
282
|
+
"data": image_data,
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
# Gemini's generate_content is sync, wrap it
|
|
286
|
+
import asyncio
|
|
287
|
+
response = await asyncio.to_thread(
|
|
288
|
+
client.generate_content,
|
|
289
|
+
[prompt, image_part],
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
description = response.text if response.text else ""
|
|
293
|
+
|
|
294
|
+
usage = {}
|
|
295
|
+
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
|
296
|
+
usage = {
|
|
297
|
+
"prompt_tokens": response.usage_metadata.prompt_token_count,
|
|
298
|
+
"completion_tokens": response.usage_metadata.candidates_token_count,
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
return VisionResult(
|
|
302
|
+
description=description,
|
|
303
|
+
model=kwargs.get("model", self.model),
|
|
304
|
+
raw_response=response,
|
|
305
|
+
usage=usage,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
# Registry of all vision providers
|
|
310
|
+
VISION_PROVIDERS: dict[str, type[VisionProvider]] = {
|
|
311
|
+
"openai": OpenAIVision,
|
|
312
|
+
"anthropic": AnthropicVision,
|
|
313
|
+
"gemini": GeminiVision,
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def get_vision_provider(
|
|
318
|
+
name: str,
|
|
319
|
+
**kwargs,
|
|
320
|
+
) -> VisionProvider:
|
|
321
|
+
"""
|
|
322
|
+
Get a vision provider by name.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
name: Provider name ('openai', 'anthropic', 'gemini')
|
|
326
|
+
**kwargs: Provider-specific configuration
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
Configured VisionProvider instance
|
|
330
|
+
|
|
331
|
+
Raises:
|
|
332
|
+
ValueError: If provider name is unknown
|
|
333
|
+
ImportError: If provider dependencies are not installed
|
|
334
|
+
"""
|
|
335
|
+
if name not in VISION_PROVIDERS:
|
|
336
|
+
available = list(VISION_PROVIDERS.keys())
|
|
337
|
+
raise ValueError(f"Unknown vision provider: {name}. Available: {available}")
|
|
338
|
+
|
|
339
|
+
provider_class = VISION_PROVIDERS[name]
|
|
340
|
+
|
|
341
|
+
if not provider_class.is_available():
|
|
342
|
+
raise ImportError(
|
|
343
|
+
f"Vision provider '{name}' dependencies not installed. "
|
|
344
|
+
f"Check the provider documentation for installation instructions."
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return provider_class(**kwargs)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def get_available_vision_providers() -> list[str]:
|
|
351
|
+
"""
|
|
352
|
+
Get list of vision providers that have their dependencies installed.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
List of available provider names
|
|
356
|
+
"""
|
|
357
|
+
return [
|
|
358
|
+
name for name, cls in VISION_PROVIDERS.items()
|
|
359
|
+
if cls.is_available()
|
|
360
|
+
]
|