isa-model 0.3.91__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. isa_model/client.py +1166 -584
  2. isa_model/core/cache/redis_cache.py +410 -0
  3. isa_model/core/config/config_manager.py +282 -12
  4. isa_model/core/config.py +91 -1
  5. isa_model/core/database/__init__.py +1 -0
  6. isa_model/core/database/direct_db_client.py +114 -0
  7. isa_model/core/database/migration_manager.py +563 -0
  8. isa_model/core/database/migrations.py +297 -0
  9. isa_model/core/database/supabase_client.py +258 -0
  10. isa_model/core/dependencies.py +316 -0
  11. isa_model/core/discovery/__init__.py +19 -0
  12. isa_model/core/discovery/consul_discovery.py +190 -0
  13. isa_model/core/logging/__init__.py +54 -0
  14. isa_model/core/logging/influx_logger.py +523 -0
  15. isa_model/core/logging/loki_logger.py +160 -0
  16. isa_model/core/models/__init__.py +46 -0
  17. isa_model/core/models/config_models.py +625 -0
  18. isa_model/core/models/deployment_billing_tracker.py +430 -0
  19. isa_model/core/models/model_billing_tracker.py +60 -88
  20. isa_model/core/models/model_manager.py +66 -25
  21. isa_model/core/models/model_metadata.py +690 -0
  22. isa_model/core/models/model_repo.py +217 -55
  23. isa_model/core/models/model_statistics_tracker.py +234 -0
  24. isa_model/core/models/model_storage.py +0 -1
  25. isa_model/core/models/model_version_manager.py +959 -0
  26. isa_model/core/models/system_models.py +857 -0
  27. isa_model/core/pricing_manager.py +2 -249
  28. isa_model/core/repositories/__init__.py +9 -0
  29. isa_model/core/repositories/config_repository.py +912 -0
  30. isa_model/core/resilience/circuit_breaker.py +366 -0
  31. isa_model/core/security/secrets.py +358 -0
  32. isa_model/core/services/__init__.py +2 -4
  33. isa_model/core/services/intelligent_model_selector.py +479 -370
  34. isa_model/core/storage/hf_storage.py +2 -2
  35. isa_model/core/types.py +8 -0
  36. isa_model/deployment/__init__.py +5 -48
  37. isa_model/deployment/core/__init__.py +2 -31
  38. isa_model/deployment/core/deployment_manager.py +1278 -368
  39. isa_model/deployment/local/__init__.py +31 -0
  40. isa_model/deployment/local/config.py +248 -0
  41. isa_model/deployment/local/gpu_gateway.py +607 -0
  42. isa_model/deployment/local/health_checker.py +428 -0
  43. isa_model/deployment/local/provider.py +586 -0
  44. isa_model/deployment/local/tensorrt_service.py +621 -0
  45. isa_model/deployment/local/transformers_service.py +644 -0
  46. isa_model/deployment/local/vllm_service.py +527 -0
  47. isa_model/deployment/modal/__init__.py +8 -0
  48. isa_model/deployment/modal/config.py +136 -0
  49. isa_model/deployment/modal/deployer.py +894 -0
  50. isa_model/deployment/modal/services/__init__.py +3 -0
  51. isa_model/deployment/modal/services/audio/__init__.py +1 -0
  52. isa_model/deployment/modal/services/audio/isa_audio_chatTTS_service.py +520 -0
  53. isa_model/deployment/modal/services/audio/isa_audio_openvoice_service.py +758 -0
  54. isa_model/deployment/modal/services/audio/isa_audio_service_v2.py +1044 -0
  55. isa_model/deployment/modal/services/embedding/__init__.py +1 -0
  56. isa_model/deployment/modal/services/embedding/isa_embed_rerank_service.py +296 -0
  57. isa_model/deployment/modal/services/llm/__init__.py +1 -0
  58. isa_model/deployment/modal/services/llm/isa_llm_service.py +424 -0
  59. isa_model/deployment/modal/services/video/__init__.py +1 -0
  60. isa_model/deployment/modal/services/video/isa_video_hunyuan_service.py +423 -0
  61. isa_model/deployment/modal/services/vision/__init__.py +1 -0
  62. isa_model/deployment/modal/services/vision/isa_vision_ocr_service.py +519 -0
  63. isa_model/deployment/modal/services/vision/isa_vision_qwen25_service.py +709 -0
  64. isa_model/deployment/modal/services/vision/isa_vision_table_service.py +676 -0
  65. isa_model/deployment/modal/services/vision/isa_vision_ui_service.py +833 -0
  66. isa_model/deployment/modal/services/vision/isa_vision_ui_service_optimized.py +660 -0
  67. isa_model/deployment/models/org-org-acme-corp-tenant-a-service-llm-20250825-225822/tenant-a-service_modal_service.py +48 -0
  68. isa_model/deployment/models/org-test-org-123-prefix-test-service-llm-20250825-225822/prefix-test-service_modal_service.py +48 -0
  69. isa_model/deployment/models/test-llm-service-llm-20250825-204442/test-llm-service_modal_service.py +48 -0
  70. isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-212906/test-monitoring-gpt2_modal_service.py +48 -0
  71. isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-213009/test-monitoring-gpt2_modal_service.py +48 -0
  72. isa_model/deployment/storage/__init__.py +5 -0
  73. isa_model/deployment/storage/deployment_repository.py +824 -0
  74. isa_model/deployment/triton/__init__.py +10 -0
  75. isa_model/deployment/triton/config.py +196 -0
  76. isa_model/deployment/triton/configs/__init__.py +1 -0
  77. isa_model/deployment/triton/provider.py +512 -0
  78. isa_model/deployment/triton/scripts/__init__.py +1 -0
  79. isa_model/deployment/triton/templates/__init__.py +1 -0
  80. isa_model/inference/__init__.py +47 -1
  81. isa_model/inference/ai_factory.py +179 -16
  82. isa_model/inference/legacy_services/__init__.py +21 -0
  83. isa_model/inference/legacy_services/model_evaluation.py +637 -0
  84. isa_model/inference/legacy_services/model_service.py +573 -0
  85. isa_model/inference/legacy_services/model_serving.py +717 -0
  86. isa_model/inference/legacy_services/model_training.py +561 -0
  87. isa_model/inference/models/__init__.py +21 -0
  88. isa_model/inference/models/inference_config.py +551 -0
  89. isa_model/inference/models/inference_record.py +675 -0
  90. isa_model/inference/models/performance_models.py +714 -0
  91. isa_model/inference/repositories/__init__.py +9 -0
  92. isa_model/inference/repositories/inference_repository.py +828 -0
  93. isa_model/inference/services/audio/__init__.py +21 -0
  94. isa_model/inference/services/audio/base_realtime_service.py +225 -0
  95. isa_model/inference/services/audio/base_stt_service.py +184 -11
  96. isa_model/inference/services/audio/isa_tts_service.py +0 -0
  97. isa_model/inference/services/audio/openai_realtime_service.py +320 -124
  98. isa_model/inference/services/audio/openai_stt_service.py +53 -11
  99. isa_model/inference/services/base_service.py +17 -1
  100. isa_model/inference/services/custom_model_manager.py +277 -0
  101. isa_model/inference/services/embedding/__init__.py +13 -0
  102. isa_model/inference/services/embedding/base_embed_service.py +111 -8
  103. isa_model/inference/services/embedding/isa_embed_service.py +305 -0
  104. isa_model/inference/services/embedding/ollama_embed_service.py +15 -3
  105. isa_model/inference/services/embedding/openai_embed_service.py +2 -4
  106. isa_model/inference/services/embedding/resilient_embed_service.py +285 -0
  107. isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
  108. isa_model/inference/services/img/__init__.py +2 -2
  109. isa_model/inference/services/img/base_image_gen_service.py +24 -7
  110. isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
  111. isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
  112. isa_model/inference/services/img/services/replicate_flux.py +226 -0
  113. isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
  114. isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
  115. isa_model/inference/services/img/tests/test_img_client.py +297 -0
  116. isa_model/inference/services/llm/__init__.py +10 -2
  117. isa_model/inference/services/llm/base_llm_service.py +361 -26
  118. isa_model/inference/services/llm/cerebras_llm_service.py +628 -0
  119. isa_model/inference/services/llm/helpers/llm_adapter.py +71 -12
  120. isa_model/inference/services/llm/helpers/llm_prompts.py +342 -0
  121. isa_model/inference/services/llm/helpers/llm_utils.py +321 -23
  122. isa_model/inference/services/llm/huggingface_llm_service.py +581 -0
  123. isa_model/inference/services/llm/local_llm_service.py +747 -0
  124. isa_model/inference/services/llm/ollama_llm_service.py +11 -3
  125. isa_model/inference/services/llm/openai_llm_service.py +670 -56
  126. isa_model/inference/services/llm/yyds_llm_service.py +10 -3
  127. isa_model/inference/services/vision/__init__.py +27 -6
  128. isa_model/inference/services/vision/base_vision_service.py +118 -185
  129. isa_model/inference/services/vision/blip_vision_service.py +359 -0
  130. isa_model/inference/services/vision/helpers/image_utils.py +19 -10
  131. isa_model/inference/services/vision/isa_vision_service.py +634 -0
  132. isa_model/inference/services/vision/openai_vision_service.py +19 -10
  133. isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
  134. isa_model/inference/services/vision/vgg16_vision_service.py +257 -0
  135. isa_model/serving/api/cache_manager.py +245 -0
  136. isa_model/serving/api/dependencies/__init__.py +1 -0
  137. isa_model/serving/api/dependencies/auth.py +194 -0
  138. isa_model/serving/api/dependencies/database.py +139 -0
  139. isa_model/serving/api/error_handlers.py +284 -0
  140. isa_model/serving/api/fastapi_server.py +240 -18
  141. isa_model/serving/api/middleware/auth.py +317 -0
  142. isa_model/serving/api/middleware/security.py +268 -0
  143. isa_model/serving/api/middleware/tenant_context.py +414 -0
  144. isa_model/serving/api/routes/analytics.py +489 -0
  145. isa_model/serving/api/routes/config.py +645 -0
  146. isa_model/serving/api/routes/deployment_billing.py +315 -0
  147. isa_model/serving/api/routes/deployments.py +475 -0
  148. isa_model/serving/api/routes/gpu_gateway.py +440 -0
  149. isa_model/serving/api/routes/health.py +32 -12
  150. isa_model/serving/api/routes/inference_monitoring.py +486 -0
  151. isa_model/serving/api/routes/local_deployments.py +448 -0
  152. isa_model/serving/api/routes/logs.py +430 -0
  153. isa_model/serving/api/routes/settings.py +582 -0
  154. isa_model/serving/api/routes/tenants.py +575 -0
  155. isa_model/serving/api/routes/unified.py +992 -171
  156. isa_model/serving/api/routes/webhooks.py +479 -0
  157. isa_model/serving/api/startup.py +318 -0
  158. isa_model/serving/modal_proxy_server.py +249 -0
  159. isa_model/utils/gpu_utils.py +311 -0
  160. {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/METADATA +76 -22
  161. isa_model-0.4.3.dist-info/RECORD +193 -0
  162. isa_model/deployment/cloud/__init__.py +0 -9
  163. isa_model/deployment/cloud/modal/__init__.py +0 -10
  164. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
  165. isa_model/deployment/cloud/modal/isa_vision_table_service.py +0 -532
  166. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +0 -406
  167. isa_model/deployment/cloud/modal/register_models.py +0 -321
  168. isa_model/deployment/core/deployment_config.py +0 -356
  169. isa_model/deployment/core/isa_deployment_service.py +0 -401
  170. isa_model/deployment/gpu_int8_ds8/app/server.py +0 -66
  171. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +0 -43
  172. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +0 -35
  173. isa_model/deployment/runtime/deployed_service.py +0 -338
  174. isa_model/deployment/services/__init__.py +0 -9
  175. isa_model/deployment/services/auto_deploy_vision_service.py +0 -538
  176. isa_model/deployment/services/model_service.py +0 -332
  177. isa_model/deployment/services/service_monitor.py +0 -356
  178. isa_model/deployment/services/service_registry.py +0 -527
  179. isa_model/eval/__init__.py +0 -92
  180. isa_model/eval/benchmarks.py +0 -469
  181. isa_model/eval/config/__init__.py +0 -10
  182. isa_model/eval/config/evaluation_config.py +0 -108
  183. isa_model/eval/evaluators/__init__.py +0 -18
  184. isa_model/eval/evaluators/base_evaluator.py +0 -503
  185. isa_model/eval/evaluators/llm_evaluator.py +0 -472
  186. isa_model/eval/factory.py +0 -531
  187. isa_model/eval/infrastructure/__init__.py +0 -24
  188. isa_model/eval/infrastructure/experiment_tracker.py +0 -466
  189. isa_model/eval/metrics.py +0 -798
  190. isa_model/inference/adapter/unified_api.py +0 -248
  191. isa_model/inference/services/helpers/stacked_config.py +0 -148
  192. isa_model/inference/services/img/flux_professional_service.py +0 -603
  193. isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
  194. isa_model/inference/services/others/table_transformer_service.py +0 -61
  195. isa_model/inference/services/vision/doc_analysis_service.py +0 -640
  196. isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
  197. isa_model/inference/services/vision/ui_analysis_service.py +0 -823
  198. isa_model/scripts/inference_tracker.py +0 -283
  199. isa_model/scripts/mlflow_manager.py +0 -379
  200. isa_model/scripts/model_registry.py +0 -465
  201. isa_model/scripts/register_models.py +0 -370
  202. isa_model/scripts/register_models_with_embeddings.py +0 -510
  203. isa_model/scripts/start_mlflow.py +0 -95
  204. isa_model/scripts/training_tracker.py +0 -257
  205. isa_model/training/__init__.py +0 -74
  206. isa_model/training/annotation/annotation_schema.py +0 -47
  207. isa_model/training/annotation/processors/annotation_processor.py +0 -126
  208. isa_model/training/annotation/storage/dataset_manager.py +0 -131
  209. isa_model/training/annotation/storage/dataset_schema.py +0 -44
  210. isa_model/training/annotation/tests/test_annotation_flow.py +0 -109
  211. isa_model/training/annotation/tests/test_minio copy.py +0 -113
  212. isa_model/training/annotation/tests/test_minio_upload.py +0 -43
  213. isa_model/training/annotation/views/annotation_controller.py +0 -158
  214. isa_model/training/cloud/__init__.py +0 -22
  215. isa_model/training/cloud/job_orchestrator.py +0 -402
  216. isa_model/training/cloud/runpod_trainer.py +0 -454
  217. isa_model/training/cloud/storage_manager.py +0 -482
  218. isa_model/training/core/__init__.py +0 -23
  219. isa_model/training/core/config.py +0 -181
  220. isa_model/training/core/dataset.py +0 -222
  221. isa_model/training/core/trainer.py +0 -720
  222. isa_model/training/core/utils.py +0 -213
  223. isa_model/training/factory.py +0 -424
  224. isa_model-0.3.91.dist-info/RECORD +0 -138
  225. /isa_model/{core/storage/minio_storage.py → deployment/modal/services/audio/isa_audio_fish_service.py} +0 -0
  226. /isa_model/deployment/{services → modal/services/vision}/simple_auto_deploy_vision_service.py +0 -0
  227. {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/WHEEL +0 -0
  228. {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/top_level.txt +0 -0
@@ -1,257 +0,0 @@
1
- """
2
- MLflow tracker for training workflows.
3
- """
4
-
5
- import os
6
- import json
7
- import logging
8
- from typing import Dict, List, Optional, Any, Union
9
- from contextlib import contextmanager
10
-
11
- from .mlflow_manager import MLflowManager, ExperimentType
12
- from .model_registry import ModelRegistry, ModelStage
13
-
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- class TrainingTracker:
19
- """
20
- Tracker for model training workflows.
21
-
22
- This class provides utilities to track model training using MLflow
23
- and register trained models in the model registry.
24
-
25
- Example:
26
- ```python
27
- # Initialize tracker
28
- tracker = TrainingTracker(
29
- tracking_uri="http://localhost:5000",
30
- registry_uri="http://localhost:5000"
31
- )
32
-
33
- # Start tracking training
34
- with tracker.track_training_run(
35
- model_name="llama-7b",
36
- training_params={
37
- "learning_rate": 2e-5,
38
- "batch_size": 8,
39
- "epochs": 3
40
- }
41
- ) as run_info:
42
- # Train the model...
43
-
44
- # Log metrics during training
45
- tracker.log_metrics({
46
- "train_loss": 0.1,
47
- "val_loss": 0.2
48
- })
49
-
50
- # After training completes
51
- model_path = "/path/to/trained_model"
52
-
53
- # Register the model
54
- tracker.register_trained_model(
55
- model_path=model_path,
56
- metrics={
57
- "accuracy": 0.95,
58
- "f1": 0.92
59
- },
60
- stage=ModelStage.STAGING
61
- )
62
- ```
63
- """
64
-
65
- def __init__(
66
- self,
67
- tracking_uri: Optional[str] = None,
68
- artifact_uri: Optional[str] = None,
69
- registry_uri: Optional[str] = None
70
- ):
71
- """
72
- Initialize the training tracker.
73
-
74
- Args:
75
- tracking_uri: URI for MLflow tracking server
76
- artifact_uri: URI for MLflow artifacts
77
- registry_uri: URI for MLflow model registry
78
- """
79
- self.mlflow_manager = MLflowManager(
80
- tracking_uri=tracking_uri,
81
- artifact_uri=artifact_uri,
82
- registry_uri=registry_uri
83
- )
84
- self.model_registry = ModelRegistry(
85
- tracking_uri=tracking_uri,
86
- registry_uri=registry_uri
87
- )
88
- self.current_run_info = {}
89
-
90
- @contextmanager
91
- def track_training_run(
92
- self,
93
- model_name: str,
94
- training_params: Dict[str, Any],
95
- description: Optional[str] = None,
96
- tags: Optional[Dict[str, str]] = None,
97
- experiment_type: ExperimentType = ExperimentType.TRAINING
98
- ):
99
- """
100
- Track a training run with MLflow.
101
-
102
- Args:
103
- model_name: Name of the model being trained
104
- training_params: Parameters for the training run
105
- description: Description of the training run
106
- tags: Tags for the training run
107
- experiment_type: Type of experiment
108
-
109
- Yields:
110
- Dictionary with run information
111
- """
112
- run_info = {
113
- "model_name": model_name,
114
- "params": training_params,
115
- "metrics": {}
116
- }
117
-
118
- # Add description to tags if provided
119
- if tags is None:
120
- tags = {}
121
-
122
- if description:
123
- tags["description"] = description
124
-
125
- # Start the MLflow run
126
- with self.mlflow_manager.start_run(
127
- experiment_type=experiment_type,
128
- model_name=model_name,
129
- tags=tags
130
- ) as run:
131
- run_info["run_id"] = run.info.run_id
132
- run_info["experiment_id"] = run.info.experiment_id
133
- run_info["status"] = "running"
134
-
135
- # Save parameters
136
- self.mlflow_manager.log_params(training_params)
137
-
138
- self.current_run_info = run_info
139
- try:
140
- yield run_info
141
- # Mark as successful if no exceptions
142
- run_info["status"] = "completed"
143
- except Exception as e:
144
- # Mark as failed if exception occurred
145
- run_info["status"] = "failed"
146
- run_info["error"] = str(e)
147
- self.mlflow_manager.set_tracking_tag("error", str(e))
148
- raise
149
- finally:
150
- self.current_run_info = {}
151
-
152
- def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
153
- """
154
- Log metrics to the current run.
155
-
156
- Args:
157
- metrics: Dictionary of metrics to log
158
- step: Step value for the metrics
159
- """
160
- self.mlflow_manager.log_metrics(metrics, step)
161
-
162
- if self.current_run_info:
163
- if "metrics" not in self.current_run_info:
164
- self.current_run_info["metrics"] = {}
165
-
166
- # Only keep the latest metrics
167
- self.current_run_info["metrics"].update(metrics)
168
-
169
- def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
170
- """
171
- Log artifacts to the current run.
172
-
173
- Args:
174
- local_dir: Local directory containing artifacts
175
- artifact_path: Path for the artifacts in MLflow
176
- """
177
- self.mlflow_manager.log_artifacts(local_dir, artifact_path)
178
-
179
- def register_trained_model(
180
- self,
181
- model_path: str,
182
- metrics: Optional[Dict[str, float]] = None,
183
- description: Optional[str] = None,
184
- tags: Optional[Dict[str, str]] = None,
185
- stage: Optional[ModelStage] = None,
186
- flavor: str = "pyfunc"
187
- ) -> Optional[str]:
188
- """
189
- Register a trained model with MLflow.
190
-
191
- Args:
192
- model_path: Path to the trained model
193
- metrics: Evaluation metrics for the model
194
- description: Description of the model
195
- tags: Tags for the model
196
- stage: Stage to register the model in
197
- flavor: MLflow model flavor
198
-
199
- Returns:
200
- Version of the registered model or None if registration failed
201
- """
202
- if not self.current_run_info:
203
- logger.warning("No active run. Model cannot be registered.")
204
- return None
205
-
206
- model_name = self.current_run_info.get("model_name")
207
- if not model_name:
208
- logger.warning("Model name not available in run info. Using generic name.")
209
- model_name = "unnamed_model"
210
-
211
- # Log final metrics if provided
212
- if metrics:
213
- self.log_metrics(metrics)
214
-
215
- # Prepare model tags
216
- if tags is None:
217
- tags = {}
218
-
219
- # Add run ID to tags
220
- tags["run_id"] = self.current_run_info.get("run_id", "")
221
-
222
- # Add metrics to tags
223
- for k, v in self.current_run_info.get("metrics", {}).items():
224
- tags[f"metric.{k}"] = str(v)
225
-
226
- # Log model to MLflow
227
- artifact_path = self.mlflow_manager.log_model(
228
- model_path=model_path,
229
- name=model_name,
230
- flavor=flavor
231
- )
232
-
233
- if not artifact_path:
234
- logger.error("Failed to log model to MLflow.")
235
- return None
236
-
237
- # Get model URI
238
- run_id = self.current_run_info.get("run_id")
239
- model_uri = f"runs:/{run_id}/{artifact_path}"
240
-
241
- # Register the model
242
- version = self.model_registry.register_model(
243
- name=model_name,
244
- source=model_uri,
245
- description=description,
246
- tags=tags
247
- )
248
-
249
- # Transition to the specified stage if provided
250
- if version and stage:
251
- self.model_registry.transition_model_version_stage(
252
- name=model_name,
253
- version=version,
254
- stage=stage
255
- )
256
-
257
- return version
@@ -1,74 +0,0 @@
1
- """
2
- ISA Model Training Module
3
-
4
- Provides unified training capabilities for AI models including:
5
- - Local training with SFT (Supervised Fine-Tuning)
6
- - Cloud training on RunPod
7
- - Model evaluation and management
8
- - HuggingFace integration
9
-
10
- Example usage:
11
- ```python
12
- from isa_model.training import TrainingFactory, train_gemma
13
-
14
- # Quick Gemma training
15
- model_path = train_gemma(
16
- dataset_path="tatsu-lab/alpaca",
17
- model_size="4b",
18
- num_epochs=3
19
- )
20
-
21
- # Advanced training with custom configuration
22
- factory = TrainingFactory()
23
- model_path = factory.train_model(
24
- model_name="google/gemma-2-4b-it",
25
- dataset_path="your-dataset.json",
26
- use_lora=True,
27
- batch_size=4,
28
- num_epochs=3
29
- )
30
- ```
31
- """
32
-
33
- # Import the new clean factory
34
- from .factory import TrainingFactory, train_gemma
35
-
36
- # Import core components
37
- from .core import (
38
- TrainingConfig,
39
- LoRAConfig,
40
- DatasetConfig,
41
- BaseTrainer,
42
- SFTTrainer,
43
- TrainingUtils,
44
- DatasetManager
45
- )
46
-
47
- # Import cloud training components
48
- from .cloud import (
49
- RunPodConfig,
50
- StorageConfig,
51
- JobConfig,
52
- TrainingJobOrchestrator
53
- )
54
-
55
- __all__ = [
56
- # Main factory
57
- 'TrainingFactory',
58
- 'train_gemma',
59
-
60
- # Core components
61
- 'TrainingConfig',
62
- 'LoRAConfig',
63
- 'DatasetConfig',
64
- 'BaseTrainer',
65
- 'SFTTrainer',
66
- 'TrainingUtils',
67
- 'DatasetManager',
68
-
69
- # Cloud components
70
- 'RunPodConfig',
71
- 'StorageConfig',
72
- 'JobConfig',
73
- 'TrainingJobOrchestrator'
74
- ]
@@ -1,47 +0,0 @@
1
- # app/services/llm_model/tracing/annotation/annotation_schema.py
2
- from enum import Enum
3
- from pydantic import BaseModel, Field
4
- from typing import Dict, Any, List, Optional
5
- from datetime import datetime
6
-
7
- class AnnotationType(str, Enum):
8
- ACCURACY = "accuracy"
9
- HELPFULNESS = "helpfulness"
10
- TOXICITY = "toxicity"
11
- CUSTOM = "custom"
12
-
13
- class RatingScale(int, Enum):
14
- POOR = 1
15
- FAIR = 2
16
- GOOD = 3
17
- EXCELLENT = 4
18
-
19
- class AnnotationAspects(BaseModel):
20
- factually_correct: bool = True
21
- relevant: bool = True
22
- harmful: bool = False
23
- biased: bool = False
24
- complete: bool = True
25
- efficient: bool = True
26
-
27
- class BetterResponse(BaseModel):
28
- content: str
29
- reason: Optional[str]
30
- metadata: Optional[Dict[str, Any]] = {}
31
-
32
- class AnnotationFeedback(BaseModel):
33
- rating: RatingScale
34
- category: AnnotationType
35
- aspects: AnnotationAspects
36
- better_response: Optional[BetterResponse]
37
- comment: Optional[str]
38
- metadata: Optional[Dict[str, Any]] = {}
39
- is_selected_for_training: bool = False
40
-
41
- class ItemAnnotation(BaseModel):
42
- item_id: str
43
- feedback: Optional[AnnotationFeedback]
44
- status: str = "pending"
45
- annotated_at: Optional[datetime]
46
- annotator_id: Optional[str]
47
- training_status: Optional[str] = None
@@ -1,126 +0,0 @@
1
- from typing import Dict, Any, List
2
- from datetime import datetime
3
- from app.config.config_manager import config_manager
4
- from app.services.training.llm_model.annotation.annotation_schema import AnnotationFeedback, RatingScale, AnnotationAspects
5
- from bson.objectid import ObjectId
6
- from app.services.training.llm_model.annotation.storage.dataset_manager import DatasetManager
7
-
8
- class AnnotationProcessor:
9
- def __init__(self):
10
- self.logger = config_manager.get_logger(__name__)
11
- self.dataset_manager = DatasetManager()
12
- self.batch_size = 1000 # Configure as needed
13
-
14
- async def process_queue(self) -> None:
15
- """Process pending items and create datasets"""
16
- db = await config_manager.get_db('mongodb')
17
- queue = db['training_queue']
18
-
19
- # Process SFT items
20
- sft_items = await self._get_pending_items("sft")
21
- if len(sft_items) >= self.batch_size:
22
- await self._create_sft_dataset(sft_items)
23
-
24
- # Process RLHF items
25
- rlhf_items = await self._get_pending_items("rlhf")
26
- if len(rlhf_items) >= self.batch_size:
27
- await self._create_rlhf_dataset(rlhf_items)
28
-
29
- async def _create_sft_dataset(self, items: List[Dict[str, Any]]):
30
- """Create and upload SFT dataset"""
31
- dataset = await self.dataset_manager.create_dataset(
32
- name=f"sft_dataset_v{datetime.now().strftime('%Y%m%d')}",
33
- type="sft",
34
- version=datetime.now().strftime("%Y%m%d"),
35
- source_annotations=[item["annotation_id"] for item in items]
36
- )
37
-
38
- formatted_data = [
39
- await self._process_sft_item(item)
40
- for item in items
41
- ]
42
-
43
- await self.dataset_manager.upload_dataset_file(
44
- dataset.id,
45
- formatted_data
46
- )
47
-
48
- async def _process_sft_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
49
- """Process item for SFT dataset generation
50
- Format follows HF conversation format for SFT training
51
- """
52
- db = await config_manager.get_db('mongodb')
53
- annotations = db['annotations']
54
-
55
- # Get full annotation context
56
- annotation = await annotations.find_one({"_id": ObjectId(item["annotation_id"])})
57
- target_item = next(i for i in annotation["items"] if i["item_id"] == item["item_id"])
58
-
59
- # Format as conversation
60
- messages = [
61
- {
62
- "role": "system",
63
- "content": "You are a helpful AI assistant that provides accurate and relevant information."
64
- },
65
- {
66
- "role": "user",
67
- "content": target_item["input"]["messages"][0]["content"]
68
- },
69
- {
70
- "role": "assistant",
71
- "content": target_item["output"]["content"]
72
- }
73
- ]
74
-
75
- return {
76
- "messages": messages,
77
- "metadata": {
78
- "rating": item["feedback"]["rating"],
79
- "aspects": item["feedback"]["aspects"],
80
- "category": item["feedback"]["category"]
81
- }
82
- }
83
-
84
- async def _process_rlhf_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
85
- """Process item for RLHF dataset generation
86
- Format follows preference pairs structure for RLHF training
87
- """
88
- db = await config_manager.get_db('mongodb')
89
- annotations = db['annotations']
90
-
91
- # Get full annotation context
92
- annotation = await annotations.find_one({"_id": ObjectId(item["annotation_id"])})
93
- target_item = next(i for i in annotation["items"] if i["item_id"] == item["item_id"])
94
-
95
- # Format as preference pairs
96
- return {
97
- "prompt": target_item["input"]["messages"][0]["content"],
98
- "chosen": item["feedback"]["better_response"]["content"],
99
- "rejected": target_item["output"]["content"],
100
- "metadata": {
101
- "reason": item["feedback"]["better_response"]["reason"],
102
- "category": item["feedback"]["category"]
103
- }
104
- }
105
-
106
- async def get_training_data(
107
- self,
108
- data_type: str,
109
- limit: int = 1000
110
- ) -> List[Dict[str, Any]]:
111
- """Retrieve formatted training data"""
112
- db = await config_manager.get_db('mongodb')
113
- training_data = db['training_data']
114
-
115
- data = await training_data.find(
116
- {"type": data_type}
117
- ).limit(limit).to_list(length=limit)
118
-
119
- if data_type == "sft":
120
- return [item["data"]["messages"] for item in data]
121
- else: # rlhf
122
- return [{
123
- "prompt": item["data"]["prompt"],
124
- "chosen": item["data"]["chosen"],
125
- "rejected": item["data"]["rejected"]
126
- } for item in data]
@@ -1,131 +0,0 @@
1
- # app/services/llm_model/annotation/dataset/dataset_manager.py
2
- from typing import Dict, Any, List
3
- from datetime import datetime
4
- import json
5
- import io
6
- from app.config.config_manager import config_manager
7
- from .dataset_schema import Dataset, DatasetType, DatasetStatus, DatasetFiles, DatasetStats
8
- from bson import ObjectId
9
-
10
- class DatasetManager:
11
- def __init__(self):
12
- self.logger = config_manager.get_logger(__name__)
13
- self.minio_client = None
14
- self.bucket_name = "training-datasets"
15
-
16
- async def _ensure_minio_client(self):
17
- if not self.minio_client:
18
- self.minio_client = await config_manager.get_storage_client()
19
-
20
- async def create_dataset(
21
- self,
22
- name: str,
23
- type: DatasetType,
24
- version: str,
25
- source_annotations: List[str]
26
- ) -> Dataset:
27
- """Create a new dataset record"""
28
- db = await config_manager.get_db('mongodb')
29
- collection = db['training_datasets']
30
-
31
- dataset = Dataset(
32
- name=name,
33
- type=type,
34
- version=version,
35
- storage_path=f"datasets/{type.value}/{version}",
36
- files=DatasetFiles(
37
- train="train.jsonl",
38
- eval=None,
39
- test=None
40
- ),
41
- stats=DatasetStats(
42
- total_examples=0,
43
- avg_length=0.0,
44
- num_conversations=0,
45
- additional_metrics={}
46
- ),
47
- source_annotations=source_annotations,
48
- created_at=datetime.utcnow(),
49
- status=DatasetStatus.PENDING,
50
- metadata={}
51
- )
52
-
53
- result = await collection.insert_one(dataset.dict(exclude={'id'}))
54
- return Dataset(**{**dataset.dict(), '_id': result.inserted_id})
55
-
56
- async def upload_dataset_file(
57
- self,
58
- dataset_id: str,
59
- data: List[Dict[str, Any]],
60
- file_type: str = "train"
61
- ) -> bool:
62
- """Upload dataset to MinIO"""
63
- try:
64
- await self._ensure_minio_client()
65
- db = await config_manager.get_db('mongodb')
66
-
67
- object_id = ObjectId(dataset_id)
68
- dataset = await db['training_datasets'].find_one({"_id": object_id})
69
-
70
- if not dataset:
71
- self.logger.error(f"Dataset not found with id: {dataset_id}")
72
- return False
73
-
74
- # Convert to JSONL
75
- buffer = io.StringIO()
76
- for item in data:
77
- buffer.write(json.dumps(item) + "\n")
78
-
79
- storage_path = dataset['storage_path'].rstrip('/')
80
- file_path = f"{storage_path}/{file_type}.jsonl"
81
-
82
- buffer_value = buffer.getvalue().encode()
83
-
84
- self.logger.debug(f"Uploading to MinIO path: {file_path}")
85
-
86
- self.minio_client.put_object(
87
- self.bucket_name,
88
- file_path,
89
- io.BytesIO(buffer_value),
90
- len(buffer_value)
91
- )
92
-
93
- avg_length = sum(len(str(item)) for item in data) / len(data) if data else 0
94
-
95
- await db['training_datasets'].update_one(
96
- {"_id": object_id},
97
- {
98
- "$set": {
99
- f"files.{file_type}": f"{file_type}.jsonl",
100
- "stats.total_examples": len(data),
101
- "stats.avg_length": avg_length,
102
- "stats.num_conversations": len(data),
103
- "status": DatasetStatus.READY
104
- }
105
- }
106
- )
107
-
108
- return True
109
-
110
- except Exception as e:
111
- self.logger.error(f"Failed to upload dataset: {e}")
112
- return False
113
-
114
- async def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]:
115
- """Get dataset information"""
116
- try:
117
- db = await config_manager.get_db('mongodb')
118
- object_id = ObjectId(dataset_id) # Convert string ID to ObjectId
119
- dataset = await db['training_datasets'].find_one({"_id": object_id})
120
-
121
- if not dataset:
122
- self.logger.error(f"Dataset not found with id: {dataset_id}")
123
- return None
124
-
125
- # Convert ObjectId to string for JSON serialization
126
- dataset['_id'] = str(dataset['_id'])
127
- return dataset
128
-
129
- except Exception as e:
130
- self.logger.error(f"Failed to get dataset info: {e}")
131
- return None
@@ -1,44 +0,0 @@
1
- # app/services/llm_model/annotation/dataset/dataset_schema.py
2
- from enum import Enum
3
- from pydantic import BaseModel, Field
4
- from typing import Dict, List, Optional
5
- from datetime import datetime
6
- from bson import ObjectId
7
-
8
- class DatasetType(str, Enum):
9
- SFT = "sft"
10
- RLHF = "rlhf"
11
-
12
- class DatasetStatus(str, Enum):
13
- PENDING = "pending"
14
- PROCESSING = "processing"
15
- READY = "ready"
16
- ERROR = "error"
17
-
18
- class DatasetFiles(BaseModel):
19
- train: str
20
- eval: Optional[str]
21
- test: Optional[str]
22
-
23
- class DatasetStats(BaseModel):
24
- total_examples: int
25
- avg_length: Optional[float]
26
- num_conversations: Optional[int]
27
- additional_metrics: Optional[Dict] = {}
28
-
29
- class Dataset(BaseModel):
30
- id: Optional[ObjectId] = Field(None, alias="_id")
31
- name: str
32
- type: DatasetType
33
- version: str
34
- storage_path: str
35
- files: DatasetFiles
36
- stats: DatasetStats
37
- source_annotations: List[str]
38
- created_at: datetime
39
- status: DatasetStatus
40
- metadata: Optional[Dict] = {}
41
-
42
- class Config:
43
- arbitrary_types_allowed = True
44
- populate_by_name = True