isa-model 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/config/__init__.py +9 -0
- isa_model/config/config_manager.py +213 -0
- isa_model/core/model_manager.py +5 -0
- isa_model/core/model_registry.py +39 -6
- isa_model/core/storage/supabase_storage.py +344 -0
- isa_model/core/vision_models_init.py +116 -0
- isa_model/deployment/cloud/__init__.py +9 -0
- isa_model/deployment/cloud/modal/__init__.py +10 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +612 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +305 -0
- isa_model/inference/ai_factory.py +238 -14
- isa_model/inference/providers/modal_provider.py +109 -0
- isa_model/inference/providers/yyds_provider.py +108 -0
- isa_model/inference/services/__init__.py +2 -1
- isa_model/inference/services/base_service.py +0 -38
- isa_model/inference/services/llm/base_llm_service.py +32 -0
- isa_model/inference/services/llm/llm_adapter.py +40 -0
- isa_model/inference/services/llm/ollama_llm_service.py +104 -3
- isa_model/inference/services/llm/openai_llm_service.py +67 -15
- isa_model/inference/services/llm/yyds_llm_service.py +254 -0
- isa_model/inference/services/stacked/__init__.py +26 -0
- isa_model/inference/services/stacked/base_stacked_service.py +269 -0
- isa_model/inference/services/stacked/config.py +426 -0
- isa_model/inference/services/stacked/doc_analysis_service.py +640 -0
- isa_model/inference/services/stacked/flux_professional_service.py +579 -0
- isa_model/inference/services/stacked/ui_analysis_service.py +1319 -0
- isa_model/inference/services/vision/base_image_gen_service.py +0 -34
- isa_model/inference/services/vision/base_vision_service.py +46 -2
- isa_model/inference/services/vision/isA_vision_service.py +402 -0
- isa_model/inference/services/vision/openai_vision_service.py +151 -9
- isa_model/inference/services/vision/replicate_image_gen_service.py +166 -38
- isa_model/inference/services/vision/replicate_vision_service.py +693 -0
- isa_model/serving/__init__.py +19 -0
- isa_model/serving/api/__init__.py +10 -0
- isa_model/serving/api/fastapi_server.py +84 -0
- isa_model/serving/api/middleware/__init__.py +9 -0
- isa_model/serving/api/middleware/request_logger.py +88 -0
- isa_model/serving/api/routes/__init__.py +5 -0
- isa_model/serving/api/routes/health.py +82 -0
- isa_model/serving/api/routes/llm.py +19 -0
- isa_model/serving/api/routes/ui_analysis.py +223 -0
- isa_model/serving/api/routes/vision.py +19 -0
- isa_model/serving/api/schemas/__init__.py +17 -0
- isa_model/serving/api/schemas/common.py +33 -0
- isa_model/serving/api/schemas/ui_analysis.py +78 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.5.dist-info}/METADATA +1 -1
- {isa_model-0.3.4.dist-info → isa_model-0.3.5.dist-info}/RECORD +49 -17
- {isa_model-0.3.4.dist-info → isa_model-0.3.5.dist-info}/WHEEL +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,693 @@
|
|
1
|
+
from typing import Dict, Any, Union, List, Optional, BinaryIO
|
2
|
+
import base64
|
3
|
+
import os
|
4
|
+
import replicate
|
5
|
+
import re
|
6
|
+
import ast
|
7
|
+
from isa_model.inference.services.vision.base_vision_service import BaseVisionService
|
8
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
+
from isa_model.inference.billing_tracker import ServiceType
|
10
|
+
import logging
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
class ReplicateVisionService(BaseVisionService):
|
15
|
+
"""Enhanced Replicate Vision service supporting multiple specialized models"""
|
16
|
+
|
17
|
+
# Supported model configurations
|
18
|
+
MODELS = {
|
19
|
+
"cogvlm": "cjwbw/cogvlm:a5092d718ea77a073e6d8f6969d5c0fb87d0ac7e4cdb7175427331e1798a34ed",
|
20
|
+
"florence-2": "microsoft/florence-2-large:fcdb54e52322b9e6dce7a35e5d8ad173dce30b46ef49a236c1a71bc6b78b5bed",
|
21
|
+
"omniparser": "microsoft/omniparser-v2:49cf3d41b8d3aca1360514e83be4c97131ce8f0d99abfc365526d8384caa88df",
|
22
|
+
"yolov8": "adirik/yolov8:3b21ba0e5da47bb2c69a96f72894a31b7c1e77b3e8a7b6ba43b7eb93b7b2c4f4"
|
23
|
+
}
|
24
|
+
|
25
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "cogvlm"):
|
26
|
+
# Resolve model name to full model path
|
27
|
+
self.model_key = model_name
|
28
|
+
resolved_model = self.MODELS.get(model_name, model_name)
|
29
|
+
super().__init__(provider, resolved_model)
|
30
|
+
|
31
|
+
# Get full configuration from provider
|
32
|
+
provider_config = provider.get_full_config()
|
33
|
+
|
34
|
+
# Initialize Replicate client
|
35
|
+
try:
|
36
|
+
# Get API token - try different possible keys like the image gen service
|
37
|
+
self.api_token = provider_config.get("api_token") or provider_config.get("replicate_api_token") or provider_config.get("api_key")
|
38
|
+
|
39
|
+
if not self.api_token:
|
40
|
+
raise ValueError("Replicate API token not found in provider configuration")
|
41
|
+
|
42
|
+
# Set API token for replicate
|
43
|
+
os.environ["REPLICATE_API_TOKEN"] = self.api_token
|
44
|
+
|
45
|
+
logger.info(f"Initialized ReplicateVisionService with model {self.model_key} ({self.model_name})")
|
46
|
+
|
47
|
+
except Exception as e:
|
48
|
+
logger.error(f"Failed to initialize Replicate client: {e}")
|
49
|
+
raise ValueError(f"Failed to initialize Replicate client. Check your API key configuration: {e}") from e
|
50
|
+
|
51
|
+
self.temperature = provider_config.get('temperature', 0.7)
|
52
|
+
|
53
|
+
def _prepare_image(self, image: Union[str, BinaryIO]) -> str:
|
54
|
+
"""Prepare image for Replicate API - convert to URL or base64"""
|
55
|
+
if isinstance(image, str):
|
56
|
+
if image.startswith(('http://', 'https://')):
|
57
|
+
# Already a URL
|
58
|
+
return image
|
59
|
+
else:
|
60
|
+
# Local file path - need to convert to base64 data URL
|
61
|
+
with open(image, "rb") as f:
|
62
|
+
image_data = f.read()
|
63
|
+
image_b64 = base64.b64encode(image_data).decode()
|
64
|
+
# Determine file extension for MIME type
|
65
|
+
ext = os.path.splitext(image)[1].lower()
|
66
|
+
mime_type = {
|
67
|
+
'.jpg': 'image/jpeg',
|
68
|
+
'.jpeg': 'image/jpeg',
|
69
|
+
'.png': 'image/png',
|
70
|
+
'.gif': 'image/gif',
|
71
|
+
'.webp': 'image/webp'
|
72
|
+
}.get(ext, 'image/jpeg')
|
73
|
+
return f"data:{mime_type};base64,{image_b64}"
|
74
|
+
else:
|
75
|
+
# BinaryIO or bytes data - convert to base64 data URL
|
76
|
+
if hasattr(image, 'read'):
|
77
|
+
image_data = image.read()
|
78
|
+
if isinstance(image_data, bytes):
|
79
|
+
image_b64 = base64.b64encode(image_data).decode()
|
80
|
+
else:
|
81
|
+
raise ValueError("File-like object did not return bytes")
|
82
|
+
else:
|
83
|
+
# Assume it's bytes
|
84
|
+
image_b64 = base64.b64encode(image).decode() # type: ignore
|
85
|
+
return f"data:image/jpeg;base64,{image_b64}"
|
86
|
+
|
87
|
+
async def invoke(
|
88
|
+
self,
|
89
|
+
image: Union[str, BinaryIO],
|
90
|
+
prompt: Optional[str] = None,
|
91
|
+
task: Optional[str] = None,
|
92
|
+
**kwargs
|
93
|
+
) -> Dict[str, Any]:
|
94
|
+
"""
|
95
|
+
Unified invoke method for all vision operations
|
96
|
+
"""
|
97
|
+
task = task or "analyze"
|
98
|
+
|
99
|
+
if task == "analyze":
|
100
|
+
return await self.analyze_image(image, prompt, kwargs.get("max_tokens", 1000))
|
101
|
+
elif task == "element_detection":
|
102
|
+
if self.model_key == "omniparser":
|
103
|
+
return await self.run_omniparser(image, **kwargs)
|
104
|
+
elif self.model_key == "florence-2":
|
105
|
+
return await self.run_florence2(image, **kwargs)
|
106
|
+
elif self.model_key == "yolov8":
|
107
|
+
return await self.run_yolo(image, **kwargs)
|
108
|
+
else:
|
109
|
+
return await self.detect_objects(image, kwargs.get("confidence_threshold", 0.5))
|
110
|
+
elif task == "describe":
|
111
|
+
return await self.describe_image(image, kwargs.get("detail_level", "medium"))
|
112
|
+
elif task == "extract_text":
|
113
|
+
return await self.extract_text(image)
|
114
|
+
elif task == "detect_objects":
|
115
|
+
return await self.detect_objects(image, kwargs.get("confidence_threshold", 0.5))
|
116
|
+
elif task == "classify":
|
117
|
+
return await self.classify_image(image, kwargs.get("categories"))
|
118
|
+
else:
|
119
|
+
# Default to analyze_image for unknown tasks
|
120
|
+
return await self.analyze_image(image, prompt, kwargs.get("max_tokens", 1000))
|
121
|
+
|
122
|
+
async def analyze_image(
|
123
|
+
self,
|
124
|
+
image: Union[str, BinaryIO],
|
125
|
+
prompt: Optional[str] = None,
|
126
|
+
max_tokens: int = 1000
|
127
|
+
) -> Dict[str, Any]:
|
128
|
+
"""
|
129
|
+
Analyze image and provide description or answer questions
|
130
|
+
"""
|
131
|
+
try:
|
132
|
+
# Prepare image for API
|
133
|
+
image_input = self._prepare_image(image)
|
134
|
+
|
135
|
+
# Use default prompt if none provided
|
136
|
+
if prompt is None:
|
137
|
+
prompt = "Describe this image in detail."
|
138
|
+
|
139
|
+
# Run CogVLM model
|
140
|
+
output = replicate.run(
|
141
|
+
self.model_name,
|
142
|
+
input={
|
143
|
+
"vqa": True, # Visual Question Answering mode
|
144
|
+
"image": image_input,
|
145
|
+
"query": prompt
|
146
|
+
}
|
147
|
+
)
|
148
|
+
|
149
|
+
# CogVLM returns a string response
|
150
|
+
response_text = str(output) if output else ""
|
151
|
+
|
152
|
+
# Track usage for billing
|
153
|
+
self._track_usage(
|
154
|
+
service_type=ServiceType.VISION,
|
155
|
+
operation="image_analysis",
|
156
|
+
input_tokens=len(prompt.split()) if prompt else 0,
|
157
|
+
output_tokens=len(response_text.split()),
|
158
|
+
metadata={"prompt": prompt[:100] if prompt else "", "model": self.model_name}
|
159
|
+
)
|
160
|
+
|
161
|
+
return {
|
162
|
+
"text": response_text,
|
163
|
+
"confidence": 1.0, # CogVLM doesn't provide confidence scores
|
164
|
+
"detected_objects": [], # Would need separate object detection
|
165
|
+
"metadata": {
|
166
|
+
"model": self.model_name,
|
167
|
+
"prompt": prompt,
|
168
|
+
"tokens_used": len(response_text.split())
|
169
|
+
}
|
170
|
+
}
|
171
|
+
|
172
|
+
except Exception as e:
|
173
|
+
logger.error(f"Error in image analysis: {e}")
|
174
|
+
raise
|
175
|
+
|
176
|
+
async def analyze_images(
|
177
|
+
self,
|
178
|
+
images: List[Union[str, BinaryIO]],
|
179
|
+
prompt: Optional[str] = None,
|
180
|
+
max_tokens: int = 1000
|
181
|
+
) -> List[Dict[str, Any]]:
|
182
|
+
"""Analyze multiple images"""
|
183
|
+
results = []
|
184
|
+
for image in images:
|
185
|
+
result = await self.analyze_image(image, prompt, max_tokens)
|
186
|
+
results.append(result)
|
187
|
+
return results
|
188
|
+
|
189
|
+
async def describe_image(
|
190
|
+
self,
|
191
|
+
image: Union[str, BinaryIO],
|
192
|
+
detail_level: str = "medium"
|
193
|
+
) -> Dict[str, Any]:
|
194
|
+
"""Generate detailed description of image"""
|
195
|
+
detail_prompts = {
|
196
|
+
"low": "Briefly describe what you see in this image.",
|
197
|
+
"medium": "Describe what you see in this image in detail, including objects, colors, and scene.",
|
198
|
+
"high": "Provide a comprehensive and detailed description of this image, including all visible objects, their positions, colors, textures, lighting, composition, and any text or symbols present."
|
199
|
+
}
|
200
|
+
|
201
|
+
prompt = detail_prompts.get(detail_level, detail_prompts["medium"])
|
202
|
+
result = await self.analyze_image(image, prompt, 1500)
|
203
|
+
|
204
|
+
return {
|
205
|
+
"description": result["text"],
|
206
|
+
"objects": [], # Would need object detection API
|
207
|
+
"scene": result["text"], # Use same description
|
208
|
+
"colors": [], # Would need color analysis
|
209
|
+
"detail_level": detail_level,
|
210
|
+
"metadata": result["metadata"]
|
211
|
+
}
|
212
|
+
|
213
|
+
async def extract_text(self, image: Union[str, BinaryIO]) -> Dict[str, Any]:
|
214
|
+
"""Extract text from image (OCR)"""
|
215
|
+
prompt = "Extract all text visible in this image. Provide only the text content, maintaining the original structure and formatting as much as possible."
|
216
|
+
result = await self.analyze_image(image, prompt, 1000)
|
217
|
+
|
218
|
+
return {
|
219
|
+
"text": result["text"],
|
220
|
+
"confidence": 1.0,
|
221
|
+
"bounding_boxes": [], # CogVLM doesn't provide bounding boxes
|
222
|
+
"language": "unknown", # Would need language detection
|
223
|
+
"metadata": result["metadata"]
|
224
|
+
}
|
225
|
+
|
226
|
+
async def detect_objects(
|
227
|
+
self,
|
228
|
+
image: Union[str, BinaryIO],
|
229
|
+
confidence_threshold: float = 0.5
|
230
|
+
) -> Dict[str, Any]:
|
231
|
+
"""Detect objects in image"""
|
232
|
+
prompt = """Analyze this image and identify all distinct objects, UI elements, or regions. For each element you identify, provide its location and size as percentages.
|
233
|
+
|
234
|
+
Look carefully at the image and identify distinct visual elements like:
|
235
|
+
- Text regions, buttons, input fields, images
|
236
|
+
- Distinct objects, shapes, or regions
|
237
|
+
- Interactive elements like buttons or form controls
|
238
|
+
|
239
|
+
For each element, respond in this EXACT format:
|
240
|
+
ElementName: x=X%, y=Y%, width=W%, height=H% - Description
|
241
|
+
|
242
|
+
Where:
|
243
|
+
- x% = horizontal position from left edge (0-100%)
|
244
|
+
- y% = vertical position from top edge (0-100%)
|
245
|
+
- width% = element width as percentage of image width (0-100%)
|
246
|
+
- height% = element height as percentage of image height (0-100%)
|
247
|
+
|
248
|
+
Be precise about the actual visual boundaries of each element.
|
249
|
+
|
250
|
+
Example: "Submit Button: x=25%, y=60%, width=15%, height=5% - Blue rectangular button with white text"
|
251
|
+
"""
|
252
|
+
result = await self.analyze_image(image, prompt, 1500)
|
253
|
+
|
254
|
+
# Parse the response to extract object information with coordinates
|
255
|
+
objects = []
|
256
|
+
bounding_boxes = []
|
257
|
+
lines = result["text"].split('\n')
|
258
|
+
|
259
|
+
for line in lines:
|
260
|
+
line = line.strip()
|
261
|
+
if line and ':' in line and ('x=' in line or 'width=' in line):
|
262
|
+
try:
|
263
|
+
# Extract object name and details
|
264
|
+
parts = line.split(':', 1)
|
265
|
+
if len(parts) == 2:
|
266
|
+
object_name = parts[0].strip()
|
267
|
+
details = parts[1].strip()
|
268
|
+
|
269
|
+
# Extract coordinates using regex-like parsing
|
270
|
+
coords = {}
|
271
|
+
for param in ['x', 'y', 'width', 'height']:
|
272
|
+
param_pattern = f"{param}="
|
273
|
+
if param_pattern in details:
|
274
|
+
start_idx = details.find(param_pattern) + len(param_pattern)
|
275
|
+
end_idx = details.find('%', start_idx)
|
276
|
+
if end_idx > start_idx:
|
277
|
+
try:
|
278
|
+
value = float(details[start_idx:end_idx])
|
279
|
+
coords[param] = value
|
280
|
+
except ValueError:
|
281
|
+
continue
|
282
|
+
|
283
|
+
# Extract description (after the coordinates)
|
284
|
+
desc_start = details.find(' - ')
|
285
|
+
description = details[desc_start + 3:] if desc_start != -1 else details
|
286
|
+
|
287
|
+
objects.append({
|
288
|
+
"label": object_name,
|
289
|
+
"confidence": 1.0,
|
290
|
+
"coordinates": coords,
|
291
|
+
"description": description
|
292
|
+
})
|
293
|
+
|
294
|
+
# Add bounding box if we have coordinates
|
295
|
+
if all(k in coords for k in ['x', 'y', 'width', 'height']):
|
296
|
+
bounding_boxes.append({
|
297
|
+
"label": object_name,
|
298
|
+
"x_percent": coords['x'],
|
299
|
+
"y_percent": coords['y'],
|
300
|
+
"width_percent": coords['width'],
|
301
|
+
"height_percent": coords['height']
|
302
|
+
})
|
303
|
+
|
304
|
+
except Exception:
|
305
|
+
# Fallback for objects that don't match expected format
|
306
|
+
objects.append({
|
307
|
+
"label": line,
|
308
|
+
"confidence": 1.0,
|
309
|
+
"coordinates": {},
|
310
|
+
"description": line
|
311
|
+
})
|
312
|
+
|
313
|
+
return {
|
314
|
+
"objects": objects,
|
315
|
+
"count": len(objects),
|
316
|
+
"bounding_boxes": bounding_boxes,
|
317
|
+
"metadata": result["metadata"]
|
318
|
+
}
|
319
|
+
|
320
|
+
async def get_object_coordinates(
|
321
|
+
self,
|
322
|
+
image: Union[str, BinaryIO],
|
323
|
+
object_name: str
|
324
|
+
) -> Dict[str, Any]:
|
325
|
+
"""Get coordinates of a specific object in the image"""
|
326
|
+
prompt = f"""Locate the {object_name} in this image and return its center coordinates as [x, y] pixels.
|
327
|
+
|
328
|
+
Look carefully at the image to find the exact element described. Be very precise about the location.
|
329
|
+
|
330
|
+
Respond in this exact format:
|
331
|
+
FOUND: YES/NO
|
332
|
+
CENTER: [x, y]
|
333
|
+
DESCRIPTION: [Brief description]
|
334
|
+
|
335
|
+
If found, provide the pixel coordinates of the center point.
|
336
|
+
If not found, explain why.
|
337
|
+
|
338
|
+
Example:
|
339
|
+
FOUND: YES
|
340
|
+
CENTER: [640, 360]
|
341
|
+
DESCRIPTION: Blue login button in the center-left area
|
342
|
+
"""
|
343
|
+
|
344
|
+
result = await self.analyze_image(image, prompt, 300)
|
345
|
+
response_text = result["text"]
|
346
|
+
|
347
|
+
# Parse the structured response
|
348
|
+
found = False
|
349
|
+
center_coords = None
|
350
|
+
description = ""
|
351
|
+
|
352
|
+
lines = response_text.split('\n')
|
353
|
+
for line in lines:
|
354
|
+
line = line.strip()
|
355
|
+
if line.startswith('FOUND:'):
|
356
|
+
found = 'YES' in line.upper()
|
357
|
+
elif line.startswith('CENTER:') and found:
|
358
|
+
# Extract center coordinates [x, y]
|
359
|
+
coords_text = line.replace('CENTER:', '').strip()
|
360
|
+
try:
|
361
|
+
# Remove brackets and split
|
362
|
+
coords_text = coords_text.replace('[', '').replace(']', '')
|
363
|
+
if ',' in coords_text:
|
364
|
+
x_str, y_str = coords_text.split(',')
|
365
|
+
x = int(float(x_str.strip()))
|
366
|
+
y = int(float(y_str.strip()))
|
367
|
+
center_coords = [x, y]
|
368
|
+
except (ValueError, IndexError):
|
369
|
+
pass
|
370
|
+
elif line.startswith('DESCRIPTION:'):
|
371
|
+
description = line.replace('DESCRIPTION:', '').strip()
|
372
|
+
|
373
|
+
return {
|
374
|
+
"found": found,
|
375
|
+
"center_coordinates": center_coords,
|
376
|
+
"confidence": 1.0 if found else 0.0,
|
377
|
+
"description": description,
|
378
|
+
"metadata": result["metadata"]
|
379
|
+
}
|
380
|
+
|
381
|
+
async def classify_image(
|
382
|
+
self,
|
383
|
+
image: Union[str, BinaryIO],
|
384
|
+
categories: Optional[List[str]] = None
|
385
|
+
) -> Dict[str, Any]:
|
386
|
+
"""Classify image into categories"""
|
387
|
+
if categories:
|
388
|
+
category_list = ", ".join(categories)
|
389
|
+
prompt = f"Classify this image into one of these categories: {category_list}. Respond with only the most appropriate category name."
|
390
|
+
else:
|
391
|
+
prompt = "What category best describes this image? Provide a single category name."
|
392
|
+
|
393
|
+
result = await self.analyze_image(image, prompt, 100)
|
394
|
+
category = result["text"].strip()
|
395
|
+
|
396
|
+
return {
|
397
|
+
"category": category,
|
398
|
+
"confidence": 1.0,
|
399
|
+
"all_predictions": [{"category": category, "confidence": 1.0}],
|
400
|
+
"metadata": result["metadata"]
|
401
|
+
}
|
402
|
+
|
403
|
+
async def compare_images(
|
404
|
+
self,
|
405
|
+
image1: Union[str, BinaryIO],
|
406
|
+
image2: Union[str, BinaryIO]
|
407
|
+
) -> Dict[str, Any]:
|
408
|
+
"""Compare two images for similarity"""
|
409
|
+
# For now, analyze both images separately and compare descriptions
|
410
|
+
result1 = await self.analyze_image(image1, "Describe this image in detail.")
|
411
|
+
result2 = await self.analyze_image(image2, "Describe this image in detail.")
|
412
|
+
|
413
|
+
# Use another CogVLM call to compare the descriptions
|
414
|
+
comparison_prompt = f"Compare these two image descriptions and provide a similarity analysis:\n\nImage 1: {result1['text']}\n\nImage 2: {result2['text']}\n\nProvide: 1) A similarity score from 0.0 to 1.0, 2) Key differences, 3) Common elements."
|
415
|
+
|
416
|
+
# Create a simple text prompt for comparison
|
417
|
+
comparison_result = await self.analyze_image(image1, comparison_prompt)
|
418
|
+
|
419
|
+
comparison_text = comparison_result["text"]
|
420
|
+
|
421
|
+
return {
|
422
|
+
"similarity_score": 0.5, # Would need better parsing to extract actual score
|
423
|
+
"differences": comparison_text,
|
424
|
+
"common_elements": comparison_text,
|
425
|
+
"metadata": {
|
426
|
+
"model": self.model_name,
|
427
|
+
"comparison_method": "description_based"
|
428
|
+
}
|
429
|
+
}
|
430
|
+
|
431
|
+
def get_supported_formats(self) -> List[str]:
|
432
|
+
"""Get list of supported image formats"""
|
433
|
+
return ['jpg', 'jpeg', 'png', 'gif', 'webp']
|
434
|
+
|
435
|
+
def get_max_image_size(self) -> Dict[str, int]:
|
436
|
+
"""Get maximum supported image dimensions"""
|
437
|
+
return {
|
438
|
+
"width": 2048,
|
439
|
+
"height": 2048,
|
440
|
+
"file_size_mb": 10
|
441
|
+
}
|
442
|
+
|
443
|
+
# ==================== MODEL-SPECIFIC METHODS ====================
|
444
|
+
|
445
|
+
async def run_omniparser(
|
446
|
+
self,
|
447
|
+
image: Union[str, BinaryIO],
|
448
|
+
imgsz: int = 640,
|
449
|
+
box_threshold: float = 0.05,
|
450
|
+
iou_threshold: float = 0.1
|
451
|
+
) -> Dict[str, Any]:
|
452
|
+
"""Run OmniParser-v2 for UI element detection"""
|
453
|
+
if self.model_key != "omniparser":
|
454
|
+
# Switch to OmniParser model temporarily
|
455
|
+
original_model = self.model_name
|
456
|
+
self.model_name = self.MODELS["omniparser"]
|
457
|
+
|
458
|
+
try:
|
459
|
+
image_input = self._prepare_image(image)
|
460
|
+
|
461
|
+
output = replicate.run(
|
462
|
+
self.model_name,
|
463
|
+
input={
|
464
|
+
"image": image_input,
|
465
|
+
"imgsz": imgsz,
|
466
|
+
"box_threshold": box_threshold,
|
467
|
+
"iou_threshold": iou_threshold
|
468
|
+
}
|
469
|
+
)
|
470
|
+
|
471
|
+
# Parse OmniParser output format
|
472
|
+
elements = []
|
473
|
+
if isinstance(output, dict) and 'elements' in output:
|
474
|
+
elements_text = output['elements']
|
475
|
+
elements = self._parse_omniparser_elements(elements_text, image)
|
476
|
+
|
477
|
+
return {
|
478
|
+
"model": "omniparser",
|
479
|
+
"raw_output": output,
|
480
|
+
"parsed_elements": elements,
|
481
|
+
"metadata": {
|
482
|
+
"imgsz": imgsz,
|
483
|
+
"box_threshold": box_threshold,
|
484
|
+
"iou_threshold": iou_threshold
|
485
|
+
}
|
486
|
+
}
|
487
|
+
|
488
|
+
finally:
|
489
|
+
if self.model_key != "omniparser":
|
490
|
+
# Restore original model
|
491
|
+
self.model_name = original_model
|
492
|
+
|
493
|
+
async def run_florence2(
|
494
|
+
self,
|
495
|
+
image: Union[str, BinaryIO],
|
496
|
+
task: str = "<OPEN_VOCABULARY_DETECTION>",
|
497
|
+
text_input: Optional[str] = None
|
498
|
+
) -> Dict[str, Any]:
|
499
|
+
"""Run Florence-2 for object detection and description"""
|
500
|
+
if self.model_key != "florence-2":
|
501
|
+
original_model = self.model_name
|
502
|
+
self.model_name = self.MODELS["florence-2"]
|
503
|
+
|
504
|
+
try:
|
505
|
+
image_input = self._prepare_image(image)
|
506
|
+
|
507
|
+
input_params = {
|
508
|
+
"image": image_input,
|
509
|
+
"task": task
|
510
|
+
}
|
511
|
+
if text_input:
|
512
|
+
input_params["text_input"] = text_input
|
513
|
+
|
514
|
+
output = replicate.run(self.model_name, input=input_params)
|
515
|
+
|
516
|
+
# Parse Florence-2 output
|
517
|
+
parsed_objects = []
|
518
|
+
if isinstance(output, dict):
|
519
|
+
parsed_objects = self._parse_florence2_output(output, image)
|
520
|
+
|
521
|
+
return {
|
522
|
+
"model": "florence-2",
|
523
|
+
"task": task,
|
524
|
+
"raw_output": output,
|
525
|
+
"parsed_objects": parsed_objects,
|
526
|
+
"metadata": {"task": task, "text_input": text_input}
|
527
|
+
}
|
528
|
+
|
529
|
+
finally:
|
530
|
+
if self.model_key != "florence-2":
|
531
|
+
self.model_name = original_model
|
532
|
+
|
533
|
+
async def run_yolo(
|
534
|
+
self,
|
535
|
+
image: Union[str, BinaryIO],
|
536
|
+
confidence: float = 0.5,
|
537
|
+
iou_threshold: float = 0.45
|
538
|
+
) -> Dict[str, Any]:
|
539
|
+
"""Run YOLO for general object detection"""
|
540
|
+
if self.model_key != "yolov8":
|
541
|
+
original_model = self.model_name
|
542
|
+
self.model_name = self.MODELS["yolov8"]
|
543
|
+
|
544
|
+
try:
|
545
|
+
image_input = self._prepare_image(image)
|
546
|
+
|
547
|
+
output = replicate.run(
|
548
|
+
self.model_name,
|
549
|
+
input={
|
550
|
+
"image": image_input,
|
551
|
+
"confidence": confidence,
|
552
|
+
"iou_threshold": iou_threshold
|
553
|
+
}
|
554
|
+
)
|
555
|
+
|
556
|
+
# Parse YOLO output
|
557
|
+
detected_objects = []
|
558
|
+
if output:
|
559
|
+
detected_objects = self._parse_yolo_output(output, image)
|
560
|
+
|
561
|
+
return {
|
562
|
+
"model": "yolov8",
|
563
|
+
"raw_output": output,
|
564
|
+
"detected_objects": detected_objects,
|
565
|
+
"metadata": {
|
566
|
+
"confidence": confidence,
|
567
|
+
"iou_threshold": iou_threshold
|
568
|
+
}
|
569
|
+
}
|
570
|
+
|
571
|
+
finally:
|
572
|
+
if self.model_key != "yolov8":
|
573
|
+
self.model_name = original_model
|
574
|
+
|
575
|
+
# ==================== PARSING HELPERS ====================
|
576
|
+
|
577
|
+
def _parse_omniparser_elements(self, elements_text: str, image: Union[str, BinaryIO]) -> List[Dict[str, Any]]:
|
578
|
+
"""Parse OmniParser-v2 elements format"""
|
579
|
+
elements = []
|
580
|
+
|
581
|
+
# Get image dimensions for coordinate conversion
|
582
|
+
from PIL import Image as PILImage
|
583
|
+
if isinstance(image, str):
|
584
|
+
img = PILImage.open(image)
|
585
|
+
else:
|
586
|
+
img = PILImage.open(image)
|
587
|
+
img_width, img_height = img.size
|
588
|
+
|
589
|
+
try:
|
590
|
+
# Extract individual icon entries
|
591
|
+
icon_pattern = r"icon (\d+): ({.*?})\n?"
|
592
|
+
matches = re.findall(icon_pattern, elements_text, re.DOTALL)
|
593
|
+
|
594
|
+
for icon_id, icon_data_str in matches:
|
595
|
+
try:
|
596
|
+
icon_data = eval(icon_data_str) # Safe since we control the source
|
597
|
+
|
598
|
+
bbox = icon_data.get('bbox', [])
|
599
|
+
element_type = icon_data.get('type', 'unknown')
|
600
|
+
interactivity = icon_data.get('interactivity', False)
|
601
|
+
content = icon_data.get('content', '').strip()
|
602
|
+
|
603
|
+
if len(bbox) == 4:
|
604
|
+
# Convert normalized coordinates to pixel coordinates
|
605
|
+
x1_norm, y1_norm, x2_norm, y2_norm = bbox
|
606
|
+
x1 = int(x1_norm * img_width)
|
607
|
+
y1 = int(y1_norm * img_height)
|
608
|
+
x2 = int(x2_norm * img_width)
|
609
|
+
y2 = int(y2_norm * img_height)
|
610
|
+
|
611
|
+
element = {
|
612
|
+
'id': f'omni_icon_{icon_id}',
|
613
|
+
'bbox': [x1, y1, x2, y2],
|
614
|
+
'center': [int((x1 + x2) / 2), int((y1 + y2) / 2)],
|
615
|
+
'size': [x2 - x1, y2 - y1],
|
616
|
+
'type': element_type,
|
617
|
+
'interactivity': interactivity,
|
618
|
+
'content': content,
|
619
|
+
'confidence': 0.9
|
620
|
+
}
|
621
|
+
elements.append(element)
|
622
|
+
|
623
|
+
except Exception as e:
|
624
|
+
logger.warning(f"Failed to parse icon {icon_id}: {e}")
|
625
|
+
|
626
|
+
except Exception as e:
|
627
|
+
logger.error(f"Failed to parse OmniParser elements: {e}")
|
628
|
+
|
629
|
+
return elements
|
630
|
+
|
631
|
+
def _parse_florence2_output(self, output: Dict[str, Any], image: Union[str, BinaryIO]) -> List[Dict[str, Any]]:
|
632
|
+
"""Parse Florence-2 detection output"""
|
633
|
+
objects = []
|
634
|
+
|
635
|
+
try:
|
636
|
+
# Florence-2 typically returns nested detection data
|
637
|
+
for key, value in output.items():
|
638
|
+
if isinstance(value, dict) and ('bboxes' in value and 'labels' in value):
|
639
|
+
bboxes = value['bboxes']
|
640
|
+
labels = value['labels']
|
641
|
+
|
642
|
+
for i, (label, bbox) in enumerate(zip(labels, bboxes)):
|
643
|
+
if len(bbox) >= 4:
|
644
|
+
x1, y1, x2, y2 = bbox[:4]
|
645
|
+
obj = {
|
646
|
+
'id': f'florence_{i}',
|
647
|
+
'label': label,
|
648
|
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
649
|
+
'center': [int((x1 + x2) / 2), int((y1 + y2) / 2)],
|
650
|
+
'size': [int(x2 - x1), int(y2 - y1)],
|
651
|
+
'confidence': 0.9
|
652
|
+
}
|
653
|
+
objects.append(obj)
|
654
|
+
|
655
|
+
except Exception as e:
|
656
|
+
logger.error(f"Failed to parse Florence-2 output: {e}")
|
657
|
+
|
658
|
+
return objects
|
659
|
+
|
660
|
+
def _parse_yolo_output(self, output: Any, image: Union[str, BinaryIO]) -> List[Dict[str, Any]]:
|
661
|
+
"""Parse YOLO detection output"""
|
662
|
+
objects = []
|
663
|
+
|
664
|
+
try:
|
665
|
+
# YOLO output format varies, handle common formats
|
666
|
+
if isinstance(output, list):
|
667
|
+
for i, detection in enumerate(output):
|
668
|
+
if isinstance(detection, dict):
|
669
|
+
bbox = detection.get('bbox', detection.get('box', []))
|
670
|
+
label = detection.get('class', detection.get('label', f'object_{i}'))
|
671
|
+
confidence = detection.get('confidence', detection.get('score', 0.9))
|
672
|
+
|
673
|
+
if len(bbox) >= 4:
|
674
|
+
x1, y1, x2, y2 = bbox[:4]
|
675
|
+
obj = {
|
676
|
+
'id': f'yolo_{i}',
|
677
|
+
'label': label,
|
678
|
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
679
|
+
'center': [int((x1 + x2) / 2), int((y1 + y2) / 2)],
|
680
|
+
'size': [int(x2 - x1), int(y2 - y1)],
|
681
|
+
'confidence': float(confidence)
|
682
|
+
}
|
683
|
+
objects.append(obj)
|
684
|
+
|
685
|
+
except Exception as e:
|
686
|
+
logger.error(f"Failed to parse YOLO output: {e}")
|
687
|
+
|
688
|
+
return objects
|
689
|
+
|
690
|
+
async def close(self):
|
691
|
+
"""Clean up resources"""
|
692
|
+
# Replicate doesn't need explicit cleanup
|
693
|
+
pass
|