isa-model 0.2.0__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/storage/hf_storage.py +419 -0
  3. isa_model/deployment/__init__.py +52 -0
  4. isa_model/deployment/core/__init__.py +34 -0
  5. isa_model/deployment/core/deployment_config.py +356 -0
  6. isa_model/deployment/core/deployment_manager.py +549 -0
  7. isa_model/deployment/core/isa_deployment_service.py +401 -0
  8. isa_model/eval/factory.py +381 -140
  9. isa_model/inference/ai_factory.py +142 -240
  10. isa_model/inference/providers/ml_provider.py +50 -0
  11. isa_model/inference/services/audio/openai_tts_service.py +104 -3
  12. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  13. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  14. isa_model/inference/services/llm/__init__.py +2 -0
  15. isa_model/inference/services/llm/base_llm_service.py +111 -1
  16. isa_model/inference/services/llm/ollama_llm_service.py +234 -26
  17. isa_model/inference/services/llm/openai_llm_service.py +225 -28
  18. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  19. isa_model/inference/services/ml/base_ml_service.py +78 -0
  20. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  21. isa_model/inference/services/vision/__init__.py +3 -3
  22. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  23. isa_model/inference/services/vision/base_vision_service.py +177 -0
  24. isa_model/inference/services/vision/ollama_vision_service.py +143 -17
  25. isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
  26. isa_model/training/__init__.py +62 -32
  27. isa_model/training/cloud/__init__.py +22 -0
  28. isa_model/training/cloud/job_orchestrator.py +402 -0
  29. isa_model/training/cloud/runpod_trainer.py +454 -0
  30. isa_model/training/cloud/storage_manager.py +482 -0
  31. isa_model/training/core/__init__.py +23 -0
  32. isa_model/training/core/config.py +181 -0
  33. isa_model/training/core/dataset.py +222 -0
  34. isa_model/training/core/trainer.py +720 -0
  35. isa_model/training/core/utils.py +213 -0
  36. isa_model/training/factory.py +229 -198
  37. isa_model-0.2.8.dist-info/METADATA +465 -0
  38. isa_model-0.2.8.dist-info/RECORD +86 -0
  39. isa_model/core/model_router.py +0 -226
  40. isa_model/core/model_version.py +0 -0
  41. isa_model/core/resource_manager.py +0 -202
  42. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  43. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  44. isa_model/training/engine/llama_factory/__init__.py +0 -39
  45. isa_model/training/engine/llama_factory/config.py +0 -115
  46. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  47. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  48. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  49. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  50. isa_model/training/engine/llama_factory/factory.py +0 -331
  51. isa_model/training/engine/llama_factory/rl.py +0 -254
  52. isa_model/training/engine/llama_factory/trainer.py +0 -171
  53. isa_model/training/image_model/configs/create_config.py +0 -37
  54. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  55. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  56. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  57. isa_model/training/image_model/prepare_upload.py +0 -17
  58. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  59. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  60. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  61. isa_model/training/image_model/train/train.py +0 -42
  62. isa_model/training/image_model/train/train_flux.py +0 -41
  63. isa_model/training/image_model/train/train_lora.py +0 -57
  64. isa_model/training/image_model/train_main.py +0 -25
  65. isa_model-0.2.0.dist-info/METADATA +0 -327
  66. isa_model-0.2.0.dist-info/RECORD +0 -92
  67. isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
  68. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  69. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  70. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  71. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  72. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  73. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  74. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  75. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  76. {isa_model-0.2.0.dist-info → isa_model-0.2.8.dist-info}/WHEEL +0 -0
  77. {isa_model-0.2.0.dist-info → isa_model-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,8 @@
1
1
  """
2
- Unified Training Factory for ISA Model Framework
2
+ ISA Model Training Factory
3
3
 
4
- This factory provides a single interface for all training operations:
5
- - LLM fine-tuning (SFT, DPO, RLHF)
6
- - Image model training (Flux, LoRA)
7
- - Model evaluation and benchmarking
4
+ A clean, simplified training factory that uses HuggingFace Transformers directly
5
+ without external dependencies like LlamaFactory.
8
6
  """
9
7
 
10
8
  import os
@@ -13,43 +11,49 @@ from typing import Optional, Dict, Any, Union, List
13
11
  from pathlib import Path
14
12
  import datetime
15
13
 
16
- from .engine.llama_factory import LlamaFactory, TrainingStrategy, DatasetFormat
17
- from .engine.llama_factory.config import SFTConfig, RLConfig, DPOConfig
14
+ from .core import (
15
+ TrainingConfig,
16
+ LoRAConfig,
17
+ DatasetConfig,
18
+ BaseTrainer,
19
+ SFTTrainer,
20
+ TrainingUtils,
21
+ DatasetManager,
22
+ )
23
+ from .cloud import TrainingJobOrchestrator
18
24
 
19
25
  logger = logging.getLogger(__name__)
20
26
 
21
27
 
22
28
  class TrainingFactory:
23
29
  """
24
- Unified factory for all AI model training operations.
30
+ Unified Training Factory for ISA Model SDK
25
31
 
26
- This class provides simplified interfaces for:
27
- - LLM training using LlamaFactory
28
- - Image model training using Flux/LoRA
29
- - Model evaluation and benchmarking
32
+ Provides a clean interface for:
33
+ - Local training with SFT (Supervised Fine-Tuning)
34
+ - Cloud training on RunPod
35
+ - Model evaluation and management
30
36
 
31
- Example usage for fine-tuning Gemma 3:4B:
37
+ Example usage:
32
38
  ```python
33
39
  from isa_model.training import TrainingFactory
34
40
 
35
41
  factory = TrainingFactory()
36
42
 
37
- # Fine-tune with your dataset
38
- model_path = factory.finetune_llm(
43
+ # Local training
44
+ model_path = factory.train_model(
39
45
  model_name="google/gemma-2-4b-it",
40
- dataset_path="path/to/your/data.json",
41
- training_type="sft",
46
+ dataset_path="tatsu-lab/alpaca",
42
47
  use_lora=True,
43
- num_epochs=3,
44
- batch_size=4,
45
- learning_rate=2e-5
48
+ num_epochs=3
46
49
  )
47
50
 
48
- # Train with DPO for preference optimization
49
- dpo_model = factory.train_with_preferences(
50
- model_path=model_path,
51
- preference_data="path/to/preferences.json",
52
- beta=0.1
51
+ # Cloud training on RunPod
52
+ result = factory.train_on_runpod(
53
+ model_name="google/gemma-2-4b-it",
54
+ dataset_path="tatsu-lab/alpaca",
55
+ runpod_api_key="your-api-key",
56
+ template_id="your-template-id"
53
57
  )
54
58
  ```
55
59
  """
@@ -59,32 +63,19 @@ class TrainingFactory:
59
63
  Initialize the training factory.
60
64
 
61
65
  Args:
62
- base_output_dir: Base directory for all training outputs
66
+ base_output_dir: Base directory for training outputs
63
67
  """
64
68
  self.base_output_dir = base_output_dir or os.path.join(os.getcwd(), "training_outputs")
65
69
  os.makedirs(self.base_output_dir, exist_ok=True)
66
70
 
67
- # Initialize sub-factories
68
- self.llm_factory = LlamaFactory(base_output_dir=os.path.join(self.base_output_dir, "llm"))
69
-
70
71
  logger.info(f"TrainingFactory initialized with output dir: {self.base_output_dir}")
71
72
 
72
- def _get_output_dir(self, model_name: str, training_type: str) -> str:
73
- """Generate timestamped output directory."""
74
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
75
- safe_model_name = model_name.replace("/", "_").replace(":", "_")
76
- return os.path.join(self.base_output_dir, f"{safe_model_name}_{training_type}_{timestamp}")
77
-
78
- # =================
79
- # LLM Training Methods
80
- # =================
81
-
82
- def finetune_llm(
73
+ def train_model(
83
74
  self,
84
75
  model_name: str,
85
76
  dataset_path: str,
86
- training_type: str = "sft",
87
77
  output_dir: Optional[str] = None,
78
+ training_type: str = "sft",
88
79
  dataset_format: str = "alpaca",
89
80
  use_lora: bool = True,
90
81
  batch_size: int = 4,
@@ -93,17 +84,17 @@ class TrainingFactory:
93
84
  max_length: int = 1024,
94
85
  lora_rank: int = 8,
95
86
  lora_alpha: int = 16,
96
- val_dataset_path: Optional[str] = None,
87
+ validation_split: float = 0.1,
97
88
  **kwargs
98
89
  ) -> str:
99
90
  """
100
- Fine-tune an LLM model.
91
+ Train a model locally.
101
92
 
102
93
  Args:
103
- model_name: Model identifier (e.g., "google/gemma-2-4b-it", "meta-llama/Llama-2-7b-hf")
104
- dataset_path: Path to training dataset
105
- training_type: Type of training ("sft", "dpo", "rlhf")
94
+ model_name: Model identifier (e.g., "google/gemma-2-4b-it")
95
+ dataset_path: Path to dataset or HuggingFace dataset name
106
96
  output_dir: Custom output directory
97
+ training_type: Type of training ("sft" supported)
107
98
  dataset_format: Dataset format ("alpaca", "sharegpt", "custom")
108
99
  use_lora: Whether to use LoRA for efficient training
109
100
  batch_size: Training batch size
@@ -112,7 +103,7 @@ class TrainingFactory:
112
103
  max_length: Maximum sequence length
113
104
  lora_rank: LoRA rank parameter
114
105
  lora_alpha: LoRA alpha parameter
115
- val_dataset_path: Path to validation dataset (optional)
106
+ validation_split: Fraction of data for validation
116
107
  **kwargs: Additional training parameters
117
108
 
118
109
  Returns:
@@ -120,184 +111,207 @@ class TrainingFactory:
120
111
 
121
112
  Example:
122
113
  ```python
123
- # Fine-tune Gemma 3:4B with your dataset
124
- model_path = factory.finetune_llm(
114
+ model_path = factory.train_model(
125
115
  model_name="google/gemma-2-4b-it",
126
- dataset_path="my_training_data.json",
127
- training_type="sft",
116
+ dataset_path="tatsu-lab/alpaca",
128
117
  use_lora=True,
129
118
  num_epochs=3,
130
119
  batch_size=4
131
120
  )
132
121
  ```
133
122
  """
123
+ # Generate output directory if not provided
134
124
  if not output_dir:
135
- output_dir = self._get_output_dir(model_name, training_type)
125
+ output_dir = TrainingUtils.generate_output_dir(
126
+ model_name, training_type, self.base_output_dir
127
+ )
136
128
 
137
- # Convert format string to enum
138
- format_map = {
139
- "alpaca": DatasetFormat.ALPACA,
140
- "sharegpt": DatasetFormat.SHAREGPT,
141
- "custom": DatasetFormat.CUSTOM
142
- }
143
- dataset_format_enum = format_map.get(dataset_format, DatasetFormat.ALPACA)
129
+ # Create configurations
130
+ lora_config = LoRAConfig(
131
+ use_lora=use_lora,
132
+ lora_rank=lora_rank,
133
+ lora_alpha=lora_alpha
134
+ ) if use_lora else None
144
135
 
145
- if training_type.lower() == "sft":
146
- return self.llm_factory.finetune(
147
- model_path=model_name,
148
- train_data=dataset_path,
149
- val_data=val_dataset_path,
150
- output_dir=output_dir,
151
- dataset_format=dataset_format_enum,
152
- use_lora=use_lora,
153
- batch_size=batch_size,
154
- num_epochs=num_epochs,
155
- learning_rate=learning_rate,
156
- max_length=max_length,
157
- lora_rank=lora_rank,
158
- lora_alpha=lora_alpha,
159
- **kwargs
160
- )
161
- else:
162
- raise ValueError(f"Training type '{training_type}' not supported yet. Use 'sft' for now.")
163
-
164
- def train_with_preferences(
165
- self,
166
- model_path: str,
167
- preference_data: str,
168
- output_dir: Optional[str] = None,
169
- reference_model: Optional[str] = None,
170
- beta: float = 0.1,
171
- use_lora: bool = True,
172
- batch_size: int = 4,
173
- num_epochs: int = 3,
174
- learning_rate: float = 5e-6,
175
- val_data: Optional[str] = None,
176
- **kwargs
177
- ) -> str:
178
- """
179
- Train model with preference data using DPO.
136
+ dataset_config = DatasetConfig(
137
+ dataset_path=dataset_path,
138
+ dataset_format=dataset_format,
139
+ max_length=max_length,
140
+ validation_split=validation_split
141
+ )
180
142
 
181
- Args:
182
- model_path: Path to the base model
183
- preference_data: Path to preference dataset
184
- output_dir: Custom output directory
185
- reference_model: Reference model for DPO (optional)
186
- beta: DPO beta parameter
187
- use_lora: Whether to use LoRA
188
- batch_size: Training batch size
189
- num_epochs: Number of epochs
190
- learning_rate: Learning rate
191
- val_data: Validation data path
192
- **kwargs: Additional parameters
193
-
194
- Returns:
195
- Path to the trained model
196
- """
197
- if not output_dir:
198
- model_name = os.path.basename(model_path)
199
- output_dir = self._get_output_dir(model_name, "dpo")
200
-
201
- return self.llm_factory.dpo(
202
- model_path=model_path,
203
- train_data=preference_data,
204
- val_data=val_data,
205
- reference_model=reference_model,
143
+ training_config = TrainingConfig(
144
+ model_name=model_name,
206
145
  output_dir=output_dir,
207
- use_lora=use_lora,
208
- batch_size=batch_size,
146
+ training_type=training_type,
209
147
  num_epochs=num_epochs,
148
+ batch_size=batch_size,
210
149
  learning_rate=learning_rate,
211
- beta=beta,
150
+ lora_config=lora_config,
151
+ dataset_config=dataset_config,
212
152
  **kwargs
213
153
  )
154
+
155
+ # Print training summary
156
+ model_info = TrainingUtils.get_model_info(model_name)
157
+ memory_estimate = TrainingUtils.estimate_memory_usage(
158
+ model_name, batch_size, max_length, use_lora
159
+ )
160
+
161
+ summary = TrainingUtils.format_training_summary(
162
+ training_config.to_dict(), model_info, memory_estimate
163
+ )
164
+ print(summary)
165
+
166
+ # Validate configuration
167
+ issues = TrainingUtils.validate_training_config(training_config.to_dict())
168
+ if issues:
169
+ raise ValueError(f"Training configuration issues: {issues}")
170
+
171
+ # Initialize trainer based on training type
172
+ if training_type.lower() == "sft":
173
+ trainer = SFTTrainer(training_config)
174
+ else:
175
+ raise ValueError(f"Training type '{training_type}' not supported yet")
176
+
177
+ # Execute training
178
+ logger.info(f"Starting {training_type.upper()} training...")
179
+ result_path = trainer.train()
180
+
181
+ logger.info(f"Training completed! Model saved to: {result_path}")
182
+ return result_path
214
183
 
215
- def train_reward_model(
184
+ def train_on_runpod(
216
185
  self,
217
- model_path: str,
218
- reward_data: str,
219
- output_dir: Optional[str] = None,
220
- use_lora: bool = True,
221
- batch_size: int = 8,
222
- num_epochs: int = 3,
223
- learning_rate: float = 1e-5,
224
- val_data: Optional[str] = None,
225
- **kwargs
226
- ) -> str:
186
+ model_name: str,
187
+ dataset_path: str,
188
+ runpod_api_key: str,
189
+ template_id: str,
190
+ gpu_type: str = "NVIDIA RTX A6000",
191
+ storage_config: Optional[Dict[str, Any]] = None,
192
+ job_name: Optional[str] = None,
193
+ **training_params
194
+ ) -> Dict[str, Any]:
227
195
  """
228
- Train a reward model for RLHF.
196
+ Train a model on RunPod cloud infrastructure.
229
197
 
230
198
  Args:
231
- model_path: Base model path
232
- reward_data: Reward training data
233
- output_dir: Output directory
234
- use_lora: Whether to use LoRA
235
- batch_size: Batch size
236
- num_epochs: Number of epochs
237
- learning_rate: Learning rate
238
- val_data: Validation data
239
- **kwargs: Additional parameters
199
+ model_name: Model identifier
200
+ dataset_path: Dataset path or HuggingFace dataset name
201
+ runpod_api_key: RunPod API key
202
+ template_id: RunPod template ID
203
+ gpu_type: GPU type to use
204
+ storage_config: Optional cloud storage configuration
205
+ job_name: Optional job name
206
+ **training_params: Additional training parameters
240
207
 
241
208
  Returns:
242
- Path to trained reward model
209
+ Training job results
210
+
211
+ Example:
212
+ ```python
213
+ result = factory.train_on_runpod(
214
+ model_name="google/gemma-2-4b-it",
215
+ dataset_path="tatsu-lab/alpaca",
216
+ runpod_api_key="your-api-key",
217
+ template_id="your-template-id",
218
+ use_lora=True,
219
+ num_epochs=3
220
+ )
221
+ ```
243
222
  """
244
- if not output_dir:
245
- model_name = os.path.basename(model_path)
246
- output_dir = self._get_output_dir(model_name, "reward")
223
+ # Import cloud components
224
+ from .cloud import TrainingJobOrchestrator
225
+ from .cloud.runpod_trainer import RunPodConfig
226
+ from .cloud.storage_manager import StorageConfig
227
+ from .cloud.job_orchestrator import JobConfig
247
228
 
248
- return self.llm_factory.train_reward_model(
249
- model_path=model_path,
250
- train_data=reward_data,
251
- val_data=val_data,
252
- output_dir=output_dir,
253
- use_lora=use_lora,
254
- batch_size=batch_size,
255
- num_epochs=num_epochs,
256
- learning_rate=learning_rate,
257
- **kwargs
229
+ # Create RunPod configuration
230
+ runpod_config = RunPodConfig(
231
+ api_key=runpod_api_key,
232
+ template_id=template_id,
233
+ gpu_type=gpu_type
258
234
  )
235
+
236
+ # Create storage configuration if provided
237
+ storage_cfg = None
238
+ if storage_config:
239
+ storage_cfg = StorageConfig(**storage_config)
240
+
241
+ # Create job configuration
242
+ job_config = JobConfig(
243
+ model_name=model_name,
244
+ dataset_source=dataset_path,
245
+ job_name=job_name or f"gemma-training-{int(datetime.datetime.now().timestamp())}",
246
+ **training_params
247
+ )
248
+
249
+ # Initialize orchestrator and execute training
250
+ orchestrator = TrainingJobOrchestrator(
251
+ runpod_config=runpod_config,
252
+ storage_config=storage_cfg
253
+ )
254
+
255
+ logger.info(f"Starting RunPod training for {model_name}")
256
+ result = orchestrator.execute_training_workflow(job_config)
257
+
258
+ return result
259
259
 
260
- # =================
261
- # Image Model Training Methods
262
- # =================
263
-
264
- def train_image_model(
260
+ async def upload_to_huggingface(
265
261
  self,
266
- model_type: str = "flux",
267
- training_images_dir: str = "",
268
- output_dir: Optional[str] = None,
269
- use_lora: bool = True,
270
- num_epochs: int = 1000,
271
- batch_size: int = 1,
272
- learning_rate: float = 1e-4,
273
- **kwargs
262
+ model_path: str,
263
+ hf_model_name: str,
264
+ hf_token: Optional[str] = None,
265
+ metadata: Optional[Dict[str, Any]] = None
274
266
  ) -> str:
275
267
  """
276
- Train an image generation model.
268
+ Upload a trained model to HuggingFace Hub using HuggingFaceStorage.
277
269
 
278
270
  Args:
279
- model_type: Type of model ("flux", "lora")
280
- training_images_dir: Directory containing training images
281
- output_dir: Output directory
282
- use_lora: Whether to use LoRA
283
- num_epochs: Training epochs
284
- batch_size: Batch size
285
- learning_rate: Learning rate
286
- **kwargs: Additional parameters
271
+ model_path: Path to the trained model
272
+ hf_model_name: Name for the model on HuggingFace Hub
273
+ hf_token: HuggingFace token
274
+ metadata: Additional metadata for the model
287
275
 
288
276
  Returns:
289
- Path to trained model
277
+ URL of the uploaded model
290
278
  """
291
- if not output_dir:
292
- output_dir = self._get_output_dir("image_model", model_type)
293
-
294
- # TODO: Implement image model training
295
- logger.warning("Image model training not fully implemented yet")
296
- return output_dir
297
-
298
- # =================
299
- # Utility Methods
300
- # =================
279
+ try:
280
+ from ..core.storage.hf_storage import HuggingFaceStorage
281
+
282
+ logger.info(f"Uploading model to HuggingFace: {hf_model_name}")
283
+
284
+ # Initialize HuggingFace storage
285
+ storage = HuggingFaceStorage(
286
+ username="xenobordom",
287
+ token=hf_token
288
+ )
289
+
290
+ # Prepare metadata
291
+ upload_metadata = metadata or {}
292
+ upload_metadata.update({
293
+ "description": f"Fine-tuned model: {hf_model_name}",
294
+ "training_framework": "ISA Model SDK",
295
+ "uploaded_from": "training_factory"
296
+ })
297
+
298
+ # Upload model
299
+ success = await storage.save_model(
300
+ model_id=hf_model_name,
301
+ model_path=model_path,
302
+ metadata=upload_metadata
303
+ )
304
+
305
+ if success:
306
+ model_url = storage.get_public_url(hf_model_name)
307
+ logger.info(f"Model uploaded successfully: {model_url}")
308
+ return model_url
309
+ else:
310
+ raise Exception("Failed to upload model")
311
+
312
+ except Exception as e:
313
+ logger.error(f"Failed to upload to HuggingFace: {e}")
314
+ raise
301
315
 
302
316
  def get_training_status(self, output_dir: str) -> Dict[str, Any]:
303
317
  """
@@ -318,6 +332,21 @@ class TrainingFactory:
318
332
  if status["exists"]:
319
333
  status["files"] = os.listdir(output_dir)
320
334
 
335
+ # Check for specific files
336
+ config_path = os.path.join(output_dir, "training_config.json")
337
+ metrics_path = os.path.join(output_dir, "training_metrics.json")
338
+ model_path = os.path.join(output_dir, "pytorch_model.bin")
339
+
340
+ status["has_config"] = os.path.exists(config_path)
341
+ status["has_metrics"] = os.path.exists(metrics_path)
342
+ status["has_model"] = os.path.exists(model_path) or os.path.exists(os.path.join(output_dir, "adapter_model.bin"))
343
+
344
+ if status["has_config"]:
345
+ try:
346
+ status["config"] = TrainingUtils.load_training_args(output_dir)
347
+ except:
348
+ pass
349
+
321
350
  return status
322
351
 
323
352
  def list_trained_models(self) -> List[Dict[str, Any]]:
@@ -333,26 +362,28 @@ class TrainingFactory:
333
362
  for item in os.listdir(self.base_output_dir):
334
363
  item_path = os.path.join(self.base_output_dir, item)
335
364
  if os.path.isdir(item_path):
365
+ status = self.get_training_status(item_path)
336
366
  models.append({
337
367
  "name": item,
338
368
  "path": item_path,
339
369
  "created": datetime.datetime.fromtimestamp(
340
370
  os.path.getctime(item_path)
341
- ).isoformat()
371
+ ).isoformat(),
372
+ "status": status
342
373
  })
