isa-model 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/client.py +732 -565
- isa_model/core/cache/redis_cache.py +401 -0
- isa_model/core/config/config_manager.py +53 -10
- isa_model/core/config.py +1 -1
- isa_model/core/database/__init__.py +1 -0
- isa_model/core/database/migrations.py +277 -0
- isa_model/core/database/supabase_client.py +123 -0
- isa_model/core/models/__init__.py +37 -0
- isa_model/core/models/model_billing_tracker.py +60 -88
- isa_model/core/models/model_manager.py +36 -18
- isa_model/core/models/model_repo.py +44 -38
- isa_model/core/models/model_statistics_tracker.py +234 -0
- isa_model/core/models/model_storage.py +0 -1
- isa_model/core/models/model_version_manager.py +959 -0
- isa_model/core/pricing_manager.py +2 -249
- isa_model/core/resilience/circuit_breaker.py +366 -0
- isa_model/core/security/secrets.py +358 -0
- isa_model/core/services/__init__.py +2 -4
- isa_model/core/services/intelligent_model_selector.py +101 -370
- isa_model/core/storage/hf_storage.py +1 -1
- isa_model/core/types.py +7 -0
- isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
- isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
- isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
- isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
- isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
- isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
- isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
- isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
- isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
- isa_model/deployment/core/deployment_manager.py +6 -4
- isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
- isa_model/eval/benchmarks/__init__.py +27 -0
- isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
- isa_model/eval/benchmarks.py +244 -12
- isa_model/eval/evaluators/__init__.py +8 -2
- isa_model/eval/evaluators/audio_evaluator.py +727 -0
- isa_model/eval/evaluators/embedding_evaluator.py +742 -0
- isa_model/eval/evaluators/vision_evaluator.py +564 -0
- isa_model/eval/example_evaluation.py +395 -0
- isa_model/eval/factory.py +272 -5
- isa_model/eval/isa_benchmarks.py +700 -0
- isa_model/eval/isa_integration.py +582 -0
- isa_model/eval/metrics.py +159 -6
- isa_model/eval/tests/unit/test_basic.py +396 -0
- isa_model/inference/ai_factory.py +44 -8
- isa_model/inference/services/audio/__init__.py +21 -0
- isa_model/inference/services/audio/base_realtime_service.py +225 -0
- isa_model/inference/services/audio/isa_tts_service.py +0 -0
- isa_model/inference/services/audio/openai_realtime_service.py +320 -124
- isa_model/inference/services/audio/openai_stt_service.py +32 -6
- isa_model/inference/services/base_service.py +17 -1
- isa_model/inference/services/embedding/__init__.py +13 -0
- isa_model/inference/services/embedding/base_embed_service.py +111 -8
- isa_model/inference/services/embedding/isa_embed_service.py +305 -0
- isa_model/inference/services/embedding/openai_embed_service.py +2 -4
- isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
- isa_model/inference/services/img/__init__.py +2 -2
- isa_model/inference/services/img/base_image_gen_service.py +24 -7
- isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
- isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
- isa_model/inference/services/img/services/replicate_flux.py +226 -0
- isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
- isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
- isa_model/inference/services/img/tests/test_img_client.py +297 -0
- isa_model/inference/services/llm/base_llm_service.py +30 -6
- isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
- isa_model/inference/services/llm/ollama_llm_service.py +2 -1
- isa_model/inference/services/llm/openai_llm_service.py +652 -55
- isa_model/inference/services/llm/yyds_llm_service.py +2 -1
- isa_model/inference/services/vision/__init__.py +5 -5
- isa_model/inference/services/vision/base_vision_service.py +118 -185
- isa_model/inference/services/vision/helpers/image_utils.py +11 -5
- isa_model/inference/services/vision/isa_vision_service.py +573 -0
- isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
- isa_model/serving/api/fastapi_server.py +88 -16
- isa_model/serving/api/middleware/auth.py +311 -0
- isa_model/serving/api/middleware/security.py +278 -0
- isa_model/serving/api/routes/analytics.py +486 -0
- isa_model/serving/api/routes/deployments.py +339 -0
- isa_model/serving/api/routes/evaluations.py +579 -0
- isa_model/serving/api/routes/logs.py +430 -0
- isa_model/serving/api/routes/settings.py +582 -0
- isa_model/serving/api/routes/unified.py +324 -165
- isa_model/serving/api/startup.py +304 -0
- isa_model/serving/modal_proxy_server.py +249 -0
- isa_model/training/__init__.py +100 -6
- isa_model/training/core/__init__.py +4 -1
- isa_model/training/examples/intelligent_training_example.py +281 -0
- isa_model/training/intelligent/__init__.py +25 -0
- isa_model/training/intelligent/decision_engine.py +643 -0
- isa_model/training/intelligent/intelligent_factory.py +888 -0
- isa_model/training/intelligent/knowledge_base.py +751 -0
- isa_model/training/intelligent/resource_optimizer.py +839 -0
- isa_model/training/intelligent/task_classifier.py +576 -0
- isa_model/training/storage/__init__.py +24 -0
- isa_model/training/storage/core_integration.py +439 -0
- isa_model/training/storage/training_repository.py +552 -0
- isa_model/training/storage/training_storage.py +628 -0
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
- isa_model-0.4.0.dist-info/RECORD +182 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
- isa_model/deployment/cloud/modal/register_models.py +0 -321
- isa_model/inference/adapter/unified_api.py +0 -248
- isa_model/inference/services/helpers/stacked_config.py +0 -148
- isa_model/inference/services/img/flux_professional_service.py +0 -603
- isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/others/table_transformer_service.py +0 -61
- isa_model/inference/services/vision/doc_analysis_service.py +0 -640
- isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/vision/ui_analysis_service.py +0 -823
- isa_model/scripts/inference_tracker.py +0 -283
- isa_model/scripts/mlflow_manager.py +0 -379
- isa_model/scripts/model_registry.py +0 -465
- isa_model/scripts/register_models.py +0 -370
- isa_model/scripts/register_models_with_embeddings.py +0 -510
- isa_model/scripts/start_mlflow.py +0 -95
- isa_model/scripts/training_tracker.py +0 -257
- isa_model-0.3.9.dist-info/RECORD +0 -138
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,249 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Replicate Sticker Maker Service
|
6
|
+
Specialized service for generating stickers using the fofr/sticker-maker model
|
7
|
+
"""
|
8
|
+
|
9
|
+
import os
|
10
|
+
import logging
|
11
|
+
from typing import Dict, Any, Optional
|
12
|
+
import replicate
|
13
|
+
|
14
|
+
from ..base_image_gen_service import BaseImageGenService
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
class ReplicateStickerMakerService(BaseImageGenService):
|
19
|
+
"""
|
20
|
+
Replicate Sticker Maker Service - $0.0024 per generation
|
21
|
+
Specialized for creating cute stickers from text prompts
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, provider_name: str, model_name: str, **kwargs):
|
25
|
+
super().__init__(provider_name, model_name, **kwargs)
|
26
|
+
|
27
|
+
# Get configuration from centralized config manager
|
28
|
+
provider_config = self.get_provider_config()
|
29
|
+
|
30
|
+
try:
|
31
|
+
self.api_token = provider_config.get("api_key") or provider_config.get("replicate_api_token")
|
32
|
+
|
33
|
+
if not self.api_token:
|
34
|
+
raise ValueError("Replicate API token not found in provider configuration")
|
35
|
+
|
36
|
+
# Set API token
|
37
|
+
os.environ["REPLICATE_API_TOKEN"] = self.api_token
|
38
|
+
|
39
|
+
# Model path
|
40
|
+
self.model_path = "fofr/sticker-maker:4acb778eb059772225ec213948f0660867b2e03f277448f18cf1800b96a65a1a"
|
41
|
+
|
42
|
+
# Statistics
|
43
|
+
self.total_generation_count = 0
|
44
|
+
|
45
|
+
logger.info(f"Initialized ReplicateStickerMakerService with model '{self.model_name}'")
|
46
|
+
|
47
|
+
except Exception as e:
|
48
|
+
logger.error(f"Failed to initialize Replicate Sticker Maker client: {e}")
|
49
|
+
raise ValueError(f"Failed to initialize Replicate Sticker Maker client: {e}") from e
|
50
|
+
|
51
|
+
async def generate_sticker(
|
52
|
+
self,
|
53
|
+
prompt: str,
|
54
|
+
steps: int = 17,
|
55
|
+
width: int = 1152,
|
56
|
+
height: int = 1152,
|
57
|
+
output_format: str = "webp",
|
58
|
+
output_quality: int = 100,
|
59
|
+
negative_prompt: str = "",
|
60
|
+
number_of_images: int = 1
|
61
|
+
) -> Dict[str, Any]:
|
62
|
+
"""Generate sticker from text prompt"""
|
63
|
+
|
64
|
+
input_data = {
|
65
|
+
"steps": steps,
|
66
|
+
"width": width,
|
67
|
+
"height": height,
|
68
|
+
"prompt": prompt,
|
69
|
+
"output_format": output_format,
|
70
|
+
"output_quality": output_quality,
|
71
|
+
"negative_prompt": negative_prompt,
|
72
|
+
"number_of_images": number_of_images
|
73
|
+
}
|
74
|
+
|
75
|
+
return await self._generate_internal(input_data)
|
76
|
+
|
77
|
+
async def _generate_internal(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
78
|
+
"""Internal generation method"""
|
79
|
+
try:
|
80
|
+
logger.info(f"Starting sticker generation with prompt: {input_data.get('prompt', '')[:50]}...")
|
81
|
+
|
82
|
+
# Call Replicate API
|
83
|
+
output = await replicate.async_run(self.model_path, input=input_data)
|
84
|
+
|
85
|
+
# Process output - convert FileOutput objects to URL strings
|
86
|
+
if isinstance(output, list):
|
87
|
+
raw_urls = output
|
88
|
+
else:
|
89
|
+
raw_urls = [output]
|
90
|
+
|
91
|
+
# Convert to string URLs
|
92
|
+
urls = []
|
93
|
+
for url in raw_urls:
|
94
|
+
if hasattr(url, 'url'):
|
95
|
+
urls.append(str(url.url))
|
96
|
+
else:
|
97
|
+
urls.append(str(url))
|
98
|
+
|
99
|
+
# Update statistics
|
100
|
+
self.total_generation_count += len(urls)
|
101
|
+
|
102
|
+
# Calculate cost
|
103
|
+
cost = self._calculate_cost(len(urls))
|
104
|
+
|
105
|
+
# Track billing information
|
106
|
+
await self._track_usage(
|
107
|
+
service_type="image_generation",
|
108
|
+
operation="sticker_generation",
|
109
|
+
input_tokens=0,
|
110
|
+
output_tokens=0,
|
111
|
+
input_units=1, # Input prompt
|
112
|
+
output_units=len(urls), # Generated stickers count
|
113
|
+
metadata={
|
114
|
+
"model": self.model_name,
|
115
|
+
"prompt": input_data.get("prompt", "")[:100],
|
116
|
+
"generation_type": "sticker",
|
117
|
+
"image_count": len(urls),
|
118
|
+
"cost_usd": cost
|
119
|
+
}
|
120
|
+
)
|
121
|
+
|
122
|
+
# Return URLs
|
123
|
+
result = {
|
124
|
+
"urls": urls,
|
125
|
+
"url": urls[0] if urls else None,
|
126
|
+
"format": input_data.get("output_format", "webp"),
|
127
|
+
"width": input_data.get("width", 1152),
|
128
|
+
"height": input_data.get("height", 1152),
|
129
|
+
"count": len(urls),
|
130
|
+
"cost_usd": cost,
|
131
|
+
"metadata": {
|
132
|
+
"model": self.model_name,
|
133
|
+
"input": input_data,
|
134
|
+
"generation_count": len(urls)
|
135
|
+
}
|
136
|
+
}
|
137
|
+
|
138
|
+
logger.info(f"Sticker generation completed: {len(urls)} stickers, cost: ${cost:.6f}")
|
139
|
+
return result
|
140
|
+
|
141
|
+
except Exception as e:
|
142
|
+
logger.error(f"Sticker generation failed: {e}")
|
143
|
+
raise
|
144
|
+
|
145
|
+
def _calculate_cost(self, image_count: int) -> float:
|
146
|
+
"""Calculate generation cost - $0.0024 per generation"""
|
147
|
+
return image_count * 0.0024
|
148
|
+
|
149
|
+
def get_generation_stats(self) -> Dict[str, Any]:
|
150
|
+
"""Get generation statistics"""
|
151
|
+
total_cost = self.total_generation_count * 0.0024
|
152
|
+
|
153
|
+
return {
|
154
|
+
"total_generation_count": self.total_generation_count,
|
155
|
+
"total_cost_usd": total_cost,
|
156
|
+
"cost_per_generation": 0.0024,
|
157
|
+
"model": self.model_name
|
158
|
+
}
|
159
|
+
|
160
|
+
def get_model_info(self) -> Dict[str, Any]:
|
161
|
+
"""Get model information"""
|
162
|
+
return {
|
163
|
+
"name": self.model_name,
|
164
|
+
"type": "sticker_generation",
|
165
|
+
"cost_per_generation": 0.0024,
|
166
|
+
"supports_negative_prompt": True,
|
167
|
+
"max_width": 1152,
|
168
|
+
"max_height": 1152,
|
169
|
+
"output_formats": ["webp", "jpg", "png"]
|
170
|
+
}
|
171
|
+
|
172
|
+
async def load(self) -> None:
|
173
|
+
"""Load service"""
|
174
|
+
if not self.api_token:
|
175
|
+
raise ValueError("Missing Replicate API token")
|
176
|
+
logger.info(f"Replicate Sticker Maker service ready with model: {self.model_name}")
|
177
|
+
|
178
|
+
async def unload(self) -> None:
|
179
|
+
"""Unload service"""
|
180
|
+
logger.info(f"Unloading Replicate Sticker Maker service: {self.model_name}")
|
181
|
+
|
182
|
+
async def close(self):
|
183
|
+
"""Close service"""
|
184
|
+
await self.unload()
|
185
|
+
|
186
|
+
async def generate_image(
|
187
|
+
self,
|
188
|
+
prompt: str,
|
189
|
+
negative_prompt: Optional[str] = None,
|
190
|
+
width: int = 512,
|
191
|
+
height: int = 512,
|
192
|
+
num_inference_steps: int = 17,
|
193
|
+
guidance_scale: float = 7.5,
|
194
|
+
seed: Optional[int] = None
|
195
|
+
) -> Dict[str, Any]:
|
196
|
+
"""Generate single sticker"""
|
197
|
+
return await self.generate_sticker(
|
198
|
+
prompt=prompt,
|
199
|
+
steps=num_inference_steps,
|
200
|
+
width=width,
|
201
|
+
height=height,
|
202
|
+
negative_prompt=negative_prompt or ""
|
203
|
+
)
|
204
|
+
|
205
|
+
async def generate_images(
|
206
|
+
self,
|
207
|
+
prompt: str,
|
208
|
+
num_images: int = 1,
|
209
|
+
negative_prompt: Optional[str] = None,
|
210
|
+
width: int = 512,
|
211
|
+
height: int = 512,
|
212
|
+
num_inference_steps: int = 17,
|
213
|
+
guidance_scale: float = 7.5,
|
214
|
+
seed: Optional[int] = None
|
215
|
+
) -> list[Dict[str, Any]]:
|
216
|
+
"""Generate multiple stickers"""
|
217
|
+
results = []
|
218
|
+
for i in range(num_images):
|
219
|
+
result = await self.generate_sticker(
|
220
|
+
prompt=prompt,
|
221
|
+
steps=num_inference_steps,
|
222
|
+
width=width,
|
223
|
+
height=height,
|
224
|
+
negative_prompt=negative_prompt or "",
|
225
|
+
number_of_images=1
|
226
|
+
)
|
227
|
+
results.append(result)
|
228
|
+
return results
|
229
|
+
|
230
|
+
async def image_to_image(
|
231
|
+
self,
|
232
|
+
prompt: str,
|
233
|
+
init_image,
|
234
|
+
strength: float = 0.8,
|
235
|
+
negative_prompt: Optional[str] = None,
|
236
|
+
num_inference_steps: int = 17,
|
237
|
+
guidance_scale: float = 7.5,
|
238
|
+
seed: Optional[int] = None
|
239
|
+
) -> Dict[str, Any]:
|
240
|
+
"""Not supported by sticker maker"""
|
241
|
+
raise NotImplementedError("Sticker maker does not support image-to-image generation")
|
242
|
+
|
243
|
+
def get_supported_sizes(self) -> list[Dict[str, int]]:
|
244
|
+
"""Get supported image sizes"""
|
245
|
+
return [
|
246
|
+
{"width": 1152, "height": 1152},
|
247
|
+
{"width": 1024, "height": 1024},
|
248
|
+
{"width": 768, "height": 768},
|
249
|
+
]
|
@@ -0,0 +1,297 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Test Image Generation Client using ISA Model Client
|
6
|
+
Tests the four specialized image generation services through the unified client:
|
7
|
+
1. FLUX Schnell (text-to-image)
|
8
|
+
2. FLUX Kontext Pro (image-to-image)
|
9
|
+
3. Sticker Maker
|
10
|
+
4. Face Swap
|
11
|
+
"""
|
12
|
+
|
13
|
+
import asyncio
|
14
|
+
import logging
|
15
|
+
from typing import Dict, Any
|
16
|
+
|
17
|
+
from isa_model.client import ISAModelClient
|
18
|
+
|
19
|
+
# Set up logging
|
20
|
+
logging.basicConfig(level=logging.INFO)
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
class ImageGenerationTester:
|
24
|
+
"""Test client for image generation services using ISA Model Client"""
|
25
|
+
|
26
|
+
def __init__(self):
|
27
|
+
self.client = ISAModelClient()
|
28
|
+
|
29
|
+
# Test configurations for each service
|
30
|
+
self.test_configs = {
|
31
|
+
"flux_schnell": {
|
32
|
+
"model": "flux-schnell",
|
33
|
+
"provider": "replicate",
|
34
|
+
"task": "generate",
|
35
|
+
"prompt": "a cute cat sitting in a garden"
|
36
|
+
},
|
37
|
+
"flux_kontext": {
|
38
|
+
"model": "flux-kontext-pro",
|
39
|
+
"provider": "replicate",
|
40
|
+
"task": "img2img",
|
41
|
+
"prompt": "transform this into a futuristic cityscape",
|
42
|
+
"init_image": "https://replicate.delivery/pbxt/Mb44XIUHkUrmyyH1OP5K1WmFN7SNN0eUSU16A8rBtuXe7eYV/cyberpunk_80s_example.png"
|
43
|
+
},
|
44
|
+
"sticker_maker": {
|
45
|
+
"model": "sticker-maker",
|
46
|
+
"provider": "replicate",
|
47
|
+
"task": "generate",
|
48
|
+
"prompt": "a cute cat"
|
49
|
+
},
|
50
|
+
"face_swap": {
|
51
|
+
"model": "face-swap",
|
52
|
+
"provider": "replicate",
|
53
|
+
"task": "face_swap",
|
54
|
+
"swap_image": "https://replicate.delivery/pbxt/Mb44Wp0W7Xfa1Pp91zcxDzSSQQz8GusUmXQXi3GGzRxDvoCI/0_1.webp",
|
55
|
+
"target_image": "https://replicate.delivery/pbxt/Mb44XIUHkUrmyyH1OP5K1WmFN7SNN0eUSU16A8rBtuXe7eYV/cyberpunk_80s_example.png"
|
56
|
+
}
|
57
|
+
}
|
58
|
+
|
59
|
+
async def test_flux_text_to_image(self) -> Dict[str, Any]:
|
60
|
+
"""Test FLUX Schnell text-to-image generation"""
|
61
|
+
logger.info("Testing FLUX Schnell text-to-image...")
|
62
|
+
|
63
|
+
try:
|
64
|
+
config = self.test_configs["flux_schnell"]
|
65
|
+
|
66
|
+
result = await self.client.invoke(
|
67
|
+
input_data=config["prompt"],
|
68
|
+
task=config["task"],
|
69
|
+
service_type="image",
|
70
|
+
model=config["model"],
|
71
|
+
provider=config["provider"],
|
72
|
+
width=1024,
|
73
|
+
height=1024,
|
74
|
+
num_inference_steps=4
|
75
|
+
)
|
76
|
+
|
77
|
+
if result.get("success"):
|
78
|
+
response = result["result"]
|
79
|
+
logger.info(f"FLUX generation successful: {response.get('count', 0)} images")
|
80
|
+
logger.info(f"Cost: ${response.get('cost_usd', 0):.6f}")
|
81
|
+
logger.info(f"URL: {response.get('url', 'N/A')}")
|
82
|
+
|
83
|
+
return {
|
84
|
+
"status": "success",
|
85
|
+
"result": response,
|
86
|
+
"metadata": result.get("metadata", {})
|
87
|
+
}
|
88
|
+
else:
|
89
|
+
error_msg = result.get("error", "Unknown error")
|
90
|
+
logger.error(f"FLUX generation failed: {error_msg}")
|
91
|
+
return {"status": "error", "error": error_msg}
|
92
|
+
|
93
|
+
except Exception as e:
|
94
|
+
logger.error(f"FLUX generation failed with exception: {e}")
|
95
|
+
return {"status": "error", "error": str(e)}
|
96
|
+
|
97
|
+
async def test_flux_kontext_image_to_image(self) -> Dict[str, Any]:
|
98
|
+
"""Test FLUX Kontext Pro image-to-image generation"""
|
99
|
+
logger.info("Testing FLUX Kontext Pro image-to-image...")
|
100
|
+
|
101
|
+
try:
|
102
|
+
config = self.test_configs["flux_kontext"]
|
103
|
+
|
104
|
+
result = await self.client.invoke(
|
105
|
+
input_data=config["prompt"],
|
106
|
+
task=config["task"],
|
107
|
+
service_type="image",
|
108
|
+
model=config["model"],
|
109
|
+
provider=config["provider"],
|
110
|
+
init_image=config["init_image"],
|
111
|
+
strength=0.8
|
112
|
+
)
|
113
|
+
|
114
|
+
if result.get("success"):
|
115
|
+
response = result["result"]
|
116
|
+
logger.info(f"FLUX Kontext generation successful: {response.get('count', 0)} images")
|
117
|
+
logger.info(f"Cost: ${response.get('cost_usd', 0):.6f}")
|
118
|
+
logger.info(f"URL: {response.get('url', 'N/A')}")
|
119
|
+
|
120
|
+
return {
|
121
|
+
"status": "success",
|
122
|
+
"result": response,
|
123
|
+
"metadata": result.get("metadata", {})
|
124
|
+
}
|
125
|
+
else:
|
126
|
+
error_msg = result.get("error", "Unknown error")
|
127
|
+
logger.error(f"FLUX Kontext generation failed: {error_msg}")
|
128
|
+
return {"status": "error", "error": error_msg}
|
129
|
+
|
130
|
+
except Exception as e:
|
131
|
+
logger.error(f"FLUX Kontext generation failed with exception: {e}")
|
132
|
+
return {"status": "error", "error": str(e)}
|
133
|
+
|
134
|
+
async def test_sticker_generation(self) -> Dict[str, Any]:
|
135
|
+
"""Test sticker generation"""
|
136
|
+
logger.info("Testing Sticker Maker...")
|
137
|
+
|
138
|
+
try:
|
139
|
+
config = self.test_configs["sticker_maker"]
|
140
|
+
|
141
|
+
result = await self.client.invoke(
|
142
|
+
input_data=config["prompt"],
|
143
|
+
task=config["task"],
|
144
|
+
service_type="image",
|
145
|
+
model=config["model"],
|
146
|
+
provider=config["provider"],
|
147
|
+
steps=17,
|
148
|
+
width=1152,
|
149
|
+
height=1152,
|
150
|
+
output_format="webp",
|
151
|
+
output_quality=100
|
152
|
+
)
|
153
|
+
|
154
|
+
if result.get("success"):
|
155
|
+
response = result["result"]
|
156
|
+
logger.info(f"Sticker generation successful: {response.get('count', 0)} stickers")
|
157
|
+
logger.info(f"Cost: ${response.get('cost_usd', 0):.6f}")
|
158
|
+
logger.info(f"URL: {response.get('url', 'N/A')}")
|
159
|
+
|
160
|
+
return {
|
161
|
+
"status": "success",
|
162
|
+
"result": response,
|
163
|
+
"metadata": result.get("metadata", {})
|
164
|
+
}
|
165
|
+
else:
|
166
|
+
error_msg = result.get("error", "Unknown error")
|
167
|
+
logger.error(f"Sticker generation failed: {error_msg}")
|
168
|
+
return {"status": "error", "error": error_msg}
|
169
|
+
|
170
|
+
except Exception as e:
|
171
|
+
logger.error(f"Sticker generation failed with exception: {e}")
|
172
|
+
return {"status": "error", "error": str(e)}
|
173
|
+
|
174
|
+
async def test_face_swap(self) -> Dict[str, Any]:
|
175
|
+
"""Test face swap"""
|
176
|
+
logger.info("Testing Face Swap...")
|
177
|
+
|
178
|
+
try:
|
179
|
+
config = self.test_configs["face_swap"]
|
180
|
+
|
181
|
+
result = await self.client.invoke(
|
182
|
+
input_data=config["swap_image"], # Use swap_image as input_data
|
183
|
+
task=config["task"],
|
184
|
+
service_type="image",
|
185
|
+
model=config["model"],
|
186
|
+
provider=config["provider"],
|
187
|
+
target_image=config["target_image"],
|
188
|
+
hair_source="target",
|
189
|
+
user_gender="default",
|
190
|
+
user_b_gender="default"
|
191
|
+
)
|
192
|
+
|
193
|
+
if result.get("success"):
|
194
|
+
response = result["result"]
|
195
|
+
logger.info(f"Face swap successful: {response.get('count', 0)} images")
|
196
|
+
logger.info(f"Cost: ${response.get('cost_usd', 0):.6f}")
|
197
|
+
logger.info(f"URL: {response.get('url', 'N/A')}")
|
198
|
+
|
199
|
+
return {
|
200
|
+
"status": "success",
|
201
|
+
"result": response,
|
202
|
+
"metadata": result.get("metadata", {})
|
203
|
+
}
|
204
|
+
else:
|
205
|
+
error_msg = result.get("error", "Unknown error")
|
206
|
+
logger.error(f"Face swap failed: {error_msg}")
|
207
|
+
return {"status": "error", "error": error_msg}
|
208
|
+
|
209
|
+
except Exception as e:
|
210
|
+
logger.error(f"Face swap failed with exception: {e}")
|
211
|
+
return {"status": "error", "error": str(e)}
|
212
|
+
|
213
|
+
async def test_all_services(self) -> Dict[str, Dict[str, Any]]:
|
214
|
+
"""Test all image generation services"""
|
215
|
+
logger.info("Starting comprehensive image generation tests using ISA Model Client...")
|
216
|
+
|
217
|
+
results = {}
|
218
|
+
|
219
|
+
# Test each service
|
220
|
+
tests = [
|
221
|
+
("flux_text_to_image", self.test_flux_text_to_image),
|
222
|
+
("flux_image_to_image", self.test_flux_kontext_image_to_image),
|
223
|
+
("sticker_generation", self.test_sticker_generation),
|
224
|
+
("face_swap", self.test_face_swap)
|
225
|
+
]
|
226
|
+
|
227
|
+
for test_name, test_func in tests:
|
228
|
+
logger.info(f"\n{'='*50}")
|
229
|
+
logger.info(f"Running test: {test_name}")
|
230
|
+
logger.info(f"{'='*50}")
|
231
|
+
|
232
|
+
try:
|
233
|
+
result = await test_func()
|
234
|
+
results[test_name] = result
|
235
|
+
|
236
|
+
if result.get("status") == "success":
|
237
|
+
logger.info(f" {test_name} PASSED")
|
238
|
+
else:
|
239
|
+
logger.error(f"L {test_name} FAILED: {result.get('error', 'Unknown error')}")
|
240
|
+
|
241
|
+
except Exception as e:
|
242
|
+
logger.error(f"L {test_name} FAILED with exception: {e}")
|
243
|
+
results[test_name] = {"status": "error", "error": str(e)}
|
244
|
+
|
245
|
+
# Summary
|
246
|
+
logger.info(f"\n{'='*50}")
|
247
|
+
logger.info("TEST SUMMARY")
|
248
|
+
logger.info(f"{'='*50}")
|
249
|
+
|
250
|
+
passed = sum(1 for r in results.values() if r.get("status") == "success")
|
251
|
+
total = len(results)
|
252
|
+
|
253
|
+
logger.info(f"Passed: {passed}/{total}")
|
254
|
+
|
255
|
+
for test_name, result in results.items():
|
256
|
+
status = " PASS" if result.get("status") == "success" else "L FAIL"
|
257
|
+
logger.info(f"{test_name}: {status}")
|
258
|
+
|
259
|
+
return results
|
260
|
+
|
261
|
+
async def get_service_health(self) -> Dict[str, Any]:
|
262
|
+
"""Get health status of the client and services"""
|
263
|
+
logger.info("Checking service health...")
|
264
|
+
|
265
|
+
try:
|
266
|
+
health = await self.client.health_check()
|
267
|
+
return health
|
268
|
+
except Exception as e:
|
269
|
+
logger.error(f"Health check failed: {e}")
|
270
|
+
return {"status": "error", "error": str(e)}
|
271
|
+
|
272
|
+
async def main():
|
273
|
+
"""Main test function"""
|
274
|
+
tester = ImageGenerationTester()
|
275
|
+
|
276
|
+
# Get service health
|
277
|
+
logger.info("Checking service health...")
|
278
|
+
health = await tester.get_service_health()
|
279
|
+
logger.info(f"Service health: {health}")
|
280
|
+
|
281
|
+
# Run all tests
|
282
|
+
results = await tester.test_all_services()
|
283
|
+
|
284
|
+
# Calculate total cost
|
285
|
+
total_cost = 0.0
|
286
|
+
for test_name, result in results.items():
|
287
|
+
if result.get("status") == "success":
|
288
|
+
cost = result.get("result", {}).get("cost_usd", 0.0)
|
289
|
+
total_cost += cost
|
290
|
+
|
291
|
+
logger.info(f"\nTotal cost for all tests: ${total_cost:.6f}")
|
292
|
+
|
293
|
+
return results
|
294
|
+
|
295
|
+
if __name__ == "__main__":
|
296
|
+
# Run the tests
|
297
|
+
results = asyncio.run(main())
|
@@ -28,6 +28,7 @@ class BaseLLMService(BaseService):
|
|
28
28
|
self,
|
29
29
|
input_data: Union[str, List[Dict[str, str]], Any],
|
30
30
|
task: Optional[str] = None,
|
31
|
+
show_reasoning: bool = False,
|
31
32
|
**kwargs
|
32
33
|
) -> Dict[str, Any]:
|
33
34
|
"""
|
@@ -48,7 +49,7 @@ class BaseLLMService(BaseService):
|
|
48
49
|
|
49
50
|
# ==================== 对话类任务 ====================
|
50
51
|
if task == "chat":
|
51
|
-
return await self.chat(input_data, kwargs.get("max_tokens", self.max_tokens))
|
52
|
+
return await self.chat(input_data, kwargs.get("max_tokens", self.max_tokens), show_reasoning=show_reasoning)
|
52
53
|
elif task == "complete":
|
53
54
|
return await self.complete_text(input_data, kwargs.get("max_tokens", self.max_tokens))
|
54
55
|
elif task == "instruct":
|
@@ -62,7 +63,10 @@ class BaseLLMService(BaseService):
|
|
62
63
|
elif task == "summarize":
|
63
64
|
return await self.summarize_text(input_data, kwargs.get("max_length"), kwargs.get("style"))
|
64
65
|
elif task == "translate":
|
65
|
-
|
66
|
+
target_language = kwargs.get("target_language")
|
67
|
+
if not target_language:
|
68
|
+
raise ValueError("target_language is required for translate task")
|
69
|
+
return await self.translate_text(input_data, target_language, kwargs.get("source_language"))
|
66
70
|
|
67
71
|
# ==================== 分析类任务 ====================
|
68
72
|
elif task == "analyze":
|
@@ -91,12 +95,17 @@ class BaseLLMService(BaseService):
|
|
91
95
|
return await self.solve_problem(input_data, kwargs.get("problem_type"))
|
92
96
|
elif task == "plan":
|
93
97
|
return await self.create_plan(input_data, kwargs.get("plan_type"))
|
98
|
+
elif task == "deep_research":
|
99
|
+
return await self.deep_research(input_data, kwargs.get("research_type"), kwargs.get("search_enabled", True))
|
94
100
|
|
95
101
|
# ==================== 工具调用类任务 ====================
|
96
102
|
elif task == "tool_call":
|
97
103
|
return await self.call_tools(input_data, kwargs.get("available_tools"))
|
98
104
|
elif task == "function_call":
|
99
|
-
|
105
|
+
function_name = kwargs.get("function_name")
|
106
|
+
if not function_name:
|
107
|
+
raise ValueError("function_name is required for function_call task")
|
108
|
+
return await self.call_function(input_data, function_name, kwargs.get("parameters"))
|
100
109
|
|
101
110
|
else:
|
102
111
|
raise NotImplementedError(f"{self.__class__.__name__} does not support task: {task}")
|
@@ -106,7 +115,8 @@ class BaseLLMService(BaseService):
|
|
106
115
|
async def chat(
|
107
116
|
self,
|
108
117
|
input_data: Union[str, List[Dict[str, str]], Any],
|
109
|
-
max_tokens: Optional[int] = None
|
118
|
+
max_tokens: Optional[int] = None,
|
119
|
+
show_reasoning: bool = False
|
110
120
|
) -> Dict[str, Any]:
|
111
121
|
"""
|
112
122
|
对话聊天 - Provider必须实现
|
@@ -114,6 +124,7 @@ class BaseLLMService(BaseService):
|
|
114
124
|
Args:
|
115
125
|
input_data: 输入消息
|
116
126
|
max_tokens: 最大生成token数
|
127
|
+
show_reasoning: 是否显示推理过程
|
117
128
|
|
118
129
|
Returns:
|
119
130
|
Dict containing chat response
|
@@ -303,6 +314,17 @@ class BaseLLMService(BaseService):
|
|
303
314
|
"""
|
304
315
|
raise NotImplementedError(f"{self.__class__.__name__} does not support create_plan task")
|
305
316
|
|
317
|
+
async def deep_research(
|
318
|
+
self,
|
319
|
+
input_data: Union[str, Any],
|
320
|
+
research_type: Optional[str] = None,
|
321
|
+
search_enabled: bool = True
|
322
|
+
) -> Dict[str, Any]:
|
323
|
+
"""
|
324
|
+
深度研究 - O-series模型专用任务,支持网络搜索和深入分析
|
325
|
+
"""
|
326
|
+
raise NotImplementedError(f"{self.__class__.__name__} does not support deep_research task")
|
327
|
+
|
306
328
|
# ==================== 工具调用类方法 ====================
|
307
329
|
|
308
330
|
async def call_tools(
|
@@ -354,7 +376,7 @@ class BaseLLMService(BaseService):
|
|
354
376
|
"""使用适配器管理器转换消息格式"""
|
355
377
|
return self.adapter_manager.convert_messages(input_data)
|
356
378
|
|
357
|
-
def _format_response(self, response: str, original_input: Any) -> Union[str, Any]:
|
379
|
+
def _format_response(self, response: Union[str, Any], original_input: Any) -> Union[str, Any]:
|
358
380
|
"""使用适配器管理器格式化响应"""
|
359
381
|
return self.adapter_manager.format_response(response, original_input)
|
360
382
|
|
@@ -379,7 +401,7 @@ class BaseLLMService(BaseService):
|
|
379
401
|
pass
|
380
402
|
|
381
403
|
@abstractmethod
|
382
|
-
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any]) -> Union[str, Any]:
|
404
|
+
async def ainvoke(self, input_data: Union[str, List[Dict[str, str]], Any], show_reasoning: bool = False) -> Union[str, Any]:
|
383
405
|
"""
|
384
406
|
Universal async invocation method that handles different input types
|
385
407
|
|
@@ -388,6 +410,7 @@ class BaseLLMService(BaseService):
|
|
388
410
|
- str: Simple text prompt
|
389
411
|
- list: Message history like [{"role": "user", "content": "hello"}]
|
390
412
|
- Any: LangChain message objects or other formats
|
413
|
+
show_reasoning: If True and model supports it, show reasoning process
|
391
414
|
|
392
415
|
Returns:
|
393
416
|
Model response (string for simple cases, object for complex cases)
|
@@ -527,6 +550,7 @@ class BaseLLMService(BaseService):
|
|
527
550
|
'reason_about': 'reason',
|
528
551
|
'solve_problem': 'solve',
|
529
552
|
'create_plan': 'plan',
|
553
|
+
'deep_research': 'deep_research',
|
530
554
|
# 工具调用类
|
531
555
|
'call_tools': 'tool_call',
|
532
556
|
'call_function': 'function_call'
|