isa-model 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +770 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/models/model_repo.py +343 -0
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/__init__.py +9 -0
- isa_model/deployment/cloud/modal/__init__.py +10 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +766 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +406 -0
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +187 -387
- isa_model/inference/providers/modal_provider.py +109 -0
- isa_model/inference/providers/yyds_provider.py +108 -0
- isa_model/inference/services/__init__.py +2 -1
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -55
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -35
- isa_model/inference/services/img/flux_professional_service.py +603 -0
- isa_model/inference/services/img/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +210 -69
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +519 -35
- isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +40 -0
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +150 -15
- isa_model/inference/services/llm/openai_llm_service.py +134 -31
- isa_model/inference/services/llm/yyds_llm_service.py +255 -0
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +241 -96
- isa_model/inference/services/vision/disabled/isA_vision_service.py +500 -0
- isa_model/inference/services/vision/doc_analysis_service.py +640 -0
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +109 -170
- isa_model/inference/services/vision/replicate_vision_service.py +508 -0
- isa_model/inference/services/vision/ui_analysis_service.py +823 -0
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/__init__.py +19 -0
- isa_model/serving/api/__init__.py +10 -0
- isa_model/serving/api/fastapi_server.py +89 -0
- isa_model/serving/api/middleware/__init__.py +9 -0
- isa_model/serving/api/middleware/request_logger.py +88 -0
- isa_model/serving/api/routes/__init__.py +5 -0
- isa_model/serving/api/routes/health.py +82 -0
- isa_model/serving/api/routes/llm.py +19 -0
- isa_model/serving/api/routes/ui_analysis.py +223 -0
- isa_model/serving/api/routes/unified.py +202 -0
- isa_model/serving/api/routes/vision.py +19 -0
- isa_model/serving/api/schemas/__init__.py +17 -0
- isa_model/serving/api/schemas/common.py +33 -0
- isa_model/serving/api/schemas/ui_analysis.py +78 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
- isa_model-0.3.6.dist-info/RECORD +147 -0
- isa_model/core/model_manager.py +0 -208
- isa_model/core/model_registry.py +0 -342
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- isa_model-0.3.4.dist-info/RECORD +0 -91
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,274 @@
|
|
1
|
+
"""
|
2
|
+
Base Stacked Service for orchestrating multiple AI models
|
3
|
+
"""
|
4
|
+
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from typing import Dict, Any, List, Optional, Union, Callable
|
7
|
+
import time
|
8
|
+
import asyncio
|
9
|
+
import logging
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from enum import Enum
|
12
|
+
|
13
|
+
# Import shared types from helpers
|
14
|
+
try:
|
15
|
+
from ..helpers.stacked_config import StackedLayerType as LayerType, LayerConfig, LayerResult
|
16
|
+
except ImportError:
|
17
|
+
# Fallback definitions if shared config is not available
|
18
|
+
class LayerType(Enum):
|
19
|
+
"""Types of processing layers"""
|
20
|
+
INTELLIGENCE = "intelligence"
|
21
|
+
DETECTION = "detection"
|
22
|
+
CLASSIFICATION = "classification"
|
23
|
+
VALIDATION = "validation"
|
24
|
+
TRANSFORMATION = "transformation"
|
25
|
+
GENERATION = "generation"
|
26
|
+
ENHANCEMENT = "enhancement"
|
27
|
+
CONTROL = "control"
|
28
|
+
UPSCALING = "upscaling"
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class LayerConfig:
|
32
|
+
"""Configuration for a processing layer"""
|
33
|
+
name: str
|
34
|
+
layer_type: LayerType
|
35
|
+
service_type: str
|
36
|
+
model_name: str
|
37
|
+
parameters: Dict[str, Any]
|
38
|
+
depends_on: List[str]
|
39
|
+
timeout: float = 30.0
|
40
|
+
retry_count: int = 1
|
41
|
+
fallback_enabled: bool = True
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class LayerResult:
|
45
|
+
"""Result from a processing layer"""
|
46
|
+
layer_name: str
|
47
|
+
success: bool
|
48
|
+
data: Any
|
49
|
+
metadata: Dict[str, Any]
|
50
|
+
execution_time: float
|
51
|
+
error: Optional[str] = None
|
52
|
+
|
53
|
+
logger = logging.getLogger(__name__)
|
54
|
+
|
55
|
+
class BaseStackedService(ABC):
|
56
|
+
"""
|
57
|
+
Base class for stacked services that orchestrate multiple AI models
|
58
|
+
"""
|
59
|
+
|
60
|
+
def __init__(self, ai_factory, service_name: str):
|
61
|
+
self.ai_factory = ai_factory
|
62
|
+
self.service_name = service_name
|
63
|
+
self.layers: List[LayerConfig] = []
|
64
|
+
self.services: Dict[str, Any] = {}
|
65
|
+
self.results: Dict[str, LayerResult] = {}
|
66
|
+
|
67
|
+
def add_layer(self, config: LayerConfig):
|
68
|
+
"""Add a processing layer to the stack"""
|
69
|
+
self.layers.append(config)
|
70
|
+
logger.info(f"Added layer {config.name} ({config.layer_type.value}) to {self.service_name}")
|
71
|
+
|
72
|
+
async def initialize_services(self):
|
73
|
+
"""Initialize all required services"""
|
74
|
+
for layer in self.layers:
|
75
|
+
service_key = f"{layer.service_type}_{layer.model_name}"
|
76
|
+
|
77
|
+
if service_key not in self.services:
|
78
|
+
if layer.service_type == 'vision':
|
79
|
+
if layer.model_name == "default":
|
80
|
+
# 使用默认vision服务
|
81
|
+
service = self.ai_factory.get_vision()
|
82
|
+
elif layer.model_name == "omniparser":
|
83
|
+
# 使用replicate omniparser
|
84
|
+
service = self.ai_factory.get_vision(model_name="omniparser", provider="replicate")
|
85
|
+
else:
|
86
|
+
# 其他指定模型
|
87
|
+
service = self.ai_factory.get_vision(model_name=layer.model_name)
|
88
|
+
elif layer.service_type == 'llm':
|
89
|
+
if layer.model_name == "default":
|
90
|
+
service = self.ai_factory.get_llm()
|
91
|
+
else:
|
92
|
+
service = self.ai_factory.get_llm(model_name=layer.model_name)
|
93
|
+
elif layer.service_type == 'image_gen':
|
94
|
+
if layer.model_name == "default":
|
95
|
+
service = self.ai_factory.get_image_gen()
|
96
|
+
else:
|
97
|
+
service = self.ai_factory.get_image_gen(model_name=layer.model_name)
|
98
|
+
else:
|
99
|
+
raise ValueError(f"Unsupported service type: {layer.service_type}")
|
100
|
+
|
101
|
+
self.services[service_key] = service
|
102
|
+
logger.info(f"Initialized {service_key} service")
|
103
|
+
|
104
|
+
async def execute_layer(self, layer: LayerConfig, context: Dict[str, Any]) -> LayerResult:
|
105
|
+
"""Execute a single layer"""
|
106
|
+
start_time = time.time()
|
107
|
+
|
108
|
+
try:
|
109
|
+
# Check dependencies
|
110
|
+
for dep in layer.depends_on:
|
111
|
+
if dep not in self.results or not self.results[dep].success:
|
112
|
+
raise ValueError(f"Dependency {dep} failed or not executed")
|
113
|
+
|
114
|
+
# Get the service
|
115
|
+
service_key = f"{layer.service_type}_{layer.model_name}"
|
116
|
+
service = self.services[service_key]
|
117
|
+
|
118
|
+
# Execute layer with timeout
|
119
|
+
data = await asyncio.wait_for(
|
120
|
+
self.execute_layer_logic(layer, service, context),
|
121
|
+
timeout=layer.timeout
|
122
|
+
)
|
123
|
+
|
124
|
+
execution_time = time.time() - start_time
|
125
|
+
|
126
|
+
result = LayerResult(
|
127
|
+
layer_name=layer.name,
|
128
|
+
success=True,
|
129
|
+
data=data,
|
130
|
+
metadata={
|
131
|
+
"layer_type": layer.layer_type.value,
|
132
|
+
"model": layer.model_name,
|
133
|
+
"parameters": layer.parameters
|
134
|
+
},
|
135
|
+
execution_time=execution_time
|
136
|
+
)
|
137
|
+
|
138
|
+
logger.info(f"Layer {layer.name} completed in {execution_time:.2f}s")
|
139
|
+
return result
|
140
|
+
|
141
|
+
except Exception as e:
|
142
|
+
execution_time = time.time() - start_time
|
143
|
+
error_msg = str(e)
|
144
|
+
|
145
|
+
logger.error(f"Layer {layer.name} failed after {execution_time:.2f}s: {error_msg}")
|
146
|
+
|
147
|
+
result = LayerResult(
|
148
|
+
layer_name=layer.name,
|
149
|
+
success=False,
|
150
|
+
data=None,
|
151
|
+
metadata={
|
152
|
+
"layer_type": layer.layer_type.value,
|
153
|
+
"model": layer.model_name,
|
154
|
+
"parameters": layer.parameters
|
155
|
+
},
|
156
|
+
execution_time=execution_time,
|
157
|
+
error=error_msg
|
158
|
+
)
|
159
|
+
|
160
|
+
# Try fallback if enabled
|
161
|
+
if layer.fallback_enabled:
|
162
|
+
fallback_result = await self.execute_fallback(layer, context, error_msg)
|
163
|
+
if fallback_result:
|
164
|
+
result.data = fallback_result
|
165
|
+
result.success = True
|
166
|
+
result.error = f"Fallback used: {error_msg}"
|
167
|
+
|
168
|
+
return result
|
169
|
+
|
170
|
+
@abstractmethod
|
171
|
+
async def execute_layer_logic(self, layer: LayerConfig, service: Any, context: Dict[str, Any]) -> Any:
|
172
|
+
"""Execute the specific logic for a layer - to be implemented by subclasses"""
|
173
|
+
pass
|
174
|
+
|
175
|
+
async def execute_fallback(self, layer: LayerConfig, context: Dict[str, Any], error: str) -> Optional[Any]:
|
176
|
+
"""Execute fallback logic for a failed layer - can be overridden by subclasses"""
|
177
|
+
return None
|
178
|
+
|
179
|
+
async def invoke(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
180
|
+
"""Invoke the entire stack of layers"""
|
181
|
+
logger.info(f"Starting {self.service_name} stack invocation")
|
182
|
+
stack_start_time = time.time()
|
183
|
+
|
184
|
+
# Initialize services if not done
|
185
|
+
if not self.services:
|
186
|
+
await self.initialize_services()
|
187
|
+
|
188
|
+
# Clear previous results
|
189
|
+
self.results.clear()
|
190
|
+
|
191
|
+
# Build execution order based on dependencies
|
192
|
+
execution_order = self._build_execution_order()
|
193
|
+
|
194
|
+
# Execute layers in order
|
195
|
+
context = {"input": input_data, "results": self.results}
|
196
|
+
|
197
|
+
for layer in execution_order:
|
198
|
+
result = await self.execute_layer(layer, context)
|
199
|
+
self.results[layer.name] = result
|
200
|
+
|
201
|
+
# Update context with result
|
202
|
+
context["results"] = self.results
|
203
|
+
|
204
|
+
# Stop if critical layer fails
|
205
|
+
if not result.success and not layer.fallback_enabled:
|
206
|
+
logger.error(f"Critical layer {layer.name} failed, stopping execution")
|
207
|
+
break
|
208
|
+
|
209
|
+
total_time = time.time() - stack_start_time
|
210
|
+
|
211
|
+
# Generate final result
|
212
|
+
final_result = {
|
213
|
+
"service": self.service_name,
|
214
|
+
"success": all(r.success for r in self.results.values()),
|
215
|
+
"total_execution_time": total_time,
|
216
|
+
"layer_results": {name: result for name, result in self.results.items()},
|
217
|
+
"final_output": self.generate_final_output(self.results)
|
218
|
+
}
|
219
|
+
|
220
|
+
logger.info(f"{self.service_name} stack invocation completed in {total_time:.2f}s")
|
221
|
+
return final_result
|
222
|
+
|
223
|
+
def _build_execution_order(self) -> List[LayerConfig]:
|
224
|
+
"""Build execution order based on dependencies"""
|
225
|
+
# Simple topological sort
|
226
|
+
ordered = []
|
227
|
+
remaining = self.layers.copy()
|
228
|
+
|
229
|
+
while remaining:
|
230
|
+
# Find layers with no unmet dependencies
|
231
|
+
ready = []
|
232
|
+
for layer in remaining:
|
233
|
+
deps_met = all(dep in [l.name for l in ordered] for dep in layer.depends_on)
|
234
|
+
if deps_met:
|
235
|
+
ready.append(layer)
|
236
|
+
|
237
|
+
if not ready:
|
238
|
+
raise ValueError("Circular dependency detected in layer configuration")
|
239
|
+
|
240
|
+
# Add ready layers to order
|
241
|
+
ordered.extend(ready)
|
242
|
+
for layer in ready:
|
243
|
+
remaining.remove(layer)
|
244
|
+
|
245
|
+
return ordered
|
246
|
+
|
247
|
+
@abstractmethod
|
248
|
+
def generate_final_output(self, results: Dict[str, LayerResult]) -> Any:
|
249
|
+
"""Generate final output from all layer results - to be implemented by subclasses"""
|
250
|
+
pass
|
251
|
+
|
252
|
+
async def close(self):
|
253
|
+
"""Close all services"""
|
254
|
+
for service in self.services.values():
|
255
|
+
if hasattr(service, 'close'):
|
256
|
+
await service.close()
|
257
|
+
self.services.clear()
|
258
|
+
logger.info(f"Closed all services for {self.service_name}")
|
259
|
+
|
260
|
+
def get_performance_metrics(self) -> Dict[str, Any]:
|
261
|
+
"""Get performance metrics for the stack"""
|
262
|
+
if not self.results:
|
263
|
+
return {}
|
264
|
+
|
265
|
+
metrics = {
|
266
|
+
"total_layers": len(self.results),
|
267
|
+
"successful_layers": sum(1 for r in self.results.values() if r.success),
|
268
|
+
"failed_layers": sum(1 for r in self.results.values() if not r.success),
|
269
|
+
"total_execution_time": sum(r.execution_time for r in self.results.values()),
|
270
|
+
"layer_times": {name: r.execution_time for name, r in self.results.items()},
|
271
|
+
"layer_success": {name: r.success for name, r in self.results.items()}
|
272
|
+
}
|
273
|
+
|
274
|
+
return metrics
|
@@ -17,42 +17,41 @@ import replicate
|
|
17
17
|
from PIL import Image
|
18
18
|
from io import BytesIO
|
19
19
|
|
20
|
-
from
|
21
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
20
|
+
from .base_image_gen_service import BaseImageGenService
|
22
21
|
|
23
|
-
# 设置日志记录
|
24
|
-
logging.basicConfig(level=logging.INFO)
|
25
22
|
logger = logging.getLogger(__name__)
|
26
23
|
|
27
24
|
class ReplicateImageGenService(BaseImageGenService):
|
28
25
|
"""
|
29
|
-
Replicate 图像生成服务
|
26
|
+
Replicate 图像生成服务 with unified architecture
|
30
27
|
- flux-schnell: 文生图 (t2i) - $3 per 1000 images
|
31
28
|
- flux-kontext-pro: 图生图 (i2i) - $0.04 per image
|
32
29
|
"""
|
33
30
|
|
34
|
-
def __init__(self,
|
35
|
-
super().__init__(
|
31
|
+
def __init__(self, provider_name: str, model_name: str, **kwargs):
|
32
|
+
super().__init__(provider_name, model_name, **kwargs)
|
36
33
|
|
37
|
-
#
|
38
|
-
provider_config =
|
39
|
-
self.api_token = provider_config.get("api_token") or provider_config.get("replicate_api_token")
|
34
|
+
# Get configuration from centralized config manager
|
35
|
+
provider_config = self.get_provider_config()
|
40
36
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
37
|
+
try:
|
38
|
+
self.api_token = provider_config.get("api_key") or provider_config.get("replicate_api_token")
|
39
|
+
|
40
|
+
if not self.api_token:
|
41
|
+
raise ValueError("Replicate API token not found in provider configuration")
|
42
|
+
|
43
|
+
# Set API token
|
44
|
+
os.environ["REPLICATE_API_TOKEN"] = self.api_token
|
45
|
+
|
46
|
+
# Statistics
|
47
|
+
self.last_generation_count = 0
|
48
|
+
self.total_generation_count = 0
|
49
|
+
|
50
|
+
logger.info(f"Initialized ReplicateImageGenService with model '{self.model_name}'")
|
51
|
+
|
52
|
+
except Exception as e:
|
53
|
+
logger.error(f"Failed to initialize Replicate client: {e}")
|
54
|
+
raise ValueError(f"Failed to initialize Replicate client: {e}") from e
|
56
55
|
|
57
56
|
async def generate_image(
|
58
57
|
self,
|
@@ -133,6 +132,161 @@ class ReplicateImageGenService(BaseImageGenService):
|
|
133
132
|
|
134
133
|
return await self._generate_internal(input_data)
|
135
134
|
|
135
|
+
async def instant_id_generation(
|
136
|
+
self,
|
137
|
+
prompt: str,
|
138
|
+
face_image: Union[str, Any],
|
139
|
+
negative_prompt: Optional[str] = None,
|
140
|
+
num_inference_steps: int = 30,
|
141
|
+
guidance_scale: float = 5.0,
|
142
|
+
seed: Optional[int] = None,
|
143
|
+
identitynet_strength_ratio: float = 0.8,
|
144
|
+
adapter_strength_ratio: float = 0.8
|
145
|
+
) -> Dict[str, Any]:
|
146
|
+
"""InstantID人脸一致性生成"""
|
147
|
+
|
148
|
+
if "instant-id" in self.model_name:
|
149
|
+
input_data = {
|
150
|
+
"prompt": prompt,
|
151
|
+
"image": face_image,
|
152
|
+
"guidance_scale": guidance_scale,
|
153
|
+
"num_inference_steps": num_inference_steps,
|
154
|
+
"identitynet_strength_ratio": identitynet_strength_ratio,
|
155
|
+
"adapter_strength_ratio": adapter_strength_ratio
|
156
|
+
}
|
157
|
+
|
158
|
+
if negative_prompt:
|
159
|
+
input_data["negative_prompt"] = negative_prompt
|
160
|
+
if seed:
|
161
|
+
input_data["seed"] = seed
|
162
|
+
else:
|
163
|
+
# 默认InstantID参数
|
164
|
+
input_data = {
|
165
|
+
"prompt": prompt,
|
166
|
+
"face_image": face_image,
|
167
|
+
"negative_prompt": negative_prompt or "",
|
168
|
+
"num_inference_steps": num_inference_steps,
|
169
|
+
"guidance_scale": guidance_scale,
|
170
|
+
"identitynet_strength_ratio": identitynet_strength_ratio,
|
171
|
+
"adapter_strength_ratio": adapter_strength_ratio
|
172
|
+
}
|
173
|
+
|
174
|
+
if seed:
|
175
|
+
input_data["seed"] = seed
|
176
|
+
|
177
|
+
return await self._generate_internal(input_data)
|
178
|
+
|
179
|
+
async def consistent_character_generation(
|
180
|
+
self,
|
181
|
+
subject: Union[str, Any],
|
182
|
+
prompt: Optional[str] = None,
|
183
|
+
negative_prompt: Optional[str] = None,
|
184
|
+
number_of_images: int = 4,
|
185
|
+
disable_safety_checker: bool = False
|
186
|
+
) -> Dict[str, Any]:
|
187
|
+
"""一致性角色生成 - 生成同一角色的多种姿态和表情"""
|
188
|
+
|
189
|
+
if "consistent-character" in self.model_name:
|
190
|
+
input_data = {
|
191
|
+
"subject": subject,
|
192
|
+
"number_of_images": number_of_images,
|
193
|
+
"disable_safety_checker": disable_safety_checker
|
194
|
+
}
|
195
|
+
|
196
|
+
if prompt:
|
197
|
+
input_data["prompt"] = prompt
|
198
|
+
if negative_prompt:
|
199
|
+
input_data["negative_prompt"] = negative_prompt
|
200
|
+
else:
|
201
|
+
# 默认一致性角色参数
|
202
|
+
input_data = {
|
203
|
+
"subject_image": subject,
|
204
|
+
"prompt": prompt or "portrait, different poses and expressions",
|
205
|
+
"negative_prompt": negative_prompt or "low quality, blurry",
|
206
|
+
"num_images": number_of_images
|
207
|
+
}
|
208
|
+
|
209
|
+
return await self._generate_internal(input_data)
|
210
|
+
|
211
|
+
async def flux_lora_generation(
|
212
|
+
self,
|
213
|
+
prompt: str,
|
214
|
+
lora_scale: float = 1.0,
|
215
|
+
num_outputs: int = 1,
|
216
|
+
aspect_ratio: str = "1:1",
|
217
|
+
output_format: str = "jpg",
|
218
|
+
guidance_scale: float = 3.5,
|
219
|
+
output_quality: int = 90,
|
220
|
+
num_inference_steps: int = 28,
|
221
|
+
disable_safety_checker: bool = False
|
222
|
+
) -> Dict[str, Any]:
|
223
|
+
"""FLUX LoRA生成 - 使用预训练的LoRA权重"""
|
224
|
+
|
225
|
+
if any(lora in self.model_name for lora in ["flux-dev-lora", "flux-lora"]):
|
226
|
+
input_data = {
|
227
|
+
"prompt": prompt,
|
228
|
+
"lora_scale": lora_scale,
|
229
|
+
"num_outputs": num_outputs,
|
230
|
+
"aspect_ratio": aspect_ratio,
|
231
|
+
"output_format": output_format,
|
232
|
+
"guidance_scale": guidance_scale,
|
233
|
+
"output_quality": output_quality,
|
234
|
+
"num_inference_steps": num_inference_steps,
|
235
|
+
"disable_safety_checker": disable_safety_checker
|
236
|
+
}
|
237
|
+
else:
|
238
|
+
# 默认LoRA参数
|
239
|
+
input_data = {
|
240
|
+
"prompt": prompt,
|
241
|
+
"lora_strength": lora_scale,
|
242
|
+
"num_images": num_outputs,
|
243
|
+
"guidance_scale": guidance_scale,
|
244
|
+
"num_inference_steps": num_inference_steps
|
245
|
+
}
|
246
|
+
|
247
|
+
return await self._generate_internal(input_data)
|
248
|
+
|
249
|
+
async def ultimate_upscale(
|
250
|
+
self,
|
251
|
+
image: Union[str, Any],
|
252
|
+
scale: int = 4,
|
253
|
+
scheduler: str = "K_EULER_ANCESTRAL",
|
254
|
+
num_inference_steps: int = 20,
|
255
|
+
guidance_scale: float = 10.0,
|
256
|
+
strength: float = 0.55,
|
257
|
+
hdr: float = 0.0,
|
258
|
+
seed: Optional[int] = None
|
259
|
+
) -> Dict[str, Any]:
|
260
|
+
"""Ultimate SD Upscaler - 专业超分辨率"""
|
261
|
+
|
262
|
+
if "ultimate" in self.model_name or "upscal" in self.model_name:
|
263
|
+
input_data = {
|
264
|
+
"image": image,
|
265
|
+
"scale": scale,
|
266
|
+
"scheduler": scheduler,
|
267
|
+
"num_inference_steps": num_inference_steps,
|
268
|
+
"guidance_scale": guidance_scale,
|
269
|
+
"strength": strength,
|
270
|
+
"hdr": hdr
|
271
|
+
}
|
272
|
+
|
273
|
+
if seed:
|
274
|
+
input_data["seed"] = seed
|
275
|
+
else:
|
276
|
+
# 默认超分辨率参数
|
277
|
+
input_data = {
|
278
|
+
"image": image,
|
279
|
+
"upscale_factor": scale,
|
280
|
+
"num_inference_steps": num_inference_steps,
|
281
|
+
"guidance_scale": guidance_scale,
|
282
|
+
"denoising_strength": strength
|
283
|
+
}
|
284
|
+
|
285
|
+
if seed:
|
286
|
+
input_data["seed"] = seed
|
287
|
+
|
288
|
+
return await self._generate_internal(input_data)
|
289
|
+
|
136
290
|
async def _generate_internal(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
137
291
|
"""内部生成方法"""
|
138
292
|
try:
|
@@ -141,11 +295,19 @@ class ReplicateImageGenService(BaseImageGenService):
|
|
141
295
|
# 调用 Replicate API
|
142
296
|
output = await replicate.async_run(self.model_name, input=input_data)
|
143
297
|
|
144
|
-
# 处理输出
|
298
|
+
# 处理输出 - 转换FileOutput对象为URL字符串
|
145
299
|
if isinstance(output, list):
|
146
|
-
|
300
|
+
raw_urls = output
|
147
301
|
else:
|
148
|
-
|
302
|
+
raw_urls = [output]
|
303
|
+
|
304
|
+
# 转换为字符串URL
|
305
|
+
urls = []
|
306
|
+
for url in raw_urls:
|
307
|
+
if hasattr(url, 'url'):
|
308
|
+
urls.append(str(url.url)) # type: ignore
|
309
|
+
else:
|
310
|
+
urls.append(str(url))
|
149
311
|
|
150
312
|
# 更新统计
|
151
313
|
self.last_generation_count = len(urls)
|
@@ -154,25 +316,35 @@ class ReplicateImageGenService(BaseImageGenService):
|
|
154
316
|
# 计算成本
|
155
317
|
cost = self._calculate_cost(len(urls))
|
156
318
|
|
157
|
-
#
|
158
|
-
|
159
|
-
|
160
|
-
service_type=ServiceType.IMAGE_GENERATION,
|
319
|
+
# Track billing information
|
320
|
+
await self._track_usage(
|
321
|
+
service_type="image_generation",
|
161
322
|
operation="image_generation",
|
162
|
-
|
323
|
+
input_tokens=0,
|
324
|
+
output_tokens=0,
|
325
|
+
input_units=1, # Input prompt
|
326
|
+
output_units=len(urls), # Generated images count
|
163
327
|
metadata={
|
164
328
|
"model": self.model_name,
|
165
|
-
"prompt": input_data.get("prompt", "")[:100], #
|
166
|
-
"generation_type": "t2i" if "flux-schnell" in self.model_name else "i2i"
|
329
|
+
"prompt": input_data.get("prompt", "")[:100], # Truncate to 100 chars
|
330
|
+
"generation_type": "t2i" if "flux-schnell" in self.model_name else "i2i",
|
331
|
+
"image_count": len(urls),
|
332
|
+
"cost_usd": cost
|
167
333
|
}
|
168
334
|
)
|
169
335
|
|
336
|
+
# Return URLs instead of binary data for HTTP API compatibility
|
170
337
|
result = {
|
171
|
-
"urls": urls,
|
338
|
+
"urls": urls, # Image URLs - primary response
|
339
|
+
"url": urls[0] if urls else None, # First URL for convenience
|
340
|
+
"format": "jpg", # Default format
|
341
|
+
"width": input_data.get("width", 1024),
|
342
|
+
"height": input_data.get("height", 1024),
|
343
|
+
"seed": input_data.get("seed"),
|
172
344
|
"count": len(urls),
|
173
345
|
"cost_usd": cost,
|
174
|
-
"model": self.model_name,
|
175
346
|
"metadata": {
|
347
|
+
"model": self.model_name,
|
176
348
|
"input": input_data,
|
177
349
|
"generation_count": len(urls)
|
178
350
|
}
|
@@ -187,7 +359,7 @@ class ReplicateImageGenService(BaseImageGenService):
|
|
187
359
|
|
188
360
|
def _calculate_cost(self, image_count: int) -> float:
|
189
361
|
"""计算生成成本"""
|
190
|
-
from isa_model.core.model_manager import ModelManager
|
362
|
+
from isa_model.core.models.model_manager import ModelManager
|
191
363
|
|
192
364
|
manager = ModelManager()
|
193
365
|
|
@@ -224,37 +396,6 @@ class ReplicateImageGenService(BaseImageGenService):
|
|
224
396
|
results.append(result)
|
225
397
|
return results
|
226
398
|
|
227
|
-
async def generate_image_to_file(
|
228
|
-
self,
|
229
|
-
prompt: str,
|
230
|
-
output_path: str,
|
231
|
-
negative_prompt: Optional[str] = None,
|
232
|
-
width: int = 512,
|
233
|
-
height: int = 512,
|
234
|
-
num_inference_steps: int = 4,
|
235
|
-
guidance_scale: float = 7.5,
|
236
|
-
seed: Optional[int] = None
|
237
|
-
) -> Dict[str, Any]:
|
238
|
-
"""生成图像并保存到文件"""
|
239
|
-
result = await self.generate_image(
|
240
|
-
prompt, negative_prompt, width, height,
|
241
|
-
num_inference_steps, guidance_scale, seed
|
242
|
-
)
|
243
|
-
|
244
|
-
# 保存第一张图像
|
245
|
-
if result.get("urls"):
|
246
|
-
url = result["urls"][0]
|
247
|
-
url_str = str(url) if hasattr(url, "__str__") else url
|
248
|
-
await self._download_image(url_str, output_path)
|
249
|
-
|
250
|
-
return {
|
251
|
-
"file_path": output_path,
|
252
|
-
"cost_usd": result.get("cost_usd", 0.0),
|
253
|
-
"model": self.model_name
|
254
|
-
}
|
255
|
-
else:
|
256
|
-
raise ValueError("No image generated")
|
257
|
-
|
258
399
|
async def _download_image(self, url: str, save_path: str) -> None:
|
259
400
|
"""下载图像并保存"""
|
260
401
|
try:
|
@@ -5,10 +5,10 @@ LLM Services - Business logic services for Language Models
|
|
5
5
|
# Import LLM services here when created
|
6
6
|
from .ollama_llm_service import OllamaLLMService
|
7
7
|
from .openai_llm_service import OpenAILLMService
|
8
|
-
from .
|
8
|
+
from .yyds_llm_service import YydsLLMService
|
9
9
|
|
10
10
|
__all__ = [
|
11
11
|
"OllamaLLMService",
|
12
|
-
"OpenAILLMService",
|
13
|
-
"
|
12
|
+
"OpenAILLMService",
|
13
|
+
"YydsLLMService"
|
14
14
|
]
|