343
374
 
344
375
  return sorted(models, key=lambda x: x["created"], reverse=True)
345
376
 
346
377
 
347
378
  # Convenience functions for quick access
348
- def finetune_gemma(
379
+ def train_gemma(
349
380
  dataset_path: str,
350
381
  model_size: str = "4b",
351
382
  output_dir: Optional[str] = None,
352
383
  **kwargs
353
384
  ) -> str:
354
385
  """
355
- Quick function to fine-tune Gemma models.
386
+ Quick function to train Gemma models.
356
387
 
357
388
  Args:
358
389
  dataset_path: Path to training dataset
@@ -361,14 +392,14 @@ def finetune_gemma(
361
392
  **kwargs: Additional training parameters
362
393
 
363
394
  Returns:
364
- Path to fine-tuned model
395
+ Path to trained model
365
396
 
366
397
  Example:
367
398
  ```python
368
- from isa_model.training import finetune_gemma
399
+ from isa_model.training import train_gemma
369
400
 
370
- model_path = finetune_gemma(
371
- dataset_path="my_data.json",
401
+ model_path = train_gemma(
402
+ dataset_path="tatsu-lab/alpaca",
372
403
  model_size="4b",
373
404
  num_epochs=3,
374
405
  batch_size=4
@@ -385,7 +416,7 @@ def finetune_gemma(
385
416
 
386
417
  model_name = model_map.get(model_size, "google/gemma-2-4b-it")
387
418
 
388
- return factory.finetune_llm(
419
+ return factory.train_model(
389
420
  model_name=model_name,
390
421
  dataset_path=dataset_path,
391
422
  output_dir=output_dir,