gptmed 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,397 @@
1
+ """
2
+ Training Callbacks
3
+
4
+ PURPOSE:
5
+ Concrete observer implementations for common training needs.
6
+
7
+ CALLBACKS INCLUDED:
8
+ - ConsoleCallback: Print progress to console
9
+ - JSONLoggerCallback: Log to JSONL file
10
+ - EarlyStoppingCallback: Stop training if no improvement
11
+ - (Future) TensorBoardCallback: Log to TensorBoard
12
+ - (Future) WandBCallback: Log to Weights & Biases
13
+ """
14
+
15
+ import json
16
+ import time
17
+ from pathlib import Path
18
+ from typing import Dict, Any, Optional
19
+
20
+ from gptmed.observability.base import (
21
+ TrainingObserver,
22
+ StepMetrics,
23
+ ValidationMetrics,
24
+ GradientMetrics,
25
+ )
26
+
27
+
28
+ class ConsoleCallback(TrainingObserver):
29
+ """
30
+ Prints training progress to console.
31
+
32
+ Features:
33
+ - Colored output for warnings
34
+ - Progress bar style step counter
35
+ - Issue detection (NaN, high loss, etc.)
36
+
37
+ Example:
38
+ >>> trainer.add_observer(ConsoleCallback(log_interval=100))
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ log_interval: int = 100,
44
+ show_progress_bar: bool = True,
45
+ name: str = "ConsoleCallback",
46
+ ):
47
+ """
48
+ Initialize ConsoleCallback.
49
+
50
+ Args:
51
+ log_interval: Print every N steps
52
+ show_progress_bar: Whether to show progress indicators
53
+ name: Callback name
54
+ """
55
+ super().__init__(name=name)
56
+ self.log_interval = log_interval
57
+ self.show_progress_bar = show_progress_bar
58
+ self.total_steps: int = 0
59
+ self.start_time: Optional[float] = None
60
+
61
+ def on_train_start(self, config: Dict[str, Any]) -> None:
62
+ """Called when training begins."""
63
+ self.start_time = time.time()
64
+ self.total_steps = config.get('max_steps', 0) or config.get('total_steps', 0)
65
+
66
+ print("\n" + "=" * 70)
67
+ print("🚀 Training Started")
68
+ print("=" * 70)
69
+ print(f" Model: {config.get('model_size', 'unknown')}")
70
+ print(f" Device: {config.get('device', 'unknown')}")
71
+ print(f" Batch size: {config.get('batch_size', 'unknown')}")
72
+ print(f" Learning rate: {config.get('learning_rate', 'unknown')}")
73
+ print(f" Total steps: {self.total_steps}")
74
+ print("=" * 70 + "\n")
75
+
76
+ def on_step(self, metrics: StepMetrics) -> None:
77
+ """Called after each training step."""
78
+ if metrics.step % self.log_interval != 0:
79
+ return
80
+
81
+ # Calculate progress
82
+ progress = ""
83
+ if self.total_steps > 0:
84
+ pct = (metrics.step / self.total_steps) * 100
85
+ progress = f"[{pct:5.1f}%] "
86
+
87
+ # Calculate speed
88
+ elapsed = time.time() - self.start_time if self.start_time else 0
89
+ steps_per_sec = metrics.step / elapsed if elapsed > 0 else 0
90
+
91
+ # Build message
92
+ msg = f"{progress}Step {metrics.step:6d} | "
93
+ msg += f"Loss: {metrics.loss:.4f} | "
94
+ msg += f"PPL: {metrics.perplexity:.2f} | "
95
+ msg += f"LR: {metrics.learning_rate:.2e} | "
96
+ msg += f"Grad: {metrics.grad_norm:.3f} | "
97
+ msg += f"{steps_per_sec:.1f} steps/s"
98
+
99
+ # Check for issues
100
+ warnings = []
101
+ if metrics.loss != metrics.loss: # NaN check
102
+ warnings.append("🔥 NaN LOSS!")
103
+ elif metrics.loss > 100:
104
+ warnings.append("⚠️ High loss")
105
+
106
+ if metrics.grad_norm > 10:
107
+ warnings.append("⚠️ Large grads")
108
+
109
+ if warnings:
110
+ msg += " | " + " ".join(warnings)
111
+
112
+ print(msg)
113
+
114
+ def on_validation(self, metrics: ValidationMetrics) -> None:
115
+ """Called after validation."""
116
+ print(f"\n{'─' * 50}")
117
+ print(f"📊 Validation @ Step {metrics.step}")
118
+ print(f" Val Loss: {metrics.val_loss:.4f}")
119
+ print(f" Val PPL: {metrics.val_perplexity:.2f}")
120
+ print(f"{'─' * 50}\n")
121
+
122
+ def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
123
+ """Called when training completes."""
124
+ elapsed = time.time() - self.start_time if self.start_time else 0
125
+
126
+ print("\n" + "=" * 70)
127
+ print("✅ Training Completed")
128
+ print("=" * 70)
129
+ print(f" Total time: {elapsed/60:.2f} minutes")
130
+ print(f" Best val loss: {final_metrics.get('best_val_loss', 'N/A')}")
131
+ print(f" Best checkpoint: {final_metrics.get('best_checkpoint', 'N/A')}")
132
+ print("=" * 70 + "\n")
133
+
134
+ def on_epoch_start(self, epoch: int) -> None:
135
+ """Called at start of each epoch."""
136
+ print(f"\n📅 Epoch {epoch + 1} starting...")
137
+
138
+ def on_epoch_end(self, epoch: int, metrics: Dict[str, float]) -> None:
139
+ """Called at end of each epoch."""
140
+ print(f"📅 Epoch {epoch + 1} completed")
141
+ if 'train_loss' in metrics:
142
+ print(f" Avg train loss: {metrics['train_loss']:.4f}")
143
+ if 'val_loss' in metrics:
144
+ print(f" Val loss: {metrics['val_loss']:.4f}")
145
+
146
+
147
+ class JSONLoggerCallback(TrainingObserver):
148
+ """
149
+ Logs metrics to JSONL file.
150
+
151
+ Format: One JSON object per line (JSONL)
152
+ Easy to parse and analyze with Python/pandas.
153
+
154
+ Example:
155
+ >>> callback = JSONLoggerCallback(log_dir='logs', experiment_name='exp1')
156
+ >>> trainer.add_observer(callback)
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ log_dir: str = "logs",
162
+ experiment_name: str = "training",
163
+ log_interval: int = 1,
164
+ name: str = "JSONLoggerCallback",
165
+ ):
166
+ """
167
+ Initialize JSONLoggerCallback.
168
+
169
+ Args:
170
+ log_dir: Directory for log files
171
+ experiment_name: Name for log files
172
+ log_interval: Log every N steps
173
+ name: Callback name
174
+ """
175
+ super().__init__(name=name)
176
+
177
+ self.log_dir = Path(log_dir)
178
+ self.log_dir.mkdir(parents=True, exist_ok=True)
179
+
180
+ self.experiment_name = experiment_name
181
+ self.log_interval = log_interval
182
+
183
+ self.log_file = self.log_dir / f"{experiment_name}_log.jsonl"
184
+ self.start_time: Optional[float] = None
185
+
186
+ def on_train_start(self, config: Dict[str, Any]) -> None:
187
+ """Called when training begins."""
188
+ self.start_time = time.time()
189
+
190
+ # Log config
191
+ self._write_log({
192
+ "event": "train_start",
193
+ "timestamp": 0,
194
+ "config": config,
195
+ })
196
+
197
+ def on_step(self, metrics: StepMetrics) -> None:
198
+ """Called after each training step."""
199
+ if metrics.step % self.log_interval != 0:
200
+ return
201
+
202
+ timestamp = time.time() - self.start_time if self.start_time else 0
203
+
204
+ self._write_log({
205
+ "event": "step",
206
+ "timestamp": timestamp,
207
+ **metrics.to_dict(),
208
+ })
209
+
210
+ def on_validation(self, metrics: ValidationMetrics) -> None:
211
+ """Called after validation."""
212
+ timestamp = time.time() - self.start_time if self.start_time else 0
213
+
214
+ self._write_log({
215
+ "event": "validation",
216
+ "timestamp": timestamp,
217
+ **metrics.to_dict(),
218
+ })
219
+
220
+ def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
221
+ """Called when training completes."""
222
+ timestamp = time.time() - self.start_time if self.start_time else 0
223
+
224
+ self._write_log({
225
+ "event": "train_end",
226
+ "timestamp": timestamp,
227
+ **final_metrics,
228
+ })
229
+
230
+ def on_checkpoint(self, step: int, checkpoint_path: str) -> None:
231
+ """Called when a checkpoint is saved."""
232
+ timestamp = time.time() - self.start_time if self.start_time else 0
233
+
234
+ self._write_log({
235
+ "event": "checkpoint",
236
+ "timestamp": timestamp,
237
+ "step": step,
238
+ "checkpoint_path": checkpoint_path,
239
+ })
240
+
241
+ def _write_log(self, data: Dict[str, Any]) -> None:
242
+ """Write log entry to file."""
243
+ with open(self.log_file, 'a') as f:
244
+ f.write(json.dumps(data, default=str) + '\n')
245
+
246
+
247
+ class EarlyStoppingCallback(TrainingObserver):
248
+ """
249
+ Stops training if validation loss doesn't improve.
250
+
251
+ Features:
252
+ - Patience: Number of validations to wait
253
+ - Min delta: Minimum improvement to count as progress
254
+ - Restore best: Flag to restore best weights (handled by trainer)
255
+
256
+ Example:
257
+ >>> callback = EarlyStoppingCallback(patience=5, min_delta=0.01)
258
+ >>> trainer.add_observer(callback)
259
+ >>> # During training, check: callback.should_stop
260
+ """
261
+
262
+ def __init__(
263
+ self,
264
+ patience: int = 5,
265
+ min_delta: float = 0.0,
266
+ name: str = "EarlyStoppingCallback",
267
+ ):
268
+ """
269
+ Initialize EarlyStoppingCallback.
270
+
271
+ Args:
272
+ patience: Number of validations without improvement before stopping
273
+ min_delta: Minimum change to qualify as improvement
274
+ name: Callback name
275
+ """
276
+ super().__init__(name=name)
277
+
278
+ self.patience = patience
279
+ self.min_delta = min_delta
280
+
281
+ self.best_loss: float = float('inf')
282
+ self.best_step: int = 0
283
+ self.wait_count: int = 0
284
+ self.should_stop: bool = False
285
+
286
+ def on_train_start(self, config: Dict[str, Any]) -> None:
287
+ """Reset state at training start."""
288
+ self.best_loss = float('inf')
289
+ self.best_step = 0
290
+ self.wait_count = 0
291
+ self.should_stop = False
292
+
293
+ def on_step(self, metrics: StepMetrics) -> None:
294
+ """Not used - we check on validation."""
295
+ pass
296
+
297
+ def on_validation(self, metrics: ValidationMetrics) -> None:
298
+ """Check if we should stop training."""
299
+ current_loss = metrics.val_loss
300
+
301
+ if current_loss < self.best_loss - self.min_delta:
302
+ # Improvement found
303
+ self.best_loss = current_loss
304
+ self.best_step = metrics.step
305
+ self.wait_count = 0
306
+ else:
307
+ # No improvement
308
+ self.wait_count += 1
309
+
310
+ if self.wait_count >= self.patience:
311
+ self.should_stop = True
312
+ print(f"\n⏹️ Early stopping triggered!")
313
+ print(f" No improvement for {self.patience} validations")
314
+ print(f" Best val loss: {self.best_loss:.4f} at step {self.best_step}")
315
+
316
+ def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
317
+ """Report final state."""
318
+ if self.should_stop:
319
+ print(f" Training stopped early at step {final_metrics.get('step', 'unknown')}")
320
+
321
+
322
+ class LRSchedulerCallback(TrainingObserver):
323
+ """
324
+ Monitors and can adjust learning rate based on training progress.
325
+
326
+ Features:
327
+ - Reduce LR on plateau
328
+ - Warmup monitoring
329
+ - LR range test (experimental)
330
+ """
331
+
332
+ def __init__(
333
+ self,
334
+ mode: str = 'monitor', # 'monitor' or 'plateau'
335
+ factor: float = 0.5,
336
+ patience: int = 3,
337
+ min_lr: float = 1e-7,
338
+ name: str = "LRSchedulerCallback",
339
+ ):
340
+ """
341
+ Initialize LRSchedulerCallback.
342
+
343
+ Args:
344
+ mode: 'monitor' (just watch) or 'plateau' (reduce on plateau)
345
+ factor: Factor to reduce LR by
346
+ patience: Validations without improvement before reducing
347
+ min_lr: Minimum learning rate
348
+ name: Callback name
349
+ """
350
+ super().__init__(name=name)
351
+
352
+ self.mode = mode
353
+ self.factor = factor
354
+ self.patience = patience
355
+ self.min_lr = min_lr
356
+
357
+ self.best_loss: float = float('inf')
358
+ self.wait_count: int = 0
359
+ self.lr_history: list = []
360
+ self.suggested_lr: Optional[float] = None
361
+
362
+ def on_train_start(self, config: Dict[str, Any]) -> None:
363
+ """Reset state."""
364
+ self.best_loss = float('inf')
365
+ self.wait_count = 0
366
+ self.lr_history = []
367
+ self.suggested_lr = None
368
+
369
+ def on_step(self, metrics: StepMetrics) -> None:
370
+ """Track learning rate."""
371
+ self.lr_history.append((metrics.step, metrics.learning_rate))
372
+
373
+ def on_validation(self, metrics: ValidationMetrics) -> None:
374
+ """Check for plateau."""
375
+ if self.mode != 'plateau':
376
+ return
377
+
378
+ if metrics.val_loss < self.best_loss:
379
+ self.best_loss = metrics.val_loss
380
+ self.wait_count = 0
381
+ else:
382
+ self.wait_count += 1
383
+
384
+ if self.wait_count >= self.patience:
385
+ if self.lr_history:
386
+ current_lr = self.lr_history[-1][1]
387
+ new_lr = max(current_lr * self.factor, self.min_lr)
388
+ self.suggested_lr = new_lr
389
+ print(f"\n📉 Suggest reducing LR: {current_lr:.2e} → {new_lr:.2e}")
390
+ self.wait_count = 0
391
+
392
+ def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
393
+ """Report LR summary."""
394
+ if self.lr_history:
395
+ initial_lr = self.lr_history[0][1]
396
+ final_lr = self.lr_history[-1][1]
397
+ print(f" LR range: {initial_lr:.2e} → {final_lr:.2e}")