isa-model 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +937 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/{storage/supabase_storage.py → models/model_repo.py} +72 -73
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +157 -3
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +104 -3
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +538 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/deployment/services/simple_auto_deploy_vision_service.py +275 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +257 -601
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -17
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -1
- isa_model/inference/services/{stacked → img}/flux_professional_service.py +25 -1
- isa_model/inference/services/{stacked → img/helpers}/base_stacked_service.py +40 -35
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +44 -31
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +492 -40
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +51 -17
- isa_model/inference/services/llm/openai_llm_service.py +70 -19
- isa_model/inference/services/llm/yyds_llm_service.py +24 -23
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +218 -117
- isa_model/inference/services/vision/{isA_vision_service.py → disabled/isA_vision_service.py} +98 -0
- isa_model/inference/services/{stacked → vision}/doc_analysis_service.py +1 -1
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +104 -307
- isa_model/inference/services/vision/replicate_vision_service.py +140 -325
- isa_model/inference/services/{stacked → vision}/ui_analysis_service.py +2 -498
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/api/fastapi_server.py +6 -1
- isa_model/serving/api/routes/unified.py +274 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/METADATA +4 -1
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/RECORD +78 -53
- isa_model/config/__init__.py +0 -9
- isa_model/config/config_manager.py +0 -213
- isa_model/core/model_manager.py +0 -213
- isa_model/core/model_registry.py +0 -375
- isa_model/core/vision_models_init.py +0 -116
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/stacked/__init__.py +0 -26
- isa_model/inference/services/stacked/config.py +0 -426
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- /isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/WHEEL +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,274 @@
|
|
1
|
+
"""
|
2
|
+
Base Stacked Service for orchestrating multiple AI models
|
3
|
+
"""
|
4
|
+
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from typing import Dict, Any, List, Optional, Union, Callable
|
7
|
+
import time
|
8
|
+
import asyncio
|
9
|
+
import logging
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from enum import Enum
|
12
|
+
|
13
|
+
# Import shared types from helpers
|
14
|
+
try:
|
15
|
+
from ..helpers.stacked_config import StackedLayerType as LayerType, LayerConfig, LayerResult
|
16
|
+
except ImportError:
|
17
|
+
# Fallback definitions if shared config is not available
|
18
|
+
class LayerType(Enum):
|
19
|
+
"""Types of processing layers"""
|
20
|
+
INTELLIGENCE = "intelligence"
|
21
|
+
DETECTION = "detection"
|
22
|
+
CLASSIFICATION = "classification"
|
23
|
+
VALIDATION = "validation"
|
24
|
+
TRANSFORMATION = "transformation"
|
25
|
+
GENERATION = "generation"
|
26
|
+
ENHANCEMENT = "enhancement"
|
27
|
+
CONTROL = "control"
|
28
|
+
UPSCALING = "upscaling"
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class LayerConfig:
|
32
|
+
"""Configuration for a processing layer"""
|
33
|
+
name: str
|
34
|
+
layer_type: LayerType
|
35
|
+
service_type: str
|
36
|
+
model_name: str
|
37
|
+
parameters: Dict[str, Any]
|
38
|
+
depends_on: List[str]
|
39
|
+
timeout: float = 30.0
|
40
|
+
retry_count: int = 1
|
41
|
+
fallback_enabled: bool = True
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class LayerResult:
|
45
|
+
"""Result from a processing layer"""
|
46
|
+
layer_name: str
|
47
|
+
success: bool
|
48
|
+
data: Any
|
49
|
+
metadata: Dict[str, Any]
|
50
|
+
execution_time: float
|
51
|
+
error: Optional[str] = None
|
52
|
+
|
53
|
+
logger = logging.getLogger(__name__)
|
54
|
+
|
55
|
+
class BaseStackedService(ABC):
|
56
|
+
"""
|
57
|
+
Base class for stacked services that orchestrate multiple AI models
|
58
|
+
"""
|
59
|
+
|
60
|
+
def __init__(self, ai_factory, service_name: str):
|
61
|
+
self.ai_factory = ai_factory
|
62
|
+
self.service_name = service_name
|
63
|
+
self.layers: List[LayerConfig] = []
|
64
|
+
self.services: Dict[str, Any] = {}
|
65
|
+
self.results: Dict[str, LayerResult] = {}
|
66
|
+
|
67
|
+
def add_layer(self, config: LayerConfig):
|
68
|
+
"""Add a processing layer to the stack"""
|
69
|
+
self.layers.append(config)
|
70
|
+
logger.info(f"Added layer {config.name} ({config.layer_type.value}) to {self.service_name}")
|
71
|
+
|
72
|
+
async def initialize_services(self):
|
73
|
+
"""Initialize all required services"""
|
74
|
+
for layer in self.layers:
|
75
|
+
service_key = f"{layer.service_type}_{layer.model_name}"
|
76
|
+
|
77
|
+
if service_key not in self.services:
|
78
|
+
if layer.service_type == 'vision':
|
79
|
+
if layer.model_name == "default":
|
80
|
+
# 使用默认vision服务
|
81
|
+
service = self.ai_factory.get_vision()
|
82
|
+
elif layer.model_name == "omniparser":
|
83
|
+
# 使用replicate omniparser
|
84
|
+
service = self.ai_factory.get_vision(model_name="omniparser", provider="replicate")
|
85
|
+
else:
|
86
|
+
# 其他指定模型
|
87
|
+
service = self.ai_factory.get_vision(model_name=layer.model_name)
|
88
|
+
elif layer.service_type == 'llm':
|
89
|
+
if layer.model_name == "default":
|
90
|
+
service = self.ai_factory.get_llm()
|
91
|
+
else:
|
92
|
+
service = self.ai_factory.get_llm(model_name=layer.model_name)
|
93
|
+
elif layer.service_type == 'image_gen':
|
94
|
+
if layer.model_name == "default":
|
95
|
+
service = self.ai_factory.get_image_gen()
|
96
|
+
else:
|
97
|
+
service = self.ai_factory.get_image_gen(model_name=layer.model_name)
|
98
|
+
else:
|
99
|
+
raise ValueError(f"Unsupported service type: {layer.service_type}")
|
100
|
+
|
101
|
+
self.services[service_key] = service
|
102
|
+
logger.info(f"Initialized {service_key} service")
|
103
|
+
|
104
|
+
async def execute_layer(self, layer: LayerConfig, context: Dict[str, Any]) -> LayerResult:
|
105
|
+
"""Execute a single layer"""
|
106
|
+
start_time = time.time()
|
107
|
+
|
108
|
+
try:
|
109
|
+
# Check dependencies
|
110
|
+
for dep in layer.depends_on:
|
111
|
+
if dep not in self.results or not self.results[dep].success:
|
112
|
+
raise ValueError(f"Dependency {dep} failed or not executed")
|
113
|
+
|
114
|
+
# Get the service
|
115
|
+
service_key = f"{layer.service_type}_{layer.model_name}"
|
116
|
+
service = self.services[service_key]
|
117
|
+
|
118
|
+
# Execute layer with timeout
|
119
|
+
data = await asyncio.wait_for(
|
120
|
+
self.execute_layer_logic(layer, service, context),
|
121
|
+
timeout=layer.timeout
|
122
|
+
)
|
123
|
+
|
124
|
+
execution_time = time.time() - start_time
|
125
|
+
|
126
|
+
result = LayerResult(
|
127
|
+
layer_name=layer.name,
|
128
|
+
success=True,
|
129
|
+
data=data,
|
130
|
+
metadata={
|
131
|
+
"layer_type": layer.layer_type.value,
|
132
|
+
"model": layer.model_name,
|
133
|
+
"parameters": layer.parameters
|
134
|
+
},
|
135
|
+
execution_time=execution_time
|
136
|
+
)
|
137
|
+
|
138
|
+
logger.info(f"Layer {layer.name} completed in {execution_time:.2f}s")
|
139
|
+
return result
|
140
|
+
|
141
|
+
except Exception as e:
|
142
|
+
execution_time = time.time() - start_time
|
143
|
+
error_msg = str(e)
|
144
|
+
|
145
|
+
logger.error(f"Layer {layer.name} failed after {execution_time:.2f}s: {error_msg}")
|
146
|
+
|
147
|
+
result = LayerResult(
|
148
|
+
layer_name=layer.name,
|
149
|
+
success=False,
|
150
|
+
data=None,
|
151
|
+
metadata={
|
152
|
+
"layer_type": layer.layer_type.value,
|
153
|
+
"model": layer.model_name,
|
154
|
+
"parameters": layer.parameters
|
155
|
+
},
|
156
|
+
execution_time=execution_time,
|
157
|
+
error=error_msg
|
158
|
+
)
|
159
|
+
|
160
|
+
# Try fallback if enabled
|
161
|
+
if layer.fallback_enabled:
|
162
|
+
fallback_result = await self.execute_fallback(layer, context, error_msg)
|
163
|
+
if fallback_result:
|
164
|
+
result.data = fallback_result
|
165
|
+
result.success = True
|
166
|
+
result.error = f"Fallback used: {error_msg}"
|
167
|
+
|
168
|
+
return result
|
169
|
+
|
170
|
+
@abstractmethod
|
171
|
+
async def execute_layer_logic(self, layer: LayerConfig, service: Any, context: Dict[str, Any]) -> Any:
|
172
|
+
"""Execute the specific logic for a layer - to be implemented by subclasses"""
|
173
|
+
pass
|
174
|
+
|
175
|
+
async def execute_fallback(self, layer: LayerConfig, context: Dict[str, Any], error: str) -> Optional[Any]:
|
176
|
+
"""Execute fallback logic for a failed layer - can be overridden by subclasses"""
|
177
|
+
return None
|
178
|
+
|
179
|
+
async def invoke(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
180
|
+
"""Invoke the entire stack of layers"""
|
181
|
+
logger.info(f"Starting {self.service_name} stack invocation")
|
182
|
+
stack_start_time = time.time()
|
183
|
+
|
184
|
+
# Initialize services if not done
|
185
|
+
if not self.services:
|
186
|
+
await self.initialize_services()
|
187
|
+
|
188
|
+
# Clear previous results
|
189
|
+
self.results.clear()
|
190
|
+
|
191
|
+
# Build execution order based on dependencies
|
192
|
+
execution_order = self._build_execution_order()
|
193
|
+
|
194
|
+
# Execute layers in order
|
195
|
+
context = {"input": input_data, "results": self.results}
|
196
|
+
|
197
|
+
for layer in execution_order:
|
198
|
+
result = await self.execute_layer(layer, context)
|
199
|
+
self.results[layer.name] = result
|
200
|
+
|
201
|
+
# Update context with result
|
202
|
+
context["results"] = self.results
|
203
|
+
|
204
|
+
# Stop if critical layer fails
|
205
|
+
if not result.success and not layer.fallback_enabled:
|
206
|
+
logger.error(f"Critical layer {layer.name} failed, stopping execution")
|
207
|
+
break
|
208
|
+
|
209
|
+
total_time = time.time() - stack_start_time
|
210
|
+
|
211
|
+
# Generate final result
|
212
|
+
final_result = {
|
213
|
+
"service": self.service_name,
|
214
|
+
"success": all(r.success for r in self.results.values()),
|
215
|
+
"total_execution_time": total_time,
|
216
|
+
"layer_results": {name: result for name, result in self.results.items()},
|
217
|
+
"final_output": self.generate_final_output(self.results)
|
218
|
+
}
|
219
|
+
|
220
|
+
logger.info(f"{self.service_name} stack invocation completed in {total_time:.2f}s")
|
221
|
+
return final_result
|
222
|
+
|
223
|
+
def _build_execution_order(self) -> List[LayerConfig]:
|
224
|
+
"""Build execution order based on dependencies"""
|
225
|
+
# Simple topological sort
|
226
|
+
ordered = []
|
227
|
+
remaining = self.layers.copy()
|
228
|
+
|
229
|
+
while remaining:
|
230
|
+
# Find layers with no unmet dependencies
|
231
|
+
ready = []
|
232
|
+
for layer in remaining:
|
233
|
+
deps_met = all(dep in [l.name for l in ordered] for dep in layer.depends_on)
|
234
|
+
if deps_met:
|
235
|
+
ready.append(layer)
|
236
|
+
|
237
|
+
if not ready:
|
238
|
+
raise ValueError("Circular dependency detected in layer configuration")
|
239
|
+
|
240
|
+
# Add ready layers to order
|
241
|
+
ordered.extend(ready)
|
242
|
+
for layer in ready:
|
243
|
+
remaining.remove(layer)
|
244
|
+
|
245
|
+
return ordered
|
246
|
+
|
247
|
+
@abstractmethod
|
248
|
+
def generate_final_output(self, results: Dict[str, LayerResult]) -> Any:
|
249
|
+
"""Generate final output from all layer results - to be implemented by subclasses"""
|
250
|
+
pass
|
251
|
+
|
252
|
+
async def close(self):
|
253
|
+
"""Close all services"""
|
254
|
+
for service in self.services.values():
|
255
|
+
if hasattr(service, 'close'):
|
256
|
+
await service.close()
|
257
|
+
self.services.clear()
|
258
|
+
logger.info(f"Closed all services for {self.service_name}")
|
259
|
+
|
260
|
+
def get_performance_metrics(self) -> Dict[str, Any]:
|
261
|
+
"""Get performance metrics for the stack"""
|
262
|
+
if not self.results:
|
263
|
+
return {}
|
264
|
+
|
265
|
+
metrics = {
|
266
|
+
"total_layers": len(self.results),
|
267
|
+
"successful_layers": sum(1 for r in self.results.values() if r.success),
|
268
|
+
"failed_layers": sum(1 for r in self.results.values() if not r.success),
|
269
|
+
"total_execution_time": sum(r.execution_time for r in self.results.values()),
|
270
|
+
"layer_times": {name: r.execution_time for name, r in self.results.items()},
|
271
|
+
"layer_success": {name: r.success for name, r in self.results.items()}
|
272
|
+
}
|
273
|
+
|
274
|
+
return metrics
|
@@ -1,12 +1,56 @@
|
|
1
1
|
from io import BytesIO
|
2
2
|
from PIL import Image
|
3
|
-
from typing import Union
|
3
|
+
from typing import Union, BinaryIO, Tuple
|
4
4
|
import base64
|
5
|
-
|
5
|
+
import requests
|
6
|
+
import os
|
6
7
|
import logging
|
7
8
|
|
8
9
|
logger = logging.getLogger(__name__)
|
9
10
|
|
11
|
+
def get_image_data(image: Union[str, BinaryIO]) -> bytes:
|
12
|
+
"""
|
13
|
+
从各种输入类型获取图像数据 (统一的图像数据获取函数)
|
14
|
+
|
15
|
+
Args:
|
16
|
+
image: 图像路径、URL或二进制数据
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
bytes: 原始图像数据
|
20
|
+
"""
|
21
|
+
try:
|
22
|
+
if isinstance(image, str):
|
23
|
+
if image.startswith(('http://', 'https://')):
|
24
|
+
# 从URL下载
|
25
|
+
response = requests.get(image)
|
26
|
+
response.raise_for_status()
|
27
|
+
return response.content
|
28
|
+
elif image.startswith('data:'):
|
29
|
+
# Data URL格式 (如 data:image/png;base64,...)
|
30
|
+
import base64
|
31
|
+
if 'base64,' in image:
|
32
|
+
base64_data = image.split('base64,')[1]
|
33
|
+
return base64.b64decode(base64_data)
|
34
|
+
else:
|
35
|
+
raise ValueError("Unsupported data URL format")
|
36
|
+
else:
|
37
|
+
# 本地文件路径
|
38
|
+
with open(image, 'rb') as f:
|
39
|
+
return f.read()
|
40
|
+
elif hasattr(image, 'read'):
|
41
|
+
# 文件类对象
|
42
|
+
data = image.read()
|
43
|
+
if isinstance(data, bytes):
|
44
|
+
return data
|
45
|
+
else:
|
46
|
+
raise ValueError("File-like object did not return bytes")
|
47
|
+
else:
|
48
|
+
# 假设是bytes
|
49
|
+
return bytes(image) if not isinstance(image, bytes) else image
|
50
|
+
except Exception as e:
|
51
|
+
logger.error(f"Error getting image data: {e}")
|
52
|
+
raise
|
53
|
+
|
10
54
|
def compress_image(image_data: Union[bytes, BytesIO], max_size: int = 1024) -> bytes:
|
11
55
|
"""压缩图片以减小大小
|
12
56
|
|
@@ -56,4 +100,229 @@ def encode_image_to_base64(image_data: bytes) -> str:
|
|
56
100
|
return base64.b64encode(image_data).decode('utf-8')
|
57
101
|
except Exception as e:
|
58
102
|
logger.error(f"Error encoding image to base64: {e}")
|
59
|
-
raise
|
103
|
+
raise
|
104
|
+
|
105
|
+
def prepare_image_base64(image: Union[str, BinaryIO], compress: bool = False, max_size: int = 1024) -> str:
|
106
|
+
"""
|
107
|
+
将图像准备为base64格式 (统一的base64编码函数)
|
108
|
+
|
109
|
+
Args:
|
110
|
+
image: 图像输入
|
111
|
+
compress: 是否压缩图像
|
112
|
+
max_size: 压缩时的最大尺寸
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
str: Base64编码的图像字符串
|
116
|
+
"""
|
117
|
+
try:
|
118
|
+
image_data = get_image_data(image)
|
119
|
+
|
120
|
+
if compress:
|
121
|
+
image_data = compress_image(image_data, max_size)
|
122
|
+
|
123
|
+
return encode_image_to_base64(image_data)
|
124
|
+
except Exception as e:
|
125
|
+
logger.error(f"Error preparing image base64: {e}")
|
126
|
+
raise
|
127
|
+
|
128
|
+
def prepare_image_data_url(image: Union[str, BinaryIO], compress: bool = False, max_size: int = 1024) -> str:
|
129
|
+
"""
|
130
|
+
将图像准备为data URL格式 (统一的data URL生成函数)
|
131
|
+
|
132
|
+
Args:
|
133
|
+
image: 图像输入
|
134
|
+
compress: 是否压缩图像
|
135
|
+
max_size: 压缩时的最大尺寸
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
str: data URL格式的图像字符串
|
139
|
+
"""
|
140
|
+
try:
|
141
|
+
base64_data = prepare_image_base64(image, compress, max_size)
|
142
|
+
mime_type = get_image_mime_type(image)
|
143
|
+
return f"data:{mime_type};base64,{base64_data}"
|
144
|
+
except Exception as e:
|
145
|
+
logger.error(f"Error preparing image data URL: {e}")
|
146
|
+
raise
|
147
|
+
|
148
|
+
def get_image_mime_type(image: Union[str, BinaryIO]) -> str:
|
149
|
+
"""
|
150
|
+
获取图像的MIME类型 (统一的MIME类型检测函数)
|
151
|
+
|
152
|
+
Args:
|
153
|
+
image: 图像输入
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
str: MIME类型
|
157
|
+
"""
|
158
|
+
try:
|
159
|
+
if isinstance(image, str):
|
160
|
+
# 文件路径 - 检查扩展名
|
161
|
+
ext = os.path.splitext(image)[1].lower()
|
162
|
+
mime_mapping = {
|
163
|
+
'.jpg': 'image/jpeg',
|
164
|
+
'.jpeg': 'image/jpeg',
|
165
|
+
'.png': 'image/png',
|
166
|
+
'.gif': 'image/gif',
|
167
|
+
'.webp': 'image/webp',
|
168
|
+
'.bmp': 'image/bmp',
|
169
|
+
'.tiff': 'image/tiff'
|
170
|
+
}
|
171
|
+
return mime_mapping.get(ext, 'image/jpeg')
|
172
|
+
else:
|
173
|
+
# 尝试从图像数据检测
|
174
|
+
image_data = get_image_data(image)
|
175
|
+
img = Image.open(BytesIO(image_data))
|
176
|
+
format_mapping = {
|
177
|
+
'JPEG': 'image/jpeg',
|
178
|
+
'PNG': 'image/png',
|
179
|
+
'GIF': 'image/gif',
|
180
|
+
'WEBP': 'image/webp',
|
181
|
+
'BMP': 'image/bmp',
|
182
|
+
'TIFF': 'image/tiff'
|
183
|
+
}
|
184
|
+
return format_mapping.get(img.format, 'image/jpeg')
|
185
|
+
except Exception:
|
186
|
+
# 默认回退
|
187
|
+
return 'image/jpeg'
|
188
|
+
|
189
|
+
def get_image_dimensions(image: Union[str, BinaryIO]) -> Tuple[int, int]:
|
190
|
+
"""
|
191
|
+
获取图像尺寸 (统一的尺寸获取函数)
|
192
|
+
|
193
|
+
Args:
|
194
|
+
image: 图像输入
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
tuple: (width, height)
|
198
|
+
"""
|
199
|
+
try:
|
200
|
+
image_data = get_image_data(image)
|
201
|
+
img = Image.open(BytesIO(image_data))
|
202
|
+
return img.size
|
203
|
+
except Exception as e:
|
204
|
+
logger.error(f"Error getting image dimensions: {e}")
|
205
|
+
return (0, 0)
|
206
|
+
|
207
|
+
def validate_image_format(image: Union[str, BinaryIO], supported_formats: list = None) -> bool:
|
208
|
+
"""
|
209
|
+
验证图像格式是否受支持 (统一的格式验证函数)
|
210
|
+
|
211
|
+
Args:
|
212
|
+
image: 图像输入
|
213
|
+
supported_formats: 支持的格式列表,默认为常见格式
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
bool: 如果支持则为True
|
217
|
+
"""
|
218
|
+
if supported_formats is None:
|
219
|
+
supported_formats = ['jpg', 'jpeg', 'png', 'gif', 'webp', 'bmp', 'tiff']
|
220
|
+
|
221
|
+
try:
|
222
|
+
if isinstance(image, str):
|
223
|
+
ext = os.path.splitext(image)[1].lower().lstrip('.')
|
224
|
+
return ext in supported_formats
|
225
|
+
else:
|
226
|
+
# 检查实际图像格式
|
227
|
+
image_data = get_image_data(image)
|
228
|
+
img = Image.open(BytesIO(image_data))
|
229
|
+
return img.format.lower() in [fmt.upper() for fmt in supported_formats]
|
230
|
+
except Exception as e:
|
231
|
+
logger.warning(f"Could not validate image format: {e}")
|
232
|
+
return True # 默认允许
|
233
|
+
|
234
|
+
def parse_coordinates_from_text(text: str) -> list:
|
235
|
+
"""
|
236
|
+
从文本响应中解析对象坐标 (统一的解析逻辑)
|
237
|
+
|
238
|
+
Args:
|
239
|
+
text: 包含坐标信息的文本响应
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
list: 解析出的对象列表,每个对象包含label, confidence, coordinates, description
|
243
|
+
"""
|
244
|
+
objects = []
|
245
|
+
lines = text.split('\n')
|
246
|
+
|
247
|
+
for line in lines:
|
248
|
+
line = line.strip()
|
249
|
+
if line and ':' in line and ('x=' in line or 'width=' in line):
|
250
|
+
try:
|
251
|
+
# 提取对象名称和详细信息
|
252
|
+
parts = line.split(':', 1)
|
253
|
+
if len(parts) == 2:
|
254
|
+
object_name = parts[0].strip()
|
255
|
+
details = parts[1].strip()
|
256
|
+
|
257
|
+
# 使用类似正则表达式的解析提取坐标
|
258
|
+
coords = {}
|
259
|
+
for param in ['x', 'y', 'width', 'height']:
|
260
|
+
param_pattern = f"{param}="
|
261
|
+
if param_pattern in details:
|
262
|
+
start_idx = details.find(param_pattern) + len(param_pattern)
|
263
|
+
end_idx = details.find('%', start_idx)
|
264
|
+
if end_idx > start_idx:
|
265
|
+
try:
|
266
|
+
value = float(details[start_idx:end_idx])
|
267
|
+
coords[param] = value
|
268
|
+
except ValueError:
|
269
|
+
continue
|
270
|
+
|
271
|
+
# 提取描述(坐标之后)
|
272
|
+
desc_start = details.find(' - ')
|
273
|
+
description = details[desc_start + 3:] if desc_start != -1 else details
|
274
|
+
|
275
|
+
objects.append({
|
276
|
+
"label": object_name,
|
277
|
+
"confidence": 1.0,
|
278
|
+
"coordinates": coords,
|
279
|
+
"description": description
|
280
|
+
})
|
281
|
+
|
282
|
+
except Exception:
|
283
|
+
# 对于不匹配预期格式的对象的回退
|
284
|
+
objects.append({
|
285
|
+
"label": line,
|
286
|
+
"confidence": 1.0,
|
287
|
+
"coordinates": {},
|
288
|
+
"description": line
|
289
|
+
})
|
290
|
+
|
291
|
+
return objects
|
292
|
+
|
293
|
+
def parse_center_coordinates_from_text(text: str) -> tuple:
|
294
|
+
"""
|
295
|
+
从结构化文本响应中解析中心坐标 (统一的解析逻辑)
|
296
|
+
|
297
|
+
Args:
|
298
|
+
text: 包含FOUND/CENTER/DESCRIPTION格式的文本响应
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
tuple: (found: bool, center_coords: List[int] | None, description: str)
|
302
|
+
"""
|
303
|
+
found = False
|
304
|
+
center_coords = None
|
305
|
+
description = ""
|
306
|
+
|
307
|
+
lines = text.split('\n')
|
308
|
+
for line in lines:
|
309
|
+
line = line.strip()
|
310
|
+
if line.startswith('FOUND:'):
|
311
|
+
found = 'YES' in line.upper()
|
312
|
+
elif line.startswith('CENTER:') and found:
|
313
|
+
# 提取中心坐标 [x, y]
|
314
|
+
coords_text = line.replace('CENTER:', '').strip()
|
315
|
+
try:
|
316
|
+
# 移除括号并分割
|
317
|
+
coords_text = coords_text.replace('[', '').replace(']', '')
|
318
|
+
if ',' in coords_text:
|
319
|
+
x_str, y_str = coords_text.split(',')
|
320
|
+
x = int(float(x_str.strip()))
|
321
|
+
y = int(float(y_str.strip()))
|
322
|
+
center_coords = [x, y]
|
323
|
+
except (ValueError, IndexError):
|
324
|
+
pass
|
325
|
+
elif line.startswith('DESCRIPTION:'):
|
326
|
+
description = line.replace('DESCRIPTION:', '').strip()
|
327
|
+
|
328
|
+
return found, center_coords, description
|