isa-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/mlflow_gateway/__init__.py +8 -0
  12. isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
  13. isa_model/deployment/unified_multimodal_client.py +341 -0
  14. isa_model/inference/__init__.py +11 -0
  15. isa_model/inference/adapter/triton_adapter.py +453 -0
  16. isa_model/inference/adapter/unified_api.py +248 -0
  17. isa_model/inference/ai_factory.py +354 -0
  18. isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
  19. isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
  20. isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
  21. isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
  22. isa_model/inference/backends/__init__.py +53 -0
  23. isa_model/inference/backends/base_backend_client.py +26 -0
  24. isa_model/inference/backends/container_services.py +104 -0
  25. isa_model/inference/backends/local_services.py +72 -0
  26. isa_model/inference/backends/openai_client.py +130 -0
  27. isa_model/inference/backends/replicate_client.py +197 -0
  28. isa_model/inference/backends/third_party_services.py +239 -0
  29. isa_model/inference/backends/triton_client.py +97 -0
  30. isa_model/inference/base.py +46 -0
  31. isa_model/inference/client_sdk/__init__.py +0 -0
  32. isa_model/inference/client_sdk/client.py +134 -0
  33. isa_model/inference/client_sdk/client_data_std.py +34 -0
  34. isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
  35. isa_model/inference/client_sdk/exceptions.py +0 -0
  36. isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
  37. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
  38. isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
  39. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
  40. isa_model/inference/providers/__init__.py +19 -0
  41. isa_model/inference/providers/base_provider.py +30 -0
  42. isa_model/inference/providers/model_cache_manager.py +341 -0
  43. isa_model/inference/providers/ollama_provider.py +73 -0
  44. isa_model/inference/providers/openai_provider.py +87 -0
  45. isa_model/inference/providers/replicate_provider.py +94 -0
  46. isa_model/inference/providers/triton_provider.py +439 -0
  47. isa_model/inference/providers/vllm_provider.py +0 -0
  48. isa_model/inference/providers/yyds_provider.py +83 -0
  49. isa_model/inference/services/__init__.py +14 -0
  50. isa_model/inference/services/audio/fish_speech/handler.py +215 -0
  51. isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
  52. isa_model/inference/services/audio/triton_speech_service.py +138 -0
  53. isa_model/inference/services/audio/whisper_service.py +186 -0
  54. isa_model/inference/services/audio/yyds_audio_service.py +71 -0
  55. isa_model/inference/services/base_service.py +106 -0
  56. isa_model/inference/services/base_tts_service.py +66 -0
  57. isa_model/inference/services/embedding/bge_service.py +183 -0
  58. isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
  59. isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
  60. isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
  61. isa_model/inference/services/llm/__init__.py +16 -0
  62. isa_model/inference/services/llm/gemma_service.py +143 -0
  63. isa_model/inference/services/llm/llama_service.py +143 -0
  64. isa_model/inference/services/llm/ollama_llm_service.py +108 -0
  65. isa_model/inference/services/llm/openai_llm_service.py +129 -0
  66. isa_model/inference/services/llm/replicate_llm_service.py +179 -0
  67. isa_model/inference/services/llm/triton_llm_service.py +230 -0
  68. isa_model/inference/services/others/table_transformer_service.py +61 -0
  69. isa_model/inference/services/vision/__init__.py +12 -0
  70. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  71. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  72. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  73. isa_model/inference/services/vision/replicate_vision_service.py +241 -0
  74. isa_model/inference/services/vision/triton_vision_service.py +199 -0
  75. isa_model/inference/services/vision/yyds_vision_service.py +80 -0
  76. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  77. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  78. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  79. isa_model/scripts/inference_tracker.py +283 -0
  80. isa_model/scripts/mlflow_manager.py +379 -0
  81. isa_model/scripts/model_registry.py +465 -0
  82. isa_model/scripts/start_mlflow.py +95 -0
  83. isa_model/scripts/training_tracker.py +257 -0
  84. isa_model/training/engine/llama_factory/__init__.py +39 -0
  85. isa_model/training/engine/llama_factory/config.py +115 -0
  86. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  87. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  88. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  89. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  90. isa_model/training/engine/llama_factory/factory.py +331 -0
  91. isa_model/training/engine/llama_factory/rl.py +254 -0
  92. isa_model/training/engine/llama_factory/trainer.py +171 -0
  93. isa_model/training/image_model/configs/create_config.py +37 -0
  94. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  95. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  96. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  97. isa_model/training/image_model/prepare_upload.py +17 -0
  98. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  99. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  100. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  101. isa_model/training/image_model/train/train.py +42 -0
  102. isa_model/training/image_model/train/train_flux.py +41 -0
  103. isa_model/training/image_model/train/train_lora.py +57 -0
  104. isa_model/training/image_model/train_main.py +25 -0
  105. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  106. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  107. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  108. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  109. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  110. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  111. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  112. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  113. isa_model-0.1.0.dist-info/METADATA +116 -0
  114. isa_model-0.1.0.dist-info/RECORD +117 -0
  115. isa_model-0.1.0.dist-info/WHEEL +5 -0
  116. isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
  117. isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,257 @@
