trainpit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
trainpit/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ """Rich CLI progress monitoring for machine learning training loops."""
2
+
3
+ from trainpit._tracker import TrainTracker, train
4
+
5
+ __all__ = ["TrainTracker", "train"]
trainpit/_state.py ADDED
@@ -0,0 +1,195 @@
1
+ """Internal state for train progress tracking."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Mapping
6
+ from time import monotonic
7
+ from typing import Annotated, Literal
8
+
9
+ from pydantic import BaseModel, ConfigDict, Field
10
+
11
+ Scalar = float | int
12
+ DisplayText = Annotated[str, Field(strict=True, min_length=1)]
13
+ FiniteFloat = Annotated[float, Field(strict=True, allow_inf_nan=False)]
14
+ NonNegativeFiniteFloat = Annotated[
15
+ float,
16
+ Field(strict=True, ge=0, allow_inf_nan=False),
17
+ ]
18
+ PositiveInt = Annotated[int, Field(strict=True, ge=1)]
19
+ TrainPhase = Literal["train"]
20
+
21
+
22
+ class TrainState(BaseModel):
23
+ """Current state for one train display."""
24
+
25
+ model_config = ConfigDict(
26
+ arbitrary_types_allowed=True,
27
+ extra="forbid",
28
+ strict=True,
29
+ validate_assignment=True,
30
+ validate_default=True,
31
+ )
32
+
33
+ label: DisplayText | None = Field(
34
+ default=None,
35
+ description="Human-readable label for the training run.",
36
+ )
37
+ total_epochs: PositiveInt | None = Field(
38
+ default=None,
39
+ description="Total epoch count when known.",
40
+ )
41
+ total_steps: PositiveInt | None = Field(
42
+ default=None,
43
+ description="Total step count per epoch when known.",
44
+ )
45
+ phase: TrainPhase = Field(
46
+ default="train",
47
+ description="Current lifecycle phase for the tracker.",
48
+ )
49
+ current_epoch: PositiveInt | None = Field(
50
+ default=None,
51
+ description="Current one-based epoch index.",
52
+ )
53
+ current_step: PositiveInt | None = Field(
54
+ default=None,
55
+ description="Current one-based step index within the active epoch.",
56
+ )
57
+ loss: FiniteFloat | None = Field(
58
+ default=None,
59
+ description="Latest finite loss value.",
60
+ )
61
+ metrics: dict[DisplayText, FiniteFloat] = Field(
62
+ default_factory=dict,
63
+ description="Latest finite scalar metrics keyed by metric name.",
64
+ )
65
+ learning_rate: NonNegativeFiniteFloat | None = Field(
66
+ default=None,
67
+ description="Latest non-negative learning rate.",
68
+ )
69
+ event: DisplayText | None = Field(
70
+ default=None,
71
+ description="Latest event message.",
72
+ )
73
+ started_at: NonNegativeFiniteFloat | None = Field(
74
+ default=None,
75
+ description="Monotonic timestamp when progress tracking started.",
76
+ )
77
+ updated_at: NonNegativeFiniteFloat | None = Field(
78
+ default=None,
79
+ description="Monotonic timestamp for the latest state update.",
80
+ )
81
+ finished_at: NonNegativeFiniteFloat | None = Field(
82
+ default=None,
83
+ description="Monotonic timestamp when progress tracking ended.",
84
+ )
85
+ started: bool = Field(
86
+ default=False,
87
+ description="Whether progress tracking has started.",
88
+ )
89
+ finished: bool = Field(
90
+ default=False,
91
+ description="Whether progress tracking has reached a terminal state.",
92
+ )
93
+ failed: bool = Field(
94
+ default=False,
95
+ description="Whether progress tracking ended with an error.",
96
+ )
97
+ error: BaseException | None = Field(
98
+ default=None,
99
+ description="Error captured when progress tracking fails.",
100
+ )
101
+
102
+ def start(self, *, now: Scalar | None = None) -> None:
103
+ timestamp = _resolve_timestamp(now)
104
+ self.started = True
105
+ self.finished = False
106
+ self.failed = False
107
+ self.error = None
108
+ self.started_at = timestamp
109
+ self.updated_at = timestamp
110
+ self.finished_at = None
111
+
112
+ def set_epoch(self, value: int, *, now: Scalar | None = None) -> None:
113
+ self.current_epoch = _require_positive_integer("epoch", value)
114
+ self._touch(_resolve_timestamp(now))
115
+
116
+ def set_step(
117
+ self,
118
+ value: int,
119
+ *,
120
+ loss: Scalar | None = None,
121
+ metrics: Mapping[str, Scalar] | None = None,
122
+ learning_rate: Scalar | None = None,
123
+ now: Scalar | None = None,
124
+ ) -> None:
125
+ self.current_step = _require_positive_integer("step", value)
126
+
127
+ if loss is not None:
128
+ self.loss = _coerce_scalar("loss", loss)
129
+ if metrics is not None:
130
+ self.metrics = {**self.metrics, **_coerce_metrics(metrics)}
131
+ if learning_rate is not None:
132
+ self.learning_rate = _coerce_scalar("learning_rate", learning_rate)
133
+
134
+ self._touch(_resolve_timestamp(now))
135
+
136
+ def set_metrics(
137
+ self,
138
+ values: Mapping[str, Scalar],
139
+ *,
140
+ now: Scalar | None = None,
141
+ ) -> None:
142
+ self.metrics = {**self.metrics, **_coerce_metrics(values)}
143
+ self._touch(_resolve_timestamp(now))
144
+
145
+ def set_event(self, message: str, *, now: Scalar | None = None) -> None:
146
+ self.event = message
147
+ self._touch(_resolve_timestamp(now))
148
+
149
+ def finish(self, *, now: Scalar | None = None) -> None:
150
+ timestamp = _resolve_timestamp(now)
151
+ self._touch(timestamp)
152
+ self.finished = True
153
+ self.failed = False
154
+ self.error = None
155
+ self.finished_at = timestamp
156
+
157
+ def fail(self, error: BaseException, *, now: Scalar | None = None) -> None:
158
+ timestamp = _resolve_timestamp(now)
159
+ self._touch(timestamp)
160
+ self.finished = True
161
+ self.failed = True
162
+ self.error = error
163
+ self.finished_at = timestamp
164
+
165
+ def _touch(self, timestamp: float) -> None:
166
+ if self.started_at is None:
167
+ self.started_at = timestamp
168
+ self.updated_at = timestamp
169
+
170
+
171
+ def _require_positive_integer(name: str, value: int) -> int:
172
+ if isinstance(value, bool) or not isinstance(value, int):
173
+ raise TypeError(f"{name} must be an integer")
174
+ if value < 1:
175
+ raise ValueError(f"{name} must be greater than or equal to 1")
176
+ return value
177
+
178
+
179
+ def _coerce_scalar(name: str, value: Scalar) -> float:
180
+ if isinstance(value, bool) or not isinstance(value, int | float):
181
+ raise TypeError(f"{name} must be a finite scalar")
182
+
183
+ return float(value)
184
+
185
+
186
+ def _resolve_timestamp(value: Scalar | None) -> float:
187
+ timestamp = monotonic() if value is None else _coerce_scalar("now", value)
188
+ if timestamp < 0:
189
+ raise ValueError("now must be greater than or equal to 0")
190
+
191
+ return timestamp
192
+
193
+
194
+ def _coerce_metrics(values: Mapping[str, Scalar]) -> dict[str, float]:
195
+ return {name: _coerce_scalar(name, value) for name, value in values.items()}
trainpit/_tracker.py ADDED
@@ -0,0 +1,126 @@
1
+ """Public train tracker entry point implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Mapping
6
+ from types import TracebackType
7
+ from typing import Self
8
+
9
+ from trainpit._state import Scalar, TrainState
10
+
11
+
12
+ class TrainTracker:
13
+ """Context manager used to update one train progress display."""
14
+
15
+ def __init__(
16
+ self,
17
+ *,
18
+ total_epochs: int | None = None,
19
+ total_steps: int | None = None,
20
+ label: str | None = None,
21
+ ) -> None:
22
+ self._state = TrainState(
23
+ label=label,
24
+ total_epochs=_optional_positive_integer("total_epochs", total_epochs),
25
+ total_steps=_optional_positive_integer("total_steps", total_steps),
26
+ )
27
+
28
+ @property
29
+ def snapshot(self) -> TrainState:
30
+ """Return a copy of the current progress state."""
31
+
32
+ return self._state.model_copy(deep=True)
33
+
34
+ def __enter__(self) -> Self:
35
+ self.start()
36
+ return self
37
+
38
+ def __exit__(
39
+ self,
40
+ exc_type: type[BaseException] | None,
41
+ exc: BaseException | None,
42
+ traceback: TracebackType | None,
43
+ ) -> bool:
44
+ del exc_type, traceback
45
+
46
+ if exc is None:
47
+ self.finish()
48
+ else:
49
+ self.fail(exc)
50
+
51
+ return False
52
+
53
+ def epoch(self, value: int) -> None:
54
+ self._state.set_epoch(value)
55
+
56
+ def start(self) -> None:
57
+ self._state.start()
58
+
59
+ def step(
60
+ self,
61
+ value: int,
62
+ *,
63
+ loss: Scalar | None = None,
64
+ metrics: Mapping[str, Scalar] | None = None,
65
+ learning_rate: Scalar | None = None,
66
+ lr: Scalar | None = None,
67
+ ) -> None:
68
+ self._state.set_step(
69
+ value,
70
+ loss=loss,
71
+ metrics=metrics,
72
+ learning_rate=_resolve_learning_rate(learning_rate, lr),
73
+ )
74
+
75
+ def update_metrics(self, values: Mapping[str, Scalar]) -> None:
76
+ self._state.set_metrics(values)
77
+
78
+ def metrics(self, values: Mapping[str, Scalar]) -> None:
79
+ self.update_metrics(values)
80
+
81
+ def log(self, message: str) -> None:
82
+ self._state.set_event(message)
83
+
84
+ def event(self, message: str) -> None:
85
+ self.log(message)
86
+
87
+ def finish(self) -> None:
88
+ self._state.finish()
89
+
90
+ def fail(self, error: BaseException) -> None:
91
+ self._state.fail(error)
92
+
93
+
94
+ def train(
95
+ *,
96
+ total_epochs: int | None = None,
97
+ total_steps: int | None = None,
98
+ label: str | None = None,
99
+ ) -> TrainTracker:
100
+ """Create a tracker for one training loop."""
101
+
102
+ return TrainTracker(
103
+ total_epochs=total_epochs,
104
+ total_steps=total_steps,
105
+ label=label,
106
+ )
107
+
108
+
109
+ def _optional_positive_integer(name: str, value: int | None) -> int | None:
110
+ if value is None:
111
+ return None
112
+ if isinstance(value, bool) or not isinstance(value, int):
113
+ raise TypeError(f"{name} must be an integer")
114
+ if value < 1:
115
+ raise ValueError(f"{name} must be greater than or equal to 1")
116
+ return value
117
+
118
+
119
+ def _resolve_learning_rate(
120
+ learning_rate: Scalar | None,
121
+ lr: Scalar | None,
122
+ ) -> Scalar | None:
123
+ if learning_rate is not None and lr is not None:
124
+ raise ValueError("Use either learning_rate or lr, not both")
125
+
126
+ return learning_rate if learning_rate is not None else lr
@@ -0,0 +1,43 @@
1
+ """Internal modules backing the public trainpit.tui facade."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from trainpit._tui.app import TrainDashboardApp as TrainDashboardApp
6
+ from trainpit._tui.formatting import (
7
+ _format_curves as _format_curves,
8
+ _format_metrics as _format_metrics,
9
+ _format_progress as _format_progress,
10
+ _format_timing as _format_timing,
11
+ )
12
+ from trainpit._tui.graphs import line_graph as line_graph
13
+ from trainpit._tui.graphs import scatter_graph as scatter_graph
14
+ from trainpit._tui.panels import DashboardPanel as DashboardPanel
15
+ from trainpit._tui.snapshot import TrainDashboardSnapshot as TrainDashboardSnapshot
16
+ from trainpit._tui.types import CurveGraphRenderer as CurveGraphRenderer
17
+ from trainpit._tui.types import DisplayText as DisplayText
18
+ from trainpit._tui.types import FiniteFloat as FiniteFloat
19
+ from trainpit._tui.types import NonNegativeFiniteFloat as NonNegativeFiniteFloat
20
+ from trainpit._tui.types import PanelSlot as PanelSlot
21
+ from trainpit._tui.types import PositiveInt as PositiveInt
22
+ from trainpit._tui.types import RunStatus as RunStatus
23
+ from trainpit._tui.types import Scalar as Scalar
24
+
25
+ __all__ = [
26
+ "CurveGraphRenderer",
27
+ "DashboardPanel",
28
+ "DisplayText",
29
+ "FiniteFloat",
30
+ "NonNegativeFiniteFloat",
31
+ "PanelSlot",
32
+ "PositiveInt",
33
+ "RunStatus",
34
+ "Scalar",
35
+ "TrainDashboardApp",
36
+ "TrainDashboardSnapshot",
37
+ "_format_curves",
38
+ "_format_metrics",
39
+ "_format_progress",
40
+ "_format_timing",
41
+ "line_graph",
42
+ "scatter_graph",
43
+ ]
trainpit/_tui/app.py ADDED
@@ -0,0 +1,270 @@
1
+ """Textual app implementation for the trainpit dashboard."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Sequence
6
+
7
+ from textual.app import App, ComposeResult
8
+ from textual.containers import Horizontal, Vertical
9
+ from textual.widgets import Footer, Header, Static
10
+
11
+ from trainpit._tui.formatting import (
12
+ _format_curves,
13
+ _format_metrics,
14
+ _format_progress,
15
+ _format_timing,
16
+ )
17
+ from trainpit._tui.graphs import line_graph
18
+ from trainpit._tui.panels import (
19
+ DashboardPanel,
20
+ _panel_body_id,
21
+ _panel_container_id,
22
+ _panel_title,
23
+ _validate_extra_panels,
24
+ )
25
+ from trainpit._tui.snapshot import TrainDashboardSnapshot
26
+ from trainpit._tui.types import CurveGraphRenderer, PanelSlot
27
+
28
+
29
+ class TrainDashboardApp(App[None]):
30
+ """A compact Textual dashboard for train progress."""
31
+
32
+ CSS = """
33
+ Screen {
34
+ background: #07090d;
35
+ color: #d7dee8;
36
+ }
37
+
38
+ Header,
39
+ Footer {
40
+ background: #0b0f14;
41
+ color: #97a3b6;
42
+ }
43
+
44
+ #dashboard {
45
+ height: 1fr;
46
+ padding: 0 1 1 1;
47
+ }
48
+
49
+ #top-row {
50
+ height: 10;
51
+ min-height: 9;
52
+ }
53
+
54
+ #bottom-row {
55
+ height: 1fr;
56
+ min-height: 16;
57
+ }
58
+
59
+ .panel {
60
+ background: #0c1117;
61
+ border: round #27313f;
62
+ padding: 0 1;
63
+ margin: 0 1 1 0;
64
+ min-width: 24;
65
+ }
66
+
67
+ #status-panel {
68
+ width: 1fr;
69
+ min-width: 22;
70
+ }
71
+
72
+ #progress-panel {
73
+ width: 2fr;
74
+ min-width: 34;
75
+ }
76
+
77
+ #metrics-panel {
78
+ width: 2fr;
79
+ min-width: 32;
80
+ }
81
+
82
+ .graph-panel {
83
+ width: 4fr;
84
+ min-width: 60;
85
+ }
86
+
87
+ #bottom-side {
88
+ width: 1fr;
89
+ min-width: 24;
90
+ }
91
+
92
+ .side-panel {
93
+ height: 1fr;
94
+ }
95
+
96
+ .bottom-panel {
97
+ width: 2fr;
98
+ min-width: 32;
99
+ }
100
+
101
+ .custom-panel {
102
+ color: #d7dee8;
103
+ }
104
+
105
+ .panel-title {
106
+ background: #101b24;
107
+ color: #6ee7f9;
108
+ text-style: bold reverse;
109
+ padding: 0 1;
110
+ margin-bottom: 1;
111
+ width: 1fr;
112
+ }
113
+
114
+ #status {
115
+ color: #7ee787;
116
+ text-style: bold;
117
+ }
118
+
119
+ .status-running {
120
+ color: #7ee787;
121
+ }
122
+
123
+ .status-finished {
124
+ color: #6ee7f9;
125
+ }
126
+
127
+ .status-failed {
128
+ color: #ff7b72;
129
+ }
130
+
131
+ #label {
132
+ color: #f0f4fa;
133
+ }
134
+
135
+ #curve {
136
+ color: #7ee787;
137
+ }
138
+
139
+ #progress {
140
+ color: #f0c36a;
141
+ }
142
+
143
+ #metrics {
144
+ color: #f0f4fa;
145
+ }
146
+
147
+ #timing {
148
+ color: #97a3b6;
149
+ }
150
+
151
+ #events {
152
+ color: #d7dee8;
153
+ }
154
+ """
155
+
156
+ TITLE = "trainpit"
157
+ SUB_TITLE = "training monitor"
158
+ ENABLE_COMMAND_PALETTE = False
159
+ BINDINGS = [("q", "quit", "Quit")]
160
+
161
+ def __init__(
162
+ self,
163
+ snapshot: TrainDashboardSnapshot | None = None,
164
+ *,
165
+ extra_panels: Sequence[DashboardPanel] | None = None,
166
+ graph_renderer: CurveGraphRenderer = line_graph,
167
+ ) -> None:
168
+ super().__init__()
169
+ self.snapshot = snapshot or TrainDashboardSnapshot()
170
+ self.extra_panels = _validate_extra_panels(extra_panels or ())
171
+ self.graph_renderer = _validate_graph_renderer(graph_renderer)
172
+
173
+ def compose(self) -> ComposeResult:
174
+ yield Header(show_clock=True)
175
+ with Vertical(id="dashboard"):
176
+ with Horizontal(id="top-row"):
177
+ with Vertical(id="status-panel", classes="panel"):
178
+ yield Static(_panel_title("STATUS"), classes="panel-title")
179
+ yield Static("", id="status")
180
+ yield Static("", id="label")
181
+ with Vertical(id="progress-panel", classes="panel"):
182
+ yield Static(_panel_title("PROGRESS"), classes="panel-title")
183
+ yield Static("", id="progress")
184
+ with Vertical(id="metrics-panel", classes="panel"):
185
+ yield Static(_panel_title("METRICS"), classes="panel-title")
186
+ yield Static("", id="metrics")
187
+ for panel in self._panels_for_slot("top"):
188
+ with Vertical(
189
+ id=_panel_container_id(panel),
190
+ classes="panel custom-panel",
191
+ ):
192
+ yield Static(
193
+ _panel_title(panel.title),
194
+ classes="panel-title",
195
+ )
196
+ yield Static("", id=_panel_body_id(panel))
197
+ with Horizontal(id="bottom-row"):
198
+ with Vertical(classes="panel graph-panel"):
199
+ yield Static(_panel_title("LEARNING CURVE"), classes="panel-title")
200
+ yield Static("", id="curve")
201
+ for panel in self._panels_for_slot("bottom"):
202
+ with Vertical(
203
+ id=_panel_container_id(panel),
204
+ classes="panel bottom-panel custom-panel",
205
+ ):
206
+ yield Static(
207
+ _panel_title(panel.title),
208
+ classes="panel-title",
209
+ )
210
+ yield Static("", id=_panel_body_id(panel))
211
+ with Vertical(id="bottom-side"):
212
+ with Vertical(classes="panel side-panel"):
213
+ yield Static(_panel_title("TIMING"), classes="panel-title")
214
+ yield Static("", id="timing")
215
+ with Vertical(classes="panel side-panel"):
216
+ yield Static(_panel_title("EVENTS"), classes="panel-title")
217
+ yield Static("", id="events")
218
+ for panel in self._panels_for_slot("side"):
219
+ with Vertical(
220
+ id=_panel_container_id(panel),
221
+ classes="panel side-panel custom-panel",
222
+ ):
223
+ yield Static(
224
+ _panel_title(panel.title),
225
+ classes="panel-title",
226
+ )
227
+ yield Static("", id=_panel_body_id(panel))
228
+ yield Footer()
229
+
230
+ def on_mount(self) -> None:
231
+ self.refresh_dashboard()
232
+
233
+ def refresh_dashboard(self) -> None:
234
+ """Render the current snapshot into dashboard widgets."""
235
+
236
+ snapshot = self.snapshot
237
+ status = self.query_one("#status", Static)
238
+ status.update(snapshot.status.upper())
239
+ for value in ("running", "finished", "failed"):
240
+ status.set_class(snapshot.status == value, f"status-{value}")
241
+
242
+ self.query_one("#label", Static).update(snapshot.label or "unlabeled run")
243
+ self.query_one("#progress", Static).update(_format_progress(snapshot))
244
+ self.query_one("#metrics", Static).update(_format_metrics(snapshot))
245
+ self.query_one("#curve", Static).update(
246
+ _format_curves(snapshot, graph_renderer=self.graph_renderer)
247
+ )
248
+ self.query_one("#timing", Static).update(_format_timing(snapshot))
249
+ self.query_one("#events", Static).update(snapshot.event or "waiting for events")
250
+ for panel in self.extra_panels:
251
+ self.query_one(f"#{_panel_body_id(panel)}", Static).update(
252
+ panel.render(snapshot)
253
+ )
254
+
255
+ def update_snapshot(self, snapshot: TrainDashboardSnapshot) -> None:
256
+ """Replace the dashboard snapshot and refresh visible widgets."""
257
+
258
+ self.snapshot = snapshot
259
+ if self.is_mounted:
260
+ self.refresh_dashboard()
261
+
262
+ def _panels_for_slot(self, slot: PanelSlot) -> tuple[DashboardPanel, ...]:
263
+ return tuple(panel for panel in self.extra_panels if panel.slot == slot)
264
+
265
+
266
+ def _validate_graph_renderer(renderer: CurveGraphRenderer) -> CurveGraphRenderer:
267
+ if not callable(renderer):
268
+ raise TypeError("graph_renderer must be callable")
269
+
270
+ return renderer
@@ -0,0 +1,26 @@
1
+ """Constants used by trainpit TUI rendering."""
2
+
3
+ from __future__ import annotations
4
+
5
+ GRAPH_WIDTH = 56
6
+ PRIMARY_GRAPH_HEIGHT = 7
7
+ SECONDARY_GRAPH_HEIGHT = 3
8
+ PROGRESS_BAR_WIDTH = 24
9
+ BRAILLE_BASE = 0x2800
10
+ BRAILLE_COLUMNS = (
11
+ (0x01, 0x02, 0x04, 0x40),
12
+ (0x08, 0x10, 0x20, 0x80),
13
+ )
14
+ BRAILLE_ROWS_PER_CELL = 4
15
+ SAMPLES_PER_BRAILLE = 2
16
+ EPOCH_AXIS_MIN_LABEL_SPACING = 8
17
+ EPOCH_AXIS_LABEL_PADDING = 1
18
+ PANEL_ID_CHARACTERS = frozenset(
19
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-"
20
+ )
21
+ FULLWIDTH_TITLE_TRANSLATION = str.maketrans(
22
+ {
23
+ " ": " ",
24
+ **{chr(value): chr(value + 0xFEE0) for value in range(ord("A"), ord("Z") + 1)},
25
+ }
26
+ )