gptmed 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,7 @@ import torch
26
26
  import random
27
27
  import numpy as np
28
28
  from pathlib import Path
29
- from typing import Dict, Any, Optional
29
+ from typing import Dict, Any, Optional, List
30
30
 
31
31
  from gptmed.services.device_manager import DeviceManager
32
32
  from gptmed.model.architecture import GPTTransformer
@@ -35,6 +35,13 @@ from gptmed.configs.train_config import TrainingConfig
35
35
  from gptmed.training.dataset import create_dataloaders
36
36
  from gptmed.training.trainer import Trainer
37
37
 
38
+ # Observability imports
39
+ from gptmed.observability import (
40
+ TrainingObserver,
41
+ MetricsTracker,
42
+ ConsoleCallback,
43
+ )
44
+
38
45
 
39
46
  class TrainingService:
40
47
  """
@@ -162,7 +169,8 @@ class TrainingService:
162
169
  optimizer,
163
170
  train_config: TrainingConfig,
164
171
  device: str,
165
- model_config_dict: dict
172
+ model_config_dict: dict,
173
+ observers: Optional[List[TrainingObserver]] = None,
166
174
  ) -> Dict[str, Any]:
167
175
  """
168
176
  Execute the training process.
@@ -175,13 +183,20 @@ class TrainingService:
175
183
  train_config: Training configuration
176
184
  device: Device to use
177
185
  model_config_dict: Model configuration as dictionary
186
+ observers: Optional list of TrainingObserver instances.
187
+ If None, default observers (MetricsTracker) will be used.
178
188
 
179
189
  Returns:
180
190
  Dictionary with training results
