gptmed 0.3.5__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -3
- gptmed/model/__init__.py +2 -2
- gptmed/observability/__init__.py +43 -0
- gptmed/observability/base.py +369 -0
- gptmed/observability/callbacks.py +397 -0
- gptmed/observability/metrics_tracker.py +544 -0
- gptmed/services/training_service.py +161 -7
- gptmed/training/trainer.py +124 -10
- gptmed/utils/checkpoints.py +1 -1
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/METADATA +180 -43
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/RECORD +15 -11
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/WHEEL +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/entry_points.txt +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {gptmed-0.3.5.dist-info → gptmed-0.4.0.dist-info}/top_level.txt +0 -0
gptmed/__init__.py
CHANGED
|
@@ -3,7 +3,14 @@ GptMed: A lightweight GPT-based language model framework
|
|
|
3
3
|
|
|
4
4
|
A domain-agnostic framework for training custom question-answering models.
|
|
5
5
|
Train your own GPT model on any Q&A dataset - medical, technical support,
|
|
6
|
-
education, or any other domain.
|
|
6
|
+
education, legal, customer service, or any other domain.
|
|
7
|
+
|
|
8
|
+
Key Features:
|
|
9
|
+
- Simple 3-step training: config → train → generate
|
|
10
|
+
- Built-in training observability with loss curves and metrics
|
|
11
|
+
- Flexible model sizes (tiny, small, medium)
|
|
12
|
+
- Device-agnostic (CPU/CUDA with auto-detection)
|
|
13
|
+
- XAI-ready architecture for model interpretability
|
|
7
14
|
|
|
8
15
|
Quick Start:
|
|
9
16
|
>>> import gptmed
|
|
@@ -13,7 +20,7 @@ Quick Start:
|
|
|
13
20
|
>>>
|
|
14
21
|
>>> # 2. Edit my_config.yaml with your settings
|
|
15
22
|
>>>
|
|
16
|
-
>>> # 3. Train your model
|
|
23
|
+
>>> # 3. Train your model (with automatic metrics tracking)
|
|
17
24
|
>>> results = gptmed.train_from_config('my_config.yaml')
|
|
18
25
|
>>>
|
|
19
26
|
>>> # 4. Generate answers
|
|
@@ -23,6 +30,16 @@ Quick Start:
|
|
|
23
30
|
... prompt='Your question here?'
|
|
24
31
|
... )
|
|
25
32
|
|
|
33
|
+
Observability (v0.4.0+):
|
|
34
|
+
>>> from gptmed import MetricsTracker, EarlyStoppingCallback
|
|
35
|
+
>>>
|
|
36
|
+
>>> # Training automatically tracks metrics and generates reports:
|
|
37
|
+
>>> # - Loss curves (train/val)
|
|
38
|
+
>>> # - Gradient norms
|
|
39
|
+
>>> # - Learning rate schedule
|
|
40
|
+
>>> # - Perplexity
|
|
41
|
+
>>> # - Training health checks
|
|
42
|
+
|
|
26
43
|
Advanced Usage:
|
|
27
44
|
>>> from gptmed.model.architecture import GPTTransformer
|
|
28
45
|
>>> from gptmed.model.configs.model_config import get_small_config
|
|
@@ -32,7 +49,7 @@ Advanced Usage:
|
|
|
32
49
|
>>> model = GPTTransformer(config)
|
|
33
50
|
"""
|
|
34
51
|
|
|
35
|
-
__version__ = "0.
|
|
52
|
+
__version__ = "0.4.0"
|
|
36
53
|
__author__ = "Sanjog Sigdel"
|
|
37
54
|
__email__ = "sigdelsanjog@gmail.com"
|
|
38
55
|
|
|
@@ -47,6 +64,16 @@ from gptmed.api import (
|
|
|
47
64
|
from gptmed.model.architecture import GPTTransformer
|
|
48
65
|
from gptmed.model.configs.model_config import ModelConfig, get_small_config, get_tiny_config
|
|
49
66
|
|
|
67
|
+
# Observability module
|
|
68
|
+
from gptmed.observability import (
|
|
69
|
+
TrainingObserver,
|
|
70
|
+
ObserverManager,
|
|
71
|
+
MetricsTracker,
|
|
72
|
+
ConsoleCallback,
|
|
73
|
+
JSONLoggerCallback,
|
|
74
|
+
EarlyStoppingCallback,
|
|
75
|
+
)
|
|
76
|
+
|
|
50
77
|
__all__ = [
|
|
51
78
|
# Simple API
|
|
52
79
|
"create_config",
|
|
@@ -57,4 +84,11 @@ __all__ = [
|
|
|
57
84
|
"ModelConfig",
|
|
58
85
|
"get_small_config",
|
|
59
86
|
"get_tiny_config",
|
|
87
|
+
# Observability
|
|
88
|
+
"TrainingObserver",
|
|
89
|
+
"ObserverManager",
|
|
90
|
+
"MetricsTracker",
|
|
91
|
+
"ConsoleCallback",
|
|
92
|
+
"JSONLoggerCallback",
|
|
93
|
+
"EarlyStoppingCallback",
|
|
60
94
|
]
|
gptmed/model/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
GPTMED Model Package
|
|
3
3
|
|
|
4
|
-
This package contains the GPT-based transformer architecture for
|
|
4
|
+
This package contains the GPT-based transformer architecture for general purpose QA.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from gptmed.model.architecture import GPTTransformer
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Observability Module for GptMed
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Provides training observability, metrics tracking, and XAI capabilities.
|
|
6
|
+
Implements Observer Pattern for decoupled monitoring.
|
|
7
|
+
|
|
8
|
+
DESIGN PATTERNS:
|
|
9
|
+
- Observer Pattern: Trainer emits events, observers react independently
|
|
10
|
+
- Strategy Pattern: Swap between different logging backends
|
|
11
|
+
- Open/Closed Principle: Add new observers without modifying Trainer
|
|
12
|
+
|
|
13
|
+
COMPONENTS:
|
|
14
|
+
- TrainingObserver: Abstract base class for all observers
|
|
15
|
+
- MetricsTracker: Enhanced loss curves and training metrics
|
|
16
|
+
- Callbacks: TensorBoard, W&B, Early Stopping, etc.
|
|
17
|
+
|
|
18
|
+
FUTURE EXTENSIONS:
|
|
19
|
+
- Attention visualization
|
|
20
|
+
- Saliency maps / input attribution
|
|
21
|
+
- Embedding space analysis
|
|
22
|
+
- Gradient flow analysis
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from gptmed.observability.base import TrainingObserver, ObserverManager
|
|
26
|
+
from gptmed.observability.metrics_tracker import MetricsTracker
|
|
27
|
+
from gptmed.observability.callbacks import (
|
|
28
|
+
ConsoleCallback,
|
|
29
|
+
JSONLoggerCallback,
|
|
30
|
+
EarlyStoppingCallback,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
# Base classes
|
|
35
|
+
"TrainingObserver",
|
|
36
|
+
"ObserverManager",
|
|
37
|
+
# Metrics
|
|
38
|
+
"MetricsTracker",
|
|
39
|
+
# Callbacks
|
|
40
|
+
"ConsoleCallback",
|
|
41
|
+
"JSONLoggerCallback",
|
|
42
|
+
"EarlyStoppingCallback",
|
|
43
|
+
]
|
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base Observer Classes
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Define abstract interfaces for training observation.
|
|
6
|
+
Implements Observer Pattern for decoupled monitoring.
|
|
7
|
+
|
|
8
|
+
DESIGN PATTERNS:
|
|
9
|
+
- Observer Pattern: Subjects (Trainer) notify observers of state changes
|
|
10
|
+
- Template Method: Base class defines interface, subclasses implement
|
|
11
|
+
- Dependency Inversion: Trainer depends on abstraction, not concrete observers
|
|
12
|
+
|
|
13
|
+
WHY OBSERVER PATTERN:
|
|
14
|
+
1. Decoupling: Trainer doesn't know about logging implementations
|
|
15
|
+
2. Extensibility: Add new observers without modifying Trainer
|
|
16
|
+
3. Single Responsibility: Each observer handles one concern
|
|
17
|
+
4. Open/Closed: Open for extension, closed for modification
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
from typing import Dict, Any, List, Optional
|
|
22
|
+
from dataclasses import dataclass, field
|
|
23
|
+
from enum import Enum
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TrainingEvent(Enum):
|
|
27
|
+
"""Training lifecycle events that observers can subscribe to."""
|
|
28
|
+
TRAIN_START = "on_train_start"
|
|
29
|
+
TRAIN_END = "on_train_end"
|
|
30
|
+
EPOCH_START = "on_epoch_start"
|
|
31
|
+
EPOCH_END = "on_epoch_end"
|
|
32
|
+
STEP = "on_step"
|
|
33
|
+
VALIDATION = "on_validation"
|
|
34
|
+
CHECKPOINT = "on_checkpoint"
|
|
35
|
+
GRADIENT_COMPUTED = "on_gradient_computed"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class StepMetrics:
|
|
40
|
+
"""
|
|
41
|
+
Metrics collected at each training step.
|
|
42
|
+
|
|
43
|
+
Using dataclass for:
|
|
44
|
+
- Type safety
|
|
45
|
+
- Clear documentation of expected metrics
|
|
46
|
+
- Easy serialization
|
|
47
|
+
"""
|
|
48
|
+
step: int
|
|
49
|
+
loss: float
|
|
50
|
+
learning_rate: float
|
|
51
|
+
grad_norm: float
|
|
52
|
+
batch_size: int
|
|
53
|
+
seq_len: int
|
|
54
|
+
tokens_per_sec: float = 0.0
|
|
55
|
+
perplexity: float = 0.0
|
|
56
|
+
|
|
57
|
+
def __post_init__(self):
|
|
58
|
+
"""Compute derived metrics."""
|
|
59
|
+
if self.perplexity == 0.0 and self.loss > 0:
|
|
60
|
+
import math
|
|
61
|
+
self.perplexity = math.exp(min(self.loss, 100)) # Cap to avoid overflow
|
|
62
|
+
|
|
63
|
+
def to_dict(self) -> Dict[str, float]:
|
|
64
|
+
"""Convert to dictionary for logging."""
|
|
65
|
+
return {
|
|
66
|
+
"step": self.step,
|
|
67
|
+
"loss": self.loss,
|
|
68
|
+
"learning_rate": self.learning_rate,
|
|
69
|
+
"grad_norm": self.grad_norm,
|
|
70
|
+
"batch_size": self.batch_size,
|
|
71
|
+
"seq_len": self.seq_len,
|
|
72
|
+
"tokens_per_sec": self.tokens_per_sec,
|
|
73
|
+
"perplexity": self.perplexity,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class ValidationMetrics:
|
|
79
|
+
"""Metrics collected during validation."""
|
|
80
|
+
step: int
|
|
81
|
+
val_loss: float
|
|
82
|
+
val_perplexity: float = 0.0
|
|
83
|
+
|
|
84
|
+
def __post_init__(self):
|
|
85
|
+
"""Compute derived metrics."""
|
|
86
|
+
if self.val_perplexity == 0.0 and self.val_loss > 0:
|
|
87
|
+
import math
|
|
88
|
+
self.val_perplexity = math.exp(min(self.val_loss, 100))
|
|
89
|
+
|
|
90
|
+
def to_dict(self) -> Dict[str, float]:
|
|
91
|
+
"""Convert to dictionary for logging."""
|
|
92
|
+
return {
|
|
93
|
+
"step": self.step,
|
|
94
|
+
"val_loss": self.val_loss,
|
|
95
|
+
"val_perplexity": self.val_perplexity,
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class GradientMetrics:
|
|
101
|
+
"""
|
|
102
|
+
Gradient statistics for observability.
|
|
103
|
+
|
|
104
|
+
Used for detecting:
|
|
105
|
+
- Vanishing gradients (norm → 0)
|
|
106
|
+
- Exploding gradients (norm → ∞)
|
|
107
|
+
- Dead neurons (high zero fraction)
|
|
108
|
+
"""
|
|
109
|
+
step: int
|
|
110
|
+
total_norm: float
|
|
111
|
+
layer_norms: Dict[str, float] = field(default_factory=dict)
|
|
112
|
+
max_grad: float = 0.0
|
|
113
|
+
min_grad: float = 0.0
|
|
114
|
+
zero_fraction: float = 0.0
|
|
115
|
+
|
|
116
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
117
|
+
"""Convert to dictionary for logging."""
|
|
118
|
+
return {
|
|
119
|
+
"step": self.step,
|
|
120
|
+
"total_norm": self.total_norm,
|
|
121
|
+
"layer_norms": self.layer_norms,
|
|
122
|
+
"max_grad": self.max_grad,
|
|
123
|
+
"min_grad": self.min_grad,
|
|
124
|
+
"zero_fraction": self.zero_fraction,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TrainingObserver(ABC):
|
|
129
|
+
"""
|
|
130
|
+
Abstract base class for training observers.
|
|
131
|
+
|
|
132
|
+
Implements Observer Pattern - receives notifications from Trainer
|
|
133
|
+
without Trainer knowing the concrete implementation.
|
|
134
|
+
|
|
135
|
+
Lifecycle:
|
|
136
|
+
on_train_start → [on_epoch_start → [on_step]* → on_validation → on_epoch_end]* → on_train_end
|
|
137
|
+
|
|
138
|
+
Example:
|
|
139
|
+
>>> class MyObserver(TrainingObserver):
|
|
140
|
+
... def on_step(self, metrics: StepMetrics) -> None:
|
|
141
|
+
... print(f"Loss: {metrics.loss}")
|
|
142
|
+
...
|
|
143
|
+
>>> trainer.add_observer(MyObserver())
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(self, name: str = None):
|
|
147
|
+
"""
|
|
148
|
+
Initialize observer.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
name: Human-readable name for this observer
|
|
152
|
+
"""
|
|
153
|
+
self.name = name or self.__class__.__name__
|
|
154
|
+
self._enabled = True
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def enabled(self) -> bool:
|
|
158
|
+
"""Whether this observer is active."""
|
|
159
|
+
return self._enabled
|
|
160
|
+
|
|
161
|
+
def enable(self) -> None:
|
|
162
|
+
"""Enable this observer."""
|
|
163
|
+
self._enabled = True
|
|
164
|
+
|
|
165
|
+
def disable(self) -> None:
|
|
166
|
+
"""Disable this observer."""
|
|
167
|
+
self._enabled = False
|
|
168
|
+
|
|
169
|
+
# Required methods (must implement)
|
|
170
|
+
|
|
171
|
+
@abstractmethod
|
|
172
|
+
def on_train_start(self, config: Dict[str, Any]) -> None:
|
|
173
|
+
"""
|
|
174
|
+
Called when training begins.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
config: Training configuration dictionary
|
|
178
|
+
"""
|
|
179
|
+
pass
|
|
180
|
+
|
|
181
|
+
@abstractmethod
|
|
182
|
+
def on_step(self, metrics: StepMetrics) -> None:
|
|
183
|
+
"""
|
|
184
|
+
Called after each training step.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
metrics: Step metrics (loss, grad_norm, etc.)
|
|
188
|
+
"""
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
@abstractmethod
|
|
192
|
+
def on_validation(self, metrics: ValidationMetrics) -> None:
|
|
193
|
+
"""
|
|
194
|
+
Called after validation.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
metrics: Validation metrics (val_loss, val_perplexity)
|
|
198
|
+
"""
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
@abstractmethod
|
|
202
|
+
def on_train_end(self, final_metrics: Dict[str, Any]) -> None:
|
|
203
|
+
"""
|
|
204
|
+
Called when training completes.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
final_metrics: Final training summary
|
|
208
|
+
"""
|
|
209
|
+
pass
|
|
210
|
+
|
|
211
|
+
# Optional methods (override if needed)
|
|
212
|
+
|
|
213
|
+
def on_epoch_start(self, epoch: int) -> None:
|
|
214
|
+
"""
|
|
215
|
+
Called at start of each epoch.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
epoch: Current epoch number (0-indexed)
|
|
219
|
+
"""
|
|
220
|
+
pass
|
|
221
|
+
|
|
222
|
+
def on_epoch_end(self, epoch: int, metrics: Dict[str, float]) -> None:
|
|
223
|
+
"""
|
|
224
|
+
Called at end of each epoch.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
epoch: Current epoch number
|
|
228
|
+
metrics: Epoch summary metrics
|
|
229
|
+
"""
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
def on_checkpoint(self, step: int, checkpoint_path: str) -> None:
|
|
233
|
+
"""
|
|
234
|
+
Called when a checkpoint is saved.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
step: Training step
|
|
238
|
+
checkpoint_path: Path to saved checkpoint
|
|
239
|
+
"""
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
def on_gradient_computed(self, metrics: GradientMetrics) -> None:
|
|
243
|
+
"""
|
|
244
|
+
Called after gradients are computed (before optimizer step).
|
|
245
|
+
|
|
246
|
+
Useful for gradient flow analysis.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
metrics: Gradient statistics
|
|
250
|
+
"""
|
|
251
|
+
pass
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class ObserverManager:
|
|
255
|
+
"""
|
|
256
|
+
Manages multiple observers and dispatches events.
|
|
257
|
+
|
|
258
|
+
Implements Composite pattern - treats collection of observers uniformly.
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
>>> manager = ObserverManager()
|
|
262
|
+
>>> manager.add(ConsoleCallback())
|
|
263
|
+
>>> manager.add(MetricsTracker(log_dir='logs'))
|
|
264
|
+
>>> manager.notify_step(step_metrics)
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
def __init__(self):
|
|
268
|
+
"""Initialize empty observer list."""
|
|
269
|
+
self._observers: List[TrainingObserver] = []
|
|
270
|
+
|
|
271
|
+
def add(self, observer: TrainingObserver) -> 'ObserverManager':
|
|
272
|
+
"""
|
|
273
|
+
Add an observer.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
observer: Observer to add
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Self for method chaining
|
|
280
|
+
"""
|
|
281
|
+
self._observers.append(observer)
|
|
282
|
+
return self
|
|
283
|
+
|
|
284
|
+
def remove(self, observer: TrainingObserver) -> bool:
|
|
285
|
+
"""
|
|
286
|
+
Remove an observer.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
observer: Observer to remove
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
True if removed, False if not found
|
|
293
|
+
"""
|
|
294
|
+
try:
|
|
295
|
+
self._observers.remove(observer)
|
|
296
|
+
return True
|
|
297
|
+
except ValueError:
|
|
298
|
+
return False
|
|
299
|
+
|
|
300
|
+
def get_observer(self, name: str) -> Optional[TrainingObserver]:
|
|
301
|
+
"""
|
|
302
|
+
Get observer by name.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
name: Observer name
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Observer if found, None otherwise
|
|
309
|
+
"""
|
|
310
|
+
for obs in self._observers:
|
|
311
|
+
if obs.name == name:
|
|
312
|
+
return obs
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def observers(self) -> List[TrainingObserver]:
|
|
317
|
+
"""Get list of all observers."""
|
|
318
|
+
return self._observers.copy()
|
|
319
|
+
|
|
320
|
+
def _notify(self, event: str, *args, **kwargs) -> None:
|
|
321
|
+
"""
|
|
322
|
+
Dispatch event to all enabled observers.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
event: Event method name
|
|
326
|
+
*args, **kwargs: Event arguments
|
|
327
|
+
"""
|
|
328
|
+
for observer in self._observers:
|
|
329
|
+
if observer.enabled:
|
|
330
|
+
handler = getattr(observer, event, None)
|
|
331
|
+
if handler and callable(handler):
|
|
332
|
+
try:
|
|
333
|
+
handler(*args, **kwargs)
|
|
334
|
+
except Exception as e:
|
|
335
|
+
print(f"Warning: Observer {observer.name} failed on {event}: {e}")
|
|
336
|
+
|
|
337
|
+
# Convenience methods for each event type
|
|
338
|
+
|
|
339
|
+
def notify_train_start(self, config: Dict[str, Any]) -> None:
|
|
340
|
+
"""Notify all observers of training start."""
|
|
341
|
+
self._notify('on_train_start', config)
|
|
342
|
+
|
|
343
|
+
def notify_train_end(self, final_metrics: Dict[str, Any]) -> None:
|
|
344
|
+
"""Notify all observers of training end."""
|
|
345
|
+
self._notify('on_train_end', final_metrics)
|
|
346
|
+
|
|
347
|
+
def notify_epoch_start(self, epoch: int) -> None:
|
|
348
|
+
"""Notify all observers of epoch start."""
|
|
349
|
+
self._notify('on_epoch_start', epoch)
|
|
350
|
+
|
|
351
|
+
def notify_epoch_end(self, epoch: int, metrics: Dict[str, float]) -> None:
|
|
352
|
+
"""Notify all observers of epoch end."""
|
|
353
|
+
self._notify('on_epoch_end', epoch, metrics)
|
|
354
|
+
|
|
355
|
+
def notify_step(self, metrics: StepMetrics) -> None:
|
|
356
|
+
"""Notify all observers of training step."""
|
|
357
|
+
self._notify('on_step', metrics)
|
|
358
|
+
|
|
359
|
+
def notify_validation(self, metrics: ValidationMetrics) -> None:
|
|
360
|
+
"""Notify all observers of validation."""
|
|
361
|
+
self._notify('on_validation', metrics)
|
|
362
|
+
|
|
363
|
+
def notify_checkpoint(self, step: int, checkpoint_path: str) -> None:
|
|
364
|
+
"""Notify all observers of checkpoint save."""
|
|
365
|
+
self._notify('on_checkpoint', step, checkpoint_path)
|
|
366
|
+
|
|
367
|
+
def notify_gradient(self, metrics: GradientMetrics) -> None:
|
|
368
|
+
"""Notify all observers of gradient computation."""
|
|
369
|
+
self._notify('on_gradient_computed', metrics)
|