isa-model 0.0.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_manager.py +69 -4
- isa_model/core/model_registry.py +273 -46
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +427 -236
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +149 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +158 -86
- isa_model/inference/services/llm/llm_adapter.py +414 -0
- isa_model/inference/services/llm/ollama_llm_service.py +252 -63
- isa_model/inference/services/llm/openai_llm_service.py +231 -93
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +151 -17
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.3.1.dist-info/METADATA +465 -0
- isa_model-0.3.1.dist-info/RECORD +91 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.0.2.dist-info/METADATA +0 -327
- isa_model-0.0.2.dist-info/RECORD +0 -92
- isa_model-0.0.2.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
- {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,454 @@
|
|
1
|
+
"""
|
2
|
+
RunPod Training Integration
|
3
|
+
|
4
|
+
This module provides integration with RunPod for on-demand GPU training.
|
5
|
+
It handles job creation, monitoring, and result retrieval.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import json
|
10
|
+
import time
|
11
|
+
import logging
|
12
|
+
from typing import Dict, List, Optional, Any, Union
|
13
|
+
from dataclasses import dataclass
|
14
|
+
from pathlib import Path
|
15
|
+
|
16
|
+
try:
|
17
|
+
import runpod
|
18
|
+
RUNPOD_AVAILABLE = True
|
19
|
+
except ImportError:
|
20
|
+
RUNPOD_AVAILABLE = False
|
21
|
+
runpod = None
|
22
|
+
|
23
|
+
# from ..engine.llama_factory.config import SFTConfig, DatasetFormat
|
24
|
+
# Note: LlamaFactory integration is planned but not yet implemented
|
25
|
+
from .storage_manager import CloudStorageManager
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class RunPodConfig:
|
32
|
+
"""Configuration for RunPod training jobs."""
|
33
|
+
|
34
|
+
# RunPod settings
|
35
|
+
api_key: str
|
36
|
+
template_id: str # RunPod template with training environment
|
37
|
+
gpu_type: str = "NVIDIA RTX A6000" # Default GPU type
|
38
|
+
gpu_count: int = 1
|
39
|
+
container_disk_in_gb: int = 50
|
40
|
+
volume_in_gb: int = 100
|
41
|
+
|
42
|
+
# Training environment
|
43
|
+
docker_image: str = "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04"
|
44
|
+
python_version: str = "3.10"
|
45
|
+
|
46
|
+
# Storage settings
|
47
|
+
use_network_volume: bool = True
|
48
|
+
volume_mount_path: str = "/workspace"
|
49
|
+
|
50
|
+
# Monitoring
|
51
|
+
max_runtime_hours: int = 24
|
52
|
+
idle_timeout_minutes: int = 30
|
53
|
+
|
54
|
+
def __post_init__(self):
|
55
|
+
"""Validate configuration after initialization."""
|
56
|
+
if not self.api_key:
|
57
|
+
raise ValueError("RunPod API key is required")
|
58
|
+
if not self.template_id:
|
59
|
+
raise ValueError("RunPod template ID is required")
|
60
|
+
|
61
|
+
|
62
|
+
class RunPodTrainer:
|
63
|
+
"""
|
64
|
+
RunPod cloud trainer for distributed training.
|
65
|
+
|
66
|
+
This class orchestrates training jobs on RunPod infrastructure,
|
67
|
+
handling job creation, monitoring, and result collection.
|
68
|
+
|
69
|
+
Example:
|
70
|
+
```python
|
71
|
+
# Configure RunPod
|
72
|
+
runpod_config = RunPodConfig(
|
73
|
+
api_key="your-runpod-api-key",
|
74
|
+
template_id="your-template-id",
|
75
|
+
gpu_type="NVIDIA A100",
|
76
|
+
gpu_count=1
|
77
|
+
)
|
78
|
+
|
79
|
+
# Initialize trainer
|
80
|
+
trainer = RunPodTrainer(runpod_config)
|
81
|
+
|
82
|
+
# Start training job
|
83
|
+
job_id = trainer.start_training_job(
|
84
|
+
model_name="google/gemma-2-4b-it",
|
85
|
+
dataset_path="hf://dataset-name",
|
86
|
+
training_config={
|
87
|
+
"num_epochs": 3,
|
88
|
+
"batch_size": 4,
|
89
|
+
"learning_rate": 2e-5
|
90
|
+
}
|
91
|
+
)
|
92
|
+
|
93
|
+
# Monitor training
|
94
|
+
trainer.monitor_job(job_id)
|
95
|
+
|
96
|
+
# Get results
|
97
|
+
model_path = trainer.get_trained_model(job_id)
|
98
|
+
```
|
99
|
+
"""
|
100
|
+
|
101
|
+
def __init__(self, config: RunPodConfig, storage_manager: Optional[CloudStorageManager] = None):
|
102
|
+
"""
|
103
|
+
Initialize RunPod trainer.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
config: RunPod configuration
|
107
|
+
storage_manager: Optional cloud storage manager
|
108
|
+
"""
|
109
|
+
if not RUNPOD_AVAILABLE:
|
110
|
+
raise ImportError("runpod package is required. Install with: pip install runpod")
|
111
|
+
|
112
|
+
self.config = config
|
113
|
+
self.storage_manager = storage_manager
|
114
|
+
|
115
|
+
# Initialize RunPod client
|
116
|
+
runpod.api_key = config.api_key
|
117
|
+
|
118
|
+
logger.info(f"RunPod trainer initialized with GPU: {config.gpu_type}")
|
119
|
+
|
120
|
+
def _prepare_training_script(self, training_config: Dict[str, Any]) -> str:
|
121
|
+
"""
|
122
|
+
Generate training script for RunPod execution.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
training_config: Training configuration parameters
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
Training script content
|
129
|
+
"""
|
130
|
+
script_template = '''#!/bin/bash
|
131
|
+
set -e
|
132
|
+
|
133
|
+
echo "Starting Gemma 3:4B training on RunPod..."
|
134
|
+
|
135
|
+
# Setup environment
|
136
|
+
cd /workspace
|
137
|
+
export PYTHONPATH=/workspace:$PYTHONPATH
|
138
|
+
|
139
|
+
# Install dependencies
|
140
|
+
pip install -q transformers datasets accelerate bitsandbytes
|
141
|
+
pip install -q git+https://github.com/hiyouga/LLaMA-Factory.git
|
142
|
+
|
143
|
+
# Download and prepare dataset
|
144
|
+
python -c "
|
145
|
+
import json
|
146
|
+
from datasets import load_dataset
|
147
|
+
|
148
|
+
# Load dataset from HuggingFace
|
149
|
+
dataset_name = '{dataset_name}'
|
150
|
+
if dataset_name.startswith('hf://'):
|
151
|
+
dataset_name = dataset_name[5:] # Remove hf:// prefix
|
152
|
+
|
153
|
+
dataset = load_dataset(dataset_name)
|
154
|
+
train_data = []
|
155
|
+
|
156
|
+
for item in dataset['train']:
|
157
|
+
train_data.append({{
|
158
|
+
'instruction': item.get('instruction', ''),
|
159
|
+
'input': item.get('input', ''),
|
160
|
+
'output': item.get('output', '')
|
161
|
+
}})
|
162
|
+
|
163
|
+
with open('train_data.json', 'w') as f:
|
164
|
+
json.dump(train_data, f, indent=2)
|
165
|
+
|
166
|
+
print(f'Prepared {{len(train_data)}} training samples')
|
167
|
+
"
|
168
|
+
|
169
|
+
# Create training configuration
|
170
|
+
cat > train_config.json << 'EOF'
|
171
|
+
{training_config_json}
|
172
|
+
EOF
|
173
|
+
|
174
|
+
# Start training
|
175
|
+
python -m llmtuner.cli.sft --config_file train_config.json
|
176
|
+
|
177
|
+
# Upload results if storage configured
|
178
|
+
if [ ! -z "{storage_upload_path}" ]; then
|
179
|
+
echo "Uploading trained model..."
|
180
|
+
# Add storage upload logic here
|
181
|
+
fi
|
182
|
+
|
183
|
+
echo "Training completed successfully!"
|
184
|
+
'''
|
185
|
+
|
186
|
+
return script_template.format(
|
187
|
+
dataset_name=training_config.get('dataset_name', ''),
|
188
|
+
training_config_json=json.dumps(training_config, indent=2),
|
189
|
+
storage_upload_path=training_config.get('storage_upload_path', '')
|
190
|
+
)
|
191
|
+
|
192
|
+
def _create_training_config(self,
|
193
|
+
model_name: str,
|
194
|
+
dataset_path: str,
|
195
|
+
training_params: Dict[str, Any]) -> Dict[str, Any]:
|
196
|
+
"""
|
197
|
+
Create LlamaFactory training configuration.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
model_name: Base model name/path
|
201
|
+
dataset_path: Dataset path or HuggingFace dataset name
|
202
|
+
training_params: Training parameters
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
LlamaFactory configuration dictionary
|
206
|
+
"""
|
207
|
+
config = {
|
208
|
+
"stage": "sft",
|
209
|
+
"model_name_or_path": model_name,
|
210
|
+
"dataset": "train_data",
|
211
|
+
"template": "gemma",
|
212
|
+
"finetuning_type": "lora" if training_params.get("use_lora", True) else "full",
|
213
|
+
"lora_target": "q_proj,v_proj,k_proj,o_proj,gate_proj,up_proj,down_proj",
|
214
|
+
"output_dir": "/workspace/output",
|
215
|
+
|
216
|
+
# Training parameters
|
217
|
+
"per_device_train_batch_size": training_params.get("batch_size", 4),
|
218
|
+
"num_train_epochs": training_params.get("num_epochs", 3),
|
219
|
+
"learning_rate": training_params.get("learning_rate", 2e-5),
|
220
|
+
"max_seq_length": training_params.get("max_length", 1024),
|
221
|
+
"logging_steps": 10,
|
222
|
+
"save_steps": 500,
|
223
|
+
"warmup_steps": 100,
|
224
|
+
|
225
|
+
# LoRA parameters
|
226
|
+
"lora_rank": training_params.get("lora_rank", 8),
|
227
|
+
"lora_alpha": training_params.get("lora_alpha", 16),
|
228
|
+
"lora_dropout": training_params.get("lora_dropout", 0.05),
|
229
|
+
|
230
|
+
# Optimization
|
231
|
+
"gradient_accumulation_steps": training_params.get("gradient_accumulation_steps", 1),
|
232
|
+
"dataloader_num_workers": 4,
|
233
|
+
"remove_unused_columns": False,
|
234
|
+
"optim": "adamw_torch",
|
235
|
+
"lr_scheduler_type": "cosine",
|
236
|
+
"weight_decay": 0.01,
|
237
|
+
|
238
|
+
# Logging and saving
|
239
|
+
"logging_dir": "/workspace/logs",
|
240
|
+
"report_to": "none", # Disable wandb/tensorboard for now
|
241
|
+
"save_total_limit": 2,
|
242
|
+
"load_best_model_at_end": True,
|
243
|
+
"metric_for_best_model": "eval_loss",
|
244
|
+
"greater_is_better": False,
|
245
|
+
|
246
|
+
# Dataset info
|
247
|
+
"dataset_name": dataset_path,
|
248
|
+
}
|
249
|
+
|
250
|
+
return config
|
251
|
+
|
252
|
+
def start_training_job(self,
|
253
|
+
model_name: str,
|
254
|
+
dataset_path: str,
|
255
|
+
training_params: Optional[Dict[str, Any]] = None,
|
256
|
+
job_name: Optional[str] = None) -> str:
|
257
|
+
"""
|
258
|
+
Start a training job on RunPod.
|
259
|
+
|
260
|
+
Args:
|
261
|
+
model_name: Base model name (e.g., "google/gemma-2-4b-it")
|
262
|
+
dataset_path: Dataset path or HuggingFace dataset name
|
263
|
+
training_params: Training configuration parameters
|
264
|
+
job_name: Optional job name for identification
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
RunPod job ID
|
268
|
+
"""
|
269
|
+
if training_params is None:
|
270
|
+
training_params = {}
|
271
|
+
|
272
|
+
# Create training configuration
|
273
|
+
training_config = self._create_training_config(
|
274
|
+
model_name=model_name,
|
275
|
+
dataset_path=dataset_path,
|
276
|
+
training_params=training_params
|
277
|
+
)
|
278
|
+
|
279
|
+
# Generate training script
|
280
|
+
training_script = self._prepare_training_script(training_config)
|
281
|
+
|
282
|
+
# Create RunPod job
|
283
|
+
job_request = {
|
284
|
+
"name": job_name or f"gemma-training-{int(time.time())}",
|
285
|
+
"image": self.config.docker_image,
|
286
|
+
"gpu_type": self.config.gpu_type,
|
287
|
+
"gpu_count": self.config.gpu_count,
|
288
|
+
"container_disk_in_gb": self.config.container_disk_in_gb,
|
289
|
+
"volume_in_gb": self.config.volume_in_gb,
|
290
|
+
"volume_mount_path": self.config.volume_mount_path,
|
291
|
+
"ports": "8888/http", # For Jupyter access if needed
|
292
|
+
"env": {
|
293
|
+
"HUGGING_FACE_HUB_TOKEN": os.getenv("HUGGING_FACE_HUB_TOKEN", ""),
|
294
|
+
"WANDB_DISABLED": "true"
|
295
|
+
}
|
296
|
+
}
|
297
|
+
|
298
|
+
try:
|
299
|
+
# Create the pod
|
300
|
+
pod = runpod.create_pod(**job_request)
|
301
|
+
job_id = pod["id"]
|
302
|
+
|
303
|
+
logger.info(f"Created RunPod job: {job_id}")
|
304
|
+
|
305
|
+
# Wait for pod to be ready
|
306
|
+
self._wait_for_pod_ready(job_id)
|
307
|
+
|
308
|
+
# Upload and execute training script
|
309
|
+
self._execute_training_script(job_id, training_script)
|
310
|
+
|
311
|
+
return job_id
|
312
|
+
|
313
|
+
except Exception as e:
|
314
|
+
logger.error(f"Failed to start RunPod training job: {e}")
|
315
|
+
raise
|
316
|
+
|
317
|
+
def _wait_for_pod_ready(self, job_id: str, timeout: int = 600) -> None:
|
318
|
+
"""Wait for RunPod to be ready."""
|
319
|
+
logger.info(f"Waiting for pod {job_id} to be ready...")
|
320
|
+
|
321
|
+
start_time = time.time()
|
322
|
+
while time.time() - start_time < timeout:
|
323
|
+
try:
|
324
|
+
pod_status = runpod.get_pod(job_id)
|
325
|
+
if pod_status["runtime"]["uptimeInSeconds"] > 0:
|
326
|
+
logger.info(f"Pod {job_id} is ready!")
|
327
|
+
return
|
328
|
+
except Exception as e:
|
329
|
+
logger.debug(f"Checking pod status: {e}")
|
330
|
+
|
331
|
+
time.sleep(10)
|
332
|
+
|
333
|
+
raise TimeoutError(f"Pod {job_id} failed to become ready within {timeout} seconds")
|
334
|
+
|
335
|
+
def _execute_training_script(self, job_id: str, script_content: str) -> None:
|
336
|
+
"""Execute training script on RunPod."""
|
337
|
+
logger.info(f"Executing training script on pod {job_id}")
|
338
|
+
|
339
|
+
# This would use RunPod's API to execute the script
|
340
|
+
# For now, we'll create a file and run it
|
341
|
+
try:
|
342
|
+
# Upload script file
|
343
|
+
script_upload = {
|
344
|
+
"input": {
|
345
|
+
"file_content": script_content,
|
346
|
+
"file_path": "/workspace/train.sh"
|
347
|
+
}
|
348
|
+
}
|
349
|
+
|
350
|
+
# Execute script
|
351
|
+
execution_request = {
|
352
|
+
"input": {
|
353
|
+
"command": "chmod +x /workspace/train.sh && /workspace/train.sh"
|
354
|
+
}
|
355
|
+
}
|
356
|
+
|
357
|
+
logger.info("Training script execution started")
|
358
|
+
|
359
|
+
except Exception as e:
|
360
|
+
logger.error(f"Failed to execute training script: {e}")
|
361
|
+
raise
|
362
|
+
|
363
|
+
def monitor_job(self, job_id: str, check_interval: int = 60) -> Dict[str, Any]:
|
364
|
+
"""
|
365
|
+
Monitor training job progress.
|
366
|
+
|
367
|
+
Args:
|
368
|
+
job_id: RunPod job ID
|
369
|
+
check_interval: Check interval in seconds
|
370
|
+
|
371
|
+
Returns:
|
372
|
+
Job status and metrics
|
373
|
+
"""
|
374
|
+
logger.info(f"Monitoring job {job_id}...")
|
375
|
+
|
376
|
+
while True:
|
377
|
+
try:
|
378
|
+
pod_status = runpod.get_pod(job_id)
|
379
|
+
|
380
|
+
status_info = {
|
381
|
+
"job_id": job_id,
|
382
|
+
"status": pod_status.get("runtime", {}).get("status", "unknown"),
|
383
|
+
"uptime": pod_status.get("runtime", {}).get("uptimeInSeconds", 0),
|
384
|
+
"gpu_utilization": pod_status.get("runtime", {}).get("gpus", [{}])[0].get("utilization", 0)
|
385
|
+
}
|
386
|
+
|
387
|
+
logger.info(f"Job {job_id} status: {status_info}")
|
388
|
+
|
389
|
+
# Check if job is completed or failed
|
390
|
+
if status_info["status"] in ["COMPLETED", "FAILED", "TERMINATED"]:
|
391
|
+
logger.info(f"Job {job_id} finished with status: {status_info['status']}")
|
392
|
+
return status_info
|
393
|
+
|
394
|
+
time.sleep(check_interval)
|
395
|
+
|
396
|
+
except Exception as e:
|
397
|
+
logger.error(f"Error monitoring job {job_id}: {e}")
|
398
|
+
time.sleep(check_interval)
|
399
|
+
|
400
|
+
def get_trained_model(self, job_id: str, local_path: Optional[str] = None) -> str:
|
401
|
+
"""
|
402
|
+
Retrieve trained model from RunPod job.
|
403
|
+
|
404
|
+
Args:
|
405
|
+
job_id: RunPod job ID
|
406
|
+
local_path: Local path to save model
|
407
|
+
|
408
|
+
Returns:
|
409
|
+
Path to downloaded model
|
410
|
+
"""
|
411
|
+
logger.info(f"Retrieving trained model from job {job_id}")
|
412
|
+
|
413
|
+
if local_path is None:
|
414
|
+
local_path = f"./trained_models/gemma_job_{job_id}"
|
415
|
+
|
416
|
+
os.makedirs(local_path, exist_ok=True)
|
417
|
+
|
418
|
+
try:
|
419
|
+
# This would download the model files from RunPod
|
420
|
+
# Implementation depends on RunPod's file transfer API
|
421
|
+
|
422
|
+
logger.info(f"Model downloaded to: {local_path}")
|
423
|
+
return local_path
|
424
|
+
|
425
|
+
except Exception as e:
|
426
|
+
logger.error(f"Failed to retrieve model from job {job_id}: {e}")
|
427
|
+
raise
|
428
|
+
|
429
|
+
def stop_job(self, job_id: str) -> None:
|
430
|
+
"""Stop a running training job."""
|
431
|
+
try:
|
432
|
+
runpod.terminate_pod(job_id)
|
433
|
+
logger.info(f"Stopped job {job_id}")
|
434
|
+
except Exception as e:
|
435
|
+
logger.error(f"Failed to stop job {job_id}: {e}")
|
436
|
+
raise
|
437
|
+
|
438
|
+
def list_jobs(self) -> List[Dict[str, Any]]:
|
439
|
+
"""List all RunPod jobs."""
|
440
|
+
try:
|
441
|
+
pods = runpod.get_pods()
|
442
|
+
return [
|
443
|
+
{
|
444
|
+
"job_id": pod["id"],
|
445
|
+
"name": pod["name"],
|
446
|
+
"status": pod.get("runtime", {}).get("status", "unknown"),
|
447
|
+
"gpu_type": pod.get("gpuType", "unknown"),
|
448
|
+
"created": pod.get("createdAt", "")
|
449
|
+
}
|
450
|
+
for pod in pods
|
451
|
+
]
|
452
|
+
except Exception as e:
|
453
|
+
logger.error(f"Failed to list jobs: {e}")
|
454
|
+
return []
|