181
191
  """
182
- # Create trainer
192
+ # Set up default observers if none provided
193
+ if observers is None:
194
+ observers = self._create_default_observers(train_config)
195
+
196
+ # Create trainer with observers
183
197
  if self.verbose:
184
198
  print(f"\n🎯 Initializing trainer...")
199
+ print(f" Observers: {len(observers)} ({', '.join(o.name for o in observers)})")
185
200
 
186
201
  trainer = Trainer(
187
202
  model=model,
@@ -190,6 +205,7 @@ class TrainingService:
190
205
  optimizer=optimizer,
191
206
  config=train_config,
192
207
  device=device,
208
+ observers=observers,
193
209
  )
194
210
 
195
211
  # Resume if requested
@@ -209,9 +225,11 @@ class TrainingService:
209
225
  print("🚀 Starting Training!")
210
226
  print(f"{'='*60}\n")
211
227
 
228
+ interrupted = False
212
229
  try:
213
230
  trainer.train()
214
231
  except KeyboardInterrupt:
232
+ interrupted = True
215
233
  if self.verbose:
216
234
  print("\n\n⏸️ Training interrupted by user")
217
235
  print("💾 Saving checkpoint...")
@@ -227,29 +245,165 @@ class TrainingService:
227
245
  if self.verbose:
228
246
  print("✓ Checkpoint saved. Resume with resume_from in config.")
229
247
 
248
+ # Generate observability reports (on BOTH normal and abnormal exit)
249
+ self._generate_observability_reports(
250
+ observers=observers,
251
+ train_config=train_config,
252
+ trainer=trainer,
253
+ interrupted=interrupted,
254
+ )
255
+
230
256
  # Return results
231
- best_checkpoint = Path(train_config.checkpoint_dir) / "best_model.pt"
257
+ final_checkpoint = Path(train_config.checkpoint_dir) / "final_model.pt"
232
258
 
233
259
  results = {
234
- 'best_checkpoint': str(best_checkpoint),
260
+ 'final_checkpoint': str(final_checkpoint),
261
+ 'best_checkpoint': str(final_checkpoint), # Alias for backward compatibility
235
262
  'final_val_loss': trainer.best_val_loss,
236
263
  'total_epochs': trainer.current_epoch,
264
+ 'total_steps': trainer.global_step,
237
265
  'checkpoint_dir': train_config.checkpoint_dir,
238
266
  'log_dir': train_config.log_dir,
267
+ 'interrupted': interrupted,
239
268
  }
240
269
 
270
+ # Get training issues from metrics tracker
271
+ metrics_tracker = self._get_metrics_tracker(observers)
272
+ if metrics_tracker:
273
+ results['training_issues'] = metrics_tracker.detect_issues()
274
+
241
275
  if self.verbose:
276
+ status = "⏸️ Training Interrupted" if interrupted else "✅ Training Complete!"
242
277
  print(f"\n{'='*60}")
243
- print("✅ Training Complete!")
278
+ print(status)
244
279
  print(f"{'='*60}")
245
280
  print(f"\n📁 Results:")
246
- print(f" Best checkpoint: {results['best_checkpoint']}")
281
+ print(f" Final checkpoint: {results['final_checkpoint']}")
247
282
  print(f" Best val loss: {results['final_val_loss']:.4f}")
283
+ print(f" Total steps: {results['total_steps']}")
248
284
  print(f" Total epochs: {results['total_epochs']}")
249
285
  print(f" Logs: {results['log_dir']}")
250
286
 
251
287
  return results
252
288
 
289
+ def _generate_observability_reports(
290
+ self,
291
+ observers: List[TrainingObserver],
292
+ train_config: TrainingConfig,
293
+ trainer,
294
+ interrupted: bool = False,
295
+ ) -> None:
296
+ """
297
+ Generate observability reports from metrics tracker.
298
+
299
+ Called on both normal completion and abnormal exit (Ctrl+C).
300
+
301
+ Args:
302
+ observers: List of training observers
303
+ train_config: Training configuration
304
+ trainer: Trainer instance
305
+ interrupted: Whether training was interrupted
306
+ """
307
+ metrics_tracker = self._get_metrics_tracker(observers)
308
+
309
+ if not metrics_tracker:
310
+ if self.verbose:
311
+ print("\n⚠️ No MetricsTracker found - skipping observability reports")
312
+ return
313
+
314
+ if self.verbose:
315
+ print(f"\n{'='*60}")
316
+ print("📊 Generating Observability Reports")
317
+ print(f"{'='*60}")
318
+
319
+ try:
320
+ # Export metrics to CSV
321
+ csv_path = metrics_tracker.export_to_csv()
322
+ if self.verbose:
323
+ print(f" ✓ CSV exported: {csv_path}")
324
+ except Exception as e:
325
+ if self.verbose:
326
+ print(f" ✗ CSV export failed: {e}")
327
+
328
+ try:
329
+ # Export metrics to JSON
330
+ json_path = metrics_tracker.export_to_json()
331
+ if self.verbose:
332
+ print(f" ✓ JSON exported: {json_path}")
333
+ except Exception as e:
334
+ if self.verbose:
335
+ print(f" ✗ JSON export failed: {e}")
336
+
337
+ try:
338
+ # Generate loss curve plots
339
+ plot_path = metrics_tracker.plot_loss_curves()
340
+ if plot_path and self.verbose:
341
+ print(f" ✓ Loss curves plotted: {plot_path}")
342
+ except ImportError:
343
+ if self.verbose:
344
+ print(f" ⚠️ Plotting skipped (matplotlib not installed)")
345
+ except Exception as e:
346
+ if self.verbose:
347
+ print(f" ✗ Plotting failed: {e}")
348
+
349
+ # Training health check
350
+ issues = metrics_tracker.detect_issues()
351
+ if self.verbose:
352
+ print(f"\n📋 Training Health Check:")
353
+ for issue in issues:
354
+ print(f" {issue}")
355
+
356
+ # Add interrupted notice if applicable
357
+ if interrupted and self.verbose:
358
+ print(f"\n⚠️ Note: Training was interrupted at step {trainer.global_step}")
359
+ print(f" Reports reflect partial training data only.")
360
+
361
+ def _create_default_observers(self, train_config: TrainingConfig) -> List[TrainingObserver]:
362
+ """
363
+ Create default observers for training.
364
+
365
+ Args:
366
+ train_config: Training configuration
367
+
368
+ Returns:
369
+ List of default TrainingObserver instances
370
+ """
371
+ observers = []
372
+
373
+ # MetricsTracker - comprehensive metrics logging
374
+ metrics_tracker = MetricsTracker(
375
+ log_dir=train_config.log_dir,
376
+ experiment_name="gptmed_training",
377
+ moving_avg_window=100,
378
+ log_interval=train_config.log_interval,
379
+ verbose=self.verbose,
380
+ )
381
+ observers.append(metrics_tracker)
382
+
383
+ # Note: ConsoleCallback is optional since Trainer already has console output
384
+ # Uncomment if you want additional formatted console output:
385
+ # console_callback = ConsoleCallback(log_interval=train_config.log_interval)
386
+ # observers.append(console_callback)
387
+
388
+ return observers
389
+
390
+ def _get_metrics_tracker(self, observers: List[TrainingObserver]) -> Optional[MetricsTracker]:
391
+ """
392
+ Get MetricsTracker from observers list if present.
393
+
394
+ Args:
395
+ observers: List of observers
396
+
397
+ Returns:
398
+ MetricsTracker instance or None
399
+ """
400
+ for obs in observers:
401
+ if isinstance(obs, MetricsTracker):
402
+ return obs
403
+ return None
404
+
405
+ return results
406
+
253
407
  def train(
254
408
  self,
255
409
  model_size: str,
@@ -49,7 +49,7 @@ import torch.nn as nn
49
49
  from torch.utils.data import DataLoader
50
50
  import time
51
51
  from pathlib import Path
52
- from typing import Optional
52
+ from typing import Optional, List
53
53
 
54
54
  from gptmed.model.architecture import GPTTransformer
55
55
  from gptmed.training.utils import (
@@ -62,6 +62,15 @@ from gptmed.training.utils import (
62
62
  from gptmed.utils.logging import MetricsLogger, log_training_step, log_validation
63
63
  from gptmed.utils.checkpoints import CheckpointManager
64
64
 
65
+ # New observability imports
66
+ from gptmed.observability.base import (
67
+ TrainingObserver,
68
+ ObserverManager,
69
+ StepMetrics,
70
+ ValidationMetrics,
71
+ GradientMetrics,
72
+ )
73
+
65
74
 
66
75
  class Trainer:
67
76
  """
