gptmed 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,489 @@
1
+ """
2
+ Training Service
3
+
4
+ PURPOSE:
5
+ Encapsulates training logic following Service Layer Pattern.
6
+ Provides a high-level interface for model training with device flexibility.
7
+
8
+ DESIGN PATTERNS:
9
+ - Service Layer Pattern: Business logic separated from API layer
10
+ - Dependency Injection: DeviceManager injected for flexibility
11
+ - Single Responsibility: Only handles training orchestration
12
+ - Open/Closed Principle: Extensible without modification
13
+
14
+ WHAT THIS FILE DOES:
15
+ 1. Orchestrates the training process
16
+ 2. Manages device configuration via DeviceManager
17
+ 3. Coordinates model, data, optimizer, and trainer
18
+ 4. Provides clean interface for training operations
19
+
20
+ PACKAGES USED:
21
+ - torch: PyTorch training
22
+ - pathlib: Path handling
23
+ """
24
+
25
+ import torch
26
+ import random
27
+ import numpy as np
28
+ from pathlib import Path
29
+ from typing import Dict, Any, Optional, List
30
+
31
+ from gptmed.services.device_manager import DeviceManager
32
+ from gptmed.model.architecture import GPTTransformer
33
+ from gptmed.model.configs.model_config import get_tiny_config, get_small_config, get_medium_config
34
+ from gptmed.configs.train_config import TrainingConfig
35
+ from gptmed.training.dataset import create_dataloaders
36
+ from gptmed.training.trainer import Trainer
37
+
38
+ # Observability imports
39
+ from gptmed.observability import (
40
+ TrainingObserver,
41
+ MetricsTracker,
42
+ ConsoleCallback,
43
+ )
44
+
45
+
46
+ class TrainingService:
47
+ """
48
+ High-level service for model training.
49
+
50
+ Implements Service Layer Pattern to encapsulate training logic.
51
+ Uses Dependency Injection for DeviceManager.
52
+
53
+ Example:
54
+ >>> device_manager = DeviceManager(preferred_device='cpu')
55
+ >>> service = TrainingService(device_manager=device_manager)
56
+ >>> results = service.train_from_config('config.yaml', verbose=True)
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ device_manager: Optional[DeviceManager] = None,
62
+ verbose: bool = True
63
+ ):
64
+ """
65
+ Initialize TrainingService.
66
+
67
+ Args:
68
+ device_manager: DeviceManager instance (if None, creates default)
69
+ verbose: Whether to print training information
70
+ """
71
+ self.device_manager = device_manager or DeviceManager(preferred_device='cuda')
72
+ self.verbose = verbose
73
+
74
+ def set_seed(self, seed: int) -> None:
75
+ """
76
+ Set random seeds for reproducibility.
77
+
78
+ Args:
79
+ seed: Random seed value
80
+ """
81
+ random.seed(seed)
82
+ np.random.seed(seed)
83
+ torch.manual_seed(seed)
84
+ if torch.cuda.is_available():
85
+ torch.cuda.manual_seed(seed)
86
+ torch.cuda.manual_seed_all(seed)
87
+ torch.backends.cudnn.deterministic = True
88
+ torch.backends.cudnn.benchmark = False
89
+
90
+ def create_model(self, model_size: str) -> GPTTransformer:
91
+ """
92
+ Create model based on size specification.
93
+
94
+ Args:
95
+ model_size: Model size ('tiny', 'small', or 'medium')
96
+
97
+ Returns:
98
+ GPTTransformer model instance
99
+
100
+ Raises:
101
+ ValueError: If model_size is invalid
102
+ """
103
+ if model_size == 'tiny':
104
+ model_config = get_tiny_config()
105
+ elif model_size == 'small':
106
+ model_config = get_small_config()
107
+ elif model_size == 'medium':
108
+ model_config = get_medium_config()
109
+ else:
110
+ raise ValueError(f"Unknown model size: {model_size}")
111
+
112
+ return GPTTransformer(model_config)
113
+
114
+ def prepare_training(
115
+ self,
116
+ model: GPTTransformer,
117
+ train_config: TrainingConfig,
118
+ device: str
119
+ ) -> tuple:
120
+ """
121
+ Prepare components for training.
122
+
123
+ Args:
124
+ model: Model to train
125
+ train_config: Training configuration
126
+ device: Device to use
127
+
128
+ Returns:
129
+ Tuple of (train_loader, val_loader, optimizer)
130
+ """
131
+ # Load data
132
+ if self.verbose:
133
+ print(f"\n📊 Loading data...")
134
+ print(f" Train: {train_config.train_data_path}")
135
+ print(f" Val: {train_config.val_data_path}")
136
+
137
+ train_loader, val_loader = create_dataloaders(
138
+ train_path=Path(train_config.train_data_path),
139
+ val_path=Path(train_config.val_data_path),
140
+ batch_size=train_config.batch_size,
141
+ num_workers=0,
142
+ )
143
+
144
+ if self.verbose:
145
+ print(f" Train batches: {len(train_loader)}")
146
+ print(f" Val batches: {len(val_loader)}")
147
+
148
+ # Create optimizer
149
+ if self.verbose:
150
+ print(f"\n⚙️ Setting up optimizer...")
151
+ print(f" Learning rate: {train_config.learning_rate}")
152
+ print(f" Weight decay: {train_config.weight_decay}")
153
+
154
+ optimizer = torch.optim.AdamW(
155
+ model.parameters(),
156
+ lr=train_config.learning_rate,
157
+ betas=train_config.betas,
158
+ eps=train_config.eps,
159
+ weight_decay=train_config.weight_decay,
160
+ )
161
+
162
+ return train_loader, val_loader, optimizer
163
+
164
+ def execute_training(
165
+ self,
166
+ model: GPTTransformer,
167
+ train_loader,
168
+ val_loader,
169
+ optimizer,
170
+ train_config: TrainingConfig,
171
+ device: str,
172
+ model_config_dict: dict,
173
+ observers: Optional[List[TrainingObserver]] = None,
174
+ ) -> Dict[str, Any]:
175
+ """
176
+ Execute the training process.
177
+
178
+ Args:
179
+ model: Model to train
180
+ train_loader: Training data loader
181
+ val_loader: Validation data loader
182
+ optimizer: Optimizer
183
+ train_config: Training configuration
184
+ device: Device to use
185
+ model_config_dict: Model configuration as dictionary
186
+ observers: Optional list of TrainingObserver instances.
187
+ If None, default observers (MetricsTracker) will be used.
188
+
189
+ Returns:
190
+ Dictionary with training results
191
+ """
192
+ # Set up default observers if none provided
193
+ if observers is None:
194
+ observers = self._create_default_observers(train_config)
195
+
196
+ # Create trainer with observers
197
+ if self.verbose:
198
+ print(f"\n🎯 Initializing trainer...")
199
+ print(f" Observers: {len(observers)} ({', '.join(o.name for o in observers)})")
200
+
201
+ trainer = Trainer(
202
+ model=model,
203
+ train_loader=train_loader,
204
+ val_loader=val_loader,
205
+ optimizer=optimizer,
206
+ config=train_config,
207
+ device=device,
208
+ observers=observers,
209
+ )
210
+
211
+ # Resume if requested
212
+ if hasattr(train_config, 'resume_from') and train_config.resume_from is not None:
213
+ if self.verbose:
214
+ print(f"\n📥 Resuming from checkpoint: {train_config.resume_from}")
215
+ trainer.resume_from_checkpoint(Path(train_config.resume_from))
216
+ elif train_config.checkpoint_dir and hasattr(train_config, 'checkpoint_dir'):
217
+ # Check if there's a resume_from in the checkpoint dir
218
+ resume_path = Path(train_config.checkpoint_dir) / "resume_from.pt"
219
+ if resume_path.exists() and self.verbose:
220
+ print(f"\n📥 Found checkpoint to resume: {resume_path}")
221
+
222
+ # Start training
223
+ if self.verbose:
224
+ print(f"\n{'='*60}")
225
+ print("🚀 Starting Training!")
226
+ print(f"{'='*60}\n")
227
+
228
+ interrupted = False
229
+ try:
230
+ trainer.train()
231
+ except KeyboardInterrupt:
232
+ interrupted = True
233
+ if self.verbose:
234
+ print("\n\n⏸️ Training interrupted by user")
235
+ print("💾 Saving checkpoint...")
236
+ trainer.checkpoint_manager.save_checkpoint(
237
+ model=model,
238
+ optimizer=optimizer,
239
+ step=trainer.global_step,
240
+ epoch=trainer.current_epoch,
241
+ val_loss=trainer.best_val_loss,
242
+ model_config=model_config_dict,
243
+ train_config=train_config.to_dict(),
244
+ )
245
+ if self.verbose:
246
+ print("✓ Checkpoint saved. Resume with resume_from in config.")
247
+
248
+ # Generate observability reports (on BOTH normal and abnormal exit)
249
+ self._generate_observability_reports(
250
+ observers=observers,
251
+ train_config=train_config,
252
+ trainer=trainer,
253
+ interrupted=interrupted,
254
+ )
255
+
256
+ # Return results
257
+ final_checkpoint = Path(train_config.checkpoint_dir) / "final_model.pt"
258
+
259
+ results = {
260
+ 'final_checkpoint': str(final_checkpoint),
261
+ 'best_checkpoint': str(final_checkpoint), # Alias for backward compatibility
262
+ 'final_val_loss': trainer.best_val_loss,
263
+ 'total_epochs': trainer.current_epoch,
264
+ 'total_steps': trainer.global_step,
265
+ 'checkpoint_dir': train_config.checkpoint_dir,
266
+ 'log_dir': train_config.log_dir,
267
+ 'interrupted': interrupted,
268
+ }
269
+
270
+ # Get training issues from metrics tracker
271
+ metrics_tracker = self._get_metrics_tracker(observers)
272
+ if metrics_tracker:
273
+ results['training_issues'] = metrics_tracker.detect_issues()
274
+
275
+ if self.verbose:
276
+ status = "⏸️ Training Interrupted" if interrupted else "✅ Training Complete!"
277
+ print(f"\n{'='*60}")
278
+ print(status)
279
+ print(f"{'='*60}")
280
+ print(f"\n📁 Results:")
281
+ print(f" Final checkpoint: {results['final_checkpoint']}")
282
+ print(f" Best val loss: {results['final_val_loss']:.4f}")
283
+ print(f" Total steps: {results['total_steps']}")
284
+ print(f" Total epochs: {results['total_epochs']}")
285
+ print(f" Logs: {results['log_dir']}")
286
+
287
+ return results
288
+
289
+ def _generate_observability_reports(
290
+ self,
291
+ observers: List[TrainingObserver],
292
+ train_config: TrainingConfig,
293
+ trainer,
294
+ interrupted: bool = False,
295
+ ) -> None:
296
+ """
297
+ Generate observability reports from metrics tracker.
298
+
299
+ Called on both normal completion and abnormal exit (Ctrl+C).
300
+
301
+ Args:
302
+ observers: List of training observers
303
+ train_config: Training configuration
304
+ trainer: Trainer instance
305
+ interrupted: Whether training was interrupted
306
+ """
307
+ metrics_tracker = self._get_metrics_tracker(observers)
308
+
309
+ if not metrics_tracker:
310
+ if self.verbose:
311
+ print("\n⚠️ No MetricsTracker found - skipping observability reports")
312
+ return
313
+
314
+ if self.verbose:
315
+ print(f"\n{'='*60}")
316
+ print("📊 Generating Observability Reports")
317
+ print(f"{'='*60}")
318
+
319
+ try:
320
+ # Export metrics to CSV
321
+ csv_path = metrics_tracker.export_to_csv()
322
+ if self.verbose:
323
+ print(f" ✓ CSV exported: {csv_path}")
324
+ except Exception as e:
325
+ if self.verbose:
326
+ print(f" ✗ CSV export failed: {e}")
327
+
328
+ try:
329
+ # Export metrics to JSON
330
+ json_path = metrics_tracker.export_to_json()
331
+ if self.verbose:
332
+ print(f" ✓ JSON exported: {json_path}")
333
+ except Exception as e:
334
+ if self.verbose:
335
+ print(f" ✗ JSON export failed: {e}")
336
+
337
+ try:
338
+ # Generate loss curve plots
339
+ plot_path = metrics_tracker.plot_loss_curves()
340
+ if plot_path and self.verbose:
341
+ print(f" ✓ Loss curves plotted: {plot_path}")
342
+ except ImportError:
343
+ if self.verbose:
344
+ print(f" ⚠️ Plotting skipped (matplotlib not installed)")
345
+ except Exception as e:
346
+ if self.verbose:
347
+ print(f" ✗ Plotting failed: {e}")
348
+
349
+ # Training health check
350
+ issues = metrics_tracker.detect_issues()
351
+ if self.verbose:
352
+ print(f"\n📋 Training Health Check:")
353
+ for issue in issues:
354
+ print(f" {issue}")
355
+
356
+ # Add interrupted notice if applicable
357
+ if interrupted and self.verbose:
358
+ print(f"\n⚠️ Note: Training was interrupted at step {trainer.global_step}")
359
+ print(f" Reports reflect partial training data only.")
360
+
361
+ def _create_default_observers(self, train_config: TrainingConfig) -> List[TrainingObserver]:
362
+ """
363
+ Create default observers for training.
364
+
365
+ Args:
366
+ train_config: Training configuration
367
+
368
+ Returns:
369
+ List of default TrainingObserver instances
370
+ """
371
+ observers = []
372
+
373
+ # MetricsTracker - comprehensive metrics logging
374
+ metrics_tracker = MetricsTracker(
375
+ log_dir=train_config.log_dir,
376
+ experiment_name="gptmed_training",
377
+ moving_avg_window=100,
378
+ log_interval=train_config.log_interval,
379
+ verbose=self.verbose,
380
+ )
381
+ observers.append(metrics_tracker)
382
+
383
+ # Note: ConsoleCallback is optional since Trainer already has console output
384
+ # Uncomment if you want additional formatted console output:
385
+ # console_callback = ConsoleCallback(log_interval=train_config.log_interval)
386
+ # observers.append(console_callback)
387
+
388
+ return observers
389
+
390
+ def _get_metrics_tracker(self, observers: List[TrainingObserver]) -> Optional[MetricsTracker]:
391
+ """
392
+ Get MetricsTracker from observers list if present.
393
+
394
+ Args:
395
+ observers: List of observers
396
+
397
+ Returns:
398
+ MetricsTracker instance or None
399
+ """
400
+ for obs in observers:
401
+ if isinstance(obs, MetricsTracker):
402
+ return obs
403
+ return None
404
+
405
+ return results
406
+
407
+ def train(
408
+ self,
409
+ model_size: str,
410
+ train_data_path: str,
411
+ val_data_path: str,
412
+ batch_size: int = 16,
413
+ learning_rate: float = 3e-4,
414
+ num_epochs: int = 10,
415
+ checkpoint_dir: str = "./model/checkpoints",
416
+ log_dir: str = "./logs",
417
+ seed: int = 42,
418
+ **kwargs
419
+ ) -> Dict[str, Any]:
420
+ """
421
+ High-level training interface.
422
+
423
+ Args:
424
+ model_size: Model size ('tiny', 'small', 'medium')
425
+ train_data_path: Path to training data
426
+ val_data_path: Path to validation data
427
+ batch_size: Training batch size
428
+ learning_rate: Learning rate
429
+ num_epochs: Number of training epochs
430
+ checkpoint_dir: Directory for checkpoints
431
+ log_dir: Directory for logs
432
+ seed: Random seed
433
+ **kwargs: Additional training config parameters
434
+
435
+ Returns:
436
+ Dictionary with training results
437
+ """
438
+ # Set seed
439
+ if self.verbose:
440
+ print(f"\n🎲 Setting random seed: {seed}")
441
+ self.set_seed(seed)
442
+
443
+ # Get device
444
+ device = self.device_manager.get_device()
445
+ self.device_manager.print_device_info(verbose=self.verbose)
446
+
447
+ # Create model
448
+ if self.verbose:
449
+ print(f"\n🧠 Creating model: {model_size}")
450
+
451
+ model = self.create_model(model_size)
452
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
453
+
454
+ if self.verbose:
455
+ print(f" Model size: {model_size}")
456
+ print(f" Parameters: {total_params:,}")
457
+ print(f" Memory: ~{total_params * 4 / 1024 / 1024:.2f} MB")
458
+
459
+ # Create training config
460
+ train_config = TrainingConfig(
461
+ train_data_path=train_data_path,
462
+ val_data_path=val_data_path,
463
+ batch_size=batch_size,
464
+ learning_rate=learning_rate,
465
+ num_epochs=num_epochs,
466
+ checkpoint_dir=checkpoint_dir,
467
+ log_dir=log_dir,
468
+ device=device,
469
+ seed=seed,
470
+ **{k: v for k, v in kwargs.items() if hasattr(TrainingConfig, k)}
471
+ )
472
+
473
+ # Prepare training components
474
+ train_loader, val_loader, optimizer = self.prepare_training(
475
+ model, train_config, device
476
+ )
477
+
478
+ # Execute training
479
+ results = self.execute_training(
480
+ model=model,
481
+ train_loader=train_loader,
482
+ val_loader=val_loader,
483
+ optimizer=optimizer,
484
+ train_config=train_config,
485
+ device=device,
486
+ model_config_dict=model.config.to_dict()
487
+ )
488
+
489
+ return results