isa-model 0.3.91__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/client.py +1166 -584
- isa_model/core/cache/redis_cache.py +410 -0
- isa_model/core/config/config_manager.py +282 -12
- isa_model/core/config.py +91 -1
- isa_model/core/database/__init__.py +1 -0
- isa_model/core/database/direct_db_client.py +114 -0
- isa_model/core/database/migration_manager.py +563 -0
- isa_model/core/database/migrations.py +297 -0
- isa_model/core/database/supabase_client.py +258 -0
- isa_model/core/dependencies.py +316 -0
- isa_model/core/discovery/__init__.py +19 -0
- isa_model/core/discovery/consul_discovery.py +190 -0
- isa_model/core/logging/__init__.py +54 -0
- isa_model/core/logging/influx_logger.py +523 -0
- isa_model/core/logging/loki_logger.py +160 -0
- isa_model/core/models/__init__.py +46 -0
- isa_model/core/models/config_models.py +625 -0
- isa_model/core/models/deployment_billing_tracker.py +430 -0
- isa_model/core/models/model_billing_tracker.py +60 -88
- isa_model/core/models/model_manager.py +66 -25
- isa_model/core/models/model_metadata.py +690 -0
- isa_model/core/models/model_repo.py +217 -55
- isa_model/core/models/model_statistics_tracker.py +234 -0
- isa_model/core/models/model_storage.py +0 -1
- isa_model/core/models/model_version_manager.py +959 -0
- isa_model/core/models/system_models.py +857 -0
- isa_model/core/pricing_manager.py +2 -249
- isa_model/core/repositories/__init__.py +9 -0
- isa_model/core/repositories/config_repository.py +912 -0
- isa_model/core/resilience/circuit_breaker.py +366 -0
- isa_model/core/security/secrets.py +358 -0
- isa_model/core/services/__init__.py +2 -4
- isa_model/core/services/intelligent_model_selector.py +479 -370
- isa_model/core/storage/hf_storage.py +2 -2
- isa_model/core/types.py +8 -0
- isa_model/deployment/__init__.py +5 -48
- isa_model/deployment/core/__init__.py +2 -31
- isa_model/deployment/core/deployment_manager.py +1278 -368
- isa_model/deployment/local/__init__.py +31 -0
- isa_model/deployment/local/config.py +248 -0
- isa_model/deployment/local/gpu_gateway.py +607 -0
- isa_model/deployment/local/health_checker.py +428 -0
- isa_model/deployment/local/provider.py +586 -0
- isa_model/deployment/local/tensorrt_service.py +621 -0
- isa_model/deployment/local/transformers_service.py +644 -0
- isa_model/deployment/local/vllm_service.py +527 -0
- isa_model/deployment/modal/__init__.py +8 -0
- isa_model/deployment/modal/config.py +136 -0
- isa_model/deployment/modal/deployer.py +894 -0
- isa_model/deployment/modal/services/__init__.py +3 -0
- isa_model/deployment/modal/services/audio/__init__.py +1 -0
- isa_model/deployment/modal/services/audio/isa_audio_chatTTS_service.py +520 -0
- isa_model/deployment/modal/services/audio/isa_audio_openvoice_service.py +758 -0
- isa_model/deployment/modal/services/audio/isa_audio_service_v2.py +1044 -0
- isa_model/deployment/modal/services/embedding/__init__.py +1 -0
- isa_model/deployment/modal/services/embedding/isa_embed_rerank_service.py +296 -0
- isa_model/deployment/modal/services/llm/__init__.py +1 -0
- isa_model/deployment/modal/services/llm/isa_llm_service.py +424 -0
- isa_model/deployment/modal/services/video/__init__.py +1 -0
- isa_model/deployment/modal/services/video/isa_video_hunyuan_service.py +423 -0
- isa_model/deployment/modal/services/vision/__init__.py +1 -0
- isa_model/deployment/modal/services/vision/isa_vision_ocr_service.py +519 -0
- isa_model/deployment/modal/services/vision/isa_vision_qwen25_service.py +709 -0
- isa_model/deployment/modal/services/vision/isa_vision_table_service.py +676 -0
- isa_model/deployment/modal/services/vision/isa_vision_ui_service.py +833 -0
- isa_model/deployment/modal/services/vision/isa_vision_ui_service_optimized.py +660 -0
- isa_model/deployment/models/org-org-acme-corp-tenant-a-service-llm-20250825-225822/tenant-a-service_modal_service.py +48 -0
- isa_model/deployment/models/org-test-org-123-prefix-test-service-llm-20250825-225822/prefix-test-service_modal_service.py +48 -0
- isa_model/deployment/models/test-llm-service-llm-20250825-204442/test-llm-service_modal_service.py +48 -0
- isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-212906/test-monitoring-gpt2_modal_service.py +48 -0
- isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-213009/test-monitoring-gpt2_modal_service.py +48 -0
- isa_model/deployment/storage/__init__.py +5 -0
- isa_model/deployment/storage/deployment_repository.py +824 -0
- isa_model/deployment/triton/__init__.py +10 -0
- isa_model/deployment/triton/config.py +196 -0
- isa_model/deployment/triton/configs/__init__.py +1 -0
- isa_model/deployment/triton/provider.py +512 -0
- isa_model/deployment/triton/scripts/__init__.py +1 -0
- isa_model/deployment/triton/templates/__init__.py +1 -0
- isa_model/inference/__init__.py +47 -1
- isa_model/inference/ai_factory.py +179 -16
- isa_model/inference/legacy_services/__init__.py +21 -0
- isa_model/inference/legacy_services/model_evaluation.py +637 -0
- isa_model/inference/legacy_services/model_service.py +573 -0
- isa_model/inference/legacy_services/model_serving.py +717 -0
- isa_model/inference/legacy_services/model_training.py +561 -0
- isa_model/inference/models/__init__.py +21 -0
- isa_model/inference/models/inference_config.py +551 -0
- isa_model/inference/models/inference_record.py +675 -0
- isa_model/inference/models/performance_models.py +714 -0
- isa_model/inference/repositories/__init__.py +9 -0
- isa_model/inference/repositories/inference_repository.py +828 -0
- isa_model/inference/services/audio/__init__.py +21 -0
- isa_model/inference/services/audio/base_realtime_service.py +225 -0
- isa_model/inference/services/audio/base_stt_service.py +184 -11
- isa_model/inference/services/audio/isa_tts_service.py +0 -0
- isa_model/inference/services/audio/openai_realtime_service.py +320 -124
- isa_model/inference/services/audio/openai_stt_service.py +53 -11
- isa_model/inference/services/base_service.py +17 -1
- isa_model/inference/services/custom_model_manager.py +277 -0
- isa_model/inference/services/embedding/__init__.py +13 -0
- isa_model/inference/services/embedding/base_embed_service.py +111 -8
- isa_model/inference/services/embedding/isa_embed_service.py +305 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +15 -3
- isa_model/inference/services/embedding/openai_embed_service.py +2 -4
- isa_model/inference/services/embedding/resilient_embed_service.py +285 -0
- isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
- isa_model/inference/services/img/__init__.py +2 -2
- isa_model/inference/services/img/base_image_gen_service.py +24 -7
- isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
- isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
- isa_model/inference/services/img/services/replicate_flux.py +226 -0
- isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
- isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
- isa_model/inference/services/img/tests/test_img_client.py +297 -0
- isa_model/inference/services/llm/__init__.py +10 -2
- isa_model/inference/services/llm/base_llm_service.py +361 -26
- isa_model/inference/services/llm/cerebras_llm_service.py +628 -0
- isa_model/inference/services/llm/helpers/llm_adapter.py +71 -12
- isa_model/inference/services/llm/helpers/llm_prompts.py +342 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +321 -23
- isa_model/inference/services/llm/huggingface_llm_service.py +581 -0
- isa_model/inference/services/llm/local_llm_service.py +747 -0
- isa_model/inference/services/llm/ollama_llm_service.py +11 -3
- isa_model/inference/services/llm/openai_llm_service.py +670 -56
- isa_model/inference/services/llm/yyds_llm_service.py +10 -3
- isa_model/inference/services/vision/__init__.py +27 -6
- isa_model/inference/services/vision/base_vision_service.py +118 -185
- isa_model/inference/services/vision/blip_vision_service.py +359 -0
- isa_model/inference/services/vision/helpers/image_utils.py +19 -10
- isa_model/inference/services/vision/isa_vision_service.py +634 -0
- isa_model/inference/services/vision/openai_vision_service.py +19 -10
- isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
- isa_model/inference/services/vision/vgg16_vision_service.py +257 -0
- isa_model/serving/api/cache_manager.py +245 -0
- isa_model/serving/api/dependencies/__init__.py +1 -0
- isa_model/serving/api/dependencies/auth.py +194 -0
- isa_model/serving/api/dependencies/database.py +139 -0
- isa_model/serving/api/error_handlers.py +284 -0
- isa_model/serving/api/fastapi_server.py +240 -18
- isa_model/serving/api/middleware/auth.py +317 -0
- isa_model/serving/api/middleware/security.py +268 -0
- isa_model/serving/api/middleware/tenant_context.py +414 -0
- isa_model/serving/api/routes/analytics.py +489 -0
- isa_model/serving/api/routes/config.py +645 -0
- isa_model/serving/api/routes/deployment_billing.py +315 -0
- isa_model/serving/api/routes/deployments.py +475 -0
- isa_model/serving/api/routes/gpu_gateway.py +440 -0
- isa_model/serving/api/routes/health.py +32 -12
- isa_model/serving/api/routes/inference_monitoring.py +486 -0
- isa_model/serving/api/routes/local_deployments.py +448 -0
- isa_model/serving/api/routes/logs.py +430 -0
- isa_model/serving/api/routes/settings.py +582 -0
- isa_model/serving/api/routes/tenants.py +575 -0
- isa_model/serving/api/routes/unified.py +992 -171
- isa_model/serving/api/routes/webhooks.py +479 -0
- isa_model/serving/api/startup.py +318 -0
- isa_model/serving/modal_proxy_server.py +249 -0
- isa_model/utils/gpu_utils.py +311 -0
- {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/METADATA +76 -22
- isa_model-0.4.3.dist-info/RECORD +193 -0
- isa_model/deployment/cloud/__init__.py +0 -9
- isa_model/deployment/cloud/modal/__init__.py +0 -10
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +0 -532
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +0 -406
- isa_model/deployment/cloud/modal/register_models.py +0 -321
- isa_model/deployment/core/deployment_config.py +0 -356
- isa_model/deployment/core/isa_deployment_service.py +0 -401
- isa_model/deployment/gpu_int8_ds8/app/server.py +0 -66
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +0 -43
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +0 -35
- isa_model/deployment/runtime/deployed_service.py +0 -338
- isa_model/deployment/services/__init__.py +0 -9
- isa_model/deployment/services/auto_deploy_vision_service.py +0 -538
- isa_model/deployment/services/model_service.py +0 -332
- isa_model/deployment/services/service_monitor.py +0 -356
- isa_model/deployment/services/service_registry.py +0 -527
- isa_model/eval/__init__.py +0 -92
- isa_model/eval/benchmarks.py +0 -469
- isa_model/eval/config/__init__.py +0 -10
- isa_model/eval/config/evaluation_config.py +0 -108
- isa_model/eval/evaluators/__init__.py +0 -18
- isa_model/eval/evaluators/base_evaluator.py +0 -503
- isa_model/eval/evaluators/llm_evaluator.py +0 -472
- isa_model/eval/factory.py +0 -531
- isa_model/eval/infrastructure/__init__.py +0 -24
- isa_model/eval/infrastructure/experiment_tracker.py +0 -466
- isa_model/eval/metrics.py +0 -798
- isa_model/inference/adapter/unified_api.py +0 -248
- isa_model/inference/services/helpers/stacked_config.py +0 -148
- isa_model/inference/services/img/flux_professional_service.py +0 -603
- isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/others/table_transformer_service.py +0 -61
- isa_model/inference/services/vision/doc_analysis_service.py +0 -640
- isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/vision/ui_analysis_service.py +0 -823
- isa_model/scripts/inference_tracker.py +0 -283
- isa_model/scripts/mlflow_manager.py +0 -379
- isa_model/scripts/model_registry.py +0 -465
- isa_model/scripts/register_models.py +0 -370
- isa_model/scripts/register_models_with_embeddings.py +0 -510
- isa_model/scripts/start_mlflow.py +0 -95
- isa_model/scripts/training_tracker.py +0 -257
- isa_model/training/__init__.py +0 -74
- isa_model/training/annotation/annotation_schema.py +0 -47
- isa_model/training/annotation/processors/annotation_processor.py +0 -126
- isa_model/training/annotation/storage/dataset_manager.py +0 -131
- isa_model/training/annotation/storage/dataset_schema.py +0 -44
- isa_model/training/annotation/tests/test_annotation_flow.py +0 -109
- isa_model/training/annotation/tests/test_minio copy.py +0 -113
- isa_model/training/annotation/tests/test_minio_upload.py +0 -43
- isa_model/training/annotation/views/annotation_controller.py +0 -158
- isa_model/training/cloud/__init__.py +0 -22
- isa_model/training/cloud/job_orchestrator.py +0 -402
- isa_model/training/cloud/runpod_trainer.py +0 -454
- isa_model/training/cloud/storage_manager.py +0 -482
- isa_model/training/core/__init__.py +0 -23
- isa_model/training/core/config.py +0 -181
- isa_model/training/core/dataset.py +0 -222
- isa_model/training/core/trainer.py +0 -720
- isa_model/training/core/utils.py +0 -213
- isa_model/training/factory.py +0 -424
- isa_model-0.3.91.dist-info/RECORD +0 -138
- /isa_model/{core/storage/minio_storage.py → deployment/modal/services/audio/isa_audio_fish_service.py} +0 -0
- /isa_model/deployment/{services → modal/services/vision}/simple_auto_deploy_vision_service.py +0 -0
- {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/WHEEL +0 -0
- {isa_model-0.3.91.dist-info → isa_model-0.4.3.dist-info}/top_level.txt +0 -0
@@ -1,257 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
MLflow tracker for training workflows.
|
3
|
-
"""
|
4
|
-
|
5
|
-
import os
|
6
|
-
import json
|
7
|
-
import logging
|
8
|
-
from typing import Dict, List, Optional, Any, Union
|
9
|
-
from contextlib import contextmanager
|
10
|
-
|
11
|
-
from .mlflow_manager import MLflowManager, ExperimentType
|
12
|
-
from .model_registry import ModelRegistry, ModelStage
|
13
|
-
|
14
|
-
|
15
|
-
logger = logging.getLogger(__name__)
|
16
|
-
|
17
|
-
|
18
|
-
class TrainingTracker:
|
19
|
-
"""
|
20
|
-
Tracker for model training workflows.
|
21
|
-
|
22
|
-
This class provides utilities to track model training using MLflow
|
23
|
-
and register trained models in the model registry.
|
24
|
-
|
25
|
-
Example:
|
26
|
-
```python
|
27
|
-
# Initialize tracker
|
28
|
-
tracker = TrainingTracker(
|
29
|
-
tracking_uri="http://localhost:5000",
|
30
|
-
registry_uri="http://localhost:5000"
|
31
|
-
)
|
32
|
-
|
33
|
-
# Start tracking training
|
34
|
-
with tracker.track_training_run(
|
35
|
-
model_name="llama-7b",
|
36
|
-
training_params={
|
37
|
-
"learning_rate": 2e-5,
|
38
|
-
"batch_size": 8,
|
39
|
-
"epochs": 3
|
40
|
-
}
|
41
|
-
) as run_info:
|
42
|
-
# Train the model...
|
43
|
-
|
44
|
-
# Log metrics during training
|
45
|
-
tracker.log_metrics({
|
46
|
-
"train_loss": 0.1,
|
47
|
-
"val_loss": 0.2
|
48
|
-
})
|
49
|
-
|
50
|
-
# After training completes
|
51
|
-
model_path = "/path/to/trained_model"
|
52
|
-
|
53
|
-
# Register the model
|
54
|
-
tracker.register_trained_model(
|
55
|
-
model_path=model_path,
|
56
|
-
metrics={
|
57
|
-
"accuracy": 0.95,
|
58
|
-
"f1": 0.92
|
59
|
-
},
|
60
|
-
stage=ModelStage.STAGING
|
61
|
-
)
|
62
|
-
```
|
63
|
-
"""
|
64
|
-
|
65
|
-
def __init__(
|
66
|
-
self,
|
67
|
-
tracking_uri: Optional[str] = None,
|
68
|
-
artifact_uri: Optional[str] = None,
|
69
|
-
registry_uri: Optional[str] = None
|
70
|
-
):
|
71
|
-
"""
|
72
|
-
Initialize the training tracker.
|
73
|
-
|
74
|
-
Args:
|
75
|
-
tracking_uri: URI for MLflow tracking server
|
76
|
-
artifact_uri: URI for MLflow artifacts
|
77
|
-
registry_uri: URI for MLflow model registry
|
78
|
-
"""
|
79
|
-
self.mlflow_manager = MLflowManager(
|
80
|
-
tracking_uri=tracking_uri,
|
81
|
-
artifact_uri=artifact_uri,
|
82
|
-
registry_uri=registry_uri
|
83
|
-
)
|
84
|
-
self.model_registry = ModelRegistry(
|
85
|
-
tracking_uri=tracking_uri,
|
86
|
-
registry_uri=registry_uri
|
87
|
-
)
|
88
|
-
self.current_run_info = {}
|
89
|
-
|
90
|
-
@contextmanager
|
91
|
-
def track_training_run(
|
92
|
-
self,
|
93
|
-
model_name: str,
|
94
|
-
training_params: Dict[str, Any],
|
95
|
-
description: Optional[str] = None,
|
96
|
-
tags: Optional[Dict[str, str]] = None,
|
97
|
-
experiment_type: ExperimentType = ExperimentType.TRAINING
|
98
|
-
):
|
99
|
-
"""
|
100
|
-
Track a training run with MLflow.
|
101
|
-
|
102
|
-
Args:
|
103
|
-
model_name: Name of the model being trained
|
104
|
-
training_params: Parameters for the training run
|
105
|
-
description: Description of the training run
|
106
|
-
tags: Tags for the training run
|
107
|
-
experiment_type: Type of experiment
|
108
|
-
|
109
|
-
Yields:
|
110
|
-
Dictionary with run information
|
111
|
-
"""
|
112
|
-
run_info = {
|
113
|
-
"model_name": model_name,
|
114
|
-
"params": training_params,
|
115
|
-
"metrics": {}
|
116
|
-
}
|
117
|
-
|
118
|
-
# Add description to tags if provided
|
119
|
-
if tags is None:
|
120
|
-
tags = {}
|
121
|
-
|
122
|
-
if description:
|
123
|
-
tags["description"] = description
|
124
|
-
|
125
|
-
# Start the MLflow run
|
126
|
-
with self.mlflow_manager.start_run(
|
127
|
-
experiment_type=experiment_type,
|
128
|
-
model_name=model_name,
|
129
|
-
tags=tags
|
130
|
-
) as run:
|
131
|
-
run_info["run_id"] = run.info.run_id
|
132
|
-
run_info["experiment_id"] = run.info.experiment_id
|
133
|
-
run_info["status"] = "running"
|
134
|
-
|
135
|
-
# Save parameters
|
136
|
-
self.mlflow_manager.log_params(training_params)
|
137
|
-
|
138
|
-
self.current_run_info = run_info
|
139
|
-
try:
|
140
|
-
yield run_info
|
141
|
-
# Mark as successful if no exceptions
|
142
|
-
run_info["status"] = "completed"
|
143
|
-
except Exception as e:
|
144
|
-
# Mark as failed if exception occurred
|
145
|
-
run_info["status"] = "failed"
|
146
|
-
run_info["error"] = str(e)
|
147
|
-
self.mlflow_manager.set_tracking_tag("error", str(e))
|
148
|
-
raise
|
149
|
-
finally:
|
150
|
-
self.current_run_info = {}
|
151
|
-
|
152
|
-
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
153
|
-
"""
|
154
|
-
Log metrics to the current run.
|
155
|
-
|
156
|
-
Args:
|
157
|
-
metrics: Dictionary of metrics to log
|
158
|
-
step: Step value for the metrics
|
159
|
-
"""
|
160
|
-
self.mlflow_manager.log_metrics(metrics, step)
|
161
|
-
|
162
|
-
if self.current_run_info:
|
163
|
-
if "metrics" not in self.current_run_info:
|
164
|
-
self.current_run_info["metrics"] = {}
|
165
|
-
|
166
|
-
# Only keep the latest metrics
|
167
|
-
self.current_run_info["metrics"].update(metrics)
|
168
|
-
|
169
|
-
def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
|
170
|
-
"""
|
171
|
-
Log artifacts to the current run.
|
172
|
-
|
173
|
-
Args:
|
174
|
-
local_dir: Local directory containing artifacts
|
175
|
-
artifact_path: Path for the artifacts in MLflow
|
176
|
-
"""
|
177
|
-
self.mlflow_manager.log_artifacts(local_dir, artifact_path)
|
178
|
-
|
179
|
-
def register_trained_model(
|
180
|
-
self,
|
181
|
-
model_path: str,
|
182
|
-
metrics: Optional[Dict[str, float]] = None,
|
183
|
-
description: Optional[str] = None,
|
184
|
-
tags: Optional[Dict[str, str]] = None,
|
185
|
-
stage: Optional[ModelStage] = None,
|
186
|
-
flavor: str = "pyfunc"
|
187
|
-
) -> Optional[str]:
|
188
|
-
"""
|
189
|
-
Register a trained model with MLflow.
|
190
|
-
|
191
|
-
Args:
|
192
|
-
model_path: Path to the trained model
|
193
|
-
metrics: Evaluation metrics for the model
|
194
|
-
description: Description of the model
|
195
|
-
tags: Tags for the model
|
196
|
-
stage: Stage to register the model in
|
197
|
-
flavor: MLflow model flavor
|
198
|
-
|
199
|
-
Returns:
|
200
|
-
Version of the registered model or None if registration failed
|
201
|
-
"""
|
202
|
-
if not self.current_run_info:
|
203
|
-
logger.warning("No active run. Model cannot be registered.")
|
204
|
-
return None
|
205
|
-
|
206
|
-
model_name = self.current_run_info.get("model_name")
|
207
|
-
if not model_name:
|
208
|
-
logger.warning("Model name not available in run info. Using generic name.")
|
209
|
-
model_name = "unnamed_model"
|
210
|
-
|
211
|
-
# Log final metrics if provided
|
212
|
-
if metrics:
|
213
|
-
self.log_metrics(metrics)
|
214
|
-
|
215
|
-
# Prepare model tags
|
216
|
-
if tags is None:
|
217
|
-
tags = {}
|
218
|
-
|
219
|
-
# Add run ID to tags
|
220
|
-
tags["run_id"] = self.current_run_info.get("run_id", "")
|
221
|
-
|
222
|
-
# Add metrics to tags
|
223
|
-
for k, v in self.current_run_info.get("metrics", {}).items():
|
224
|
-
tags[f"metric.{k}"] = str(v)
|
225
|
-
|
226
|
-
# Log model to MLflow
|
227
|
-
artifact_path = self.mlflow_manager.log_model(
|
228
|
-
model_path=model_path,
|
229
|
-
name=model_name,
|
230
|
-
flavor=flavor
|
231
|
-
)
|
232
|
-
|
233
|
-
if not artifact_path:
|
234
|
-
logger.error("Failed to log model to MLflow.")
|
235
|
-
return None
|
236
|
-
|
237
|
-
# Get model URI
|
238
|
-
run_id = self.current_run_info.get("run_id")
|
239
|
-
model_uri = f"runs:/{run_id}/{artifact_path}"
|
240
|
-
|
241
|
-
# Register the model
|
242
|
-
version = self.model_registry.register_model(
|
243
|
-
name=model_name,
|
244
|
-
source=model_uri,
|
245
|
-
description=description,
|
246
|
-
tags=tags
|
247
|
-
)
|
248
|
-
|
249
|
-
# Transition to the specified stage if provided
|
250
|
-
if version and stage:
|
251
|
-
self.model_registry.transition_model_version_stage(
|
252
|
-
name=model_name,
|
253
|
-
version=version,
|
254
|
-
stage=stage
|
255
|
-
)
|
256
|
-
|
257
|
-
return version
|
isa_model/training/__init__.py
DELETED
@@ -1,74 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
ISA Model Training Module
|
3
|
-
|
4
|
-
Provides unified training capabilities for AI models including:
|
5
|
-
- Local training with SFT (Supervised Fine-Tuning)
|
6
|
-
- Cloud training on RunPod
|
7
|
-
- Model evaluation and management
|
8
|
-
- HuggingFace integration
|
9
|
-
|
10
|
-
Example usage:
|
11
|
-
```python
|
12
|
-
from isa_model.training import TrainingFactory, train_gemma
|
13
|
-
|
14
|
-
# Quick Gemma training
|
15
|
-
model_path = train_gemma(
|
16
|
-
dataset_path="tatsu-lab/alpaca",
|
17
|
-
model_size="4b",
|
18
|
-
num_epochs=3
|
19
|
-
)
|
20
|
-
|
21
|
-
# Advanced training with custom configuration
|
22
|
-
factory = TrainingFactory()
|
23
|
-
model_path = factory.train_model(
|
24
|
-
model_name="google/gemma-2-4b-it",
|
25
|
-
dataset_path="your-dataset.json",
|
26
|
-
use_lora=True,
|
27
|
-
batch_size=4,
|
28
|
-
num_epochs=3
|
29
|
-
)
|
30
|
-
```
|
31
|
-
"""
|
32
|
-
|
33
|
-
# Import the new clean factory
|
34
|
-
from .factory import TrainingFactory, train_gemma
|
35
|
-
|
36
|
-
# Import core components
|
37
|
-
from .core import (
|
38
|
-
TrainingConfig,
|
39
|
-
LoRAConfig,
|
40
|
-
DatasetConfig,
|
41
|
-
BaseTrainer,
|
42
|
-
SFTTrainer,
|
43
|
-
TrainingUtils,
|
44
|
-
DatasetManager
|
45
|
-
)
|
46
|
-
|
47
|
-
# Import cloud training components
|
48
|
-
from .cloud import (
|
49
|
-
RunPodConfig,
|
50
|
-
StorageConfig,
|
51
|
-
JobConfig,
|
52
|
-
TrainingJobOrchestrator
|
53
|
-
)
|
54
|
-
|
55
|
-
__all__ = [
|
56
|
-
# Main factory
|
57
|
-
'TrainingFactory',
|
58
|
-
'train_gemma',
|
59
|
-
|
60
|
-
# Core components
|
61
|
-
'TrainingConfig',
|
62
|
-
'LoRAConfig',
|
63
|
-
'DatasetConfig',
|
64
|
-
'BaseTrainer',
|
65
|
-
'SFTTrainer',
|
66
|
-
'TrainingUtils',
|
67
|
-
'DatasetManager',
|
68
|
-
|
69
|
-
# Cloud components
|
70
|
-
'RunPodConfig',
|
71
|
-
'StorageConfig',
|
72
|
-
'JobConfig',
|
73
|
-
'TrainingJobOrchestrator'
|
74
|
-
]
|
@@ -1,47 +0,0 @@
|
|
1
|
-
# app/services/llm_model/tracing/annotation/annotation_schema.py
|
2
|
-
from enum import Enum
|
3
|
-
from pydantic import BaseModel, Field
|
4
|
-
from typing import Dict, Any, List, Optional
|
5
|
-
from datetime import datetime
|
6
|
-
|
7
|
-
class AnnotationType(str, Enum):
|
8
|
-
ACCURACY = "accuracy"
|
9
|
-
HELPFULNESS = "helpfulness"
|
10
|
-
TOXICITY = "toxicity"
|
11
|
-
CUSTOM = "custom"
|
12
|
-
|
13
|
-
class RatingScale(int, Enum):
|
14
|
-
POOR = 1
|
15
|
-
FAIR = 2
|
16
|
-
GOOD = 3
|
17
|
-
EXCELLENT = 4
|
18
|
-
|
19
|
-
class AnnotationAspects(BaseModel):
|
20
|
-
factually_correct: bool = True
|
21
|
-
relevant: bool = True
|
22
|
-
harmful: bool = False
|
23
|
-
biased: bool = False
|
24
|
-
complete: bool = True
|
25
|
-
efficient: bool = True
|
26
|
-
|
27
|
-
class BetterResponse(BaseModel):
|
28
|
-
content: str
|
29
|
-
reason: Optional[str]
|
30
|
-
metadata: Optional[Dict[str, Any]] = {}
|
31
|
-
|
32
|
-
class AnnotationFeedback(BaseModel):
|
33
|
-
rating: RatingScale
|
34
|
-
category: AnnotationType
|
35
|
-
aspects: AnnotationAspects
|
36
|
-
better_response: Optional[BetterResponse]
|
37
|
-
comment: Optional[str]
|
38
|
-
metadata: Optional[Dict[str, Any]] = {}
|
39
|
-
is_selected_for_training: bool = False
|
40
|
-
|
41
|
-
class ItemAnnotation(BaseModel):
|
42
|
-
item_id: str
|
43
|
-
feedback: Optional[AnnotationFeedback]
|
44
|
-
status: str = "pending"
|
45
|
-
annotated_at: Optional[datetime]
|
46
|
-
annotator_id: Optional[str]
|
47
|
-
training_status: Optional[str] = None
|
@@ -1,126 +0,0 @@
|
|
1
|
-
from typing import Dict, Any, List
|
2
|
-
from datetime import datetime
|
3
|
-
from app.config.config_manager import config_manager
|
4
|
-
from app.services.training.llm_model.annotation.annotation_schema import AnnotationFeedback, RatingScale, AnnotationAspects
|
5
|
-
from bson.objectid import ObjectId
|
6
|
-
from app.services.training.llm_model.annotation.storage.dataset_manager import DatasetManager
|
7
|
-
|
8
|
-
class AnnotationProcessor:
|
9
|
-
def __init__(self):
|
10
|
-
self.logger = config_manager.get_logger(__name__)
|
11
|
-
self.dataset_manager = DatasetManager()
|
12
|
-
self.batch_size = 1000 # Configure as needed
|
13
|
-
|
14
|
-
async def process_queue(self) -> None:
|
15
|
-
"""Process pending items and create datasets"""
|
16
|
-
db = await config_manager.get_db('mongodb')
|
17
|
-
queue = db['training_queue']
|
18
|
-
|
19
|
-
# Process SFT items
|
20
|
-
sft_items = await self._get_pending_items("sft")
|
21
|
-
if len(sft_items) >= self.batch_size:
|
22
|
-
await self._create_sft_dataset(sft_items)
|
23
|
-
|
24
|
-
# Process RLHF items
|
25
|
-
rlhf_items = await self._get_pending_items("rlhf")
|
26
|
-
if len(rlhf_items) >= self.batch_size:
|
27
|
-
await self._create_rlhf_dataset(rlhf_items)
|
28
|
-
|
29
|
-
async def _create_sft_dataset(self, items: List[Dict[str, Any]]):
|
30
|
-
"""Create and upload SFT dataset"""
|
31
|
-
dataset = await self.dataset_manager.create_dataset(
|
32
|
-
name=f"sft_dataset_v{datetime.now().strftime('%Y%m%d')}",
|
33
|
-
type="sft",
|
34
|
-
version=datetime.now().strftime("%Y%m%d"),
|
35
|
-
source_annotations=[item["annotation_id"] for item in items]
|
36
|
-
)
|
37
|
-
|
38
|
-
formatted_data = [
|
39
|
-
await self._process_sft_item(item)
|
40
|
-
for item in items
|
41
|
-
]
|
42
|
-
|
43
|
-
await self.dataset_manager.upload_dataset_file(
|
44
|
-
dataset.id,
|
45
|
-
formatted_data
|
46
|
-
)
|
47
|
-
|
48
|
-
async def _process_sft_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
49
|
-
"""Process item for SFT dataset generation
|
50
|
-
Format follows HF conversation format for SFT training
|
51
|
-
"""
|
52
|
-
db = await config_manager.get_db('mongodb')
|
53
|
-
annotations = db['annotations']
|
54
|
-
|
55
|
-
# Get full annotation context
|
56
|
-
annotation = await annotations.find_one({"_id": ObjectId(item["annotation_id"])})
|
57
|
-
target_item = next(i for i in annotation["items"] if i["item_id"] == item["item_id"])
|
58
|
-
|
59
|
-
# Format as conversation
|
60
|
-
messages = [
|
61
|
-
{
|
62
|
-
"role": "system",
|
63
|
-
"content": "You are a helpful AI assistant that provides accurate and relevant information."
|
64
|
-
},
|
65
|
-
{
|
66
|
-
"role": "user",
|
67
|
-
"content": target_item["input"]["messages"][0]["content"]
|
68
|
-
},
|
69
|
-
{
|
70
|
-
"role": "assistant",
|
71
|
-
"content": target_item["output"]["content"]
|
72
|
-
}
|
73
|
-
]
|
74
|
-
|
75
|
-
return {
|
76
|
-
"messages": messages,
|
77
|
-
"metadata": {
|
78
|
-
"rating": item["feedback"]["rating"],
|
79
|
-
"aspects": item["feedback"]["aspects"],
|
80
|
-
"category": item["feedback"]["category"]
|
81
|
-
}
|
82
|
-
}
|
83
|
-
|
84
|
-
async def _process_rlhf_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
85
|
-
"""Process item for RLHF dataset generation
|
86
|
-
Format follows preference pairs structure for RLHF training
|
87
|
-
"""
|
88
|
-
db = await config_manager.get_db('mongodb')
|
89
|
-
annotations = db['annotations']
|
90
|
-
|
91
|
-
# Get full annotation context
|
92
|
-
annotation = await annotations.find_one({"_id": ObjectId(item["annotation_id"])})
|
93
|
-
target_item = next(i for i in annotation["items"] if i["item_id"] == item["item_id"])
|
94
|
-
|
95
|
-
# Format as preference pairs
|
96
|
-
return {
|
97
|
-
"prompt": target_item["input"]["messages"][0]["content"],
|
98
|
-
"chosen": item["feedback"]["better_response"]["content"],
|
99
|
-
"rejected": target_item["output"]["content"],
|
100
|
-
"metadata": {
|
101
|
-
"reason": item["feedback"]["better_response"]["reason"],
|
102
|
-
"category": item["feedback"]["category"]
|
103
|
-
}
|
104
|
-
}
|
105
|
-
|
106
|
-
async def get_training_data(
|
107
|
-
self,
|
108
|
-
data_type: str,
|
109
|
-
limit: int = 1000
|
110
|
-
) -> List[Dict[str, Any]]:
|
111
|
-
"""Retrieve formatted training data"""
|
112
|
-
db = await config_manager.get_db('mongodb')
|
113
|
-
training_data = db['training_data']
|
114
|
-
|
115
|
-
data = await training_data.find(
|
116
|
-
{"type": data_type}
|
117
|
-
).limit(limit).to_list(length=limit)
|
118
|
-
|
119
|
-
if data_type == "sft":
|
120
|
-
return [item["data"]["messages"] for item in data]
|
121
|
-
else: # rlhf
|
122
|
-
return [{
|
123
|
-
"prompt": item["data"]["prompt"],
|
124
|
-
"chosen": item["data"]["chosen"],
|
125
|
-
"rejected": item["data"]["rejected"]
|
126
|
-
} for item in data]
|
@@ -1,131 +0,0 @@
|
|
1
|
-
# app/services/llm_model/annotation/dataset/dataset_manager.py
|
2
|
-
from typing import Dict, Any, List
|
3
|
-
from datetime import datetime
|
4
|
-
import json
|
5
|
-
import io
|
6
|
-
from app.config.config_manager import config_manager
|
7
|
-
from .dataset_schema import Dataset, DatasetType, DatasetStatus, DatasetFiles, DatasetStats
|
8
|
-
from bson import ObjectId
|
9
|
-
|
10
|
-
class DatasetManager:
|
11
|
-
def __init__(self):
|
12
|
-
self.logger = config_manager.get_logger(__name__)
|
13
|
-
self.minio_client = None
|
14
|
-
self.bucket_name = "training-datasets"
|
15
|
-
|
16
|
-
async def _ensure_minio_client(self):
|
17
|
-
if not self.minio_client:
|
18
|
-
self.minio_client = await config_manager.get_storage_client()
|
19
|
-
|
20
|
-
async def create_dataset(
|
21
|
-
self,
|
22
|
-
name: str,
|
23
|
-
type: DatasetType,
|
24
|
-
version: str,
|
25
|
-
source_annotations: List[str]
|
26
|
-
) -> Dataset:
|
27
|
-
"""Create a new dataset record"""
|
28
|
-
db = await config_manager.get_db('mongodb')
|
29
|
-
collection = db['training_datasets']
|
30
|
-
|
31
|
-
dataset = Dataset(
|
32
|
-
name=name,
|
33
|
-
type=type,
|
34
|
-
version=version,
|
35
|
-
storage_path=f"datasets/{type.value}/{version}",
|
36
|
-
files=DatasetFiles(
|
37
|
-
train="train.jsonl",
|
38
|
-
eval=None,
|
39
|
-
test=None
|
40
|
-
),
|
41
|
-
stats=DatasetStats(
|
42
|
-
total_examples=0,
|
43
|
-
avg_length=0.0,
|
44
|
-
num_conversations=0,
|
45
|
-
additional_metrics={}
|
46
|
-
),
|
47
|
-
source_annotations=source_annotations,
|
48
|
-
created_at=datetime.utcnow(),
|
49
|
-
status=DatasetStatus.PENDING,
|
50
|
-
metadata={}
|
51
|
-
)
|
52
|
-
|
53
|
-
result = await collection.insert_one(dataset.dict(exclude={'id'}))
|
54
|
-
return Dataset(**{**dataset.dict(), '_id': result.inserted_id})
|
55
|
-
|
56
|
-
async def upload_dataset_file(
|
57
|
-
self,
|
58
|
-
dataset_id: str,
|
59
|
-
data: List[Dict[str, Any]],
|
60
|
-
file_type: str = "train"
|
61
|
-
) -> bool:
|
62
|
-
"""Upload dataset to MinIO"""
|
63
|
-
try:
|
64
|
-
await self._ensure_minio_client()
|
65
|
-
db = await config_manager.get_db('mongodb')
|
66
|
-
|
67
|
-
object_id = ObjectId(dataset_id)
|
68
|
-
dataset = await db['training_datasets'].find_one({"_id": object_id})
|
69
|
-
|
70
|
-
if not dataset:
|
71
|
-
self.logger.error(f"Dataset not found with id: {dataset_id}")
|
72
|
-
return False
|
73
|
-
|
74
|
-
# Convert to JSONL
|
75
|
-
buffer = io.StringIO()
|
76
|
-
for item in data:
|
77
|
-
buffer.write(json.dumps(item) + "\n")
|
78
|
-
|
79
|
-
storage_path = dataset['storage_path'].rstrip('/')
|
80
|
-
file_path = f"{storage_path}/{file_type}.jsonl"
|
81
|
-
|
82
|
-
buffer_value = buffer.getvalue().encode()
|
83
|
-
|
84
|
-
self.logger.debug(f"Uploading to MinIO path: {file_path}")
|
85
|
-
|
86
|
-
self.minio_client.put_object(
|
87
|
-
self.bucket_name,
|
88
|
-
file_path,
|
89
|
-
io.BytesIO(buffer_value),
|
90
|
-
len(buffer_value)
|
91
|
-
)
|
92
|
-
|
93
|
-
avg_length = sum(len(str(item)) for item in data) / len(data) if data else 0
|
94
|
-
|
95
|
-
await db['training_datasets'].update_one(
|
96
|
-
{"_id": object_id},
|
97
|
-
{
|
98
|
-
"$set": {
|
99
|
-
f"files.{file_type}": f"{file_type}.jsonl",
|
100
|
-
"stats.total_examples": len(data),
|
101
|
-
"stats.avg_length": avg_length,
|
102
|
-
"stats.num_conversations": len(data),
|
103
|
-
"status": DatasetStatus.READY
|
104
|
-
}
|
105
|
-
}
|
106
|
-
)
|
107
|
-
|
108
|
-
return True
|
109
|
-
|
110
|
-
except Exception as e:
|
111
|
-
self.logger.error(f"Failed to upload dataset: {e}")
|
112
|
-
return False
|
113
|
-
|
114
|
-
async def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]:
|
115
|
-
"""Get dataset information"""
|
116
|
-
try:
|
117
|
-
db = await config_manager.get_db('mongodb')
|
118
|
-
object_id = ObjectId(dataset_id) # Convert string ID to ObjectId
|
119
|
-
dataset = await db['training_datasets'].find_one({"_id": object_id})
|
120
|
-
|
121
|
-
if not dataset:
|
122
|
-
self.logger.error(f"Dataset not found with id: {dataset_id}")
|
123
|
-
return None
|
124
|
-
|
125
|
-
# Convert ObjectId to string for JSON serialization
|
126
|
-
dataset['_id'] = str(dataset['_id'])
|
127
|
-
return dataset
|
128
|
-
|
129
|
-
except Exception as e:
|
130
|
-
self.logger.error(f"Failed to get dataset info: {e}")
|
131
|
-
return None
|
@@ -1,44 +0,0 @@
|
|
1
|
-
# app/services/llm_model/annotation/dataset/dataset_schema.py
|
2
|
-
from enum import Enum
|
3
|
-
from pydantic import BaseModel, Field
|
4
|
-
from typing import Dict, List, Optional
|
5
|
-
from datetime import datetime
|
6
|
-
from bson import ObjectId
|
7
|
-
|
8
|
-
class DatasetType(str, Enum):
|
9
|
-
SFT = "sft"
|
10
|
-
RLHF = "rlhf"
|
11
|
-
|
12
|
-
class DatasetStatus(str, Enum):
|
13
|
-
PENDING = "pending"
|
14
|
-
PROCESSING = "processing"
|
15
|
-
READY = "ready"
|
16
|
-
ERROR = "error"
|
17
|
-
|
18
|
-
class DatasetFiles(BaseModel):
|
19
|
-
train: str
|
20
|
-
eval: Optional[str]
|
21
|
-
test: Optional[str]
|
22
|
-
|
23
|
-
class DatasetStats(BaseModel):
|
24
|
-
total_examples: int
|
25
|
-
avg_length: Optional[float]
|
26
|
-
num_conversations: Optional[int]
|
27
|
-
additional_metrics: Optional[Dict] = {}
|
28
|
-
|
29
|
-
class Dataset(BaseModel):
|
30
|
-
id: Optional[ObjectId] = Field(None, alias="_id")
|
31
|
-
name: str
|
32
|
-
type: DatasetType
|
33
|
-
version: str
|
34
|
-
storage_path: str
|
35
|
-
files: DatasetFiles
|
36
|
-
stats: DatasetStats
|
37
|
-
source_annotations: List[str]
|
38
|
-
created_at: datetime
|
39
|
-
status: DatasetStatus
|
40
|
-
metadata: Optional[Dict] = {}
|
41
|
-
|
42
|
-
class Config:
|
43
|
-
arbitrary_types_allowed = True
|
44
|
-
populate_by_name = True
|