gptmed 0.3.5__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gptmed/__init__.py CHANGED
@@ -3,7 +3,14 @@ GptMed: A lightweight GPT-based language model framework
3
3
 
4
4
  A domain-agnostic framework for training custom question-answering models.
5
5
  Train your own GPT model on any Q&A dataset - medical, technical support,
6
- education, or any other domain.
6
+ education, legal, customer service, or any other domain.
7
+
8
+ Key Features:
9
+ - Simple 3-step training: config → train → generate
10
+ - Built-in training observability with loss curves and metrics
11
+ - Flexible model sizes (tiny, small, medium)
12
+ - Device-agnostic (CPU/CUDA with auto-detection)
13
+ - XAI-ready architecture for model interpretability
7
14
 
8
15
  Quick Start:
9
16
  >>> import gptmed
@@ -13,7 +20,7 @@ Quick Start:
13
20
  >>>
14
21
  >>> # 2. Edit my_config.yaml with your settings
15
22
  >>>
16
- >>> # 3. Train your model
23
+ >>> # 3. Train your model (with automatic metrics tracking)
17
24
  >>> results = gptmed.train_from_config('my_config.yaml')
18
25
  >>>
19
26
  >>> # 4. Generate answers
@@ -23,6 +30,16 @@ Quick Start:
23
30
  ... prompt='Your question here?'
24
31
  ... )
25
32
 
33
+ Observability (v0.4.0+):
34
+ >>> from gptmed import MetricsTracker, EarlyStoppingCallback
35
+ >>>
36
+ >>> # Training automatically tracks metrics and generates reports:
37
+ >>> # - Loss curves (train/val)
38
+ >>> # - Gradient norms
39
+ >>> # - Learning rate schedule
40
+ >>> # - Perplexity
41
+ >>> # - Training health checks
42
+
26
43
  Advanced Usage:
27
44
  >>> from gptmed.model.architecture import GPTTransformer
28
45
  >>> from gptmed.model.configs.model_config import get_small_config
@@ -32,7 +49,7 @@ Advanced Usage:
32
49
  >>> model = GPTTransformer(config)
33
50
  """
34
51
 
35
- __version__ = "0.3.3"
52
+ __version__ = "0.4.0"
36
53
  __author__ = "Sanjog Sigdel"
37
54
  __email__ = "sigdelsanjog@gmail.com"
38
55
 
@@ -47,6 +64,16 @@ from gptmed.api import (
47
64
  from gptmed.model.architecture import GPTTransformer
48
65
  from gptmed.model.configs.model_config import ModelConfig, get_small_config, get_tiny_config
49
66
 
67
+ # Observability module
68
+ from gptmed.observability import (
69
+ TrainingObserver,
70
+ ObserverManager,
71
+ MetricsTracker,
72
+ ConsoleCallback,
73
+ JSONLoggerCallback,
74
+ EarlyStoppingCallback,
75
+ )
76
+
50
77
  __all__ = [
51
78
  # Simple API
52
79
  "create_config",
@@ -57,4 +84,11 @@ __all__ = [
57
84
  "ModelConfig",
58
85
  "get_small_config",
59
86
  "get_tiny_config",
87
+ # Observability
88
+ "TrainingObserver",
89
+ "ObserverManager",
90
+ "MetricsTracker",
91
+ "ConsoleCallback",
92
+ "JSONLoggerCallback",
93
+ "EarlyStoppingCallback",
60
94
  ]
gptmed/model/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """
2
- MedLLM Model Package
2
+ GPTMED Model Package
3
3
 
