gptmed 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -3
- gptmed/api.py +1 -0
- gptmed/configs/train_config.py +5 -0
- gptmed/model/__init__.py +2 -2
- gptmed/observability/__init__.py +43 -0
- gptmed/observability/base.py +369 -0
- gptmed/observability/callbacks.py +397 -0
- gptmed/observability/metrics_tracker.py +544 -0
- gptmed/services/training_service.py +161 -7
- gptmed/training/trainer.py +124 -10
- gptmed/utils/checkpoints.py +1 -1
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/METADATA +180 -43
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/RECORD +17 -13
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/WHEEL +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/entry_points.txt +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Training Callbacks
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Concrete observer implementations for common training needs.
|
|
6
|
+
|
|
7
|
+
CALLBACKS INCLUDED:
|
|
8
|
+
- ConsoleCallback: Print progress to console
|
|
9
|
+
- JSONLoggerCallback: Log to JSONL file
|
|
10
|
+
- EarlyStoppingCallback: Stop training if no improvement
|
|
11
|
+
- (Future) TensorBoardCallback: Log to TensorBoard
|
|
12
|
+
- (Future) WandBCallback: Log to Weights & Biases
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import time
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Dict, Any, Optional
|
|
19
|
+
|
|
20
|
+
from gptmed.observability.base import (
|
|
21
|
+
TrainingObserver,
|
|
22
|
+
StepMetrics,
|
|
23
|
+
ValidationMetrics,
|
|
24
|
+
GradientMetrics,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ConsoleCallback(TrainingObserver):
|
|
29
|
+
"""
|
|
30
|
+
Prints training progress to console.
|
|
31
|
+
|
|
32
|
+
Features:
|
|
33
|
+
- Colored output for warnings
|
|
34
|
+
- Progress bar style step counter
|
|
35
|
+
- Issue detection (NaN, high loss, etc.)
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> trainer.add_observer(ConsoleCallback(log_interval=100))
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
log_interval: int = 100,
|
|
44
|
+
show_progress_bar: bool = True,
|
|
45
|
+
name: str = "ConsoleCallback",
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Initialize ConsoleCallback.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
log_interval: Print every N steps
|
|
52
|
+
show_progress_bar: Whether to show progress indicators
|
|
53
|
+
name: Callback name
|
|
54
|
+
"""
|
|
55
|
+
super().__init__(name=name)
|
|
56
|
+
self.log_interval = log_interval
|
|
57
|
+
self.show_progress_bar = show_progress_bar
|
|
58
|
+
self.total_steps: int = 0
|
|
59
|
+
self.start_time: Optional[float] = None
|
|
60
|
+
|
|
61
|
+
def on_train_start(self, config: Dict[str, Any]) -> None:
|
|
62
|
+
"""Called when training begins."""
|
|
63
|
+
self.start_time = time.time()
|
|
64
|
+
self.total_steps = config.get('max_steps', 0) or config.get('total_steps', 0)
|
|
65
|
+
|
|
66
|
+
print("\n" + "=" * 70)
|
|
67
|
+
print("🚀 Training Started")
|
|
68
|
+
print("=" * 70)
|
|
69
|
+
print(f" Model: {config.get('model_size', 'unknown')}")
|
|
70
|
+
print(f" Device: {config.get('device', 'unknown')}")
|
|
71
|
+
print(f" Batch size: {config.get('batch_size', 'unknown')}")
|
|
72
|
+
print(f" Learning rate: {config.get('learning_rate', 'unknown')}")
|
|
73
|
+
print(f" Total steps: {self.total_steps}")
|
|
74
|
+
print("=" * 70 + "\n")
|
|
75
|
+
|
|
76
|
+
def on_step(self, metrics: StepMetrics) -> None:
|
|
77
|
+
"""Called after each training step."""
|
|
78
|
+
if metrics.step % self.log_interval != 0:
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
# Calculate progress
|
|
82
|
+
progress = ""
|
|
83
|
+
if self.total_steps > 0:
|
|
84
|
+
pct = (metrics.step / self.total_steps) * 100
|
|
85
|
+
progress = f"[{pct:5.1f}%] "
|
|
86
|
+
|
|
87
|
+
# Calculate speed
|
|
88
|
+
elapsed = time.time() - self.start_time if self.start_time else 0
|
|
89
|
+
steps_per_sec = metrics.step / elapsed if elapsed > 0 else 0
|
|
90
|
+
|
|
91
|
+
# Build message
|
|
92
|
+
msg = f"{progress}Step {metrics.step:6d} | "
|
|
93
|
+
msg += f"Loss: {metrics.loss:.4f} | "
|
|
94
|
+
msg += f"PPL: {metrics.perplexity:.2f} | "
|
|
95
|
+
msg += f"LR: {metrics.learning_rate:.2e} | "
|
|
96
|
+
msg += f"Grad: {metrics.grad_norm:.3f} | "
|
|
97
|
+
msg += f"{steps_per_sec:.1f} steps/s"
|
|
98
|
+
|
|
99
|
+
# Check for issues
|
|
100
|
+
warnings = []
|
|
101
|
+
if metrics.loss != metrics.loss: # NaN check
|
|
102
|
+
warnings.append("🔥 NaN LOSS!")
|
|
103
|
+
elif metrics.loss > 100:
|
|
104
|
+
warnings.append("⚠️ High loss")
|
|
105
|
+
|
|
106
|
+
if metrics.grad_norm > 10:
|
|
107
|
+
warnings.append("⚠️ Large grads")
|
|
108
|
+
|
|
109
|
+
if warnings:
|
|
110
|
+
msg += " | " + " ".join(warnings)
|
|
111
|
+
|
|
112
|
+
print(msg)
|
|
113
|
+
|
|
114
|
+
def on_validation(self, metrics: ValidationMetrics) -> None:
|
|
115
|
+
"""Called after validation."""
|
|
116
|
+
print(f"\n{'─' * 50}")
|
|
117
|
+
print(f"📊 Validation @ Step {metrics.step}")
|
|
118
|
+
print(f" Val Loss: {metrics.val_loss:.4f}")
|
|
119
|
+
print(f" Val PPL: {metrics.val_perplexity:.2f}")
|
|
120
|
+
print(f"{'─' * 50}\n")
|
|
121
|
+
|
|
122
|
+
def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
|
|
123
|
+
"""Called when training completes."""
|
|
124
|
+
elapsed = time.time() - self.start_time if self.start_time else 0
|
|
125
|
+
|
|
126
|
+
print("\n" + "=" * 70)
|
|
127
|
+
print("✅ Training Completed")
|
|
128
|
+
print("=" * 70)
|
|
129
|
+
print(f" Total time: {elapsed/60:.2f} minutes")
|
|
130
|
+
print(f" Best val loss: {final_metrics.get('best_val_loss', 'N/A')}")
|
|
131
|
+
print(f" Best checkpoint: {final_metrics.get('best_checkpoint', 'N/A')}")
|
|
132
|
+
print("=" * 70 + "\n")
|
|
133
|
+
|
|
134
|
+
def on_epoch_start(self, epoch: int) -> None:
|
|
135
|
+
"""Called at start of each epoch."""
|
|
136
|
+
print(f"\n📅 Epoch {epoch + 1} starting...")
|
|
137
|
+
|
|
138
|
+
def on_epoch_end(self, epoch: int, metrics: Dict[str, float]) -> None:
|
|
139
|
+
"""Called at end of each epoch."""
|
|
140
|
+
print(f"📅 Epoch {epoch + 1} completed")
|
|
141
|
+
if 'train_loss' in metrics:
|
|
142
|
+
print(f" Avg train loss: {metrics['train_loss']:.4f}")
|
|
143
|
+
if 'val_loss' in metrics:
|
|
144
|
+
print(f" Val loss: {metrics['val_loss']:.4f}")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class JSONLoggerCallback(TrainingObserver):
|
|
148
|
+
"""
|
|
149
|
+
Logs metrics to JSONL file.
|
|
150
|
+
|
|
151
|
+
Format: One JSON object per line (JSONL)
|
|
152
|
+
Easy to parse and analyze with Python/pandas.
|
|
153
|
+
|
|
154
|
+
Example:
|
|
155
|
+
>>> callback = JSONLoggerCallback(log_dir='logs', experiment_name='exp1')
|
|
156
|
+
>>> trainer.add_observer(callback)
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
log_dir: str = "logs",
|
|
162
|
+
experiment_name: str = "training",
|
|
163
|
+
log_interval: int = 1,
|
|
164
|
+
name: str = "JSONLoggerCallback",
|
|
165
|
+
):
|
|
166
|
+
"""
|
|
167
|
+
Initialize JSONLoggerCallback.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
log_dir: Directory for log files
|
|
171
|
+
experiment_name: Name for log files
|
|
172
|
+
log_interval: Log every N steps
|
|
173
|
+
name: Callback name
|
|
174
|
+
"""
|
|
175
|
+
super().__init__(name=name)
|
|
176
|
+
|
|
177
|
+
self.log_dir = Path(log_dir)
|
|
178
|
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
179
|
+
|
|
180
|
+
self.experiment_name = experiment_name
|
|
181
|
+
self.log_interval = log_interval
|
|
182
|
+
|
|
183
|
+
self.log_file = self.log_dir / f"{experiment_name}_log.jsonl"
|
|
184
|
+
self.start_time: Optional[float] = None
|
|
185
|
+
|
|
186
|
+
def on_train_start(self, config: Dict[str, Any]) -> None:
|
|
187
|
+
"""Called when training begins."""
|
|
188
|
+
self.start_time = time.time()
|
|
189
|
+
|
|
190
|
+
# Log config
|
|
191
|
+
self._write_log({
|
|
192
|
+
"event": "train_start",
|
|
193
|
+
"timestamp": 0,
|
|
194
|
+
"config": config,
|
|
195
|
+
})
|
|
196
|
+
|
|
197
|
+
def on_step(self, metrics: StepMetrics) -> None:
|
|
198
|
+
"""Called after each training step."""
|
|
199
|
+
if metrics.step % self.log_interval != 0:
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
timestamp = time.time() - self.start_time if self.start_time else 0
|
|
203
|
+
|
|
204
|
+
self._write_log({
|
|
205
|
+
"event": "step",
|
|
206
|
+
"timestamp": timestamp,
|
|
207
|
+
**metrics.to_dict(),
|
|
208
|
+
})
|
|
209
|
+
|
|
210
|
+
def on_validation(self, metrics: ValidationMetrics) -> None:
|
|
211
|
+
"""Called after validation."""
|
|
212
|
+
timestamp = time.time() - self.start_time if self.start_time else 0
|
|
213
|
+
|
|
214
|
+
self._write_log({
|
|
215
|
+
"event": "validation",
|
|
216
|
+
"timestamp": timestamp,
|
|
217
|
+
**metrics.to_dict(),
|
|
218
|
+
})
|
|
219
|
+
|
|
220
|
+
def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
|
|
221
|
+
"""Called when training completes."""
|
|
222
|
+
timestamp = time.time() - self.start_time if self.start_time else 0
|
|
223
|
+
|
|
224
|
+
self._write_log({
|
|
225
|
+
"event": "train_end",
|
|
226
|
+
"timestamp": timestamp,
|
|
227
|
+
**final_metrics,
|
|
228
|
+
})
|
|
229
|
+
|
|
230
|
+
def on_checkpoint(self, step: int, checkpoint_path: str) -> None:
|
|
231
|
+
"""Called when a checkpoint is saved."""
|
|
232
|
+
timestamp = time.time() - self.start_time if self.start_time else 0
|
|
233
|
+
|
|
234
|
+
self._write_log({
|
|
235
|
+
"event": "checkpoint",
|
|
236
|
+
"timestamp": timestamp,
|
|
237
|
+
"step": step,
|
|
238
|
+
"checkpoint_path": checkpoint_path,
|
|
239
|
+
})
|
|
240
|
+
|
|
241
|
+
def _write_log(self, data: Dict[str, Any]) -> None:
|
|
242
|
+
"""Write log entry to file."""
|
|
243
|
+
with open(self.log_file, 'a') as f:
|
|
244
|
+
f.write(json.dumps(data, default=str) + '\n')
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class EarlyStoppingCallback(TrainingObserver):
|
|
248
|
+
"""
|
|
249
|
+
Stops training if validation loss doesn't improve.
|
|
250
|
+
|
|
251
|
+
Features:
|
|
252
|
+
- Patience: Number of validations to wait
|
|
253
|
+
- Min delta: Minimum improvement to count as progress
|
|
254
|
+
- Restore best: Flag to restore best weights (handled by trainer)
|
|
255
|
+
|
|
256
|
+
Example:
|
|
257
|
+
>>> callback = EarlyStoppingCallback(patience=5, min_delta=0.01)
|
|
258
|
+
>>> trainer.add_observer(callback)
|
|
259
|
+
>>> # During training, check: callback.should_stop
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
def __init__(
|
|
263
|
+
self,
|
|
264
|
+
patience: int = 5,
|
|
265
|
+
min_delta: float = 0.0,
|
|
266
|
+
name: str = "EarlyStoppingCallback",
|
|
267
|
+
):
|
|
268
|
+
"""
|
|
269
|
+
Initialize EarlyStoppingCallback.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
patience: Number of validations without improvement before stopping
|
|
273
|
+
min_delta: Minimum change to qualify as improvement
|
|
274
|
+
name: Callback name
|
|
275
|
+
"""
|
|
276
|
+
super().__init__(name=name)
|
|
277
|
+
|
|
278
|
+
self.patience = patience
|
|
279
|
+
self.min_delta = min_delta
|
|
280
|
+
|
|
281
|
+
self.best_loss: float = float('inf')
|
|
282
|
+
self.best_step: int = 0
|
|
283
|
+
self.wait_count: int = 0
|
|
284
|
+
self.should_stop: bool = False
|
|
285
|
+
|
|
286
|
+
def on_train_start(self, config: Dict[str, Any]) -> None:
|
|
287
|
+
"""Reset state at training start."""
|
|
288
|
+
self.best_loss = float('inf')
|
|
289
|
+
self.best_step = 0
|
|
290
|
+
self.wait_count = 0
|
|
291
|
+
self.should_stop = False
|
|
292
|
+
|
|
293
|
+
def on_step(self, metrics: StepMetrics) -> None:
|
|
294
|
+
"""Not used - we check on validation."""
|
|
295
|
+
pass
|
|
296
|
+
|
|
297
|
+
def on_validation(self, metrics: ValidationMetrics) -> None:
|
|
298
|
+
"""Check if we should stop training."""
|
|
299
|
+
current_loss = metrics.val_loss
|
|
300
|
+
|
|
301
|
+
if current_loss < self.best_loss - self.min_delta:
|
|
302
|
+
# Improvement found
|
|
303
|
+
self.best_loss = current_loss
|
|
304
|
+
self.best_step = metrics.step
|
|
305
|
+
self.wait_count = 0
|
|
306
|
+
else:
|
|
307
|
+
# No improvement
|
|
308
|
+
self.wait_count += 1
|
|
309
|
+
|
|
310
|
+
if self.wait_count >= self.patience:
|
|
311
|
+
self.should_stop = True
|
|
312
|
+
print(f"\n⏹️ Early stopping triggered!")
|
|
313
|
+
print(f" No improvement for {self.patience} validations")
|
|
314
|
+
print(f" Best val loss: {self.best_loss:.4f} at step {self.best_step}")
|
|
315
|
+
|
|
316
|
+
def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
|
|
317
|
+
"""Report final state."""
|
|
318
|
+
if self.should_stop:
|
|
319
|
+
print(f" Training stopped early at step {final_metrics.get('step', 'unknown')}")
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class LRSchedulerCallback(TrainingObserver):
|
|
323
|
+
"""
|
|
324
|
+
Monitors and can adjust learning rate based on training progress.
|
|
325
|
+
|
|
326
|
+
Features:
|
|
327
|
+
- Reduce LR on plateau
|
|
328
|
+
- Warmup monitoring
|
|
329
|
+
- LR range test (experimental)
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
def __init__(
|
|
333
|
+
self,
|
|
334
|
+
mode: str = 'monitor', # 'monitor' or 'plateau'
|
|
335
|
+
factor: float = 0.5,
|
|
336
|
+
patience: int = 3,
|
|
337
|
+
min_lr: float = 1e-7,
|
|
338
|
+
name: str = "LRSchedulerCallback",
|
|
339
|
+
):
|
|
340
|
+
"""
|
|
341
|
+
Initialize LRSchedulerCallback.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
mode: 'monitor' (just watch) or 'plateau' (reduce on plateau)
|
|
345
|
+
factor: Factor to reduce LR by
|
|
346
|
+
patience: Validations without improvement before reducing
|
|
347
|
+
min_lr: Minimum learning rate
|
|
348
|
+
name: Callback name
|
|
349
|
+
"""
|
|
350
|
+
super().__init__(name=name)
|
|
351
|
+
|
|
352
|
+
self.mode = mode
|
|
353
|
+
self.factor = factor
|
|
354
|
+
self.patience = patience
|
|
355
|
+
self.min_lr = min_lr
|
|
356
|
+
|
|
357
|
+
self.best_loss: float = float('inf')
|
|
358
|
+
self.wait_count: int = 0
|
|
359
|
+
self.lr_history: list = []
|
|
360
|
+
self.suggested_lr: Optional[float] = None
|
|
361
|
+
|
|
362
|
+
def on_train_start(self, config: Dict[str, Any]) -> None:
|
|
363
|
+
"""Reset state."""
|
|
364
|
+
self.best_loss = float('inf')
|
|
365
|
+
self.wait_count = 0
|
|
366
|
+
self.lr_history = []
|
|
367
|
+
self.suggested_lr = None
|
|
368
|
+
|
|
369
|
+
def on_step(self, metrics: StepMetrics) -> None:
|
|
370
|
+
"""Track learning rate."""
|
|
371
|
+
self.lr_history.append((metrics.step, metrics.learning_rate))
|
|
372
|
+
|
|
373
|
+
def on_validation(self, metrics: ValidationMetrics) -> None:
|
|
374
|
+
"""Check for plateau."""
|
|
375
|
+
if self.mode != 'plateau':
|
|
376
|
+
return
|
|
377
|
+
|
|
378
|
+
if metrics.val_loss < self.best_loss:
|
|
379
|
+
self.best_loss = metrics.val_loss
|
|
380
|
+
self.wait_count = 0
|
|
381
|
+
else:
|
|
382
|
+
self.wait_count += 1
|
|
383
|
+
|
|
384
|
+
if self.wait_count >= self.patience:
|
|
385
|
+
if self.lr_history:
|
|
386
|
+
current_lr = self.lr_history[-1][1]
|
|
387
|
+
new_lr = max(current_lr * self.factor, self.min_lr)
|
|
388
|
+
self.suggested_lr = new_lr
|
|
389
|
+
print(f"\n📉 Suggest reducing LR: {current_lr:.2e} → {new_lr:.2e}")
|
|
390
|
+
self.wait_count = 0
|
|
391
|
+
|
|
392
|
+
def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
|
|
393
|
+
"""Report LR summary."""
|
|
394
|
+
if self.lr_history:
|
|
395
|
+
initial_lr = self.lr_history[0][1]
|
|
396
|
+
final_lr = self.lr_history[-1][1]
|
|
397
|
+
print(f" LR range: {initial_lr:.2e} → {final_lr:.2e}")
|