isa-model 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/client.py +732 -565
  3. isa_model/core/cache/redis_cache.py +401 -0
  4. isa_model/core/config/config_manager.py +53 -10
  5. isa_model/core/config.py +1 -1
  6. isa_model/core/database/__init__.py +1 -0
  7. isa_model/core/database/migrations.py +277 -0
  8. isa_model/core/database/supabase_client.py +123 -0
  9. isa_model/core/models/__init__.py +37 -0
  10. isa_model/core/models/model_billing_tracker.py +60 -88
  11. isa_model/core/models/model_manager.py +36 -18
  12. isa_model/core/models/model_repo.py +44 -38
  13. isa_model/core/models/model_statistics_tracker.py +234 -0
  14. isa_model/core/models/model_storage.py +0 -1
  15. isa_model/core/models/model_version_manager.py +959 -0
  16. isa_model/core/pricing_manager.py +2 -249
  17. isa_model/core/resilience/circuit_breaker.py +366 -0
  18. isa_model/core/security/secrets.py +358 -0
  19. isa_model/core/services/__init__.py +2 -4
  20. isa_model/core/services/intelligent_model_selector.py +101 -370
  21. isa_model/core/storage/hf_storage.py +1 -1
  22. isa_model/core/types.py +7 -0
  23. isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
  24. isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
  25. isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
  26. isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
  27. isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
  28. isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
  29. isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
  30. isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
  31. isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
  32. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
  33. isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
  34. isa_model/deployment/core/deployment_manager.py +6 -4
  35. isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
  36. isa_model/eval/benchmarks/__init__.py +27 -0
  37. isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
  38. isa_model/eval/benchmarks.py +244 -12
  39. isa_model/eval/evaluators/__init__.py +8 -2
  40. isa_model/eval/evaluators/audio_evaluator.py +727 -0
  41. isa_model/eval/evaluators/embedding_evaluator.py +742 -0
  42. isa_model/eval/evaluators/vision_evaluator.py +564 -0
  43. isa_model/eval/example_evaluation.py +395 -0
  44. isa_model/eval/factory.py +272 -5
  45. isa_model/eval/isa_benchmarks.py +700 -0
  46. isa_model/eval/isa_integration.py +582 -0
  47. isa_model/eval/metrics.py +159 -6
  48. isa_model/eval/tests/unit/test_basic.py +396 -0
  49. isa_model/inference/ai_factory.py +44 -8
  50. isa_model/inference/services/audio/__init__.py +21 -0
  51. isa_model/inference/services/audio/base_realtime_service.py +225 -0
  52. isa_model/inference/services/audio/isa_tts_service.py +0 -0
  53. isa_model/inference/services/audio/openai_realtime_service.py +320 -124
  54. isa_model/inference/services/audio/openai_stt_service.py +32 -6
  55. isa_model/inference/services/base_service.py +17 -1
  56. isa_model/inference/services/embedding/__init__.py +13 -0
  57. isa_model/inference/services/embedding/base_embed_service.py +111 -8
  58. isa_model/inference/services/embedding/isa_embed_service.py +305 -0
  59. isa_model/inference/services/embedding/openai_embed_service.py +2 -4
  60. isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
  61. isa_model/inference/services/img/__init__.py +2 -2
  62. isa_model/inference/services/img/base_image_gen_service.py +24 -7
  63. isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
  64. isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
  65. isa_model/inference/services/img/services/replicate_flux.py +226 -0
  66. isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
  67. isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
  68. isa_model/inference/services/img/tests/test_img_client.py +297 -0
  69. isa_model/inference/services/llm/base_llm_service.py +30 -6
  70. isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
  71. isa_model/inference/services/llm/ollama_llm_service.py +2 -1
  72. isa_model/inference/services/llm/openai_llm_service.py +652 -55
  73. isa_model/inference/services/llm/yyds_llm_service.py +2 -1
  74. isa_model/inference/services/vision/__init__.py +5 -5
  75. isa_model/inference/services/vision/base_vision_service.py +118 -185
  76. isa_model/inference/services/vision/helpers/image_utils.py +11 -5
  77. isa_model/inference/services/vision/isa_vision_service.py +573 -0
  78. isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
  79. isa_model/serving/api/fastapi_server.py +88 -16
  80. isa_model/serving/api/middleware/auth.py +311 -0
  81. isa_model/serving/api/middleware/security.py +278 -0
  82. isa_model/serving/api/routes/analytics.py +486 -0
  83. isa_model/serving/api/routes/deployments.py +339 -0
  84. isa_model/serving/api/routes/evaluations.py +579 -0
  85. isa_model/serving/api/routes/logs.py +430 -0
  86. isa_model/serving/api/routes/settings.py +582 -0
  87. isa_model/serving/api/routes/unified.py +324 -165
  88. isa_model/serving/api/startup.py +304 -0
  89. isa_model/serving/modal_proxy_server.py +249 -0
  90. isa_model/training/__init__.py +100 -6
  91. isa_model/training/core/__init__.py +4 -1
  92. isa_model/training/examples/intelligent_training_example.py +281 -0
  93. isa_model/training/intelligent/__init__.py +25 -0
  94. isa_model/training/intelligent/decision_engine.py +643 -0
  95. isa_model/training/intelligent/intelligent_factory.py +888 -0
  96. isa_model/training/intelligent/knowledge_base.py +751 -0
  97. isa_model/training/intelligent/resource_optimizer.py +839 -0
  98. isa_model/training/intelligent/task_classifier.py +576 -0
  99. isa_model/training/storage/__init__.py +24 -0
  100. isa_model/training/storage/core_integration.py +439 -0
  101. isa_model/training/storage/training_repository.py +552 -0
  102. isa_model/training/storage/training_storage.py +628 -0
  103. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
  104. isa_model-0.4.0.dist-info/RECORD +182 -0
  105. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
  106. isa_model/deployment/cloud/modal/register_models.py +0 -321
  107. isa_model/inference/adapter/unified_api.py +0 -248
  108. isa_model/inference/services/helpers/stacked_config.py +0 -148
  109. isa_model/inference/services/img/flux_professional_service.py +0 -603
  110. isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
  111. isa_model/inference/services/others/table_transformer_service.py +0 -61
  112. isa_model/inference/services/vision/doc_analysis_service.py +0 -640
  113. isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
  114. isa_model/inference/services/vision/ui_analysis_service.py +0 -823
  115. isa_model/scripts/inference_tracker.py +0 -283
  116. isa_model/scripts/mlflow_manager.py +0 -379
  117. isa_model/scripts/model_registry.py +0 -465
  118. isa_model/scripts/register_models.py +0 -370
  119. isa_model/scripts/register_models_with_embeddings.py +0 -510
  120. isa_model/scripts/start_mlflow.py +0 -95
  121. isa_model/scripts/training_tracker.py +0 -257
  122. isa_model-0.3.9.dist-info/RECORD +0 -138
  123. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
  124. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Replicate Face Swap Service