1
+ """
2
+ MLflow tracker for training workflows.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import logging
8
+ from typing import Dict, List, Optional, Any, Union
9
+ from contextlib import contextmanager
10
+
11
+ from .mlflow_manager import MLflowManager, ExperimentType
12
+ from .model_registry import ModelRegistry, ModelStage
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class TrainingTracker:
19
+ """
20
+ Tracker for model training workflows.
21
+
22
+ This class provides utilities to track model training using MLflow
23
+ and register trained models in the model registry.
24
+
25
+ Example:
26
+ ```python
27
+ # Initialize tracker
28
+ tracker = TrainingTracker(
29
+ tracking_uri="http://localhost:5000",
30
+ registry_uri="http://localhost:5000"
31
+ )
32
+
33
+ # Start tracking training
34
+ with tracker.track_training_run(
35
+ model_name="llama-7b",
36
+ training_params={
37
+ "learning_rate": 2e-5,
38
+ "batch_size": 8,
39
+ "epochs": 3
40
+ }
41
+ ) as run_info:
42
+ # Train the model...
43
+
44
+ # Log metrics during training
45
+ tracker.log_metrics({
46
+ "train_loss": 0.1,
47
+ "val_loss": 0.2
48
+ })
49
+
50
+ # After training completes
51
+ model_path = "/path/to/trained_model"
52
+
53
+ # Register the model
54
+ tracker.register_trained_model(
55
+ model_path=model_path,
56
+ metrics={
57
+ "accuracy": 0.95,
58
+ "f1": 0.92
59
+ },
60
+ stage=ModelStage.STAGING
61
+ )
62
+ ```
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ tracking_uri: Optional[str] = None,
68
+ artifact_uri: Optional[str] = None,
69
+ registry_uri: Optional[str] = None
70
+ ):
71
+ """
72
+ Initialize the training tracker.
73
+
74
+ Args:
75
+ tracking_uri: URI for MLflow tracking server
76
+ artifact_uri: URI for MLflow artifacts
77
+ registry_uri: URI for MLflow model registry
78
+ """
79
+ self.mlflow_manager = MLflowManager(
80
+ tracking_uri=tracking_uri,
81
+ artifact_uri=artifact_uri,
82
+ registry_uri=registry_uri
83
+ )
84
+ self.model_registry = ModelRegistry(
85
+ tracking_uri=tracking_uri,
86
+ registry_uri=registry_uri
87
+ )
88
+ self.current_run_info = {}
89
+
90
+ @contextmanager
91
+ def track_training_run(
92
+ self,
93
+ model_name: str,
94
+ training_params: Dict[str, Any],
95
+ description: Optional[str] = None,
96
+ tags: Optional[Dict[str, str]] = None,
97
+ experiment_type: ExperimentType = ExperimentType.TRAINING
98
+ ):
99
+ """
100
+ Track a training run with MLflow.
101
+
102
+ Args:
103
+ model_name: Name of the model being trained
104
+ training_params: Parameters for the training run
105
+ description: Description of the training run
106
+ tags: Tags for the training run
107
+ experiment_type: Type of experiment
108
+
109
+ Yields:
110
+ Dictionary with run information
111
+ """
112
+ run_info = {
113
+ "model_name": model_name,
114
+ "params": training_params,
115
+ "metrics": {}
116
+ }
117
+
118
+ # Add description to tags if provided
119
+ if tags is None:
120
+ tags = {}
121
+
122
+ if description:
123
+ tags["description"] = description
124
+
125
+ # Start the MLflow run
126
+ with self.mlflow_manager.start_run(
127
+ experiment_type=experiment_type,
128
+ model_name=model_name,
129
+ tags=tags
130
+ ) as run:
131
+ run_info["run_id"] = run.info.run_id
132
+ run_info["experiment_id"] = run.info.experiment_id
133
+ run_info["status"] = "running"
134
+
135
+ # Save parameters
136
+ self.mlflow_manager.log_params(training_params)
137
+
138
+ self.current_run_info = run_info
139
+ try:
140
+ yield run_info
141
+ # Mark as successful if no exceptions
142
+ run_info["status"] = "completed"
143
+ except Exception as e:
144
+ # Mark as failed if exception occurred
145
+ run_info["status"] = "failed"
146
+ run_info["error"] = str(e)
147
+ self.mlflow_manager.set_tracking_tag("error", str(e))
148
+ raise
149
+ finally:
150
+ self.current_run_info = {}
151
+
152
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
153
+ """
154
+ Log metrics to the current run.
155
+
156
+ Args:
157
+ metrics: Dictionary of metrics to log
158
+ step: Step value for the metrics
159
+ """
160
+ self.mlflow_manager.log_metrics(metrics, step)
161
+
162
+ if self.current_run_info:
163
+ if "metrics" not in self.current_run_info:
164
+ self.current_run_info["metrics"] = {}
165
+
166
+ # Only keep the latest metrics
167
+ self.current_run_info["metrics"].update(metrics)
168
+
169
+ def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
170
+ """
171
+ Log artifacts to the current run.
172
+
173
+ Args:
174
+ local_dir: Local directory containing artifacts
175
+ artifact_path: Path for the artifacts in MLflow
176
+ """
177
+ self.mlflow_manager.log_artifacts(local_dir, artifact_path)
178
+
179
+ def register_trained_model(
180
+ self,
181
+ model_path: str,
182
+ metrics: Optional[Dict[str, float]] = None,
183
+ description: Optional[str] = None,
184
+ tags: Optional[Dict[str, str]] = None,
185
+ stage: Optional[ModelStage] = None,
186
+ flavor: str = "pyfunc"
187
+ ) -> Optional[str]:
188
+ """
189
+ Register a trained model with MLflow.
190
+
191
+ Args:
192
+ model_path: Path to the trained model
193
+ metrics: Evaluation metrics for the model
194
+ description: Description of the model
195
+ tags: Tags for the model
196
+ stage: Stage to register the model in
197
+ flavor: MLflow model flavor
198
+
199
+ Returns:
200
+ Version of the registered model or None if registration failed
201
+ """
202
+ if not self.current_run_info:
203
+ logger.warning("No active run. Model cannot be registered.")
204
+ return None
205
+
206
+ model_name = self.current_run_info.get("model_name")
207
+ if not model_name:
208
+ logger.warning("Model name not available in run info. Using generic name.")
209
+ model_name = "unnamed_model"
210
+
211
+ # Log final metrics if provided
212
+ if metrics:
213
+ self.log_metrics(metrics)
214
+
215
+ # Prepare model tags
216
+ if tags is None:
217
+ tags = {}
218
+
219
+ # Add run ID to tags
220
+ tags["run_id"] = self.current_run_info.get("run_id", "")
221
+
222
+ # Add metrics to tags
223
+ for k, v in self.current_run_info.get("metrics", {}).items():
224
+ tags[f"metric.{k}"] = str(v)
225
+
226
+ # Log model to MLflow
227
+ artifact_path = self.mlflow_manager.log_model(
228
+ model_path=model_path,
229
+ name=model_name,
230
+ flavor=flavor
231
+ )
232
+
233
+ if not artifact_path:
234
+ logger.error("Failed to log model to MLflow.")
235
+ return None
236
+
237
+ # Get model URI
238
+ run_id = self.current_run_info.get("run_id")
239
+ model_uri = f"runs:/{run_id}/{artifact_path}"
240
+
241
+ # Register the model
242
+ version = self.model_registry.register_model(
243
+ name=model_name,
244
+ source=model_uri,
245
+ description=description,
246
+ tags=tags
247
+ )
248
+
249
+ # Transition to the specified stage if provided
250
+ if version and stage:
251
+ self.model_registry.transition_model_version_stage(
252
+ name=model_name,
253
+ version=version,
254
+ stage=stage
255
+ )
256
+
257
+ return version
@@ -0,0 +1,39 @@
1
+ """
2
+ LlamaFactory engine for fine-tuning and reinforcement learning.
3
+
4
+ This package provides interfaces for using LlamaFactory to:
5
+ - Fine-tune models with various datasets
6
+ - Perform reinforcement learning from human feedback (RLHF)
7
+ - Support instruction tuning and preference optimization
8
+ """
9
+
10
+ from .config import (
11
+ LlamaFactoryConfig,
12
+ SFTConfig,
13
+ RLConfig,
14
+ DPOConfig,
15
+ TrainingStrategy,
16
+ DatasetFormat,
17
+ create_default_config
18
+ )
19
+ from .trainer import LlamaFactoryTrainer
20
+ from .rl import LlamaFactoryRL
21
+ from .factory import LlamaFactory
22
+ from .data_adapter import DataAdapter, AlpacaAdapter, ShareGPTAdapter, DataAdapterFactory
23
+
24
+ __all__ = [
25
+ "LlamaFactoryTrainer",
26
+ "LlamaFactoryRL",
27
+ "LlamaFactoryConfig",
28
+ "SFTConfig",
29
+ "RLConfig",
30
+ "DPOConfig",
31
+ "TrainingStrategy",
32
+ "DatasetFormat",
33
+ "create_default_config",
34
+ "LlamaFactory",
35
+ "DataAdapter",
36
+ "AlpacaAdapter",
37
+ "ShareGPTAdapter",
38
+ "DataAdapterFactory"
39
+ ]
@@ -0,0 +1,115 @@
1
+ """
2
+ Configuration classes for LlamaFactory training and RL.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Union
7
+ from enum import Enum
8
+ import os
9
+ from pathlib import Path
10
+
11
+
12
+ class TrainingStrategy(str, Enum):
13
+ """Training strategies supported by LlamaFactory."""
14
+
15
+ SUPERVISED_FINETUNING = "sft"
16
+ REINFORCEMENT_LEARNING = "rl"
17
+ PREFERENCE_OPTIMIZATION = "dpo"
18
+ PREFERENCE_PAIRWISE = "ppo"
19
+
20
+
21
+ class DatasetFormat(str, Enum):
22
+ """Dataset formats supported by LlamaFactory."""
23
+
24
+ ALPACA = "alpaca"
25
+ SHAREGPT = "sharegpt"
26
+ CUSTOM = "custom"
27
+
28
+
29
+ @dataclass
30
+ class LlamaFactoryConfig:
31
+ """Base configuration for LlamaFactory trainers."""
32
+
33
+ model_path: str
34
+ output_dir: str = field(default_factory=lambda: os.path.join(os.getcwd(), "outputs"))
35
+
36
+ # Training parameters
37
+ batch_size: int = 8
38
+ num_epochs: int = 3
39
+ learning_rate: float = 2e-5
40
+ weight_decay: float = 0.01
41
+ lr_scheduler_type: str = "cosine"
42
+ warmup_ratio: float = 0.1
43
+
44
+ # Model parameters
45
+ lora_rank: int = 8
46
+ lora_alpha: int = 16
47
+ lora_dropout: float = 0.05
48
+ use_lora: bool = True
49
+
50
+ # Data processing
51
+ max_length: int = 1024
52
+ dataset_format: DatasetFormat = DatasetFormat.ALPACA
53
+
54
+ # Logging
55
+ log_with: str = "mlflow"
56
+ logging_steps: int = 10
57
+
58
+ def to_dict(self) -> Dict:
59
+ """Convert config to dictionary for LlamaFactory CLI."""
60
+ return {k: v.value if isinstance(v, Enum) else v
61
+ for k, v in self.__dict__.items()}
62
+
63
+
64
+ @dataclass
65
+ class SFTConfig(LlamaFactoryConfig):
66
+ """Configuration for supervised fine-tuning."""
67
+
68
+ strategy: TrainingStrategy = TrainingStrategy.SUPERVISED_FINETUNING
69
+ train_file: str = ""
70
+ val_file: Optional[str] = None
71
+
72
+ # SFT specific settings
73
+ cutoff_len: int = 1024
74
+ normalize_data: bool = True
75
+
76
+
77
+ @dataclass
78
+ class RLConfig(LlamaFactoryConfig):
79
+ """Configuration for reinforcement learning."""
80
+
81
+ strategy: TrainingStrategy = TrainingStrategy.REINFORCEMENT_LEARNING
82
+ reward_model: str = ""
83
+ train_file: str = ""
84
+
85
+ # RL specific settings
86
+ kl_coef: float = 0.1
87
+ top_k: int = 0
88
+ top_p: float = 1.0
89
+ temperature: float = 1.0
90
+
91
+
92
+ @dataclass
93
+ class DPOConfig(LlamaFactoryConfig):
94
+ """Configuration for direct preference optimization."""
95
+
96
+ strategy: TrainingStrategy = TrainingStrategy.PREFERENCE_OPTIMIZATION
97
+ train_file: str = ""
98
+ val_file: Optional[str] = None
99
+
100
+ # DPO specific settings
101
+ beta: float = 0.1
102
+ reference_model: Optional[str] = None
103
+
104
+
105
+ def create_default_config(strategy: TrainingStrategy, model_path: str) -> LlamaFactoryConfig:
106
+ """Create a default configuration based on the training strategy."""
107
+
108
+ if strategy == TrainingStrategy.SUPERVISED_FINETUNING:
109
+ return SFTConfig(model_path=model_path)
110
+ elif strategy == TrainingStrategy.REINFORCEMENT_LEARNING:
111
+ return RLConfig(model_path=model_path)
112
+ elif strategy == TrainingStrategy.PREFERENCE_OPTIMIZATION:
113
+ return DPOConfig(model_path=model_path)
114
+ else:
115
+ raise ValueError(f"Unsupported strategy: {strategy}")
@@ -0,0 +1,284 @@
1
+ """
2
+ Data adapters for LlamaFactory training.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import logging
8
+ from typing import Dict, List, Any, Optional, Union
9
+
10
+ from .config import DatasetFormat
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class DataAdapter:
16
+ """
17
+ Base class for adapting datasets to LlamaFactory format.
18
+
19
+ This class handles converting data from various formats into
20
+ the specific format expected by LlamaFactory.
21
+
22
+ Example:
23
+ ```python
24
+ # Create a custom adapter
25
+ adapter = CustomAdapter()
26
+
27
+ # Convert data
28
+ converted_path = adapter.convert_data(
29
+ input_file="path/to/source_data.json",
30
+ output_file="path/to/converted_data.json"
31
+ )
32
+ ```
33
+ """
34
+
35
+ def __init__(self, format_type: DatasetFormat = DatasetFormat.ALPACA):
36
+ """
37
+ Initialize the data adapter.
38
+
39
+ Args:
40
+ format_type: The target format to convert to
41
+ """
42
+ self.format_type = format_type
43
+
44
+ def convert_data(self, input_file: str, output_file: Optional[str] = None) -> str:
45
+ """
46
+ Convert data from input format to LlamaFactory format.
47
+
48
+ Args:
49
+ input_file: Path to the input data file
50
+ output_file: Path to save the converted data
51
+
52
+ Returns:
53
+ Path to the converted data file
54
+ """
55
+ if output_file is None:
56
+ dirname = os.path.dirname(input_file)
57
+ basename = os.path.basename(input_file)
58
+ name, ext = os.path.splitext(basename)
59
+ output_file = os.path.join(dirname, f"{name}_converted{ext}")
60
+
61
+ data = self._load_data(input_file)
62
+ converted_data = self._convert_data(data)
63
+ self._save_data(converted_data, output_file)
64
+
65
+ return output_file
66
+
67
+ def _load_data(self, input_file: str) -> Any:
68
+ """
69
+ Load data from the input file.
70
+
71
+ Args:
72
+ input_file: Path to the input file
73
+
74
+ Returns:
75
+ Loaded data
76
+ """
77
+ with open(input_file, 'r', encoding='utf-8') as f:
78
+ return json.load(f)
79
+
80
+ def _save_data(self, data: Any, output_file: str) -> None:
81
+ """
82
+ Save data to the output file.
83
+
84
+ Args:
85
+ data: Data to save
86
+ output_file: Path to save the data
87
+ """
88
+ with open(output_file, 'w', encoding='utf-8') as f:
89
+ json.dump(data, f, ensure_ascii=False, indent=2)
90
+
91
+ def _convert_data(self, data: Any) -> List[Dict[str, Any]]:
92
+ """
93
+ Convert data to the target format.
94
+
95
+ Args:
96
+ data: Input data
97
+
98
+ Returns:
99
+ Converted data
100
+ """
101
+ raise NotImplementedError("Subclasses must implement this method")
102
+
103
+
104
+ class AlpacaAdapter(DataAdapter):
105
+ """
106
+ Adapter for Alpaca format data.
107
+
108
+ Example:
109
+ ```python
110
+ adapter = AlpacaAdapter()
111
+ converted_path = adapter.convert_data("custom_data.json")
112
+ ```
113
+ """
114
+
115
+ def __init__(self):
116
+ """Initialize the Alpaca adapter."""
117
+ super().__init__(DatasetFormat.ALPACA)
118
+
119
+ def _convert_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
120
+ """
121
+ Convert data to Alpaca format.
122
+
123
+ Expected output format:
124
+ [
125
+ {
126
+ "instruction": "Task description",
127
+ "input": "Optional input (context)",
128
+ "output": "Expected output"
129
+ },
130
+ ...
131
+ ]
132
+
133
+ Args:
134
+ data: Input data
135
+
136
+ Returns:
137
+ Converted data in Alpaca format
138
+ """
139
+ result = []
140
+
141
+ for item in data:
142
+ if isinstance(item, dict):
143
+ # If already in the expected format, just add to result
144
+ if all(k in item for k in ["instruction", "output"]):
145
+ alpaca_item = {
146
+ "instruction": item["instruction"],
147
+ "input": item.get("input", ""),
148
+ "output": item["output"]
149
+ }
150
+ result.append(alpaca_item)
151
+ # Otherwise, try to convert from common formats
152
+ elif "prompt" in item and "response" in item:
153
+ alpaca_item = {
154
+ "instruction": item["prompt"],
155
+ "input": "",
156
+ "output": item["response"]
157
+ }
158
+ result.append(alpaca_item)
159
+ elif "question" in item and "answer" in item:
160
+ alpaca_item = {
161
+ "instruction": item["question"],
162
+ "input": "",
163
+ "output": item["answer"]
164
+ }
165
+ result.append(alpaca_item)
166
+ else:
167
+ logger.warning(f"Could not convert item: {item}")
168
+
169
+ return result
170
+
171
+
172
+ class ShareGPTAdapter(DataAdapter):
173
+ """
174
+ Adapter for ShareGPT format data.
175
+
176
+ Example:
177
+ ```python
178
+ adapter = ShareGPTAdapter()
179
+ converted_path = adapter.convert_data("sharegpt_data.json")
180
+ ```
181
+ """
182
+
183
+ def __init__(self):
184
+ """Initialize the ShareGPT adapter."""
185
+ super().__init__(DatasetFormat.SHAREGPT)
186
+
187
+ def _convert_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
188
+ """
189
+ Convert data to ShareGPT format.
190
+
191
+ Expected output format:
192
+ [
193
+ {
194
+ "conversations": [
195
+ {"from": "human", "value": "Human message"},
196
+ {"from": "gpt", "value": "Assistant response"},
197
+ ...
198
+ ]
199
+ },
200
+ ...
201
+ ]
202
+
203
+ Args:
204
+ data: Input data
205
+
206
+ Returns:
207
+ Converted data in ShareGPT format
208
+ """
209
+ result = []
210
+
211
+ for item in data:
212
+ if isinstance(item, dict):
213
+ conversations = []
214
+
215
+ # Handle different input formats
216
+
217
+ # If already in conversations format
218
+ if "conversations" in item and isinstance(item["conversations"], list):
219
+ # Make sure format is correct
220
+ for conv in item["conversations"]:
221
+ if "from" in conv and "value" in conv:
222
+ # Normalize role names
223
+ role = conv["from"].lower()
224
+ if role in ["user", "human"]:
225
+ role = "human"
226
+ elif role in ["assistant", "gpt", "ai"]:
227
+ role = "gpt"
228
+ else:
229
+ logger.warning(f"Unknown role: {role}, skipping message")
230
+ continue
231
+
232
+ conversations.append({
233
+ "from": role,
234
+ "value": conv["value"]
235
+ })
236
+
237
+ # If in QA format
238
+ elif "question" in item and "answer" in item:
239
+ conversations = [
240
+ {"from": "human", "value": item["question"]},
241
+ {"from": "gpt", "value": item["answer"]}
242
+ ]
243
+
244
+ # If in prompt/response format
245
+ elif "prompt" in item and "response" in item:
246
+ conversations = [
247
+ {"from": "human", "value": item["prompt"]},
248
+ {"from": "gpt", "value": item["response"]}
249
+ ]
250
+
251
+ if conversations:
252
+ result.append({"conversations": conversations})
253
+
254
+ return result
255
+
256
+
257
+ class DataAdapterFactory:
258
+ """
259
+ Factory for creating data adapters.
260
+
261
+ Example:
262
+ ```python
263
+ # Create an adapter based on format
264
+ adapter = DataAdapterFactory.create_adapter(DatasetFormat.ALPACA)
265
+ ```
266
+ """
267
+
268
+ @staticmethod
269
+ def create_adapter(format_type: DatasetFormat) -> DataAdapter:
270
+ """
271
+ Create a data adapter for the specified format.
272
+
273
+ Args:
274
+ format_type: The dataset format
275
+
276
+ Returns:
277
+ A data adapter instance
278
+ """
279
+ if format_type == DatasetFormat.ALPACA:
280
+ return AlpacaAdapter()
281
+ elif format_type == DatasetFormat.SHAREGPT:
282
+ return ShareGPTAdapter()
283
+ else:
284
+ raise ValueError(f"Unsupported format type: {format_type}")
@@ -0,0 +1,6 @@
1
+ """
2
+ Example scripts for LlamaFactory fine-tuning and reinforcement learning.
3
+
4
+ This package contains example scripts that demonstrate how to use
5
+ the LlamaFactory engine for various tasks such as fine-tuning and RLHF.
6
+ """