trainpit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trainpit/__init__.py +5 -0
- trainpit/_state.py +195 -0
- trainpit/_tracker.py +126 -0
- trainpit/_tui/__init__.py +43 -0
- trainpit/_tui/app.py +270 -0
- trainpit/_tui/constants.py +26 -0
- trainpit/_tui/formatting.py +414 -0
- trainpit/_tui/graphs.py +188 -0
- trainpit/_tui/panels.py +68 -0
- trainpit/_tui/snapshot.py +162 -0
- trainpit/_tui/types.py +20 -0
- trainpit/_tui/values.py +11 -0
- trainpit/tui.py +43 -0
- trainpit-0.1.0.dist-info/METADATA +233 -0
- trainpit-0.1.0.dist-info/RECORD +18 -0
- trainpit-0.1.0.dist-info/WHEEL +5 -0
- trainpit-0.1.0.dist-info/licenses/LICENSE +21 -0
- trainpit-0.1.0.dist-info/top_level.txt +1 -0
trainpit/__init__.py
ADDED
trainpit/_state.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""Internal state for train progress tracking."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Mapping
|
|
6
|
+
from time import monotonic
|
|
7
|
+
from typing import Annotated, Literal
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
10
|
+
|
|
11
|
+
Scalar = float | int
|
|
12
|
+
DisplayText = Annotated[str, Field(strict=True, min_length=1)]
|
|
13
|
+
FiniteFloat = Annotated[float, Field(strict=True, allow_inf_nan=False)]
|
|
14
|
+
NonNegativeFiniteFloat = Annotated[
|
|
15
|
+
float,
|
|
16
|
+
Field(strict=True, ge=0, allow_inf_nan=False),
|
|
17
|
+
]
|
|
18
|
+
PositiveInt = Annotated[int, Field(strict=True, ge=1)]
|
|
19
|
+
TrainPhase = Literal["train"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TrainState(BaseModel):
|
|
23
|
+
"""Current state for one train display."""
|
|
24
|
+
|
|
25
|
+
model_config = ConfigDict(
|
|
26
|
+
arbitrary_types_allowed=True,
|
|
27
|
+
extra="forbid",
|
|
28
|
+
strict=True,
|
|
29
|
+
validate_assignment=True,
|
|
30
|
+
validate_default=True,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
label: DisplayText | None = Field(
|
|
34
|
+
default=None,
|
|
35
|
+
description="Human-readable label for the training run.",
|
|
36
|
+
)
|
|
37
|
+
total_epochs: PositiveInt | None = Field(
|
|
38
|
+
default=None,
|
|
39
|
+
description="Total epoch count when known.",
|
|
40
|
+
)
|
|
41
|
+
total_steps: PositiveInt | None = Field(
|
|
42
|
+
default=None,
|
|
43
|
+
description="Total step count per epoch when known.",
|
|
44
|
+
)
|
|
45
|
+
phase: TrainPhase = Field(
|
|
46
|
+
default="train",
|
|
47
|
+
description="Current lifecycle phase for the tracker.",
|
|
48
|
+
)
|
|
49
|
+
current_epoch: PositiveInt | None = Field(
|
|
50
|
+
default=None,
|
|
51
|
+
description="Current one-based epoch index.",
|
|
52
|
+
)
|
|
53
|
+
current_step: PositiveInt | None = Field(
|
|
54
|
+
default=None,
|
|
55
|
+
description="Current one-based step index within the active epoch.",
|
|
56
|
+
)
|
|
57
|
+
loss: FiniteFloat | None = Field(
|
|
58
|
+
default=None,
|
|
59
|
+
description="Latest finite loss value.",
|
|
60
|
+
)
|
|
61
|
+
metrics: dict[DisplayText, FiniteFloat] = Field(
|
|
62
|
+
default_factory=dict,
|
|
63
|
+
description="Latest finite scalar metrics keyed by metric name.",
|
|
64
|
+
)
|
|
65
|
+
learning_rate: NonNegativeFiniteFloat | None = Field(
|
|
66
|
+
default=None,
|
|
67
|
+
description="Latest non-negative learning rate.",
|
|
68
|
+
)
|
|
69
|
+
event: DisplayText | None = Field(
|
|
70
|
+
default=None,
|
|
71
|
+
description="Latest event message.",
|
|
72
|
+
)
|
|
73
|
+
started_at: NonNegativeFiniteFloat | None = Field(
|
|
74
|
+
default=None,
|
|
75
|
+
description="Monotonic timestamp when progress tracking started.",
|
|
76
|
+
)
|
|
77
|
+
updated_at: NonNegativeFiniteFloat | None = Field(
|
|
78
|
+
default=None,
|
|
79
|
+
description="Monotonic timestamp for the latest state update.",
|
|
80
|
+
)
|
|
81
|
+
finished_at: NonNegativeFiniteFloat | None = Field(
|
|
82
|
+
default=None,
|
|
83
|
+
description="Monotonic timestamp when progress tracking ended.",
|
|
84
|
+
)
|
|
85
|
+
started: bool = Field(
|
|
86
|
+
default=False,
|
|
87
|
+
description="Whether progress tracking has started.",
|
|
88
|
+
)
|
|
89
|
+
finished: bool = Field(
|
|
90
|
+
default=False,
|
|
91
|
+
description="Whether progress tracking has reached a terminal state.",
|
|
92
|
+
)
|
|
93
|
+
failed: bool = Field(
|
|
94
|
+
default=False,
|
|
95
|
+
description="Whether progress tracking ended with an error.",
|
|
96
|
+
)
|
|
97
|
+
error: BaseException | None = Field(
|
|
98
|
+
default=None,
|
|
99
|
+
description="Error captured when progress tracking fails.",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def start(self, *, now: Scalar | None = None) -> None:
|
|
103
|
+
timestamp = _resolve_timestamp(now)
|
|
104
|
+
self.started = True
|
|
105
|
+
self.finished = False
|
|
106
|
+
self.failed = False
|
|
107
|
+
self.error = None
|
|
108
|
+
self.started_at = timestamp
|
|
109
|
+
self.updated_at = timestamp
|
|
110
|
+
self.finished_at = None
|
|
111
|
+
|
|
112
|
+
def set_epoch(self, value: int, *, now: Scalar | None = None) -> None:
|
|
113
|
+
self.current_epoch = _require_positive_integer("epoch", value)
|
|
114
|
+
self._touch(_resolve_timestamp(now))
|
|
115
|
+
|
|
116
|
+
def set_step(
|
|
117
|
+
self,
|
|
118
|
+
value: int,
|
|
119
|
+
*,
|
|
120
|
+
loss: Scalar | None = None,
|
|
121
|
+
metrics: Mapping[str, Scalar] | None = None,
|
|
122
|
+
learning_rate: Scalar | None = None,
|
|
123
|
+
now: Scalar | None = None,
|
|
124
|
+
) -> None:
|
|
125
|
+
self.current_step = _require_positive_integer("step", value)
|
|
126
|
+
|
|
127
|
+
if loss is not None:
|
|
128
|
+
self.loss = _coerce_scalar("loss", loss)
|
|
129
|
+
if metrics is not None:
|
|
130
|
+
self.metrics = {**self.metrics, **_coerce_metrics(metrics)}
|
|
131
|
+
if learning_rate is not None:
|
|
132
|
+
self.learning_rate = _coerce_scalar("learning_rate", learning_rate)
|
|
133
|
+
|
|
134
|
+
self._touch(_resolve_timestamp(now))
|
|
135
|
+
|
|
136
|
+
def set_metrics(
|
|
137
|
+
self,
|
|
138
|
+
values: Mapping[str, Scalar],
|
|
139
|
+
*,
|
|
140
|
+
now: Scalar | None = None,
|
|
141
|
+
) -> None:
|
|
142
|
+
self.metrics = {**self.metrics, **_coerce_metrics(values)}
|
|
143
|
+
self._touch(_resolve_timestamp(now))
|
|
144
|
+
|
|
145
|
+
def set_event(self, message: str, *, now: Scalar | None = None) -> None:
|
|
146
|
+
self.event = message
|
|
147
|
+
self._touch(_resolve_timestamp(now))
|
|
148
|
+
|
|
149
|
+
def finish(self, *, now: Scalar | None = None) -> None:
|
|
150
|
+
timestamp = _resolve_timestamp(now)
|
|
151
|
+
self._touch(timestamp)
|
|
152
|
+
self.finished = True
|
|
153
|
+
self.failed = False
|
|
154
|
+
self.error = None
|
|
155
|
+
self.finished_at = timestamp
|
|
156
|
+
|
|
157
|
+
def fail(self, error: BaseException, *, now: Scalar | None = None) -> None:
|
|
158
|
+
timestamp = _resolve_timestamp(now)
|
|
159
|
+
self._touch(timestamp)
|
|
160
|
+
self.finished = True
|
|
161
|
+
self.failed = True
|
|
162
|
+
self.error = error
|
|
163
|
+
self.finished_at = timestamp
|
|
164
|
+
|
|
165
|
+
def _touch(self, timestamp: float) -> None:
|
|
166
|
+
if self.started_at is None:
|
|
167
|
+
self.started_at = timestamp
|
|
168
|
+
self.updated_at = timestamp
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _require_positive_integer(name: str, value: int) -> int:
|
|
172
|
+
if isinstance(value, bool) or not isinstance(value, int):
|
|
173
|
+
raise TypeError(f"{name} must be an integer")
|
|
174
|
+
if value < 1:
|
|
175
|
+
raise ValueError(f"{name} must be greater than or equal to 1")
|
|
176
|
+
return value
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _coerce_scalar(name: str, value: Scalar) -> float:
|
|
180
|
+
if isinstance(value, bool) or not isinstance(value, int | float):
|
|
181
|
+
raise TypeError(f"{name} must be a finite scalar")
|
|
182
|
+
|
|
183
|
+
return float(value)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _resolve_timestamp(value: Scalar | None) -> float:
|
|
187
|
+
timestamp = monotonic() if value is None else _coerce_scalar("now", value)
|
|
188
|
+
if timestamp < 0:
|
|
189
|
+
raise ValueError("now must be greater than or equal to 0")
|
|
190
|
+
|
|
191
|
+
return timestamp
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _coerce_metrics(values: Mapping[str, Scalar]) -> dict[str, float]:
|
|
195
|
+
return {name: _coerce_scalar(name, value) for name, value in values.items()}
|
trainpit/_tracker.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Public train tracker entry point implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Mapping
|
|
6
|
+
from types import TracebackType
|
|
7
|
+
from typing import Self
|
|
8
|
+
|
|
9
|
+
from trainpit._state import Scalar, TrainState
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TrainTracker:
|
|
13
|
+
"""Context manager used to update one train progress display."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
*,
|
|
18
|
+
total_epochs: int | None = None,
|
|
19
|
+
total_steps: int | None = None,
|
|
20
|
+
label: str | None = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
self._state = TrainState(
|
|
23
|
+
label=label,
|
|
24
|
+
total_epochs=_optional_positive_integer("total_epochs", total_epochs),
|
|
25
|
+
total_steps=_optional_positive_integer("total_steps", total_steps),
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def snapshot(self) -> TrainState:
|
|
30
|
+
"""Return a copy of the current progress state."""
|
|
31
|
+
|
|
32
|
+
return self._state.model_copy(deep=True)
|
|
33
|
+
|
|
34
|
+
def __enter__(self) -> Self:
|
|
35
|
+
self.start()
|
|
36
|
+
return self
|
|
37
|
+
|
|
38
|
+
def __exit__(
|
|
39
|
+
self,
|
|
40
|
+
exc_type: type[BaseException] | None,
|
|
41
|
+
exc: BaseException | None,
|
|
42
|
+
traceback: TracebackType | None,
|
|
43
|
+
) -> bool:
|
|
44
|
+
del exc_type, traceback
|
|
45
|
+
|
|
46
|
+
if exc is None:
|
|
47
|
+
self.finish()
|
|
48
|
+
else:
|
|
49
|
+
self.fail(exc)
|
|
50
|
+
|
|
51
|
+
return False
|
|
52
|
+
|
|
53
|
+
def epoch(self, value: int) -> None:
|
|
54
|
+
self._state.set_epoch(value)
|
|
55
|
+
|
|
56
|
+
def start(self) -> None:
|
|
57
|
+
self._state.start()
|
|
58
|
+
|
|
59
|
+
def step(
|
|
60
|
+
self,
|
|
61
|
+
value: int,
|
|
62
|
+
*,
|
|
63
|
+
loss: Scalar | None = None,
|
|
64
|
+
metrics: Mapping[str, Scalar] | None = None,
|
|
65
|
+
learning_rate: Scalar | None = None,
|
|
66
|
+
lr: Scalar | None = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
self._state.set_step(
|
|
69
|
+
value,
|
|
70
|
+
loss=loss,
|
|
71
|
+
metrics=metrics,
|
|
72
|
+
learning_rate=_resolve_learning_rate(learning_rate, lr),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def update_metrics(self, values: Mapping[str, Scalar]) -> None:
|
|
76
|
+
self._state.set_metrics(values)
|
|
77
|
+
|
|
78
|
+
def metrics(self, values: Mapping[str, Scalar]) -> None:
|
|
79
|
+
self.update_metrics(values)
|
|
80
|
+
|
|
81
|
+
def log(self, message: str) -> None:
|
|
82
|
+
self._state.set_event(message)
|
|
83
|
+
|
|
84
|
+
def event(self, message: str) -> None:
|
|
85
|
+
self.log(message)
|
|
86
|
+
|
|
87
|
+
def finish(self) -> None:
|
|
88
|
+
self._state.finish()
|
|
89
|
+
|
|
90
|
+
def fail(self, error: BaseException) -> None:
|
|
91
|
+
self._state.fail(error)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def train(
|
|
95
|
+
*,
|
|
96
|
+
total_epochs: int | None = None,
|
|
97
|
+
total_steps: int | None = None,
|
|
98
|
+
label: str | None = None,
|
|
99
|
+
) -> TrainTracker:
|
|
100
|
+
"""Create a tracker for one training loop."""
|
|
101
|
+
|
|
102
|
+
return TrainTracker(
|
|
103
|
+
total_epochs=total_epochs,
|
|
104
|
+
total_steps=total_steps,
|
|
105
|
+
label=label,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _optional_positive_integer(name: str, value: int | None) -> int | None:
|
|
110
|
+
if value is None:
|
|
111
|
+
return None
|
|
112
|
+
if isinstance(value, bool) or not isinstance(value, int):
|
|
113
|
+
raise TypeError(f"{name} must be an integer")
|
|
114
|
+
if value < 1:
|
|
115
|
+
raise ValueError(f"{name} must be greater than or equal to 1")
|
|
116
|
+
return value
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _resolve_learning_rate(
|
|
120
|
+
learning_rate: Scalar | None,
|
|
121
|
+
lr: Scalar | None,
|
|
122
|
+
) -> Scalar | None:
|
|
123
|
+
if learning_rate is not None and lr is not None:
|
|
124
|
+
raise ValueError("Use either learning_rate or lr, not both")
|
|
125
|
+
|
|
126
|
+
return learning_rate if learning_rate is not None else lr
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Internal modules backing the public trainpit.tui facade."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from trainpit._tui.app import TrainDashboardApp as TrainDashboardApp
|
|
6
|
+
from trainpit._tui.formatting import (
|
|
7
|
+
_format_curves as _format_curves,
|
|
8
|
+
_format_metrics as _format_metrics,
|
|
9
|
+
_format_progress as _format_progress,
|
|
10
|
+
_format_timing as _format_timing,
|
|
11
|
+
)
|
|
12
|
+
from trainpit._tui.graphs import line_graph as line_graph
|
|
13
|
+
from trainpit._tui.graphs import scatter_graph as scatter_graph
|
|
14
|
+
from trainpit._tui.panels import DashboardPanel as DashboardPanel
|
|
15
|
+
from trainpit._tui.snapshot import TrainDashboardSnapshot as TrainDashboardSnapshot
|
|
16
|
+
from trainpit._tui.types import CurveGraphRenderer as CurveGraphRenderer
|
|
17
|
+
from trainpit._tui.types import DisplayText as DisplayText
|
|
18
|
+
from trainpit._tui.types import FiniteFloat as FiniteFloat
|
|
19
|
+
from trainpit._tui.types import NonNegativeFiniteFloat as NonNegativeFiniteFloat
|
|
20
|
+
from trainpit._tui.types import PanelSlot as PanelSlot
|
|
21
|
+
from trainpit._tui.types import PositiveInt as PositiveInt
|
|
22
|
+
from trainpit._tui.types import RunStatus as RunStatus
|
|
23
|
+
from trainpit._tui.types import Scalar as Scalar
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"CurveGraphRenderer",
|
|
27
|
+
"DashboardPanel",
|
|
28
|
+
"DisplayText",
|
|
29
|
+
"FiniteFloat",
|
|
30
|
+
"NonNegativeFiniteFloat",
|
|
31
|
+
"PanelSlot",
|
|
32
|
+
"PositiveInt",
|
|
33
|
+
"RunStatus",
|
|
34
|
+
"Scalar",
|
|
35
|
+
"TrainDashboardApp",
|
|
36
|
+
"TrainDashboardSnapshot",
|
|
37
|
+
"_format_curves",
|
|
38
|
+
"_format_metrics",
|
|
39
|
+
"_format_progress",
|
|
40
|
+
"_format_timing",
|
|
41
|
+
"line_graph",
|
|
42
|
+
"scatter_graph",
|
|
43
|
+
]
|
trainpit/_tui/app.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""Textual app implementation for the trainpit dashboard."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
|
|
7
|
+
from textual.app import App, ComposeResult
|
|
8
|
+
from textual.containers import Horizontal, Vertical
|
|
9
|
+
from textual.widgets import Footer, Header, Static
|
|
10
|
+
|
|
11
|
+
from trainpit._tui.formatting import (
|
|
12
|
+
_format_curves,
|
|
13
|
+
_format_metrics,
|
|
14
|
+
_format_progress,
|
|
15
|
+
_format_timing,
|
|
16
|
+
)
|
|
17
|
+
from trainpit._tui.graphs import line_graph
|
|
18
|
+
from trainpit._tui.panels import (
|
|
19
|
+
DashboardPanel,
|
|
20
|
+
_panel_body_id,
|
|
21
|
+
_panel_container_id,
|
|
22
|
+
_panel_title,
|
|
23
|
+
_validate_extra_panels,
|
|
24
|
+
)
|
|
25
|
+
from trainpit._tui.snapshot import TrainDashboardSnapshot
|
|
26
|
+
from trainpit._tui.types import CurveGraphRenderer, PanelSlot
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TrainDashboardApp(App[None]):
|
|
30
|
+
"""A compact Textual dashboard for train progress."""
|
|
31
|
+
|
|
32
|
+
CSS = """
|
|
33
|
+
Screen {
|
|
34
|
+
background: #07090d;
|
|
35
|
+
color: #d7dee8;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
Header,
|
|
39
|
+
Footer {
|
|
40
|
+
background: #0b0f14;
|
|
41
|
+
color: #97a3b6;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
#dashboard {
|
|
45
|
+
height: 1fr;
|
|
46
|
+
padding: 0 1 1 1;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
#top-row {
|
|
50
|
+
height: 10;
|
|
51
|
+
min-height: 9;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
#bottom-row {
|
|
55
|
+
height: 1fr;
|
|
56
|
+
min-height: 16;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
.panel {
|
|
60
|
+
background: #0c1117;
|
|
61
|
+
border: round #27313f;
|
|
62
|
+
padding: 0 1;
|
|
63
|
+
margin: 0 1 1 0;
|
|
64
|
+
min-width: 24;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
#status-panel {
|
|
68
|
+
width: 1fr;
|
|
69
|
+
min-width: 22;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
#progress-panel {
|
|
73
|
+
width: 2fr;
|
|
74
|
+
min-width: 34;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
#metrics-panel {
|
|
78
|
+
width: 2fr;
|
|
79
|
+
min-width: 32;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
.graph-panel {
|
|
83
|
+
width: 4fr;
|
|
84
|
+
min-width: 60;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
#bottom-side {
|
|
88
|
+
width: 1fr;
|
|
89
|
+
min-width: 24;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
.side-panel {
|
|
93
|
+
height: 1fr;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
.bottom-panel {
|
|
97
|
+
width: 2fr;
|
|
98
|
+
min-width: 32;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
.custom-panel {
|
|
102
|
+
color: #d7dee8;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
.panel-title {
|
|
106
|
+
background: #101b24;
|
|
107
|
+
color: #6ee7f9;
|
|
108
|
+
text-style: bold reverse;
|
|
109
|
+
padding: 0 1;
|
|
110
|
+
margin-bottom: 1;
|
|
111
|
+
width: 1fr;
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
#status {
|
|
115
|
+
color: #7ee787;
|
|
116
|
+
text-style: bold;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
.status-running {
|
|
120
|
+
color: #7ee787;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
.status-finished {
|
|
124
|
+
color: #6ee7f9;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
.status-failed {
|
|
128
|
+
color: #ff7b72;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
#label {
|
|
132
|
+
color: #f0f4fa;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
#curve {
|
|
136
|
+
color: #7ee787;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
#progress {
|
|
140
|
+
color: #f0c36a;
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
#metrics {
|
|
144
|
+
color: #f0f4fa;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
#timing {
|
|
148
|
+
color: #97a3b6;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
#events {
|
|
152
|
+
color: #d7dee8;
|
|
153
|
+
}
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
TITLE = "trainpit"
|
|
157
|
+
SUB_TITLE = "training monitor"
|
|
158
|
+
ENABLE_COMMAND_PALETTE = False
|
|
159
|
+
BINDINGS = [("q", "quit", "Quit")]
|
|
160
|
+
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
snapshot: TrainDashboardSnapshot | None = None,
|
|
164
|
+
*,
|
|
165
|
+
extra_panels: Sequence[DashboardPanel] | None = None,
|
|
166
|
+
graph_renderer: CurveGraphRenderer = line_graph,
|
|
167
|
+
) -> None:
|
|
168
|
+
super().__init__()
|
|
169
|
+
self.snapshot = snapshot or TrainDashboardSnapshot()
|
|
170
|
+
self.extra_panels = _validate_extra_panels(extra_panels or ())
|
|
171
|
+
self.graph_renderer = _validate_graph_renderer(graph_renderer)
|
|
172
|
+
|
|
173
|
+
def compose(self) -> ComposeResult:
|
|
174
|
+
yield Header(show_clock=True)
|
|
175
|
+
with Vertical(id="dashboard"):
|
|
176
|
+
with Horizontal(id="top-row"):
|
|
177
|
+
with Vertical(id="status-panel", classes="panel"):
|
|
178
|
+
yield Static(_panel_title("STATUS"), classes="panel-title")
|
|
179
|
+
yield Static("", id="status")
|
|
180
|
+
yield Static("", id="label")
|
|
181
|
+
with Vertical(id="progress-panel", classes="panel"):
|
|
182
|
+
yield Static(_panel_title("PROGRESS"), classes="panel-title")
|
|
183
|
+
yield Static("", id="progress")
|
|
184
|
+
with Vertical(id="metrics-panel", classes="panel"):
|
|
185
|
+
yield Static(_panel_title("METRICS"), classes="panel-title")
|
|
186
|
+
yield Static("", id="metrics")
|
|
187
|
+
for panel in self._panels_for_slot("top"):
|
|
188
|
+
with Vertical(
|
|
189
|
+
id=_panel_container_id(panel),
|
|
190
|
+
classes="panel custom-panel",
|
|
191
|
+
):
|
|
192
|
+
yield Static(
|
|
193
|
+
_panel_title(panel.title),
|
|
194
|
+
classes="panel-title",
|
|
195
|
+
)
|
|
196
|
+
yield Static("", id=_panel_body_id(panel))
|
|
197
|
+
with Horizontal(id="bottom-row"):
|
|
198
|
+
with Vertical(classes="panel graph-panel"):
|
|
199
|
+
yield Static(_panel_title("LEARNING CURVE"), classes="panel-title")
|
|
200
|
+
yield Static("", id="curve")
|
|
201
|
+
for panel in self._panels_for_slot("bottom"):
|
|
202
|
+
with Vertical(
|
|
203
|
+
id=_panel_container_id(panel),
|
|
204
|
+
classes="panel bottom-panel custom-panel",
|
|
205
|
+
):
|
|
206
|
+
yield Static(
|
|
207
|
+
_panel_title(panel.title),
|
|
208
|
+
classes="panel-title",
|
|
209
|
+
)
|
|
210
|
+
yield Static("", id=_panel_body_id(panel))
|
|
211
|
+
with Vertical(id="bottom-side"):
|
|
212
|
+
with Vertical(classes="panel side-panel"):
|
|
213
|
+
yield Static(_panel_title("TIMING"), classes="panel-title")
|
|
214
|
+
yield Static("", id="timing")
|
|
215
|
+
with Vertical(classes="panel side-panel"):
|
|
216
|
+
yield Static(_panel_title("EVENTS"), classes="panel-title")
|
|
217
|
+
yield Static("", id="events")
|
|
218
|
+
for panel in self._panels_for_slot("side"):
|
|
219
|
+
with Vertical(
|
|
220
|
+
id=_panel_container_id(panel),
|
|
221
|
+
classes="panel side-panel custom-panel",
|
|
222
|
+
):
|
|
223
|
+
yield Static(
|
|
224
|
+
_panel_title(panel.title),
|
|
225
|
+
classes="panel-title",
|
|
226
|
+
)
|
|
227
|
+
yield Static("", id=_panel_body_id(panel))
|
|
228
|
+
yield Footer()
|
|
229
|
+
|
|
230
|
+
def on_mount(self) -> None:
|
|
231
|
+
self.refresh_dashboard()
|
|
232
|
+
|
|
233
|
+
def refresh_dashboard(self) -> None:
|
|
234
|
+
"""Render the current snapshot into dashboard widgets."""
|
|
235
|
+
|
|
236
|
+
snapshot = self.snapshot
|
|
237
|
+
status = self.query_one("#status", Static)
|
|
238
|
+
status.update(snapshot.status.upper())
|
|
239
|
+
for value in ("running", "finished", "failed"):
|
|
240
|
+
status.set_class(snapshot.status == value, f"status-{value}")
|
|
241
|
+
|
|
242
|
+
self.query_one("#label", Static).update(snapshot.label or "unlabeled run")
|
|
243
|
+
self.query_one("#progress", Static).update(_format_progress(snapshot))
|
|
244
|
+
self.query_one("#metrics", Static).update(_format_metrics(snapshot))
|
|
245
|
+
self.query_one("#curve", Static).update(
|
|
246
|
+
_format_curves(snapshot, graph_renderer=self.graph_renderer)
|
|
247
|
+
)
|
|
248
|
+
self.query_one("#timing", Static).update(_format_timing(snapshot))
|
|
249
|
+
self.query_one("#events", Static).update(snapshot.event or "waiting for events")
|
|
250
|
+
for panel in self.extra_panels:
|
|
251
|
+
self.query_one(f"#{_panel_body_id(panel)}", Static).update(
|
|
252
|
+
panel.render(snapshot)
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def update_snapshot(self, snapshot: TrainDashboardSnapshot) -> None:
|
|
256
|
+
"""Replace the dashboard snapshot and refresh visible widgets."""
|
|
257
|
+
|
|
258
|
+
self.snapshot = snapshot
|
|
259
|
+
if self.is_mounted:
|
|
260
|
+
self.refresh_dashboard()
|
|
261
|
+
|
|
262
|
+
def _panels_for_slot(self, slot: PanelSlot) -> tuple[DashboardPanel, ...]:
|
|
263
|
+
return tuple(panel for panel in self.extra_panels if panel.slot == slot)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _validate_graph_renderer(renderer: CurveGraphRenderer) -> CurveGraphRenderer:
|
|
267
|
+
if not callable(renderer):
|
|
268
|
+
raise TypeError("graph_renderer must be callable")
|
|
269
|
+
|
|
270
|
+
return renderer
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Constants used by trainpit TUI rendering."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
GRAPH_WIDTH = 56
|
|
6
|
+
PRIMARY_GRAPH_HEIGHT = 7
|
|
7
|
+
SECONDARY_GRAPH_HEIGHT = 3
|
|
8
|
+
PROGRESS_BAR_WIDTH = 24
|
|
9
|
+
BRAILLE_BASE = 0x2800
|
|
10
|
+
BRAILLE_COLUMNS = (
|
|
11
|
+
(0x01, 0x02, 0x04, 0x40),
|
|
12
|
+
(0x08, 0x10, 0x20, 0x80),
|
|
13
|
+
)
|
|
14
|
+
BRAILLE_ROWS_PER_CELL = 4
|
|
15
|
+
SAMPLES_PER_BRAILLE = 2
|
|
16
|
+
EPOCH_AXIS_MIN_LABEL_SPACING = 8
|
|
17
|
+
EPOCH_AXIS_LABEL_PADDING = 1
|
|
18
|
+
PANEL_ID_CHARACTERS = frozenset(
|
|
19
|
+
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-"
|
|
20
|
+
)
|
|
21
|
+
FULLWIDTH_TITLE_TRANSLATION = str.maketrans(
|
|
22
|
+
{
|
|
23
|
+
" ": " ",
|
|
24
|
+
**{chr(value): chr(value + 0xFEE0) for value in range(ord("A"), ord("Z") + 1)},
|
|
25
|
+
}
|
|
26
|
+
)
|