gptmed 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -3
- gptmed/model/__init__.py +2 -2
- gptmed/observability/__init__.py +43 -0
- gptmed/observability/base.py +369 -0
- gptmed/observability/callbacks.py +397 -0
- gptmed/observability/metrics_tracker.py +544 -0
- gptmed/services/__init__.py +15 -0
- gptmed/services/device_manager.py +252 -0
- gptmed/services/training_service.py +489 -0
- gptmed/training/trainer.py +124 -10
- gptmed/utils/checkpoints.py +1 -1
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/METADATA +180 -43
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/RECORD +17 -10
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/WHEEL +0 -0
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/entry_points.txt +0 -0
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,489 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Training Service
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Encapsulates training logic following Service Layer Pattern.
|
|
6
|
+
Provides a high-level interface for model training with device flexibility.
|
|
7
|
+
|
|
8
|
+
DESIGN PATTERNS:
|
|
9
|
+
- Service Layer Pattern: Business logic separated from API layer
|
|
10
|
+
- Dependency Injection: DeviceManager injected for flexibility
|
|
11
|
+
- Single Responsibility: Only handles training orchestration
|
|
12
|
+
- Open/Closed Principle: Extensible without modification
|
|
13
|
+
|
|
14
|
+
WHAT THIS FILE DOES:
|
|
15
|
+
1. Orchestrates the training process
|
|
16
|
+
2. Manages device configuration via DeviceManager
|
|
17
|
+
3. Coordinates model, data, optimizer, and trainer
|
|
18
|
+
4. Provides clean interface for training operations
|
|
19
|
+
|
|
20
|
+
PACKAGES USED:
|
|
21
|
+
- torch: PyTorch training
|
|
22
|
+
- pathlib: Path handling
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import random
|
|
27
|
+
import numpy as np
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import Dict, Any, Optional, List
|
|
30
|
+
|
|
31
|
+
from gptmed.services.device_manager import DeviceManager
|
|
32
|
+
from gptmed.model.architecture import GPTTransformer
|
|
33
|
+
from gptmed.model.configs.model_config import get_tiny_config, get_small_config, get_medium_config
|
|
34
|
+
from gptmed.configs.train_config import TrainingConfig
|
|
35
|
+
from gptmed.training.dataset import create_dataloaders
|
|
36
|
+
from gptmed.training.trainer import Trainer
|
|
37
|
+
|
|
38
|
+
# Observability imports
|
|
39
|
+
from gptmed.observability import (
|
|
40
|
+
TrainingObserver,
|
|
41
|
+
MetricsTracker,
|
|
42
|
+
ConsoleCallback,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TrainingService:
|
|
47
|
+
"""
|
|
48
|
+
High-level service for model training.
|
|
49
|
+
|
|
50
|
+
Implements Service Layer Pattern to encapsulate training logic.
|
|
51
|
+
Uses Dependency Injection for DeviceManager.
|
|
52
|
+
|
|
53
|
+
Example:
|
|
54
|
+
>>> device_manager = DeviceManager(preferred_device='cpu')
|
|
55
|
+
>>> service = TrainingService(device_manager=device_manager)
|
|
56
|
+
>>> results = service.train_from_config('config.yaml', verbose=True)
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
device_manager: Optional[DeviceManager] = None,
|
|
62
|
+
verbose: bool = True
|
|
63
|
+
):
|
|
64
|
+
"""
|
|
65
|
+
Initialize TrainingService.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
device_manager: DeviceManager instance (if None, creates default)
|
|
69
|
+
verbose: Whether to print training information
|
|
70
|
+
"""
|
|
71
|
+
self.device_manager = device_manager or DeviceManager(preferred_device='cuda')
|
|
72
|
+
self.verbose = verbose
|
|
73
|
+
|
|
74
|
+
def set_seed(self, seed: int) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Set random seeds for reproducibility.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
seed: Random seed value
|
|
80
|
+
"""
|
|
81
|
+
random.seed(seed)
|
|
82
|
+
np.random.seed(seed)
|
|
83
|
+
torch.manual_seed(seed)
|
|
84
|
+
if torch.cuda.is_available():
|
|
85
|
+
torch.cuda.manual_seed(seed)
|
|
86
|
+
torch.cuda.manual_seed_all(seed)
|
|
87
|
+
torch.backends.cudnn.deterministic = True
|
|
88
|
+
torch.backends.cudnn.benchmark = False
|
|
89
|
+
|
|
90
|
+
def create_model(self, model_size: str) -> GPTTransformer:
|
|
91
|
+
"""
|
|
92
|
+
Create model based on size specification.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
model_size: Model size ('tiny', 'small', or 'medium')
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
GPTTransformer model instance
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ValueError: If model_size is invalid
|
|
102
|
+
"""
|
|
103
|
+
if model_size == 'tiny':
|
|
104
|
+
model_config = get_tiny_config()
|
|
105
|
+
elif model_size == 'small':
|
|
106
|
+
model_config = get_small_config()
|
|
107
|
+
elif model_size == 'medium':
|
|
108
|
+
model_config = get_medium_config()
|
|
109
|
+
else:
|
|
110
|
+
raise ValueError(f"Unknown model size: {model_size}")
|
|
111
|
+
|
|
112
|
+
return GPTTransformer(model_config)
|
|
113
|
+
|
|
114
|
+
def prepare_training(
|
|
115
|
+
self,
|
|
116
|
+
model: GPTTransformer,
|
|
117
|
+
train_config: TrainingConfig,
|
|
118
|
+
device: str
|
|
119
|
+
) -> tuple:
|
|
120
|
+
"""
|
|
121
|
+
Prepare components for training.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
model: Model to train
|
|
125
|
+
train_config: Training configuration
|
|
126
|
+
device: Device to use
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Tuple of (train_loader, val_loader, optimizer)
|
|
130
|
+
"""
|
|
131
|
+
# Load data
|
|
132
|
+
if self.verbose:
|
|
133
|
+
print(f"\n📊 Loading data...")
|
|
134
|
+
print(f" Train: {train_config.train_data_path}")
|
|
135
|
+
print(f" Val: {train_config.val_data_path}")
|
|
136
|
+
|
|
137
|
+
train_loader, val_loader = create_dataloaders(
|
|
138
|
+
train_path=Path(train_config.train_data_path),
|
|
139
|
+
val_path=Path(train_config.val_data_path),
|
|
140
|
+
batch_size=train_config.batch_size,
|
|
141
|
+
num_workers=0,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if self.verbose:
|
|
145
|
+
print(f" Train batches: {len(train_loader)}")
|
|
146
|
+
print(f" Val batches: {len(val_loader)}")
|
|
147
|
+
|
|
148
|
+
# Create optimizer
|
|
149
|
+
if self.verbose:
|
|
150
|
+
print(f"\n⚙️ Setting up optimizer...")
|
|
151
|
+
print(f" Learning rate: {train_config.learning_rate}")
|
|
152
|
+
print(f" Weight decay: {train_config.weight_decay}")
|
|
153
|
+
|
|
154
|
+
optimizer = torch.optim.AdamW(
|
|
155
|
+
model.parameters(),
|
|
156
|
+
lr=train_config.learning_rate,
|
|
157
|
+
betas=train_config.betas,
|
|
158
|
+
eps=train_config.eps,
|
|
159
|
+
weight_decay=train_config.weight_decay,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
return train_loader, val_loader, optimizer
|
|
163
|
+
|
|
164
|
+
def execute_training(
|
|
165
|
+
self,
|
|
166
|
+
model: GPTTransformer,
|
|
167
|
+
train_loader,
|
|
168
|
+
val_loader,
|
|
169
|
+
optimizer,
|
|
170
|
+
train_config: TrainingConfig,
|
|
171
|
+
device: str,
|
|
172
|
+
model_config_dict: dict,
|
|
173
|
+
observers: Optional[List[TrainingObserver]] = None,
|
|
174
|
+
) -> Dict[str, Any]:
|
|
175
|
+
"""
|
|
176
|
+
Execute the training process.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
model: Model to train
|
|
180
|
+
train_loader: Training data loader
|
|
181
|
+
val_loader: Validation data loader
|
|
182
|
+
optimizer: Optimizer
|
|
183
|
+
train_config: Training configuration
|
|
184
|
+
device: Device to use
|
|
185
|
+
model_config_dict: Model configuration as dictionary
|
|
186
|
+
observers: Optional list of TrainingObserver instances.
|
|
187
|
+
If None, default observers (MetricsTracker) will be used.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Dictionary with training results
|
|
191
|
+
"""
|
|
192
|
+
# Set up default observers if none provided
|
|
193
|
+
if observers is None:
|
|
194
|
+
observers = self._create_default_observers(train_config)
|
|
195
|
+
|
|
196
|
+
# Create trainer with observers
|
|
197
|
+
if self.verbose:
|
|
198
|
+
print(f"\n🎯 Initializing trainer...")
|
|
199
|
+
print(f" Observers: {len(observers)} ({', '.join(o.name for o in observers)})")
|
|
200
|
+
|
|
201
|
+
trainer = Trainer(
|
|
202
|
+
model=model,
|
|
203
|
+
train_loader=train_loader,
|
|
204
|
+
val_loader=val_loader,
|
|
205
|
+
optimizer=optimizer,
|
|
206
|
+
config=train_config,
|
|
207
|
+
device=device,
|
|
208
|
+
observers=observers,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Resume if requested
|
|
212
|
+
if hasattr(train_config, 'resume_from') and train_config.resume_from is not None:
|
|
213
|
+
if self.verbose:
|
|
214
|
+
print(f"\n📥 Resuming from checkpoint: {train_config.resume_from}")
|
|
215
|
+
trainer.resume_from_checkpoint(Path(train_config.resume_from))
|
|
216
|
+
elif train_config.checkpoint_dir and hasattr(train_config, 'checkpoint_dir'):
|
|
217
|
+
# Check if there's a resume_from in the checkpoint dir
|
|
218
|
+
resume_path = Path(train_config.checkpoint_dir) / "resume_from.pt"
|
|
219
|
+
if resume_path.exists() and self.verbose:
|
|
220
|
+
print(f"\n📥 Found checkpoint to resume: {resume_path}")
|
|
221
|
+
|
|
222
|
+
# Start training
|
|
223
|
+
if self.verbose:
|
|
224
|
+
print(f"\n{'='*60}")
|
|
225
|
+
print("🚀 Starting Training!")
|
|
226
|
+
print(f"{'='*60}\n")
|
|
227
|
+
|
|
228
|
+
interrupted = False
|
|
229
|
+
try:
|
|
230
|
+
trainer.train()
|
|
231
|
+
except KeyboardInterrupt:
|
|
232
|
+
interrupted = True
|
|
233
|
+
if self.verbose:
|
|
234
|
+
print("\n\n⏸️ Training interrupted by user")
|
|
235
|
+
print("💾 Saving checkpoint...")
|
|
236
|
+
trainer.checkpoint_manager.save_checkpoint(
|
|
237
|
+
model=model,
|
|
238
|
+
optimizer=optimizer,
|
|
239
|
+
step=trainer.global_step,
|
|
240
|
+
epoch=trainer.current_epoch,
|
|
241
|
+
val_loss=trainer.best_val_loss,
|
|
242
|
+
model_config=model_config_dict,
|
|
243
|
+
train_config=train_config.to_dict(),
|
|
244
|
+
)
|
|
245
|
+
if self.verbose:
|
|
246
|
+
print("✓ Checkpoint saved. Resume with resume_from in config.")
|
|
247
|
+
|
|
248
|
+
# Generate observability reports (on BOTH normal and abnormal exit)
|
|
249
|
+
self._generate_observability_reports(
|
|
250
|
+
observers=observers,
|
|
251
|
+
train_config=train_config,
|
|
252
|
+
trainer=trainer,
|
|
253
|
+
interrupted=interrupted,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Return results
|
|
257
|
+
final_checkpoint = Path(train_config.checkpoint_dir) / "final_model.pt"
|
|
258
|
+
|
|
259
|
+
results = {
|
|
260
|
+
'final_checkpoint': str(final_checkpoint),
|
|
261
|
+
'best_checkpoint': str(final_checkpoint), # Alias for backward compatibility
|
|
262
|
+
'final_val_loss': trainer.best_val_loss,
|
|
263
|
+
'total_epochs': trainer.current_epoch,
|
|
264
|
+
'total_steps': trainer.global_step,
|
|
265
|
+
'checkpoint_dir': train_config.checkpoint_dir,
|
|
266
|
+
'log_dir': train_config.log_dir,
|
|
267
|
+
'interrupted': interrupted,
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
# Get training issues from metrics tracker
|
|
271
|
+
metrics_tracker = self._get_metrics_tracker(observers)
|
|
272
|
+
if metrics_tracker:
|
|
273
|
+
results['training_issues'] = metrics_tracker.detect_issues()
|
|
274
|
+
|
|
275
|
+
if self.verbose:
|
|
276
|
+
status = "⏸️ Training Interrupted" if interrupted else "✅ Training Complete!"
|
|
277
|
+
print(f"\n{'='*60}")
|
|
278
|
+
print(status)
|
|
279
|
+
print(f"{'='*60}")
|
|
280
|
+
print(f"\n📁 Results:")
|
|
281
|
+
print(f" Final checkpoint: {results['final_checkpoint']}")
|
|
282
|
+
print(f" Best val loss: {results['final_val_loss']:.4f}")
|
|
283
|
+
print(f" Total steps: {results['total_steps']}")
|
|
284
|
+
print(f" Total epochs: {results['total_epochs']}")
|
|
285
|
+
print(f" Logs: {results['log_dir']}")
|
|
286
|
+
|
|
287
|
+
return results
|
|
288
|
+
|
|
289
|
+
def _generate_observability_reports(
|
|
290
|
+
self,
|
|
291
|
+
observers: List[TrainingObserver],
|
|
292
|
+
train_config: TrainingConfig,
|
|
293
|
+
trainer,
|
|
294
|
+
interrupted: bool = False,
|
|
295
|
+
) -> None:
|
|
296
|
+
"""
|
|
297
|
+
Generate observability reports from metrics tracker.
|
|
298
|
+
|
|
299
|
+
Called on both normal completion and abnormal exit (Ctrl+C).
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
observers: List of training observers
|
|
303
|
+
train_config: Training configuration
|
|
304
|
+
trainer: Trainer instance
|
|
305
|
+
interrupted: Whether training was interrupted
|
|
306
|
+
"""
|
|
307
|
+
metrics_tracker = self._get_metrics_tracker(observers)
|
|
308
|
+
|
|
309
|
+
if not metrics_tracker:
|
|
310
|
+
if self.verbose:
|
|
311
|
+
print("\n⚠️ No MetricsTracker found - skipping observability reports")
|
|
312
|
+
return
|
|
313
|
+
|
|
314
|
+
if self.verbose:
|
|
315
|
+
print(f"\n{'='*60}")
|
|
316
|
+
print("📊 Generating Observability Reports")
|
|
317
|
+
print(f"{'='*60}")
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
# Export metrics to CSV
|
|
321
|
+
csv_path = metrics_tracker.export_to_csv()
|
|
322
|
+
if self.verbose:
|
|
323
|
+
print(f" ✓ CSV exported: {csv_path}")
|
|
324
|
+
except Exception as e:
|
|
325
|
+
if self.verbose:
|
|
326
|
+
print(f" ✗ CSV export failed: {e}")
|
|
327
|
+
|
|
328
|
+
try:
|
|
329
|
+
# Export metrics to JSON
|
|
330
|
+
json_path = metrics_tracker.export_to_json()
|
|
331
|
+
if self.verbose:
|
|
332
|
+
print(f" ✓ JSON exported: {json_path}")
|
|
333
|
+
except Exception as e:
|
|
334
|
+
if self.verbose:
|
|
335
|
+
print(f" ✗ JSON export failed: {e}")
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
# Generate loss curve plots
|
|
339
|
+
plot_path = metrics_tracker.plot_loss_curves()
|
|
340
|
+
if plot_path and self.verbose:
|
|
341
|
+
print(f" ✓ Loss curves plotted: {plot_path}")
|
|
342
|
+
except ImportError:
|
|
343
|
+
if self.verbose:
|
|
344
|
+
print(f" ⚠️ Plotting skipped (matplotlib not installed)")
|
|
345
|
+
except Exception as e:
|
|
346
|
+
if self.verbose:
|
|
347
|
+
print(f" ✗ Plotting failed: {e}")
|
|
348
|
+
|
|
349
|
+
# Training health check
|
|
350
|
+
issues = metrics_tracker.detect_issues()
|
|
351
|
+
if self.verbose:
|
|
352
|
+
print(f"\n📋 Training Health Check:")
|
|
353
|
+
for issue in issues:
|
|
354
|
+
print(f" {issue}")
|
|
355
|
+
|
|
356
|
+
# Add interrupted notice if applicable
|
|
357
|
+
if interrupted and self.verbose:
|
|
358
|
+
print(f"\n⚠️ Note: Training was interrupted at step {trainer.global_step}")
|
|
359
|
+
print(f" Reports reflect partial training data only.")
|
|
360
|
+
|
|
361
|
+
def _create_default_observers(self, train_config: TrainingConfig) -> List[TrainingObserver]:
|
|
362
|
+
"""
|
|
363
|
+
Create default observers for training.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
train_config: Training configuration
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
List of default TrainingObserver instances
|
|
370
|
+
"""
|
|
371
|
+
observers = []
|
|
372
|
+
|
|
373
|
+
# MetricsTracker - comprehensive metrics logging
|
|
374
|
+
metrics_tracker = MetricsTracker(
|
|
375
|
+
log_dir=train_config.log_dir,
|
|
376
|
+
experiment_name="gptmed_training",
|
|
377
|
+
moving_avg_window=100,
|
|
378
|
+
log_interval=train_config.log_interval,
|
|
379
|
+
verbose=self.verbose,
|
|
380
|
+
)
|
|
381
|
+
observers.append(metrics_tracker)
|
|
382
|
+
|
|
383
|
+
# Note: ConsoleCallback is optional since Trainer already has console output
|
|
384
|
+
# Uncomment if you want additional formatted console output:
|
|
385
|
+
# console_callback = ConsoleCallback(log_interval=train_config.log_interval)
|
|
386
|
+
# observers.append(console_callback)
|
|
387
|
+
|
|
388
|
+
return observers
|
|
389
|
+
|
|
390
|
+
def _get_metrics_tracker(self, observers: List[TrainingObserver]) -> Optional[MetricsTracker]:
|
|
391
|
+
"""
|
|
392
|
+
Get MetricsTracker from observers list if present.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
observers: List of observers
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
MetricsTracker instance or None
|
|
399
|
+
"""
|
|
400
|
+
for obs in observers:
|
|
401
|
+
if isinstance(obs, MetricsTracker):
|
|
402
|
+
return obs
|
|
403
|
+
return None
|
|
404
|
+
|
|
405
|
+
return results
|
|
406
|
+
|
|
407
|
+
def train(
|
|
408
|
+
self,
|
|
409
|
+
model_size: str,
|
|
410
|
+
train_data_path: str,
|
|
411
|
+
val_data_path: str,
|
|
412
|
+
batch_size: int = 16,
|
|
413
|
+
learning_rate: float = 3e-4,
|
|
414
|
+
num_epochs: int = 10,
|
|
415
|
+
checkpoint_dir: str = "./model/checkpoints",
|
|
416
|
+
log_dir: str = "./logs",
|
|
417
|
+
seed: int = 42,
|
|
418
|
+
**kwargs
|
|
419
|
+
) -> Dict[str, Any]:
|
|
420
|
+
"""
|
|
421
|
+
High-level training interface.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
model_size: Model size ('tiny', 'small', 'medium')
|
|
425
|
+
train_data_path: Path to training data
|
|
426
|
+
val_data_path: Path to validation data
|
|
427
|
+
batch_size: Training batch size
|
|
428
|
+
learning_rate: Learning rate
|
|
429
|
+
num_epochs: Number of training epochs
|
|
430
|
+
checkpoint_dir: Directory for checkpoints
|
|
431
|
+
log_dir: Directory for logs
|
|
432
|
+
seed: Random seed
|
|
433
|
+
**kwargs: Additional training config parameters
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
Dictionary with training results
|
|
437
|
+
"""
|
|
438
|
+
# Set seed
|
|
439
|
+
if self.verbose:
|
|
440
|
+
print(f"\n🎲 Setting random seed: {seed}")
|
|
441
|
+
self.set_seed(seed)
|
|
442
|
+
|
|
443
|
+
# Get device
|
|
444
|
+
device = self.device_manager.get_device()
|
|
445
|
+
self.device_manager.print_device_info(verbose=self.verbose)
|
|
446
|
+
|
|
447
|
+
# Create model
|
|
448
|
+
if self.verbose:
|
|
449
|
+
print(f"\n🧠 Creating model: {model_size}")
|
|
450
|
+
|
|
451
|
+
model = self.create_model(model_size)
|
|
452
|
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
453
|
+
|
|
454
|
+
if self.verbose:
|
|
455
|
+
print(f" Model size: {model_size}")
|
|
456
|
+
print(f" Parameters: {total_params:,}")
|
|
457
|
+
print(f" Memory: ~{total_params * 4 / 1024 / 1024:.2f} MB")
|
|
458
|
+
|
|
459
|
+
# Create training config
|
|
460
|
+
train_config = TrainingConfig(
|
|
461
|
+
train_data_path=train_data_path,
|
|
462
|
+
val_data_path=val_data_path,
|
|
463
|
+
batch_size=batch_size,
|
|
464
|
+
learning_rate=learning_rate,
|
|
465
|
+
num_epochs=num_epochs,
|
|
466
|
+
checkpoint_dir=checkpoint_dir,
|
|
467
|
+
log_dir=log_dir,
|
|
468
|
+
device=device,
|
|
469
|
+
seed=seed,
|
|
470
|
+
**{k: v for k, v in kwargs.items() if hasattr(TrainingConfig, k)}
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Prepare training components
|
|
474
|
+
train_loader, val_loader, optimizer = self.prepare_training(
|
|
475
|
+
model, train_config, device
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Execute training
|
|
479
|
+
results = self.execute_training(
|
|
480
|
+
model=model,
|
|
481
|
+
train_loader=train_loader,
|
|
482
|
+
val_loader=val_loader,
|
|
483
|
+
optimizer=optimizer,
|
|
484
|
+
train_config=train_config,
|
|
485
|
+
device=device,
|
|
486
|
+
model_config_dict=model.config.to_dict()
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
return results
|