isa-model 0.0.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_manager.py +69 -4
- isa_model/core/model_registry.py +273 -46
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +427 -236
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +149 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +158 -86
- isa_model/inference/services/llm/llm_adapter.py +414 -0
- isa_model/inference/services/llm/ollama_llm_service.py +252 -63
- isa_model/inference/services/llm/openai_llm_service.py +231 -93
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +151 -17
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.3.1.dist-info/METADATA +465 -0
- isa_model-0.3.1.dist-info/RECORD +91 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.0.2.dist-info/METADATA +0 -327
- isa_model-0.0.2.dist-info/RECORD +0 -92
- isa_model-0.0.2.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
- {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
isa_model/training/factory.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1
1
|
"""
|
2
|
-
|
2
|
+
ISA Model Training Factory
|
3
3
|
|
4
|
-
|
5
|
-
|
6
|
-
- Image model training (Flux, LoRA)
|
7
|
-
- Model evaluation and benchmarking
|
4
|
+
A clean, simplified training factory that uses HuggingFace Transformers directly
|
5
|
+
without external dependencies like LlamaFactory.
|
8
6
|
"""
|
9
7
|
|
10
8
|
import os
|
@@ -13,43 +11,49 @@ from typing import Optional, Dict, Any, Union, List
|
|
13
11
|
from pathlib import Path
|
14
12
|
import datetime
|
15
13
|
|
16
|
-
from .
|
17
|
-
|
14
|
+
from .core import (
|
15
|
+
TrainingConfig,
|
16
|
+
LoRAConfig,
|
17
|
+
DatasetConfig,
|
18
|
+
BaseTrainer,
|
19
|
+
SFTTrainer,
|
20
|
+
TrainingUtils,
|
21
|
+
DatasetManager,
|
22
|
+
)
|
23
|
+
from .cloud import TrainingJobOrchestrator
|
18
24
|
|
19
25
|
logger = logging.getLogger(__name__)
|
20
26
|
|
21
27
|
|
22
28
|
class TrainingFactory:
|
23
29
|
"""
|
24
|
-
Unified
|
30
|
+
Unified Training Factory for ISA Model SDK
|
25
31
|
|
26
|
-
|
27
|
-
-
|
28
|
-
-
|
29
|
-
- Model evaluation and
|
32
|
+
Provides a clean interface for:
|
33
|
+
- Local training with SFT (Supervised Fine-Tuning)
|
34
|
+
- Cloud training on RunPod
|
35
|
+
- Model evaluation and management
|
30
36
|
|
31
|
-
Example usage
|
37
|
+
Example usage:
|
32
38
|
```python
|
33
39
|
from isa_model.training import TrainingFactory
|
34
40
|
|
35
41
|
factory = TrainingFactory()
|
36
42
|
|
37
|
-
#
|
38
|
-
model_path = factory.
|
43
|
+
# Local training
|
44
|
+
model_path = factory.train_model(
|
39
45
|
model_name="google/gemma-2-4b-it",
|
40
|
-
dataset_path="
|
41
|
-
training_type="sft",
|
46
|
+
dataset_path="tatsu-lab/alpaca",
|
42
47
|
use_lora=True,
|
43
|
-
num_epochs=3
|
44
|
-
batch_size=4,
|
45
|
-
learning_rate=2e-5
|
48
|
+
num_epochs=3
|
46
49
|
)
|
47
50
|
|
48
|
-
#
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
51
|
+
# Cloud training on RunPod
|
52
|
+
result = factory.train_on_runpod(
|
53
|
+
model_name="google/gemma-2-4b-it",
|
54
|
+
dataset_path="tatsu-lab/alpaca",
|
55
|
+
runpod_api_key="your-api-key",
|
56
|
+
template_id="your-template-id"
|
53
57
|
)
|
54
58
|
```
|
55
59
|
"""
|
@@ -59,32 +63,19 @@ class TrainingFactory:
|
|
59
63
|
Initialize the training factory.
|
60
64
|
|
61
65
|
Args:
|
62
|
-
base_output_dir: Base directory for
|
66
|
+
base_output_dir: Base directory for training outputs
|
63
67
|
"""
|
64
68
|
self.base_output_dir = base_output_dir or os.path.join(os.getcwd(), "training_outputs")
|
65
69
|
os.makedirs(self.base_output_dir, exist_ok=True)
|
66
70
|
|
67
|
-
# Initialize sub-factories
|
68
|
-
self.llm_factory = LlamaFactory(base_output_dir=os.path.join(self.base_output_dir, "llm"))
|
69
|
-
|
70
71
|
logger.info(f"TrainingFactory initialized with output dir: {self.base_output_dir}")
|
71
72
|
|
72
|
-
def
|
73
|
-
"""Generate timestamped output directory."""
|
74
|
-
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
75
|
-
safe_model_name = model_name.replace("/", "_").replace(":", "_")
|
76
|
-
return os.path.join(self.base_output_dir, f"{safe_model_name}_{training_type}_{timestamp}")
|
77
|
-
|
78
|
-
# =================
|
79
|
-
# LLM Training Methods
|
80
|
-
# =================
|
81
|
-
|
82
|
-
def finetune_llm(
|
73
|
+
def train_model(
|
83
74
|
self,
|
84
75
|
model_name: str,
|
85
76
|
dataset_path: str,
|
86
|
-
training_type: str = "sft",
|
87
77
|
output_dir: Optional[str] = None,
|
78
|
+
training_type: str = "sft",
|
88
79
|
dataset_format: str = "alpaca",
|
89
80
|
use_lora: bool = True,
|
90
81
|
batch_size: int = 4,
|
@@ -93,17 +84,17 @@ class TrainingFactory:
|
|
93
84
|
max_length: int = 1024,
|
94
85
|
lora_rank: int = 8,
|
95
86
|
lora_alpha: int = 16,
|
96
|
-
|
87
|
+
validation_split: float = 0.1,
|
97
88
|
**kwargs
|
98
89
|
) -> str:
|
99
90
|
"""
|
100
|
-
|
91
|
+
Train a model locally.
|
101
92
|
|
102
93
|
Args:
|
103
|
-
model_name: Model identifier (e.g., "google/gemma-2-4b-it"
|
104
|
-
dataset_path: Path to
|
105
|
-
training_type: Type of training ("sft", "dpo", "rlhf")
|
94
|
+
model_name: Model identifier (e.g., "google/gemma-2-4b-it")
|
95
|
+
dataset_path: Path to dataset or HuggingFace dataset name
|
106
96
|
output_dir: Custom output directory
|
97
|
+
training_type: Type of training ("sft" supported)
|
107
98
|
dataset_format: Dataset format ("alpaca", "sharegpt", "custom")
|
108
99
|
use_lora: Whether to use LoRA for efficient training
|
109
100
|
batch_size: Training batch size
|
@@ -112,7 +103,7 @@ class TrainingFactory:
|
|
112
103
|
max_length: Maximum sequence length
|
113
104
|
lora_rank: LoRA rank parameter
|
114
105
|
lora_alpha: LoRA alpha parameter
|
115
|
-
|
106
|
+
validation_split: Fraction of data for validation
|
116
107
|
**kwargs: Additional training parameters
|
117
108
|
|
118
109
|
Returns:
|
@@ -120,184 +111,207 @@ class TrainingFactory:
|
|
120
111
|
|
121
112
|
Example:
|
122
113
|
```python
|
123
|
-
|
124
|
-
model_path = factory.finetune_llm(
|
114
|
+
model_path = factory.train_model(
|
125
115
|
model_name="google/gemma-2-4b-it",
|
126
|
-
dataset_path="
|
127
|
-
training_type="sft",
|
116
|
+
dataset_path="tatsu-lab/alpaca",
|
128
117
|
use_lora=True,
|
129
118
|
num_epochs=3,
|
130
119
|
batch_size=4
|
131
120
|
)
|
132
121
|
```
|
133
122
|
"""
|
123
|
+
# Generate output directory if not provided
|
134
124
|
if not output_dir:
|
135
|
-
output_dir =
|
125
|
+
output_dir = TrainingUtils.generate_output_dir(
|
126
|
+
model_name, training_type, self.base_output_dir
|
127
|
+
)
|
136
128
|
|
137
|
-
#
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
dataset_format_enum = format_map.get(dataset_format, DatasetFormat.ALPACA)
|
129
|
+
# Create configurations
|
130
|
+
lora_config = LoRAConfig(
|
131
|
+
use_lora=use_lora,
|
132
|
+
lora_rank=lora_rank,
|
133
|
+
lora_alpha=lora_alpha
|
134
|
+
) if use_lora else None
|
144
135
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
dataset_format=dataset_format_enum,
|
152
|
-
use_lora=use_lora,
|
153
|
-
batch_size=batch_size,
|
154
|
-
num_epochs=num_epochs,
|
155
|
-
learning_rate=learning_rate,
|
156
|
-
max_length=max_length,
|
157
|
-
lora_rank=lora_rank,
|
158
|
-
lora_alpha=lora_alpha,
|
159
|
-
**kwargs
|
160
|
-
)
|
161
|
-
else:
|
162
|
-
raise ValueError(f"Training type '{training_type}' not supported yet. Use 'sft' for now.")
|
163
|
-
|
164
|
-
def train_with_preferences(
|
165
|
-
self,
|
166
|
-
model_path: str,
|
167
|
-
preference_data: str,
|
168
|
-
output_dir: Optional[str] = None,
|
169
|
-
reference_model: Optional[str] = None,
|
170
|
-
beta: float = 0.1,
|
171
|
-
use_lora: bool = True,
|
172
|
-
batch_size: int = 4,
|
173
|
-
num_epochs: int = 3,
|
174
|
-
learning_rate: float = 5e-6,
|
175
|
-
val_data: Optional[str] = None,
|
176
|
-
**kwargs
|
177
|
-
) -> str:
|
178
|
-
"""
|
179
|
-
Train model with preference data using DPO.
|
136
|
+
dataset_config = DatasetConfig(
|
137
|
+
dataset_path=dataset_path,
|
138
|
+
dataset_format=dataset_format,
|
139
|
+
max_length=max_length,
|
140
|
+
validation_split=validation_split
|
141
|
+
)
|
180
142
|
|
181
|
-
|
182
|
-
|
183
|
-
preference_data: Path to preference dataset
|
184
|
-
output_dir: Custom output directory
|
185
|
-
reference_model: Reference model for DPO (optional)
|
186
|
-
beta: DPO beta parameter
|
187
|
-
use_lora: Whether to use LoRA
|
188
|
-
batch_size: Training batch size
|
189
|
-
num_epochs: Number of epochs
|
190
|
-
learning_rate: Learning rate
|
191
|
-
val_data: Validation data path
|
192
|
-
**kwargs: Additional parameters
|
193
|
-
|
194
|
-
Returns:
|
195
|
-
Path to the trained model
|
196
|
-
"""
|
197
|
-
if not output_dir:
|
198
|
-
model_name = os.path.basename(model_path)
|
199
|
-
output_dir = self._get_output_dir(model_name, "dpo")
|
200
|
-
|
201
|
-
return self.llm_factory.dpo(
|
202
|
-
model_path=model_path,
|
203
|
-
train_data=preference_data,
|
204
|
-
val_data=val_data,
|
205
|
-
reference_model=reference_model,
|
143
|
+
training_config = TrainingConfig(
|
144
|
+
model_name=model_name,
|
206
145
|
output_dir=output_dir,
|
207
|
-
|
208
|
-
batch_size=batch_size,
|
146
|
+
training_type=training_type,
|
209
147
|
num_epochs=num_epochs,
|
148
|
+
batch_size=batch_size,
|
210
149
|
learning_rate=learning_rate,
|
211
|
-
|
150
|
+
lora_config=lora_config,
|
151
|
+
dataset_config=dataset_config,
|
212
152
|
**kwargs
|
213
153
|
)
|
154
|
+
|
155
|
+
# Print training summary
|
156
|
+
model_info = TrainingUtils.get_model_info(model_name)
|
157
|
+
memory_estimate = TrainingUtils.estimate_memory_usage(
|
158
|
+
model_name, batch_size, max_length, use_lora
|
159
|
+
)
|
160
|
+
|
161
|
+
summary = TrainingUtils.format_training_summary(
|
162
|
+
training_config.to_dict(), model_info, memory_estimate
|
163
|
+
)
|
164
|
+
print(summary)
|
165
|
+
|
166
|
+
# Validate configuration
|
167
|
+
issues = TrainingUtils.validate_training_config(training_config.to_dict())
|
168
|
+
if issues:
|
169
|
+
raise ValueError(f"Training configuration issues: {issues}")
|
170
|
+
|
171
|
+
# Initialize trainer based on training type
|
172
|
+
if training_type.lower() == "sft":
|
173
|
+
trainer = SFTTrainer(training_config)
|
174
|
+
else:
|
175
|
+
raise ValueError(f"Training type '{training_type}' not supported yet")
|
176
|
+
|
177
|
+
# Execute training
|
178
|
+
logger.info(f"Starting {training_type.upper()} training...")
|
179
|
+
result_path = trainer.train()
|
180
|
+
|
181
|
+
logger.info(f"Training completed! Model saved to: {result_path}")
|
182
|
+
return result_path
|
214
183
|
|
215
|
-
def
|
184
|
+
def train_on_runpod(
|
216
185
|
self,
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
) -> str:
|
186
|
+
model_name: str,
|
187
|
+
dataset_path: str,
|
188
|
+
runpod_api_key: str,
|
189
|
+
template_id: str,
|
190
|
+
gpu_type: str = "NVIDIA RTX A6000",
|
191
|
+
storage_config: Optional[Dict[str, Any]] = None,
|
192
|
+
job_name: Optional[str] = None,
|
193
|
+
**training_params
|
194
|
+
) -> Dict[str, Any]:
|
227
195
|
"""
|
228
|
-
Train a
|
196
|
+
Train a model on RunPod cloud infrastructure.
|
229
197
|
|
230
198
|
Args:
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
**kwargs: Additional parameters
|
199
|
+
model_name: Model identifier
|
200
|
+
dataset_path: Dataset path or HuggingFace dataset name
|
201
|
+
runpod_api_key: RunPod API key
|
202
|
+
template_id: RunPod template ID
|
203
|
+
gpu_type: GPU type to use
|
204
|
+
storage_config: Optional cloud storage configuration
|
205
|
+
job_name: Optional job name
|
206
|
+
**training_params: Additional training parameters
|
240
207
|
|
241
208
|
Returns:
|
242
|
-
|
209
|
+
Training job results
|
210
|
+
|
211
|
+
Example:
|
212
|
+
```python
|
213
|
+
result = factory.train_on_runpod(
|
214
|
+
model_name="google/gemma-2-4b-it",
|
215
|
+
dataset_path="tatsu-lab/alpaca",
|
216
|
+
runpod_api_key="your-api-key",
|
217
|
+
template_id="your-template-id",
|
218
|
+
use_lora=True,
|
219
|
+
num_epochs=3
|
220
|
+
)
|
221
|
+
```
|
243
222
|
"""
|
244
|
-
|
245
|
-
|
246
|
-
|
223
|
+
# Import cloud components
|
224
|
+
from .cloud import TrainingJobOrchestrator
|
225
|
+
from .cloud.runpod_trainer import RunPodConfig
|
226
|
+
from .cloud.storage_manager import StorageConfig
|
227
|
+
from .cloud.job_orchestrator import JobConfig
|
247
228
|
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
use_lora=use_lora,
|
254
|
-
batch_size=batch_size,
|
255
|
-
num_epochs=num_epochs,
|
256
|
-
learning_rate=learning_rate,
|
257
|
-
**kwargs
|
229
|
+
# Create RunPod configuration
|
230
|
+
runpod_config = RunPodConfig(
|
231
|
+
api_key=runpod_api_key,
|
232
|
+
template_id=template_id,
|
233
|
+
gpu_type=gpu_type
|
258
234
|
)
|
235
|
+
|
236
|
+
# Create storage configuration if provided
|
237
|
+
storage_cfg = None
|
238
|
+
if storage_config:
|
239
|
+
storage_cfg = StorageConfig(**storage_config)
|
240
|
+
|
241
|
+
# Create job configuration
|
242
|
+
job_config = JobConfig(
|
243
|
+
model_name=model_name,
|
244
|
+
dataset_source=dataset_path,
|
245
|
+
job_name=job_name or f"gemma-training-{int(datetime.datetime.now().timestamp())}",
|
246
|
+
**training_params
|
247
|
+
)
|
248
|
+
|
249
|
+
# Initialize orchestrator and execute training
|
250
|
+
orchestrator = TrainingJobOrchestrator(
|
251
|
+
runpod_config=runpod_config,
|
252
|
+
storage_config=storage_cfg
|
253
|
+
)
|
254
|
+
|
255
|
+
logger.info(f"Starting RunPod training for {model_name}")
|
256
|
+
result = orchestrator.execute_training_workflow(job_config)
|
257
|
+
|
258
|
+
return result
|
259
259
|
|
260
|
-
|
261
|
-
# Image Model Training Methods
|
262
|
-
# =================
|
263
|
-
|
264
|
-
def train_image_model(
|
260
|
+
async def upload_to_huggingface(
|
265
261
|
self,
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
num_epochs: int = 1000,
|
271
|
-
batch_size: int = 1,
|
272
|
-
learning_rate: float = 1e-4,
|
273
|
-
**kwargs
|
262
|
+
model_path: str,
|
263
|
+
hf_model_name: str,
|
264
|
+
hf_token: Optional[str] = None,
|
265
|
+
metadata: Optional[Dict[str, Any]] = None
|
274
266
|
) -> str:
|
275
267
|
"""
|
276
|
-
|
268
|
+
Upload a trained model to HuggingFace Hub using HuggingFaceStorage.
|
277
269
|
|
278
270
|
Args:
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
num_epochs: Training epochs
|
284
|
-
batch_size: Batch size
|
285
|
-
learning_rate: Learning rate
|
286
|
-
**kwargs: Additional parameters
|
271
|
+
model_path: Path to the trained model
|
272
|
+
hf_model_name: Name for the model on HuggingFace Hub
|
273
|
+
hf_token: HuggingFace token
|
274
|
+
metadata: Additional metadata for the model
|
287
275
|
|
288
276
|
Returns:
|
289
|
-
|
277
|
+
URL of the uploaded model
|
290
278
|
"""
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
279
|
+
try:
|
280
|
+
from ..core.storage.hf_storage import HuggingFaceStorage
|
281
|
+
|
282
|
+
logger.info(f"Uploading model to HuggingFace: {hf_model_name}")
|
283
|
+
|
284
|
+
# Initialize HuggingFace storage
|
285
|
+
storage = HuggingFaceStorage(
|
286
|
+
username="xenobordom",
|
287
|
+
token=hf_token
|
288
|
+
)
|
289
|
+
|
290
|
+
# Prepare metadata
|
291
|
+
upload_metadata = metadata or {}
|
292
|
+
upload_metadata.update({
|
293
|
+
"description": f"Fine-tuned model: {hf_model_name}",
|
294
|
+
"training_framework": "ISA Model SDK",
|
295
|
+
"uploaded_from": "training_factory"
|
296
|
+
})
|
297
|
+
|
298
|
+
# Upload model
|
299
|
+
success = await storage.save_model(
|
300
|
+
model_id=hf_model_name,
|
301
|
+
model_path=model_path,
|
302
|
+
metadata=upload_metadata
|
303
|
+
)
|
304
|
+
|
305
|
+
if success:
|
306
|
+
model_url = storage.get_public_url(hf_model_name)
|
307
|
+
logger.info(f"Model uploaded successfully: {model_url}")
|
308
|
+
return model_url
|
309
|
+
else:
|
310
|
+
raise Exception("Failed to upload model")
|
311
|
+
|
312
|
+
except Exception as e:
|
313
|
+
logger.error(f"Failed to upload to HuggingFace: {e}")
|
314
|
+
raise
|
301
315
|
|
302
316
|
def get_training_status(self, output_dir: str) -> Dict[str, Any]:
|
303
317
|
"""
|
@@ -318,6 +332,21 @@ class TrainingFactory:
|
|
318
332
|
if status["exists"]:
|
319
333
|
status["files"] = os.listdir(output_dir)
|
320
334
|
|
335
|
+
# Check for specific files
|
336
|
+
config_path = os.path.join(output_dir, "training_config.json")
|
337
|
+
metrics_path = os.path.join(output_dir, "training_metrics.json")
|
338
|
+
model_path = os.path.join(output_dir, "pytorch_model.bin")
|
339
|
+
|
340
|
+
status["has_config"] = os.path.exists(config_path)
|
341
|
+
status["has_metrics"] = os.path.exists(metrics_path)
|
342
|
+
status["has_model"] = os.path.exists(model_path) or os.path.exists(os.path.join(output_dir, "adapter_model.bin"))
|
343
|
+
|
344
|
+
if status["has_config"]:
|
345
|
+
try:
|
346
|
+
status["config"] = TrainingUtils.load_training_args(output_dir)
|
347
|
+
except:
|
348
|
+
pass
|
349
|
+
|
321
350
|
return status
|
322
351
|
|
323
352
|
def list_trained_models(self) -> List[Dict[str, Any]]:
|
@@ -333,26 +362,28 @@ class TrainingFactory:
|
|
333
362
|
for item in os.listdir(self.base_output_dir):
|
334
363
|
item_path = os.path.join(self.base_output_dir, item)
|
335
364
|
if os.path.isdir(item_path):
|
365
|
+
status = self.get_training_status(item_path)
|
336
366
|
models.append({
|
337
367
|
"name": item,
|
338
368
|
"path": item_path,
|
339
369
|
"created": datetime.datetime.fromtimestamp(
|
340
370
|
os.path.getctime(item_path)
|
341
|
-
).isoformat()
|
371
|
+
).isoformat(),
|
372
|
+
"status": status
|
342
373
|
})
|
343
374
|
|
344
375
|
return sorted(models, key=lambda x: x["created"], reverse=True)
|
345
376
|
|
346
377
|
|
347
378
|
# Convenience functions for quick access
|
348
|
-
def
|
379
|
+
def train_gemma(
|
349
380
|
dataset_path: str,
|
350
381
|
model_size: str = "4b",
|
351
382
|
output_dir: Optional[str] = None,
|
352
383
|
**kwargs
|
353
384
|
) -> str:
|
354
385
|
"""
|
355
|
-
Quick function to
|
386
|
+
Quick function to train Gemma models.
|
356
387
|
|
357
388
|
Args:
|
358
389
|
dataset_path: Path to training dataset
|
@@ -361,14 +392,14 @@ def finetune_gemma(
|
|
361
392
|
**kwargs: Additional training parameters
|
362
393
|
|
363
394
|
Returns:
|
364
|
-
Path to
|
395
|
+
Path to trained model
|
365
396
|
|
366
397
|
Example:
|
367
398
|
```python
|
368
|
-
from isa_model.training import
|
399
|
+
from isa_model.training import train_gemma
|
369
400
|
|
370
|
-
model_path =
|
371
|
-
dataset_path="
|
401
|
+
model_path = train_gemma(
|
402
|
+
dataset_path="tatsu-lab/alpaca",
|
372
403
|
model_size="4b",
|
373
404
|
num_epochs=3,
|
374
405
|
batch_size=4
|
@@ -385,7 +416,7 @@ def finetune_gemma(
|
|
385
416
|
|
386
417
|
model_name = model_map.get(model_size, "google/gemma-2-4b-it")
|
387
418
|
|
388
|
-
return factory.
|
419
|
+
return factory.train_model(
|
389
420
|
model_name=model_name,
|
390
421
|
dataset_path=dataset_path,
|
391
422
|
output_dir=output_dir,
|