isa-model 0.1.1__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +142 -240
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/services/audio/openai_tts_service.py +104 -3
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +111 -1
- isa_model/inference/services/llm/ollama_llm_service.py +234 -26
- isa_model/inference/services/llm/openai_llm_service.py +225 -28
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/ollama_vision_service.py +143 -17
- isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.2.8.dist-info/METADATA +465 -0
- isa_model-0.2.8.dist-info/RECORD +86 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.1.1.dist-info/METADATA +0 -327
- isa_model-0.1.1.dist-info/RECORD +0 -92
- isa_model-0.1.1.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/WHEEL +0 -0
- {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,115 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Configuration classes for LlamaFactory training and RL.
|
3
|
-
"""
|
4
|
-
|
5
|
-
from dataclasses import dataclass, field
|
6
|
-
from typing import Dict, List, Optional, Union
|
7
|
-
from enum import Enum
|
8
|
-
import os
|
9
|
-
from pathlib import Path
|
10
|
-
|
11
|
-
|
12
|
-
class TrainingStrategy(str, Enum):
|
13
|
-
"""Training strategies supported by LlamaFactory."""
|
14
|
-
|
15
|
-
SUPERVISED_FINETUNING = "sft"
|
16
|
-
REINFORCEMENT_LEARNING = "rl"
|
17
|
-
PREFERENCE_OPTIMIZATION = "dpo"
|
18
|
-
PREFERENCE_PAIRWISE = "ppo"
|
19
|
-
|
20
|
-
|
21
|
-
class DatasetFormat(str, Enum):
|
22
|
-
"""Dataset formats supported by LlamaFactory."""
|
23
|
-
|
24
|
-
ALPACA = "alpaca"
|
25
|
-
SHAREGPT = "sharegpt"
|
26
|
-
CUSTOM = "custom"
|
27
|
-
|
28
|
-
|
29
|
-
@dataclass
|
30
|
-
class LlamaFactoryConfig:
|
31
|
-
"""Base configuration for LlamaFactory trainers."""
|
32
|
-
|
33
|
-
model_path: str
|
34
|
-
output_dir: str = field(default_factory=lambda: os.path.join(os.getcwd(), "outputs"))
|
35
|
-
|
36
|
-
# Training parameters
|
37
|
-
batch_size: int = 8
|
38
|
-
num_epochs: int = 3
|
39
|
-
learning_rate: float = 2e-5
|
40
|
-
weight_decay: float = 0.01
|
41
|
-
lr_scheduler_type: str = "cosine"
|
42
|
-
warmup_ratio: float = 0.1
|
43
|
-
|
44
|
-
# Model parameters
|
45
|
-
lora_rank: int = 8
|
46
|
-
lora_alpha: int = 16
|
47
|
-
lora_dropout: float = 0.05
|
48
|
-
use_lora: bool = True
|
49
|
-
|
50
|
-
# Data processing
|
51
|
-
max_length: int = 1024
|
52
|
-
dataset_format: DatasetFormat = DatasetFormat.ALPACA
|
53
|
-
|
54
|
-
# Logging
|
55
|
-
log_with: str = "mlflow"
|
56
|
-
logging_steps: int = 10
|
57
|
-
|
58
|
-
def to_dict(self) -> Dict:
|
59
|
-
"""Convert config to dictionary for LlamaFactory CLI."""
|
60
|
-
return {k: v.value if isinstance(v, Enum) else v
|
61
|
-
for k, v in self.__dict__.items()}
|
62
|
-
|
63
|
-
|
64
|
-
@dataclass
|
65
|
-
class SFTConfig(LlamaFactoryConfig):
|
66
|
-
"""Configuration for supervised fine-tuning."""
|
67
|
-
|
68
|
-
strategy: TrainingStrategy = TrainingStrategy.SUPERVISED_FINETUNING
|
69
|
-
train_file: str = ""
|
70
|
-
val_file: Optional[str] = None
|
71
|
-
|
72
|
-
# SFT specific settings
|
73
|
-
cutoff_len: int = 1024
|
74
|
-
normalize_data: bool = True
|
75
|
-
|
76
|
-
|
77
|
-
@dataclass
|
78
|
-
class RLConfig(LlamaFactoryConfig):
|
79
|
-
"""Configuration for reinforcement learning."""
|
80
|
-
|
81
|
-
strategy: TrainingStrategy = TrainingStrategy.REINFORCEMENT_LEARNING
|
82
|
-
reward_model: str = ""
|
83
|
-
train_file: str = ""
|
84
|
-
|
85
|
-
# RL specific settings
|
86
|
-
kl_coef: float = 0.1
|
87
|
-
top_k: int = 0
|
88
|
-
top_p: float = 1.0
|
89
|
-
temperature: float = 1.0
|
90
|
-
|
91
|
-
|
92
|
-
@dataclass
|
93
|
-
class DPOConfig(LlamaFactoryConfig):
|
94
|
-
"""Configuration for direct preference optimization."""
|
95
|
-
|
96
|
-
strategy: TrainingStrategy = TrainingStrategy.PREFERENCE_OPTIMIZATION
|
97
|
-
train_file: str = ""
|
98
|
-
val_file: Optional[str] = None
|
99
|
-
|
100
|
-
# DPO specific settings
|
101
|
-
beta: float = 0.1
|
102
|
-
reference_model: Optional[str] = None
|
103
|
-
|
104
|
-
|
105
|
-
def create_default_config(strategy: TrainingStrategy, model_path: str) -> LlamaFactoryConfig:
|
106
|
-
"""Create a default configuration based on the training strategy."""
|
107
|
-
|
108
|
-
if strategy == TrainingStrategy.SUPERVISED_FINETUNING:
|
109
|
-
return SFTConfig(model_path=model_path)
|
110
|
-
elif strategy == TrainingStrategy.REINFORCEMENT_LEARNING:
|
111
|
-
return RLConfig(model_path=model_path)
|
112
|
-
elif strategy == TrainingStrategy.PREFERENCE_OPTIMIZATION:
|
113
|
-
return DPOConfig(model_path=model_path)
|
114
|
-
else:
|
115
|
-
raise ValueError(f"Unsupported strategy: {strategy}")
|
@@ -1,284 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Data adapters for LlamaFactory training.
|
3
|
-
"""
|
4
|
-
|
5
|
-
import os
|
6
|
-
import json
|
7
|
-
import logging
|
8
|
-
from typing import Dict, List, Any, Optional, Union
|
9
|
-
|
10
|
-
from .config import DatasetFormat
|
11
|
-
|
12
|
-
logger = logging.getLogger(__name__)
|
13
|
-
|
14
|
-
|
15
|
-
class DataAdapter:
|
16
|
-
"""
|
17
|
-
Base class for adapting datasets to LlamaFactory format.
|
18
|
-
|
19
|
-
This class handles converting data from various formats into
|
20
|
-
the specific format expected by LlamaFactory.
|
21
|
-
|
22
|
-
Example:
|
23
|
-
```python
|
24
|
-
# Create a custom adapter
|
25
|
-
adapter = CustomAdapter()
|
26
|
-
|
27
|
-
# Convert data
|
28
|
-
converted_path = adapter.convert_data(
|
29
|
-
input_file="path/to/source_data.json",
|
30
|
-
output_file="path/to/converted_data.json"
|
31
|
-
)
|
32
|
-
```
|
33
|
-
"""
|
34
|
-
|
35
|
-
def __init__(self, format_type: DatasetFormat = DatasetFormat.ALPACA):
|
36
|
-
"""
|
37
|
-
Initialize the data adapter.
|
38
|
-
|
39
|
-
Args:
|
40
|
-
format_type: The target format to convert to
|
41
|
-
"""
|
42
|
-
self.format_type = format_type
|
43
|
-
|
44
|
-
def convert_data(self, input_file: str, output_file: Optional[str] = None) -> str:
|
45
|
-
"""
|
46
|
-
Convert data from input format to LlamaFactory format.
|
47
|
-
|
48
|
-
Args:
|
49
|
-
input_file: Path to the input data file
|
50
|
-
output_file: Path to save the converted data
|
51
|
-
|
52
|
-
Returns:
|
53
|
-
Path to the converted data file
|
54
|
-
"""
|
55
|
-
if output_file is None:
|
56
|
-
dirname = os.path.dirname(input_file)
|
57
|
-
basename = os.path.basename(input_file)
|
58
|
-
name, ext = os.path.splitext(basename)
|
59
|
-
output_file = os.path.join(dirname, f"{name}_converted{ext}")
|
60
|
-
|
61
|
-
data = self._load_data(input_file)
|
62
|
-
converted_data = self._convert_data(data)
|
63
|
-
self._save_data(converted_data, output_file)
|
64
|
-
|
65
|
-
return output_file
|
66
|
-
|
67
|
-
def _load_data(self, input_file: str) -> Any:
|
68
|
-
"""
|
69
|
-
Load data from the input file.
|
70
|
-
|
71
|
-
Args:
|
72
|
-
input_file: Path to the input file
|
73
|
-
|
74
|
-
Returns:
|
75
|
-
Loaded data
|
76
|
-
"""
|
77
|
-
with open(input_file, 'r', encoding='utf-8') as f:
|
78
|
-
return json.load(f)
|
79
|
-
|
80
|
-
def _save_data(self, data: Any, output_file: str) -> None:
|
81
|
-
"""
|
82
|
-
Save data to the output file.
|
83
|
-
|
84
|
-
Args:
|
85
|
-
data: Data to save
|
86
|
-
output_file: Path to save the data
|
87
|
-
"""
|
88
|
-
with open(output_file, 'w', encoding='utf-8') as f:
|
89
|
-
json.dump(data, f, ensure_ascii=False, indent=2)
|
90
|
-
|
91
|
-
def _convert_data(self, data: Any) -> List[Dict[str, Any]]:
|
92
|
-
"""
|
93
|
-
Convert data to the target format.
|
94
|
-
|
95
|
-
Args:
|
96
|
-
data: Input data
|
97
|
-
|
98
|
-
Returns:
|
99
|
-
Converted data
|
100
|
-
"""
|
101
|
-
raise NotImplementedError("Subclasses must implement this method")
|
102
|
-
|
103
|
-
|
104
|
-
class AlpacaAdapter(DataAdapter):
|
105
|
-
"""
|
106
|
-
Adapter for Alpaca format data.
|
107
|
-
|
108
|
-
Example:
|
109
|
-
```python
|
110
|
-
adapter = AlpacaAdapter()
|
111
|
-
converted_path = adapter.convert_data("custom_data.json")
|
112
|
-
```
|
113
|
-
"""
|
114
|
-
|
115
|
-
def __init__(self):
|
116
|
-
"""Initialize the Alpaca adapter."""
|
117
|
-
super().__init__(DatasetFormat.ALPACA)
|
118
|
-
|
119
|
-
def _convert_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
120
|
-
"""
|
121
|
-
Convert data to Alpaca format.
|
122
|
-
|
123
|
-
Expected output format:
|
124
|
-
[
|
125
|
-
{
|
126
|
-
"instruction": "Task description",
|
127
|
-
"input": "Optional input (context)",
|
128
|
-
"output": "Expected output"
|
129
|
-
},
|
130
|
-
...
|
131
|
-
]
|
132
|
-
|
133
|
-
Args:
|
134
|
-
data: Input data
|
135
|
-
|
136
|
-
Returns:
|
137
|
-
Converted data in Alpaca format
|
138
|
-
"""
|
139
|
-
result = []
|
140
|
-
|
141
|
-
for item in data:
|
142
|
-
if isinstance(item, dict):
|
143
|
-
# If already in the expected format, just add to result
|
144
|
-
if all(k in item for k in ["instruction", "output"]):
|
145
|
-
alpaca_item = {
|
146
|
-
"instruction": item["instruction"],
|
147
|
-
"input": item.get("input", ""),
|
148
|
-
"output": item["output"]
|
149
|
-
}
|
150
|
-
result.append(alpaca_item)
|
151
|
-
# Otherwise, try to convert from common formats
|
152
|
-
elif "prompt" in item and "response" in item:
|
153
|
-
alpaca_item = {
|
154
|
-
"instruction": item["prompt"],
|
155
|
-
"input": "",
|
156
|
-
"output": item["response"]
|
157
|
-
}
|
158
|
-
result.append(alpaca_item)
|
159
|
-
elif "question" in item and "answer" in item:
|
160
|
-
alpaca_item = {
|
161
|
-
"instruction": item["question"],
|
162
|
-
"input": "",
|
163
|
-
"output": item["answer"]
|
164
|
-
}
|
165
|
-
result.append(alpaca_item)
|
166
|
-
else:
|
167
|
-
logger.warning(f"Could not convert item: {item}")
|
168
|
-
|
169
|
-
return result
|
170
|
-
|
171
|
-
|
172
|
-
class ShareGPTAdapter(DataAdapter):
|
173
|
-
"""
|
174
|
-
Adapter for ShareGPT format data.
|
175
|
-
|
176
|
-
Example:
|
177
|
-
```python
|
178
|
-
adapter = ShareGPTAdapter()
|
179
|
-
converted_path = adapter.convert_data("sharegpt_data.json")
|
180
|
-
```
|
181
|
-
"""
|
182
|
-
|
183
|
-
def __init__(self):
|
184
|
-
"""Initialize the ShareGPT adapter."""
|
185
|
-
super().__init__(DatasetFormat.SHAREGPT)
|
186
|
-
|
187
|
-
def _convert_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
188
|
-
"""
|
189
|
-
Convert data to ShareGPT format.
|
190
|
-
|
191
|
-
Expected output format:
|
192
|
-
[
|
193
|
-
{
|
194
|
-
"conversations": [
|
195
|
-
{"from": "human", "value": "Human message"},
|
196
|
-
{"from": "gpt", "value": "Assistant response"},
|
197
|
-
...
|
198
|
-
]
|
199
|
-
},
|
200
|
-
...
|
201
|
-
]
|
202
|
-
|
203
|
-
Args:
|
204
|
-
data: Input data
|
205
|
-
|
206
|
-
Returns:
|
207
|
-
Converted data in ShareGPT format
|
208
|
-
"""
|
209
|
-
result = []
|
210
|
-
|
211
|
-
for item in data:
|
212
|
-
if isinstance(item, dict):
|
213
|
-
conversations = []
|
214
|
-
|
215
|
-
# Handle different input formats
|
216
|
-
|
217
|
-
# If already in conversations format
|
218
|
-
if "conversations" in item and isinstance(item["conversations"], list):
|
219
|
-
# Make sure format is correct
|
220
|
-
for conv in item["conversations"]:
|
221
|
-
if "from" in conv and "value" in conv:
|
222
|
-
# Normalize role names
|
223
|
-
role = conv["from"].lower()
|
224
|
-
if role in ["user", "human"]:
|
225
|
-
role = "human"
|
226
|
-
elif role in ["assistant", "gpt", "ai"]:
|
227
|
-
role = "gpt"
|
228
|
-
else:
|
229
|
-
logger.warning(f"Unknown role: {role}, skipping message")
|
230
|
-
continue
|
231
|
-
|
232
|
-
conversations.append({
|
233
|
-
"from": role,
|
234
|
-
"value": conv["value"]
|
235
|
-
})
|
236
|
-
|
237
|
-
# If in QA format
|
238
|
-
elif "question" in item and "answer" in item:
|
239
|
-
conversations = [
|
240
|
-
{"from": "human", "value": item["question"]},
|
241
|
-
{"from": "gpt", "value": item["answer"]}
|
242
|
-
]
|
243
|
-
|
244
|
-
# If in prompt/response format
|
245
|
-
elif "prompt" in item and "response" in item:
|
246
|
-
conversations = [
|
247
|
-
{"from": "human", "value": item["prompt"]},
|
248
|
-
{"from": "gpt", "value": item["response"]}
|
249
|
-
]
|
250
|
-
|
251
|
-
if conversations:
|
252
|
-
result.append({"conversations": conversations})
|
253
|
-
|
254
|
-
return result
|
255
|
-
|
256
|
-
|
257
|
-
class DataAdapterFactory:
|
258
|
-
"""
|
259
|
-
Factory for creating data adapters.
|
260
|
-
|
261
|
-
Example:
|
262
|
-
```python
|
263
|
-
# Create an adapter based on format
|
264
|
-
adapter = DataAdapterFactory.create_adapter(DatasetFormat.ALPACA)
|
265
|
-
```
|
266
|
-
"""
|
267
|
-
|
268
|
-
@staticmethod
|
269
|
-
def create_adapter(format_type: DatasetFormat) -> DataAdapter:
|
270
|
-
"""
|
271
|
-
Create a data adapter for the specified format.
|
272
|
-
|
273
|
-
Args:
|
274
|
-
format_type: The dataset format
|
275
|
-
|
276
|
-
Returns:
|
277
|
-
A data adapter instance
|
278
|
-
"""
|
279
|
-
if format_type == DatasetFormat.ALPACA:
|
280
|
-
return AlpacaAdapter()
|
281
|
-
elif format_type == DatasetFormat.SHAREGPT:
|
282
|
-
return ShareGPTAdapter()
|
283
|
-
else:
|
284
|
-
raise ValueError(f"Unsupported format type: {format_type}")
|
@@ -1,185 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Example of fine-tuning with LlamaFactory and MLflow tracking.
|
3
|
-
|
4
|
-
This example demonstrates how to use LlamaFactory for fine-tuning
|
5
|
-
and the MLflow tracking system to monitor and log the process.
|
6
|
-
"""
|
7
|
-
|
8
|
-
import os
|
9
|
-
import argparse
|
10
|
-
import logging
|
11
|
-
from typing import Dict, Any
|
12
|
-
|
13
|
-
from app.services.ai.models.training.engine.llama_factory import (
|
14
|
-
LlamaFactory,
|
15
|
-
DatasetFormat,
|
16
|
-
TrainingStrategy
|
17
|
-
)
|
18
|
-
from app.services.ai.models.mlops import (
|
19
|
-
TrainingTracker,
|
20
|
-
ModelStage
|
21
|
-
)
|
22
|
-
|
23
|
-
|
24
|
-
# Configure logging
|
25
|
-
logging.basicConfig(
|
26
|
-
level=logging.INFO,
|
27
|
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
28
|
-
)
|
29
|
-
logger = logging.getLogger(__name__)
|
30
|
-
|
31
|
-
|
32
|
-
def parse_args():
|
33
|
-
"""Parse command line arguments."""
|
34
|
-
parser = argparse.ArgumentParser(description="Fine-tune with LlamaFactory and MLflow tracking")
|
35
|
-
|
36
|
-
parser.add_argument(
|
37
|
-
"--model_path",
|
38
|
-
type=str,
|
39
|
-
required=True,
|
40
|
-
help="Path or name of the base model"
|
41
|
-
)
|
42
|
-
parser.add_argument(
|
43
|
-
"--train_data",
|
44
|
-
type=str,
|
45
|
-
required=True,
|
46
|
-
help="Path to the training data"
|
47
|
-
)
|
48
|
-
parser.add_argument(
|
49
|
-
"--val_data",
|
50
|
-
type=str,
|
51
|
-
help="Path to the validation data"
|
52
|
-
)
|
53
|
-
parser.add_argument(
|
54
|
-
"--output_dir",
|
55
|
-
type=str,
|
56
|
-
default="outputs",
|
57
|
-
help="Directory to save outputs"
|
58
|
-
)
|
59
|
-
parser.add_argument(
|
60
|
-
"--dataset_format",
|
61
|
-
type=str,
|
62
|
-
choices=["alpaca", "sharegpt", "custom"],
|
63
|
-
default="alpaca",
|
64
|
-
help="Format of the dataset"
|
65
|
-
)
|
66
|
-
parser.add_argument(
|
67
|
-
"--use_lora",
|
68
|
-
action="store_true",
|
69
|
-
help="Whether to use LoRA for training"
|
70
|
-
)
|
71
|
-
parser.add_argument(
|
72
|
-
"--batch_size",
|
73
|
-
type=int,
|
74
|
-
default=8,
|
75
|
-
help="Training batch size"
|
76
|
-
)
|
77
|
-
parser.add_argument(
|
78
|
-
"--num_epochs",
|
79
|
-
type=int,
|
80
|
-
default=3,
|
81
|
-
help="Number of training epochs"
|
82
|
-
)
|
83
|
-
parser.add_argument(
|
84
|
-
"--learning_rate",
|
85
|
-
type=float,
|
86
|
-
default=2e-5,
|
87
|
-
help="Learning rate"
|
88
|
-
)
|
89
|
-
parser.add_argument(
|
90
|
-
"--max_length",
|
91
|
-
type=int,
|
92
|
-
default=1024,
|
93
|
-
help="Maximum sequence length"
|
94
|
-
)
|
95
|
-
parser.add_argument(
|
96
|
-
"--lora_rank",
|
97
|
-
type=int,
|
98
|
-
default=8,
|
99
|
-
help="LoRA rank parameter"
|
100
|
-
)
|
101
|
-
parser.add_argument(
|
102
|
-
"--tracking_uri",
|
103
|
-
type=str,
|
104
|
-
help="URI for MLflow tracking server"
|
105
|
-
)
|
106
|
-
parser.add_argument(
|
107
|
-
"--register_model",
|
108
|
-
action="store_true",
|
109
|
-
help="Whether to register the model in the registry"
|
110
|
-
)
|
111
|
-
|
112
|
-
return parser.parse_args()
|
113
|
-
|
114
|
-
|
115
|
-
def main():
|
116
|
-
"""Run the fine-tuning process with tracking."""
|
117
|
-
args = parse_args()
|
118
|
-
|
119
|
-
# Map dataset format string to enum
|
120
|
-
dataset_format_map = {
|
121
|
-
"alpaca": DatasetFormat.ALPACA,
|
122
|
-
"sharegpt": DatasetFormat.SHAREGPT,
|
123
|
-
"custom": DatasetFormat.CUSTOM
|
124
|
-
}
|
125
|
-
dataset_format = dataset_format_map[args.dataset_format]
|
126
|
-
|
127
|
-
# Create output directory
|
128
|
-
os.makedirs(args.output_dir, exist_ok=True)
|
129
|
-
|
130
|
-
# Initialize LlamaFactory
|
131
|
-
factory = LlamaFactory(base_output_dir=args.output_dir)
|
132
|
-
|
133
|
-
# Initialize training tracker
|
134
|
-
tracker = TrainingTracker(tracking_uri=args.tracking_uri)
|
135
|
-
|
136
|
-
# Get model name from path
|
137
|
-
model_name = os.path.basename(args.model_path)
|
138
|
-
|
139
|
-
# Set up training parameters
|
140
|
-
train_params = {
|
141
|
-
"model_path": args.model_path,
|
142
|
-
"train_data": args.train_data,
|
143
|
-
"val_data": args.val_data,
|
144
|
-
"output_dir": None, # Will be set by factory
|
145
|
-
"dataset_format": dataset_format,
|
146
|
-
"use_lora": args.use_lora,
|
147
|
-
"batch_size": args.batch_size,
|
148
|
-
"num_epochs": args.num_epochs,
|
149
|
-
"learning_rate": args.learning_rate,
|
150
|
-
"max_length": args.max_length,
|
151
|
-
"lora_rank": args.lora_rank
|
152
|
-
}
|
153
|
-
|
154
|
-
# Track the training run with MLflow
|
155
|
-
with tracker.track_training_run(
|
156
|
-
model_name=model_name,
|
157
|
-
training_params=train_params,
|
158
|
-
description=f"Fine-tuning {model_name} with LlamaFactory"
|
159
|
-
) as run_info:
|
160
|
-
# Run the fine-tuning
|
161
|
-
try:
|
162
|
-
model_path = factory.finetune(**train_params)
|
163
|
-
|
164
|
-
# Log success
|
165
|
-
tracker.log_metrics({"success": 1.0})
|
166
|
-
logger.info(f"Fine-tuning completed successfully. Model saved to {model_path}")
|
167
|
-
|
168
|
-
# Register the model if requested
|
169
|
-
if args.register_model:
|
170
|
-
version = tracker.register_trained_model(
|
171
|
-
model_path=model_path,
|
172
|
-
description=f"Fine-tuned {model_name}",
|
173
|
-
stage=ModelStage.STAGING
|
174
|
-
)
|
175
|
-
logger.info(f"Model registered as version {version}")
|
176
|
-
|
177
|
-
except Exception as e:
|
178
|
-
# Log failure
|
179
|
-
tracker.log_metrics({"success": 0.0})
|
180
|
-
logger.error(f"Fine-tuning failed: {e}")
|
181
|
-
raise
|
182
|
-
|
183
|
-
|
184
|
-
if __name__ == "__main__":
|
185
|
-
main()
|