isa-model 0.3.91__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. isa_model/client.py +1166 -584
  2. isa_model/core/cache/redis_cache.py +410 -0
  3. isa_model/core/config/config_manager.py +282 -12
  4. isa_model/core/config.py +91 -1
  5. isa_model/core/database/__init__.py +1 -0
  6. isa_model/core/database/direct_db_client.py +114 -0
  7. isa_model/core/database/migration_manager.py +563 -0
  8. isa_model/core/database/migrations.py +297 -0
  9. isa_model/core/database/supabase_client.py +258 -0
  10. isa_model/core/dependencies.py +316 -0
  11. isa_model/core/discovery/__init__.py +19 -0
  12. isa_model/core/discovery/consul_discovery.py +190 -0
  13. isa_model/core/logging/__init__.py +54 -0
  14. isa_model/core/logging/influx_logger.py +523 -0
  15. isa_model/core/logging/loki_logger.py +160 -0
  16. isa_model/core/models/__init__.py +46 -0
  17. isa_model/core/models/config_models.py +625 -0
  18. isa_model/core/models/deployment_billing_tracker.py +430 -0
  19. isa_model/core/models/model_billing_tracker.py +60 -88
  20. isa_model/core/models/model_manager.py +66 -25
  21. isa_model/core/models/model_metadata.py +690 -0
  22. isa_model/core/models/model_repo.py +217 -55
  23. isa_model/core/models/model_statistics_tracker.py +234 -0
  24. isa_model/core/models/model_storage.py +0 -1
  25. isa_model/core/models/model_version_manager.py +959 -0
  26. isa_model/core/models/system_models.py +857 -0
  27. isa_model/core/pricing_manager.py +2 -249
  28. isa_model/core/repositories/__init__.py +9 -0
  29. isa_model/core/repositories/config_repository.py +912 -0
  30. isa_model/core/resilience/circuit_breaker.py +366 -0
  31. isa_model/core/security/secrets.py +358 -0
  32. isa_model/core/services/__init__.py +2 -4
  33. isa_model/core/services/intelligent_model_selector.py +479 -370
  34. isa_model/core/storage/hf_storage.py +2 -2
  35. isa_model/core/types.py +8 -0
  36. isa_model/deployment/__init__.py +5 -48
  37. isa_model/deployment/core/__init__.py +2 -31
  38. isa_model/deployment/core/deployment_manager.py +1278 -368
  39. isa_model/deployment/local/__init__.py +31 -0
  40. isa_model/deployment/local/config.py +248 -0
  41. isa_model/deployment/local/gpu_gateway.py +607 -0
  42. isa_model/deployment/local/health_checker.py +428 -0
  43. isa_model/deployment/local/provider.py +586 -0
  44. isa_model/deployment/local/tensorrt_service.py +621 -0
  45. isa_model/deployment/local/transformers_service.py +644 -0
  46. isa_model/deployment/local/vllm_service.py +527 -0
  47. isa_model/deployment/modal/__init__.py +8 -0
  48. isa_model/deployment/modal/config.py +136 -0
  49. isa_model/deployment/modal/deployer.py +894 -0
  50. isa_model/deployment/modal/services/__init__.py +3 -0
  51. isa_model/deployment/modal/services/audio/__init__.py +1 -0
  52. isa_model/deployment/modal/services/audio/isa_audio_chatTTS_service.py +520 -0
  53. isa_model/deployment/modal/services/audio/isa_audio_openvoice_service.py +758 -0
  54. isa_model/deployment/modal/services/audio/isa_audio_service_v2.py +1044 -0
  55. isa_model/deployment/modal/services/embedding/__init__.py +1 -0
  56. isa_model/deployment/modal/services/embedding/isa_embed_rerank_service.py +296 -0
  57. isa_model/deployment/modal/services/llm/__init__.py +1 -0
  58. isa_model/deployment/modal/services/llm/isa_llm_service.py +424 -0
  59. isa_model/deployment/modal/services/video/__init__.py +1 -0
  60. isa_model/deployment/modal/services/video/isa_video_hunyuan_service.py +423 -0
  61. isa_model/deployment/modal/services/vision/__init__.py +1 -0
  62. isa_model/deployment/modal/services/vision/isa_vision_ocr_service.py +519 -0
  63. isa_model/deployment/modal/services/vision/isa_vision_qwen25_service.py +709 -0
  64. isa_model/deployment/modal/services/vision/isa_vision_table_service.py +676 -0
  65. isa_model/deployment/modal/services/vision/isa_vision_ui_service.py +833 -0
  66. isa_model/deployment/modal/services/vision/isa_vision_ui_service_optimized.py +660 -0
  67. isa_model/deployment/models/org-org-acme-corp-tenant-a-service-llm-20250825-225822/tenant-a-service_modal_service.py +48 -0
  68. isa_model/deployment/models/org-test-org-123-prefix-test-service-llm-20250825-225822/prefix-test-service_modal_service.py +48 -0
  69. isa_model/deployment/models/test-llm-service-llm-20250825-204442/test-llm-service_modal_service.py +48 -0
  70. isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-212906/test-monitoring-gpt2_modal_service.py +48 -0
  71. isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-213009/test-monitoring-gpt2_modal_service.py +48 -0
  72. isa_model/deployment/storage/__init__.py +5 -0
  73. isa_model/deployment/storage/deployment_repository.py +824 -0
  74. isa_model/deployment/triton/__init__.py +10 -0
  75. isa_model/deployment/triton/config.py +196 -0
  76. isa_model/deployment/triton/configs/__init__.py +1 -0
  77. isa_model/deployment/triton/provider.py +512 -0
  78. isa_model/deployment/triton/scripts/__init__.py +1 -0
  79. isa_model/deployment/triton/templates/__init__.py +1 -0
  80. isa_model/inference/__init__.py +47 -1
  81. isa_model/inference/ai_factory.py +179 -16
  82. isa_model/inference/legacy_services/__init__.py +21 -0
  83. isa_model/inference/legacy_services/model_evaluation.py +637 -0
  84. isa_model/inference/legacy_services/model_service.py +573 -0
  85. isa_model/inference/legacy_services/model_serving.py +717 -0
  86. isa_model/inference/legacy_services/model_training.py +561 -0
  87. isa_model/inference/models/__init__.py +21 -0
  88. isa_model/inference/models/inference_config.py +551 -0
  89. isa_model/inference/models/inference_record.py +675 -0
  90. isa_model/inference/models/performance_models.py +714 -0
  91. isa_model/inference/repositories/__init__.py +9 -0
  92. isa_model/inference/repositories/inference_repository.py +828 -0
  93. isa_model/inference/services/audio/__init__.py +21 -0
  94. isa_model/inference/services/audio/base_realtime_service.py +225 -0
  95. isa_model/inference/services/audio/base_stt_service.py +184 -11
  96. isa_model/inference/services/audio/isa_tts_service.py +0 -0
  97. isa_model/inference/services/audio/openai_realtime_service.py +320 -124
  98. isa_model/inference/services/audio/openai_stt_service.py +53 -11
  99. isa_model/inference/services/base_service.py +17 -1
  100. isa_model/inference/services/custom_model_manager.py +277 -0
  101. isa_model/inference/services/embedding/__init__.py +13 -0
  102. isa_model/inference/services/embedding/base_embed_service.py +111 -8
  103. isa_model/inference/services/embedding/isa_embed_service.py +305 -0
  104. isa_model/inference/services/embedding/ollama_embed_service.py +15 -3
  105. isa_model/inference/services/embedding/openai_embed_service.py +2 -4
  106. isa_model/inference/services/embedding/resilient_embed_service.py +285 -0
  107. isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
  108. isa_model/inference/services/img/__init__.py +2 -2
  109. isa_model/inference/services/img/base_image_gen_service.py +24 -7
  110. isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
  111. isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
  112. isa_model/inference/services/img/services/replicate_flux.py +226 -0
  113. isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
  114. isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
  115. isa_model/inference/services/img/tests/test_img_client.py +297 -0
  116. isa_model/inference/services/llm/__init__.py +10 -2
  117. isa_model/inference/services/llm/base_llm_service.py +361 -26
  118. isa_model/inference/services/llm/cerebras_llm_service.py +628 -0
  119. isa_model/inference/services/llm/helpers/llm_adapter.py +71 -12
  120. isa_model/inference/services/llm/helpers/llm_prompts.py +342 -0
  121. isa_model/inference/services/llm/helpers/llm_utils.py +321 -23
  122. isa_model/inference/services/llm/huggingface_llm_service.py +581 -0
  123. isa_model/inference/services/llm/local_llm_service.py +747 -0
  124. isa_model/inference/services/llm/ollama_llm_service.py +11 -3
  125. isa_model/inference/services/llm/openai_llm_service.py +670 -56
  126. isa_model/inference/services/llm/yyds_llm_service.py +10 -3
  127. isa_model/inference/services/vision/__init__.py +27 -6
  128. isa_model/inference/services/vision/base_vision_service.py +118 -185
  129. isa_model/inference/services/vision/blip_vision_service.py +359 -0
  130. isa_model/inference/services/vision/helpers/image_utils.py +19 -10
  131. isa_model/inference/services/vision/isa_vision_service.py +634 -0
  132. isa_model/inference/services/vision/openai_vision_service.py +19 -10
  133. isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
  134. isa_model/inference/services/vision/vgg16_vision_service.py +257 -0
  135. isa_model/serving/api/cache_manager.py +245 -0
  136. isa_model/serving/api/dependencies/__init__.py +1 -0
  137. isa_model/serving/api/dependencies/auth.py +194 -0
  138. isa_model/serving/api/dependencies/database.py +139 -0
  139. isa_model/serving/api/error_handlers.py +284 -0
  140. isa_model/serving/api/fastapi_server.py +240 -18
  141. isa_model/serving/api/middleware/auth.py +317 -0
  142. isa_model/serving/api/middleware/security.py +268 -0
  143. isa_model/serving/api/middleware/tenant_context.py +414 -0
  144. isa_model/serving/api/routes/analytics.py +489 -0
  145. isa_model/serving/api/routes/config.py +645 -0
  146. isa_model/serving/api/routes/deployment_billing.py +315 -0
  147. isa_model/serving/api/routes/deployments.py +475 -0
  148. isa_model/serving/api/routes/gpu_gateway.py +440 -0
  149. isa_model/serving/api/routes/health.py +32 -12
  150. isa_model/serving/api/routes/inference_monitoring.py +486 -0
  151. isa_model/serving/api/routes/local_deployments.py +448 -0
  152. isa_model/serving/api/routes/logs.py +430 -0
  153. isa_model/serving/api/routes/settings.py +582 -0
  154. isa_model/serving/api/routes/tenants.py +575 -0
  155. isa_model/serving/api/routes/unified.py +992 -171
  156. isa_model/serving/api/routes/webhooks.py +479 -0
  157. isa_model/serving/api/startup.py +318 -0
  158. isa_model/serving/modal_proxy_server.py +249 -0
  159. isa_model/utils/gpu_utils.py +311 -0
  160. {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/METADATA +76 -22
  161. isa_model-0.4.3.dist-info/RECORD +193 -0
  162. isa_model/deployment/cloud/__init__.py +0 -9
  163. isa_model/deployment/cloud/modal/__init__.py +0 -10
  164. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
  165. isa_model/deployment/cloud/modal/isa_vision_table_service.py +0 -532
  166. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +0 -406
  167. isa_model/deployment/cloud/modal/register_models.py +0 -321
  168. isa_model/deployment/core/deployment_config.py +0 -356
  169. isa_model/deployment/core/isa_deployment_service.py +0 -401
  170. isa_model/deployment/gpu_int8_ds8/app/server.py +0 -66
  171. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +0 -43
  172. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +0 -35
  173. isa_model/deployment/runtime/deployed_service.py +0 -338
  174. isa_model/deployment/services/__init__.py +0 -9
  175. isa_model/deployment/services/auto_deploy_vision_service.py +0 -538
  176. isa_model/deployment/services/model_service.py +0 -332
  177. isa_model/deployment/services/service_monitor.py +0 -356
  178. isa_model/deployment/services/service_registry.py +0 -527
  179. isa_model/eval/__init__.py +0 -92
  180. isa_model/eval/benchmarks.py +0 -469
  181. isa_model/eval/config/__init__.py +0 -10
  182. isa_model/eval/config/evaluation_config.py +0 -108
  183. isa_model/eval/evaluators/__init__.py +0 -18
  184. isa_model/eval/evaluators/base_evaluator.py +0 -503
  185. isa_model/eval/evaluators/llm_evaluator.py +0 -472
  186. isa_model/eval/factory.py +0 -531
  187. isa_model/eval/infrastructure/__init__.py +0 -24
  188. isa_model/eval/infrastructure/experiment_tracker.py +0 -466
  189. isa_model/eval/metrics.py +0 -798
  190. isa_model/inference/adapter/unified_api.py +0 -248
  191. isa_model/inference/services/helpers/stacked_config.py +0 -148
  192. isa_model/inference/services/img/flux_professional_service.py +0 -603
  193. isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
  194. isa_model/inference/services/others/table_transformer_service.py +0 -61
  195. isa_model/inference/services/vision/doc_analysis_service.py +0 -640
  196. isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
  197. isa_model/inference/services/vision/ui_analysis_service.py +0 -823
  198. isa_model/scripts/inference_tracker.py +0 -283
  199. isa_model/scripts/mlflow_manager.py +0 -379
  200. isa_model/scripts/model_registry.py +0 -465
  201. isa_model/scripts/register_models.py +0 -370
  202. isa_model/scripts/register_models_with_embeddings.py +0 -510
  203. isa_model/scripts/start_mlflow.py +0 -95
  204. isa_model/scripts/training_tracker.py +0 -257
  205. isa_model/training/__init__.py +0 -74
  206. isa_model/training/annotation/annotation_schema.py +0 -47
  207. isa_model/training/annotation/processors/annotation_processor.py +0 -126
  208. isa_model/training/annotation/storage/dataset_manager.py +0 -131
  209. isa_model/training/annotation/storage/dataset_schema.py +0 -44
  210. isa_model/training/annotation/tests/test_annotation_flow.py +0 -109
  211. isa_model/training/annotation/tests/test_minio copy.py +0 -113
  212. isa_model/training/annotation/tests/test_minio_upload.py +0 -43
  213. isa_model/training/annotation/views/annotation_controller.py +0 -158
  214. isa_model/training/cloud/__init__.py +0 -22
  215. isa_model/training/cloud/job_orchestrator.py +0 -402
  216. isa_model/training/cloud/runpod_trainer.py +0 -454
  217. isa_model/training/cloud/storage_manager.py +0 -482
  218. isa_model/training/core/__init__.py +0 -23
  219. isa_model/training/core/config.py +0 -181
  220. isa_model/training/core/dataset.py +0 -222
  221. isa_model/training/core/trainer.py +0 -720
  222. isa_model/training/core/utils.py +0 -213
  223. isa_model/training/factory.py +0 -424
  224. isa_model-0.3.91.dist-info/RECORD +0 -138
  225. /isa_model/{core/storage/minio_storage.py → deployment/modal/services/audio/isa_audio_fish_service.py} +0 -0
  226. /isa_model/deployment/{services → modal/services/vision}/simple_auto_deploy_vision_service.py +0 -0
  227. {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/WHEEL +0 -0
  228. {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/top_level.txt +0 -0
@@ -1,213 +0,0 @@
1
- """
2
- Training Utilities
3
-
4
- Helper functions and utilities for training operations.
5
- """
6
-
7
- import os
8
- import json
9
- import logging
10
- import datetime
11
- from typing import Dict, Any, Optional, List
12
- from pathlib import Path
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class TrainingUtils:
18
- """Utility functions for training operations."""
19
-
20
- @staticmethod
21
- def generate_output_dir(model_name: str, training_type: str, base_dir: str = "training_outputs") -> str:
22
- """Generate a timestamped output directory."""
23
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
24
- safe_model_name = model_name.replace("/", "_").replace(":", "_")
25
- output_dir = os.path.join(base_dir, f"{safe_model_name}_{training_type}_{timestamp}")
26
- return output_dir
27
-
28
- @staticmethod
29
- def save_training_args(args: Dict[str, Any], output_dir: str) -> None:
30
- """Save training arguments to file."""
31
- args_path = Path(output_dir) / "training_args.json"
32
- args_path.parent.mkdir(parents=True, exist_ok=True)
33
-
34
- with open(args_path, 'w') as f:
35
- json.dump(args, f, indent=2, default=str)
36
-
37
- logger.info(f"Training arguments saved to: {args_path}")
38
-
39
- @staticmethod
40
- def load_training_args(output_dir: str) -> Dict[str, Any]:
41
- """Load training arguments from file."""
42
- args_path = Path(output_dir) / "training_args.json"
43
-
44
- if not args_path.exists():
45
- raise FileNotFoundError(f"Training args not found: {args_path}")
46
-
47
- with open(args_path, 'r') as f:
48
- args = json.load(f)
49
-
50
- return args
51
-
52
- @staticmethod
53
- def get_model_info(model_name: str) -> Dict[str, Any]:
54
- """Get information about a model."""
55
- try:
56
- from transformers import AutoConfig
57
-
58
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
59
-
60
- model_info = {
61
- "model_name": model_name,
62
- "model_type": config.model_type,
63
- "vocab_size": getattr(config, 'vocab_size', None),
64
- "hidden_size": getattr(config, 'hidden_size', None),
65
- "num_layers": getattr(config, 'num_hidden_layers', None),
66
- "num_attention_heads": getattr(config, 'num_attention_heads', None),
67
- "max_position_embeddings": getattr(config, 'max_position_embeddings', None),
68
- }
69
-
70
- return model_info
71
-
72
- except Exception as e:
73
- logger.warning(f"Could not get model info for {model_name}: {e}")
74
- return {"model_name": model_name, "error": str(e)}
75
-
76
- @staticmethod
77
- def estimate_memory_usage(
78
- model_name: str,
79
- batch_size: int = 1,
80
- max_length: int = 1024,
81
- use_lora: bool = True
82
- ) -> Dict[str, Any]:
83
- """Estimate memory usage for training."""
84
- try:
85
- model_info = TrainingUtils.get_model_info(model_name)
86
-
87
- # Rough estimation based on model parameters
88
- hidden_size = model_info.get('hidden_size', 4096)
89
- num_layers = model_info.get('num_layers', 32)
90
- vocab_size = model_info.get('vocab_size', 32000)
91
-
92
- # Estimate model parameters (in millions)
93
- param_count = (hidden_size * hidden_size * 12 * num_layers + vocab_size * hidden_size) / 1e6
94
-
95
- # Base memory for model (assuming fp16)
96
- model_memory_gb = param_count * 2 / 1024 # 2 bytes per parameter
97
-
98
- # Training memory overhead (gradients, optimizer states, activations)
99
- if use_lora:
100
- training_overhead = 2.0 # LoRA reduces memory usage significantly
101
- else:
102
- training_overhead = 4.0 # Full fine-tuning needs more memory
103
-
104
- # Batch and sequence length impact
105
- sequence_memory = batch_size * max_length * hidden_size * 2 / (1024**3) # Activation memory
106
-
107
- total_memory_gb = model_memory_gb * training_overhead + sequence_memory
108
-
109
- return {
110
- "estimated_params_millions": param_count,
111
- "model_memory_gb": model_memory_gb,
112
- "total_training_memory_gb": total_memory_gb,
113
- "recommended_gpu": TrainingUtils._recommend_gpu(total_memory_gb),
114
- "use_lora": use_lora,
115
- "batch_size": batch_size,
116
- "max_length": max_length
117
- }
118
-
119
- except Exception as e:
120
- logger.warning(f"Could not estimate memory usage: {e}")
121
- return {"error": str(e)}
122
-
123
- @staticmethod
124
- def _recommend_gpu(memory_gb: float) -> str:
125
- """Recommend GPU based on memory requirements."""
126
- if memory_gb <= 8:
127
- return "RTX 3080/4070 (8-12GB)"
128
- elif memory_gb <= 16:
129
- return "RTX 4080/4090 (16GB)"
130
- elif memory_gb <= 24:
131
- return "RTX A6000/4090 (24GB)"
132
- elif memory_gb <= 40:
133
- return "A100 40GB"
134
- elif memory_gb <= 80:
135
- return "A100 80GB"
136
- else:
137
- return "Multiple A100 80GB (Multi-GPU required)"
138
-
139
- @staticmethod
140
- def validate_training_config(config: Dict[str, Any]) -> List[str]:
141
- """Validate training configuration and return any issues."""
142
- issues = []
143
-
144
- # Check required fields
145
- required_fields = ["model_name", "output_dir"]
146
- for field in required_fields:
147
- if field not in config:
148
- issues.append(f"Missing required field: {field}")
149
-
150
- # Check batch size
151
- if config.get("batch_size", 0) <= 0:
152
- issues.append("batch_size must be positive")
153
-
154
- # Check learning rate
155
- lr = config.get("learning_rate", 0)
156
- if lr <= 0 or lr > 1:
157
- issues.append("learning_rate should be between 0 and 1")
158
-
159
- # Check epochs
160
- if config.get("num_epochs", 0) <= 0:
161
- issues.append("num_epochs must be positive")
162
-
163
- # Check LoRA config
164
- if config.get("use_lora", False):
165
- lora_rank = config.get("lora_rank", 8)
166
- if lora_rank <= 0 or lora_rank > 256:
167
- issues.append("lora_rank should be between 1 and 256")
168
-
169
- return issues
170
-
171
- @staticmethod
172
- def format_training_summary(
173
- config: Dict[str, Any],
174
- model_info: Dict[str, Any],
175
- memory_estimate: Dict[str, Any]
176
- ) -> str:
177
- """Format a training summary for display."""
178
- summary = []
179
- summary.append("=" * 60)
180
- summary.append("TRAINING CONFIGURATION SUMMARY")
181
- summary.append("=" * 60)
182
-
183
- # Model information
184
- summary.append(f"Model: {config.get('model_name', 'Unknown')}")
185
- summary.append(f"Model Type: {model_info.get('model_type', 'Unknown')}")
186
- summary.append(f"Parameters: ~{memory_estimate.get('estimated_params_millions', 0):.1f}M")
187
-
188
- # Training configuration
189
- summary.append(f"\nTraining Configuration:")
190
- summary.append(f" Training Type: {config.get('training_type', 'sft')}")
191
- summary.append(f" Epochs: {config.get('num_epochs', 3)}")
192
- summary.append(f" Batch Size: {config.get('batch_size', 4)}")
193
- summary.append(f" Learning Rate: {config.get('learning_rate', 2e-5)}")
194
- summary.append(f" Max Length: {config.get('max_length', 1024)}")
195
-
196
- # LoRA configuration
197
- if config.get('use_lora', True):
198
- summary.append(f"\nLoRA Configuration:")
199
- summary.append(f" LoRA Rank: {config.get('lora_rank', 8)}")
200
- summary.append(f" LoRA Alpha: {config.get('lora_alpha', 16)}")
201
- summary.append(f" LoRA Dropout: {config.get('lora_dropout', 0.05)}")
202
-
203
- # Memory estimation
204
- summary.append(f"\nMemory Estimation:")
205
- summary.append(f" Estimated Memory: ~{memory_estimate.get('total_training_memory_gb', 0):.1f}GB")
206
- summary.append(f" Recommended GPU: {memory_estimate.get('recommended_gpu', 'Unknown')}")
207
-
208
- # Output
209
- summary.append(f"\nOutput Directory: {config.get('output_dir', 'Unknown')}")
210
-
211
- summary.append("=" * 60)
212
-
213
- return "\n".join(summary)
@@ -1,424 +0,0 @@
1
- """
2
- ISA Model Training Factory
3
-
4
- A clean, simplified training factory that uses HuggingFace Transformers directly
5
- without external dependencies like LlamaFactory.
6
- """
7
-
8
- import os
9
- import logging
10
- from typing import Optional, Dict, Any, Union, List
11
- from pathlib import Path
12
- import datetime
13
-
14
- from .core import (
15
- TrainingConfig,
16
- LoRAConfig,
17
- DatasetConfig,
18
- BaseTrainer,
19
- SFTTrainer,
20
- TrainingUtils,
21
- DatasetManager,
22
- )
23
- from .cloud import TrainingJobOrchestrator
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- class TrainingFactory:
29
- """
30
- Unified Training Factory for ISA Model SDK
31
-
32
- Provides a clean interface for:
33
- - Local training with SFT (Supervised Fine-Tuning)
34
- - Cloud training on RunPod
35
- - Model evaluation and management
36
-
37
- Example usage:
38
- ```python
39
- from isa_model.training import TrainingFactory
40
-
41
- factory = TrainingFactory()
42
-
43
- # Local training
44
- model_path = factory.train_model(
45
- model_name="google/gemma-2-4b-it",
46
- dataset_path="tatsu-lab/alpaca",
47
- use_lora=True,
48
- num_epochs=3
49
- )
50
-
51
- # Cloud training on RunPod
52
- result = factory.train_on_runpod(
53
- model_name="google/gemma-2-4b-it",
54
- dataset_path="tatsu-lab/alpaca",
55
- runpod_api_key="your-api-key",
56
- template_id="your-template-id"
57
- )
58
- ```
59
- """
60
-
61
- def __init__(self, base_output_dir: Optional[str] = None):
62
- """
63
- Initialize the training factory.
64
-
65
- Args:
66
- base_output_dir: Base directory for training outputs
67
- """
68
- self.base_output_dir = base_output_dir or os.path.join(os.getcwd(), "training_outputs")
69
- os.makedirs(self.base_output_dir, exist_ok=True)
70
-
71
- logger.info(f"TrainingFactory initialized with output dir: {self.base_output_dir}")
72
-
73
- def train_model(
74
- self,
75
- model_name: str,
76
- dataset_path: str,
77
- output_dir: Optional[str] = None,
78
- training_type: str = "sft",
79
- dataset_format: str = "alpaca",
80
- use_lora: bool = True,
81
- batch_size: int = 4,
82
- num_epochs: int = 3,
83
- learning_rate: float = 2e-5,
84
- max_length: int = 1024,
85
- lora_rank: int = 8,
86
- lora_alpha: int = 16,
87
- validation_split: float = 0.1,
88
- **kwargs
89
- ) -> str:
90
- """
91
- Train a model locally.
92
-
93
- Args:
94
- model_name: Model identifier (e.g., "google/gemma-2-4b-it")
95
- dataset_path: Path to dataset or HuggingFace dataset name
96
- output_dir: Custom output directory
97
- training_type: Type of training ("sft" supported)
98
- dataset_format: Dataset format ("alpaca", "sharegpt", "custom")
99
- use_lora: Whether to use LoRA for efficient training
100
- batch_size: Training batch size
101
- num_epochs: Number of training epochs
102
- learning_rate: Learning rate
103
- max_length: Maximum sequence length
104
- lora_rank: LoRA rank parameter
105
- lora_alpha: LoRA alpha parameter
106
- validation_split: Fraction of data for validation
107
- **kwargs: Additional training parameters
108
-
109
- Returns:
110
- Path to the trained model
111
-
112
- Example:
113
- ```python
114
- model_path = factory.train_model(
115
- model_name="google/gemma-2-4b-it",
116
- dataset_path="tatsu-lab/alpaca",
117
- use_lora=True,
118
- num_epochs=3,
119
- batch_size=4
120
- )
121
- ```
122
- """
123
- # Generate output directory if not provided
124
- if not output_dir:
125
- output_dir = TrainingUtils.generate_output_dir(
126
- model_name, training_type, self.base_output_dir
127
- )
128
-
129
- # Create configurations
130
- lora_config = LoRAConfig(
131
- use_lora=use_lora,
132
- lora_rank=lora_rank,
133
- lora_alpha=lora_alpha
134
- ) if use_lora else None
135
-
136
- dataset_config = DatasetConfig(
137
- dataset_path=dataset_path,
138
- dataset_format=dataset_format,
139
- max_length=max_length,
140
- validation_split=validation_split
141
- )
142
-
143
- training_config = TrainingConfig(
144
- model_name=model_name,
145
- output_dir=output_dir,
146
- training_type=training_type,
147
- num_epochs=num_epochs,
148
- batch_size=batch_size,
149
- learning_rate=learning_rate,
150
- lora_config=lora_config,
151
- dataset_config=dataset_config,
152
- **kwargs
153
- )
154
-
155
- # Print training summary
156
- model_info = TrainingUtils.get_model_info(model_name)
157
- memory_estimate = TrainingUtils.estimate_memory_usage(
158
- model_name, batch_size, max_length, use_lora
159
- )
160
-
161
- summary = TrainingUtils.format_training_summary(
162
- training_config.to_dict(), model_info, memory_estimate
163
- )
164
- print(summary)
165
-
166
- # Validate configuration
167
- issues = TrainingUtils.validate_training_config(training_config.to_dict())
168
- if issues:
169
- raise ValueError(f"Training configuration issues: {issues}")
170
-
171
- # Initialize trainer based on training type
172
- if training_type.lower() == "sft":
173
- trainer = SFTTrainer(training_config)
174
- else:
175
- raise ValueError(f"Training type '{training_type}' not supported yet")
176
-
177
- # Execute training
178
- logger.info(f"Starting {training_type.upper()} training...")
179
- result_path = trainer.train()
180
-
181
- logger.info(f"Training completed! Model saved to: {result_path}")
182
- return result_path
183
-
184
- def train_on_runpod(
185
- self,
186
- model_name: str,
187
- dataset_path: str,
188
- runpod_api_key: str,
189
- template_id: str,
190
- gpu_type: str = "NVIDIA RTX A6000",
191
- storage_config: Optional[Dict[str, Any]] = None,
192
- job_name: Optional[str] = None,
193
- **training_params
194
- ) -> Dict[str, Any]:
195
- """
196
- Train a model on RunPod cloud infrastructure.
197
-
198
- Args:
199
- model_name: Model identifier
200
- dataset_path: Dataset path or HuggingFace dataset name
201
- runpod_api_key: RunPod API key
202
- template_id: RunPod template ID
203
- gpu_type: GPU type to use
204
- storage_config: Optional cloud storage configuration
205
- job_name: Optional job name
206
- **training_params: Additional training parameters
207
-
208
- Returns:
209
- Training job results
210
-
211
- Example:
212
- ```python
213
- result = factory.train_on_runpod(
214
- model_name="google/gemma-2-4b-it",
215
- dataset_path="tatsu-lab/alpaca",
216
- runpod_api_key="your-api-key",
217
- template_id="your-template-id",
218
- use_lora=True,
219
- num_epochs=3
220
- )
221
- ```
222
- """
223
- # Import cloud components
224
- from .cloud import TrainingJobOrchestrator
225
- from .cloud.runpod_trainer import RunPodConfig
226
- from .cloud.storage_manager import StorageConfig
227
- from .cloud.job_orchestrator import JobConfig
228
-
229
- # Create RunPod configuration
230
- runpod_config = RunPodConfig(
231
- api_key=runpod_api_key,
232
- template_id=template_id,
233
- gpu_type=gpu_type
234
- )
235
-
236
- # Create storage configuration if provided
237
- storage_cfg = None
238
- if storage_config:
239
- storage_cfg = StorageConfig(**storage_config)
240
-
241
- # Create job configuration
242
- job_config = JobConfig(
243
- model_name=model_name,
244
- dataset_source=dataset_path,
245
- job_name=job_name or f"gemma-training-{int(datetime.datetime.now().timestamp())}",
246
- **training_params
247
- )
248
-
249
- # Initialize orchestrator and execute training
250
- orchestrator = TrainingJobOrchestrator(
251
- runpod_config=runpod_config,
252
- storage_config=storage_cfg
253
- )
254
-
255
- logger.info(f"Starting RunPod training for {model_name}")
256
- result = orchestrator.execute_training_workflow(job_config)
257
-
258
- return result
259
-
260
- async def upload_to_huggingface(
261
- self,
262
- model_path: str,
263
- hf_model_name: str,
264
- hf_token: Optional[str] = None,
265
- metadata: Optional[Dict[str, Any]] = None
266
- ) -> str:
267
- """
268
- Upload a trained model to HuggingFace Hub using HuggingFaceStorage.
269
-
270
- Args:
271
- model_path: Path to the trained model
272
- hf_model_name: Name for the model on HuggingFace Hub
273
- hf_token: HuggingFace token
274
- metadata: Additional metadata for the model
275
-
276
- Returns:
277
- URL of the uploaded model
278
- """
279
- try:
280
- from ..core.storage.hf_storage import HuggingFaceStorage
281
-
282
- logger.info(f"Uploading model to HuggingFace: {hf_model_name}")
283
-
284
- # Initialize HuggingFace storage
285
- storage = HuggingFaceStorage(
286
- username="xenobordom",
287
- token=hf_token
288
- )
289
-
290
- # Prepare metadata
291
- upload_metadata = metadata or {}
292
- upload_metadata.update({
293
- "description": f"Fine-tuned model: {hf_model_name}",
294
- "training_framework": "ISA Model SDK",
295
- "uploaded_from": "training_factory"
296
- })
297
-
298
- # Upload model
299
- success = await storage.save_model(
300
- model_id=hf_model_name,
301
- model_path=model_path,
302
- metadata=upload_metadata
303
- )
304
-
305
- if success:
306
- model_url = storage.get_public_url(hf_model_name)
307
- logger.info(f"Model uploaded successfully: {model_url}")
308
- return model_url
309
- else:
310
- raise Exception("Failed to upload model")
311
-
312
- except Exception as e:
313
- logger.error(f"Failed to upload to HuggingFace: {e}")
314
- raise
315
-
316
- def get_training_status(self, output_dir: str) -> Dict[str, Any]:
317
- """
318
- Get training status from output directory.
319
-
320
- Args:
321
- output_dir: Training output directory
322
-
323
- Returns:
324
- Dictionary with training status information
325
- """
326
- status = {
327
- "output_dir": output_dir,
328
- "exists": os.path.exists(output_dir),
329
- "files": []
330
- }
331
-
332
- if status["exists"]:
333
- status["files"] = os.listdir(output_dir)
334
-
335
- # Check for specific files
336
- config_path = os.path.join(output_dir, "training_config.json")
337
- metrics_path = os.path.join(output_dir, "training_metrics.json")
338
- model_path = os.path.join(output_dir, "pytorch_model.bin")
339
-
340
- status["has_config"] = os.path.exists(config_path)
341
- status["has_metrics"] = os.path.exists(metrics_path)
342
- status["has_model"] = os.path.exists(model_path) or os.path.exists(os.path.join(output_dir, "adapter_model.bin"))
343
-
344
- if status["has_config"]:
345
- try:
346
- status["config"] = TrainingUtils.load_training_args(output_dir)
347
- except:
348
- pass
349
-
350
- return status
351
-
352
- def list_trained_models(self) -> List[Dict[str, Any]]:
353
- """
354
- List all trained models in the output directory.
355
-
356
- Returns:
357
- List of model information dictionaries
358
- """
359
- models = []
360
-
361
- if os.path.exists(self.base_output_dir):
362
- for item in os.listdir(self.base_output_dir):
363
- item_path = os.path.join(self.base_output_dir, item)
364
- if os.path.isdir(item_path):
365
- status = self.get_training_status(item_path)
366
- models.append({
367
- "name": item,
368
- "path": item_path,
369
- "created": datetime.datetime.fromtimestamp(
370
- os.path.getctime(item_path)
371
- ).isoformat(),
372
- "status": status
373
- })
374
-
375
- return sorted(models, key=lambda x: x["created"], reverse=True)
376
-
377
-
378
- # Convenience functions for quick access
379
- def train_gemma(
380
- dataset_path: str,
381
- model_size: str = "4b",
382
- output_dir: Optional[str] = None,
383
- **kwargs
384
- ) -> str:
385
- """
386
- Quick function to train Gemma models.
387
-
388
- Args:
389
- dataset_path: Path to training dataset
390
- model_size: Model size ("2b", "4b", "7b")
391
- output_dir: Output directory
392
- **kwargs: Additional training parameters
393
-
394
- Returns:
395
- Path to trained model
396
-
397
- Example:
398
- ```python
399
- from isa_model.training import train_gemma
400
-
401
- model_path = train_gemma(
402
- dataset_path="tatsu-lab/alpaca",
403
- model_size="4b",
404
- num_epochs=3,
405
- batch_size=4
406
- )
407
- ```
408
- """
409
- factory = TrainingFactory()
410
-
411
- model_map = {
412
- "2b": "google/gemma-2-2b-it",
413
- "4b": "google/gemma-2-4b-it",
414
- "7b": "google/gemma-2-7b-it"
415
- }
416
-
417
- model_name = model_map.get(model_size, "google/gemma-2-4b-it")
418
-
419
- return factory.train_model(
420
- model_name=model_name,
421
- dataset_path=dataset_path,
422
- output_dir=output_dir,
423
- **kwargs
424
- )