gptmed 0.3.5__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -3
- gptmed/model/__init__.py +2 -2
- gptmed/observability/__init__.py +43 -0
- gptmed/observability/base.py +369 -0
- gptmed/observability/callbacks.py +397 -0
- gptmed/observability/metrics_tracker.py +544 -0
- gptmed/services/training_service.py +161 -7
- gptmed/training/trainer.py +124 -10
- gptmed/utils/checkpoints.py +1 -1
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/METADATA +180 -43
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/RECORD +15 -11
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/WHEEL +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/entry_points.txt +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -26,7 +26,7 @@ import torch
|
|
|
26
26
|
import random
|
|
27
27
|
import numpy as np
|
|
28
28
|
from pathlib import Path
|
|
29
|
-
from typing import Dict, Any, Optional
|
|
29
|
+
from typing import Dict, Any, Optional, List
|
|
30
30
|
|
|
31
31
|
from gptmed.services.device_manager import DeviceManager
|
|
32
32
|
from gptmed.model.architecture import GPTTransformer
|
|
@@ -35,6 +35,13 @@ from gptmed.configs.train_config import TrainingConfig
|
|
|
35
35
|
from gptmed.training.dataset import create_dataloaders
|
|
36
36
|
from gptmed.training.trainer import Trainer
|
|
37
37
|
|
|
38
|
+
# Observability imports
|
|
39
|
+
from gptmed.observability import (
|
|
40
|
+
TrainingObserver,
|
|
41
|
+
MetricsTracker,
|
|
42
|
+
ConsoleCallback,
|
|
43
|
+
)
|
|
44
|
+
|
|
38
45
|
|
|
39
46
|
class TrainingService:
|
|
40
47
|
"""
|
|
@@ -162,7 +169,8 @@ class TrainingService:
|
|
|
162
169
|
optimizer,
|
|
163
170
|
train_config: TrainingConfig,
|
|
164
171
|
device: str,
|
|
165
|
-
model_config_dict: dict
|
|
172
|
+
model_config_dict: dict,
|
|
173
|
+
observers: Optional[List[TrainingObserver]] = None,
|
|
166
174
|
) -> Dict[str, Any]:
|
|
167
175
|
"""
|
|
168
176
|
Execute the training process.
|
|
@@ -175,13 +183,20 @@ class TrainingService:
|
|
|
175
183
|
train_config: Training configuration
|
|
176
184
|
device: Device to use
|
|
177
185
|
model_config_dict: Model configuration as dictionary
|
|
186
|
+
observers: Optional list of TrainingObserver instances.
|
|
187
|
+
If None, default observers (MetricsTracker) will be used.
|
|
178
188
|
|
|
179
189
|
Returns:
|
|
180
190
|
Dictionary with training results
|
|
181
191
|
"""
|
|
182
|
-
#
|
|
192
|
+
# Set up default observers if none provided
|
|
193
|
+
if observers is None:
|
|
194
|
+
observers = self._create_default_observers(train_config)
|
|
195
|
+
|
|
196
|
+
# Create trainer with observers
|
|
183
197
|
if self.verbose:
|
|
184
198
|
print(f"\n🎯 Initializing trainer...")
|
|
199
|
+
print(f" Observers: {len(observers)} ({', '.join(o.name for o in observers)})")
|
|
185
200
|
|
|
186
201
|
trainer = Trainer(
|
|
187
202
|
model=model,
|
|
@@ -190,6 +205,7 @@ class TrainingService:
|
|
|
190
205
|
optimizer=optimizer,
|
|
191
206
|
config=train_config,
|
|
192
207
|
device=device,
|
|
208
|
+
observers=observers,
|
|
193
209
|
)
|
|
194
210
|
|
|
195
211
|
# Resume if requested
|
|
@@ -209,9 +225,11 @@ class TrainingService:
|
|
|
209
225
|
print("🚀 Starting Training!")
|
|
210
226
|
print(f"{'='*60}\n")
|
|
211
227
|
|
|
228
|
+
interrupted = False
|
|
212
229
|
try:
|
|
213
230
|
trainer.train()
|
|
214
231
|
except KeyboardInterrupt:
|
|
232
|
+
interrupted = True
|
|
215
233
|
if self.verbose:
|
|
216
234
|
print("\n\n⏸️ Training interrupted by user")
|
|
217
235
|
print("💾 Saving checkpoint...")
|
|
@@ -227,29 +245,165 @@ class TrainingService:
|
|
|
227
245
|
if self.verbose:
|
|
228
246
|
print("✓ Checkpoint saved. Resume with resume_from in config.")
|
|
229
247
|
|
|
248
|
+
# Generate observability reports (on BOTH normal and abnormal exit)
|
|
249
|
+
self._generate_observability_reports(
|
|
250
|
+
observers=observers,
|
|
251
|
+
train_config=train_config,
|
|
252
|
+
trainer=trainer,
|
|
253
|
+
interrupted=interrupted,
|
|
254
|
+
)
|
|
255
|
+
|
|
230
256
|
# Return results
|
|
231
|
-
|
|
257
|
+
final_checkpoint = Path(train_config.checkpoint_dir) / "final_model.pt"
|
|
232
258
|
|
|
233
259
|
results = {
|
|
234
|
-
'
|
|
260
|
+
'final_checkpoint': str(final_checkpoint),
|
|
261
|
+
'best_checkpoint': str(final_checkpoint), # Alias for backward compatibility
|
|
235
262
|
'final_val_loss': trainer.best_val_loss,
|
|
236
263
|
'total_epochs': trainer.current_epoch,
|
|
264
|
+
'total_steps': trainer.global_step,
|
|
237
265
|
'checkpoint_dir': train_config.checkpoint_dir,
|
|
238
266
|
'log_dir': train_config.log_dir,
|
|
267
|
+
'interrupted': interrupted,
|
|
239
268
|
}
|
|
240
269
|
|
|
270
|
+
# Get training issues from metrics tracker
|
|
271
|
+
metrics_tracker = self._get_metrics_tracker(observers)
|
|
272
|
+
if metrics_tracker:
|
|
273
|
+
results['training_issues'] = metrics_tracker.detect_issues()
|
|
274
|
+
|
|
241
275
|
if self.verbose:
|
|
276
|
+
status = "⏸️ Training Interrupted" if interrupted else "✅ Training Complete!"
|
|
242
277
|
print(f"\n{'='*60}")
|
|
243
|
-
print(
|
|
278
|
+
print(status)
|
|
244
279
|
print(f"{'='*60}")
|
|
245
280
|
print(f"\n📁 Results:")
|
|
246
|
-
print(f"
|
|
281
|
+
print(f" Final checkpoint: {results['final_checkpoint']}")
|
|
247
282
|
print(f" Best val loss: {results['final_val_loss']:.4f}")
|
|
283
|
+
print(f" Total steps: {results['total_steps']}")
|
|
248
284
|
print(f" Total epochs: {results['total_epochs']}")
|
|
249
285
|
print(f" Logs: {results['log_dir']}")
|
|
250
286
|
|
|
251
287
|
return results
|
|
252
288
|
|
|
289
|
+
def _generate_observability_reports(
|
|
290
|
+
self,
|
|
291
|
+
observers: List[TrainingObserver],
|
|
292
|
+
train_config: TrainingConfig,
|
|
293
|
+
trainer,
|
|
294
|
+
interrupted: bool = False,
|
|
295
|
+
) -> None:
|
|
296
|
+
"""
|
|
297
|
+
Generate observability reports from metrics tracker.
|
|
298
|
+
|
|
299
|
+
Called on both normal completion and abnormal exit (Ctrl+C).
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
observers: List of training observers
|
|
303
|
+
train_config: Training configuration
|
|
304
|
+
trainer: Trainer instance
|
|
305
|
+
interrupted: Whether training was interrupted
|
|
306
|
+
"""
|
|
307
|
+
metrics_tracker = self._get_metrics_tracker(observers)
|
|
308
|
+
|
|
309
|
+
if not metrics_tracker:
|
|
310
|
+
if self.verbose:
|
|
311
|
+
print("\n⚠️ No MetricsTracker found - skipping observability reports")
|
|
312
|
+
return
|
|
313
|
+
|
|
314
|
+
if self.verbose:
|
|
315
|
+
print(f"\n{'='*60}")
|
|
316
|
+
print("📊 Generating Observability Reports")
|
|
317
|
+
print(f"{'='*60}")
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
# Export metrics to CSV
|
|
321
|
+
csv_path = metrics_tracker.export_to_csv()
|
|
322
|
+
if self.verbose:
|
|
323
|
+
print(f" ✓ CSV exported: {csv_path}")
|
|
324
|
+
except Exception as e:
|
|
325
|
+
if self.verbose:
|
|
326
|
+
print(f" ✗ CSV export failed: {e}")
|
|
327
|
+
|
|
328
|
+
try:
|
|
329
|
+
# Export metrics to JSON
|
|
330
|
+
json_path = metrics_tracker.export_to_json()
|
|
331
|
+
if self.verbose:
|
|
332
|
+
print(f" ✓ JSON exported: {json_path}")
|
|
333
|
+
except Exception as e:
|
|
334
|
+
if self.verbose:
|
|
335
|
+
print(f" ✗ JSON export failed: {e}")
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
# Generate loss curve plots
|
|
339
|
+
plot_path = metrics_tracker.plot_loss_curves()
|
|
340
|
+
if plot_path and self.verbose:
|
|
341
|
+
print(f" ✓ Loss curves plotted: {plot_path}")
|
|
342
|
+
except ImportError:
|
|
343
|
+
if self.verbose:
|
|
344
|
+
print(f" ⚠️ Plotting skipped (matplotlib not installed)")
|
|
345
|
+
except Exception as e:
|
|
346
|
+
if self.verbose:
|
|
347
|
+
print(f" ✗ Plotting failed: {e}")
|
|
348
|
+
|
|
349
|
+
# Training health check
|
|
350
|
+
issues = metrics_tracker.detect_issues()
|
|
351
|
+
if self.verbose:
|
|
352
|
+
print(f"\n📋 Training Health Check:")
|
|
353
|
+
for issue in issues:
|
|
354
|
+
print(f" {issue}")
|
|
355
|
+
|
|
356
|
+
# Add interrupted notice if applicable
|
|
357
|
+
if interrupted and self.verbose:
|
|
358
|
+
print(f"\n⚠️ Note: Training was interrupted at step {trainer.global_step}")
|
|
359
|
+
print(f" Reports reflect partial training data only.")
|
|
360
|
+
|
|
361
|
+
def _create_default_observers(self, train_config: TrainingConfig) -> List[TrainingObserver]:
|
|
362
|
+
"""
|
|
363
|
+
Create default observers for training.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
train_config: Training configuration
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
List of default TrainingObserver instances
|
|
370
|
+
"""
|
|
371
|
+
observers = []
|
|
372
|
+
|
|
373
|
+
# MetricsTracker - comprehensive metrics logging
|
|
374
|
+
metrics_tracker = MetricsTracker(
|
|
375
|
+
log_dir=train_config.log_dir,
|
|
376
|
+
experiment_name="gptmed_training",
|
|
377
|
+
moving_avg_window=100,
|
|
378
|
+
log_interval=train_config.log_interval,
|
|
379
|
+
verbose=self.verbose,
|
|
380
|
+
)
|
|
381
|
+
observers.append(metrics_tracker)
|
|
382
|
+
|
|
383
|
+
# Note: ConsoleCallback is optional since Trainer already has console output
|
|
384
|
+
# Uncomment if you want additional formatted console output:
|
|
385
|
+
# console_callback = ConsoleCallback(log_interval=train_config.log_interval)
|
|
386
|
+
# observers.append(console_callback)
|
|
387
|
+
|
|
388
|
+
return observers
|
|
389
|
+
|
|
390
|
+
def _get_metrics_tracker(self, observers: List[TrainingObserver]) -> Optional[MetricsTracker]:
|
|
391
|
+
"""
|
|
392
|
+
Get MetricsTracker from observers list if present.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
observers: List of observers
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
MetricsTracker instance or None
|
|
399
|
+
"""
|
|
400
|
+
for obs in observers:
|
|
401
|
+
if isinstance(obs, MetricsTracker):
|
|
402
|
+
return obs
|
|
403
|
+
return None
|
|
404
|
+
|
|
405
|
+
return results
|
|
406
|
+
|
|
253
407
|
def train(
|
|
254
408
|
self,
|
|
255
409
|
model_size: str,
|
gptmed/training/trainer.py
CHANGED
|
@@ -49,7 +49,7 @@ import torch.nn as nn
|
|
|
49
49
|
from torch.utils.data import DataLoader
|
|
50
50
|
import time
|
|
51
51
|
from pathlib import Path
|
|
52
|
-
from typing import Optional
|
|
52
|
+
from typing import Optional, List
|
|
53
53
|
|
|
54
54
|
from gptmed.model.architecture import GPTTransformer
|
|
55
55
|
from gptmed.training.utils import (
|
|
@@ -62,6 +62,15 @@ from gptmed.training.utils import (
|
|
|
62
62
|
from gptmed.utils.logging import MetricsLogger, log_training_step, log_validation
|
|
63
63
|
from gptmed.utils.checkpoints import CheckpointManager
|
|
64
64
|
|
|
65
|
+
# New observability imports
|
|
66
|
+
from gptmed.observability.base import (
|
|
67
|
+
TrainingObserver,
|
|
68
|
+
ObserverManager,
|
|
69
|
+
StepMetrics,
|
|
70
|
+
ValidationMetrics,
|
|
71
|
+
GradientMetrics,
|
|
72
|
+
)
|
|
73
|
+
|
|
65
74
|
|
|
66
75
|
class Trainer:
|
|
67
76
|
"""
|
|
@@ -83,6 +92,7 @@ class Trainer:
|
|
|
83
92
|
optimizer: torch.optim.Optimizer,
|
|
84
93
|
config, # TrainingConfig
|
|
85
94
|
device: str = "cuda",
|
|
95
|
+
observers: List[TrainingObserver] = None,
|
|
86
96
|
):
|
|
87
97
|
"""
|
|
88
98
|
Args:
|
|
@@ -92,6 +102,7 @@ class Trainer:
|
|
|
92
102
|
optimizer: Optimizer (e.g., AdamW)
|
|
93
103
|
config: TrainingConfig object
|
|
94
104
|
device: Device to train on
|
|
105
|
+
observers: List of TrainingObserver instances for monitoring
|
|
95
106
|
"""
|
|
96
107
|
self.model = model.to(device)
|
|
97
108
|
self.train_loader = train_loader
|
|
@@ -100,7 +111,13 @@ class Trainer:
|
|
|
100
111
|
self.config = config
|
|
101
112
|
self.device = device
|
|
102
113
|
|
|
103
|
-
# Initialize
|
|
114
|
+
# Initialize observability
|
|
115
|
+
self.observer_manager = ObserverManager()
|
|
116
|
+
if observers:
|
|
117
|
+
for obs in observers:
|
|
118
|
+
self.observer_manager.add(obs)
|
|
119
|
+
|
|
120
|
+
# Initialize utilities (keep for backward compatibility)
|
|
104
121
|
self.logger = MetricsLogger(log_dir=config.log_dir, experiment_name="gpt_training")
|
|
105
122
|
|
|
106
123
|
self.checkpoint_manager = CheckpointManager(
|
|
@@ -124,17 +141,32 @@ class Trainer:
|
|
|
124
141
|
print(f" Total steps: {self.total_steps}")
|
|
125
142
|
print(f" Steps per epoch: {steps_per_epoch}")
|
|
126
143
|
print(f" Num epochs: {config.num_epochs}")
|
|
144
|
+
print(f" Observers: {len(self.observer_manager.observers)}")
|
|
145
|
+
|
|
146
|
+
def add_observer(self, observer: TrainingObserver) -> None:
|
|
147
|
+
"""
|
|
148
|
+
Add an observer for training monitoring.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
observer: TrainingObserver instance
|
|
152
|
+
"""
|
|
153
|
+
self.observer_manager.add(observer)
|
|
154
|
+
print(f" Added observer: {observer.name}")
|
|
127
155
|
|
|
128
|
-
def train_step(self, batch: tuple) -> dict:
|
|
156
|
+
def train_step(self, batch: tuple, step: int = 0, lr: float = 0.0) -> dict:
|
|
129
157
|
"""
|
|
130
158
|
Single training step.
|
|
131
159
|
|
|
132
160
|
Args:
|
|
133
161
|
batch: (input_ids, target_ids) tuple
|
|
162
|
+
step: Current global step (for observer metrics)
|
|
163
|
+
lr: Current learning rate (for observer metrics)
|
|
134
164
|
|
|
135
165
|
Returns:
|
|
136
166
|
Dictionary with step metrics
|
|
137
167
|
"""
|
|
168
|
+
step_start_time = time.time()
|
|
169
|
+
|
|
138
170
|
# Move batch to device
|
|
139
171
|
input_ids, target_ids = batch
|
|
140
172
|
input_ids = input_ids.to(self.device)
|
|
@@ -163,14 +195,34 @@ class Trainer:
|
|
|
163
195
|
# Optimizer step
|
|
164
196
|
self.optimizer.step()
|
|
165
197
|
|
|
166
|
-
#
|
|
167
|
-
|
|
198
|
+
# Calculate tokens per second
|
|
199
|
+
step_time = time.time() - step_start_time
|
|
200
|
+
tokens_per_sec = (batch_size * seq_len) / step_time if step_time > 0 else 0
|
|
201
|
+
|
|
202
|
+
# Create metrics dict (for backward compatibility)
|
|
203
|
+
metrics_dict = {
|
|
168
204
|
"loss": loss.item(),
|
|
169
205
|
"grad_norm": grad_norm,
|
|
170
206
|
"batch_size": batch_size,
|
|
171
207
|
"seq_len": seq_len,
|
|
208
|
+
"tokens_per_sec": tokens_per_sec,
|
|
172
209
|
}
|
|
173
210
|
|
|
211
|
+
# Notify observers with StepMetrics
|
|
212
|
+
step_metrics = StepMetrics(
|
|
213
|
+
step=step,
|
|
214
|
+
loss=loss.item(),
|
|
215
|
+
learning_rate=lr,
|
|
216
|
+
grad_norm=grad_norm,
|
|
217
|
+
batch_size=batch_size,
|
|
218
|
+
seq_len=seq_len,
|
|
219
|
+
tokens_per_sec=tokens_per_sec,
|
|
220
|
+
)
|
|
221
|
+
self.observer_manager.notify_step(step_metrics)
|
|
222
|
+
|
|
223
|
+
# Return metrics
|
|
224
|
+
return metrics_dict
|
|
225
|
+
|
|
174
226
|
def evaluate(self) -> dict:
|
|
175
227
|
"""
|
|
176
228
|
Evaluate on validation set.
|
|
@@ -188,6 +240,14 @@ class Trainer:
|
|
|
188
240
|
|
|
189
241
|
log_validation(self.global_step, val_loss, val_perplexity)
|
|
190
242
|
|
|
243
|
+
# Notify observers
|
|
244
|
+
val_metrics = ValidationMetrics(
|
|
245
|
+
step=self.global_step,
|
|
246
|
+
val_loss=val_loss,
|
|
247
|
+
val_perplexity=val_perplexity,
|
|
248
|
+
)
|
|
249
|
+
self.observer_manager.notify_validation(val_metrics)
|
|
250
|
+
|
|
191
251
|
return {"val_loss": val_loss, "val_perplexity": val_perplexity}
|
|
192
252
|
|
|
193
253
|
def train(self):
|
|
@@ -200,17 +260,37 @@ class Trainer:
|
|
|
200
260
|
print("Starting Training")
|
|
201
261
|
print("=" * 60)
|
|
202
262
|
|
|
263
|
+
# Notify observers of training start
|
|
264
|
+
train_config = {
|
|
265
|
+
"model_size": getattr(self.model.config, 'model_size', 'unknown'),
|
|
266
|
+
"device": self.device,
|
|
267
|
+
"batch_size": self.config.batch_size,
|
|
268
|
+
"learning_rate": self.config.learning_rate,
|
|
269
|
+
"num_epochs": self.config.num_epochs,
|
|
270
|
+
"max_steps": self.config.max_steps,
|
|
271
|
+
"total_steps": self.total_steps,
|
|
272
|
+
"warmup_steps": self.config.warmup_steps,
|
|
273
|
+
"grad_clip": self.config.grad_clip,
|
|
274
|
+
"weight_decay": self.config.weight_decay,
|
|
275
|
+
}
|
|
276
|
+
self.observer_manager.notify_train_start(train_config)
|
|
277
|
+
|
|
203
278
|
self.model.train()
|
|
204
279
|
|
|
205
280
|
# Training loop
|
|
206
281
|
for epoch in range(self.config.num_epochs):
|
|
207
282
|
self.current_epoch = epoch
|
|
208
283
|
|
|
284
|
+
# Notify observers of epoch start
|
|
285
|
+
self.observer_manager.notify_epoch_start(epoch)
|
|
286
|
+
|
|
209
287
|
print(f"\n{'='*60}")
|
|
210
288
|
print(f"Epoch {epoch + 1}/{self.config.num_epochs}")
|
|
211
289
|
print(f"{'='*60}")
|
|
212
290
|
|
|
213
291
|
epoch_start_time = time.time()
|
|
292
|
+
epoch_loss_sum = 0.0
|
|
293
|
+
epoch_steps = 0
|
|
214
294
|
|
|
215
295
|
for batch_idx, batch in enumerate(self.train_loader):
|
|
216
296
|
step_start_time = time.time()
|
|
@@ -226,8 +306,12 @@ class Trainer:
|
|
|
226
306
|
)
|
|
227
307
|
set_learning_rate(self.optimizer, lr)
|
|
228
308
|
|
|
229
|
-
# Training step
|
|
230
|
-
metrics = self.train_step(batch)
|
|
309
|
+
# Training step (now with step and lr for observers)
|
|
310
|
+
metrics = self.train_step(batch, step=self.global_step, lr=lr)
|
|
311
|
+
|
|
312
|
+
# Track epoch loss
|
|
313
|
+
epoch_loss_sum += metrics["loss"]
|
|
314
|
+
epoch_steps += 1
|
|
231
315
|
|
|
232
316
|
# Calculate tokens per second
|
|
233
317
|
step_time = time.time() - step_start_time
|
|
@@ -243,7 +327,7 @@ class Trainer:
|
|
|
243
327
|
tokens_per_sec=tokens_per_sec,
|
|
244
328
|
)
|
|
245
329
|
|
|
246
|
-
# Log metrics
|
|
330
|
+
# Log metrics (legacy logger)
|
|
247
331
|
self.logger.log(
|
|
248
332
|
self.global_step,
|
|
249
333
|
{
|
|
@@ -269,7 +353,7 @@ class Trainer:
|
|
|
269
353
|
is_best = False
|
|
270
354
|
|
|
271
355
|
# Save checkpoint
|
|
272
|
-
self.checkpoint_manager.save_checkpoint(
|
|
356
|
+
checkpoint_path = self.checkpoint_manager.save_checkpoint(
|
|
273
357
|
model=self.model,
|
|
274
358
|
optimizer=self.optimizer,
|
|
275
359
|
step=self.global_step,
|
|
@@ -280,8 +364,19 @@ class Trainer:
|
|
|
280
364
|
is_best=is_best,
|
|
281
365
|
)
|
|
282
366
|
|
|
367
|
+
# Notify observers of checkpoint
|
|
368
|
+
if checkpoint_path:
|
|
369
|
+
self.observer_manager.notify_checkpoint(self.global_step, str(checkpoint_path))
|
|
370
|
+
|
|
283
371
|
self.model.train() # Back to training mode
|
|
284
372
|
|
|
373
|
+
# Check for early stopping (if any observer requests it)
|
|
374
|
+
for obs in self.observer_manager.observers:
|
|
375
|
+
if hasattr(obs, 'should_stop') and obs.should_stop:
|
|
376
|
+
print(f"\nEarly stopping requested by {obs.name}")
|
|
377
|
+
self._finish_training()
|
|
378
|
+
return
|
|
379
|
+
|
|
285
380
|
# Save checkpoint periodically
|
|
286
381
|
if self.global_step % self.config.save_interval == 0 and self.global_step > 0:
|
|
287
382
|
self.checkpoint_manager.save_checkpoint(
|
|
@@ -299,17 +394,36 @@ class Trainer:
|
|
|
299
394
|
# Check if reached max steps
|
|
300
395
|
if self.config.max_steps > 0 and self.global_step >= self.config.max_steps:
|
|
301
396
|
print(f"\nReached max_steps ({self.config.max_steps}). Stopping training.")
|
|
397
|
+
self._finish_training()
|
|
302
398
|
return
|
|
303
399
|
|
|
304
|
-
# End of epoch
|
|
400
|
+
# End of epoch - notify observers
|
|
305
401
|
epoch_time = time.time() - epoch_start_time
|
|
402
|
+
epoch_avg_loss = epoch_loss_sum / epoch_steps if epoch_steps > 0 else 0
|
|
403
|
+
self.observer_manager.notify_epoch_end(epoch, {
|
|
404
|
+
"train_loss": epoch_avg_loss,
|
|
405
|
+
"epoch_time": epoch_time,
|
|
406
|
+
})
|
|
306
407
|
print(f"\nEpoch {epoch + 1} completed in {epoch_time:.2f}s")
|
|
307
408
|
|
|
409
|
+
self._finish_training()
|
|
410
|
+
|
|
411
|
+
def _finish_training(self):
|
|
412
|
+
"""Finalize training and notify observers."""
|
|
308
413
|
print("\n" + "=" * 60)
|
|
309
414
|
print("Training Complete!")
|
|
310
415
|
print("=" * 60)
|
|
311
416
|
print(f"Best validation loss: {self.best_val_loss:.4f}")
|
|
312
417
|
|
|
418
|
+
# Notify observers of training end
|
|
419
|
+
final_metrics = {
|
|
420
|
+
"best_val_loss": self.best_val_loss,
|
|
421
|
+
"total_steps": self.global_step,
|
|
422
|
+
"final_epoch": self.current_epoch,
|
|
423
|
+
"final_checkpoint": str(self.checkpoint_manager.checkpoint_dir / "final_model.pt"),
|
|
424
|
+
}
|
|
425
|
+
self.observer_manager.notify_train_end(final_metrics)
|
|
426
|
+
|
|
313
427
|
def resume_from_checkpoint(self, checkpoint_path: Optional[Path] = None):
|
|
314
428
|
"""
|
|
315
429
|
Resume training from a checkpoint.
|
gptmed/utils/checkpoints.py
CHANGED
|
@@ -108,7 +108,7 @@ class CheckpointManager:
|
|
|
108
108
|
# Save as best if applicable
|
|
109
109
|
if is_best or val_loss < self.best_val_loss:
|
|
110
110
|
self.best_val_loss = val_loss
|
|
111
|
-
best_path = self.checkpoint_dir / "
|
|
111
|
+
best_path = self.checkpoint_dir / "final_model.pt"
|
|
112
112
|
torch.save(checkpoint, best_path)
|
|
113
113
|
print(f"Best model saved: {best_path} (val_loss: {val_loss:.4f})")
|
|
114
114
|
|