isa-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +354 -0
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model/inference/backends/__init__.py +53 -0
- isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model/inference/backends/container_services.py +104 -0
- isa_model/inference/backends/local_services.py +72 -0
- isa_model/inference/backends/openai_client.py +130 -0
- isa_model/inference/backends/replicate_client.py +197 -0
- isa_model/inference/backends/third_party_services.py +239 -0
- isa_model/inference/backends/triton_client.py +97 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model/inference/client_sdk/client.py +134 -0
- isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +87 -0
- isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/base_tts_service.py +66 -0
- isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model/inference/services/llm/__init__.py +16 -0
- isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0.dist-info/METADATA +116 -0
- isa_model-0.1.0.dist-info/RECORD +117 -0
- isa_model-0.1.0.dist-info/WHEEL +5 -0
- isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
- isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,257 @@
|
|
1
|
+
"""
|
2
|
+
MLflow tracker for training workflows.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import json
|
7
|
+
import logging
|
8
|
+
from typing import Dict, List, Optional, Any, Union
|
9
|
+
from contextlib import contextmanager
|
10
|
+
|
11
|
+
from .mlflow_manager import MLflowManager, ExperimentType
|
12
|
+
from .model_registry import ModelRegistry, ModelStage
|
13
|
+
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class TrainingTracker:
|
19
|
+
"""
|
20
|
+
Tracker for model training workflows.
|
21
|
+
|
22
|
+
This class provides utilities to track model training using MLflow
|
23
|
+
and register trained models in the model registry.
|
24
|
+
|
25
|
+
Example:
|
26
|
+
```python
|
27
|
+
# Initialize tracker
|
28
|
+
tracker = TrainingTracker(
|
29
|
+
tracking_uri="http://localhost:5000",
|
30
|
+
registry_uri="http://localhost:5000"
|
31
|
+
)
|
32
|
+
|
33
|
+
# Start tracking training
|
34
|
+
with tracker.track_training_run(
|
35
|
+
model_name="llama-7b",
|
36
|
+
training_params={
|
37
|
+
"learning_rate": 2e-5,
|
38
|
+
"batch_size": 8,
|
39
|
+
"epochs": 3
|
40
|
+
}
|
41
|
+
) as run_info:
|
42
|
+
# Train the model...
|
43
|
+
|
44
|
+
# Log metrics during training
|
45
|
+
tracker.log_metrics({
|
46
|
+
"train_loss": 0.1,
|
47
|
+
"val_loss": 0.2
|
48
|
+
})
|
49
|
+
|
50
|
+
# After training completes
|
51
|
+
model_path = "/path/to/trained_model"
|
52
|
+
|
53
|
+
# Register the model
|
54
|
+
tracker.register_trained_model(
|
55
|
+
model_path=model_path,
|
56
|
+
metrics={
|
57
|
+
"accuracy": 0.95,
|
58
|
+
"f1": 0.92
|
59
|
+
},
|
60
|
+
stage=ModelStage.STAGING
|
61
|
+
)
|
62
|
+
```
|
63
|
+
"""
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
tracking_uri: Optional[str] = None,
|
68
|
+
artifact_uri: Optional[str] = None,
|
69
|
+
registry_uri: Optional[str] = None
|
70
|
+
):
|
71
|
+
"""
|
72
|
+
Initialize the training tracker.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
tracking_uri: URI for MLflow tracking server
|
76
|
+
artifact_uri: URI for MLflow artifacts
|
77
|
+
registry_uri: URI for MLflow model registry
|
78
|
+
"""
|
79
|
+
self.mlflow_manager = MLflowManager(
|
80
|
+
tracking_uri=tracking_uri,
|
81
|
+
artifact_uri=artifact_uri,
|
82
|
+
registry_uri=registry_uri
|
83
|
+
)
|
84
|
+
self.model_registry = ModelRegistry(
|
85
|
+
tracking_uri=tracking_uri,
|
86
|
+
registry_uri=registry_uri
|
87
|
+
)
|
88
|
+
self.current_run_info = {}
|
89
|
+
|
90
|
+
@contextmanager
|
91
|
+
def track_training_run(
|
92
|
+
self,
|
93
|
+
model_name: str,
|
94
|
+
training_params: Dict[str, Any],
|
95
|
+
description: Optional[str] = None,
|
96
|
+
tags: Optional[Dict[str, str]] = None,
|
97
|
+
experiment_type: ExperimentType = ExperimentType.TRAINING
|
98
|
+
):
|
99
|
+
"""
|
100
|
+
Track a training run with MLflow.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
model_name: Name of the model being trained
|
104
|
+
training_params: Parameters for the training run
|
105
|
+
description: Description of the training run
|
106
|
+
tags: Tags for the training run
|
107
|
+
experiment_type: Type of experiment
|
108
|
+
|
109
|
+
Yields:
|
110
|
+
Dictionary with run information
|
111
|
+
"""
|
112
|
+
run_info = {
|
113
|
+
"model_name": model_name,
|
114
|
+
"params": training_params,
|
115
|
+
"metrics": {}
|
116
|
+
}
|
117
|
+
|
118
|
+
# Add description to tags if provided
|
119
|
+
if tags is None:
|
120
|
+
tags = {}
|
121
|
+
|
122
|
+
if description:
|
123
|
+
tags["description"] = description
|
124
|
+
|
125
|
+
# Start the MLflow run
|
126
|
+
with self.mlflow_manager.start_run(
|
127
|
+
experiment_type=experiment_type,
|
128
|
+
model_name=model_name,
|
129
|
+
tags=tags
|
130
|
+
) as run:
|
131
|
+
run_info["run_id"] = run.info.run_id
|
132
|
+
run_info["experiment_id"] = run.info.experiment_id
|
133
|
+
run_info["status"] = "running"
|
134
|
+
|
135
|
+
# Save parameters
|
136
|
+
self.mlflow_manager.log_params(training_params)
|
137
|
+
|
138
|
+
self.current_run_info = run_info
|
139
|
+
try:
|
140
|
+
yield run_info
|
141
|
+
# Mark as successful if no exceptions
|
142
|
+
run_info["status"] = "completed"
|
143
|
+
except Exception as e:
|
144
|
+
# Mark as failed if exception occurred
|
145
|
+
run_info["status"] = "failed"
|
146
|
+
run_info["error"] = str(e)
|
147
|
+
self.mlflow_manager.set_tracking_tag("error", str(e))
|
148
|
+
raise
|
149
|
+
finally:
|
150
|
+
self.current_run_info = {}
|
151
|
+
|
152
|
+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
153
|
+
"""
|
154
|
+
Log metrics to the current run.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
metrics: Dictionary of metrics to log
|
158
|
+
step: Step value for the metrics
|
159
|
+
"""
|
160
|
+
self.mlflow_manager.log_metrics(metrics, step)
|
161
|
+
|
162
|
+
if self.current_run_info:
|
163
|
+
if "metrics" not in self.current_run_info:
|
164
|
+
self.current_run_info["metrics"] = {}
|
165
|
+
|
166
|
+
# Only keep the latest metrics
|
167
|
+
self.current_run_info["metrics"].update(metrics)
|
168
|
+
|
169
|
+
def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
|
170
|
+
"""
|
171
|
+
Log artifacts to the current run.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
local_dir: Local directory containing artifacts
|
175
|
+
artifact_path: Path for the artifacts in MLflow
|
176
|
+
"""
|
177
|
+
self.mlflow_manager.log_artifacts(local_dir, artifact_path)
|
178
|
+
|
179
|
+
def register_trained_model(
|
180
|
+
self,
|
181
|
+
model_path: str,
|
182
|
+
metrics: Optional[Dict[str, float]] = None,
|
183
|
+
description: Optional[str] = None,
|
184
|
+
tags: Optional[Dict[str, str]] = None,
|
185
|
+
stage: Optional[ModelStage] = None,
|
186
|
+
flavor: str = "pyfunc"
|
187
|
+
) -> Optional[str]:
|
188
|
+
"""
|
189
|
+
Register a trained model with MLflow.
|
190
|
+
|
191
|
+
Args:
|
192
|
+
model_path: Path to the trained model
|
193
|
+
metrics: Evaluation metrics for the model
|
194
|
+
description: Description of the model
|
195
|
+
tags: Tags for the model
|
196
|
+
stage: Stage to register the model in
|
197
|
+
flavor: MLflow model flavor
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
Version of the registered model or None if registration failed
|
201
|
+
"""
|
202
|
+
if not self.current_run_info:
|
203
|
+
logger.warning("No active run. Model cannot be registered.")
|
204
|
+
return None
|
205
|
+
|
206
|
+
model_name = self.current_run_info.get("model_name")
|
207
|
+
if not model_name:
|
208
|
+
logger.warning("Model name not available in run info. Using generic name.")
|
209
|
+
model_name = "unnamed_model"
|
210
|
+
|
211
|
+
# Log final metrics if provided
|
212
|
+
if metrics:
|
213
|
+
self.log_metrics(metrics)
|
214
|
+
|
215
|
+
# Prepare model tags
|
216
|
+
if tags is None:
|
217
|
+
tags = {}
|
218
|
+
|
219
|
+
# Add run ID to tags
|
220
|
+
tags["run_id"] = self.current_run_info.get("run_id", "")
|
221
|
+
|
222
|
+
# Add metrics to tags
|
223
|
+
for k, v in self.current_run_info.get("metrics", {}).items():
|
224
|
+
tags[f"metric.{k}"] = str(v)
|
225
|
+
|
226
|
+
# Log model to MLflow
|
227
|
+
artifact_path = self.mlflow_manager.log_model(
|
228
|
+
model_path=model_path,
|
229
|
+
name=model_name,
|
230
|
+
flavor=flavor
|
231
|
+
)
|
232
|
+
|
233
|
+
if not artifact_path:
|
234
|
+
logger.error("Failed to log model to MLflow.")
|
235
|
+
return None
|
236
|
+
|
237
|
+
# Get model URI
|
238
|
+
run_id = self.current_run_info.get("run_id")
|
239
|
+
model_uri = f"runs:/{run_id}/{artifact_path}"
|
240
|
+
|
241
|
+
# Register the model
|
242
|
+
version = self.model_registry.register_model(
|
243
|
+
name=model_name,
|
244
|
+
source=model_uri,
|
245
|
+
description=description,
|
246
|
+
tags=tags
|
247
|
+
)
|
248
|
+
|
249
|
+
# Transition to the specified stage if provided
|
250
|
+
if version and stage:
|
251
|
+
self.model_registry.transition_model_version_stage(
|
252
|
+
name=model_name,
|
253
|
+
version=version,
|
254
|
+
stage=stage
|
255
|
+
)
|
256
|
+
|
257
|
+
return version
|
@@ -0,0 +1,39 @@
|
|
1
|
+
"""
|
2
|
+
LlamaFactory engine for fine-tuning and reinforcement learning.
|
3
|
+
|
4
|
+
This package provides interfaces for using LlamaFactory to:
|
5
|
+
- Fine-tune models with various datasets
|
6
|
+
- Perform reinforcement learning from human feedback (RLHF)
|
7
|
+
- Support instruction tuning and preference optimization
|
8
|
+
"""
|
9
|
+
|
10
|
+
from .config import (
|
11
|
+
LlamaFactoryConfig,
|
12
|
+
SFTConfig,
|
13
|
+
RLConfig,
|
14
|
+
DPOConfig,
|
15
|
+
TrainingStrategy,
|
16
|
+
DatasetFormat,
|
17
|
+
create_default_config
|
18
|
+
)
|
19
|
+
from .trainer import LlamaFactoryTrainer
|
20
|
+
from .rl import LlamaFactoryRL
|
21
|
+
from .factory import LlamaFactory
|
22
|
+
from .data_adapter import DataAdapter, AlpacaAdapter, ShareGPTAdapter, DataAdapterFactory
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
"LlamaFactoryTrainer",
|
26
|
+
"LlamaFactoryRL",
|
27
|
+
"LlamaFactoryConfig",
|
28
|
+
"SFTConfig",
|
29
|
+
"RLConfig",
|
30
|
+
"DPOConfig",
|
31
|
+
"TrainingStrategy",
|
32
|
+
"DatasetFormat",
|
33
|
+
"create_default_config",
|
34
|
+
"LlamaFactory",
|
35
|
+
"DataAdapter",
|
36
|
+
"AlpacaAdapter",
|
37
|
+
"ShareGPTAdapter",
|
38
|
+
"DataAdapterFactory"
|
39
|
+
]
|
@@ -0,0 +1,115 @@
|
|
1
|
+
"""
|
2
|
+
Configuration classes for LlamaFactory training and RL.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
from typing import Dict, List, Optional, Union
|
7
|
+
from enum import Enum
|
8
|
+
import os
|
9
|
+
from pathlib import Path
|
10
|
+
|
11
|
+
|
12
|
+
class TrainingStrategy(str, Enum):
|
13
|
+
"""Training strategies supported by LlamaFactory."""
|
14
|
+
|
15
|
+
SUPERVISED_FINETUNING = "sft"
|
16
|
+
REINFORCEMENT_LEARNING = "rl"
|
17
|
+
PREFERENCE_OPTIMIZATION = "dpo"
|
18
|
+
PREFERENCE_PAIRWISE = "ppo"
|
19
|
+
|
20
|
+
|
21
|
+
class DatasetFormat(str, Enum):
|
22
|
+
"""Dataset formats supported by LlamaFactory."""
|
23
|
+
|
24
|
+
ALPACA = "alpaca"
|
25
|
+
SHAREGPT = "sharegpt"
|
26
|
+
CUSTOM = "custom"
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class LlamaFactoryConfig:
|
31
|
+
"""Base configuration for LlamaFactory trainers."""
|
32
|
+
|
33
|
+
model_path: str
|
34
|
+
output_dir: str = field(default_factory=lambda: os.path.join(os.getcwd(), "outputs"))
|
35
|
+
|
36
|
+
# Training parameters
|
37
|
+
batch_size: int = 8
|
38
|
+
num_epochs: int = 3
|
39
|
+
learning_rate: float = 2e-5
|
40
|
+
weight_decay: float = 0.01
|
41
|
+
lr_scheduler_type: str = "cosine"
|
42
|
+
warmup_ratio: float = 0.1
|
43
|
+
|
44
|
+
# Model parameters
|
45
|
+
lora_rank: int = 8
|
46
|
+
lora_alpha: int = 16
|
47
|
+
lora_dropout: float = 0.05
|
48
|
+
use_lora: bool = True
|
49
|
+
|
50
|
+
# Data processing
|
51
|
+
max_length: int = 1024
|
52
|
+
dataset_format: DatasetFormat = DatasetFormat.ALPACA
|
53
|
+
|
54
|
+
# Logging
|
55
|
+
log_with: str = "mlflow"
|
56
|
+
logging_steps: int = 10
|
57
|
+
|
58
|
+
def to_dict(self) -> Dict:
|
59
|
+
"""Convert config to dictionary for LlamaFactory CLI."""
|
60
|
+
return {k: v.value if isinstance(v, Enum) else v
|
61
|
+
for k, v in self.__dict__.items()}
|
62
|
+
|
63
|
+
|
64
|
+
@dataclass
|
65
|
+
class SFTConfig(LlamaFactoryConfig):
|
66
|
+
"""Configuration for supervised fine-tuning."""
|
67
|
+
|
68
|
+
strategy: TrainingStrategy = TrainingStrategy.SUPERVISED_FINETUNING
|
69
|
+
train_file: str = ""
|
70
|
+
val_file: Optional[str] = None
|
71
|
+
|
72
|
+
# SFT specific settings
|
73
|
+
cutoff_len: int = 1024
|
74
|
+
normalize_data: bool = True
|
75
|
+
|
76
|
+
|
77
|
+
@dataclass
|
78
|
+
class RLConfig(LlamaFactoryConfig):
|
79
|
+
"""Configuration for reinforcement learning."""
|
80
|
+
|
81
|
+
strategy: TrainingStrategy = TrainingStrategy.REINFORCEMENT_LEARNING
|
82
|
+
reward_model: str = ""
|
83
|
+
train_file: str = ""
|
84
|
+
|
85
|
+
# RL specific settings
|
86
|
+
kl_coef: float = 0.1
|
87
|
+
top_k: int = 0
|
88
|
+
top_p: float = 1.0
|
89
|
+
temperature: float = 1.0
|
90
|
+
|
91
|
+
|
92
|
+
@dataclass
|
93
|
+
class DPOConfig(LlamaFactoryConfig):
|
94
|
+
"""Configuration for direct preference optimization."""
|
95
|
+
|
96
|
+
strategy: TrainingStrategy = TrainingStrategy.PREFERENCE_OPTIMIZATION
|
97
|
+
train_file: str = ""
|
98
|
+
val_file: Optional[str] = None
|
99
|
+
|
100
|
+
# DPO specific settings
|
101
|
+
beta: float = 0.1
|
102
|
+
reference_model: Optional[str] = None
|
103
|
+
|
104
|
+
|
105
|
+
def create_default_config(strategy: TrainingStrategy, model_path: str) -> LlamaFactoryConfig:
|
106
|
+
"""Create a default configuration based on the training strategy."""
|
107
|
+
|
108
|
+
if strategy == TrainingStrategy.SUPERVISED_FINETUNING:
|
109
|
+
return SFTConfig(model_path=model_path)
|
110
|
+
elif strategy == TrainingStrategy.REINFORCEMENT_LEARNING:
|
111
|
+
return RLConfig(model_path=model_path)
|
112
|
+
elif strategy == TrainingStrategy.PREFERENCE_OPTIMIZATION:
|
113
|
+
return DPOConfig(model_path=model_path)
|
114
|
+
else:
|
115
|
+
raise ValueError(f"Unsupported strategy: {strategy}")
|
@@ -0,0 +1,284 @@
|
|
1
|
+
"""
|
2
|
+
Data adapters for LlamaFactory training.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import json
|
7
|
+
import logging
|
8
|
+
from typing import Dict, List, Any, Optional, Union
|
9
|
+
|
10
|
+
from .config import DatasetFormat
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class DataAdapter:
|
16
|
+
"""
|
17
|
+
Base class for adapting datasets to LlamaFactory format.
|
18
|
+
|
19
|
+
This class handles converting data from various formats into
|
20
|
+
the specific format expected by LlamaFactory.
|
21
|
+
|
22
|
+
Example:
|
23
|
+
```python
|
24
|
+
# Create a custom adapter
|
25
|
+
adapter = CustomAdapter()
|
26
|
+
|
27
|
+
# Convert data
|
28
|
+
converted_path = adapter.convert_data(
|
29
|
+
input_file="path/to/source_data.json",
|
30
|
+
output_file="path/to/converted_data.json"
|
31
|
+
)
|
32
|
+
```
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, format_type: DatasetFormat = DatasetFormat.ALPACA):
|
36
|
+
"""
|
37
|
+
Initialize the data adapter.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
format_type: The target format to convert to
|
41
|
+
"""
|
42
|
+
self.format_type = format_type
|
43
|
+
|
44
|
+
def convert_data(self, input_file: str, output_file: Optional[str] = None) -> str:
|
45
|
+
"""
|
46
|
+
Convert data from input format to LlamaFactory format.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
input_file: Path to the input data file
|
50
|
+
output_file: Path to save the converted data
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
Path to the converted data file
|
54
|
+
"""
|
55
|
+
if output_file is None:
|
56
|
+
dirname = os.path.dirname(input_file)
|
57
|
+
basename = os.path.basename(input_file)
|
58
|
+
name, ext = os.path.splitext(basename)
|
59
|
+
output_file = os.path.join(dirname, f"{name}_converted{ext}")
|
60
|
+
|
61
|
+
data = self._load_data(input_file)
|
62
|
+
converted_data = self._convert_data(data)
|
63
|
+
self._save_data(converted_data, output_file)
|
64
|
+
|
65
|
+
return output_file
|
66
|
+
|
67
|
+
def _load_data(self, input_file: str) -> Any:
|
68
|
+
"""
|
69
|
+
Load data from the input file.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
input_file: Path to the input file
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Loaded data
|
76
|
+
"""
|
77
|
+
with open(input_file, 'r', encoding='utf-8') as f:
|
78
|
+
return json.load(f)
|
79
|
+
|
80
|
+
def _save_data(self, data: Any, output_file: str) -> None:
|
81
|
+
"""
|
82
|
+
Save data to the output file.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
data: Data to save
|
86
|
+
output_file: Path to save the data
|
87
|
+
"""
|
88
|
+
with open(output_file, 'w', encoding='utf-8') as f:
|
89
|
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
90
|
+
|
91
|
+
def _convert_data(self, data: Any) -> List[Dict[str, Any]]:
|
92
|
+
"""
|
93
|
+
Convert data to the target format.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
data: Input data
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
Converted data
|
100
|
+
"""
|
101
|
+
raise NotImplementedError("Subclasses must implement this method")
|
102
|
+
|
103
|
+
|
104
|
+
class AlpacaAdapter(DataAdapter):
|
105
|
+
"""
|
106
|
+
Adapter for Alpaca format data.
|
107
|
+
|
108
|
+
Example:
|
109
|
+
```python
|
110
|
+
adapter = AlpacaAdapter()
|
111
|
+
converted_path = adapter.convert_data("custom_data.json")
|
112
|
+
```
|
113
|
+
"""
|
114
|
+
|
115
|
+
def __init__(self):
|
116
|
+
"""Initialize the Alpaca adapter."""
|
117
|
+
super().__init__(DatasetFormat.ALPACA)
|
118
|
+
|
119
|
+
def _convert_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
120
|
+
"""
|
121
|
+
Convert data to Alpaca format.
|
122
|
+
|
123
|
+
Expected output format:
|
124
|
+
[
|
125
|
+
{
|
126
|
+
"instruction": "Task description",
|
127
|
+
"input": "Optional input (context)",
|
128
|
+
"output": "Expected output"
|
129
|
+
},
|
130
|
+
...
|
131
|
+
]
|
132
|
+
|
133
|
+
Args:
|
134
|
+
data: Input data
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Converted data in Alpaca format
|
138
|
+
"""
|
139
|
+
result = []
|
140
|
+
|
141
|
+
for item in data:
|
142
|
+
if isinstance(item, dict):
|
143
|
+
# If already in the expected format, just add to result
|
144
|
+
if all(k in item for k in ["instruction", "output"]):
|
145
|
+
alpaca_item = {
|
146
|
+
"instruction": item["instruction"],
|
147
|
+
"input": item.get("input", ""),
|
148
|
+
"output": item["output"]
|
149
|
+
}
|
150
|
+
result.append(alpaca_item)
|
151
|
+
# Otherwise, try to convert from common formats
|
152
|
+
elif "prompt" in item and "response" in item:
|
153
|
+
alpaca_item = {
|
154
|
+
"instruction": item["prompt"],
|
155
|
+
"input": "",
|
156
|
+
"output": item["response"]
|
157
|
+
}
|
158
|
+
result.append(alpaca_item)
|
159
|
+
elif "question" in item and "answer" in item:
|
160
|
+
alpaca_item = {
|
161
|
+
"instruction": item["question"],
|
162
|
+
"input": "",
|
163
|
+
"output": item["answer"]
|
164
|
+
}
|
165
|
+
result.append(alpaca_item)
|
166
|
+
else:
|
167
|
+
logger.warning(f"Could not convert item: {item}")
|
168
|
+
|
169
|
+
return result
|
170
|
+
|
171
|
+
|
172
|
+
class ShareGPTAdapter(DataAdapter):
|
173
|
+
"""
|
174
|
+
Adapter for ShareGPT format data.
|
175
|
+
|
176
|
+
Example:
|
177
|
+
```python
|
178
|
+
adapter = ShareGPTAdapter()
|
179
|
+
converted_path = adapter.convert_data("sharegpt_data.json")
|
180
|
+
```
|
181
|
+
"""
|
182
|
+
|
183
|
+
def __init__(self):
|
184
|
+
"""Initialize the ShareGPT adapter."""
|
185
|
+
super().__init__(DatasetFormat.SHAREGPT)
|
186
|
+
|
187
|
+
def _convert_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
188
|
+
"""
|
189
|
+
Convert data to ShareGPT format.
|
190
|
+
|
191
|
+
Expected output format:
|
192
|
+
[
|
193
|
+
{
|
194
|
+
"conversations": [
|
195
|
+
{"from": "human", "value": "Human message"},
|
196
|
+
{"from": "gpt", "value": "Assistant response"},
|
197
|
+
...
|
198
|
+
]
|
199
|
+
},
|
200
|
+
...
|
201
|
+
]
|
202
|
+
|
203
|
+
Args:
|
204
|
+
data: Input data
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
Converted data in ShareGPT format
|
208
|
+
"""
|
209
|
+
result = []
|
210
|
+
|
211
|
+
for item in data:
|
212
|
+
if isinstance(item, dict):
|
213
|
+
conversations = []
|
214
|
+
|
215
|
+
# Handle different input formats
|
216
|
+
|
217
|
+
# If already in conversations format
|
218
|
+
if "conversations" in item and isinstance(item["conversations"], list):
|
219
|
+
# Make sure format is correct
|
220
|
+
for conv in item["conversations"]:
|
221
|
+
if "from" in conv and "value" in conv:
|
222
|
+
# Normalize role names
|
223
|
+
role = conv["from"].lower()
|
224
|
+
if role in ["user", "human"]:
|
225
|
+
role = "human"
|
226
|
+
elif role in ["assistant", "gpt", "ai"]:
|
227
|
+
role = "gpt"
|
228
|
+
else:
|
229
|
+
logger.warning(f"Unknown role: {role}, skipping message")
|
230
|
+
continue
|
231
|
+
|
232
|
+
conversations.append({
|
233
|
+
"from": role,
|
234
|
+
"value": conv["value"]
|
235
|
+
})
|
236
|
+
|
237
|
+
# If in QA format
|
238
|
+
elif "question" in item and "answer" in item:
|
239
|
+
conversations = [
|
240
|
+
{"from": "human", "value": item["question"]},
|
241
|
+
{"from": "gpt", "value": item["answer"]}
|
242
|
+
]
|
243
|
+
|
244
|
+
# If in prompt/response format
|
245
|
+
elif "prompt" in item and "response" in item:
|
246
|
+
conversations = [
|
247
|
+
{"from": "human", "value": item["prompt"]},
|
248
|
+
{"from": "gpt", "value": item["response"]}
|
249
|
+
]
|
250
|
+
|
251
|
+
if conversations:
|
252
|
+
result.append({"conversations": conversations})
|
253
|
+
|
254
|
+
return result
|
255
|
+
|
256
|
+
|
257
|
+
class DataAdapterFactory:
|
258
|
+
"""
|
259
|
+
Factory for creating data adapters.
|
260
|
+
|
261
|
+
Example:
|
262
|
+
```python
|
263
|
+
# Create an adapter based on format
|
264
|
+
adapter = DataAdapterFactory.create_adapter(DatasetFormat.ALPACA)
|
265
|
+
```
|
266
|
+
"""
|
267
|
+
|
268
|
+
@staticmethod
|
269
|
+
def create_adapter(format_type: DatasetFormat) -> DataAdapter:
|
270
|
+
"""
|
271
|
+
Create a data adapter for the specified format.
|
272
|
+
|
273
|
+
Args:
|
274
|
+
format_type: The dataset format
|
275
|
+
|
276
|
+
Returns:
|
277
|
+
A data adapter instance
|
278
|
+
"""
|
279
|
+
if format_type == DatasetFormat.ALPACA:
|
280
|
+
return AlpacaAdapter()
|
281
|
+
elif format_type == DatasetFormat.SHAREGPT:
|
282
|
+
return ShareGPTAdapter()
|
283
|
+
else:
|
284
|
+
raise ValueError(f"Unsupported format type: {format_type}")
|