gptmed 0.3.5__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,544 @@
1
+ """
2
+ Metrics Tracker for Training Observability
3
+
4
+ PURPOSE:
5
+ Enhanced metrics tracking with loss curves, moving averages,
6
+ gradient statistics, and export capabilities.
7
+
8
+ FEATURES:
9
+ - Loss curve history (train & validation)
10
+ - Moving averages for smoothed visualization
11
+ - Perplexity tracking
12
+ - Learning rate schedule visualization
13
+ - Gradient norm monitoring
14
+ - Export to JSON, CSV, and plots
15
+
16
+ WHAT TO LOOK FOR:
17
+ - Train loss ↓, Val loss ↓ → Healthy learning
18
+ - Train loss ↓, Val loss ↑ → Overfitting
19
+ - Loss plateau → Stuck (increase LR or check data)
20
+ - Loss spikes → Instability (reduce LR)
21
+ - Loss = NaN → Exploding gradients
22
+ """
23
+
24
+ import json
25
+ import math
26
+ import time
27
+ from pathlib import Path
28
+ from typing import Dict, List, Optional, Any, Tuple
29
+ from dataclasses import dataclass, field
30
+ from collections import deque
31
+
32
+ from gptmed.observability.base import (
33
+ TrainingObserver,
34
+ StepMetrics,
35
+ ValidationMetrics,
36
+ GradientMetrics,
37
+ )
38
+
39
+
40
+ @dataclass
41
+ class LossCurvePoint:
42
+ """Single point on the loss curve."""
43
+ step: int
44
+ loss: float
45
+ timestamp: float
46
+
47
+ def to_dict(self) -> Dict[str, float]:
48
+ return {
49
+ "step": self.step,
50
+ "loss": self.loss,
51
+ "timestamp": self.timestamp,
52
+ }
53
+
54
+
55
+ class MetricsTracker(TrainingObserver):
56
+ """
57
+ Comprehensive metrics tracking for training observability.
58
+
59
+ Tracks:
60
+ - Training loss curve
61
+ - Validation loss curve
62
+ - Learning rate schedule
63
+ - Gradient norms
64
+ - Perplexity
65
+ - Moving averages
66
+
67
+ Example:
68
+ >>> tracker = MetricsTracker(log_dir='logs/experiment1')
69
+ >>> trainer.add_observer(tracker)
70
+ >>> # After training:
71
+ >>> tracker.plot_loss_curves()
72
+ >>> tracker.export_to_csv('metrics.csv')
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ log_dir: str = "logs",
78
+ experiment_name: str = "training",
79
+ moving_avg_window: int = 100,
80
+ log_interval: int = 10,
81
+ verbose: bool = True,
82
+ ):
83
+ """
84
+ Initialize MetricsTracker.
85
+
86
+ Args:
87
+ log_dir: Directory to save logs and exports
88
+ experiment_name: Name for this experiment
89
+ moving_avg_window: Window size for moving average
90
+ log_interval: How often to log to file (every N steps)
91
+ verbose: Whether to print progress
92
+ """
93
+ super().__init__(name="MetricsTracker")
94
+
95
+ self.log_dir = Path(log_dir)
96
+ self.log_dir.mkdir(parents=True, exist_ok=True)
97
+
98
+ self.experiment_name = experiment_name
99
+ self.moving_avg_window = moving_avg_window
100
+ self.log_interval = log_interval
101
+ self.verbose = verbose
102
+
103
+ # Initialize storage
104
+ self._reset_storage()
105
+
106
+ # File paths
107
+ self.metrics_file = self.log_dir / f"{experiment_name}_metrics.jsonl"
108
+ self.summary_file = self.log_dir / f"{experiment_name}_summary.json"
109
+
110
+ if self.verbose:
111
+ print(f"📊 MetricsTracker initialized")
112
+ print(f" Log directory: {self.log_dir}")
113
+ print(f" Moving average window: {moving_avg_window}")
114
+
115
+ def _reset_storage(self) -> None:
116
+ """Reset all metric storage."""
117
+ # Loss curves
118
+ self.train_losses: List[LossCurvePoint] = []
119
+ self.val_losses: List[LossCurvePoint] = []
120
+
121
+ # Moving average buffer
122
+ self._loss_buffer: deque = deque(maxlen=self.moving_avg_window)
123
+
124
+ # Learning rate history
125
+ self.learning_rates: List[Tuple[int, float]] = []
126
+
127
+ # Gradient norms
128
+ self.gradient_norms: List[Tuple[int, float]] = []
129
+
130
+ # Perplexity
131
+ self.train_perplexities: List[Tuple[int, float]] = []
132
+ self.val_perplexities: List[Tuple[int, float]] = []
133
+
134
+ # Timing
135
+ self.start_time: Optional[float] = None
136
+ self.step_times: List[float] = []
137
+
138
+ # Training config
139
+ self.config: Dict[str, Any] = {}
140
+
141
+ # Best metrics
142
+ self.best_val_loss: float = float('inf')
143
+ self.best_val_step: int = 0
144
+
145
+ # === TrainingObserver Implementation ===
146
+
147
+ def on_train_start(self, config: Dict[str, Any]) -> None:
148
+ """Called when training begins."""
149
+ self._reset_storage()
150
+ self.start_time = time.time()
151
+ self.config = config.copy()
152
+
153
+ if self.verbose:
154
+ print(f"\n{'='*60}")
155
+ print(f"📊 Training started - MetricsTracker active")
156
+ print(f"{'='*60}")
157
+
158
+ # Log config
159
+ config_file = self.log_dir / f"{self.experiment_name}_config.json"
160
+ with open(config_file, 'w') as f:
161
+ json.dump(config, f, indent=2, default=str)
162
+
163
+ def on_step(self, metrics: StepMetrics) -> None:
164
+ """Called after each training step."""
165
+ timestamp = time.time() - self.start_time if self.start_time else 0
166
+
167
+ # Store loss
168
+ self.train_losses.append(LossCurvePoint(
169
+ step=metrics.step,
170
+ loss=metrics.loss,
171
+ timestamp=timestamp,
172
+ ))
173
+
174
+ # Update moving average buffer
175
+ self._loss_buffer.append(metrics.loss)
176
+
177
+ # Store learning rate
178
+ self.learning_rates.append((metrics.step, metrics.learning_rate))
179
+
180
+ # Store gradient norm
181
+ self.gradient_norms.append((metrics.step, metrics.grad_norm))
182
+
183
+ # Store perplexity
184
+ self.train_perplexities.append((metrics.step, metrics.perplexity))
185
+
186
+ # Log to file periodically
187
+ if metrics.step % self.log_interval == 0:
188
+ self._log_step(metrics, timestamp)
189
+
190
+ def on_validation(self, metrics: ValidationMetrics) -> None:
191
+ """Called after validation."""
192
+ timestamp = time.time() - self.start_time if self.start_time else 0
193
+
194
+ # Store validation loss
195
+ self.val_losses.append(LossCurvePoint(
196
+ step=metrics.step,
197
+ loss=metrics.val_loss,
198
+ timestamp=timestamp,
199
+ ))
200
+
201
+ # Store validation perplexity
202
+ self.val_perplexities.append((metrics.step, metrics.val_perplexity))
203
+
204
+ # Track best
205
+ if metrics.val_loss < self.best_val_loss:
206
+ self.best_val_loss = metrics.val_loss
207
+ self.best_val_step = metrics.step
208
+ if self.verbose:
209
+ print(f" ⭐ New best val_loss: {metrics.val_loss:.4f}")
210
+
211
+ # Log to file
212
+ self._log_validation(metrics, timestamp)
213
+
214
+ def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
215
+ """Called when training completes."""
216
+ total_time = time.time() - self.start_time if self.start_time else 0
217
+
218
+ # Create summary
219
+ summary = {
220
+ "experiment_name": self.experiment_name,
221
+ "total_steps": len(self.train_losses),
222
+ "total_time_seconds": total_time,
223
+ "final_train_loss": self.train_losses[-1].loss if self.train_losses else None,
224
+ "final_val_loss": self.val_losses[-1].loss if self.val_losses else None,
225
+ "best_val_loss": self.best_val_loss,
226
+ "best_val_step": self.best_val_step,
227
+ "final_perplexity": self.train_perplexities[-1][1] if self.train_perplexities else None,
228
+ "config": self.config,
229
+ **final_metrics,
230
+ }
231
+
232
+ # Save summary
233
+ with open(self.summary_file, 'w') as f:
234
+ json.dump(summary, f, indent=2, default=str)
235
+
236
+ if self.verbose:
237
+ print(f"\n{'='*60}")
238
+ print(f"📊 Training completed - MetricsTracker summary")
239
+ print(f"{'='*60}")
240
+ print(f" Total steps: {len(self.train_losses)}")
241
+ print(f" Total time: {total_time/60:.2f} minutes")
242
+ print(f" Final train loss: {summary['final_train_loss']:.4f}" if summary['final_train_loss'] else "")
243
+ print(f" Best val loss: {self.best_val_loss:.4f} (step {self.best_val_step})")
244
+ print(f" Summary saved: {self.summary_file}")
245
+
246
+ def on_gradient_computed(self, metrics: GradientMetrics) -> None:
247
+ """Called after gradients are computed."""
248
+ # Additional gradient tracking if needed
249
+ pass
250
+
251
+ # === Metrics Access Methods ===
252
+
253
+ def get_train_loss_curve(self) -> List[Tuple[int, float]]:
254
+ """Get training loss curve as (step, loss) pairs."""
255
+ return [(p.step, p.loss) for p in self.train_losses]
256
+
257
+ def get_val_loss_curve(self) -> List[Tuple[int, float]]:
258
+ """Get validation loss curve as (step, loss) pairs."""
259
+ return [(p.step, p.loss) for p in self.val_losses]
260
+
261
+ def get_moving_average(self) -> float:
262
+ """Get current moving average of training loss."""
263
+ if not self._loss_buffer:
264
+ return 0.0
265
+ return sum(self._loss_buffer) / len(self._loss_buffer)
266
+
267
+ def get_smoothed_loss_curve(self, window: int = None) -> List[Tuple[int, float]]:
268
+ """
269
+ Get smoothed training loss curve using moving average.
270
+
271
+ Args:
272
+ window: Smoothing window size (default: self.moving_avg_window)
273
+
274
+ Returns:
275
+ List of (step, smoothed_loss) tuples
276
+ """
277
+ window = window or self.moving_avg_window
278
+ if len(self.train_losses) < window:
279
+ return self.get_train_loss_curve()
280
+
281
+ smoothed = []
282
+ losses = [p.loss for p in self.train_losses]
283
+ steps = [p.step for p in self.train_losses]
284
+
285
+ for i in range(window - 1, len(losses)):
286
+ avg = sum(losses[i - window + 1:i + 1]) / window
287
+ smoothed.append((steps[i], avg))
288
+
289
+ return smoothed
290
+
291
+ def get_loss_at_step(self, step: int) -> Optional[float]:
292
+ """Get training loss at specific step."""
293
+ for point in self.train_losses:
294
+ if point.step == step:
295
+ return point.loss
296
+ return None
297
+
298
+ def get_gradient_stats(self) -> Dict[str, float]:
299
+ """Get gradient norm statistics."""
300
+ if not self.gradient_norms:
301
+ return {}
302
+
303
+ norms = [n for _, n in self.gradient_norms]
304
+ return {
305
+ "mean": sum(norms) / len(norms),
306
+ "max": max(norms),
307
+ "min": min(norms),
308
+ "last": norms[-1],
309
+ }
310
+
311
+ def detect_issues(self) -> List[str]:
312
+ """
313
+ Detect potential training issues from metrics.
314
+
315
+ Returns:
316
+ List of warning messages
317
+ """
318
+ issues = []
319
+
320
+ if not self.train_losses:
321
+ return ["No training data recorded yet"]
322
+
323
+ # Check for NaN
324
+ if any(math.isnan(p.loss) for p in self.train_losses):
325
+ issues.append("⚠️ NaN loss detected - likely exploding gradients")
326
+
327
+ # Check for loss explosion
328
+ recent_losses = [p.loss for p in self.train_losses[-100:]]
329
+ if recent_losses and max(recent_losses) > 100:
330
+ issues.append("⚠️ Very high loss (>100) - check learning rate")
331
+
332
+ # Check for gradient explosion
333
+ if self.gradient_norms:
334
+ recent_grads = [n for _, n in self.gradient_norms[-100:]]
335
+ if max(recent_grads) > 100:
336
+ issues.append("⚠️ Large gradient norms (>100) - consider gradient clipping")
337
+
338
+ # Check for overfitting
339
+ if len(self.val_losses) >= 3:
340
+ recent_val = [p.loss for p in self.val_losses[-3:]]
341
+ if all(recent_val[i] > recent_val[i-1] for i in range(1, len(recent_val))):
342
+ issues.append("⚠️ Validation loss increasing - possible overfitting")
343
+
344
+ # Check for stalled training
345
+ if len(self.train_losses) >= 1000:
346
+ early_avg = sum(p.loss for p in self.train_losses[:100]) / 100
347
+ recent_avg = sum(p.loss for p in self.train_losses[-100:]) / 100
348
+ if abs(early_avg - recent_avg) < 0.01:
349
+ issues.append("⚠️ Loss not improving - training may be stuck")
350
+
351
+ return issues if issues else ["✓ No issues detected"]
352
+
353
+ # === Export Methods ===
354
+
355
+ def export_to_csv(self, filepath: str = None) -> str:
356
+ """
357
+ Export metrics to CSV file.
358
+
359
+ Args:
360
+ filepath: Output path (default: auto-generated)
361
+
362
+ Returns:
363
+ Path to saved file
364
+ """
365
+ filepath = filepath or str(self.log_dir / f"{self.experiment_name}_metrics.csv")
366
+
367
+ with open(filepath, 'w') as f:
368
+ # Header
369
+ f.write("step,train_loss,val_loss,learning_rate,grad_norm,perplexity,timestamp\n")
370
+
371
+ # Create lookup dicts
372
+ val_lookup = {p.step: p.loss for p in self.val_losses}
373
+ lr_lookup = dict(self.learning_rates)
374
+ grad_lookup = dict(self.gradient_norms)
375
+ ppl_lookup = dict(self.train_perplexities)
376
+
377
+ for point in self.train_losses:
378
+ val_loss = val_lookup.get(point.step, "")
379
+ lr = lr_lookup.get(point.step, "")
380
+ grad = grad_lookup.get(point.step, "")
381
+ ppl = ppl_lookup.get(point.step, "")
382
+ f.write(f"{point.step},{point.loss},{val_loss},{lr},{grad},{ppl},{point.timestamp}\n")
383
+
384
+ if self.verbose:
385
+ print(f"📁 Exported to CSV: {filepath}")
386
+ return filepath
387
+
388
+ def export_to_json(self, filepath: str = None) -> str:
389
+ """
390
+ Export all metrics to JSON file.
391
+
392
+ Args:
393
+ filepath: Output path (default: auto-generated)
394
+
395
+ Returns:
396
+ Path to saved file
397
+ """
398
+ filepath = filepath or str(self.log_dir / f"{self.experiment_name}_full_metrics.json")
399
+
400
+ data = {
401
+ "experiment_name": self.experiment_name,
402
+ "config": self.config,
403
+ "train_losses": [p.to_dict() for p in self.train_losses],
404
+ "val_losses": [p.to_dict() for p in self.val_losses],
405
+ "learning_rates": self.learning_rates,
406
+ "gradient_norms": self.gradient_norms,
407
+ "train_perplexities": self.train_perplexities,
408
+ "val_perplexities": self.val_perplexities,
409
+ "best_val_loss": self.best_val_loss,
410
+ "best_val_step": self.best_val_step,
411
+ }
412
+
413
+ with open(filepath, 'w') as f:
414
+ json.dump(data, f, indent=2)
415
+
416
+ if self.verbose:
417
+ print(f"📁 Exported to JSON: {filepath}")
418
+ return filepath
419
+
420
+ def plot_loss_curves(
421
+ self,
422
+ filepath: str = None,
423
+ show_smoothed: bool = True,
424
+ figsize: Tuple[int, int] = (12, 8),
425
+ ) -> Optional[str]:
426
+ """
427
+ Plot training and validation loss curves.
428
+
429
+ Args:
430
+ filepath: Output path (default: auto-generated)
431
+ show_smoothed: Whether to show smoothed curve
432
+ figsize: Figure size (width, height)
433
+
434
+ Returns:
435
+ Path to saved figure, or None if matplotlib not available
436
+ """
437
+ try:
438
+ import matplotlib.pyplot as plt
439
+ except ImportError:
440
+ print("⚠️ matplotlib not installed. Run: pip install matplotlib")
441
+ return None
442
+
443
+ filepath = filepath or str(self.log_dir / f"{self.experiment_name}_loss_curves.png")
444
+
445
+ fig, axes = plt.subplots(2, 2, figsize=figsize)
446
+
447
+ # === Plot 1: Loss Curves ===
448
+ ax1 = axes[0, 0]
449
+
450
+ # Training loss
451
+ train_steps = [p.step for p in self.train_losses]
452
+ train_loss = [p.loss for p in self.train_losses]
453
+ ax1.plot(train_steps, train_loss, alpha=0.3, label='Train Loss (raw)', color='blue')
454
+
455
+ # Smoothed training loss
456
+ if show_smoothed and len(self.train_losses) > self.moving_avg_window:
457
+ smoothed = self.get_smoothed_loss_curve()
458
+ smooth_steps, smooth_loss = zip(*smoothed)
459
+ ax1.plot(smooth_steps, smooth_loss, label=f'Train Loss (MA-{self.moving_avg_window})', color='blue')
460
+
461
+ # Validation loss
462
+ if self.val_losses:
463
+ val_steps = [p.step for p in self.val_losses]
464
+ val_loss = [p.loss for p in self.val_losses]
465
+ ax1.plot(val_steps, val_loss, 'o-', label='Val Loss', color='orange', markersize=4)
466
+
467
+ ax1.set_xlabel('Step')
468
+ ax1.set_ylabel('Loss')
469
+ ax1.set_title('Training & Validation Loss')
470
+ ax1.legend()
471
+ ax1.grid(True, alpha=0.3)
472
+
473
+ # === Plot 2: Learning Rate ===
474
+ ax2 = axes[0, 1]
475
+ if self.learning_rates:
476
+ lr_steps, lr_values = zip(*self.learning_rates)
477
+ ax2.plot(lr_steps, lr_values, color='green')
478
+ ax2.set_xlabel('Step')
479
+ ax2.set_ylabel('Learning Rate')
480
+ ax2.set_title('Learning Rate Schedule')
481
+ ax2.grid(True, alpha=0.3)
482
+
483
+ # === Plot 3: Gradient Norms ===
484
+ ax3 = axes[1, 0]
485
+ if self.gradient_norms:
486
+ grad_steps, grad_values = zip(*self.gradient_norms)
487
+ ax3.plot(grad_steps, grad_values, alpha=0.5, color='red')
488
+ ax3.set_xlabel('Step')
489
+ ax3.set_ylabel('Gradient Norm')
490
+ ax3.set_title('Gradient Norms')
491
+ ax3.grid(True, alpha=0.3)
492
+
493
+ # === Plot 4: Perplexity ===
494
+ ax4 = axes[1, 1]
495
+ if self.train_perplexities:
496
+ ppl_steps, ppl_values = zip(*self.train_perplexities)
497
+ # Cap perplexity for visualization
498
+ ppl_values = [min(p, 1000) for p in ppl_values]
499
+ ax4.plot(ppl_steps, ppl_values, alpha=0.5, label='Train Perplexity', color='purple')
500
+ if self.val_perplexities:
501
+ val_ppl_steps, val_ppl_values = zip(*self.val_perplexities)
502
+ val_ppl_values = [min(p, 1000) for p in val_ppl_values]
503
+ ax4.plot(val_ppl_steps, val_ppl_values, 'o-', label='Val Perplexity', color='magenta', markersize=4)
504
+ ax4.set_xlabel('Step')
505
+ ax4.set_ylabel('Perplexity')
506
+ ax4.set_title('Perplexity')
507
+ ax4.legend()
508
+ ax4.grid(True, alpha=0.3)
509
+
510
+ plt.suptitle(f'Training Metrics: {self.experiment_name}', fontsize=14)
511
+ plt.tight_layout()
512
+ plt.savefig(filepath, dpi=150)
513
+ plt.close()
514
+
515
+ if self.verbose:
516
+ print(f"📊 Loss curves saved: {filepath}")
517
+
518
+ return filepath
519
+
520
+ # === Private Methods ===
521
+
522
+ def _log_step(self, metrics: StepMetrics, timestamp: float) -> None:
523
+ """Log step metrics to file."""
524
+ log_entry = {
525
+ "type": "step",
526
+ "timestamp": timestamp,
527
+ "moving_avg_loss": self.get_moving_average(),
528
+ **metrics.to_dict(),
529
+ }
530
+
531
+ with open(self.metrics_file, 'a') as f:
532
+ f.write(json.dumps(log_entry) + '\n')
533
+
534
+ def _log_validation(self, metrics: ValidationMetrics, timestamp: float) -> None:
535
+ """Log validation metrics to file."""
536
+ log_entry = {
537
+ "type": "validation",
538
+ "timestamp": timestamp,
539
+ "is_best": metrics.val_loss <= self.best_val_loss,
540
+ **metrics.to_dict(),
541
+ }
542
+
543
+ with open(self.metrics_file, 'a') as f:
544
+ f.write(json.dumps(log_entry) + '\n')