trainpulse 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
trainpulse/__init__.py ADDED
@@ -0,0 +1,33 @@
1
+ """trainpulse — lightweight training health monitor."""
2
+
3
+ __version__ = "0.2.0"
4
+
5
+ from trainpulse._types import (
6
+ Alert,
7
+ AlertSeverity,
8
+ MetricSnapshot,
9
+ MetricType,
10
+ MonitorConfig,
11
+ TrainingReport,
12
+ TrainpulseError,
13
+ )
14
+ from trainpulse.callbacks import TrainingCallback
15
+ from trainpulse.early_stopping import EarlyStopping, EarlyStopResult, recommend_patience
16
+ from trainpulse.monitor import Monitor
17
+ from trainpulse.wandb_callback import WandbCallback
18
+
19
+ __all__ = [
20
+ "Alert",
21
+ "AlertSeverity",
22
+ "EarlyStopping",
23
+ "EarlyStopResult",
24
+ "MetricSnapshot",
25
+ "MetricType",
26
+ "Monitor",
27
+ "MonitorConfig",
28
+ "TrainingCallback",
29
+ "TrainingReport",
30
+ "TrainpulseError",
31
+ "WandbCallback",
32
+ "recommend_patience",
33
+ ]
trainpulse/_types.py ADDED
@@ -0,0 +1,105 @@
1
+ """Core types for trainpulse."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from enum import Enum
7
+ from typing import Any, Callable, Dict, List, Optional
8
+
9
+
10
+ class AlertSeverity(str, Enum):
11
+ INFO = "info"
12
+ WARNING = "warning"
13
+ CRITICAL = "critical"
14
+
15
+
16
+ class MetricType(str, Enum):
17
+ LOSS = "loss"
18
+ GRADIENT_NORM = "gradient_norm"
19
+ LEARNING_RATE = "learning_rate"
20
+ STEP_TIME = "step_time"
21
+ MEMORY_USED = "memory_used"
22
+ CUSTOM = "custom"
23
+
24
+
25
+ @dataclass
26
+ class MetricSnapshot:
27
+ """A single metric recording at a given step."""
28
+
29
+ step: int
30
+ name: str
31
+ value: float
32
+ metric_type: MetricType = MetricType.CUSTOM
33
+ metadata: Dict[str, Any] = field(default_factory=dict)
34
+
35
+
36
+ @dataclass
37
+ class Alert:
38
+ """An alert triggered by a detector."""
39
+
40
+ step: int
41
+ severity: AlertSeverity
42
+ detector: str
43
+ message: str
44
+ metric_name: str = ""
45
+ metric_value: float = 0.0
46
+
47
+ def __str__(self) -> str:
48
+ return f"[{self.severity.value.upper()}] Step {self.step}: {self.message}"
49
+
50
+
51
+ @dataclass
52
+ class MonitorConfig:
53
+ """Configuration for the training monitor."""
54
+
55
+ # Loss spike detection
56
+ loss_spike_threshold: float = 5.0 # Multiplier over rolling average
57
+ loss_spike_window: int = 50 # Rolling window size
58
+
59
+ # Gradient monitoring
60
+ grad_norm_threshold: float = 100.0 # Max acceptable gradient norm
61
+ grad_vanish_threshold: float = 1e-7 # Min acceptable gradient norm
62
+
63
+ # NaN/Inf detection
64
+ check_nan: bool = True
65
+
66
+ # Learning rate monitoring
67
+ lr_change_threshold: float = 10.0 # Max acceptable LR ratio change per step
68
+
69
+ # Step time monitoring
70
+ step_time_spike_threshold: float = 3.0 # Multiplier over rolling average
71
+ step_time_window: int = 20
72
+
73
+ # Plateau detection
74
+ plateau_patience: int = 100 # Steps without improvement
75
+ plateau_min_delta: float = 1e-5 # Minimum change to count as improvement
76
+
77
+ # General
78
+ log_interval: int = 1 # Record every N steps
79
+ alert_callbacks: List[Callable[[Alert], None]] = field(default_factory=list)
80
+
81
+
82
+ @dataclass
83
+ class TrainingReport:
84
+ """Summary report of training health."""
85
+
86
+ total_steps: int
87
+ alerts: List[Alert]
88
+ metrics_summary: Dict[str, Dict[str, float]] # metric_name -> {min, max, mean, last}
89
+ health_score: float # 0.0 (terrible) to 1.0 (perfect)
90
+
91
+ @property
92
+ def n_warnings(self) -> int:
93
+ return sum(1 for a in self.alerts if a.severity == AlertSeverity.WARNING)
94
+
95
+ @property
96
+ def n_critical(self) -> int:
97
+ return sum(1 for a in self.alerts if a.severity == AlertSeverity.CRITICAL)
98
+
99
+ @property
100
+ def is_healthy(self) -> bool:
101
+ return self.n_critical == 0
102
+
103
+
104
+ class TrainpulseError(Exception):
105
+ """Base exception."""
@@ -0,0 +1,175 @@
1
+ """Framework integration callbacks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Optional
6
+
7
+ from trainpulse._types import MonitorConfig
8
+ from trainpulse.monitor import Monitor
9
+
10
+
11
+ class TrainingCallback:
12
+ """Generic callback that wraps a Monitor for any training loop.
13
+
14
+ Usage::
15
+
16
+ cb = TrainingCallback()
17
+ for step in range(num_steps):
18
+ cb.on_step_begin(step)
19
+ loss = train_step()
20
+ grad_norm = get_grad_norm()
21
+ cb.on_step_end(step, loss=loss, grad_norm=grad_norm, lr=optimizer.lr)
22
+ report = cb.report()
23
+ """
24
+
25
+ def __init__(self, config: Optional[MonitorConfig] = None) -> None:
26
+ self.monitor = Monitor(config)
27
+
28
+ def on_step_begin(self, step: int) -> None:
29
+ self.monitor.step_start()
30
+
31
+ def on_step_end(
32
+ self,
33
+ step: int,
34
+ loss: Optional[float] = None,
35
+ grad_norm: Optional[float] = None,
36
+ lr: Optional[float] = None,
37
+ **extra_metrics: float,
38
+ ) -> None:
39
+ self.monitor.step_end(step)
40
+ if loss is not None:
41
+ self.monitor.log("loss", step, loss)
42
+ if grad_norm is not None:
43
+ self.monitor.log("grad_norm", step, grad_norm)
44
+ if lr is not None:
45
+ self.monitor.log("learning_rate", step, lr)
46
+ for name, value in extra_metrics.items():
47
+ self.monitor.log(name, step, value)
48
+
49
+ def report(self) -> Any:
50
+ return self.monitor.report()
51
+
52
+
53
+ def make_pytorch_hooks(
54
+ model: Any,
55
+ monitor: Monitor,
56
+ ) -> list:
57
+ """Register backward hooks on a PyTorch model to track gradient norms.
58
+
59
+ Returns a list of hook handles for cleanup.
60
+
61
+ Usage::
62
+
63
+ monitor = Monitor()
64
+ hooks = make_pytorch_hooks(model, monitor)
65
+ # ... training loop ...
66
+ for h in hooks:
67
+ h.remove()
68
+ """
69
+ try:
70
+ import torch # type: ignore[import-untyped]
71
+ except ImportError:
72
+ raise ImportError("PyTorch is required: pip install trainpulse[torch]")
73
+
74
+ handles = []
75
+ _step_counter = {"step": 0}
76
+
77
+ def _grad_hook(module: Any, grad_input: Any, grad_output: Any) -> None:
78
+ total_norm = 0.0
79
+ for p in module.parameters():
80
+ if p.grad is not None:
81
+ total_norm += p.grad.data.norm(2).item() ** 2
82
+ total_norm = total_norm**0.5
83
+ monitor.log("grad_norm", _step_counter["step"], total_norm)
84
+
85
+ for module in model.modules():
86
+ # Only hook leaf modules to avoid double counting
87
+ if len(list(module.children())) == 0:
88
+ h = module.register_full_backward_hook(_grad_hook)
89
+ handles.append(h)
90
+
91
+ class _HookManager:
92
+ """Manages hooks and step counter."""
93
+
94
+ def __init__(self, handles: list, counter: dict) -> None:
95
+ self._handles = handles
96
+ self._counter = counter
97
+
98
+ def set_step(self, step: int) -> None:
99
+ self._counter["step"] = step
100
+
101
+ def remove(self) -> None:
102
+ for h in self._handles:
103
+ h.remove()
104
+
105
+ manager = _HookManager(handles, _step_counter)
106
+ return [manager]
107
+
108
+
109
+ def make_hf_callback(
110
+ config: Optional[MonitorConfig] = None,
111
+ ) -> Any:
112
+ """Create a HuggingFace Trainer callback.
113
+
114
+ Usage::
115
+
116
+ from transformers import Trainer
117
+ from trainpulse.callbacks import make_hf_callback
118
+
119
+ cb = make_hf_callback()
120
+ trainer = Trainer(..., callbacks=[cb])
121
+ trainer.train()
122
+ report = cb.trainpulse_monitor.report()
123
+
124
+ Returns a TrainerCallback subclass instance.
125
+ """
126
+ try:
127
+ from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments # type: ignore[import-untyped]
128
+ except ImportError:
129
+ raise ImportError(
130
+ "HuggingFace transformers is required: pip install transformers"
131
+ )
132
+
133
+ monitor = Monitor(config)
134
+
135
+ class _TrainpulseCallback(TrainerCallback):
136
+ def __init__(self) -> None:
137
+ self.trainpulse_monitor = monitor
138
+
139
+ def on_step_begin(
140
+ self,
141
+ args: TrainingArguments,
142
+ state: TrainerState,
143
+ control: TrainerControl,
144
+ **kwargs: Any,
145
+ ) -> None:
146
+ monitor.step_start()
147
+
148
+ def on_log(
149
+ self,
150
+ args: TrainingArguments,
151
+ state: TrainerState,
152
+ control: TrainerControl,
153
+ logs: Optional[dict] = None,
154
+ **kwargs: Any,
155
+ ) -> None:
156
+ if logs is None:
157
+ return
158
+ step = state.global_step
159
+ if "loss" in logs:
160
+ monitor.log("loss", step, logs["loss"])
161
+ if "learning_rate" in logs:
162
+ monitor.log("learning_rate", step, logs["learning_rate"])
163
+ if "grad_norm" in logs:
164
+ monitor.log("grad_norm", step, logs["grad_norm"])
165
+
166
+ def on_step_end(
167
+ self,
168
+ args: TrainingArguments,
169
+ state: TrainerState,
170
+ control: TrainerControl,
171
+ **kwargs: Any,
172
+ ) -> None:
173
+ monitor.step_end(state.global_step)
174
+
175
+ return _TrainpulseCallback()
trainpulse/cli.py ADDED
@@ -0,0 +1,125 @@
1
+ """CLI for trainpulse."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+
11
+ def _build_cli(): # type: ignore[no-untyped-def]
12
+ """Build the CLI. Deferred import so click/rich are optional."""
13
+ try:
14
+ import click
15
+ except ImportError:
16
+ raise SystemExit(
17
+ "CLI dependencies required: pip install trainpulse[cli]"
18
+ )
19
+
20
+ @click.group()
21
+ @click.version_option(package_name="trainpulse")
22
+ def cli() -> None:
23
+ """trainpulse — lightweight training health monitor."""
24
+
25
+ @cli.command()
26
+ @click.argument("log_file", type=click.Path(exists=True))
27
+ @click.option("--json-out", "-o", type=click.Path(), default=None, help="Save report as JSON.")
28
+ @click.option("--loss-key", default="loss", help="Key for loss in log entries.")
29
+ @click.option("--grad-key", default="grad_norm", help="Key for gradient norm.")
30
+ @click.option("--lr-key", default="learning_rate", help="Key for learning rate.")
31
+ @click.option("--step-key", default="step", help="Key for step number.")
32
+ def analyze(
33
+ log_file: str,
34
+ json_out: Optional[str],
35
+ loss_key: str,
36
+ grad_key: str,
37
+ lr_key: str,
38
+ step_key: str,
39
+ ) -> None:
40
+ """Analyze a training log file (JSONL format).
41
+
42
+ Each line should be a JSON object with at least a step and loss field.
43
+ """
44
+ from trainpulse.monitor import Monitor
45
+ from trainpulse.report import format_report_rich, format_report_text, save_json
46
+
47
+ monitor = Monitor()
48
+ path = Path(log_file)
49
+
50
+ n_lines = 0
51
+ for line in path.read_text().splitlines():
52
+ line = line.strip()
53
+ if not line:
54
+ continue
55
+ try:
56
+ entry = json.loads(line)
57
+ except json.JSONDecodeError:
58
+ continue
59
+
60
+ step = entry.get(step_key, n_lines)
61
+
62
+ if loss_key in entry:
63
+ monitor.log("loss", step, float(entry[loss_key]))
64
+ if grad_key in entry:
65
+ monitor.log("grad_norm", step, float(entry[grad_key]))
66
+ if lr_key in entry:
67
+ monitor.log("learning_rate", step, float(entry[lr_key]))
68
+
69
+ n_lines += 1
70
+
71
+ if n_lines == 0:
72
+ click.echo("No log entries found.", err=True)
73
+ sys.exit(1)
74
+
75
+ report = monitor.report()
76
+
77
+ try:
78
+ output = format_report_rich(report)
79
+ except Exception:
80
+ output = format_report_text(report)
81
+ click.echo(output)
82
+
83
+ if json_out:
84
+ save_json(report, json_out)
85
+ click.echo(f"Report saved to {json_out}")
86
+
87
+ @cli.command()
88
+ @click.argument("report_file", type=click.Path(exists=True))
89
+ def show(report_file: str) -> None:
90
+ """Display a previously saved JSON report."""
91
+ from trainpulse._types import Alert, AlertSeverity, TrainingReport
92
+ from trainpulse.report import format_report_rich, format_report_text, load_json
93
+
94
+ data = load_json(report_file)
95
+ alerts = [
96
+ Alert(
97
+ step=a["step"],
98
+ severity=AlertSeverity(a["severity"]),
99
+ detector=a["detector"],
100
+ message=a["message"],
101
+ metric_name=a.get("metric_name", ""),
102
+ metric_value=a.get("metric_value", 0.0),
103
+ )
104
+ for a in data.get("alerts", [])
105
+ ]
106
+ report = TrainingReport(
107
+ total_steps=data["total_steps"],
108
+ alerts=alerts,
109
+ metrics_summary=data.get("metrics_summary", {}),
110
+ health_score=data.get("health_score", 1.0),
111
+ )
112
+
113
+ try:
114
+ output = format_report_rich(report)
115
+ except Exception:
116
+ output = format_report_text(report)
117
+ click.echo(output)
118
+
119
+ return cli
120
+
121
+
122
+ cli = _build_cli()
123
+
124
+ if __name__ == "__main__":
125
+ cli()
@@ -0,0 +1,216 @@
1
+ """Anomaly detectors for training metrics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import List, Optional, Sequence
7
+
8
+ from trainpulse._types import Alert, AlertSeverity, MetricType
9
+
10
+
11
+ class RollingWindow:
12
+ """Fixed-size rolling window of float values."""
13
+
14
+ def __init__(self, size: int) -> None:
15
+ self._size = max(size, 1)
16
+ self._values: list[float] = []
17
+
18
+ def add(self, value: float) -> None:
19
+ self._values.append(value)
20
+ if len(self._values) > self._size:
21
+ self._values.pop(0)
22
+
23
+ @property
24
+ def values(self) -> List[float]:
25
+ return list(self._values)
26
+
27
+ @property
28
+ def mean(self) -> float:
29
+ if not self._values:
30
+ return 0.0
31
+ return sum(self._values) / len(self._values)
32
+
33
+ @property
34
+ def std(self) -> float:
35
+ if len(self._values) < 2:
36
+ return 0.0
37
+ m = self.mean
38
+ return math.sqrt(sum((v - m) ** 2 for v in self._values) / len(self._values))
39
+
40
+ @property
41
+ def is_full(self) -> bool:
42
+ return len(self._values) >= self._size
43
+
44
+ def __len__(self) -> int:
45
+ return len(self._values)
46
+
47
+
48
+ class NaNDetector:
49
+ """Detect NaN or Inf values in metrics."""
50
+
51
+ def check(self, step: int, name: str, value: float) -> Optional[Alert]:
52
+ if math.isnan(value):
53
+ return Alert(
54
+ step=step,
55
+ severity=AlertSeverity.CRITICAL,
56
+ detector="nan_detector",
57
+ message=f"{name} is NaN",
58
+ metric_name=name,
59
+ metric_value=value,
60
+ )
61
+ if math.isinf(value):
62
+ return Alert(
63
+ step=step,
64
+ severity=AlertSeverity.CRITICAL,
65
+ detector="nan_detector",
66
+ message=f"{name} is Inf",
67
+ metric_name=name,
68
+ metric_value=value,
69
+ )
70
+ return None
71
+
72
+
73
+ class LossSpikeDetector:
74
+ """Detect sudden spikes in loss values."""
75
+
76
+ def __init__(self, threshold: float = 5.0, window_size: int = 50) -> None:
77
+ self._threshold = threshold
78
+ self._window = RollingWindow(window_size)
79
+
80
+ def check(self, step: int, value: float) -> Optional[Alert]:
81
+ if self._window.is_full:
82
+ avg = self._window.mean
83
+ if avg > 0 and value > avg * self._threshold:
84
+ alert = Alert(
85
+ step=step,
86
+ severity=AlertSeverity.WARNING,
87
+ detector="loss_spike",
88
+ message=f"Loss spike: {value:.4f} ({value/avg:.1f}x rolling avg {avg:.4f})",
89
+ metric_name="loss",
90
+ metric_value=value,
91
+ )
92
+ self._window.add(value)
93
+ return alert
94
+
95
+ self._window.add(value)
96
+ return None
97
+
98
+
99
+ class GradientDetector:
100
+ """Detect gradient explosion and vanishing."""
101
+
102
+ def __init__(
103
+ self,
104
+ explosion_threshold: float = 100.0,
105
+ vanish_threshold: float = 1e-7,
106
+ ) -> None:
107
+ self._explosion = explosion_threshold
108
+ self._vanish = vanish_threshold
109
+
110
+ def check(self, step: int, grad_norm: float) -> Optional[Alert]:
111
+ if grad_norm > self._explosion:
112
+ return Alert(
113
+ step=step,
114
+ severity=AlertSeverity.CRITICAL,
115
+ detector="gradient",
116
+ message=f"Gradient explosion: norm={grad_norm:.4f} (threshold={self._explosion})",
117
+ metric_name="gradient_norm",
118
+ metric_value=grad_norm,
119
+ )
120
+ if 0 < grad_norm < self._vanish:
121
+ return Alert(
122
+ step=step,
123
+ severity=AlertSeverity.WARNING,
124
+ detector="gradient",
125
+ message=f"Vanishing gradient: norm={grad_norm:.2e} (threshold={self._vanish:.2e})",
126
+ metric_name="gradient_norm",
127
+ metric_value=grad_norm,
128
+ )
129
+ return None
130
+
131
+
132
+ class LRDetector:
133
+ """Detect suspicious learning rate changes."""
134
+
135
+ def __init__(self, change_threshold: float = 10.0) -> None:
136
+ self._threshold = change_threshold
137
+ self._prev_lr: Optional[float] = None
138
+
139
+ def check(self, step: int, lr: float) -> Optional[Alert]:
140
+ if self._prev_lr is not None and self._prev_lr > 0 and lr > 0:
141
+ ratio = max(lr / self._prev_lr, self._prev_lr / lr)
142
+ if ratio > self._threshold:
143
+ alert = Alert(
144
+ step=step,
145
+ severity=AlertSeverity.WARNING,
146
+ detector="learning_rate",
147
+ message=f"LR changed {ratio:.1f}x in one step ({self._prev_lr:.2e} → {lr:.2e})",
148
+ metric_name="learning_rate",
149
+ metric_value=lr,
150
+ )
151
+ self._prev_lr = lr
152
+ return alert
153
+ self._prev_lr = lr
154
+ return None
155
+
156
+
157
+ class PlateauDetector:
158
+ """Detect loss plateaus (no improvement for N steps)."""
159
+
160
+ def __init__(self, patience: int = 100, min_delta: float = 1e-5) -> None:
161
+ self._patience = patience
162
+ self._min_delta = min_delta
163
+ self._best_loss: Optional[float] = None
164
+ self._steps_without_improvement = 0
165
+ self._alerted = False
166
+
167
+ def check(self, step: int, loss: float) -> Optional[Alert]:
168
+ if self._best_loss is None:
169
+ self._best_loss = loss
170
+ return None
171
+
172
+ if loss < self._best_loss - self._min_delta:
173
+ self._best_loss = loss
174
+ self._steps_without_improvement = 0
175
+ self._alerted = False
176
+ return None
177
+
178
+ self._steps_without_improvement += 1
179
+
180
+ if self._steps_without_improvement >= self._patience and not self._alerted:
181
+ self._alerted = True
182
+ return Alert(
183
+ step=step,
184
+ severity=AlertSeverity.WARNING,
185
+ detector="plateau",
186
+ message=f"Loss plateau: no improvement for {self._steps_without_improvement} steps (best={self._best_loss:.6f})",
187
+ metric_name="loss",
188
+ metric_value=loss,
189
+ )
190
+ return None
191
+
192
+
193
+ class StepTimeDetector:
194
+ """Detect unusually slow training steps."""
195
+
196
+ def __init__(self, threshold: float = 3.0, window_size: int = 20) -> None:
197
+ self._threshold = threshold
198
+ self._window = RollingWindow(window_size)
199
+
200
+ def check(self, step: int, step_time: float) -> Optional[Alert]:
201
+ if self._window.is_full:
202
+ avg = self._window.mean
203
+ if avg > 0 and step_time > avg * self._threshold:
204
+ alert = Alert(
205
+ step=step,
206
+ severity=AlertSeverity.WARNING,
207
+ detector="step_time",
208
+ message=f"Slow step: {step_time:.2f}s ({step_time/avg:.1f}x avg {avg:.2f}s)",
209
+ metric_name="step_time",
210
+ metric_value=step_time,
211
+ )
212
+ self._window.add(step_time)
213
+ return alert
214
+
215
+ self._window.add(step_time)
216
+ return None