isa-model 0.0.3__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_registry.py +273 -46
  3. isa_model/core/storage/hf_storage.py +419 -0
  4. isa_model/deployment/__init__.py +52 -0
  5. isa_model/deployment/core/__init__.py +34 -0
  6. isa_model/deployment/core/deployment_config.py +356 -0
  7. isa_model/deployment/core/deployment_manager.py +549 -0
  8. isa_model/deployment/core/isa_deployment_service.py +401 -0
  9. isa_model/eval/factory.py +381 -140
  10. isa_model/inference/ai_factory.py +142 -240
  11. isa_model/inference/providers/ml_provider.py +50 -0
  12. isa_model/inference/services/audio/openai_tts_service.py +104 -3
  13. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  14. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  15. isa_model/inference/services/llm/__init__.py +2 -0
  16. isa_model/inference/services/llm/base_llm_service.py +111 -1
  17. isa_model/inference/services/llm/ollama_llm_service.py +234 -26
  18. isa_model/inference/services/llm/openai_llm_service.py +180 -26
  19. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  20. isa_model/inference/services/ml/base_ml_service.py +78 -0
  21. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  22. isa_model/inference/services/vision/__init__.py +3 -3
  23. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  24. isa_model/inference/services/vision/base_vision_service.py +177 -0
  25. isa_model/inference/services/vision/ollama_vision_service.py +143 -17
  26. isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
  27. isa_model/training/__init__.py +62 -32
  28. isa_model/training/cloud/__init__.py +22 -0
  29. isa_model/training/cloud/job_orchestrator.py +402 -0
  30. isa_model/training/cloud/runpod_trainer.py +454 -0
  31. isa_model/training/cloud/storage_manager.py +482 -0
  32. isa_model/training/core/__init__.py +23 -0
  33. isa_model/training/core/config.py +181 -0
  34. isa_model/training/core/dataset.py +222 -0
  35. isa_model/training/core/trainer.py +720 -0
  36. isa_model/training/core/utils.py +213 -0
  37. isa_model/training/factory.py +229 -198
  38. isa_model-0.0.8.dist-info/METADATA +465 -0
  39. isa_model-0.0.8.dist-info/RECORD +86 -0
  40. isa_model/core/model_router.py +0 -226
  41. isa_model/core/model_version.py +0 -0
  42. isa_model/core/resource_manager.py +0 -202
  43. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  44. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  45. isa_model/training/engine/llama_factory/__init__.py +0 -39
  46. isa_model/training/engine/llama_factory/config.py +0 -115
  47. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  48. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  49. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  50. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  51. isa_model/training/engine/llama_factory/factory.py +0 -331
  52. isa_model/training/engine/llama_factory/rl.py +0 -254
  53. isa_model/training/engine/llama_factory/trainer.py +0 -171
  54. isa_model/training/image_model/configs/create_config.py +0 -37
  55. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  56. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  57. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  58. isa_model/training/image_model/prepare_upload.py +0 -17
  59. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  60. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  61. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  62. isa_model/training/image_model/train/train.py +0 -42
  63. isa_model/training/image_model/train/train_flux.py +0 -41
  64. isa_model/training/image_model/train/train_lora.py +0 -57
  65. isa_model/training/image_model/train_main.py +0 -25
  66. isa_model-0.0.3.dist-info/METADATA +0 -327
  67. isa_model-0.0.3.dist-info/RECORD +0 -92
  68. isa_model-0.0.3.dist-info/licenses/LICENSE +0 -21
  69. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  70. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  71. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  72. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  73. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  74. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  75. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  76. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  77. {isa_model-0.0.3.dist-info → isa_model-0.0.8.dist-info}/WHEEL +0 -0
  78. {isa_model-0.0.3.dist-info → isa_model-0.0.8.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,8 @@ import replicate # 导入 replicate 库
17
17
  from PIL import Image
18
18
  from io import BytesIO
19
19
 
20
- # 调整 BaseService 的导入路径以匹配您的项目结构
21
- from isa_model.inference.services.base_service import BaseService
20
+ # 调整导入路径以使用正确的基类
21
+ from isa_model.inference.services.vision.base_image_gen_service import BaseImageGenService
22
22
  from isa_model.inference.providers.base_provider import BaseProvider
23
23
  from isa_model.inference.base import ModelType
24
24
 
@@ -26,7 +26,7 @@ from isa_model.inference.base import ModelType
26
26
  logging.basicConfig(level=logging.INFO)
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
- class ReplicateVisionService(BaseService):
29
+ class ReplicateImageGenService(BaseImageGenService):
30
30
  """
31
31
  Replicate Vision服务,用于处理图像生成和分析。
32
32
  经过调整,使用原生异步调用并优化了文件处理。
@@ -76,9 +76,9 @@ class ReplicateVisionService(BaseService):
76
76
  raise
77
77
  return prepared_input, files_to_close
78
78
 
79
- async def generate_image(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
79
+ async def _generate_image_internal(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
80
80
  """
81
- 使用Replicate模型生成图像 (已优化为原生异步)
81
+ 内部方法:使用Replicate模型生成图像 (已优化为原生异步)
82
82
  """
83
83
  prepared_input, files_to_close = await self._prepare_input_files(input_data)
84
84
  try:
@@ -146,7 +146,7 @@ class ReplicateVisionService(BaseService):
146
146
 
147
147
  async def generate_and_save(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
148
148
  """生成图像并保存到本地 (此方法无需修改)"""
149
- result = await self.generate_image(input_data)
149
+ result = await self._generate_image_internal(input_data)
150
150
  saved_paths = []
151
151
  for i, url in enumerate(result["urls"]):
152
152
  timestamp = int(time.time())
@@ -182,4 +182,136 @@ class ReplicateVisionService(BaseService):
182
182
  logger.info(f"Replicate Vision服务已准备就绪,使用模型: {self.model_name}")
183
183
 
184
184
  async def unload(self) -> None:
185
- logger.info(f"卸载Replicate Vision服务: {self.model_name}")
185
+ logger.info(f"卸载Replicate Vision服务: {self.model_name}")
186
+
187
+ # 实现BaseImageGenService的抽象方法
188
+ async def generate_image(
189
+ self,
190
+ prompt: str,
191
+ negative_prompt: Optional[str] = None,
192
+ width: int = 512,
193
+ height: int = 512,
194
+ num_inference_steps: int = 20,
195
+ guidance_scale: float = 7.5,
196
+ seed: Optional[int] = None
197
+ ) -> Dict[str, Any]:
198
+ """Generate a single image from text prompt"""
199
+ input_data = {
200
+ "prompt": prompt,
201
+ "width": width,
202
+ "height": height,
203
+ "num_inference_steps": num_inference_steps,
204
+ "guidance_scale": guidance_scale
205
+ }
206
+
207
+ if negative_prompt:
208
+ input_data["negative_prompt"] = negative_prompt
209
+ if seed:
210
+ input_data["seed"] = seed
211
+
212
+ return await self._generate_image_internal(input_data)
213
+
214
+ async def generate_images(
215
+ self,
216
+ prompt: str,
217
+ num_images: int = 1,
218
+ negative_prompt: Optional[str] = None,
219
+ width: int = 512,
220
+ height: int = 512,
221
+ num_inference_steps: int = 20,
222
+ guidance_scale: float = 7.5,
223
+ seed: Optional[int] = None
224
+ ) -> List[Dict[str, Any]]:
225
+ """Generate multiple images from text prompt"""
226
+ results = []
227
+ for i in range(num_images):
228
+ current_seed = seed + i if seed else None
229
+ result = await self.generate_image(
230
+ prompt, negative_prompt, width, height,
231
+ num_inference_steps, guidance_scale, current_seed
232
+ )
233
+ results.append(result)
234
+ return results
235
+
236
+ async def generate_image_to_file(
237
+ self,
238
+ prompt: str,
239
+ output_path: str,
240
+ negative_prompt: Optional[str] = None,
241
+ width: int = 512,
242
+ height: int = 512,
243
+ num_inference_steps: int = 20,
244
+ guidance_scale: float = 7.5,
245
+ seed: Optional[int] = None
246
+ ) -> Dict[str, Any]:
247
+ """Generate image and save directly to file"""
248
+ result = await self.generate_image(
249
+ prompt, negative_prompt, width, height,
250
+ num_inference_steps, guidance_scale, seed
251
+ )
252
+
253
+ # Save the first generated image to the specified path
254
+ if result.get("urls"):
255
+ url = result["urls"][0]
256
+ url_str = str(url) if hasattr(url, "__str__") else url
257
+ await self._download_image(url_str, output_path)
258
+
259
+ return {
260
+ "file_path": output_path,
261
+ "width": width,
262
+ "height": height,
263
+ "seed": seed
264
+ }
265
+ else:
266
+ raise ValueError("No image generated")
267
+
268
+ async def image_to_image(
269
+ self,
270
+ prompt: str,
271
+ init_image: Union[str, Any],
272
+ strength: float = 0.8,
273
+ negative_prompt: Optional[str] = None,
274
+ num_inference_steps: int = 20,
275
+ guidance_scale: float = 7.5,
276
+ seed: Optional[int] = None
277
+ ) -> Dict[str, Any]:
278
+ """Generate image based on existing image and prompt"""
279
+ input_data = {
280
+ "prompt": prompt,
281
+ "image": init_image,
282
+ "strength": strength,
283
+ "num_inference_steps": num_inference_steps,
284
+ "guidance_scale": guidance_scale
285
+ }
286
+
287
+ if negative_prompt:
288
+ input_data["negative_prompt"] = negative_prompt
289
+ if seed:
290
+ input_data["seed"] = seed
291
+
292
+ return await self._generate_image_internal(input_data)
293
+
294
+ def get_supported_sizes(self) -> List[Dict[str, int]]:
295
+ """Get list of supported image dimensions"""
296
+ return [
297
+ {"width": 512, "height": 512},
298
+ {"width": 768, "height": 768},
299
+ {"width": 1024, "height": 1024},
300
+ {"width": 768, "height": 1344},
301
+ {"width": 1344, "height": 768},
302
+ ]
303
+
304
+ def get_model_info(self) -> Dict[str, Any]:
305
+ """Get information about the image generation model"""
306
+ return {
307
+ "name": self.model_name,
308
+ "max_width": 1344,
309
+ "max_height": 1344,
310
+ "supports_negative_prompt": True,
311
+ "supports_img2img": True
312
+ }
313
+
314
+ async def close(self):
315
+ """Cleanup resources"""
316
+ await self.unload()
317
+
@@ -1,44 +1,74 @@
1
1
  """
2
- ISA Model Training Framework
2
+ ISA Model Training Module
3
3
 
4
- This module provides unified interfaces for training various types of AI models:
5
- - LLM training with LlamaFactory
6
- - Image model training with Flux/LoRA
7
- - Model evaluation and benchmarking
4
+ Provides unified training capabilities for AI models including:
5
+ - Local training with SFT (Supervised Fine-Tuning)
6
+ - Cloud training on RunPod
7
+ - Model evaluation and management
8
+ - HuggingFace integration
8
9
 
9
- Usage:
10
- from isa_model.training import TrainingFactory
10
+ Example usage:
11
+ ```python
12
+ from isa_model.training import TrainingFactory, train_gemma
11
13
 
12
- # Create training factory
13
- factory = TrainingFactory()
14
+ # Quick Gemma training
15
+ model_path = train_gemma(
16
+ dataset_path="tatsu-lab/alpaca",
17
+ model_size="4b",
18
+ num_epochs=3
19
+ )
14
20
 
15
- # Fine-tune Gemma 3:4B
16
- model_path = factory.finetune_llm(
17
- model_name="gemma:4b",
18
- dataset_path="path/to/data.json",
19
- training_type="sft"
21
+ # Advanced training with custom configuration
22
+ factory = TrainingFactory()
23
+ model_path = factory.train_model(
24
+ model_name="google/gemma-2-4b-it",
25
+ dataset_path="your-dataset.json",
26
+ use_lora=True,
27
+ batch_size=4,
28
+ num_epochs=3
20
29
  )
30
+ ```
21
31
  """
22
32
 
23
- from .factory import TrainingFactory, finetune_gemma
24
- from .engine.llama_factory import (
25
- LlamaFactory,
26
- LlamaFactoryConfig,
27
- SFTConfig,
28
- RLConfig,
29
- DPOConfig,
30
- TrainingStrategy,
31
- DatasetFormat
33
+ # Import the new clean factory
34
+ from .factory import TrainingFactory, train_gemma
35
+
36
+ # Import core components
37
+ from .core import (
38
+ TrainingConfig,
39
+ LoRAConfig,
40
+ DatasetConfig,
41
+ BaseTrainer,
42
+ SFTTrainer,
43
+ TrainingUtils,
44
+ DatasetManager
45
+ )
46
+
47
+ # Import cloud training components
48
+ from .cloud import (
49
+ RunPodConfig,
50
+ StorageConfig,
51
+ JobConfig,
52
+ TrainingJobOrchestrator
32
53
  )
33
54
 
34
55
  __all__ = [
35
- "TrainingFactory",
36
- "finetune_gemma",
37
- "LlamaFactory",
38
- "LlamaFactoryConfig",
39
- "SFTConfig",
40
- "RLConfig",
41
- "DPOConfig",
42
- "TrainingStrategy",
43
- "DatasetFormat"
56
+ # Main factory
57
+ 'TrainingFactory',
58
+ 'train_gemma',
59
+
60
+ # Core components
61
+ 'TrainingConfig',
62
+ 'LoRAConfig',
63
+ 'DatasetConfig',
64
+ 'BaseTrainer',
65
+ 'SFTTrainer',
66
+ 'TrainingUtils',
67
+ 'DatasetManager',
68
+
69
+ # Cloud components
70
+ 'RunPodConfig',
71
+ 'StorageConfig',
72
+ 'JobConfig',
73
+ 'TrainingJobOrchestrator'
44
74
  ]
@@ -0,0 +1,22 @@
1
+ """
2
+ Cloud Training Module for ISA Model Framework
3
+
4
+ This module provides cloud training capabilities including:
5
+ - RunPod integration for on-demand GPU training
6
+ - Remote storage management (S3, GCS, etc.)
7
+ - Training job orchestration and monitoring
8
+ - Automatic resource scaling and management
9
+ """
10
+
11
+ from .runpod_trainer import RunPodTrainer
12
+ from .storage_manager import CloudStorageManager
13
+ from .job_orchestrator import TrainingJobOrchestrator
14
+
15
+ # Import config classes - these are defined in each module that needs them
16
+ # from ..core.config import RunPodConfig, StorageConfig, JobConfig
17
+
18
+ __all__ = [
19
+ "RunPodTrainer",
20
+ "CloudStorageManager",
21
+ "TrainingJobOrchestrator"
22
+ ]