gptmed 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -3
- gptmed/api.py +1 -0
- gptmed/configs/train_config.py +5 -0
- gptmed/model/__init__.py +2 -2
- gptmed/observability/__init__.py +43 -0
- gptmed/observability/base.py +369 -0
- gptmed/observability/callbacks.py +397 -0
- gptmed/observability/metrics_tracker.py +544 -0
- gptmed/services/training_service.py +161 -7
- gptmed/training/trainer.py +124 -10
- gptmed/utils/checkpoints.py +1 -1
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/METADATA +180 -43
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/RECORD +17 -13
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/WHEEL +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/entry_points.txt +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,544 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Metrics Tracker for Training Observability
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Enhanced metrics tracking with loss curves, moving averages,
|
|
6
|
+
gradient statistics, and export capabilities.
|
|
7
|
+
|
|
8
|
+
FEATURES:
|
|
9
|
+
- Loss curve history (train & validation)
|
|
10
|
+
- Moving averages for smoothed visualization
|
|
11
|
+
- Perplexity tracking
|
|
12
|
+
- Learning rate schedule visualization
|
|
13
|
+
- Gradient norm monitoring
|
|
14
|
+
- Export to JSON, CSV, and plots
|
|
15
|
+
|
|
16
|
+
WHAT TO LOOK FOR:
|
|
17
|
+
- Train loss ↓, Val loss ↓ → Healthy learning
|
|
18
|
+
- Train loss ↓, Val loss ↑ → Overfitting
|
|
19
|
+
- Loss plateau → Stuck (increase LR or check data)
|
|
20
|
+
- Loss spikes → Instability (reduce LR)
|
|
21
|
+
- Loss = NaN → Exploding gradients
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import json
|
|
25
|
+
import math
|
|
26
|
+
import time
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Dict, List, Optional, Any, Tuple
|
|
29
|
+
from dataclasses import dataclass, field
|
|
30
|
+
from collections import deque
|
|
31
|
+
|
|
32
|
+
from gptmed.observability.base import (
|
|
33
|
+
TrainingObserver,
|
|
34
|
+
StepMetrics,
|
|
35
|
+
ValidationMetrics,
|
|
36
|
+
GradientMetrics,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class LossCurvePoint:
|
|
42
|
+
"""Single point on the loss curve."""
|
|
43
|
+
step: int
|
|
44
|
+
loss: float
|
|
45
|
+
timestamp: float
|
|
46
|
+
|
|
47
|
+
def to_dict(self) -> Dict[str, float]:
|
|
48
|
+
return {
|
|
49
|
+
"step": self.step,
|
|
50
|
+
"loss": self.loss,
|
|
51
|
+
"timestamp": self.timestamp,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MetricsTracker(TrainingObserver):
|
|
56
|
+
"""
|
|
57
|
+
Comprehensive metrics tracking for training observability.
|
|
58
|
+
|
|
59
|
+
Tracks:
|
|
60
|
+
- Training loss curve
|
|
61
|
+
- Validation loss curve
|
|
62
|
+
- Learning rate schedule
|
|
63
|
+
- Gradient norms
|
|
64
|
+
- Perplexity
|
|
65
|
+
- Moving averages
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
>>> tracker = MetricsTracker(log_dir='logs/experiment1')
|
|
69
|
+
>>> trainer.add_observer(tracker)
|
|
70
|
+
>>> # After training:
|
|
71
|
+
>>> tracker.plot_loss_curves()
|
|
72
|
+
>>> tracker.export_to_csv('metrics.csv')
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
log_dir: str = "logs",
|
|
78
|
+
experiment_name: str = "training",
|
|
79
|
+
moving_avg_window: int = 100,
|
|
80
|
+
log_interval: int = 10,
|
|
81
|
+
verbose: bool = True,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Initialize MetricsTracker.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
log_dir: Directory to save logs and exports
|
|
88
|
+
experiment_name: Name for this experiment
|
|
89
|
+
moving_avg_window: Window size for moving average
|
|
90
|
+
log_interval: How often to log to file (every N steps)
|
|
91
|
+
verbose: Whether to print progress
|
|
92
|
+
"""
|
|
93
|
+
super().__init__(name="MetricsTracker")
|
|
94
|
+
|
|
95
|
+
self.log_dir = Path(log_dir)
|
|
96
|
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
self.experiment_name = experiment_name
|
|
99
|
+
self.moving_avg_window = moving_avg_window
|
|
100
|
+
self.log_interval = log_interval
|
|
101
|
+
self.verbose = verbose
|
|
102
|
+
|
|
103
|
+
# Initialize storage
|
|
104
|
+
self._reset_storage()
|
|
105
|
+
|
|
106
|
+
# File paths
|
|
107
|
+
self.metrics_file = self.log_dir / f"{experiment_name}_metrics.jsonl"
|
|
108
|
+
self.summary_file = self.log_dir / f"{experiment_name}_summary.json"
|
|
109
|
+
|
|
110
|
+
if self.verbose:
|
|
111
|
+
print(f"📊 MetricsTracker initialized")
|
|
112
|
+
print(f" Log directory: {self.log_dir}")
|
|
113
|
+
print(f" Moving average window: {moving_avg_window}")
|
|
114
|
+
|
|
115
|
+
def _reset_storage(self) -> None:
|
|
116
|
+
"""Reset all metric storage."""
|
|
117
|
+
# Loss curves
|
|
118
|
+
self.train_losses: List[LossCurvePoint] = []
|
|
119
|
+
self.val_losses: List[LossCurvePoint] = []
|
|
120
|
+
|
|
121
|
+
# Moving average buffer
|
|
122
|
+
self._loss_buffer: deque = deque(maxlen=self.moving_avg_window)
|
|
123
|
+
|
|
124
|
+
# Learning rate history
|
|
125
|
+
self.learning_rates: List[Tuple[int, float]] = []
|
|
126
|
+
|
|
127
|
+
# Gradient norms
|
|
128
|
+
self.gradient_norms: List[Tuple[int, float]] = []
|
|
129
|
+
|
|
130
|
+
# Perplexity
|
|
131
|
+
self.train_perplexities: List[Tuple[int, float]] = []
|
|
132
|
+
self.val_perplexities: List[Tuple[int, float]] = []
|
|
133
|
+
|
|
134
|
+
# Timing
|
|
135
|
+
self.start_time: Optional[float] = None
|
|
136
|
+
self.step_times: List[float] = []
|
|
137
|
+
|
|
138
|
+
# Training config
|
|
139
|
+
self.config: Dict[str, Any] = {}
|
|
140
|
+
|
|
141
|
+
# Best metrics
|
|
142
|
+
self.best_val_loss: float = float('inf')
|
|
143
|
+
self.best_val_step: int = 0
|
|
144
|
+
|
|
145
|
+
# === TrainingObserver Implementation ===
|
|
146
|
+
|
|
147
|
+
def on_train_start(self, config: Dict[str, Any]) -> None:
|
|
148
|
+
"""Called when training begins."""
|
|
149
|
+
self._reset_storage()
|
|
150
|
+
self.start_time = time.time()
|
|
151
|
+
self.config = config.copy()
|
|
152
|
+
|
|
153
|
+
if self.verbose:
|
|
154
|
+
print(f"\n{'='*60}")
|
|
155
|
+
print(f"📊 Training started - MetricsTracker active")
|
|
156
|
+
print(f"{'='*60}")
|
|
157
|
+
|
|
158
|
+
# Log config
|
|
159
|
+
config_file = self.log_dir / f"{self.experiment_name}_config.json"
|
|
160
|
+
with open(config_file, 'w') as f:
|
|
161
|
+
json.dump(config, f, indent=2, default=str)
|
|
162
|
+
|
|
163
|
+
def on_step(self, metrics: StepMetrics) -> None:
|
|
164
|
+
"""Called after each training step."""
|
|
165
|
+
timestamp = time.time() - self.start_time if self.start_time else 0
|
|
166
|
+
|
|
167
|
+
# Store loss
|
|
168
|
+
self.train_losses.append(LossCurvePoint(
|
|
169
|
+
step=metrics.step,
|
|
170
|
+
loss=metrics.loss,
|
|
171
|
+
timestamp=timestamp,
|
|
172
|
+
))
|
|
173
|
+
|
|
174
|
+
# Update moving average buffer
|
|
175
|
+
self._loss_buffer.append(metrics.loss)
|
|
176
|
+
|
|
177
|
+
# Store learning rate
|
|
178
|
+
self.learning_rates.append((metrics.step, metrics.learning_rate))
|
|
179
|
+
|
|
180
|
+
# Store gradient norm
|
|
181
|
+
self.gradient_norms.append((metrics.step, metrics.grad_norm))
|
|
182
|
+
|
|
183
|
+
# Store perplexity
|
|
184
|
+
self.train_perplexities.append((metrics.step, metrics.perplexity))
|
|
185
|
+
|
|
186
|
+
# Log to file periodically
|
|
187
|
+
if metrics.step % self.log_interval == 0:
|
|
188
|
+
self._log_step(metrics, timestamp)
|
|
189
|
+
|
|
190
|
+
def on_validation(self, metrics: ValidationMetrics) -> None:
|
|
191
|
+
"""Called after validation."""
|
|
192
|
+
timestamp = time.time() - self.start_time if self.start_time else 0
|
|
193
|
+
|
|
194
|
+
# Store validation loss
|
|
195
|
+
self.val_losses.append(LossCurvePoint(
|
|
196
|
+
step=metrics.step,
|
|
197
|
+
loss=metrics.val_loss,
|
|
198
|
+
timestamp=timestamp,
|
|
199
|
+
))
|
|
200
|
+
|
|
201
|
+
# Store validation perplexity
|
|
202
|
+
self.val_perplexities.append((metrics.step, metrics.val_perplexity))
|
|
203
|
+
|
|
204
|
+
# Track best
|
|
205
|
+
if metrics.val_loss < self.best_val_loss:
|
|
206
|
+
self.best_val_loss = metrics.val_loss
|
|
207
|
+
self.best_val_step = metrics.step
|
|
208
|
+
if self.verbose:
|
|
209
|
+
print(f" ⭐ New best val_loss: {metrics.val_loss:.4f}")
|
|
210
|
+
|
|
211
|
+
# Log to file
|
|
212
|
+
self._log_validation(metrics, timestamp)
|
|
213
|
+
|
|
214
|
+
def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
|
|
215
|
+
"""Called when training completes."""
|
|
216
|
+
total_time = time.time() - self.start_time if self.start_time else 0
|
|
217
|
+
|
|
218
|
+
# Create summary
|
|
219
|
+
summary = {
|
|
220
|
+
"experiment_name": self.experiment_name,
|
|
221
|
+
"total_steps": len(self.train_losses),
|
|
222
|
+
"total_time_seconds": total_time,
|
|
223
|
+
"final_train_loss": self.train_losses[-1].loss if self.train_losses else None,
|
|
224
|
+
"final_val_loss": self.val_losses[-1].loss if self.val_losses else None,
|
|
225
|
+
"best_val_loss": self.best_val_loss,
|
|
226
|
+
"best_val_step": self.best_val_step,
|
|
227
|
+
"final_perplexity": self.train_perplexities[-1][1] if self.train_perplexities else None,
|
|
228
|
+
"config": self.config,
|
|
229
|
+
**final_metrics,
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
# Save summary
|
|
233
|
+
with open(self.summary_file, 'w') as f:
|
|
234
|
+
json.dump(summary, f, indent=2, default=str)
|
|
235
|
+
|
|
236
|
+
if self.verbose:
|
|
237
|
+
print(f"\n{'='*60}")
|
|
238
|
+
print(f"📊 Training completed - MetricsTracker summary")
|
|
239
|
+
print(f"{'='*60}")
|
|
240
|
+
print(f" Total steps: {len(self.train_losses)}")
|
|
241
|
+
print(f" Total time: {total_time/60:.2f} minutes")
|
|
242
|
+
print(f" Final train loss: {summary['final_train_loss']:.4f}" if summary['final_train_loss'] else "")
|
|
243
|
+
print(f" Best val loss: {self.best_val_loss:.4f} (step {self.best_val_step})")
|
|
244
|
+
print(f" Summary saved: {self.summary_file}")
|
|
245
|
+
|
|
246
|
+
def on_gradient_computed(self, metrics: GradientMetrics) -> None:
|
|
247
|
+
"""Called after gradients are computed."""
|
|
248
|
+
# Additional gradient tracking if needed
|
|
249
|
+
pass
|
|
250
|
+
|
|
251
|
+
# === Metrics Access Methods ===
|
|
252
|
+
|
|
253
|
+
def get_train_loss_curve(self) -> List[Tuple[int, float]]:
|
|
254
|
+
"""Get training loss curve as (step, loss) pairs."""
|
|
255
|
+
return [(p.step, p.loss) for p in self.train_losses]
|
|
256
|
+
|
|
257
|
+
def get_val_loss_curve(self) -> List[Tuple[int, float]]:
|
|
258
|
+
"""Get validation loss curve as (step, loss) pairs."""
|
|
259
|
+
return [(p.step, p.loss) for p in self.val_losses]
|
|
260
|
+
|
|
261
|
+
def get_moving_average(self) -> float:
|
|
262
|
+
"""Get current moving average of training loss."""
|
|
263
|
+
if not self._loss_buffer:
|
|
264
|
+
return 0.0
|
|
265
|
+
return sum(self._loss_buffer) / len(self._loss_buffer)
|
|
266
|
+
|
|
267
|
+
def get_smoothed_loss_curve(self, window: int = None) -> List[Tuple[int, float]]:
|
|
268
|
+
"""
|
|
269
|
+
Get smoothed training loss curve using moving average.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
window: Smoothing window size (default: self.moving_avg_window)
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
List of (step, smoothed_loss) tuples
|
|
276
|
+
"""
|
|
277
|
+
window = window or self.moving_avg_window
|
|
278
|
+
if len(self.train_losses) < window:
|
|
279
|
+
return self.get_train_loss_curve()
|
|
280
|
+
|
|
281
|
+
smoothed = []
|
|
282
|
+
losses = [p.loss for p in self.train_losses]
|
|
283
|
+
steps = [p.step for p in self.train_losses]
|
|
284
|
+
|
|
285
|
+
for i in range(window - 1, len(losses)):
|
|
286
|
+
avg = sum(losses[i - window + 1:i + 1]) / window
|
|
287
|
+
smoothed.append((steps[i], avg))
|
|
288
|
+
|
|
289
|
+
return smoothed
|
|
290
|
+
|
|
291
|
+
def get_loss_at_step(self, step: int) -> Optional[float]:
|
|
292
|
+
"""Get training loss at specific step."""
|
|
293
|
+
for point in self.train_losses:
|
|
294
|
+
if point.step == step:
|
|
295
|
+
return point.loss
|
|
296
|
+
return None
|
|
297
|
+
|
|
298
|
+
def get_gradient_stats(self) -> Dict[str, float]:
|
|
299
|
+
"""Get gradient norm statistics."""
|
|
300
|
+
if not self.gradient_norms:
|
|
301
|
+
return {}
|
|
302
|
+
|
|
303
|
+
norms = [n for _, n in self.gradient_norms]
|
|
304
|
+
return {
|
|
305
|
+
"mean": sum(norms) / len(norms),
|
|
306
|
+
"max": max(norms),
|
|
307
|
+
"min": min(norms),
|
|
308
|
+
"last": norms[-1],
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
def detect_issues(self) -> List[str]:
|
|
312
|
+
"""
|
|
313
|
+
Detect potential training issues from metrics.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
List of warning messages
|
|
317
|
+
"""
|
|
318
|
+
issues = []
|
|
319
|
+
|
|
320
|
+
if not self.train_losses:
|
|
321
|
+
return ["No training data recorded yet"]
|
|
322
|
+
|
|
323
|
+
# Check for NaN
|
|
324
|
+
if any(math.isnan(p.loss) for p in self.train_losses):
|
|
325
|
+
issues.append("⚠️ NaN loss detected - likely exploding gradients")
|
|
326
|
+
|
|
327
|
+
# Check for loss explosion
|
|
328
|
+
recent_losses = [p.loss for p in self.train_losses[-100:]]
|
|
329
|
+
if recent_losses and max(recent_losses) > 100:
|
|
330
|
+
issues.append("⚠️ Very high loss (>100) - check learning rate")
|
|
331
|
+
|
|
332
|
+
# Check for gradient explosion
|
|
333
|
+
if self.gradient_norms:
|
|
334
|
+
recent_grads = [n for _, n in self.gradient_norms[-100:]]
|
|
335
|
+
if max(recent_grads) > 100:
|
|
336
|
+
issues.append("⚠️ Large gradient norms (>100) - consider gradient clipping")
|
|
337
|
+
|
|
338
|
+
# Check for overfitting
|
|
339
|
+
if len(self.val_losses) >= 3:
|
|
340
|
+
recent_val = [p.loss for p in self.val_losses[-3:]]
|
|
341
|
+
if all(recent_val[i] > recent_val[i-1] for i in range(1, len(recent_val))):
|
|
342
|
+
issues.append("⚠️ Validation loss increasing - possible overfitting")
|
|
343
|
+
|
|
344
|
+
# Check for stalled training
|
|
345
|
+
if len(self.train_losses) >= 1000:
|
|
346
|
+
early_avg = sum(p.loss for p in self.train_losses[:100]) / 100
|
|
347
|
+
recent_avg = sum(p.loss for p in self.train_losses[-100:]) / 100
|
|
348
|
+
if abs(early_avg - recent_avg) < 0.01:
|
|
349
|
+
issues.append("⚠️ Loss not improving - training may be stuck")
|
|
350
|
+
|
|
351
|
+
return issues if issues else ["✓ No issues detected"]
|
|
352
|
+
|
|
353
|
+
# === Export Methods ===
|
|
354
|
+
|
|
355
|
+
def export_to_csv(self, filepath: str = None) -> str:
|
|
356
|
+
"""
|
|
357
|
+
Export metrics to CSV file.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
filepath: Output path (default: auto-generated)
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Path to saved file
|
|
364
|
+
"""
|
|
365
|
+
filepath = filepath or str(self.log_dir / f"{self.experiment_name}_metrics.csv")
|
|
366
|
+
|
|
367
|
+
with open(filepath, 'w') as f:
|
|
368
|
+
# Header
|
|
369
|
+
f.write("step,train_loss,val_loss,learning_rate,grad_norm,perplexity,timestamp\n")
|
|
370
|
+
|
|
371
|
+
# Create lookup dicts
|
|
372
|
+
val_lookup = {p.step: p.loss for p in self.val_losses}
|
|
373
|
+
lr_lookup = dict(self.learning_rates)
|
|
374
|
+
grad_lookup = dict(self.gradient_norms)
|
|
375
|
+
ppl_lookup = dict(self.train_perplexities)
|
|
376
|
+
|
|
377
|
+
for point in self.train_losses:
|
|
378
|
+
val_loss = val_lookup.get(point.step, "")
|
|
379
|
+
lr = lr_lookup.get(point.step, "")
|
|
380
|
+
grad = grad_lookup.get(point.step, "")
|
|
381
|
+
ppl = ppl_lookup.get(point.step, "")
|
|
382
|
+
f.write(f"{point.step},{point.loss},{val_loss},{lr},{grad},{ppl},{point.timestamp}\n")
|
|
383
|
+
|
|
384
|
+
if self.verbose:
|
|
385
|
+
print(f"📁 Exported to CSV: {filepath}")
|
|
386
|
+
return filepath
|
|
387
|
+
|
|
388
|
+
def export_to_json(self, filepath: str = None) -> str:
|
|
389
|
+
"""
|
|
390
|
+
Export all metrics to JSON file.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
filepath: Output path (default: auto-generated)
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
Path to saved file
|
|
397
|
+
"""
|
|
398
|
+
filepath = filepath or str(self.log_dir / f"{self.experiment_name}_full_metrics.json")
|
|
399
|
+
|
|
400
|
+
data = {
|
|
401
|
+
"experiment_name": self.experiment_name,
|
|
402
|
+
"config": self.config,
|
|
403
|
+
"train_losses": [p.to_dict() for p in self.train_losses],
|
|
404
|
+
"val_losses": [p.to_dict() for p in self.val_losses],
|
|
405
|
+
"learning_rates": self.learning_rates,
|
|
406
|
+
"gradient_norms": self.gradient_norms,
|
|
407
|
+
"train_perplexities": self.train_perplexities,
|
|
408
|
+
"val_perplexities": self.val_perplexities,
|
|
409
|
+
"best_val_loss": self.best_val_loss,
|
|
410
|
+
"best_val_step": self.best_val_step,
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
with open(filepath, 'w') as f:
|
|
414
|
+
json.dump(data, f, indent=2)
|
|
415
|
+
|
|
416
|
+
if self.verbose:
|
|
417
|
+
print(f"📁 Exported to JSON: {filepath}")
|
|
418
|
+
return filepath
|
|
419
|
+
|
|
420
|
+
def plot_loss_curves(
|
|
421
|
+
self,
|
|
422
|
+
filepath: str = None,
|
|
423
|
+
show_smoothed: bool = True,
|
|
424
|
+
figsize: Tuple[int, int] = (12, 8),
|
|
425
|
+
) -> Optional[str]:
|
|
426
|
+
"""
|
|
427
|
+
Plot training and validation loss curves.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
filepath: Output path (default: auto-generated)
|
|
431
|
+
show_smoothed: Whether to show smoothed curve
|
|
432
|
+
figsize: Figure size (width, height)
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
Path to saved figure, or None if matplotlib not available
|
|
436
|
+
"""
|
|
437
|
+
try:
|
|
438
|
+
import matplotlib.pyplot as plt
|
|
439
|
+
except ImportError:
|
|
440
|
+
print("⚠️ matplotlib not installed. Run: pip install matplotlib")
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
filepath = filepath or str(self.log_dir / f"{self.experiment_name}_loss_curves.png")
|
|
444
|
+
|
|
445
|
+
fig, axes = plt.subplots(2, 2, figsize=figsize)
|
|
446
|
+
|
|
447
|
+
# === Plot 1: Loss Curves ===
|
|
448
|
+
ax1 = axes[0, 0]
|
|
449
|
+
|
|
450
|
+
# Training loss
|
|
451
|
+
train_steps = [p.step for p in self.train_losses]
|
|
452
|
+
train_loss = [p.loss for p in self.train_losses]
|
|
453
|
+
ax1.plot(train_steps, train_loss, alpha=0.3, label='Train Loss (raw)', color='blue')
|
|
454
|
+
|
|
455
|
+
# Smoothed training loss
|
|
456
|
+
if show_smoothed and len(self.train_losses) > self.moving_avg_window:
|
|
457
|
+
smoothed = self.get_smoothed_loss_curve()
|
|
458
|
+
smooth_steps, smooth_loss = zip(*smoothed)
|
|
459
|
+
ax1.plot(smooth_steps, smooth_loss, label=f'Train Loss (MA-{self.moving_avg_window})', color='blue')
|
|
460
|
+
|
|
461
|
+
# Validation loss
|
|
462
|
+
if self.val_losses:
|
|
463
|
+
val_steps = [p.step for p in self.val_losses]
|
|
464
|
+
val_loss = [p.loss for p in self.val_losses]
|
|
465
|
+
ax1.plot(val_steps, val_loss, 'o-', label='Val Loss', color='orange', markersize=4)
|
|
466
|
+
|
|
467
|
+
ax1.set_xlabel('Step')
|
|
468
|
+
ax1.set_ylabel('Loss')
|
|
469
|
+
ax1.set_title('Training & Validation Loss')
|
|
470
|
+
ax1.legend()
|
|
471
|
+
ax1.grid(True, alpha=0.3)
|
|
472
|
+
|
|
473
|
+
# === Plot 2: Learning Rate ===
|
|
474
|
+
ax2 = axes[0, 1]
|
|
475
|
+
if self.learning_rates:
|
|
476
|
+
lr_steps, lr_values = zip(*self.learning_rates)
|
|
477
|
+
ax2.plot(lr_steps, lr_values, color='green')
|
|
478
|
+
ax2.set_xlabel('Step')
|
|
479
|
+
ax2.set_ylabel('Learning Rate')
|
|
480
|
+
ax2.set_title('Learning Rate Schedule')
|
|
481
|
+
ax2.grid(True, alpha=0.3)
|
|
482
|
+
|
|
483
|
+
# === Plot 3: Gradient Norms ===
|
|
484
|
+
ax3 = axes[1, 0]
|
|
485
|
+
if self.gradient_norms:
|
|
486
|
+
grad_steps, grad_values = zip(*self.gradient_norms)
|
|
487
|
+
ax3.plot(grad_steps, grad_values, alpha=0.5, color='red')
|
|
488
|
+
ax3.set_xlabel('Step')
|
|
489
|
+
ax3.set_ylabel('Gradient Norm')
|
|
490
|
+
ax3.set_title('Gradient Norms')
|
|
491
|
+
ax3.grid(True, alpha=0.3)
|
|
492
|
+
|
|
493
|
+
# === Plot 4: Perplexity ===
|
|
494
|
+
ax4 = axes[1, 1]
|
|
495
|
+
if self.train_perplexities:
|
|
496
|
+
ppl_steps, ppl_values = zip(*self.train_perplexities)
|
|
497
|
+
# Cap perplexity for visualization
|
|
498
|
+
ppl_values = [min(p, 1000) for p in ppl_values]
|
|
499
|
+
ax4.plot(ppl_steps, ppl_values, alpha=0.5, label='Train Perplexity', color='purple')
|
|
500
|
+
if self.val_perplexities:
|
|
501
|
+
val_ppl_steps, val_ppl_values = zip(*self.val_perplexities)
|
|
502
|
+
val_ppl_values = [min(p, 1000) for p in val_ppl_values]
|
|
503
|
+
ax4.plot(val_ppl_steps, val_ppl_values, 'o-', label='Val Perplexity', color='magenta', markersize=4)
|
|
504
|
+
ax4.set_xlabel('Step')
|
|
505
|
+
ax4.set_ylabel('Perplexity')
|
|
506
|
+
ax4.set_title('Perplexity')
|
|
507
|
+
ax4.legend()
|
|
508
|
+
ax4.grid(True, alpha=0.3)
|
|
509
|
+
|
|
510
|
+
plt.suptitle(f'Training Metrics: {self.experiment_name}', fontsize=14)
|
|
511
|
+
plt.tight_layout()
|
|
512
|
+
plt.savefig(filepath, dpi=150)
|
|
513
|
+
plt.close()
|
|
514
|
+
|
|
515
|
+
if self.verbose:
|
|
516
|
+
print(f"📊 Loss curves saved: {filepath}")
|
|
517
|
+
|
|
518
|
+
return filepath
|
|
519
|
+
|
|
520
|
+
# === Private Methods ===
|
|
521
|
+
|
|
522
|
+
def _log_step(self, metrics: StepMetrics, timestamp: float) -> None:
|
|
523
|
+
"""Log step metrics to file."""
|
|
524
|
+
log_entry = {
|
|
525
|
+
"type": "step",
|
|
526
|
+
"timestamp": timestamp,
|
|
527
|
+
"moving_avg_loss": self.get_moving_average(),
|
|
528
|
+
**metrics.to_dict(),
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
with open(self.metrics_file, 'a') as f:
|
|
532
|
+
f.write(json.dumps(log_entry) + '\n')
|
|
533
|
+
|
|
534
|
+
def _log_validation(self, metrics: ValidationMetrics, timestamp: float) -> None:
|
|
535
|
+
"""Log validation metrics to file."""
|
|
536
|
+
log_entry = {
|
|
537
|
+
"type": "validation",
|
|
538
|
+
"timestamp": timestamp,
|
|
539
|
+
"is_best": metrics.val_loss <= self.best_val_loss,
|
|
540
|
+
**metrics.to_dict(),
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
with open(self.metrics_file, 'a') as f:
|
|
544
|
+
f.write(json.dumps(log_entry) + '\n')
|