@@ -83,6 +92,7 @@ class Trainer:
83
92
  optimizer: torch.optim.Optimizer,
84
93
  config, # TrainingConfig
85
94
  device: str = "cuda",
95
+ observers: List[TrainingObserver] = None,
86
96
  ):
87
97
  """
88
98
  Args:
@@ -92,6 +102,7 @@ class Trainer:
92
102
  optimizer: Optimizer (e.g., AdamW)
93
103
  config: TrainingConfig object
94
104
  device: Device to train on
105
+ observers: List of TrainingObserver instances for monitoring
95
106
  """
96
107
  self.model = model.to(device)
97
108
  self.train_loader = train_loader
@@ -100,7 +111,13 @@ class Trainer:
100
111
  self.config = config
101
112
  self.device = device
102
113
 
103
- # Initialize utilities
114
+ # Initialize observability
115
+ self.observer_manager = ObserverManager()
116
+ if observers:
117
+ for obs in observers:
118
+ self.observer_manager.add(obs)
119
+
120
+ # Initialize utilities (keep for backward compatibility)
104
121
  self.logger = MetricsLogger(log_dir=config.log_dir, experiment_name="gpt_training")
105
122
 
106
123
  self.checkpoint_manager = CheckpointManager(
@@ -124,17 +141,32 @@ class Trainer:
124
141
  print(f" Total steps: {self.total_steps}")
125
142
  print(f" Steps per epoch: {steps_per_epoch}")
126
143
  print(f" Num epochs: {config.num_epochs}")
144
+ print(f" Observers: {len(self.observer_manager.observers)}")
145
+
146
+ def add_observer(self, observer: TrainingObserver) -> None:
147
+ """
148
+ Add an observer for training monitoring.
149
+
150
+ Args:
151
+ observer: TrainingObserver instance
152
+ """
153
+ self.observer_manager.add(observer)
154
+ print(f" Added observer: {observer.name}")
127
155
 
128
- def train_step(self, batch: tuple) -> dict:
156
+ def train_step(self, batch: tuple, step: int = 0, lr: float = 0.0) -> dict:
129
157
  """
130
158
  Single training step.
131
159
 
132
160
  Args:
133
161
  batch: (input_ids, target_ids) tuple
162
+ step: Current global step (for observer metrics)
163
+ lr: Current learning rate (for observer metrics)
134
164
 
135
165
  Returns:
136
166
  Dictionary with step metrics
137
167
  """
168
+ step_start_time = time.time()
169
+
138
170
  # Move batch to device
139
171
  input_ids, target_ids = batch
140
172
  input_ids = input_ids.to(self.device)
@@ -163,14 +195,34 @@ class Trainer:
163
195
  # Optimizer step
164
196
  self.optimizer.step()
165
197
 
166
- # Return metrics
167
- return {
198
+ # Calculate tokens per second
199
+ step_time = time.time() - step_start_time
200
+ tokens_per_sec = (batch_size * seq_len) / step_time if step_time > 0 else 0
201
+
202
+ # Create metrics dict (for backward compatibility)
203
+ metrics_dict = {
168
204
  "loss": loss.item(),
169
205
  "grad_norm": grad_norm,
170
206
  "batch_size": batch_size,
171
207
  "seq_len": seq_len,
208
+ "tokens_per_sec": tokens_per_sec,
172
209
  }
173
210
 
211
+ # Notify observers with StepMetrics
212
+ step_metrics = StepMetrics(
213
+ step=step,
214
+ loss=loss.item(),
215
+ learning_rate=lr,
216
+ grad_norm=grad_norm,
217
+ batch_size=batch_size,
218
+ seq_len=seq_len,
219
+ tokens_per_sec=tokens_per_sec,
220
+ )
221
+ self.observer_manager.notify_step(step_metrics)
222
+
223
+ # Return metrics
224
+ return metrics_dict
225
+
174
226
  def evaluate(self) -> dict:
175
227
  """
176
228
  Evaluate on validation set.
@@ -188,6 +240,14 @@ class Trainer:
188
240
 
189
241
  log_validation(self.global_step, val_loss, val_perplexity)
190
242
 
243
+ # Notify observers
244
+ val_metrics = ValidationMetrics(
245
+ step=self.global_step,
246
+ val_loss=val_loss,
247
+ val_perplexity=val_perplexity,
248
+ )
249
+ self.observer_manager.notify_validation(val_metrics)
250
+
191
251
  return {"val_loss": val_loss, "val_perplexity": val_perplexity}
192
252
 
193
253
  def train(self):
@@ -200,17 +260,37 @@ class Trainer:
200
260
  print("Starting Training")
201
261
  print("=" * 60)
202
262
 
263
+ # Notify observers of training start
264
+ train_config = {
265
+ "model_size": getattr(self.model.config, 'model_size', 'unknown'),
266
+ "device": self.device,
267
+ "batch_size": self.config.batch_size,
268
+ "learning_rate": self.config.learning_rate,
269
+ "num_epochs": self.config.num_epochs,
270
+ "max_steps": self.config.max_steps,
271
+ "total_steps": self.total_steps,
272
+ "warmup_steps": self.config.warmup_steps,
273
+ "grad_clip": self.config.grad_clip,
274
+ "weight_decay": self.config.weight_decay,
275
+ }
276
+ self.observer_manager.notify_train_start(train_config)
277
+
203
278
  self.model.train()
204
279
 
205
280
  # Training loop
206
281
  for epoch in range(self.config.num_epochs):
207
282
  self.current_epoch = epoch
208
283
 
284
+ # Notify observers of epoch start
285
+ self.observer_manager.notify_epoch_start(epoch)
286
+
209
287
  print(f"\n{'='*60}")
210
288
  print(f"Epoch {epoch + 1}/{self.config.num_epochs}")
211
289
  print(f"{'='*60}")
212
290
 
213
291
  epoch_start_time = time.time()
292
+ epoch_loss_sum = 0.0
293
+ epoch_steps = 0
214
294
 
215
295
  for batch_idx, batch in enumerate(self.train_loader):
216
296
  step_start_time = time.time()
@@ -226,8 +306,12 @@ class Trainer:
226
306
  )
227
307
  set_learning_rate(self.optimizer, lr)
228
308
 
229
- # Training step
230
- metrics = self.train_step(batch)
309
+ # Training step (now with step and lr for observers)
310
+ metrics = self.train_step(batch, step=self.global_step, lr=lr)
311
+
312
+ # Track epoch loss
313
+ epoch_loss_sum += metrics["loss"]
314
+ epoch_steps += 1
231
315
 
232
316
  # Calculate tokens per second
233
317
  step_time = time.time() - step_start_time
@@ -243,7 +327,7 @@ class Trainer:
243
327
  tokens_per_sec=tokens_per_sec,
244
328
  )
245
329
 
246
- # Log metrics
330
+ # Log metrics (legacy logger)
247
331
  self.logger.log(
248
332
  self.global_step,
249
333
  {
@@ -269,7 +353,7 @@ class Trainer:
269
353
  is_best = False
270
354
 
271
355
  # Save checkpoint
272
- self.checkpoint_manager.save_checkpoint(
356
+ checkpoint_path = self.checkpoint_manager.save_checkpoint(
273
357
  model=self.model,
274
358
  optimizer=self.optimizer,
275
359
  step=self.global_step,
@@ -280,8 +364,19 @@ class Trainer:
280
364
  is_best=is_best,
281
365
  )
282
366
 
367
+ # Notify observers of checkpoint
368
+ if checkpoint_path:
369
+ self.observer_manager.notify_checkpoint(self.global_step, str(checkpoint_path))
370
+
283
371
  self.model.train() # Back to training mode
284
372
 
373
+ # Check for early stopping (if any observer requests it)
374
+ for obs in self.observer_manager.observers:
375
+ if hasattr(obs, 'should_stop') and obs.should_stop:
376
+ print(f"\nEarly stopping requested by {obs.name}")
377
+ self._finish_training()
378
+ return
379
+
285
380
  # Save checkpoint periodically
286
381
  if self.global_step % self.config.save_interval == 0 and self.global_step > 0:
287
382
  self.checkpoint_manager.save_checkpoint(
@@ -299,17 +394,36 @@ class Trainer:
299
394
  # Check if reached max steps
300
395
  if self.config.max_steps > 0 and self.global_step >= self.config.max_steps:
301
396
  print(f"\nReached max_steps ({self.config.max_steps}). Stopping training.")
397
+ self._finish_training()
302
398
  return
303
399
 
304
- # End of epoch
400
+ # End of epoch - notify observers
305
401
  epoch_time = time.time() - epoch_start_time
402
+ epoch_avg_loss = epoch_loss_sum / epoch_steps if epoch_steps > 0 else 0
403
+ self.observer_manager.notify_epoch_end(epoch, {
404
+ "train_loss": epoch_avg_loss,
405
+ "epoch_time": epoch_time,
406
+ })
306
407
  print(f"\nEpoch {epoch + 1} completed in {epoch_time:.2f}s")
307
408
 
409
+ self._finish_training()
410
+
411
+ def _finish_training(self):
412
+ """Finalize training and notify observers."""
308
413
  print("\n" + "=" * 60)
309
414
  print("Training Complete!")
310
415
  print("=" * 60)
311
416
  print(f"Best validation loss: {self.best_val_loss:.4f}")
312
417
 
418
+ # Notify observers of training end
419
+ final_metrics = {
420
+ "best_val_loss": self.best_val_loss,
421
+ "total_steps": self.global_step,
422
+ "final_epoch": self.current_epoch,
423
+ "final_checkpoint": str(self.checkpoint_manager.checkpoint_dir / "final_model.pt"),
424
+ }
425
+ self.observer_manager.notify_train_end(final_metrics)
426
+
313
427
  def resume_from_checkpoint(self, checkpoint_path: Optional[Path] = None):
314
428
  """
315
429
  Resume training from a checkpoint.
@@ -108,7 +108,7 @@ class CheckpointManager:
108
108
  # Save as best if applicable
109
109
  if is_best or val_loss < self.best_val_loss:
110
110
  self.best_val_loss = val_loss
111
- best_path = self.checkpoint_dir / "best_model.pt"
111
+ best_path = self.checkpoint_dir / "final_model.pt"
112
112
  torch.save(checkpoint, best_path)
113
113
  print(f"Best model saved: {best_path} (val_loss: {val_loss:.4f})")
114
114