gptmed 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -3
- gptmed/model/__init__.py +2 -2
- gptmed/observability/__init__.py +43 -0
- gptmed/observability/base.py +369 -0
- gptmed/observability/callbacks.py +397 -0
- gptmed/observability/metrics_tracker.py +544 -0
- gptmed/services/__init__.py +15 -0
- gptmed/services/device_manager.py +252 -0
- gptmed/services/training_service.py +489 -0
- gptmed/training/trainer.py +124 -10
- gptmed/utils/checkpoints.py +1 -1
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/METADATA +180 -43
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/RECORD +17 -10
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/WHEEL +0 -0
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/entry_points.txt +0 -0
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {gptmed-0.3.4.dist-info → gptmed-0.4.0.dist-info}/top_level.txt +0 -0
gptmed/training/trainer.py
CHANGED
|
@@ -49,7 +49,7 @@ import torch.nn as nn
|
|
|
49
49
|
from torch.utils.data import DataLoader
|
|
50
50
|
import time
|
|
51
51
|
from pathlib import Path
|
|
52
|
-
from typing import Optional
|
|
52
|
+
from typing import Optional, List
|
|
53
53
|
|
|
54
54
|
from gptmed.model.architecture import GPTTransformer
|
|
55
55
|
from gptmed.training.utils import (
|
|
@@ -62,6 +62,15 @@ from gptmed.training.utils import (
|
|
|
62
62
|
from gptmed.utils.logging import MetricsLogger, log_training_step, log_validation
|
|
63
63
|
from gptmed.utils.checkpoints import CheckpointManager
|
|
64
64
|
|
|
65
|
+
# New observability imports
|
|
66
|
+
from gptmed.observability.base import (
|
|
67
|
+
TrainingObserver,
|
|
68
|
+
ObserverManager,
|
|
69
|
+
StepMetrics,
|
|
70
|
+
ValidationMetrics,
|
|
71
|
+
GradientMetrics,
|
|
72
|
+
)
|
|
73
|
+
|
|
65
74
|
|
|
66
75
|
class Trainer:
|
|
67
76
|
"""
|
|
@@ -83,6 +92,7 @@ class Trainer:
|
|
|
83
92
|
optimizer: torch.optim.Optimizer,
|
|
84
93
|
config, # TrainingConfig
|
|
85
94
|
device: str = "cuda",
|
|
95
|
+
observers: List[TrainingObserver] = None,
|
|
86
96
|
):
|
|
87
97
|
"""
|
|
88
98
|
Args:
|
|
@@ -92,6 +102,7 @@ class Trainer:
|
|
|
92
102
|
optimizer: Optimizer (e.g., AdamW)
|
|
93
103
|
config: TrainingConfig object
|
|
94
104
|
device: Device to train on
|
|
105
|
+
observers: List of TrainingObserver instances for monitoring
|
|
95
106
|
"""
|
|
96
107
|
self.model = model.to(device)
|
|
97
108
|
self.train_loader = train_loader
|
|
@@ -100,7 +111,13 @@ class Trainer:
|
|
|
100
111
|
self.config = config
|
|
101
112
|
self.device = device
|
|
102
113
|
|
|
103
|
-
# Initialize
|
|
114
|
+
# Initialize observability
|
|
115
|
+
self.observer_manager = ObserverManager()
|
|
116
|
+
if observers:
|
|
117
|
+
for obs in observers:
|
|
118
|
+
self.observer_manager.add(obs)
|
|
119
|
+
|
|
120
|
+
# Initialize utilities (keep for backward compatibility)
|
|
104
121
|
self.logger = MetricsLogger(log_dir=config.log_dir, experiment_name="gpt_training")
|
|
105
122
|
|
|
106
123
|
self.checkpoint_manager = CheckpointManager(
|
|
@@ -124,17 +141,32 @@ class Trainer:
|
|
|
124
141
|
print(f" Total steps: {self.total_steps}")
|
|
125
142
|
print(f" Steps per epoch: {steps_per_epoch}")
|
|
126
143
|
print(f" Num epochs: {config.num_epochs}")
|
|
144
|
+
print(f" Observers: {len(self.observer_manager.observers)}")
|
|
145
|
+
|
|
146
|
+
def add_observer(self, observer: TrainingObserver) -> None:
|
|
147
|
+
"""
|
|
148
|
+
Add an observer for training monitoring.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
observer: TrainingObserver instance
|
|
152
|
+
"""
|
|
153
|
+
self.observer_manager.add(observer)
|
|
154
|
+
print(f" Added observer: {observer.name}")
|
|
127
155
|
|
|
128
|
-
def train_step(self, batch: tuple) -> dict:
|
|
156
|
+
def train_step(self, batch: tuple, step: int = 0, lr: float = 0.0) -> dict:
|
|
129
157
|
"""
|
|
130
158
|
Single training step.
|
|
131
159
|
|
|
132
160
|
Args:
|
|
133
161
|
batch: (input_ids, target_ids) tuple
|
|
162
|
+
step: Current global step (for observer metrics)
|
|
163
|
+
lr: Current learning rate (for observer metrics)
|
|
134
164
|
|
|
135
165
|
Returns:
|
|
136
166
|
Dictionary with step metrics
|
|
137
167
|
"""
|
|
168
|
+
step_start_time = time.time()
|
|
169
|
+
|
|
138
170
|
# Move batch to device
|
|
139
171
|
input_ids, target_ids = batch
|
|
140
172
|
input_ids = input_ids.to(self.device)
|
|
@@ -163,14 +195,34 @@ class Trainer:
|
|
|
163
195
|
# Optimizer step
|
|
164
196
|
self.optimizer.step()
|
|
165
197
|
|
|
166
|
-
#
|
|
167
|
-
|
|
198
|
+
# Calculate tokens per second
|
|
199
|
+
step_time = time.time() - step_start_time
|
|
200
|
+
tokens_per_sec = (batch_size * seq_len) / step_time if step_time > 0 else 0
|
|
201
|
+
|
|
202
|
+
# Create metrics dict (for backward compatibility)
|
|
203
|
+
metrics_dict = {
|
|
168
204
|
"loss": loss.item(),
|
|
169
205
|
"grad_norm": grad_norm,
|
|
170
206
|
"batch_size": batch_size,
|
|
171
207
|
"seq_len": seq_len,
|
|
208
|
+
"tokens_per_sec": tokens_per_sec,
|
|
172
209
|
}
|
|
173
210
|
|
|
211
|
+
# Notify observers with StepMetrics
|
|
212
|
+
step_metrics = StepMetrics(
|
|
213
|
+
step=step,
|
|
214
|
+
loss=loss.item(),
|
|
215
|
+
learning_rate=lr,
|
|
216
|
+
grad_norm=grad_norm,
|
|
217
|
+
batch_size=batch_size,
|
|
218
|
+
seq_len=seq_len,
|
|
219
|
+
tokens_per_sec=tokens_per_sec,
|
|
220
|
+
)
|
|
221
|
+
self.observer_manager.notify_step(step_metrics)
|
|
222
|
+
|
|
223
|
+
# Return metrics
|
|
224
|
+
return metrics_dict
|
|
225
|
+
|
|
174
226
|
def evaluate(self) -> dict:
|
|
175
227
|
"""
|
|
176
228
|
Evaluate on validation set.
|
|
@@ -188,6 +240,14 @@ class Trainer:
|
|
|
188
240
|
|
|
189
241
|
log_validation(self.global_step, val_loss, val_perplexity)
|
|
190
242
|
|
|
243
|
+
# Notify observers
|
|
244
|
+
val_metrics = ValidationMetrics(
|
|
245
|
+
step=self.global_step,
|
|
246
|
+
val_loss=val_loss,
|
|
247
|
+
val_perplexity=val_perplexity,
|
|
248
|
+
)
|
|
249
|
+
self.observer_manager.notify_validation(val_metrics)
|
|
250
|
+
|
|
191
251
|
return {"val_loss": val_loss, "val_perplexity": val_perplexity}
|
|
192
252
|
|
|
193
253
|
def train(self):
|
|
@@ -200,17 +260,37 @@ class Trainer:
|
|
|
200
260
|
print("Starting Training")
|
|
201
261
|
print("=" * 60)
|
|
202
262
|
|
|
263
|
+
# Notify observers of training start
|
|
264
|
+
train_config = {
|
|
265
|
+
"model_size": getattr(self.model.config, 'model_size', 'unknown'),
|
|
266
|
+
"device": self.device,
|
|
267
|
+
"batch_size": self.config.batch_size,
|
|
268
|
+
"learning_rate": self.config.learning_rate,
|
|
269
|
+
"num_epochs": self.config.num_epochs,
|
|
270
|
+
"max_steps": self.config.max_steps,
|
|
271
|
+
"total_steps": self.total_steps,
|
|
272
|
+
"warmup_steps": self.config.warmup_steps,
|
|
273
|
+
"grad_clip": self.config.grad_clip,
|
|
274
|
+
"weight_decay": self.config.weight_decay,
|
|
275
|
+
}
|
|
276
|
+
self.observer_manager.notify_train_start(train_config)
|
|
277
|
+
|
|
203
278
|
self.model.train()
|
|
204
279
|
|
|
205
280
|
# Training loop
|
|
206
281
|
for epoch in range(self.config.num_epochs):
|
|
207
282
|
self.current_epoch = epoch
|
|
208
283
|
|
|
284
|
+
# Notify observers of epoch start
|
|
285
|
+
self.observer_manager.notify_epoch_start(epoch)
|
|
286
|
+
|
|
209
287
|
print(f"\n{'='*60}")
|
|
210
288
|
print(f"Epoch {epoch + 1}/{self.config.num_epochs}")
|
|
211
289
|
print(f"{'='*60}")
|
|
212
290
|
|
|
213
291
|
epoch_start_time = time.time()
|
|
292
|
+
epoch_loss_sum = 0.0
|
|
293
|
+
epoch_steps = 0
|
|
214
294
|
|
|
215
295
|
for batch_idx, batch in enumerate(self.train_loader):
|
|
216
296
|
step_start_time = time.time()
|
|
@@ -226,8 +306,12 @@ class Trainer:
|
|
|
226
306
|
)
|
|
227
307
|
set_learning_rate(self.optimizer, lr)
|
|
228
308
|
|
|
229
|
-
# Training step
|
|
230
|
-
metrics = self.train_step(batch)
|
|
309
|
+
# Training step (now with step and lr for observers)
|
|
310
|
+
metrics = self.train_step(batch, step=self.global_step, lr=lr)
|
|
311
|
+
|
|
312
|
+
# Track epoch loss
|
|
313
|
+
epoch_loss_sum += metrics["loss"]
|
|
314
|
+
epoch_steps += 1
|
|
231
315
|
|
|
232
316
|
# Calculate tokens per second
|
|
233
317
|
step_time = time.time() - step_start_time
|
|
@@ -243,7 +327,7 @@ class Trainer:
|
|
|
243
327
|
tokens_per_sec=tokens_per_sec,
|
|
244
328
|
)
|
|
245
329
|
|
|
246
|
-
# Log metrics
|
|
330
|
+
# Log metrics (legacy logger)
|
|
247
331
|
self.logger.log(
|
|
248
332
|
self.global_step,
|
|
249
333
|
{
|
|
@@ -269,7 +353,7 @@ class Trainer:
|
|
|
269
353
|
is_best = False
|
|
270
354
|
|
|
271
355
|
# Save checkpoint
|
|
272
|
-
self.checkpoint_manager.save_checkpoint(
|
|
356
|
+
checkpoint_path = self.checkpoint_manager.save_checkpoint(
|
|
273
357
|
model=self.model,
|
|
274
358
|
optimizer=self.optimizer,
|
|
275
359
|
step=self.global_step,
|
|
@@ -280,8 +364,19 @@ class Trainer:
|
|
|
280
364
|
is_best=is_best,
|
|
281
365
|
)
|
|
282
366
|
|
|
367
|
+
# Notify observers of checkpoint
|
|
368
|
+
if checkpoint_path:
|
|
369
|
+
self.observer_manager.notify_checkpoint(self.global_step, str(checkpoint_path))
|
|
370
|
+
|
|
283
371
|
self.model.train() # Back to training mode
|
|
284
372
|
|
|
373
|
+
# Check for early stopping (if any observer requests it)
|
|
374
|
+
for obs in self.observer_manager.observers:
|
|
375
|
+
if hasattr(obs, 'should_stop') and obs.should_stop:
|
|
376
|
+
print(f"\nEarly stopping requested by {obs.name}")
|
|
377
|
+
self._finish_training()
|
|
378
|
+
return
|
|
379
|
+
|
|
285
380
|
# Save checkpoint periodically
|
|
286
381
|
if self.global_step % self.config.save_interval == 0 and self.global_step > 0:
|
|
287
382
|
self.checkpoint_manager.save_checkpoint(
|
|
@@ -299,17 +394,36 @@ class Trainer:
|
|
|
299
394
|
# Check if reached max steps
|
|
300
395
|
if self.config.max_steps > 0 and self.global_step >= self.config.max_steps:
|
|
301
396
|
print(f"\nReached max_steps ({self.config.max_steps}). Stopping training.")
|
|
397
|
+
self._finish_training()
|
|
302
398
|
return
|
|
303
399
|
|
|
304
|
-
# End of epoch
|
|
400
|
+
# End of epoch - notify observers
|
|
305
401
|
epoch_time = time.time() - epoch_start_time
|
|
402
|
+
epoch_avg_loss = epoch_loss_sum / epoch_steps if epoch_steps > 0 else 0
|
|
403
|
+
self.observer_manager.notify_epoch_end(epoch, {
|
|
404
|
+
"train_loss": epoch_avg_loss,
|
|
405
|
+
"epoch_time": epoch_time,
|
|
406
|
+
})
|
|
306
407
|
print(f"\nEpoch {epoch + 1} completed in {epoch_time:.2f}s")
|
|
307
408
|
|
|
409
|
+
self._finish_training()
|
|
410
|
+
|
|
411
|
+
def _finish_training(self):
|
|
412
|
+
"""Finalize training and notify observers."""
|
|
308
413
|
print("\n" + "=" * 60)
|
|
309
414
|
print("Training Complete!")
|
|
310
415
|
print("=" * 60)
|
|
311
416
|
print(f"Best validation loss: {self.best_val_loss:.4f}")
|
|
312
417
|
|
|
418
|
+
# Notify observers of training end
|
|
419
|
+
final_metrics = {
|
|
420
|
+
"best_val_loss": self.best_val_loss,
|
|
421
|
+
"total_steps": self.global_step,
|
|
422
|
+
"final_epoch": self.current_epoch,
|
|
423
|
+
"final_checkpoint": str(self.checkpoint_manager.checkpoint_dir / "final_model.pt"),
|
|
424
|
+
}
|
|
425
|
+
self.observer_manager.notify_train_end(final_metrics)
|
|
426
|
+
|
|
313
427
|
def resume_from_checkpoint(self, checkpoint_path: Optional[Path] = None):
|
|
314
428
|
"""
|
|
315
429
|
Resume training from a checkpoint.
|
gptmed/utils/checkpoints.py
CHANGED
|
@@ -108,7 +108,7 @@ class CheckpointManager:
|
|
|
108
108
|
# Save as best if applicable
|
|
109
109
|
if is_best or val_loss < self.best_val_loss:
|
|
110
110
|
self.best_val_loss = val_loss
|
|
111
|
-
best_path = self.checkpoint_dir / "
|
|
111
|
+
best_path = self.checkpoint_dir / "final_model.pt"
|
|
112
112
|
torch.save(checkpoint, best_path)
|
|
113
113
|
print(f"Best model saved: {best_path} (val_loss: {val_loss:.4f})")
|
|
114
114
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gptmed
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: A lightweight GPT-based language model framework for training custom question-answering models on any domain
|
|
5
5
|
Author-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
|
|
6
6
|
Maintainer-email: Sanjog Sigdel <sigdelsanjog@gmail.com>
|
|
@@ -10,7 +10,7 @@ Project-URL: Documentation, https://github.com/sigdelsanjog/gptmed#readme
|
|
|
10
10
|
Project-URL: Repository, https://github.com/sigdelsanjog/gptmed
|
|
11
11
|
Project-URL: Issues, https://github.com/sigdelsanjog/gptmed/issues
|
|
12
12
|
Keywords: nlp,language-model,transformer,gpt,pytorch,qa,question-answering,training,deep-learning,custom-model
|
|
13
|
-
Classifier: Development Status ::
|
|
13
|
+
Classifier: Development Status :: 4 - Beta
|
|
14
14
|
Classifier: Intended Audience :: Developers
|
|
15
15
|
Classifier: Intended Audience :: Science/Research
|
|
16
16
|
Classifier: Intended Audience :: Education
|
|
@@ -38,28 +38,64 @@ Requires-Dist: mypy>=0.950; extra == "dev"
|
|
|
38
38
|
Provides-Extra: training
|
|
39
39
|
Requires-Dist: tensorboard>=2.10.0; extra == "training"
|
|
40
40
|
Requires-Dist: wandb>=0.13.0; extra == "training"
|
|
41
|
+
Provides-Extra: visualization
|
|
42
|
+
Requires-Dist: matplotlib>=3.5.0; extra == "visualization"
|
|
43
|
+
Requires-Dist: seaborn>=0.12.0; extra == "visualization"
|
|
44
|
+
Provides-Extra: xai
|
|
45
|
+
Requires-Dist: matplotlib>=3.5.0; extra == "xai"
|
|
46
|
+
Requires-Dist: seaborn>=0.12.0; extra == "xai"
|
|
47
|
+
Requires-Dist: captum>=0.6.0; extra == "xai"
|
|
48
|
+
Requires-Dist: scikit-learn>=1.0.0; extra == "xai"
|
|
41
49
|
Dynamic: license-file
|
|
42
50
|
|
|
43
51
|
# GptMed 🤖
|
|
44
52
|
|
|
45
|
-
|
|
46
|
-
|
|
53
|
+
[](https://pepy.tech/project/gptmed)
|
|
54
|
+
[](https://pepy.tech/project/gptmed)
|
|
47
55
|
[](https://badge.fury.io/py/gptmed)
|
|
48
56
|
[](https://www.python.org/downloads/)
|
|
49
57
|
[](https://opensource.org/licenses/MIT)
|
|
50
58
|
|
|
51
|
-
|
|
59
|
+
A lightweight GPT-based language model framework for training custom question-answering models on any domain. This package provides a transformer-based GPT architecture that you can train on your own Q&A datasets - whether it's casual conversations, technical support, education, or any other domain.
|
|
52
60
|
|
|
53
|
-
|
|
61
|
+
## Citation
|
|
54
62
|
|
|
55
|
-
|
|
63
|
+
If you use this model in your research, please cite:
|
|
56
64
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
65
|
+
```bibtex
|
|
66
|
+
@software{gptmed_2026,
|
|
67
|
+
author = {Sanjog Sigdel},
|
|
68
|
+
title = {GptMed: A custom causal question answering general purpose GPT Transformer Architecture Model},
|
|
69
|
+
year = {2026},
|
|
70
|
+
url = {https://github.com/sigdelsanjog/gptmed}
|
|
71
|
+
}
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
## Table of Contents
|
|
75
|
+
|
|
76
|
+
- [Installation](#installation)
|
|
77
|
+
- [From PyPI (Recommended)](#from-pypi-recommended)
|
|
78
|
+
- [From Source](#from-source)
|
|
79
|
+
- [With Optional Dependencies](#with-optional-dependencies)
|
|
80
|
+
- [Quick Start](#quick-start)
|
|
81
|
+
- [Using the High-Level API](#using-the-high-level-api)
|
|
82
|
+
- [Inference (Generate Answers)](#inference-generate-answers)
|
|
83
|
+
- [Using Command Line](#using-command-line)
|
|
84
|
+
- [Training Your Own Model](#training-your-own-model)
|
|
85
|
+
- [Model Architecture](#model-architecture)
|
|
86
|
+
- [Configuration](#configuration)
|
|
87
|
+
- [Model Sizes](#model-sizes)
|
|
88
|
+
- [Training Configuration](#training-configuration)
|
|
89
|
+
- [Observability](#observability)
|
|
90
|
+
- [Project Structure](#project-structure)
|
|
91
|
+
- [Requirements](#requirements)
|
|
92
|
+
- [Documentation](#documentation)
|
|
93
|
+
- [Performance](#performance)
|
|
94
|
+
- [Examples](#examples)
|
|
95
|
+
- [Contributing](#contributing)
|
|
96
|
+
- [Citation](#citation)
|
|
97
|
+
- [License](#license)
|
|
98
|
+
- [Support](#support)
|
|
63
99
|
|
|
64
100
|
## Installation
|
|
65
101
|
|
|
@@ -83,15 +119,49 @@ pip install -e .
|
|
|
83
119
|
# For development
|
|
84
120
|
pip install gptmed[dev]
|
|
85
121
|
|
|
86
|
-
# For training
|
|
122
|
+
# For training with logging integrations
|
|
87
123
|
pip install gptmed[training]
|
|
88
124
|
|
|
125
|
+
# For visualization (loss curves, metrics plots)
|
|
126
|
+
pip install gptmed[visualization]
|
|
127
|
+
|
|
128
|
+
# For Explainable AI features
|
|
129
|
+
pip install gptmed[xai]
|
|
130
|
+
|
|
89
131
|
# All dependencies
|
|
90
|
-
pip install gptmed[dev,training]
|
|
132
|
+
pip install gptmed[dev,training,visualization,xai]
|
|
91
133
|
```
|
|
92
134
|
|
|
93
135
|
## Quick Start
|
|
94
136
|
|
|
137
|
+
### Using the High-Level API
|
|
138
|
+
|
|
139
|
+
The easiest way to use GptMed is through the high-level API:
|
|
140
|
+
|
|
141
|
+
```python
|
|
142
|
+
import gptmed
|
|
143
|
+
|
|
144
|
+
# 1. Create a training configuration
|
|
145
|
+
gptmed.create_config('my_config.yaml')
|
|
146
|
+
|
|
147
|
+
# 2. Edit my_config.yaml with your settings (data paths, model size, etc.)
|
|
148
|
+
|
|
149
|
+
# 3. Train the model
|
|
150
|
+
gptmed.train_from_config('my_config.yaml')
|
|
151
|
+
|
|
152
|
+
# 4. Generate answers
|
|
153
|
+
answer = gptmed.generate(
|
|
154
|
+
checkpoint='model/checkpoints/best_model.pt',
|
|
155
|
+
tokenizer='tokenizer/my_tokenizer.model',
|
|
156
|
+
prompt='What is machine learning?',
|
|
157
|
+
max_length=150,
|
|
158
|
+
temperature=0.7
|
|
159
|
+
)
|
|
160
|
+
print(answer)
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
For a complete API testing workflow, see the [gptmed-api folder](https://github.com/sigdelsanjog/gptmed/tree/main/gptmed-api) with ready-to-run examples.
|
|
164
|
+
|
|
95
165
|
### Inference (Generate Answers)
|
|
96
166
|
|
|
97
167
|
```python
|
|
@@ -187,6 +257,50 @@ config = TrainingConfig(
|
|
|
187
257
|
)
|
|
188
258
|
```
|
|
189
259
|
|
|
260
|
+
## Observability
|
|
261
|
+
|
|
262
|
+
**New in v0.4.0**: Built-in training monitoring with Observer Pattern architecture.
|
|
263
|
+
|
|
264
|
+
### Features
|
|
265
|
+
|
|
266
|
+
- 📊 **Loss Curves**: Track training/validation loss over time
|
|
267
|
+
- 📈 **Metrics Tracking**: Perplexity, gradient norms, learning rates
|
|
268
|
+
- 🔔 **Callbacks**: Console output, JSON logging, early stopping
|
|
269
|
+
- 📁 **Export**: CSV export, matplotlib visualizations
|
|
270
|
+
- 🔌 **Extensible**: Add custom observers for integrations (W&B, TensorBoard)
|
|
271
|
+
|
|
272
|
+
### Quick Example
|
|
273
|
+
|
|
274
|
+
```python
|
|
275
|
+
from gptmed.observability import MetricsTracker, ConsoleCallback, EarlyStoppingCallback
|
|
276
|
+
|
|
277
|
+
# Create observers
|
|
278
|
+
tracker = MetricsTracker(output_dir='./metrics')
|
|
279
|
+
console = ConsoleCallback(print_every=50)
|
|
280
|
+
early_stop = EarlyStoppingCallback(patience=3)
|
|
281
|
+
|
|
282
|
+
# Use with TrainingService (automatic)
|
|
283
|
+
from gptmed.services import TrainingService
|
|
284
|
+
service = TrainingService(config_path='config.yaml')
|
|
285
|
+
service.train() # Automatically creates MetricsTracker
|
|
286
|
+
|
|
287
|
+
# Or use with Trainer directly
|
|
288
|
+
trainer = Trainer(model, train_loader, config, observers=[tracker, console])
|
|
289
|
+
trainer.train()
|
|
290
|
+
```
|
|
291
|
+
|
|
292
|
+
### Available Observers
|
|
293
|
+
|
|
294
|
+
| Observer | Description |
|
|
295
|
+
| ----------------------- | --------------------------------------------------------- |
|
|
296
|
+
| `MetricsTracker` | Comprehensive metrics collection with export capabilities |
|
|
297
|
+
| `ConsoleCallback` | Real-time console output with progress bars |
|
|
298
|
+
| `JSONLoggerCallback` | Structured JSON logging for analysis |
|
|
299
|
+
| `EarlyStoppingCallback` | Stop training when validation loss plateaus |
|
|
300
|
+
| `LRSchedulerCallback` | Learning rate scheduling integration |
|
|
301
|
+
|
|
302
|
+
See [XAI.md](XAI.md) for future Explainable AI features roadmap.
|
|
303
|
+
|
|
190
304
|
## Project Structure
|
|
191
305
|
|
|
192
306
|
```
|
|
@@ -201,10 +315,16 @@ gptmed/
|
|
|
201
315
|
│ ├── train.py # Training script
|
|
202
316
|
│ ├── trainer.py # Training loop
|
|
203
317
|
│ └── dataset.py # Data loading
|
|
318
|
+
├── observability/ # Training monitoring & XAI (v0.4.0+)
|
|
319
|
+
│ ├── base.py # Observer pattern interfaces
|
|
320
|
+
│ ├── metrics_tracker.py # Loss curves & metrics
|
|
321
|
+
│ └── callbacks.py # Console, JSON, early stopping
|
|
204
322
|
├── tokenizer/
|
|
205
323
|
│ └── train_tokenizer.py # SentencePiece tokenizer
|
|
206
324
|
├── configs/
|
|
207
325
|
│ └── train_config.py # Training configurations
|
|
326
|
+
├── services/
|
|
327
|
+
│ └── training_service.py # High-level training orchestration
|
|
208
328
|
└── utils/
|
|
209
329
|
├── checkpoints.py # Model checkpointing
|
|
210
330
|
└── logging.py # Training logging
|
|
@@ -226,6 +346,7 @@ gptmed/
|
|
|
226
346
|
|
|
227
347
|
- [User Manual](USER_MANUAL.md) - **Start here!** Complete training pipeline guide
|
|
228
348
|
- [Architecture Guide](ARCHITECTURE_EXTENSION_GUIDE.md) - Understanding the model architecture
|
|
349
|
+
- [XAI Roadmap](XAI.md) - Explainable AI features & implementation guide
|
|
229
350
|
- [Deployment Guide](DEPLOYMENT_GUIDE.md) - Publishing to PyPI
|
|
230
351
|
- [Changelog](CHANGELOG.md) - Version history
|
|
231
352
|
|
|
@@ -241,20 +362,53 @@ _Tested on GTX 1080 8GB_
|
|
|
241
362
|
|
|
242
363
|
## Examples
|
|
243
364
|
|
|
244
|
-
###
|
|
365
|
+
### Domain-Agnostic Usage
|
|
366
|
+
|
|
367
|
+
GptMed works with **any domain** - just train on your own Q&A data:
|
|
245
368
|
|
|
246
369
|
```python
|
|
247
|
-
#
|
|
248
|
-
question = "
|
|
370
|
+
# Technical Support Bot
|
|
371
|
+
question = "How do I reset my WiFi router?"
|
|
249
372
|
answer = generator.generate(question, temperature=0.7)
|
|
250
373
|
|
|
251
|
-
#
|
|
252
|
-
question = "
|
|
374
|
+
# Educational Assistant
|
|
375
|
+
question = "Explain the water cycle in simple terms"
|
|
253
376
|
answer = generator.generate(question, temperature=0.6)
|
|
254
377
|
|
|
255
|
-
#
|
|
256
|
-
question = "What is
|
|
378
|
+
# Customer Service
|
|
379
|
+
question = "What is your return policy?"
|
|
257
380
|
answer = generator.generate(question, temperature=0.5)
|
|
381
|
+
|
|
382
|
+
# Medical Q&A (example domain)
|
|
383
|
+
question = "What are the symptoms of flu?"
|
|
384
|
+
answer = generator.generate(question, temperature=0.7)
|
|
385
|
+
```
|
|
386
|
+
|
|
387
|
+
### Training Observability (v0.4.0+)
|
|
388
|
+
|
|
389
|
+
Monitor your training with built-in observability:
|
|
390
|
+
|
|
391
|
+
```python
|
|
392
|
+
from gptmed.observability import MetricsTracker, ConsoleCallback
|
|
393
|
+
|
|
394
|
+
# Create observers
|
|
395
|
+
tracker = MetricsTracker(output_dir='./metrics')
|
|
396
|
+
console = ConsoleCallback(print_every=10)
|
|
397
|
+
|
|
398
|
+
# Train with observability
|
|
399
|
+
gptmed.train_from_config(
|
|
400
|
+
'my_config.yaml',
|
|
401
|
+
observers=[tracker, console]
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
# After training - get the report
|
|
405
|
+
report = tracker.get_report()
|
|
406
|
+
print(f"Final Loss: {report['final_loss']:.4f}")
|
|
407
|
+
print(f"Total Steps: {report['total_steps']}")
|
|
408
|
+
|
|
409
|
+
# Export metrics
|
|
410
|
+
tracker.export_to_csv('training_metrics.csv')
|
|
411
|
+
tracker.plot_loss_curves('loss_curves.png') # Requires matplotlib
|
|
258
412
|
```
|
|
259
413
|
|
|
260
414
|
## Contributing
|
|
@@ -267,19 +421,6 @@ Contributions are welcome! Please feel free to submit a Pull Request.
|
|
|
267
421
|
4. Push to the branch (`git push origin feature/AmazingFeature`)
|
|
268
422
|
5. Open a Pull Request
|
|
269
423
|
|
|
270
|
-
## Citation
|
|
271
|
-
|
|
272
|
-
If you use this model in your research, please cite:
|
|
273
|
-
|
|
274
|
-
```bibtex
|
|
275
|
-
@software{llm_med_2026,
|
|
276
|
-
author = {Sanjog Sigdel},
|
|
277
|
-
title = {GptMed: A custom causal question answering general purpose GPT Transformer Architecture Model},
|
|
278
|
-
year = {2026},
|
|
279
|
-
url = {https://github.com/sigdelsanjog/gptmed}
|
|
280
|
-
}
|
|
281
|
-
```
|
|
282
|
-
|
|
283
424
|
## License
|
|
284
425
|
|
|
285
426
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -289,16 +430,12 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|
|
289
430
|
- MedQuAD dataset creators
|
|
290
431
|
- PyTorch team
|
|
291
432
|
|
|
292
|
-
## Disclaimer
|
|
293
|
-
|
|
294
|
-
⚠️ **Medical Disclaimer**: This model is for research and educational purposes only. It should NOT be used for actual medical diagnosis or treatment decisions. Always consult qualified healthcare professionals for medical advice.
|
|
295
|
-
|
|
296
433
|
## Support
|
|
297
434
|
|
|
298
|
-
-
|
|
299
|
-
-
|
|
435
|
+
- 📫 [User Manual](USER_MANUAL.md)\*\* - Complete step-by-step training guide
|
|
436
|
+
- 📫 Issues: [GitHub Issues](https://github.com/sigdelsanjog/gptmed/issues)
|
|
300
437
|
- 💬 Discussions: [GitHub Discussions](https://github.com/sigdelsanjog/gptmed/discussions)
|
|
301
|
-
- 📧 Email: sanjog.sigdel@ku.edu.np
|
|
438
|
+
- 📧 Email: sigdelsanjog@gmail.com | sanjog.sigdel@ku.edu.np
|
|
302
439
|
|
|
303
440
|
## Changelog
|
|
304
441
|
|
|
@@ -306,4 +443,4 @@ See [CHANGELOG.md](CHANGELOG.md) for version history.
|
|
|
306
443
|
|
|
307
444
|
---
|
|
308
445
|
|
|
309
|
-
Made with ❤️
|
|
446
|
+
#### Made with ❤️ from Nepal
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
gptmed/__init__.py,sha256=
|
|
1
|
+
gptmed/__init__.py,sha256=lSCUt0jmB81dEG0UroQdrk8TMG9Hv-_a14nAvB6yYiQ,2725
|
|
2
2
|
gptmed/api.py,sha256=k9a_1F2h__xgKnH2l0FaJqAqu-iTYt5tu_VfVO0UhrA,9806
|
|
3
3
|
gptmed/configs/__init__.py,sha256=yRa-zgPQ-OCzu8fvCrfWMG-CjF3dru3PZzknzm0oUaQ,23
|
|
4
4
|
gptmed/configs/config_loader.py,sha256=3GQ1iCNpdJ5yALWXA3SPPHRkaUO-117vdArEL6u7sK8,6354
|
|
@@ -13,7 +13,7 @@ gptmed/inference/decoding_utils.py,sha256=zTDZYdl2jcGwSrcINXMw-5uoYuF4A9TSushhPx
|
|
|
13
13
|
gptmed/inference/generation_config.py,sha256=hpPyZUk1K6qGSBAoQx3Jm0_ZrrYld77ACxbIlCCCcVU,2813
|
|
14
14
|
gptmed/inference/generator.py,sha256=6JFmDPQF4btau_Gp5pfk8a5G0Iyg6QsB9Y8Oo4ygH-4,7884
|
|
15
15
|
gptmed/inference/sampling.py,sha256=B6fRlJafypuBMKJ0rTbsk6k8KXloXiIvroi7rN6ekBA,7947
|
|
16
|
-
gptmed/model/__init__.py,sha256=
|
|
16
|
+
gptmed/model/__init__.py,sha256=j5-KIrO_-913r1exnNzrbuuhuVmRJMmssRWrQjj5rdw,199
|
|
17
17
|
gptmed/model/architecture/__init__.py,sha256=9MpSAYwwZY-t1vBLIupuRtLD7CaOLJRENMh3zKx3M-4,970
|
|
18
18
|
gptmed/model/architecture/attention.py,sha256=Qk1eGl9glKWQbhcXJWmFkO5U3VHBq7OrsjVG0tPmgnY,6420
|
|
19
19
|
gptmed/model/architecture/decoder_block.py,sha256=n-Uo09TDcirKeWTWTNumldGOrx-b2Elb25lbF6cTYwg,3879
|
|
@@ -22,20 +22,27 @@ gptmed/model/architecture/feedforward.py,sha256=uJ5QOlWX0ritKDQLUE7GPmMojelR9-sT
|
|
|
22
22
|
gptmed/model/architecture/transformer.py,sha256=H1njPoy0Uam59JbA24C0olEDwPfhh3ev4HsUFRIC_0Y,6626
|
|
23
23
|
gptmed/model/configs/__init__.py,sha256=LDCWhlCDOU7490wcfSId_jXBPfQrtYQEw8FoD67rqBs,275
|
|
24
24
|
gptmed/model/configs/model_config.py,sha256=wI-i2Dw_pTdIKCDe1pqLvP3ky3YedEy7DwZYN5lwmKE,4673
|
|
25
|
+
gptmed/observability/__init__.py,sha256=AtGf0D8jEx2LGQ0Ro-Eh0SFDuA5ZjZkot7D1Y8j1jiM,1180
|
|
26
|
+
gptmed/observability/base.py,sha256=Mi3F95bJ9Tw5scoSyw9AtKlcu9aG444G1UlycIIGCtI,10748
|
|
27
|
+
gptmed/observability/callbacks.py,sha256=1b84_e86mfyt2EQGzf-6K2Sba3bZJt4I3bBJb52TAbA,13170
|
|
28
|
+
gptmed/observability/metrics_tracker.py,sha256=Bs6tppQYG9AOb3rj2T1lhWKDyOw4R4ZG6nFGRiek8FQ,19441
|
|
29
|
+
gptmed/services/__init__.py,sha256=FtM7NQ_S4VOfl2n6A6cLcOxG9-w7BK7DicQsUvOMmGE,369
|
|
30
|
+
gptmed/services/device_manager.py,sha256=RSsu0RlsexCIO-p4eejOZAPLgpaVA0y9niTg8wf1luY,7513
|
|
31
|
+
gptmed/services/training_service.py,sha256=cF3yYo8aZe7BfQ-paTN-l7EYs9h8L_JUyRhiI0GEP4E,16921
|
|
25
32
|
gptmed/tokenizer/__init__.py,sha256=KhLAHPmQyoWhnKDenyIJRxgFflKI7xklip28j4cKfKw,157
|
|
26
33
|
gptmed/tokenizer/tokenize_data.py,sha256=KgMtMfaz_RtOhN_CrvC267k9ujxRdO89rToVJ6nzdwg,9139
|
|
27
34
|
gptmed/tokenizer/train_tokenizer.py,sha256=f0Hucyft9e8LU2RtpTqg8h_0SpOC_oMABl0_me-wfL8,7068
|
|
28
35
|
gptmed/training/__init__.py,sha256=6G0_gdlwBnQBG8wZlTm2NtgkXZJcXRfLMDQ2iu6O3U4,24
|
|
29
36
|
gptmed/training/dataset.py,sha256=QbNVTN4Og5gqMAV2ckjRX8W_k9aUc9IZJDcu0u9U8t0,5347
|
|
30
37
|
gptmed/training/train.py,sha256=sp4-1WpEXUTA9V0GUYAgSvMd2aaPkt1aq2PepQFLXD8,8142
|
|
31
|
-
gptmed/training/trainer.py,sha256=
|
|
38
|
+
gptmed/training/trainer.py,sha256=F1K9pkDMU-TbvmGGQ3JjRek4ZoFI6GHIFFEg7yYfIGM,15206
|
|
32
39
|
gptmed/training/utils.py,sha256=pJxCwneNr2STITIYwIDCxRzIICDFOxOMzK8DT7ck2oQ,5651
|
|
33
40
|
gptmed/utils/__init__.py,sha256=XuMhIqOXF7mjnog_6Iky-hSbwvFb0iK42B4iDUpgi0U,44
|
|
34
|
-
gptmed/utils/checkpoints.py,sha256=
|
|
41
|
+
gptmed/utils/checkpoints.py,sha256=jPKJtO0YRZieGmpwqotgDkBzd__s_raDxS1kLpfjBJE,7113
|
|
35
42
|
gptmed/utils/logging.py,sha256=7dJc1tayMxCBjFSDXe4r9ACUTpoPTTGsJ0UZMTqZIDY,5303
|
|
36
|
-
gptmed-0.
|
|
37
|
-
gptmed-0.
|
|
38
|
-
gptmed-0.
|
|
39
|
-
gptmed-0.
|
|
40
|
-
gptmed-0.
|
|
41
|
-
gptmed-0.
|
|
43
|
+
gptmed-0.4.0.dist-info/licenses/LICENSE,sha256=v2spsd7N1pKFFh2G8wGP_45iwe5S0DYiJzG4im8Rupc,1066
|
|
44
|
+
gptmed-0.4.0.dist-info/METADATA,sha256=kVsL6zbBoGw1jrlaDiPkBAr_D7YedPCSwZkjGCFz04c,13832
|
|
45
|
+
gptmed-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
46
|
+
gptmed-0.4.0.dist-info/entry_points.txt,sha256=ATqOzTtPVdUiFX5ZSeo3n9JkUCqocUxEXTgy1CfNRZE,110
|
|
47
|
+
gptmed-0.4.0.dist-info/top_level.txt,sha256=mhyEq3rG33t21ziJz5w3TPgx0RjPf4zXMNUx2JTiNmE,7
|
|
48
|
+
gptmed-0.4.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|