isa-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +359 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +101 -0
- isa_model/inference/providers/replicate_provider.py +107 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/openai_tts_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
- isa_model/inference/services/embedding/openai_embed_service.py +0 -0
- isa_model/inference/services/llm/__init__.py +12 -0
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +99 -0
- isa_model/inference/services/llm/openai_llm_service.py +138 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/openai_vision_service.py +80 -0
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.0.1.dist-info/METADATA +327 -0
- isa_model-0.0.1.dist-info/RECORD +86 -0
- isa_model-0.0.1.dist-info/WHEEL +5 -0
- isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
- isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,185 @@
|
|
1
|
+
"""
|
2
|
+
Example of fine-tuning with LlamaFactory and MLflow tracking.
|
3
|
+
|
4
|
+
This example demonstrates how to use LlamaFactory for fine-tuning
|
5
|
+
and the MLflow tracking system to monitor and log the process.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import argparse
|
10
|
+
import logging
|
11
|
+
from typing import Dict, Any
|
12
|
+
|
13
|
+
from app.services.ai.models.training.engine.llama_factory import (
|
14
|
+
LlamaFactory,
|
15
|
+
DatasetFormat,
|
16
|
+
TrainingStrategy
|
17
|
+
)
|
18
|
+
from app.services.ai.models.mlops import (
|
19
|
+
TrainingTracker,
|
20
|
+
ModelStage
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
# Configure logging
|
25
|
+
logging.basicConfig(
|
26
|
+
level=logging.INFO,
|
27
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
28
|
+
)
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
def parse_args():
|
33
|
+
"""Parse command line arguments."""
|
34
|
+
parser = argparse.ArgumentParser(description="Fine-tune with LlamaFactory and MLflow tracking")
|
35
|
+
|
36
|
+
parser.add_argument(
|
37
|
+
"--model_path",
|
38
|
+
type=str,
|
39
|
+
required=True,
|
40
|
+
help="Path or name of the base model"
|
41
|
+
)
|
42
|
+
parser.add_argument(
|
43
|
+
"--train_data",
|
44
|
+
type=str,
|
45
|
+
required=True,
|
46
|
+
help="Path to the training data"
|
47
|
+
)
|
48
|
+
parser.add_argument(
|
49
|
+
"--val_data",
|
50
|
+
type=str,
|
51
|
+
help="Path to the validation data"
|
52
|
+
)
|
53
|
+
parser.add_argument(
|
54
|
+
"--output_dir",
|
55
|
+
type=str,
|
56
|
+
default="outputs",
|
57
|
+
help="Directory to save outputs"
|
58
|
+
)
|
59
|
+
parser.add_argument(
|
60
|
+
"--dataset_format",
|
61
|
+
type=str,
|
62
|
+
choices=["alpaca", "sharegpt", "custom"],
|
63
|
+
default="alpaca",
|
64
|
+
help="Format of the dataset"
|
65
|
+
)
|
66
|
+
parser.add_argument(
|
67
|
+
"--use_lora",
|
68
|
+
action="store_true",
|
69
|
+
help="Whether to use LoRA for training"
|
70
|
+
)
|
71
|
+
parser.add_argument(
|
72
|
+
"--batch_size",
|
73
|
+
type=int,
|
74
|
+
default=8,
|
75
|
+
help="Training batch size"
|
76
|
+
)
|
77
|
+
parser.add_argument(
|
78
|
+
"--num_epochs",
|
79
|
+
type=int,
|
80
|
+
default=3,
|
81
|
+
help="Number of training epochs"
|
82
|
+
)
|
83
|
+
parser.add_argument(
|
84
|
+
"--learning_rate",
|
85
|
+
type=float,
|
86
|
+
default=2e-5,
|
87
|
+
help="Learning rate"
|
88
|
+
)
|
89
|
+
parser.add_argument(
|
90
|
+
"--max_length",
|
91
|
+
type=int,
|
92
|
+
default=1024,
|
93
|
+
help="Maximum sequence length"
|
94
|
+
)
|
95
|
+
parser.add_argument(
|
96
|
+
"--lora_rank",
|
97
|
+
type=int,
|
98
|
+
default=8,
|
99
|
+
help="LoRA rank parameter"
|
100
|
+
)
|
101
|
+
parser.add_argument(
|
102
|
+
"--tracking_uri",
|
103
|
+
type=str,
|
104
|
+
help="URI for MLflow tracking server"
|
105
|
+
)
|
106
|
+
parser.add_argument(
|
107
|
+
"--register_model",
|
108
|
+
action="store_true",
|
109
|
+
help="Whether to register the model in the registry"
|
110
|
+
)
|
111
|
+
|
112
|
+
return parser.parse_args()
|
113
|
+
|
114
|
+
|
115
|
+
def main():
|
116
|
+
"""Run the fine-tuning process with tracking."""
|
117
|
+
args = parse_args()
|
118
|
+
|
119
|
+
# Map dataset format string to enum
|
120
|
+
dataset_format_map = {
|
121
|
+
"alpaca": DatasetFormat.ALPACA,
|
122
|
+
"sharegpt": DatasetFormat.SHAREGPT,
|
123
|
+
"custom": DatasetFormat.CUSTOM
|
124
|
+
}
|
125
|
+
dataset_format = dataset_format_map[args.dataset_format]
|
126
|
+
|
127
|
+
# Create output directory
|
128
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
129
|
+
|
130
|
+
# Initialize LlamaFactory
|
131
|
+
factory = LlamaFactory(base_output_dir=args.output_dir)
|
132
|
+
|
133
|
+
# Initialize training tracker
|
134
|
+
tracker = TrainingTracker(tracking_uri=args.tracking_uri)
|
135
|
+
|
136
|
+
# Get model name from path
|
137
|
+
model_name = os.path.basename(args.model_path)
|
138
|
+
|
139
|
+
# Set up training parameters
|
140
|
+
train_params = {
|
141
|
+
"model_path": args.model_path,
|
142
|
+
"train_data": args.train_data,
|
143
|
+
"val_data": args.val_data,
|
144
|
+
"output_dir": None, # Will be set by factory
|
145
|
+
"dataset_format": dataset_format,
|
146
|
+
"use_lora": args.use_lora,
|
147
|
+
"batch_size": args.batch_size,
|
148
|
+
"num_epochs": args.num_epochs,
|
149
|
+
"learning_rate": args.learning_rate,
|
150
|
+
"max_length": args.max_length,
|
151
|
+
"lora_rank": args.lora_rank
|
152
|
+
}
|
153
|
+
|
154
|
+
# Track the training run with MLflow
|
155
|
+
with tracker.track_training_run(
|
156
|
+
model_name=model_name,
|
157
|
+
training_params=train_params,
|
158
|
+
description=f"Fine-tuning {model_name} with LlamaFactory"
|
159
|
+
) as run_info:
|
160
|
+
# Run the fine-tuning
|
161
|
+
try:
|
162
|
+
model_path = factory.finetune(**train_params)
|
163
|
+
|
164
|
+
# Log success
|
165
|
+
tracker.log_metrics({"success": 1.0})
|
166
|
+
logger.info(f"Fine-tuning completed successfully. Model saved to {model_path}")
|
167
|
+
|
168
|
+
# Register the model if requested
|
169
|
+
if args.register_model:
|
170
|
+
version = tracker.register_trained_model(
|
171
|
+
model_path=model_path,
|
172
|
+
description=f"Fine-tuned {model_name}",
|
173
|
+
stage=ModelStage.STAGING
|
174
|
+
)
|
175
|
+
logger.info(f"Model registered as version {version}")
|
176
|
+
|
177
|
+
except Exception as e:
|
178
|
+
# Log failure
|
179
|
+
tracker.log_metrics({"success": 0.0})
|
180
|
+
logger.error(f"Fine-tuning failed: {e}")
|
181
|
+
raise
|
182
|
+
|
183
|
+
|
184
|
+
if __name__ == "__main__":
|
185
|
+
main()
|
@@ -0,0 +1,163 @@
|
|
1
|
+
"""
|
2
|
+
Example of RLHF (Reinforcement Learning from Human Feedback) with LlamaFactory and MLflow tracking.
|
3
|
+
|
4
|
+
This example demonstrates how to use LlamaFactory for RLHF
|
5
|
+
and the MLflow tracking system to monitor and log the process.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import argparse
|
10
|
+
import logging
|
11
|
+
from typing import Dict, Any
|
12
|
+
|
13
|
+
from app.services.ai.models.training.engine.llama_factory import (
|
14
|
+
LlamaFactory,
|
15
|
+
TrainingStrategy
|
16
|
+
)
|
17
|
+
from app.services.ai.models.mlops import (
|
18
|
+
TrainingTracker,
|
19
|
+
ModelStage
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
# Configure logging
|
24
|
+
logging.basicConfig(
|
25
|
+
level=logging.INFO,
|
26
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
27
|
+
)
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
def parse_args():
|
32
|
+
"""Parse command line arguments."""
|
33
|
+
parser = argparse.ArgumentParser(description="Run RLHF with LlamaFactory and MLflow tracking")
|
34
|
+
|
35
|
+
parser.add_argument(
|
36
|
+
"--model_path",
|
37
|
+
type=str,
|
38
|
+
required=True,
|
39
|
+
help="Path or name of the base model"
|
40
|
+
)
|
41
|
+
parser.add_argument(
|
42
|
+
"--reward_model",
|
43
|
+
type=str,
|
44
|
+
required=True,
|
45
|
+
help="Path to the reward model"
|
46
|
+
)
|
47
|
+
parser.add_argument(
|
48
|
+
"--train_data",
|
49
|
+
type=str,
|
50
|
+
required=True,
|
51
|
+
help="Path to the training data"
|
52
|
+
)
|
53
|
+
parser.add_argument(
|
54
|
+
"--output_dir",
|
55
|
+
type=str,
|
56
|
+
default="outputs",
|
57
|
+
help="Directory to save outputs"
|
58
|
+
)
|
59
|
+
parser.add_argument(
|
60
|
+
"--use_lora",
|
61
|
+
action="store_true",
|
62
|
+
help="Whether to use LoRA for training"
|
63
|
+
)
|
64
|
+
parser.add_argument(
|
65
|
+
"--batch_size",
|
66
|
+
type=int,
|
67
|
+
default=4,
|
68
|
+
help="Training batch size"
|
69
|
+
)
|
70
|
+
parser.add_argument(
|
71
|
+
"--num_epochs",
|
72
|
+
type=int,
|
73
|
+
default=1,
|
74
|
+
help="Number of training epochs"
|
75
|
+
)
|
76
|
+
parser.add_argument(
|
77
|
+
"--learning_rate",
|
78
|
+
type=float,
|
79
|
+
default=1e-5,
|
80
|
+
help="Learning rate"
|
81
|
+
)
|
82
|
+
parser.add_argument(
|
83
|
+
"--kl_coef",
|
84
|
+
type=float,
|
85
|
+
default=0.1,
|
86
|
+
help="KL coefficient for PPO"
|
87
|
+
)
|
88
|
+
parser.add_argument(
|
89
|
+
"--tracking_uri",
|
90
|
+
type=str,
|
91
|
+
help="URI for MLflow tracking server"
|
92
|
+
)
|
93
|
+
parser.add_argument(
|
94
|
+
"--register_model",
|
95
|
+
action="store_true",
|
96
|
+
help="Whether to register the model in the registry"
|
97
|
+
)
|
98
|
+
|
99
|
+
return parser.parse_args()
|
100
|
+
|
101
|
+
|
102
|
+
def main():
|
103
|
+
"""Run the RLHF process with tracking."""
|
104
|
+
args = parse_args()
|
105
|
+
|
106
|
+
# Create output directory
|
107
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
108
|
+
|
109
|
+
# Initialize LlamaFactory
|
110
|
+
factory = LlamaFactory(base_output_dir=args.output_dir)
|
111
|
+
|
112
|
+
# Initialize training tracker
|
113
|
+
tracker = TrainingTracker(tracking_uri=args.tracking_uri)
|
114
|
+
|
115
|
+
# Get model name from path
|
116
|
+
model_name = os.path.basename(args.model_path)
|
117
|
+
|
118
|
+
# Set up RLHF parameters
|
119
|
+
rl_params = {
|
120
|
+
"model_path": args.model_path,
|
121
|
+
"reward_model": args.reward_model,
|
122
|
+
"train_data": args.train_data,
|
123
|
+
"output_dir": None, # Will be set by factory
|
124
|
+
"use_lora": args.use_lora,
|
125
|
+
"batch_size": args.batch_size,
|
126
|
+
"num_epochs": args.num_epochs,
|
127
|
+
"learning_rate": args.learning_rate,
|
128
|
+
"kl_coef": args.kl_coef
|
129
|
+
}
|
130
|
+
|
131
|
+
# Track the RLHF run with MLflow
|
132
|
+
with tracker.track_training_run(
|
133
|
+
model_name=model_name,
|
134
|
+
training_params=rl_params,
|
135
|
+
description=f"RLHF for {model_name} with LlamaFactory",
|
136
|
+
experiment_type="rl"
|
137
|
+
) as run_info:
|
138
|
+
# Run the RLHF
|
139
|
+
try:
|
140
|
+
model_path = factory.rlhf(**rl_params)
|
141
|
+
|
142
|
+
# Log success
|
143
|
+
tracker.log_metrics({"success": 1.0})
|
144
|
+
logger.info(f"RLHF completed successfully. Model saved to {model_path}")
|
145
|
+
|
146
|
+
# Register the model if requested
|
147
|
+
if args.register_model:
|
148
|
+
version = tracker.register_trained_model(
|
149
|
+
model_path=model_path,
|
150
|
+
description=f"RLHF-tuned {model_name}",
|
151
|
+
stage=ModelStage.STAGING
|
152
|
+
)
|
153
|
+
logger.info(f"Model registered as version {version}")
|
154
|
+
|
155
|
+
except Exception as e:
|
156
|
+
# Log failure
|
157
|
+
tracker.log_metrics({"success": 0.0})
|
158
|
+
logger.error(f"RLHF failed: {e}")
|
159
|
+
raise
|
160
|
+
|
161
|
+
|
162
|
+
if __name__ == "__main__":
|
163
|
+
main()
|
@@ -0,0 +1,331 @@
|
|
1
|
+
"""
|
2
|
+
LlamaFactory interface module.
|
3
|
+
|
4
|
+
This module provides the main interface for using LlamaFactory functionality.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
import logging
|
9
|
+
from typing import Optional, Dict, Any, Union, List
|
10
|
+
|
11
|
+
from .config import (
|
12
|
+
LlamaFactoryConfig,
|
13
|
+
SFTConfig,
|
14
|
+
RLConfig,
|
15
|
+
DPOConfig,
|
16
|
+
TrainingStrategy,
|
17
|
+
DatasetFormat,
|
18
|
+
create_default_config
|
19
|
+
)
|
20
|
+
from .trainer import LlamaFactoryTrainer
|
21
|
+
from .rl import LlamaFactoryRL
|
22
|
+
from .data_adapter import DataAdapterFactory
|
23
|
+
|
24
|
+
|
25
|
+
logger = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
|
28
|
+
class LlamaFactory:
|
29
|
+
"""
|
30
|
+
Main interface class for LlamaFactory operations.
|
31
|
+
|
32
|
+
This class provides a simplified interface to the LlamaFactory functionality
|
33
|
+
for fine-tuning and reinforcement learning.
|
34
|
+
|
35
|
+
Example for fine-tuning:
|
36
|
+
```python
|
37
|
+
# Create a LlamaFactory instance
|
38
|
+
factory = LlamaFactory()
|
39
|
+
|
40
|
+
# Fine-tune a model
|
41
|
+
model_path = factory.finetune(
|
42
|
+
model_path="meta-llama/Llama-2-7b-hf",
|
43
|
+
train_data="path/to/data.json",
|
44
|
+
val_data="path/to/val_data.json", # Optional
|
45
|
+
output_dir="path/to/output",
|
46
|
+
dataset_format=DatasetFormat.ALPACA,
|
47
|
+
use_lora=True,
|
48
|
+
num_epochs=3
|
49
|
+
)
|
50
|
+
```
|
51
|
+
|
52
|
+
Example for RL training:
|
53
|
+
```python
|
54
|
+
# Create a LlamaFactory instance
|
55
|
+
factory = LlamaFactory()
|
56
|
+
|
57
|
+
# Train with DPO
|
58
|
+
model_path = factory.dpo(
|
59
|
+
model_path="meta-llama/Llama-2-7b-hf",
|
60
|
+
train_data="path/to/preferences.json",
|
61
|
+
output_dir="path/to/output",
|
62
|
+
reference_model="meta-llama/Llama-2-7b-hf", # Optional
|
63
|
+
beta=0.1
|
64
|
+
)
|
65
|
+
```
|
66
|
+
"""
|
67
|
+
|
68
|
+
def __init__(self, base_output_dir: Optional[str] = None):
|
69
|
+
"""
|
70
|
+
Initialize the LlamaFactory interface.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
base_output_dir: Base directory for outputs
|
74
|
+
"""
|
75
|
+
self.base_output_dir = base_output_dir or os.path.join(os.getcwd(), "training_outputs")
|
76
|
+
os.makedirs(self.base_output_dir, exist_ok=True)
|
77
|
+
|
78
|
+
def _get_output_dir(self, name: str, output_dir: Optional[str] = None) -> str:
|
79
|
+
"""
|
80
|
+
Get the output directory for training.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
name: Name for the output directory
|
84
|
+
output_dir: Optional specific output directory
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
Output directory path
|
88
|
+
"""
|
89
|
+
if output_dir:
|
90
|
+
return output_dir
|
91
|
+
|
92
|
+
import datetime
|
93
|
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
94
|
+
return os.path.join(self.base_output_dir, f"{name}_{timestamp}")
|
95
|
+
|
96
|
+
def finetune(
|
97
|
+
self,
|
98
|
+
model_path: str,
|
99
|
+
train_data: str,
|
100
|
+
val_data: Optional[str] = None,
|
101
|
+
output_dir: Optional[str] = None,
|
102
|
+
dataset_format: DatasetFormat = DatasetFormat.ALPACA,
|
103
|
+
use_lora: bool = True,
|
104
|
+
batch_size: int = 8,
|
105
|
+
num_epochs: int = 3,
|
106
|
+
learning_rate: float = 2e-5,
|
107
|
+
lora_rank: int = 8,
|
108
|
+
lora_alpha: int = 16,
|
109
|
+
max_length: int = 1024,
|
110
|
+
**kwargs
|
111
|
+
) -> str:
|
112
|
+
"""
|
113
|
+
Fine-tune a model using supervised learning.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
model_path: Path or name of the base model
|
117
|
+
train_data: Path to the training data
|
118
|
+
val_data: Path to the validation data (optional)
|
119
|
+
output_dir: Directory to save outputs
|
120
|
+
dataset_format: Format of the dataset
|
121
|
+
use_lora: Whether to use LoRA for training
|
122
|
+
batch_size: Training batch size
|
123
|
+
num_epochs: Number of training epochs
|
124
|
+
learning_rate: Learning rate
|
125
|
+
lora_rank: LoRA rank parameter
|
126
|
+
lora_alpha: LoRA alpha parameter
|
127
|
+
max_length: Maximum sequence length
|
128
|
+
**kwargs: Additional parameters for SFTConfig
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
Path to the trained model
|
132
|
+
"""
|
133
|
+
# Check if data conversion is needed and convert
|
134
|
+
adapter = DataAdapterFactory.create_adapter(dataset_format)
|
135
|
+
converted_train_data = adapter.convert_data(train_data)
|
136
|
+
converted_val_data = adapter.convert_data(val_data) if val_data else None
|
137
|
+
|
138
|
+
# Create configuration
|
139
|
+
output_dir = self._get_output_dir("sft", output_dir)
|
140
|
+
config = SFTConfig(
|
141
|
+
model_path=model_path,
|
142
|
+
train_file=converted_train_data,
|
143
|
+
val_file=converted_val_data,
|
144
|
+
output_dir=output_dir,
|
145
|
+
use_lora=use_lora,
|
146
|
+
batch_size=batch_size,
|
147
|
+
num_epochs=num_epochs,
|
148
|
+
learning_rate=learning_rate,
|
149
|
+
lora_rank=lora_rank,
|
150
|
+
lora_alpha=lora_alpha,
|
151
|
+
max_length=max_length,
|
152
|
+
dataset_format=dataset_format,
|
153
|
+
**kwargs
|
154
|
+
)
|
155
|
+
|
156
|
+
# Initialize and run trainer
|
157
|
+
trainer = LlamaFactoryTrainer(config)
|
158
|
+
model_dir = trainer.train()
|
159
|
+
|
160
|
+
# Export model if using LoRA
|
161
|
+
if use_lora:
|
162
|
+
model_dir = trainer.export_model()
|
163
|
+
|
164
|
+
return model_dir
|
165
|
+
|
166
|
+
def rlhf(
|
167
|
+
self,
|
168
|
+
model_path: str,
|
169
|
+
reward_model: str,
|
170
|
+
train_data: str,
|
171
|
+
output_dir: Optional[str] = None,
|
172
|
+
use_lora: bool = True,
|
173
|
+
batch_size: int = 4,
|
174
|
+
num_epochs: int = 1,
|
175
|
+
learning_rate: float = 1e-5,
|
176
|
+
kl_coef: float = 0.1,
|
177
|
+
**kwargs
|
178
|
+
) -> str:
|
179
|
+
"""
|
180
|
+
Train a model with RLHF (Reinforcement Learning from Human Feedback).
|
181
|
+
|
182
|
+
Args:
|
183
|
+
model_path: Path or name of the base model
|
184
|
+
reward_model: Path to the reward model
|
185
|
+
train_data: Path to the training data
|
186
|
+
output_dir: Directory to save outputs
|
187
|
+
use_lora: Whether to use LoRA for training
|
188
|
+
batch_size: Training batch size
|
189
|
+
num_epochs: Number of training epochs
|
190
|
+
learning_rate: Learning rate
|
191
|
+
kl_coef: KL coefficient for PPO
|
192
|
+
**kwargs: Additional parameters for RLConfig
|
193
|
+
|
194
|
+
Returns:
|
195
|
+
Path to the trained model
|
196
|
+
"""
|
197
|
+
# Create configuration
|
198
|
+
output_dir = self._get_output_dir("rlhf", output_dir)
|
199
|
+
config = RLConfig(
|
200
|
+
model_path=model_path,
|
201
|
+
reward_model=reward_model,
|
202
|
+
train_file=train_data,
|
203
|
+
output_dir=output_dir,
|
204
|
+
use_lora=use_lora,
|
205
|
+
batch_size=batch_size,
|
206
|
+
num_epochs=num_epochs,
|
207
|
+
learning_rate=learning_rate,
|
208
|
+
kl_coef=kl_coef,
|
209
|
+
**kwargs
|
210
|
+
)
|
211
|
+
|
212
|
+
# Initialize and run RL trainer
|
213
|
+
rl_trainer = LlamaFactoryRL(config)
|
214
|
+
model_dir = rl_trainer.train()
|
215
|
+
|
216
|
+
# Export model if using LoRA
|
217
|
+
if use_lora:
|
218
|
+
model_dir = rl_trainer.export_model()
|
219
|
+
|
220
|
+
return model_dir
|
221
|
+
|
222
|
+
def dpo(
|
223
|
+
self,
|
224
|
+
model_path: str,
|
225
|
+
train_data: str,
|
226
|
+
val_data: Optional[str] = None,
|
227
|
+
reference_model: Optional[str] = None,
|
228
|
+
output_dir: Optional[str] = None,
|
229
|
+
use_lora: bool = True,
|
230
|
+
batch_size: int = 4,
|
231
|
+
num_epochs: int = 3,
|
232
|
+
learning_rate: float = 5e-6,
|
233
|
+
beta: float = 0.1,
|
234
|
+
**kwargs
|
235
|
+
) -> str:
|
236
|
+
"""
|
237
|
+
Train a model with DPO (Direct Preference Optimization).
|
238
|
+
|
239
|
+
Args:
|
240
|
+
model_path: Path or name of the base model
|
241
|
+
train_data: Path to the training data
|
242
|
+
val_data: Path to the validation data (optional)
|
243
|
+
reference_model: Path to the reference model (optional)
|
244
|
+
output_dir: Directory to save outputs
|
245
|
+
use_lora: Whether to use LoRA for training
|
246
|
+
batch_size: Training batch size
|
247
|
+
num_epochs: Number of training epochs
|
248
|
+
learning_rate: Learning rate
|
249
|
+
beta: DPO beta parameter
|
250
|
+
**kwargs: Additional parameters for DPOConfig
|
251
|
+
|
252
|
+
Returns:
|
253
|
+
Path to the trained model
|
254
|
+
"""
|
255
|
+
# Create configuration
|
256
|
+
output_dir = self._get_output_dir("dpo", output_dir)
|
257
|
+
config = DPOConfig(
|
258
|
+
model_path=model_path,
|
259
|
+
train_file=train_data,
|
260
|
+
val_file=val_data,
|
261
|
+
reference_model=reference_model,
|
262
|
+
output_dir=output_dir,
|
263
|
+
use_lora=use_lora,
|
264
|
+
batch_size=batch_size,
|
265
|
+
num_epochs=num_epochs,
|
266
|
+
learning_rate=learning_rate,
|
267
|
+
beta=beta,
|
268
|
+
**kwargs
|
269
|
+
)
|
270
|
+
|
271
|
+
# Initialize and run DPO trainer
|
272
|
+
dpo_trainer = LlamaFactoryRL(config)
|
273
|
+
model_dir = dpo_trainer.train()
|
274
|
+
|
275
|
+
# Export model if using LoRA
|
276
|
+
if use_lora:
|
277
|
+
model_dir = dpo_trainer.export_model()
|
278
|
+
|
279
|
+
return model_dir
|
280
|
+
|
281
|
+
def train_reward_model(
|
282
|
+
self,
|
283
|
+
model_path: str,
|
284
|
+
train_data: str,
|
285
|
+
val_data: Optional[str] = None,
|
286
|
+
output_dir: Optional[str] = None,
|
287
|
+
use_lora: bool = True,
|
288
|
+
batch_size: int = 8,
|
289
|
+
num_epochs: int = 3,
|
290
|
+
learning_rate: float = 1e-5,
|
291
|
+
**kwargs
|
292
|
+
) -> str:
|
293
|
+
"""
|
294
|
+
Train a reward model for RLHF.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
model_path: Path or name of the base model
|
298
|
+
train_data: Path to the training data with preferences
|
299
|
+
val_data: Path to the validation data (optional)
|
300
|
+
output_dir: Directory to save outputs
|
301
|
+
use_lora: Whether to use LoRA for training
|
302
|
+
batch_size: Training batch size
|
303
|
+
num_epochs: Number of training epochs
|
304
|
+
learning_rate: Learning rate
|
305
|
+
**kwargs: Additional parameters for RLConfig
|
306
|
+
|
307
|
+
Returns:
|
308
|
+
Path to the trained reward model
|
309
|
+
"""
|
310
|
+
# Create temporary RL config
|
311
|
+
output_dir = self._get_output_dir("reward_model", output_dir)
|
312
|
+
config = RLConfig(
|
313
|
+
model_path=model_path,
|
314
|
+
train_file=train_data,
|
315
|
+
output_dir=output_dir,
|
316
|
+
use_lora=use_lora,
|
317
|
+
batch_size=batch_size,
|
318
|
+
num_epochs=num_epochs,
|
319
|
+
learning_rate=learning_rate,
|
320
|
+
**kwargs
|
321
|
+
)
|
322
|
+
|
323
|
+
# Initialize RL trainer and train reward model
|
324
|
+
rl_trainer = LlamaFactoryRL(config)
|
325
|
+
model_dir = rl_trainer.train_reward_model()
|
326
|
+
|
327
|
+
# Export model if using LoRA
|
328
|
+
if use_lora:
|
329
|
+
model_dir = rl_trainer.export_model()
|
330
|
+
|
331
|
+
return model_dir
|