trainpulse 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trainpulse/__init__.py +33 -0
- trainpulse/_types.py +105 -0
- trainpulse/callbacks.py +175 -0
- trainpulse/cli.py +125 -0
- trainpulse/detectors.py +216 -0
- trainpulse/early_stopping.py +162 -0
- trainpulse/monitor.py +223 -0
- trainpulse/py.typed +0 -0
- trainpulse/report.py +156 -0
- trainpulse/wandb_callback.py +89 -0
- trainpulse-0.2.0.dist-info/METADATA +241 -0
- trainpulse-0.2.0.dist-info/RECORD +14 -0
- trainpulse-0.2.0.dist-info/WHEEL +4 -0
- trainpulse-0.2.0.dist-info/licenses/LICENSE +177 -0
trainpulse/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""trainpulse — lightweight training health monitor."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.2.0"
|
|
4
|
+
|
|
5
|
+
from trainpulse._types import (
|
|
6
|
+
Alert,
|
|
7
|
+
AlertSeverity,
|
|
8
|
+
MetricSnapshot,
|
|
9
|
+
MetricType,
|
|
10
|
+
MonitorConfig,
|
|
11
|
+
TrainingReport,
|
|
12
|
+
TrainpulseError,
|
|
13
|
+
)
|
|
14
|
+
from trainpulse.callbacks import TrainingCallback
|
|
15
|
+
from trainpulse.early_stopping import EarlyStopping, EarlyStopResult, recommend_patience
|
|
16
|
+
from trainpulse.monitor import Monitor
|
|
17
|
+
from trainpulse.wandb_callback import WandbCallback
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"Alert",
|
|
21
|
+
"AlertSeverity",
|
|
22
|
+
"EarlyStopping",
|
|
23
|
+
"EarlyStopResult",
|
|
24
|
+
"MetricSnapshot",
|
|
25
|
+
"MetricType",
|
|
26
|
+
"Monitor",
|
|
27
|
+
"MonitorConfig",
|
|
28
|
+
"TrainingCallback",
|
|
29
|
+
"TrainingReport",
|
|
30
|
+
"TrainpulseError",
|
|
31
|
+
"WandbCallback",
|
|
32
|
+
"recommend_patience",
|
|
33
|
+
]
|
trainpulse/_types.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""Core types for trainpulse."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AlertSeverity(str, Enum):
|
|
11
|
+
INFO = "info"
|
|
12
|
+
WARNING = "warning"
|
|
13
|
+
CRITICAL = "critical"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MetricType(str, Enum):
|
|
17
|
+
LOSS = "loss"
|
|
18
|
+
GRADIENT_NORM = "gradient_norm"
|
|
19
|
+
LEARNING_RATE = "learning_rate"
|
|
20
|
+
STEP_TIME = "step_time"
|
|
21
|
+
MEMORY_USED = "memory_used"
|
|
22
|
+
CUSTOM = "custom"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class MetricSnapshot:
|
|
27
|
+
"""A single metric recording at a given step."""
|
|
28
|
+
|
|
29
|
+
step: int
|
|
30
|
+
name: str
|
|
31
|
+
value: float
|
|
32
|
+
metric_type: MetricType = MetricType.CUSTOM
|
|
33
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class Alert:
|
|
38
|
+
"""An alert triggered by a detector."""
|
|
39
|
+
|
|
40
|
+
step: int
|
|
41
|
+
severity: AlertSeverity
|
|
42
|
+
detector: str
|
|
43
|
+
message: str
|
|
44
|
+
metric_name: str = ""
|
|
45
|
+
metric_value: float = 0.0
|
|
46
|
+
|
|
47
|
+
def __str__(self) -> str:
|
|
48
|
+
return f"[{self.severity.value.upper()}] Step {self.step}: {self.message}"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class MonitorConfig:
|
|
53
|
+
"""Configuration for the training monitor."""
|
|
54
|
+
|
|
55
|
+
# Loss spike detection
|
|
56
|
+
loss_spike_threshold: float = 5.0 # Multiplier over rolling average
|
|
57
|
+
loss_spike_window: int = 50 # Rolling window size
|
|
58
|
+
|
|
59
|
+
# Gradient monitoring
|
|
60
|
+
grad_norm_threshold: float = 100.0 # Max acceptable gradient norm
|
|
61
|
+
grad_vanish_threshold: float = 1e-7 # Min acceptable gradient norm
|
|
62
|
+
|
|
63
|
+
# NaN/Inf detection
|
|
64
|
+
check_nan: bool = True
|
|
65
|
+
|
|
66
|
+
# Learning rate monitoring
|
|
67
|
+
lr_change_threshold: float = 10.0 # Max acceptable LR ratio change per step
|
|
68
|
+
|
|
69
|
+
# Step time monitoring
|
|
70
|
+
step_time_spike_threshold: float = 3.0 # Multiplier over rolling average
|
|
71
|
+
step_time_window: int = 20
|
|
72
|
+
|
|
73
|
+
# Plateau detection
|
|
74
|
+
plateau_patience: int = 100 # Steps without improvement
|
|
75
|
+
plateau_min_delta: float = 1e-5 # Minimum change to count as improvement
|
|
76
|
+
|
|
77
|
+
# General
|
|
78
|
+
log_interval: int = 1 # Record every N steps
|
|
79
|
+
alert_callbacks: List[Callable[[Alert], None]] = field(default_factory=list)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass
|
|
83
|
+
class TrainingReport:
|
|
84
|
+
"""Summary report of training health."""
|
|
85
|
+
|
|
86
|
+
total_steps: int
|
|
87
|
+
alerts: List[Alert]
|
|
88
|
+
metrics_summary: Dict[str, Dict[str, float]] # metric_name -> {min, max, mean, last}
|
|
89
|
+
health_score: float # 0.0 (terrible) to 1.0 (perfect)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def n_warnings(self) -> int:
|
|
93
|
+
return sum(1 for a in self.alerts if a.severity == AlertSeverity.WARNING)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def n_critical(self) -> int:
|
|
97
|
+
return sum(1 for a in self.alerts if a.severity == AlertSeverity.CRITICAL)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def is_healthy(self) -> bool:
|
|
101
|
+
return self.n_critical == 0
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class TrainpulseError(Exception):
|
|
105
|
+
"""Base exception."""
|
trainpulse/callbacks.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Framework integration callbacks."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
from trainpulse._types import MonitorConfig
|
|
8
|
+
from trainpulse.monitor import Monitor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TrainingCallback:
|
|
12
|
+
"""Generic callback that wraps a Monitor for any training loop.
|
|
13
|
+
|
|
14
|
+
Usage::
|
|
15
|
+
|
|
16
|
+
cb = TrainingCallback()
|
|
17
|
+
for step in range(num_steps):
|
|
18
|
+
cb.on_step_begin(step)
|
|
19
|
+
loss = train_step()
|
|
20
|
+
grad_norm = get_grad_norm()
|
|
21
|
+
cb.on_step_end(step, loss=loss, grad_norm=grad_norm, lr=optimizer.lr)
|
|
22
|
+
report = cb.report()
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, config: Optional[MonitorConfig] = None) -> None:
|
|
26
|
+
self.monitor = Monitor(config)
|
|
27
|
+
|
|
28
|
+
def on_step_begin(self, step: int) -> None:
|
|
29
|
+
self.monitor.step_start()
|
|
30
|
+
|
|
31
|
+
def on_step_end(
|
|
32
|
+
self,
|
|
33
|
+
step: int,
|
|
34
|
+
loss: Optional[float] = None,
|
|
35
|
+
grad_norm: Optional[float] = None,
|
|
36
|
+
lr: Optional[float] = None,
|
|
37
|
+
**extra_metrics: float,
|
|
38
|
+
) -> None:
|
|
39
|
+
self.monitor.step_end(step)
|
|
40
|
+
if loss is not None:
|
|
41
|
+
self.monitor.log("loss", step, loss)
|
|
42
|
+
if grad_norm is not None:
|
|
43
|
+
self.monitor.log("grad_norm", step, grad_norm)
|
|
44
|
+
if lr is not None:
|
|
45
|
+
self.monitor.log("learning_rate", step, lr)
|
|
46
|
+
for name, value in extra_metrics.items():
|
|
47
|
+
self.monitor.log(name, step, value)
|
|
48
|
+
|
|
49
|
+
def report(self) -> Any:
|
|
50
|
+
return self.monitor.report()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def make_pytorch_hooks(
|
|
54
|
+
model: Any,
|
|
55
|
+
monitor: Monitor,
|
|
56
|
+
) -> list:
|
|
57
|
+
"""Register backward hooks on a PyTorch model to track gradient norms.
|
|
58
|
+
|
|
59
|
+
Returns a list of hook handles for cleanup.
|
|
60
|
+
|
|
61
|
+
Usage::
|
|
62
|
+
|
|
63
|
+
monitor = Monitor()
|
|
64
|
+
hooks = make_pytorch_hooks(model, monitor)
|
|
65
|
+
# ... training loop ...
|
|
66
|
+
for h in hooks:
|
|
67
|
+
h.remove()
|
|
68
|
+
"""
|
|
69
|
+
try:
|
|
70
|
+
import torch # type: ignore[import-untyped]
|
|
71
|
+
except ImportError:
|
|
72
|
+
raise ImportError("PyTorch is required: pip install trainpulse[torch]")
|
|
73
|
+
|
|
74
|
+
handles = []
|
|
75
|
+
_step_counter = {"step": 0}
|
|
76
|
+
|
|
77
|
+
def _grad_hook(module: Any, grad_input: Any, grad_output: Any) -> None:
|
|
78
|
+
total_norm = 0.0
|
|
79
|
+
for p in module.parameters():
|
|
80
|
+
if p.grad is not None:
|
|
81
|
+
total_norm += p.grad.data.norm(2).item() ** 2
|
|
82
|
+
total_norm = total_norm**0.5
|
|
83
|
+
monitor.log("grad_norm", _step_counter["step"], total_norm)
|
|
84
|
+
|
|
85
|
+
for module in model.modules():
|
|
86
|
+
# Only hook leaf modules to avoid double counting
|
|
87
|
+
if len(list(module.children())) == 0:
|
|
88
|
+
h = module.register_full_backward_hook(_grad_hook)
|
|
89
|
+
handles.append(h)
|
|
90
|
+
|
|
91
|
+
class _HookManager:
|
|
92
|
+
"""Manages hooks and step counter."""
|
|
93
|
+
|
|
94
|
+
def __init__(self, handles: list, counter: dict) -> None:
|
|
95
|
+
self._handles = handles
|
|
96
|
+
self._counter = counter
|
|
97
|
+
|
|
98
|
+
def set_step(self, step: int) -> None:
|
|
99
|
+
self._counter["step"] = step
|
|
100
|
+
|
|
101
|
+
def remove(self) -> None:
|
|
102
|
+
for h in self._handles:
|
|
103
|
+
h.remove()
|
|
104
|
+
|
|
105
|
+
manager = _HookManager(handles, _step_counter)
|
|
106
|
+
return [manager]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def make_hf_callback(
|
|
110
|
+
config: Optional[MonitorConfig] = None,
|
|
111
|
+
) -> Any:
|
|
112
|
+
"""Create a HuggingFace Trainer callback.
|
|
113
|
+
|
|
114
|
+
Usage::
|
|
115
|
+
|
|
116
|
+
from transformers import Trainer
|
|
117
|
+
from trainpulse.callbacks import make_hf_callback
|
|
118
|
+
|
|
119
|
+
cb = make_hf_callback()
|
|
120
|
+
trainer = Trainer(..., callbacks=[cb])
|
|
121
|
+
trainer.train()
|
|
122
|
+
report = cb.trainpulse_monitor.report()
|
|
123
|
+
|
|
124
|
+
Returns a TrainerCallback subclass instance.
|
|
125
|
+
"""
|
|
126
|
+
try:
|
|
127
|
+
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments # type: ignore[import-untyped]
|
|
128
|
+
except ImportError:
|
|
129
|
+
raise ImportError(
|
|
130
|
+
"HuggingFace transformers is required: pip install transformers"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
monitor = Monitor(config)
|
|
134
|
+
|
|
135
|
+
class _TrainpulseCallback(TrainerCallback):
|
|
136
|
+
def __init__(self) -> None:
|
|
137
|
+
self.trainpulse_monitor = monitor
|
|
138
|
+
|
|
139
|
+
def on_step_begin(
|
|
140
|
+
self,
|
|
141
|
+
args: TrainingArguments,
|
|
142
|
+
state: TrainerState,
|
|
143
|
+
control: TrainerControl,
|
|
144
|
+
**kwargs: Any,
|
|
145
|
+
) -> None:
|
|
146
|
+
monitor.step_start()
|
|
147
|
+
|
|
148
|
+
def on_log(
|
|
149
|
+
self,
|
|
150
|
+
args: TrainingArguments,
|
|
151
|
+
state: TrainerState,
|
|
152
|
+
control: TrainerControl,
|
|
153
|
+
logs: Optional[dict] = None,
|
|
154
|
+
**kwargs: Any,
|
|
155
|
+
) -> None:
|
|
156
|
+
if logs is None:
|
|
157
|
+
return
|
|
158
|
+
step = state.global_step
|
|
159
|
+
if "loss" in logs:
|
|
160
|
+
monitor.log("loss", step, logs["loss"])
|
|
161
|
+
if "learning_rate" in logs:
|
|
162
|
+
monitor.log("learning_rate", step, logs["learning_rate"])
|
|
163
|
+
if "grad_norm" in logs:
|
|
164
|
+
monitor.log("grad_norm", step, logs["grad_norm"])
|
|
165
|
+
|
|
166
|
+
def on_step_end(
|
|
167
|
+
self,
|
|
168
|
+
args: TrainingArguments,
|
|
169
|
+
state: TrainerState,
|
|
170
|
+
control: TrainerControl,
|
|
171
|
+
**kwargs: Any,
|
|
172
|
+
) -> None:
|
|
173
|
+
monitor.step_end(state.global_step)
|
|
174
|
+
|
|
175
|
+
return _TrainpulseCallback()
|
trainpulse/cli.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""CLI for trainpulse."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _build_cli(): # type: ignore[no-untyped-def]
|
|
12
|
+
"""Build the CLI. Deferred import so click/rich are optional."""
|
|
13
|
+
try:
|
|
14
|
+
import click
|
|
15
|
+
except ImportError:
|
|
16
|
+
raise SystemExit(
|
|
17
|
+
"CLI dependencies required: pip install trainpulse[cli]"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
@click.group()
|
|
21
|
+
@click.version_option(package_name="trainpulse")
|
|
22
|
+
def cli() -> None:
|
|
23
|
+
"""trainpulse — lightweight training health monitor."""
|
|
24
|
+
|
|
25
|
+
@cli.command()
|
|
26
|
+
@click.argument("log_file", type=click.Path(exists=True))
|
|
27
|
+
@click.option("--json-out", "-o", type=click.Path(), default=None, help="Save report as JSON.")
|
|
28
|
+
@click.option("--loss-key", default="loss", help="Key for loss in log entries.")
|
|
29
|
+
@click.option("--grad-key", default="grad_norm", help="Key for gradient norm.")
|
|
30
|
+
@click.option("--lr-key", default="learning_rate", help="Key for learning rate.")
|
|
31
|
+
@click.option("--step-key", default="step", help="Key for step number.")
|
|
32
|
+
def analyze(
|
|
33
|
+
log_file: str,
|
|
34
|
+
json_out: Optional[str],
|
|
35
|
+
loss_key: str,
|
|
36
|
+
grad_key: str,
|
|
37
|
+
lr_key: str,
|
|
38
|
+
step_key: str,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Analyze a training log file (JSONL format).
|
|
41
|
+
|
|
42
|
+
Each line should be a JSON object with at least a step and loss field.
|
|
43
|
+
"""
|
|
44
|
+
from trainpulse.monitor import Monitor
|
|
45
|
+
from trainpulse.report import format_report_rich, format_report_text, save_json
|
|
46
|
+
|
|
47
|
+
monitor = Monitor()
|
|
48
|
+
path = Path(log_file)
|
|
49
|
+
|
|
50
|
+
n_lines = 0
|
|
51
|
+
for line in path.read_text().splitlines():
|
|
52
|
+
line = line.strip()
|
|
53
|
+
if not line:
|
|
54
|
+
continue
|
|
55
|
+
try:
|
|
56
|
+
entry = json.loads(line)
|
|
57
|
+
except json.JSONDecodeError:
|
|
58
|
+
continue
|
|
59
|
+
|
|
60
|
+
step = entry.get(step_key, n_lines)
|
|
61
|
+
|
|
62
|
+
if loss_key in entry:
|
|
63
|
+
monitor.log("loss", step, float(entry[loss_key]))
|
|
64
|
+
if grad_key in entry:
|
|
65
|
+
monitor.log("grad_norm", step, float(entry[grad_key]))
|
|
66
|
+
if lr_key in entry:
|
|
67
|
+
monitor.log("learning_rate", step, float(entry[lr_key]))
|
|
68
|
+
|
|
69
|
+
n_lines += 1
|
|
70
|
+
|
|
71
|
+
if n_lines == 0:
|
|
72
|
+
click.echo("No log entries found.", err=True)
|
|
73
|
+
sys.exit(1)
|
|
74
|
+
|
|
75
|
+
report = monitor.report()
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
output = format_report_rich(report)
|
|
79
|
+
except Exception:
|
|
80
|
+
output = format_report_text(report)
|
|
81
|
+
click.echo(output)
|
|
82
|
+
|
|
83
|
+
if json_out:
|
|
84
|
+
save_json(report, json_out)
|
|
85
|
+
click.echo(f"Report saved to {json_out}")
|
|
86
|
+
|
|
87
|
+
@cli.command()
|
|
88
|
+
@click.argument("report_file", type=click.Path(exists=True))
|
|
89
|
+
def show(report_file: str) -> None:
|
|
90
|
+
"""Display a previously saved JSON report."""
|
|
91
|
+
from trainpulse._types import Alert, AlertSeverity, TrainingReport
|
|
92
|
+
from trainpulse.report import format_report_rich, format_report_text, load_json
|
|
93
|
+
|
|
94
|
+
data = load_json(report_file)
|
|
95
|
+
alerts = [
|
|
96
|
+
Alert(
|
|
97
|
+
step=a["step"],
|
|
98
|
+
severity=AlertSeverity(a["severity"]),
|
|
99
|
+
detector=a["detector"],
|
|
100
|
+
message=a["message"],
|
|
101
|
+
metric_name=a.get("metric_name", ""),
|
|
102
|
+
metric_value=a.get("metric_value", 0.0),
|
|
103
|
+
)
|
|
104
|
+
for a in data.get("alerts", [])
|
|
105
|
+
]
|
|
106
|
+
report = TrainingReport(
|
|
107
|
+
total_steps=data["total_steps"],
|
|
108
|
+
alerts=alerts,
|
|
109
|
+
metrics_summary=data.get("metrics_summary", {}),
|
|
110
|
+
health_score=data.get("health_score", 1.0),
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
output = format_report_rich(report)
|
|
115
|
+
except Exception:
|
|
116
|
+
output = format_report_text(report)
|
|
117
|
+
click.echo(output)
|
|
118
|
+
|
|
119
|
+
return cli
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
cli = _build_cli()
|
|
123
|
+
|
|
124
|
+
if __name__ == "__main__":
|
|
125
|
+
cli()
|
trainpulse/detectors.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Anomaly detectors for training metrics."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from typing import List, Optional, Sequence
|
|
7
|
+
|
|
8
|
+
from trainpulse._types import Alert, AlertSeverity, MetricType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RollingWindow:
|
|
12
|
+
"""Fixed-size rolling window of float values."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, size: int) -> None:
|
|
15
|
+
self._size = max(size, 1)
|
|
16
|
+
self._values: list[float] = []
|
|
17
|
+
|
|
18
|
+
def add(self, value: float) -> None:
|
|
19
|
+
self._values.append(value)
|
|
20
|
+
if len(self._values) > self._size:
|
|
21
|
+
self._values.pop(0)
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def values(self) -> List[float]:
|
|
25
|
+
return list(self._values)
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def mean(self) -> float:
|
|
29
|
+
if not self._values:
|
|
30
|
+
return 0.0
|
|
31
|
+
return sum(self._values) / len(self._values)
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def std(self) -> float:
|
|
35
|
+
if len(self._values) < 2:
|
|
36
|
+
return 0.0
|
|
37
|
+
m = self.mean
|
|
38
|
+
return math.sqrt(sum((v - m) ** 2 for v in self._values) / len(self._values))
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def is_full(self) -> bool:
|
|
42
|
+
return len(self._values) >= self._size
|
|
43
|
+
|
|
44
|
+
def __len__(self) -> int:
|
|
45
|
+
return len(self._values)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class NaNDetector:
|
|
49
|
+
"""Detect NaN or Inf values in metrics."""
|
|
50
|
+
|
|
51
|
+
def check(self, step: int, name: str, value: float) -> Optional[Alert]:
|
|
52
|
+
if math.isnan(value):
|
|
53
|
+
return Alert(
|
|
54
|
+
step=step,
|
|
55
|
+
severity=AlertSeverity.CRITICAL,
|
|
56
|
+
detector="nan_detector",
|
|
57
|
+
message=f"{name} is NaN",
|
|
58
|
+
metric_name=name,
|
|
59
|
+
metric_value=value,
|
|
60
|
+
)
|
|
61
|
+
if math.isinf(value):
|
|
62
|
+
return Alert(
|
|
63
|
+
step=step,
|
|
64
|
+
severity=AlertSeverity.CRITICAL,
|
|
65
|
+
detector="nan_detector",
|
|
66
|
+
message=f"{name} is Inf",
|
|
67
|
+
metric_name=name,
|
|
68
|
+
metric_value=value,
|
|
69
|
+
)
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class LossSpikeDetector:
|
|
74
|
+
"""Detect sudden spikes in loss values."""
|
|
75
|
+
|
|
76
|
+
def __init__(self, threshold: float = 5.0, window_size: int = 50) -> None:
|
|
77
|
+
self._threshold = threshold
|
|
78
|
+
self._window = RollingWindow(window_size)
|
|
79
|
+
|
|
80
|
+
def check(self, step: int, value: float) -> Optional[Alert]:
|
|
81
|
+
if self._window.is_full:
|
|
82
|
+
avg = self._window.mean
|
|
83
|
+
if avg > 0 and value > avg * self._threshold:
|
|
84
|
+
alert = Alert(
|
|
85
|
+
step=step,
|
|
86
|
+
severity=AlertSeverity.WARNING,
|
|
87
|
+
detector="loss_spike",
|
|
88
|
+
message=f"Loss spike: {value:.4f} ({value/avg:.1f}x rolling avg {avg:.4f})",
|
|
89
|
+
metric_name="loss",
|
|
90
|
+
metric_value=value,
|
|
91
|
+
)
|
|
92
|
+
self._window.add(value)
|
|
93
|
+
return alert
|
|
94
|
+
|
|
95
|
+
self._window.add(value)
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class GradientDetector:
|
|
100
|
+
"""Detect gradient explosion and vanishing."""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
explosion_threshold: float = 100.0,
|
|
105
|
+
vanish_threshold: float = 1e-7,
|
|
106
|
+
) -> None:
|
|
107
|
+
self._explosion = explosion_threshold
|
|
108
|
+
self._vanish = vanish_threshold
|
|
109
|
+
|
|
110
|
+
def check(self, step: int, grad_norm: float) -> Optional[Alert]:
|
|
111
|
+
if grad_norm > self._explosion:
|
|
112
|
+
return Alert(
|
|
113
|
+
step=step,
|
|
114
|
+
severity=AlertSeverity.CRITICAL,
|
|
115
|
+
detector="gradient",
|
|
116
|
+
message=f"Gradient explosion: norm={grad_norm:.4f} (threshold={self._explosion})",
|
|
117
|
+
metric_name="gradient_norm",
|
|
118
|
+
metric_value=grad_norm,
|
|
119
|
+
)
|
|
120
|
+
if 0 < grad_norm < self._vanish:
|
|
121
|
+
return Alert(
|
|
122
|
+
step=step,
|
|
123
|
+
severity=AlertSeverity.WARNING,
|
|
124
|
+
detector="gradient",
|
|
125
|
+
message=f"Vanishing gradient: norm={grad_norm:.2e} (threshold={self._vanish:.2e})",
|
|
126
|
+
metric_name="gradient_norm",
|
|
127
|
+
metric_value=grad_norm,
|
|
128
|
+
)
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class LRDetector:
|
|
133
|
+
"""Detect suspicious learning rate changes."""
|
|
134
|
+
|
|
135
|
+
def __init__(self, change_threshold: float = 10.0) -> None:
|
|
136
|
+
self._threshold = change_threshold
|
|
137
|
+
self._prev_lr: Optional[float] = None
|
|
138
|
+
|
|
139
|
+
def check(self, step: int, lr: float) -> Optional[Alert]:
|
|
140
|
+
if self._prev_lr is not None and self._prev_lr > 0 and lr > 0:
|
|
141
|
+
ratio = max(lr / self._prev_lr, self._prev_lr / lr)
|
|
142
|
+
if ratio > self._threshold:
|
|
143
|
+
alert = Alert(
|
|
144
|
+
step=step,
|
|
145
|
+
severity=AlertSeverity.WARNING,
|
|
146
|
+
detector="learning_rate",
|
|
147
|
+
message=f"LR changed {ratio:.1f}x in one step ({self._prev_lr:.2e} → {lr:.2e})",
|
|
148
|
+
metric_name="learning_rate",
|
|
149
|
+
metric_value=lr,
|
|
150
|
+
)
|
|
151
|
+
self._prev_lr = lr
|
|
152
|
+
return alert
|
|
153
|
+
self._prev_lr = lr
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class PlateauDetector:
|
|
158
|
+
"""Detect loss plateaus (no improvement for N steps)."""
|
|
159
|
+
|
|
160
|
+
def __init__(self, patience: int = 100, min_delta: float = 1e-5) -> None:
|
|
161
|
+
self._patience = patience
|
|
162
|
+
self._min_delta = min_delta
|
|
163
|
+
self._best_loss: Optional[float] = None
|
|
164
|
+
self._steps_without_improvement = 0
|
|
165
|
+
self._alerted = False
|
|
166
|
+
|
|
167
|
+
def check(self, step: int, loss: float) -> Optional[Alert]:
|
|
168
|
+
if self._best_loss is None:
|
|
169
|
+
self._best_loss = loss
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
if loss < self._best_loss - self._min_delta:
|
|
173
|
+
self._best_loss = loss
|
|
174
|
+
self._steps_without_improvement = 0
|
|
175
|
+
self._alerted = False
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
self._steps_without_improvement += 1
|
|
179
|
+
|
|
180
|
+
if self._steps_without_improvement >= self._patience and not self._alerted:
|
|
181
|
+
self._alerted = True
|
|
182
|
+
return Alert(
|
|
183
|
+
step=step,
|
|
184
|
+
severity=AlertSeverity.WARNING,
|
|
185
|
+
detector="plateau",
|
|
186
|
+
message=f"Loss plateau: no improvement for {self._steps_without_improvement} steps (best={self._best_loss:.6f})",
|
|
187
|
+
metric_name="loss",
|
|
188
|
+
metric_value=loss,
|
|
189
|
+
)
|
|
190
|
+
return None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class StepTimeDetector:
|
|
194
|
+
"""Detect unusually slow training steps."""
|
|
195
|
+
|
|
196
|
+
def __init__(self, threshold: float = 3.0, window_size: int = 20) -> None:
|
|
197
|
+
self._threshold = threshold
|
|
198
|
+
self._window = RollingWindow(window_size)
|
|
199
|
+
|
|
200
|
+
def check(self, step: int, step_time: float) -> Optional[Alert]:
|
|
201
|
+
if self._window.is_full:
|
|
202
|
+
avg = self._window.mean
|
|
203
|
+
if avg > 0 and step_time > avg * self._threshold:
|
|
204
|
+
alert = Alert(
|
|
205
|
+
step=step,
|
|
206
|
+
severity=AlertSeverity.WARNING,
|
|
207
|
+
detector="step_time",
|
|
208
|
+
message=f"Slow step: {step_time:.2f}s ({step_time/avg:.1f}x avg {avg:.2f}s)",
|
|
209
|
+
metric_name="step_time",
|
|
210
|
+
metric_value=step_time,
|
|
211
|
+
)
|
|
212
|
+
self._window.add(step_time)
|
|
213
|
+
return alert
|
|
214
|
+
|
|
215
|
+
self._window.add(step_time)
|
|
216
|
+
return None
|