isa-model 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +770 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/models/model_repo.py +343 -0
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/__init__.py +9 -0
- isa_model/deployment/cloud/modal/__init__.py +10 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +766 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +406 -0
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +187 -387
- isa_model/inference/providers/modal_provider.py +109 -0
- isa_model/inference/providers/yyds_provider.py +108 -0
- isa_model/inference/services/__init__.py +2 -1
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -55
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -35
- isa_model/inference/services/img/flux_professional_service.py +603 -0
- isa_model/inference/services/img/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +210 -69
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +519 -35
- isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +40 -0
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +150 -15
- isa_model/inference/services/llm/openai_llm_service.py +134 -31
- isa_model/inference/services/llm/yyds_llm_service.py +255 -0
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +241 -96
- isa_model/inference/services/vision/disabled/isA_vision_service.py +500 -0
- isa_model/inference/services/vision/doc_analysis_service.py +640 -0
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +109 -170
- isa_model/inference/services/vision/replicate_vision_service.py +508 -0
- isa_model/inference/services/vision/ui_analysis_service.py +823 -0
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/__init__.py +19 -0
- isa_model/serving/api/__init__.py +10 -0
- isa_model/serving/api/fastapi_server.py +89 -0
- isa_model/serving/api/middleware/__init__.py +9 -0
- isa_model/serving/api/middleware/request_logger.py +88 -0
- isa_model/serving/api/routes/__init__.py +5 -0
- isa_model/serving/api/routes/health.py +82 -0
- isa_model/serving/api/routes/llm.py +19 -0
- isa_model/serving/api/routes/ui_analysis.py +223 -0
- isa_model/serving/api/routes/unified.py +202 -0
- isa_model/serving/api/routes/vision.py +19 -0
- isa_model/serving/api/schemas/__init__.py +17 -0
- isa_model/serving/api/schemas/common.py +33 -0
- isa_model/serving/api/schemas/ui_analysis.py +78 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
- isa_model-0.3.6.dist-info/RECORD +147 -0
- isa_model/core/model_manager.py +0 -208
- isa_model/core/model_registry.py +0 -342
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- isa_model-0.3.4.dist-info/RECORD +0 -91
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,508 @@
|
|
1
|
+
from typing import Dict, Any, Union, List, Optional, BinaryIO
|
2
|
+
import base64
|
3
|
+
import os
|
4
|
+
import replicate
|
5
|
+
import re
|
6
|
+
import ast
|
7
|
+
from isa_model.inference.services.vision.base_vision_service import BaseVisionService
|
8
|
+
from isa_model.core.types import ServiceType
|
9
|
+
from isa_model.inference.services.vision.helpers.image_utils import prepare_image_data_url
|
10
|
+
from isa_model.inference.services.vision.helpers.vision_prompts import VisionPromptMixin
|
11
|
+
import logging
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
class ReplicateVisionService(BaseVisionService, VisionPromptMixin):
|
16
|
+
"""Enhanced Replicate Vision service supporting multiple specialized models"""
|
17
|
+
|
18
|
+
# Supported model configurations
|
19
|
+
MODELS = {
|
20
|
+
"cogvlm": "cjwbw/cogvlm:a5092d718ea77a073e6d8f6969d5c0fb87d0ac7e4cdb7175427331e1798a34ed",
|
21
|
+
"florence-2": "microsoft/florence-2-large:fcdb54e52322b9e6dce7a35e5d8ad173dce30b46ef49a236c1a71bc6b78b5bed",
|
22
|
+
"omniparser": "microsoft/omniparser-v2:49cf3d41b8d3aca1360514e83be4c97131ce8f0d99abfc365526d8384caa88df",
|
23
|
+
"yolov8": "adirik/yolov8:3b21ba0e5da47bb2c69a96f72894a31b7c1e77b3e8a7b6ba43b7eb93b7b2c4f4",
|
24
|
+
"qwen-vl-chat": "lucataco/qwen-vl-chat:50881b153b4d5f72b3db697e2bbad23bb1277ab741c5b52d80cd6ee17ea660e9"
|
25
|
+
}
|
26
|
+
|
27
|
+
def __init__(self, provider_name: str, model_name: str = "cogvlm", **kwargs):
|
28
|
+
# Resolve model name to full model path
|
29
|
+
self.model_key = model_name
|
30
|
+
resolved_model = self.MODELS.get(model_name, model_name)
|
31
|
+
super().__init__(provider_name, resolved_model, **kwargs)
|
32
|
+
|
33
|
+
# Get configuration from centralized config manager
|
34
|
+
provider_config = self.get_provider_config()
|
35
|
+
|
36
|
+
# Initialize Replicate client
|
37
|
+
try:
|
38
|
+
# Get API token - try different possible keys like the image gen service
|
39
|
+
self.api_token = provider_config.get("api_token") or provider_config.get("replicate_api_token") or provider_config.get("api_key")
|
40
|
+
|
41
|
+
if not self.api_token:
|
42
|
+
raise ValueError("Replicate API token not found in provider configuration")
|
43
|
+
|
44
|
+
# Set API token for replicate
|
45
|
+
os.environ["REPLICATE_API_TOKEN"] = self.api_token
|
46
|
+
|
47
|
+
logger.info(f"Initialized ReplicateVisionService with model {self.model_key} ({self.model_name})")
|
48
|
+
|
49
|
+
except Exception as e:
|
50
|
+
logger.error(f"Failed to initialize Replicate client: {e}")
|
51
|
+
raise ValueError(f"Failed to initialize Replicate client. Check your API key configuration: {e}") from e
|
52
|
+
|
53
|
+
self.temperature = provider_config.get('temperature', 0.7)
|
54
|
+
|
55
|
+
def _prepare_image(self, image: Union[str, BinaryIO]) -> str:
|
56
|
+
"""Prepare image for Replicate API - convert to URL or base64"""
|
57
|
+
if isinstance(image, str) and image.startswith(('http://', 'https://')):
|
58
|
+
# Already a URL
|
59
|
+
return image
|
60
|
+
else:
|
61
|
+
# Use unified image processing from image_utils
|
62
|
+
return prepare_image_data_url(image)
|
63
|
+
|
64
|
+
# Replicate使用base的invoke方法,不需要重写
|
65
|
+
# 直接实现对应的标准方法即可
|
66
|
+
|
67
|
+
async def analyze_image(
|
68
|
+
self,
|
69
|
+
image: Union[str, BinaryIO],
|
70
|
+
prompt: Optional[str] = None,
|
71
|
+
max_tokens: int = 1000
|
72
|
+
) -> Dict[str, Any]:
|
73
|
+
"""
|
74
|
+
Analyze image and provide description or answer questions
|
75
|
+
"""
|
76
|
+
try:
|
77
|
+
# Prepare image for API using unified processing
|
78
|
+
image_input = self._prepare_image(image)
|
79
|
+
|
80
|
+
# Use default prompt if none provided
|
81
|
+
if prompt is None:
|
82
|
+
prompt = "Describe this image in detail."
|
83
|
+
|
84
|
+
# Choose input format based on model type
|
85
|
+
if self.model_key == "qwen-vl-chat":
|
86
|
+
# Qwen-VL-Chat uses simple image + prompt format
|
87
|
+
output = replicate.run(
|
88
|
+
self.model_name,
|
89
|
+
input={
|
90
|
+
"image": image_input,
|
91
|
+
"prompt": prompt
|
92
|
+
}
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
# CogVLM and other models use VQA format
|
96
|
+
output = replicate.run(
|
97
|
+
self.model_name,
|
98
|
+
input={
|
99
|
+
"vqa": True, # Visual Question Answering mode
|
100
|
+
"image": image_input,
|
101
|
+
"query": prompt
|
102
|
+
}
|
103
|
+
)
|
104
|
+
|
105
|
+
# CogVLM returns a string response
|
106
|
+
response_text = str(output) if output else ""
|
107
|
+
|
108
|
+
# Track usage for billing
|
109
|
+
await self._track_usage(
|
110
|
+
service_type=ServiceType.VISION,
|
111
|
+
operation="image_analysis",
|
112
|
+
input_tokens=len(prompt.split()) if prompt else 0,
|
113
|
+
output_tokens=len(response_text.split()),
|
114
|
+
metadata={"prompt": prompt[:100] if prompt else "", "model": self.model_name}
|
115
|
+
)
|
116
|
+
|
117
|
+
return {
|
118
|
+
"text": response_text,
|
119
|
+
"confidence": 1.0, # CogVLM doesn't provide confidence scores
|
120
|
+
"detected_objects": [], # Would need separate object detection
|
121
|
+
"metadata": {
|
122
|
+
"model": self.model_name,
|
123
|
+
"prompt": prompt,
|
124
|
+
"tokens_used": len(response_text.split())
|
125
|
+
}
|
126
|
+
}
|
127
|
+
|
128
|
+
except Exception as e:
|
129
|
+
logger.error(f"Error in image analysis: {e}")
|
130
|
+
raise
|
131
|
+
|
132
|
+
# ==================== 标准接口实现:检测抽取类 ====================
|
133
|
+
|
134
|
+
async def detect_ui_elements(
|
135
|
+
self,
|
136
|
+
image: Union[str, BinaryIO],
|
137
|
+
element_types: Optional[List[str]] = None,
|
138
|
+
confidence_threshold: float = 0.5
|
139
|
+
) -> Dict[str, Any]:
|
140
|
+
"""
|
141
|
+
UI界面元素检测 - 使用专门模型实现
|
142
|
+
"""
|
143
|
+
if self.model_key == "omniparser":
|
144
|
+
return await self.run_omniparser(image, box_threshold=confidence_threshold)
|
145
|
+
elif self.model_key == "florence-2":
|
146
|
+
return await self.run_florence2(image, task="<OPEN_VOCABULARY_DETECTION>")
|
147
|
+
else:
|
148
|
+
# 使用通用物体检测作为fallback
|
149
|
+
return await self.detect_objects(image, confidence_threshold)
|
150
|
+
|
151
|
+
async def detect_document_elements(
|
152
|
+
self,
|
153
|
+
image: Union[str, BinaryIO],
|
154
|
+
element_types: Optional[List[str]] = None,
|
155
|
+
confidence_threshold: float = 0.5
|
156
|
+
) -> Dict[str, Any]:
|
157
|
+
"""
|
158
|
+
文档结构元素检测 - 使用专门模型实现
|
159
|
+
"""
|
160
|
+
if self.model_key == "florence-2":
|
161
|
+
# Florence-2可以检测文档结构
|
162
|
+
return await self.run_florence2(image, task="<DETAILED_CAPTION>")
|
163
|
+
else:
|
164
|
+
raise NotImplementedError(f"Document detection not supported for model {self.model_key}")
|
165
|
+
|
166
|
+
async def detect_objects(
|
167
|
+
self,
|
168
|
+
image: Union[str, BinaryIO],
|
169
|
+
confidence_threshold: float = 0.5
|
170
|
+
) -> Dict[str, Any]:
|
171
|
+
"""
|
172
|
+
通用物体检测 - 实现标准接口
|
173
|
+
"""
|
174
|
+
if self.model_key == "yolov8":
|
175
|
+
return await self.run_yolo(image, confidence=confidence_threshold)
|
176
|
+
elif self.model_key == "florence-2":
|
177
|
+
return await self.run_florence2(image, task="<OD>")
|
178
|
+
elif self.model_key == "qwen-vl-chat":
|
179
|
+
# Qwen-VL-Chat can do object detection through prompting
|
180
|
+
prompt = self.get_task_prompt("detect_objects", confidence_threshold=confidence_threshold)
|
181
|
+
return await self.analyze_image(image, prompt)
|
182
|
+
else:
|
183
|
+
raise NotImplementedError(f"Object detection not supported for model {self.model_key}")
|
184
|
+
|
185
|
+
# ==================== QWEN-VL-CHAT 智能提示词实现 ====================
|
186
|
+
# 类似 OpenAI,qwen-vl-chat 通过提示词实现所有 Vision 功能
|
187
|
+
|
188
|
+
async def describe_image(
|
189
|
+
self,
|
190
|
+
image: Union[str, BinaryIO],
|
191
|
+
detail_level: str = "medium"
|
192
|
+
) -> Dict[str, Any]:
|
193
|
+
"""
|
194
|
+
图像描述 - qwen-vl-chat通过提示词实现
|
195
|
+
"""
|
196
|
+
if self.model_key == "qwen-vl-chat":
|
197
|
+
prompt = self.get_task_prompt("describe", detail_level=detail_level)
|
198
|
+
return await self.analyze_image(image, prompt)
|
199
|
+
else:
|
200
|
+
raise NotImplementedError(f"describe_image not supported for model {self.model_key}")
|
201
|
+
|
202
|
+
async def extract_text(self, image: Union[str, BinaryIO]) -> Dict[str, Any]:
|
203
|
+
"""
|
204
|
+
文本提取(OCR) - qwen-vl-chat通过提示词实现
|
205
|
+
"""
|
206
|
+
if self.model_key == "qwen-vl-chat":
|
207
|
+
prompt = self.get_task_prompt("extract_text")
|
208
|
+
return await self.analyze_image(image, prompt)
|
209
|
+
else:
|
210
|
+
raise NotImplementedError(f"extract_text not supported for model {self.model_key}")
|
211
|
+
|
212
|
+
async def classify_image(
|
213
|
+
self,
|
214
|
+
image: Union[str, BinaryIO],
|
215
|
+
categories: Optional[List[str]] = None
|
216
|
+
) -> Dict[str, Any]:
|
217
|
+
"""
|
218
|
+
图像分类 - qwen-vl-chat通过提示词实现
|
219
|
+
"""
|
220
|
+
if self.model_key == "qwen-vl-chat":
|
221
|
+
prompt = self.get_task_prompt("classify", categories=categories)
|
222
|
+
return await self.analyze_image(image, prompt)
|
223
|
+
else:
|
224
|
+
raise NotImplementedError(f"classify_image not supported for model {self.model_key}")
|
225
|
+
|
226
|
+
async def extract_table_data(
|
227
|
+
self,
|
228
|
+
image: Union[str, BinaryIO],
|
229
|
+
table_format: str = "json",
|
230
|
+
preserve_formatting: bool = True
|
231
|
+
) -> Dict[str, Any]:
|
232
|
+
"""
|
233
|
+
表格数据抽取 - qwen-vl-chat通过提示词实现
|
234
|
+
"""
|
235
|
+
if self.model_key == "qwen-vl-chat":
|
236
|
+
prompt = self.get_task_prompt("extract_table_data", table_format=table_format, preserve_formatting=preserve_formatting)
|
237
|
+
return await self.analyze_image(image, prompt)
|
238
|
+
else:
|
239
|
+
raise NotImplementedError(f"extract_table_data not supported for model {self.model_key}")
|
240
|
+
|
241
|
+
async def get_object_coordinates(
|
242
|
+
self,
|
243
|
+
image: Union[str, BinaryIO],
|
244
|
+
object_name: str
|
245
|
+
) -> Dict[str, Any]:
|
246
|
+
"""
|
247
|
+
获取对象坐标 - qwen-vl-chat通过提示词实现
|
248
|
+
"""
|
249
|
+
if self.model_key == "qwen-vl-chat":
|
250
|
+
prompt = self.get_task_prompt("get_coordinates", object_name=object_name)
|
251
|
+
return await self.analyze_image(image, prompt)
|
252
|
+
else:
|
253
|
+
raise NotImplementedError(f"get_object_coordinates not supported for model {self.model_key}")
|
254
|
+
|
255
|
+
# ==================== REPLICATE专门模型方法 ====================
|
256
|
+
# 以下方法是Replicate特有的专门模型实现,不在标准接口中
|
257
|
+
|
258
|
+
# ==================== MODEL-SPECIFIC METHODS ====================
|
259
|
+
|
260
|
+
async def run_omniparser(
|
261
|
+
self,
|
262
|
+
image: Union[str, BinaryIO],
|
263
|
+
imgsz: int = 640,
|
264
|
+
box_threshold: float = 0.05,
|
265
|
+
iou_threshold: float = 0.1
|
266
|
+
) -> Dict[str, Any]:
|
267
|
+
"""Run OmniParser-v2 for UI element detection"""
|
268
|
+
if self.model_key != "omniparser":
|
269
|
+
# Switch to OmniParser model temporarily
|
270
|
+
original_model = self.model_name
|
271
|
+
self.model_name = self.MODELS["omniparser"]
|
272
|
+
|
273
|
+
try:
|
274
|
+
image_input = self._prepare_image(image)
|
275
|
+
|
276
|
+
output = replicate.run(
|
277
|
+
self.model_name,
|
278
|
+
input={
|
279
|
+
"image": image_input,
|
280
|
+
"imgsz": imgsz,
|
281
|
+
"box_threshold": box_threshold,
|
282
|
+
"iou_threshold": iou_threshold
|
283
|
+
}
|
284
|
+
)
|
285
|
+
|
286
|
+
# Parse OmniParser output format
|
287
|
+
elements = []
|
288
|
+
if isinstance(output, dict) and 'elements' in output:
|
289
|
+
elements_text = output['elements']
|
290
|
+
elements = self._parse_omniparser_elements(elements_text, image)
|
291
|
+
|
292
|
+
return {
|
293
|
+
"model": "omniparser",
|
294
|
+
"raw_output": output,
|
295
|
+
"parsed_elements": elements,
|
296
|
+
"metadata": {
|
297
|
+
"imgsz": imgsz,
|
298
|
+
"box_threshold": box_threshold,
|
299
|
+
"iou_threshold": iou_threshold
|
300
|
+
}
|
301
|
+
}
|
302
|
+
|
303
|
+
finally:
|
304
|
+
if self.model_key != "omniparser":
|
305
|
+
# Restore original model
|
306
|
+
self.model_name = original_model
|
307
|
+
|
308
|
+
async def run_florence2(
|
309
|
+
self,
|
310
|
+
image: Union[str, BinaryIO],
|
311
|
+
task: str = "<OPEN_VOCABULARY_DETECTION>",
|
312
|
+
text_input: Optional[str] = None
|
313
|
+
) -> Dict[str, Any]:
|
314
|
+
"""Run Florence-2 for object detection and description"""
|
315
|
+
if self.model_key != "florence-2":
|
316
|
+
original_model = self.model_name
|
317
|
+
self.model_name = self.MODELS["florence-2"]
|
318
|
+
|
319
|
+
try:
|
320
|
+
image_input = self._prepare_image(image)
|
321
|
+
|
322
|
+
input_params = {
|
323
|
+
"image": image_input,
|
324
|
+
"task": task
|
325
|
+
}
|
326
|
+
if text_input:
|
327
|
+
input_params["text_input"] = text_input
|
328
|
+
|
329
|
+
output = replicate.run(self.model_name, input=input_params)
|
330
|
+
|
331
|
+
# Parse Florence-2 output
|
332
|
+
parsed_objects = []
|
333
|
+
if isinstance(output, dict):
|
334
|
+
parsed_objects = self._parse_florence2_output(output, image)
|
335
|
+
|
336
|
+
return {
|
337
|
+
"model": "florence-2",
|
338
|
+
"task": task,
|
339
|
+
"raw_output": output,
|
340
|
+
"parsed_objects": parsed_objects,
|
341
|
+
"metadata": {"task": task, "text_input": text_input}
|
342
|
+
}
|
343
|
+
|
344
|
+
finally:
|
345
|
+
if self.model_key != "florence-2":
|
346
|
+
self.model_name = original_model
|
347
|
+
|
348
|
+
async def run_yolo(
|
349
|
+
self,
|
350
|
+
image: Union[str, BinaryIO],
|
351
|
+
confidence: float = 0.5,
|
352
|
+
iou_threshold: float = 0.45
|
353
|
+
) -> Dict[str, Any]:
|
354
|
+
"""Run YOLO for general object detection"""
|
355
|
+
if self.model_key != "yolov8":
|
356
|
+
original_model = self.model_name
|
357
|
+
self.model_name = self.MODELS["yolov8"]
|
358
|
+
|
359
|
+
try:
|
360
|
+
image_input = self._prepare_image(image)
|
361
|
+
|
362
|
+
output = replicate.run(
|
363
|
+
self.model_name,
|
364
|
+
input={
|
365
|
+
"image": image_input,
|
366
|
+
"confidence": confidence,
|
367
|
+
"iou_threshold": iou_threshold
|
368
|
+
}
|
369
|
+
)
|
370
|
+
|
371
|
+
# Parse YOLO output
|
372
|
+
detected_objects = []
|
373
|
+
if output:
|
374
|
+
detected_objects = self._parse_yolo_output(output, image)
|
375
|
+
|
376
|
+
return {
|
377
|
+
"model": "yolov8",
|
378
|
+
"raw_output": output,
|
379
|
+
"detected_objects": detected_objects,
|
380
|
+
"metadata": {
|
381
|
+
"confidence": confidence,
|
382
|
+
"iou_threshold": iou_threshold
|
383
|
+
}
|
384
|
+
}
|
385
|
+
|
386
|
+
finally:
|
387
|
+
if self.model_key != "yolov8":
|
388
|
+
self.model_name = original_model
|
389
|
+
|
390
|
+
# ==================== PARSING HELPERS ====================
|
391
|
+
|
392
|
+
def _parse_omniparser_elements(self, elements_text: str, image: Union[str, BinaryIO]) -> List[Dict[str, Any]]:
|
393
|
+
"""Parse OmniParser-v2 elements format"""
|
394
|
+
elements = []
|
395
|
+
|
396
|
+
# Get image dimensions for coordinate conversion
|
397
|
+
from PIL import Image as PILImage
|
398
|
+
if isinstance(image, str):
|
399
|
+
img = PILImage.open(image)
|
400
|
+
else:
|
401
|
+
img = PILImage.open(image)
|
402
|
+
img_width, img_height = img.size
|
403
|
+
|
404
|
+
try:
|
405
|
+
# Extract individual icon entries
|
406
|
+
icon_pattern = r"icon (\d+): ({.*?})\n?"
|
407
|
+
matches = re.findall(icon_pattern, elements_text, re.DOTALL)
|
408
|
+
|
409
|
+
for icon_id, icon_data_str in matches:
|
410
|
+
try:
|
411
|
+
icon_data = eval(icon_data_str) # Safe since we control the source
|
412
|
+
|
413
|
+
bbox = icon_data.get('bbox', [])
|
414
|
+
element_type = icon_data.get('type', 'unknown')
|
415
|
+
interactivity = icon_data.get('interactivity', False)
|
416
|
+
content = icon_data.get('content', '').strip()
|
417
|
+
|
418
|
+
if len(bbox) == 4:
|
419
|
+
# Convert normalized coordinates to pixel coordinates
|
420
|
+
x1_norm, y1_norm, x2_norm, y2_norm = bbox
|
421
|
+
x1 = int(x1_norm * img_width)
|
422
|
+
y1 = int(y1_norm * img_height)
|
423
|
+
x2 = int(x2_norm * img_width)
|
424
|
+
y2 = int(y2_norm * img_height)
|
425
|
+
|
426
|
+
element = {
|
427
|
+
'id': f'omni_icon_{icon_id}',
|
428
|
+
'bbox': [x1, y1, x2, y2],
|
429
|
+
'center': [int((x1 + x2) / 2), int((y1 + y2) / 2)],
|
430
|
+
'size': [x2 - x1, y2 - y1],
|
431
|
+
'type': element_type,
|
432
|
+
'interactivity': interactivity,
|
433
|
+
'content': content,
|
434
|
+
'confidence': 0.9
|
435
|
+
}
|
436
|
+
elements.append(element)
|
437
|
+
|
438
|
+
except Exception as e:
|
439
|
+
logger.warning(f"Failed to parse icon {icon_id}: {e}")
|
440
|
+
|
441
|
+
except Exception as e:
|
442
|
+
logger.error(f"Failed to parse OmniParser elements: {e}")
|
443
|
+
|
444
|
+
return elements
|
445
|
+
|
446
|
+
def _parse_florence2_output(self, output: Dict[str, Any], image: Union[str, BinaryIO]) -> List[Dict[str, Any]]:
|
447
|
+
"""Parse Florence-2 detection output"""
|
448
|
+
objects = []
|
449
|
+
|
450
|
+
try:
|
451
|
+
# Florence-2 typically returns nested detection data
|
452
|
+
for key, value in output.items():
|
453
|
+
if isinstance(value, dict) and ('bboxes' in value and 'labels' in value):
|
454
|
+
bboxes = value['bboxes']
|
455
|
+
labels = value['labels']
|
456
|
+
|
457
|
+
for i, (label, bbox) in enumerate(zip(labels, bboxes)):
|
458
|
+
if len(bbox) >= 4:
|
459
|
+
x1, y1, x2, y2 = bbox[:4]
|
460
|
+
obj = {
|
461
|
+
'id': f'florence_{i}',
|
462
|
+
'label': label,
|
463
|
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
464
|
+
'center': [int((x1 + x2) / 2), int((y1 + y2) / 2)],
|
465
|
+
'size': [int(x2 - x1), int(y2 - y1)],
|
466
|
+
'confidence': 0.9
|
467
|
+
}
|
468
|
+
objects.append(obj)
|
469
|
+
|
470
|
+
except Exception as e:
|
471
|
+
logger.error(f"Failed to parse Florence-2 output: {e}")
|
472
|
+
|
473
|
+
return objects
|
474
|
+
|
475
|
+
def _parse_yolo_output(self, output: Any, image: Union[str, BinaryIO]) -> List[Dict[str, Any]]:
|
476
|
+
"""Parse YOLO detection output"""
|
477
|
+
objects = []
|
478
|
+
|
479
|
+
try:
|
480
|
+
# YOLO output format varies, handle common formats
|
481
|
+
if isinstance(output, list):
|
482
|
+
for i, detection in enumerate(output):
|
483
|
+
if isinstance(detection, dict):
|
484
|
+
bbox = detection.get('bbox', detection.get('box', []))
|
485
|
+
label = detection.get('class', detection.get('label', f'object_{i}'))
|
486
|
+
confidence = detection.get('confidence', detection.get('score', 0.9))
|
487
|
+
|
488
|
+
if len(bbox) >= 4:
|
489
|
+
x1, y1, x2, y2 = bbox[:4]
|
490
|
+
obj = {
|
491
|
+
'id': f'yolo_{i}',
|
492
|
+
'label': label,
|
493
|
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
494
|
+
'center': [int((x1 + x2) / 2), int((y1 + y2) / 2)],
|
495
|
+
'size': [int(x2 - x1), int(y2 - y1)],
|
496
|
+
'confidence': float(confidence)
|
497
|
+
}
|
498
|
+
objects.append(obj)
|
499
|
+
|
500
|
+
except Exception as e:
|
501
|
+
logger.error(f"Failed to parse YOLO output: {e}")
|
502
|
+
|
503
|
+
return objects
|
504
|
+
|
505
|
+
async def close(self):
|
506
|
+
"""Clean up resources"""
|
507
|
+
# Replicate doesn't need explicit cleanup
|
508
|
+
pass
|