4
- This package contains the GPT-based transformer architecture for medical QA.
4
+ This package contains the GPT-based transformer architecture for general purpose QA.
5
5
  """
6
6
 
7
7
  from gptmed.model.architecture import GPTTransformer
@@ -0,0 +1,43 @@
1
+ """
2
+ Observability Module for GptMed
3
+
4
+ PURPOSE:
5
+ Provides training observability, metrics tracking, and XAI capabilities.
6
+ Implements Observer Pattern for decoupled monitoring.
7
+
8
+ DESIGN PATTERNS:
9
+ - Observer Pattern: Trainer emits events, observers react independently
10
+ - Strategy Pattern: Swap between different logging backends
11
+ - Open/Closed Principle: Add new observers without modifying Trainer
12
+
13
+ COMPONENTS:
14
+ - TrainingObserver: Abstract base class for all observers
15
+ - MetricsTracker: Enhanced loss curves and training metrics
16
+ - Callbacks: TensorBoard, W&B, Early Stopping, etc.
17
+
18
+ FUTURE EXTENSIONS:
19
+ - Attention visualization
20
+ - Saliency maps / input attribution
21
+ - Embedding space analysis
22
+ - Gradient flow analysis
23
+ """
24
+
25
+ from gptmed.observability.base import TrainingObserver, ObserverManager
26
+ from gptmed.observability.metrics_tracker import MetricsTracker
27
+ from gptmed.observability.callbacks import (
28
+ ConsoleCallback,
29
+ JSONLoggerCallback,
30
+ EarlyStoppingCallback,
31
+ )
32
+
33
+ __all__ = [
34
+ # Base classes
35
+ "TrainingObserver",
36
+ "ObserverManager",
37
+ # Metrics
38
+ "MetricsTracker",
39
+ # Callbacks
40
+ "ConsoleCallback",
41
+ "JSONLoggerCallback",
42
+ "EarlyStoppingCallback",
43
+ ]
@@ -0,0 +1,369 @@
1
+ """
2
+ Base Observer Classes
3
+
4
+ PURPOSE:
5
+ Define abstract interfaces for training observation.
6
+ Implements Observer Pattern for decoupled monitoring.
7
+
8
+ DESIGN PATTERNS:
9
+ - Observer Pattern: Subjects (Trainer) notify observers of state changes
10
+ - Template Method: Base class defines interface, subclasses implement
11
+ - Dependency Inversion: Trainer depends on abstraction, not concrete observers
12
+
13
+ WHY OBSERVER PATTERN:
14
+ 1. Decoupling: Trainer doesn't know about logging implementations
15
+ 2. Extensibility: Add new observers without modifying Trainer
16
+ 3. Single Responsibility: Each observer handles one concern
17
+ 4. Open/Closed: Open for extension, closed for modification
18
+ """
19
+
20
+ from abc import ABC, abstractmethod
21
+ from typing import Dict, Any, List, Optional
22
+ from dataclasses import dataclass, field
23
+ from enum import Enum
24
+
25
+
26
+ class TrainingEvent(Enum):
27
+ """Training lifecycle events that observers can subscribe to."""
28
+ TRAIN_START = "on_train_start"
29
+ TRAIN_END = "on_train_end"
30
+ EPOCH_START = "on_epoch_start"
31
+ EPOCH_END = "on_epoch_end"
32
+ STEP = "on_step"
33
+ VALIDATION = "on_validation"
34
+ CHECKPOINT = "on_checkpoint"
35
+ GRADIENT_COMPUTED = "on_gradient_computed"
36
+
37
+
38
+ @dataclass
39
+ class StepMetrics:
40
+ """
41
+ Metrics collected at each training step.
42
+
43
+ Using dataclass for:
44
+ - Type safety
45
+ - Clear documentation of expected metrics
46
+ - Easy serialization
47
+ """
48
+ step: int
49
+ loss: float
50
+ learning_rate: float
51
+ grad_norm: float
52
+ batch_size: int
53
+ seq_len: int
54
+ tokens_per_sec: float = 0.0
55
+ perplexity: float = 0.0
56
+
57
+ def __post_init__(self):
58
+ """Compute derived metrics."""
59
+ if self.perplexity == 0.0 and self.loss > 0:
60
+ import math
61
+ self.perplexity = math.exp(min(self.loss, 100)) # Cap to avoid overflow
62
+
63
+ def to_dict(self) -> Dict[str, float]:
64
+ """Convert to dictionary for logging."""
65
+ return {
66
+ "step": self.step,
67
+ "loss": self.loss,
68
+ "learning_rate": self.learning_rate,
69
+ "grad_norm": self.grad_norm,
70
+ "batch_size": self.batch_size,
71
+ "seq_len": self.seq_len,
72
+ "tokens_per_sec": self.tokens_per_sec,
73
+ "perplexity": self.perplexity,
74
+ }
75
+
76
+
77
+ @dataclass
78
+ class ValidationMetrics:
79
+ """Metrics collected during validation."""
80
+ step: int
81
+ val_loss: float
82
+ val_perplexity: float = 0.0
83
+
84
+ def __post_init__(self):
85
+ """Compute derived metrics."""
86
+ if self.val_perplexity == 0.0 and self.val_loss > 0:
87
+ import math
88
+ self.val_perplexity = math.exp(min(self.val_loss, 100))
89
+
90
+ def to_dict(self) -> Dict[str, float]:
91
+ """Convert to dictionary for logging."""
92
+ return {
93
+ "step": self.step,
94
+ "val_loss": self.val_loss,
95
+ "val_perplexity": self.val_perplexity,
96
+ }
97
+
98
+
99
+ @dataclass
100
+ class GradientMetrics:
101
+ """
102
+ Gradient statistics for observability.
103
+
104
+ Used for detecting:
105
+ - Vanishing gradients (norm → 0)
106
+ - Exploding gradients (norm → ∞)
107
+ - Dead neurons (high zero fraction)
108
+ """
109
+ step: int
110
+ total_norm: float
111
+ layer_norms: Dict[str, float] = field(default_factory=dict)
112
+ max_grad: float = 0.0
113
+ min_grad: float = 0.0
114
+ zero_fraction: float = 0.0
115
+
116
+ def to_dict(self) -> Dict[str, Any]:
117
+ """Convert to dictionary for logging."""
118
+ return {
119
+ "step": self.step,
120
+ "total_norm": self.total_norm,
121
+ "layer_norms": self.layer_norms,
122
+ "max_grad": self.max_grad,
123
+ "min_grad": self.min_grad,
124
+ "zero_fraction": self.zero_fraction,
125
+ }
126
+
127
+
128
+ class TrainingObserver(ABC):
129
+ """
130
+ Abstract base class for training observers.
131
+
132
+ Implements Observer Pattern - receives notifications from Trainer
133
+ without Trainer knowing the concrete implementation.
134
+
135
+ Lifecycle:
136
+ on_train_start → [on_epoch_start → [on_step]* → on_validation → on_epoch_end]* → on_train_end
137
+
138
+ Example:
139
+ >>> class MyObserver(TrainingObserver):
140
+ ... def on_step(self, metrics: StepMetrics) -> None:
141
+ ... print(f"Loss: {metrics.loss}")
142
+ ...
143
+ >>> trainer.add_observer(MyObserver())
144
+ """
145
+
146
+ def __init__(self, name: str = None):
147
+ """
148
+ Initialize observer.
149
+
150
+ Args:
151
+ name: Human-readable name for this observer
152
+ """
153
+ self.name = name or self.__class__.__name__
154
+ self._enabled = True
155
+
156
+ @property
157
+ def enabled(self) -> bool:
158
+ """Whether this observer is active."""
159
+ return self._enabled
160
+
161
+ def enable(self) -> None:
162
+ """Enable this observer."""
163
+ self._enabled = True
164
+
165
+ def disable(self) -> None:
166
+ """Disable this observer."""
167
+ self._enabled = False
168
+
169
+ # Required methods (must implement)
170
+
171
+ @abstractmethod
172
+ def on_train_start(self, config: Dict[str, Any]) -> None:
173
+ """
174
+ Called when training begins.
175
+
176
+ Args:
177
+ config: Training configuration dictionary
178
+ """
179
+ pass
180
+
181
+ @abstractmethod
182
+ def on_step(self, metrics: StepMetrics) -> None:
183
+ """
184
+ Called after each training step.
185
+
186
+ Args:
187
+ metrics: Step metrics (loss, grad_norm, etc.)
188
+ """
189
+ pass
190
+
191
+ @abstractmethod
192
+ def on_validation(self, metrics: ValidationMetrics) -> None:
193
+ """
194
+ Called after validation.
195
+
196
+ Args:
197
+ metrics: Validation metrics (val_loss, val_perplexity)
198
+ """
199
+ pass
200
+
201
+ @abstractmethod
202
+ def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
203
+ """
204
+ Called when training completes.
205
+
206
+ Args:
207
+ final_metrics: Final training summary
208
+ """
209
+ pass
210
+
211
+ # Optional methods (override if needed)
212
+
213
+ def on_epoch_start(self, epoch: int) -> None:
214
+ """
215
+ Called at start of each epoch.
216
+
217
+ Args:
218
+ epoch: Current epoch number (0-indexed)
219
+ """
220
+ pass
221
+
222
+ def on_epoch_end(self, epoch: int, metrics: Dict[str, float]) -> None:
223
+ """
224
+ Called at end of each epoch.
225
+
226
+ Args:
227
+ epoch: Current epoch number
228
+ metrics: Epoch summary metrics
229
+ """
230
+ pass
231
+
232
+ def on_checkpoint(self, step: int, checkpoint_path: str) -> None:
233
+ """
234
+ Called when a checkpoint is saved.
235
+
236
+ Args:
237
+ step: Training step
238
+ checkpoint_path: Path to saved checkpoint
239
+ """
240
+ pass
241
+
242
+ def on_gradient_computed(self, metrics: GradientMetrics) -> None:
243
+ """
244
+ Called after gradients are computed (before optimizer step).
245
+
246
+ Useful for gradient flow analysis.
247
+
248
+ Args:
249
+ metrics: Gradient statistics
250
+ """
251
+ pass
252
+
253
+
254
+ class ObserverManager:
255
+ """
256
+ Manages multiple observers and dispatches events.
257
+
258
+ Implements Composite pattern - treats collection of observers uniformly.
259
+
260
+ Example:
261
+ >>> manager = ObserverManager()
262
+ >>> manager.add(ConsoleCallback())
263
+ >>> manager.add(MetricsTracker(log_dir='logs'))
264
+ >>> manager.notify_step(step_metrics)
265
+ """
266
+
267
+ def __init__(self):
268
+ """Initialize empty observer list."""
269
+ self._observers: List[TrainingObserver] = []
270
+
271
+ def add(self, observer: TrainingObserver) -> 'ObserverManager':
272
+ """
273
+ Add an observer.
274
+
275
+ Args:
276
+ observer: Observer to add
277
+
278
+ Returns:
279
+ Self for method chaining
280
+ """
281
+ self._observers.append(observer)
282
+ return self
283
+
284
+ def remove(self, observer: TrainingObserver) -> bool:
285
+ """
286
+ Remove an observer.
287
+
288
+ Args:
289
+ observer: Observer to remove
290
+
291
+ Returns:
292
+ True if removed, False if not found
293
+ """
294
+ try:
295
+ self._observers.remove(observer)
296
+ return True
297
+ except ValueError:
298
+ return False
299
+
300
+ def get_observer(self, name: str) -> Optional[TrainingObserver]:
301
+ """
302
+ Get observer by name.
303
+
304
+ Args:
305
+ name: Observer name
306
+
307
+ Returns:
308
+ Observer if found, None otherwise
309
+ """
310
+ for obs in self._observers:
311
+ if obs.name == name:
312
+ return obs
313
+ return None
314
+
315
+ @property
316
+ def observers(self) -> List[TrainingObserver]:
317
+ """Get list of all observers."""
318
+ return self._observers.copy()
319
+
320
+ def _notify(self, event: str, *args, **kwargs) -> None:
321
+ """
322
+ Dispatch event to all enabled observers.
323
+
324
+ Args:
325
+ event: Event method name
326
+ *args, **kwargs: Event arguments
327
+ """
328
+ for observer in self._observers:
329
+ if observer.enabled:
330
+ handler = getattr(observer, event, None)
331
+ if handler and callable(handler):
332
+ try:
333
+ handler(*args, **kwargs)
334
+ except Exception as e:
335
+ print(f"Warning: Observer {observer.name} failed on {event}: {e}")
336
+
337
+ # Convenience methods for each event type
338
+
339
+ def notify_train_start(self, config: Dict[str, Any]) -> None:
340
+ """Notify all observers of training start."""
341
+ self._notify('on_train_start', config)
342
+
343
+ def notify_train_end(self, final_metrics: Dict[str, Any]) -> None:
344
+ """Notify all observers of training end."""
345
+ self._notify('on_train_end', final_metrics)
346
+
347
+ def notify_epoch_start(self, epoch: int) -> None:
348
+ """Notify all observers of epoch start."""
349
+ self._notify('on_epoch_start', epoch)
350
+
351
+ def notify_epoch_end(self, epoch: int, metrics: Dict[str, float]) -> None:
352
+ """Notify all observers of epoch end."""
353
+ self._notify('on_epoch_end', epoch, metrics)
354
+
355
+ def notify_step(self, metrics: StepMetrics) -> None:
356
+ """Notify all observers of training step."""
357
+ self._notify('on_step', metrics)
358
+
359
+ def notify_validation(self, metrics: ValidationMetrics) -> None:
360
+ """Notify all observers of validation."""
361
+ self._notify('on_validation', metrics)
362
+
363
+ def notify_checkpoint(self, step: int, checkpoint_path: str) -> None:
364
+ """Notify all observers of checkpoint save."""
365
+ self._notify('on_checkpoint', step, checkpoint_path)
366
+
367
+ def notify_gradient(self, metrics: GradientMetrics) -> None:
368
+ """Notify all observers of gradient computation."""
369
+ self._notify('on_gradient_computed', metrics)