6
+ Specialized service for face swapping using the easel/advanced-face-swap model
7
+ """
8
+
9
+ import os
10
+ import logging
11
+ from typing import Dict, Any, Union
12
+ import replicate
13
+
14
+ from ..base_image_gen_service import BaseImageGenService
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class ReplicateFaceSwapService(BaseImageGenService):
19
+ """
20
+ Replicate Face Swap Service - $0.04 per generation
21
+ Advanced face swapping with hair source control and gender options
22
+ """
23
+
24
+ def __init__(self, provider_name: str, model_name: str, **kwargs):
25
+ super().__init__(provider_name, model_name, **kwargs)
26
+
27
+ # Get configuration from centralized config manager
28
+ provider_config = self.get_provider_config()
29
+
30
+ try:
31
+ self.api_token = provider_config.get("api_key") or provider_config.get("replicate_api_token")
32
+
33
+ if not self.api_token:
34
+ raise ValueError("Replicate API token not found in provider configuration")
35
+
36
+ # Set API token
37
+ os.environ["REPLICATE_API_TOKEN"] = self.api_token
38
+
39
+ # Model path
40
+ self.model_path = "easel/advanced-face-swap"
41
+
42
+ # Statistics
43
+ self.total_generation_count = 0
44
+
45
+ logger.info(f"Initialized ReplicateFaceSwapService with model '{self.model_name}'")
46
+
47
+ except Exception as e:
48
+ logger.error(f"Failed to initialize Replicate Face Swap client: {e}")
49
+ raise ValueError(f"Failed to initialize Replicate Face Swap client: {e}") from e
50
+
51
+ async def face_swap(
52
+ self,
53
+ swap_image: Union[str, Any],
54
+ target_image: Union[str, Any],
55
+ hair_source: str = "target",
56
+ user_gender: str = "default",
57
+ user_b_gender: str = "default"
58
+ ) -> Dict[str, Any]:
59
+ """Perform face swap between two images"""
60
+
61
+ input_data = {
62
+ "swap_image": swap_image,
63
+ "target_image": target_image,
64
+ "hair_source": hair_source,
65
+ "user_gender": user_gender,
66
+ "user_b_gender": user_b_gender
67
+ }
68
+
69
+ return await self._generate_internal(input_data)
70
+
71
+ async def _generate_internal(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
72
+ """Internal generation method"""
73
+ try:
74
+ logger.info("Starting face swap generation...")
75
+
76
+ # Call Replicate API
77
+ output = await replicate.async_run(self.model_path, input=input_data)
78
+
79
+ # Process output - convert FileOutput objects to URL strings
80
+ if isinstance(output, list):
81
+ raw_urls = output
82
+ else:
83
+ raw_urls = [output]
84
+
85
+ # Convert to string URLs
86
+ urls = []
87
+ for url in raw_urls:
88
+ if hasattr(url, 'url'):
89
+ urls.append(str(url.url))
90
+ else:
91
+ urls.append(str(url))
92
+
93
+ # Update statistics
94
+ self.total_generation_count += len(urls)
95
+
96
+ # Calculate cost
97
+ cost = self._calculate_cost(len(urls))
98
+
99
+ # Track billing information
100
+ await self._track_usage(
101
+ service_type="image_generation",
102
+ operation="face_swap",
103
+ input_tokens=0,
104
+ output_tokens=0,
105
+ input_units=2, # Two input images
106
+ output_units=len(urls), # Generated images count
107
+ metadata={
108
+ "model": self.model_name,
109
+ "generation_type": "face_swap",
110
+ "image_count": len(urls),
111
+ "cost_usd": cost,
112
+ "hair_source": input_data.get("hair_source", "target")
113
+ }
114
+ )
115
+
116
+ # Return URLs
117
+ result = {
118
+ "urls": urls,
119
+ "url": urls[0] if urls else None,
120
+ "format": "jpg", # Default format
121
+ "count": len(urls),
122
+ "cost_usd": cost,
123
+ "metadata": {
124
+ "model": self.model_name,
125
+ "input": input_data,
126
+ "generation_count": len(urls)
127
+ }
128
+ }
129
+
130
+ logger.info(f"Face swap completed: {len(urls)} images, cost: ${cost:.6f}")
131
+ return result
132
+
133
+ except Exception as e:
134
+ logger.error(f"Face swap failed: {e}")
135
+ raise
136
+
137
+ def _calculate_cost(self, image_count: int) -> float:
138
+ """Calculate generation cost - $0.04 per generation"""
139
+ return image_count * 0.04
140
+
141
+ def get_generation_stats(self) -> Dict[str, Any]:
142
+ """Get generation statistics"""
143
+ total_cost = self.total_generation_count * 0.04
144
+
145
+ return {
146
+ "total_generation_count": self.total_generation_count,
147
+ "total_cost_usd": total_cost,
148
+ "cost_per_generation": 0.04,
149
+ "model": self.model_name
150
+ }
151
+
152
+ def get_model_info(self) -> Dict[str, Any]:
153
+ """Get model information"""
154
+ return {
155
+ "name": self.model_name,
156
+ "type": "face_swap",
157
+ "cost_per_generation": 0.04,
158
+ "supports_hair_source": True,
159
+ "supports_gender_control": True,
160
+ "hair_source_options": ["target", "swap"],
161
+ "gender_options": ["default", "male", "female"]
162
+ }
163
+
164
+ async def load(self) -> None:
165
+ """Load service"""
166
+ if not self.api_token:
167
+ raise ValueError("Missing Replicate API token")
168
+ logger.info(f"Replicate Face Swap service ready with model: {self.model_name}")
169
+
170
+ async def unload(self) -> None:
171
+ """Unload service"""
172
+ logger.info(f"Unloading Replicate Face Swap service: {self.model_name}")
173
+
174
+ async def close(self):
175
+ """Close service"""
176
+ await self.unload()
177
+
178
+ # Abstract method implementations (not supported by face swap)
179
+ async def generate_image(self, prompt: str, negative_prompt=None, width: int = 512, height: int = 512, num_inference_steps: int = 20, guidance_scale: float = 7.5, seed=None) -> Dict[str, Any]:
180
+ """Not supported - use face_swap instead"""
181
+ raise NotImplementedError("Face swap requires two images - use face_swap method")
182
+
183
+ async def generate_images(self, prompt: str, num_images: int = 1, negative_prompt=None, width: int = 512, height: int = 512, num_inference_steps: int = 20, guidance_scale: float = 7.5, seed=None) -> list[Dict[str, Any]]:
184
+ """Not supported - use face_swap instead"""
185
+ raise NotImplementedError("Face swap requires two images - use face_swap method")
186
+
187
+ async def image_to_image(self, prompt: str, init_image, strength: float = 0.8, negative_prompt=None, num_inference_steps: int = 20, guidance_scale: float = 7.5, seed=None) -> Dict[str, Any]:
188
+ """Not supported - use face_swap instead"""
189
+ raise NotImplementedError("Face swap requires specific face swap method")
190
+
191
+ def get_supported_sizes(self) -> list[Dict[str, int]]:
192
+ """Get supported image sizes"""
193
+ return [{"width": 512, "height": 512}, {"width": 768, "height": 768}, {"width": 1024, "height": 1024}]
@@ -0,0 +1,226 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Replicate FLUX Schnell Service
6
+ Specialized service for text-to-image generation using FLUX Schnell model
7
+ """
8
+
9
+ import os
10
+ import logging
11
+ from typing import Dict, Any, Optional
12
+ import replicate
13
+
14
+ from ..base_image_gen_service import BaseImageGenService
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class ReplicateFluxService(BaseImageGenService):
19
+ """
20
+ Replicate FLUX Schnell Service - $3 per 1000 images
21
+ Ultra-fast text-to-image generation for rapid prototyping
22
+ """
23
+
24
+ def __init__(self, provider_name: str, model_name: str, **kwargs):
25
+ super().__init__(provider_name, model_name, **kwargs)
26
+
27
+ # Get configuration from centralized config manager
28
+ provider_config = self.get_provider_config()
29
+
30
+ try:
31
+ self.api_token = provider_config.get("api_key") or provider_config.get("replicate_api_token")
32
+
33
+ if not self.api_token:
34
+ raise ValueError("Replicate API token not found in provider configuration")
35
+
36
+ # Set API token
37
+ os.environ["REPLICATE_API_TOKEN"] = self.api_token
38
+
39
+ # Model path
40
+ self.model_path = "black-forest-labs/flux-schnell"
41
+
42
+ # Statistics
43
+ self.total_generation_count = 0
44
+
45
+ logger.info(f"Initialized ReplicateFluxService with model '{self.model_name}'")
46
+
47
+ except Exception as e:
48
+ logger.error(f"Failed to initialize Replicate FLUX client: {e}")
49
+ raise ValueError(f"Failed to initialize Replicate FLUX client: {e}") from e
50
+
51
+ async def generate_image(
52
+ self,
53
+ prompt: str,
54
+ negative_prompt: Optional[str] = None,
55
+ width: int = 512,
56
+ height: int = 512,
57
+ num_inference_steps: int = 4,
58
+ guidance_scale: float = 7.5,
59
+ seed: Optional[int] = None
60
+ ) -> Dict[str, Any]:
61
+ """Generate image from text prompt using FLUX Schnell"""
62
+
63
+ input_data = {
64
+ "prompt": prompt,
65
+ "go_fast": True,
66
+ "megapixels": "1",
67
+ "num_outputs": 1,
68
+ "aspect_ratio": "1:1",
69
+ "output_format": "jpg",
70
+ "output_quality": 90,
71
+ "num_inference_steps": num_inference_steps
72
+ }
73
+
74
+ return await self._generate_internal(input_data)
75
+
76
+ async def _generate_internal(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
77
+ """Internal generation method"""
78
+ try:
79
+ logger.info(f"Starting FLUX generation with prompt: {input_data.get('prompt', '')[:50]}...")
80
+
81
+ # Call Replicate API
82
+ output = await replicate.async_run(self.model_path, input=input_data)
83
+
84
+ # Process output - convert FileOutput objects to URL strings
85
+ if isinstance(output, list):
86
+ raw_urls = output
87
+ else:
88
+ raw_urls = [output]
89
+
90
+ # Convert to string URLs
91
+ urls = []
92
+ for url in raw_urls:
93
+ if hasattr(url, 'url'):
94
+ urls.append(str(url.url))
95
+ else:
96
+ urls.append(str(url))
97
+
98
+ # Update statistics
99
+ self.total_generation_count += len(urls)
100
+
101
+ # Calculate cost
102
+ cost = self._calculate_cost(len(urls))
103
+
104
+ # Track billing information
105
+ await self._track_usage(
106
+ service_type="image_generation",
107
+ operation="text_to_image",
108
+ input_tokens=0,
109
+ output_tokens=0,
110
+ input_units=1, # Input prompt
111
+ output_units=len(urls), # Generated images count
112
+ metadata={
113
+ "model": self.model_name,
114
+ "prompt": input_data.get("prompt", "")[:100],
115
+ "generation_type": "t2i",
116
+ "image_count": len(urls),
117
+ "cost_usd": cost
118
+ }
119
+ )
120
+
121
+ # Return URLs
122
+ result = {
123
+ "urls": urls,
124
+ "url": urls[0] if urls else None,
125
+ "format": input_data.get("output_format", "jpg"),
126
+ "aspect_ratio": input_data.get("aspect_ratio", "1:1"),
127
+ "count": len(urls),
128
+ "cost_usd": cost,
129
+ "metadata": {
130
+ "model": self.model_name,
131
+ "input": input_data,
132
+ "generation_count": len(urls)
133
+ }
134
+ }
135
+
136
+ logger.info(f"FLUX generation completed: {len(urls)} images, cost: ${cost:.6f}")
137
+ return result
138
+
139
+ except Exception as e:
140
+ logger.error(f"FLUX generation failed: {e}")
141
+ raise
142
+
143
+ def _calculate_cost(self, image_count: int) -> float:
144
+ """Calculate generation cost - $3 per 1000 images"""
145
+ return (image_count / 1000) * 3.0
146
+
147
+ def get_generation_stats(self) -> Dict[str, Any]:
148
+ """Get generation statistics"""
149
+ total_cost = (self.total_generation_count / 1000) * 3.0
150
+
151
+ return {
152
+ "total_generation_count": self.total_generation_count,
153
+ "total_cost_usd": total_cost,
154
+ "cost_per_1000_images": 3.0,
155
+ "model": self.model_name
156
+ }
157
+
158
+ def get_supported_aspect_ratios(self) -> list[str]:
159
+ """Get supported aspect ratios"""
160
+ return ["1:1", "16:9", "9:16", "4:3", "3:4", "21:9", "9:21"]
161
+
162
+ def get_model_info(self) -> Dict[str, Any]:
163
+ """Get model information"""
164
+ return {
165
+ "name": self.model_name,
166
+ "type": "text_to_image",
167
+ "cost_per_1000_images": 3.0,
168
+ "supports_negative_prompt": False,
169
+ "max_inference_steps": 4,
170
+ "supported_formats": ["jpg", "png", "webp"],
171
+ "supported_aspect_ratios": self.get_supported_aspect_ratios()
172
+ }
173
+
174
+ async def load(self) -> None:
175
+ """Load service"""
176
+ if not self.api_token:
177
+ raise ValueError("Missing Replicate API token")
178
+ logger.info(f"Replicate FLUX service ready with model: {self.model_name}")
179
+
180
+ async def unload(self) -> None:
181
+ """Unload service"""
182
+ logger.info(f"Unloading Replicate FLUX service: {self.model_name}")
183
+
184
+ async def close(self):
185
+ """Close service"""
186
+ await self.unload()
187
+
188
+ async def generate_images(
189
+ self,
190
+ prompt: str,
191
+ num_images: int = 1,
192
+ negative_prompt: Optional[str] = None,
193
+ width: int = 512,
194
+ height: int = 512,
195
+ num_inference_steps: int = 4,
196
+ guidance_scale: float = 7.5,
197
+ seed: Optional[int] = None
198
+ ) -> list[Dict[str, Any]]:
199
+ """Generate multiple images"""
200
+ results = []
201
+ for i in range(num_images):
202
+ current_seed = seed + i if seed else None
203
+ result = await self.generate_image(
204
+ prompt=prompt,
205
+ go_fast=True,
206
+ megapixels="1",
207
+ num_outputs=1,
208
+ aspect_ratio="1:1",
209
+ output_format="jpg",
210
+ output_quality=90,
211
+ num_inference_steps=num_inference_steps
212
+ )
213
+ results.append(result)
214
+ return results
215
+
216
+ def get_supported_sizes(self) -> list[Dict[str, int]]:
217
+ """Get supported image sizes"""
218
+ return [
219
+ {"width": 512, "height": 512},
220
+ {"width": 768, "height": 768},
221
+ {"width": 1024, "height": 1024},
222
+ ]
223
+
224
+ async def image_to_image(self, prompt: str, init_image, strength: float = 0.8, negative_prompt=None, num_inference_steps: int = 20, guidance_scale: float = 7.5, seed=None) -> Dict[str, Any]:
225
+ """Not supported by FLUX Schnell - text-to-image only"""
226
+ raise NotImplementedError("FLUX Schnell only supports text-to-image generation")
@@ -0,0 +1,219 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Replicate FLUX Kontext Pro Service
6
+ Specialized service for image-to-image generation using FLUX Kontext Pro model
7
+ """
8
+
9
+ import os
10
+ import logging
11
+ from typing import Dict, Any, Union, Optional
12
+ import replicate
13
+
14
+ from ..base_image_gen_service import BaseImageGenService
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class ReplicateFluxKontextService(BaseImageGenService):
19
+ """
20
+ Replicate FLUX Kontext Pro Service - $0.04 per image
21
+ Advanced image-to-image generation with superior control and quality
22
+ """
23
+
24
+ def __init__(self, provider_name: str, model_name: str, **kwargs):
25
+ super().__init__(provider_name, model_name, **kwargs)
26
+
27
+ # Get configuration from centralized config manager
28
+ provider_config = self.get_provider_config()
29
+
30
+ try:
31
+ self.api_token = provider_config.get("api_key") or provider_config.get("replicate_api_token")
32
+
33
+ if not self.api_token:
34
+ raise ValueError("Replicate API token not found in provider configuration")
35
+
36
+ # Set API token
37
+ os.environ["REPLICATE_API_TOKEN"] = self.api_token
38
+
39
+ # Model path
40
+ self.model_path = "black-forest-labs/flux-kontext-pro"
41
+
42
+ # Statistics
43
+ self.total_generation_count = 0
44
+
45
+ logger.info(f"Initialized ReplicateFluxKontextService with model '{self.model_name}'")
46
+
47
+ except Exception as e:
48
+ logger.error(f"Failed to initialize Replicate FLUX Kontext client: {e}")
49
+ raise ValueError(f"Failed to initialize Replicate FLUX Kontext client: {e}") from e
50
+
51
+ async def image_to_image(
52
+ self,
53
+ prompt: str,
54
+ init_image: Union[str, Any],
55
+ strength: float = 0.8,
56
+ negative_prompt=None,
57
+ num_inference_steps: int = 20,
58
+ guidance_scale: float = 7.5,
59
+ seed=None
60
+ ) -> Dict[str, Any]:
61
+ """Generate image from input image and prompt using FLUX Kontext Pro"""
62
+
63
+ input_data = {
64
+ "prompt": prompt,
65
+ "input_image": init_image,
66
+ "aspect_ratio": "match_input_image",
67
+ "output_format": "jpg",
68
+ "safety_tolerance": 2
69
+ }
70
+
71
+ return await self._generate_internal(input_data)
72
+
73
+ async def _generate_internal(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
74
+ """Internal generation method"""
75
+ try:
76
+ logger.info(f"Starting FLUX Kontext i2i with prompt: {input_data.get('prompt', '')[:50]}...")
77
+
78
+ # Call Replicate API
79
+ output = await replicate.async_run(self.model_path, input=input_data)
80
+
81
+ # Process output - convert FileOutput objects to URL strings
82
+ if isinstance(output, list):
83
+ raw_urls = output
84
+ else:
85
+ raw_urls = [output]
86
+
87
+ # Convert to string URLs
88
+ urls = []
89
+ for url in raw_urls:
90
+ if hasattr(url, 'url'):
91
+ urls.append(str(url.url))
92
+ else:
93
+ urls.append(str(url))
94
+
95
+ # Update statistics
96
+ self.total_generation_count += len(urls)
97
+
98
+ # Calculate cost
99
+ cost = self._calculate_cost(len(urls))
100
+
101
+ # Track billing information
102
+ await self._track_usage(
103
+ service_type="image_generation",
104
+ operation="image_to_image",
105
+ input_tokens=0,
106
+ output_tokens=0,
107
+ input_units=1, # Input image + prompt
108
+ output_units=len(urls), # Generated images count
109
+ metadata={
110
+ "model": self.model_name,
111
+ "prompt": input_data.get("prompt", "")[:100],
112
+ "generation_type": "i2i",
113
+ "image_count": len(urls),
114
+ "cost_usd": cost
115
+ }
116
+ )
117
+
118
+ # Return URLs
119
+ result = {
120
+ "urls": urls,
121
+ "url": urls[0] if urls else None,
122
+ "format": input_data.get("output_format", "jpg"),
123
+ "aspect_ratio": input_data.get("aspect_ratio", "match_input_image"),
124
+ "count": len(urls),
125
+ "cost_usd": cost,
126
+ "metadata": {
127
+ "model": self.model_name,
128
+ "input": input_data,
129
+ "generation_count": len(urls)
130
+ }
131
+ }
132
+
133
+ logger.info(f"FLUX Kontext i2i completed: {len(urls)} images, cost: ${cost:.6f}")
134
+ return result
135
+
136
+ except Exception as e:
137
+ logger.error(f"FLUX Kontext i2i failed: {e}")
138
+ raise
139
+
140
+ def _calculate_cost(self, image_count: int) -> float:
141
+ """Calculate generation cost - $0.04 per image"""
142
+ return image_count * 0.04
143
+
144
+ def get_generation_stats(self) -> Dict[str, Any]:
145
+ """Get generation statistics"""
146
+ total_cost = self.total_generation_count * 0.04
147
+
148
+ return {
149
+ "total_generation_count": self.total_generation_count,
150
+ "total_cost_usd": total_cost,
151
+ "cost_per_image": 0.04,
152
+ "model": self.model_name
153
+ }
154
+
155
+ def get_supported_aspect_ratios(self) -> list[str]:
156
+ """Get supported aspect ratios"""
157
+ return ["match_input_image", "1:1", "16:9", "9:16", "4:3", "3:4", "21:9", "9:21"]
158
+
159
+ def get_model_info(self) -> Dict[str, Any]:
160
+ """Get model information"""
161
+ return {
162
+ "name": self.model_name,
163
+ "type": "image_to_image",
164
+ "cost_per_image": 0.04,
165
+ "supports_negative_prompt": False,
166
+ "supports_img2img": True,
167
+ "supported_formats": ["jpg", "png", "webp"],
168
+ "supported_aspect_ratios": self.get_supported_aspect_ratios(),
169
+ "safety_tolerance_range": [1, 2, 3, 4, 5]
170
+ }
171
+
172
+ async def load(self) -> None:
173
+ """Load service"""
174
+ if not self.api_token:
175
+ raise ValueError("Missing Replicate API token")
176
+ logger.info(f"Replicate FLUX Kontext service ready with model: {self.model_name}")
177
+
178
+ async def unload(self) -> None:
179
+ """Unload service"""
180
+ logger.info(f"Unloading Replicate FLUX Kontext service: {self.model_name}")
181
+
182
+ async def close(self):
183
+ """Close service"""
184
+ await self.unload()
185
+
186
+ async def generate_image(
187
+ self,
188
+ prompt: str,
189
+ negative_prompt: Optional[str] = None,
190
+ width: int = 512,
191
+ height: int = 512,
192
+ num_inference_steps: int = 20,
193
+ guidance_scale: float = 7.5,
194
+ seed: Optional[int] = None
195
+ ) -> Dict[str, Any]:
196
+ """Not supported - use image_to_image instead"""
197
+ raise NotImplementedError("FLUX Kontext Pro requires an input image - use image_to_image method")
198
+
199
+ async def generate_images(
200
+ self,
201
+ prompt: str,
202
+ num_images: int = 1,
203
+ negative_prompt: Optional[str] = None,
204
+ width: int = 512,
205
+ height: int = 512,
206
+ num_inference_steps: int = 20,
207
+ guidance_scale: float = 7.5,
208
+ seed: Optional[int] = None
209
+ ) -> list[Dict[str, Any]]:
210
+ """Not supported - use image_to_image instead"""
211
+ raise NotImplementedError("FLUX Kontext Pro requires an input image - use image_to_image method")
212
+
213
+ def get_supported_sizes(self) -> list[Dict[str, int]]:
214
+ """Get supported image sizes"""
215
+ return [
216
+ {"width": 512, "height": 512},
217
+ {"width": 768, "height": 768},
218
+ {"width": 1024, "height": 1024},
219
+ ]