flashstudio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flashstudio/__init__.py +5 -0
- flashstudio/app.py +64 -0
- flashstudio/cli.py +18 -0
- flashstudio/components/__init__.py +0 -0
- flashstudio/components/sidebar.py +45 -0
- flashstudio/components/styles.py +273 -0
- flashstudio/components/wizard.py +46 -0
- flashstudio/launcher.py +73 -0
- flashstudio/pages/__init__.py +0 -0
- flashstudio/pages/dashboard.py +168 -0
- flashstudio/pages/data.py +272 -0
- flashstudio/pages/export.py +212 -0
- flashstudio/pages/inference.py +1112 -0
- flashstudio/pages/model.py +370 -0
- flashstudio/pages/training.py +672 -0
- flashstudio/utils/__init__.py +0 -0
- flashstudio/utils/device.py +58 -0
- flashstudio-0.1.0.dist-info/METADATA +133 -0
- flashstudio-0.1.0.dist-info/RECORD +22 -0
- flashstudio-0.1.0.dist-info/WHEEL +5 -0
- flashstudio-0.1.0.dist-info/entry_points.txt +2 -0
- flashstudio-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,672 @@
|
|
|
1
|
+
"""FlashStudio โ Training Dashboard Page (reads real FlashDet workspace output)."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import json
|
|
6
|
+
import glob as glob_module
|
|
7
|
+
import time
|
|
8
|
+
import importlib.util
|
|
9
|
+
|
|
10
|
+
import streamlit as st
|
|
11
|
+
import plotly.graph_objects as go
|
|
12
|
+
from plotly.subplots import make_subplots
|
|
13
|
+
import numpy as np
|
|
14
|
+
from PIL import Image
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_default_workspace():
|
|
18
|
+
"""Auto-detect workspace: use save_dir from config or scan common locations."""
|
|
19
|
+
save_dir = st.session_state.get("save_dir", "")
|
|
20
|
+
if save_dir and os.path.isdir(save_dir):
|
|
21
|
+
return save_dir
|
|
22
|
+
|
|
23
|
+
# Try to find FlashDet workspace relative to this package
|
|
24
|
+
candidates = [
|
|
25
|
+
os.path.join(os.getcwd(), "workspace"),
|
|
26
|
+
os.path.join(os.getcwd(), "..", "FlashDet", "workspace"),
|
|
27
|
+
os.path.join(os.path.dirname(__file__), "..", "..", "..", "FlashDet", "workspace"),
|
|
28
|
+
]
|
|
29
|
+
for c in candidates:
|
|
30
|
+
c = os.path.abspath(c)
|
|
31
|
+
if os.path.isdir(c):
|
|
32
|
+
return c
|
|
33
|
+
|
|
34
|
+
return os.getcwd()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def render_training_page():
|
|
38
|
+
"""Render training dashboard โ monitor real FlashDet runs or start new training."""
|
|
39
|
+
from flashstudio.components.styles import render_page_header
|
|
40
|
+
render_page_header("๐๏ธ", "Training Dashboard",
|
|
41
|
+
"Monitor real FlashDet training runs โ logs, curves, visualizations.")
|
|
42
|
+
|
|
43
|
+
if "training_active" not in st.session_state:
|
|
44
|
+
st.session_state["training_active"] = False
|
|
45
|
+
st.session_state["training_status"] = "Not started"
|
|
46
|
+
|
|
47
|
+
tab_start, tab_monitor = st.tabs(["โถ๏ธ Start Training", "๐ Monitor Run"])
|
|
48
|
+
|
|
49
|
+
with tab_start:
|
|
50
|
+
_render_start_tab()
|
|
51
|
+
|
|
52
|
+
with tab_monitor:
|
|
53
|
+
_render_monitor_tab()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _render_monitor_tab():
|
|
57
|
+
"""Tab: monitor an existing training run from workspace."""
|
|
58
|
+
workspace = st.text_input(
|
|
59
|
+
"Workspace path",
|
|
60
|
+
value=_get_default_workspace(),
|
|
61
|
+
key="workspace_path",
|
|
62
|
+
help="Path containing training run folders (auto-detected from save_dir config or FlashDet workspace)",
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if not os.path.isdir(workspace):
|
|
66
|
+
st.warning(f"Workspace not found: `{workspace}`")
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
runs = sorted(
|
|
70
|
+
[d for d in os.listdir(workspace) if os.path.isdir(os.path.join(workspace, d))],
|
|
71
|
+
key=lambda d: os.path.getmtime(os.path.join(workspace, d)),
|
|
72
|
+
reverse=True,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if not runs:
|
|
76
|
+
st.info("No training runs found in workspace.")
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
selected_run = st.selectbox("Select Training Run", runs, key="selected_run")
|
|
80
|
+
run_dir = os.path.join(workspace, selected_run)
|
|
81
|
+
|
|
82
|
+
st.divider()
|
|
83
|
+
_render_run_dashboard(run_dir)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _render_run_dashboard(run_dir: str):
|
|
87
|
+
"""Render full dashboard for a single training run."""
|
|
88
|
+
log_file = _find_log_file(run_dir)
|
|
89
|
+
history = _parse_training_log(log_file) if log_file else None
|
|
90
|
+
|
|
91
|
+
# Top metrics row
|
|
92
|
+
_render_metrics_from_history(history, run_dir)
|
|
93
|
+
|
|
94
|
+
# Main content in tabs
|
|
95
|
+
tab_curves, tab_viz, tab_gt, tab_log, tab_files = st.tabs([
|
|
96
|
+
"๐ Training Curves", "๐ผ๏ธ Visualizations", "โ
GT Verification",
|
|
97
|
+
"๐ Full Log", "๐ Checkpoints"
|
|
98
|
+
])
|
|
99
|
+
|
|
100
|
+
with tab_curves:
|
|
101
|
+
_render_curves(history, run_dir)
|
|
102
|
+
|
|
103
|
+
with tab_viz:
|
|
104
|
+
_render_visualizations(run_dir)
|
|
105
|
+
|
|
106
|
+
with tab_gt:
|
|
107
|
+
_render_gt_verification(run_dir)
|
|
108
|
+
|
|
109
|
+
with tab_log:
|
|
110
|
+
_render_full_log(log_file)
|
|
111
|
+
|
|
112
|
+
with tab_files:
|
|
113
|
+
_render_checkpoints(run_dir)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _find_log_file(run_dir: str):
|
|
117
|
+
"""Find the training log file in a run directory."""
|
|
118
|
+
logs = glob_module.glob(os.path.join(run_dir, "train_*.log"))
|
|
119
|
+
if logs:
|
|
120
|
+
return max(logs, key=os.path.getmtime)
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _parse_training_log(log_path: str):
|
|
125
|
+
"""Parse FlashDet training log and extract metrics per epoch."""
|
|
126
|
+
if not log_path or not os.path.isfile(log_path):
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
history = {
|
|
130
|
+
"epochs": [], "lr": [], "train_loss": [],
|
|
131
|
+
"val_loss": [], "mAP50": [],
|
|
132
|
+
"o2m_cls": [], "o2m_box": [], "o2o_cls": [], "o2o_box": [],
|
|
133
|
+
"ema_decay": [], "epoch_time": [],
|
|
134
|
+
"model_info": "", "device": "", "classes": [],
|
|
135
|
+
"total_epochs": 0, "batch_size": 0,
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
with open(log_path, "r") as f:
|
|
139
|
+
lines = f.readlines()
|
|
140
|
+
|
|
141
|
+
current_epoch = 0
|
|
142
|
+
for line in lines:
|
|
143
|
+
# Parse header info
|
|
144
|
+
if "Model Size:" in line:
|
|
145
|
+
history["model_info"] = line.split("Model Size:")[-1].strip()
|
|
146
|
+
if "Device:" in line:
|
|
147
|
+
history["device"] = line.split("Device:")[-1].strip()
|
|
148
|
+
if "Classes" in line and ":" in line:
|
|
149
|
+
m = re.search(r"Classes \((\d+)\): \[(.+)\]", line)
|
|
150
|
+
if m:
|
|
151
|
+
history["classes"] = [c.strip().strip("'") for c in m.group(2).split(",")]
|
|
152
|
+
if "Epochs:" in line and "Batch" not in line:
|
|
153
|
+
m = re.search(r"Epochs: (\d+)", line)
|
|
154
|
+
if m:
|
|
155
|
+
history["total_epochs"] = int(m.group(1))
|
|
156
|
+
if "Batch Size:" in line:
|
|
157
|
+
m = re.search(r"Batch Size: (\d+)", line)
|
|
158
|
+
if m:
|
|
159
|
+
history["batch_size"] = int(m.group(1))
|
|
160
|
+
|
|
161
|
+
# Parse epoch header: Epoch 1/10 (lr=0.000010, ema_decay=0.000500)
|
|
162
|
+
epoch_m = re.search(r"Epoch (\d+)/(\d+) \(lr=([\d.e+-]+),\s*ema_decay=([\d.e+-]+)\)", line)
|
|
163
|
+
if epoch_m:
|
|
164
|
+
current_epoch = int(epoch_m.group(1))
|
|
165
|
+
if history["total_epochs"] == 0:
|
|
166
|
+
history["total_epochs"] = int(epoch_m.group(2))
|
|
167
|
+
history["lr"].append(float(epoch_m.group(3)))
|
|
168
|
+
history["ema_decay"].append(float(epoch_m.group(4)))
|
|
169
|
+
|
|
170
|
+
# Parse batch loss: Epoch [1] Batch [10/16] Loss: 1156.5105 (...)
|
|
171
|
+
batch_m = re.search(
|
|
172
|
+
r"Epoch \[(\d+)\] Batch \[\d+/\d+\] Loss: ([\d.]+) "
|
|
173
|
+
r"\(.*?o2m_cls: ([\d.]+), o2m_box: ([\d.]+).*?o2o_cls: ([\d.]+), o2o_box: ([\d.]+)",
|
|
174
|
+
line
|
|
175
|
+
)
|
|
176
|
+
if batch_m:
|
|
177
|
+
epoch_num = int(batch_m.group(1))
|
|
178
|
+
total_loss = float(batch_m.group(2))
|
|
179
|
+
# Keep only the last batch of each epoch as the "epoch loss"
|
|
180
|
+
while len(history["train_loss"]) < epoch_num:
|
|
181
|
+
history["train_loss"].append(None)
|
|
182
|
+
history["train_loss"][epoch_num - 1] = total_loss
|
|
183
|
+
|
|
184
|
+
while len(history["o2m_cls"]) < epoch_num:
|
|
185
|
+
history["o2m_cls"].append(None)
|
|
186
|
+
history["o2m_box"].append(None)
|
|
187
|
+
history["o2o_cls"].append(None)
|
|
188
|
+
history["o2o_box"].append(None)
|
|
189
|
+
history["o2m_cls"][epoch_num - 1] = float(batch_m.group(3))
|
|
190
|
+
history["o2m_box"][epoch_num - 1] = float(batch_m.group(4))
|
|
191
|
+
history["o2o_cls"][epoch_num - 1] = float(batch_m.group(5))
|
|
192
|
+
history["o2o_box"][epoch_num - 1] = float(batch_m.group(6))
|
|
193
|
+
|
|
194
|
+
# Parse epoch time: Epoch time: 366.7s
|
|
195
|
+
time_m = re.search(r"Epoch time: ([\d.]+)s", line)
|
|
196
|
+
if time_m:
|
|
197
|
+
history["epoch_time"].append(float(time_m.group(1)))
|
|
198
|
+
|
|
199
|
+
# Parse validation: Validation - Loss: X.XXXX | mAP@0.5: X.XXXX
|
|
200
|
+
val_m = re.search(r"Validation.*Loss: ([\d.]+).*mAP@0.5: ([\d.]+)", line)
|
|
201
|
+
if val_m:
|
|
202
|
+
history["val_loss"].append(float(val_m.group(1)))
|
|
203
|
+
history["mAP50"].append(float(val_m.group(2)))
|
|
204
|
+
|
|
205
|
+
# Fill in epochs list
|
|
206
|
+
n = max(len(history["train_loss"]), len(history["lr"]))
|
|
207
|
+
history["epochs"] = list(range(1, n + 1))
|
|
208
|
+
|
|
209
|
+
return history
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _render_metrics_from_history(history, run_dir):
|
|
213
|
+
"""Render top metric cards from parsed history."""
|
|
214
|
+
cols = st.columns(6)
|
|
215
|
+
|
|
216
|
+
if history:
|
|
217
|
+
n_epochs = len(history["epochs"])
|
|
218
|
+
total = history["total_epochs"] or n_epochs
|
|
219
|
+
|
|
220
|
+
with cols[0]:
|
|
221
|
+
st.metric("Epoch", f"{n_epochs}/{total}")
|
|
222
|
+
with cols[1]:
|
|
223
|
+
losses = [x for x in history["train_loss"] if x is not None]
|
|
224
|
+
val = f"{losses[-1]:.1f}" if losses else "โ"
|
|
225
|
+
st.metric("Train Loss", val)
|
|
226
|
+
with cols[2]:
|
|
227
|
+
val = f"{history['val_loss'][-1]:.4f}" if history["val_loss"] else "โ"
|
|
228
|
+
st.metric("Val Loss", val)
|
|
229
|
+
with cols[3]:
|
|
230
|
+
val = f"{history['mAP50'][-1]:.4f}" if history["mAP50"] else "โ"
|
|
231
|
+
st.metric("mAP@0.5", val)
|
|
232
|
+
with cols[4]:
|
|
233
|
+
best = f"{max(history['mAP50']):.4f}" if history["mAP50"] else "โ"
|
|
234
|
+
st.metric("Best mAP", best)
|
|
235
|
+
with cols[5]:
|
|
236
|
+
lr_val = f"{history['lr'][-1]:.2e}" if history["lr"] else "โ"
|
|
237
|
+
st.metric("Current LR", lr_val)
|
|
238
|
+
|
|
239
|
+
if n_epochs < total:
|
|
240
|
+
st.progress(n_epochs / max(total, 1))
|
|
241
|
+
else:
|
|
242
|
+
for col in cols:
|
|
243
|
+
with col:
|
|
244
|
+
st.metric("โ", "No data")
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _render_curves(history, run_dir):
|
|
248
|
+
"""Render training curves โ from parsed log or saved plots."""
|
|
249
|
+
# Check for pre-generated plot images
|
|
250
|
+
plots_dir = os.path.join(run_dir, "plots")
|
|
251
|
+
training_curves_img = os.path.join(plots_dir, "training_curves.png")
|
|
252
|
+
map_curve_img = os.path.join(plots_dir, "mAP_curve.png")
|
|
253
|
+
|
|
254
|
+
if os.path.isfile(training_curves_img):
|
|
255
|
+
st.markdown("#### Saved Training Curves (matplotlib)")
|
|
256
|
+
st.image(training_curves_img, caption="training_curves.png")
|
|
257
|
+
if os.path.isfile(map_curve_img):
|
|
258
|
+
st.image(map_curve_img, caption="mAP_curve.png")
|
|
259
|
+
st.divider()
|
|
260
|
+
|
|
261
|
+
# Interactive plotly charts from parsed log
|
|
262
|
+
if not history or not history["train_loss"]:
|
|
263
|
+
st.info("No training data parsed yet. Training may still be starting.")
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
st.markdown("#### Interactive Charts (parsed from log)")
|
|
267
|
+
|
|
268
|
+
losses = [x for x in history["train_loss"] if x is not None]
|
|
269
|
+
epochs_for_loss = [i + 1 for i, x in enumerate(history["train_loss"]) if x is not None]
|
|
270
|
+
|
|
271
|
+
fig = make_subplots(
|
|
272
|
+
rows=2, cols=2,
|
|
273
|
+
subplot_titles=("Total Loss", "mAP@0.5", "Sub-Losses (o2m/o2o)", "Learning Rate"),
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Total loss
|
|
277
|
+
fig.add_trace(
|
|
278
|
+
go.Scatter(x=epochs_for_loss, y=losses, mode="lines+markers",
|
|
279
|
+
name="Train Loss", line=dict(color="#7C3AED", width=2), marker=dict(size=4)),
|
|
280
|
+
row=1, col=1,
|
|
281
|
+
)
|
|
282
|
+
if history["val_loss"]:
|
|
283
|
+
val_epochs = list(range(1, len(history["val_loss"]) + 1))
|
|
284
|
+
fig.add_trace(
|
|
285
|
+
go.Scatter(x=val_epochs, y=history["val_loss"], mode="lines+markers",
|
|
286
|
+
name="Val Loss", line=dict(color="#F59E0B", width=2), marker=dict(size=4)),
|
|
287
|
+
row=1, col=1,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# mAP
|
|
291
|
+
if history["mAP50"]:
|
|
292
|
+
map_epochs = list(range(1, len(history["mAP50"]) + 1))
|
|
293
|
+
fig.add_trace(
|
|
294
|
+
go.Scatter(x=map_epochs, y=history["mAP50"], mode="lines+markers",
|
|
295
|
+
name="mAP@0.5", line=dict(color="#10B981", width=2), marker=dict(size=5)),
|
|
296
|
+
row=1, col=2,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Sub-losses
|
|
300
|
+
o2m_cls = [x for x in history["o2m_cls"] if x is not None]
|
|
301
|
+
o2o_cls = [x for x in history["o2o_cls"] if x is not None]
|
|
302
|
+
o2m_box = [x for x in history["o2m_box"] if x is not None]
|
|
303
|
+
if o2m_cls:
|
|
304
|
+
ep = list(range(1, len(o2m_cls) + 1))
|
|
305
|
+
fig.add_trace(
|
|
306
|
+
go.Scatter(x=ep, y=o2m_cls, mode="lines", name="o2m_cls",
|
|
307
|
+
line=dict(color="#EF4444", width=1.5)),
|
|
308
|
+
row=2, col=1,
|
|
309
|
+
)
|
|
310
|
+
if o2o_cls:
|
|
311
|
+
ep = list(range(1, len(o2o_cls) + 1))
|
|
312
|
+
fig.add_trace(
|
|
313
|
+
go.Scatter(x=ep, y=o2o_cls, mode="lines", name="o2o_cls",
|
|
314
|
+
line=dict(color="#3B82F6", width=1.5)),
|
|
315
|
+
row=2, col=1,
|
|
316
|
+
)
|
|
317
|
+
if o2m_box:
|
|
318
|
+
ep = list(range(1, len(o2m_box) + 1))
|
|
319
|
+
fig.add_trace(
|
|
320
|
+
go.Scatter(x=ep, y=o2m_box, mode="lines", name="o2m_box",
|
|
321
|
+
line=dict(color="#F97316", width=1.5, dash="dash")),
|
|
322
|
+
row=2, col=1,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# LR
|
|
326
|
+
if history["lr"]:
|
|
327
|
+
lr_epochs = list(range(1, len(history["lr"]) + 1))
|
|
328
|
+
fig.add_trace(
|
|
329
|
+
go.Scatter(x=lr_epochs, y=history["lr"], mode="lines",
|
|
330
|
+
name="LR", line=dict(color="#6366F1", width=2)),
|
|
331
|
+
row=2, col=2,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
fig.update_layout(
|
|
335
|
+
template="plotly_white",
|
|
336
|
+
height=500,
|
|
337
|
+
margin=dict(l=40, r=20, t=40, b=40),
|
|
338
|
+
showlegend=True,
|
|
339
|
+
legend=dict(orientation="h", yanchor="bottom", y=1.05),
|
|
340
|
+
)
|
|
341
|
+
st.plotly_chart(fig, use_container_width=True)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _render_visualizations(run_dir):
|
|
345
|
+
"""Show per-epoch GT vs Prediction visualizations."""
|
|
346
|
+
vis_dir = os.path.join(run_dir, "visualizations")
|
|
347
|
+
|
|
348
|
+
if not os.path.isdir(vis_dir):
|
|
349
|
+
st.info("No visualizations directory found. Training may not have generated any yet.")
|
|
350
|
+
return
|
|
351
|
+
|
|
352
|
+
images = sorted([
|
|
353
|
+
f for f in os.listdir(vis_dir)
|
|
354
|
+
if f.endswith(".jpg") and f != "latest_visualization.jpg"
|
|
355
|
+
])
|
|
356
|
+
|
|
357
|
+
if not images:
|
|
358
|
+
st.info("No visualization images found yet.")
|
|
359
|
+
return
|
|
360
|
+
|
|
361
|
+
st.markdown(f"**{len(images)} epoch visualizations** (GT vs Predictions)")
|
|
362
|
+
|
|
363
|
+
# Show latest first
|
|
364
|
+
latest = os.path.join(vis_dir, "latest_visualization.jpg")
|
|
365
|
+
if os.path.isfile(latest):
|
|
366
|
+
st.markdown("#### Latest Visualization")
|
|
367
|
+
st.image(latest, caption="Latest (GT left | Predictions right)")
|
|
368
|
+
|
|
369
|
+
st.divider()
|
|
370
|
+
st.markdown("#### All Epoch Visualizations")
|
|
371
|
+
|
|
372
|
+
# Display in a grid (2 per row)
|
|
373
|
+
for i in range(0, len(images), 2):
|
|
374
|
+
cols = st.columns(2)
|
|
375
|
+
for j, col in enumerate(cols):
|
|
376
|
+
idx = i + j
|
|
377
|
+
if idx < len(images):
|
|
378
|
+
img_path = os.path.join(vis_dir, images[idx])
|
|
379
|
+
with col:
|
|
380
|
+
st.image(img_path, caption=images[idx], use_container_width=True)
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _render_gt_verification(run_dir):
|
|
384
|
+
"""Show GT verification images and report."""
|
|
385
|
+
gt_dir = os.path.join(run_dir, "gt_verification")
|
|
386
|
+
|
|
387
|
+
if not os.path.isdir(gt_dir):
|
|
388
|
+
st.info("No GT verification data found.")
|
|
389
|
+
return
|
|
390
|
+
|
|
391
|
+
# Summary
|
|
392
|
+
summary_path = os.path.join(gt_dir, "verification_summary.txt")
|
|
393
|
+
if os.path.isfile(summary_path):
|
|
394
|
+
with open(summary_path) as f:
|
|
395
|
+
st.code(f.read(), language="text")
|
|
396
|
+
|
|
397
|
+
# Report JSON highlights
|
|
398
|
+
report_path = os.path.join(gt_dir, "verification_report.json")
|
|
399
|
+
if os.path.isfile(report_path):
|
|
400
|
+
with open(report_path) as f:
|
|
401
|
+
report = json.load(f)
|
|
402
|
+
|
|
403
|
+
passed = report.get("passed", False)
|
|
404
|
+
if passed:
|
|
405
|
+
st.success("Annotation verification: PASSED")
|
|
406
|
+
else:
|
|
407
|
+
st.error("Annotation verification: FAILED")
|
|
408
|
+
|
|
409
|
+
cols = st.columns(4)
|
|
410
|
+
train_coco = report.get("splits", {}).get("train", {}).get("coco", {})
|
|
411
|
+
val_coco = report.get("splits", {}).get("val", {}).get("coco", {})
|
|
412
|
+
with cols[0]:
|
|
413
|
+
st.metric("Train Images", train_coco.get("num_images", 0))
|
|
414
|
+
with cols[1]:
|
|
415
|
+
st.metric("Train Annotations", train_coco.get("num_annotations", 0))
|
|
416
|
+
with cols[2]:
|
|
417
|
+
st.metric("Val Images", val_coco.get("num_images", 0))
|
|
418
|
+
with cols[3]:
|
|
419
|
+
st.metric("Val Annotations", val_coco.get("num_annotations", 0))
|
|
420
|
+
|
|
421
|
+
# GT Images
|
|
422
|
+
st.divider()
|
|
423
|
+
raw_dir = os.path.join(gt_dir, "images", "raw")
|
|
424
|
+
dl_dir = os.path.join(gt_dir, "images", "dataloader")
|
|
425
|
+
|
|
426
|
+
gt_tab_raw, gt_tab_dl = st.tabs(["Raw GT Images", "Dataloader GT Images"])
|
|
427
|
+
|
|
428
|
+
with gt_tab_raw:
|
|
429
|
+
_render_image_grid(raw_dir, "Raw ground truth with bounding boxes")
|
|
430
|
+
|
|
431
|
+
with gt_tab_dl:
|
|
432
|
+
_render_image_grid(dl_dir, "After dataloader transforms (letterbox, normalize)")
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _render_image_grid(img_dir, description):
|
|
436
|
+
"""Render images from a directory in a grid."""
|
|
437
|
+
if not os.path.isdir(img_dir):
|
|
438
|
+
st.info(f"Directory not found: {img_dir}")
|
|
439
|
+
return
|
|
440
|
+
|
|
441
|
+
images = sorted([f for f in os.listdir(img_dir) if f.endswith((".jpg", ".png"))])
|
|
442
|
+
if not images:
|
|
443
|
+
st.info("No images found.")
|
|
444
|
+
return
|
|
445
|
+
|
|
446
|
+
st.caption(f"{description} โ {len(images)} images")
|
|
447
|
+
|
|
448
|
+
for i in range(0, len(images), 4):
|
|
449
|
+
cols = st.columns(4)
|
|
450
|
+
for j, col in enumerate(cols):
|
|
451
|
+
idx = i + j
|
|
452
|
+
if idx < len(images):
|
|
453
|
+
img_path = os.path.join(img_dir, images[idx])
|
|
454
|
+
with col:
|
|
455
|
+
st.image(img_path, caption=images[idx][:30], use_container_width=True)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _render_full_log(log_path):
|
|
459
|
+
"""Show full training log file."""
|
|
460
|
+
if not log_path or not os.path.isfile(log_path):
|
|
461
|
+
st.info("No training log found.")
|
|
462
|
+
return
|
|
463
|
+
|
|
464
|
+
with open(log_path) as f:
|
|
465
|
+
content = f.read()
|
|
466
|
+
|
|
467
|
+
st.markdown(f"**Log file:** `{os.path.basename(log_path)}`")
|
|
468
|
+
|
|
469
|
+
# Show last 100 lines by default, with option to expand
|
|
470
|
+
lines = content.strip().split("\n")
|
|
471
|
+
show_all = st.checkbox("Show full log", value=False, key="show_full_log")
|
|
472
|
+
|
|
473
|
+
if show_all:
|
|
474
|
+
st.code(content, language="bash")
|
|
475
|
+
else:
|
|
476
|
+
st.code("\n".join(lines[-50:]), language="bash")
|
|
477
|
+
if len(lines) > 50:
|
|
478
|
+
st.caption(f"Showing last 50 of {len(lines)} lines. Check 'Show full log' to see all.")
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def _render_checkpoints(run_dir):
|
|
482
|
+
"""Show checkpoint files and their sizes."""
|
|
483
|
+
st.markdown("#### Saved Checkpoints & Weights")
|
|
484
|
+
|
|
485
|
+
files_info = []
|
|
486
|
+
for f in sorted(os.listdir(run_dir)):
|
|
487
|
+
fpath = os.path.join(run_dir, f)
|
|
488
|
+
if os.path.isfile(fpath) and f.endswith((".pth", ".json", ".csv", ".log")):
|
|
489
|
+
size = os.path.getsize(fpath)
|
|
490
|
+
if size > 1024 * 1024:
|
|
491
|
+
size_str = f"{size / (1024*1024):.1f} MB"
|
|
492
|
+
elif size > 1024:
|
|
493
|
+
size_str = f"{size / 1024:.1f} KB"
|
|
494
|
+
else:
|
|
495
|
+
size_str = f"{size} B"
|
|
496
|
+
files_info.append({"File": f, "Size": size_str, "Type": _file_type(f)})
|
|
497
|
+
|
|
498
|
+
if files_info:
|
|
499
|
+
st.dataframe(files_info, use_container_width=True, hide_index=True)
|
|
500
|
+
else:
|
|
501
|
+
st.info("No checkpoint files found.")
|
|
502
|
+
|
|
503
|
+
# results.json
|
|
504
|
+
results_path = os.path.join(run_dir, "results.json")
|
|
505
|
+
if os.path.isfile(results_path):
|
|
506
|
+
st.divider()
|
|
507
|
+
st.markdown("#### Training Results Summary")
|
|
508
|
+
with open(results_path) as f:
|
|
509
|
+
results = json.load(f)
|
|
510
|
+
|
|
511
|
+
cols = st.columns(4)
|
|
512
|
+
with cols[0]:
|
|
513
|
+
st.metric("Epochs Trained", results.get("epochs_trained", "?"))
|
|
514
|
+
with cols[1]:
|
|
515
|
+
st.metric("Best mAP@0.5", f"{results.get('best_mAP50', 0):.4f}")
|
|
516
|
+
with cols[2]:
|
|
517
|
+
st.metric("Best Val Loss", f"{results.get('best_val_loss', 0):.4f}")
|
|
518
|
+
with cols[3]:
|
|
519
|
+
st.metric("Model Params", f"{results.get('model_params_M', 0):.2f}M")
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def _file_type(filename):
|
|
523
|
+
"""Categorize a file by its name."""
|
|
524
|
+
if "best" in filename:
|
|
525
|
+
return "Best checkpoint"
|
|
526
|
+
if "last" in filename:
|
|
527
|
+
return "Latest checkpoint"
|
|
528
|
+
if "inference" in filename:
|
|
529
|
+
return "Inference weights"
|
|
530
|
+
if "fp16" in filename:
|
|
531
|
+
return "FP16 weights"
|
|
532
|
+
if filename.endswith(".json"):
|
|
533
|
+
return "Results/Report"
|
|
534
|
+
if filename.endswith(".csv"):
|
|
535
|
+
return "Training log CSV"
|
|
536
|
+
if filename.endswith(".log"):
|
|
537
|
+
return "Training log"
|
|
538
|
+
return "Other"
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
# ---- Start New Training Tab ----
|
|
542
|
+
|
|
543
|
+
def _render_start_tab():
|
|
544
|
+
"""Tab: start a new training run."""
|
|
545
|
+
st.markdown("### Start New Training")
|
|
546
|
+
|
|
547
|
+
col_info, col_actions = st.columns([3, 1])
|
|
548
|
+
|
|
549
|
+
with col_info:
|
|
550
|
+
arch = st.session_state.get("model_arch", "FlashDet-Pico")
|
|
551
|
+
dataset = st.session_state.get("dataset_name", "Not selected")
|
|
552
|
+
epochs = st.session_state.get("epochs", 100)
|
|
553
|
+
bs = st.session_state.get("batch_size", 16)
|
|
554
|
+
lr = st.session_state.get("lr", 0.001)
|
|
555
|
+
|
|
556
|
+
st.markdown(
|
|
557
|
+
f"**Config:** `{arch}` ยท Dataset: `{dataset}` ยท "
|
|
558
|
+
f"Epochs: `{epochs}` ยท BS: `{bs}` ยท LR: `{lr}`"
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
with col_actions:
|
|
562
|
+
if not st.session_state.get("training_active", False):
|
|
563
|
+
if st.button("โถ๏ธ Start Training", use_container_width=True, type="primary",
|
|
564
|
+
key="btn_start_training"):
|
|
565
|
+
st.session_state["training_active"] = True
|
|
566
|
+
st.session_state["training_status"] = "Running"
|
|
567
|
+
_start_training()
|
|
568
|
+
else:
|
|
569
|
+
if st.button("โน๏ธ Stop Training", use_container_width=True, type="secondary",
|
|
570
|
+
key="btn_stop_training"):
|
|
571
|
+
st.session_state["training_active"] = False
|
|
572
|
+
st.session_state["training_status"] = "Stopped"
|
|
573
|
+
st.rerun()
|
|
574
|
+
|
|
575
|
+
if st.session_state.get("training_active"):
|
|
576
|
+
st.info("Training is running... Switch to 'Monitor Run' tab and select latest run to see progress.")
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def _start_training():
|
|
580
|
+
"""Start training โ real FlashDet or demo mode."""
|
|
581
|
+
if importlib.util.find_spec("flashdet") is not None:
|
|
582
|
+
_run_flashdet_training()
|
|
583
|
+
else:
|
|
584
|
+
st.warning("FlashDet not installed. Install with `pip install flashdet` to run real training.")
|
|
585
|
+
st.session_state["training_active"] = False
|
|
586
|
+
st.session_state["training_status"] = "FlashDet not found"
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def _run_flashdet_training():
|
|
590
|
+
"""Run actual FlashDet training via subprocess (non-blocking)."""
|
|
591
|
+
import subprocess
|
|
592
|
+
import sys
|
|
593
|
+
|
|
594
|
+
size_map = {
|
|
595
|
+
"FlashDet-Pico": "p", "FlashDet-Nano": "n", "FlashDet-Small": "s",
|
|
596
|
+
"FlashDet-Medium": "m", "FlashDet-Large": "l", "FlashDet-X": "x",
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
model_arch = st.session_state.get("model_arch", "FlashDet-Pico")
|
|
600
|
+
save_dir = st.session_state.get("save_dir", os.path.join(_get_default_workspace(), "flashstudio_run"))
|
|
601
|
+
|
|
602
|
+
# Auto-detect device
|
|
603
|
+
device = "cpu"
|
|
604
|
+
try:
|
|
605
|
+
import torch
|
|
606
|
+
if torch.cuda.is_available():
|
|
607
|
+
device = "cuda"
|
|
608
|
+
except ImportError:
|
|
609
|
+
pass
|
|
610
|
+
|
|
611
|
+
# Resolve dataset paths from session state
|
|
612
|
+
train_images = st.session_state.get("train_img_path", "")
|
|
613
|
+
val_images = st.session_state.get("val_img_path", "")
|
|
614
|
+
|
|
615
|
+
# Try to auto-detect from dataset download command
|
|
616
|
+
if not train_images or not os.path.isdir(train_images):
|
|
617
|
+
dataset_name = st.session_state.get("dataset_name", "")
|
|
618
|
+
dataset_id = dataset_name.lower().replace(" ", "").replace("(demo)", "sample")
|
|
619
|
+
for candidate_id in ["sample", "coco2017", "coco2017-val", "voc2007", "voc2012"]:
|
|
620
|
+
if candidate_id in dataset_id or dataset_id in candidate_id:
|
|
621
|
+
candidate_dir = os.path.join("data", candidate_id)
|
|
622
|
+
if os.path.isdir(os.path.join(candidate_dir, "train")):
|
|
623
|
+
train_images = os.path.join(candidate_dir, "train")
|
|
624
|
+
val_images = os.path.join(candidate_dir, "valid")
|
|
625
|
+
break
|
|
626
|
+
|
|
627
|
+
if not train_images or not val_images:
|
|
628
|
+
st.error(
|
|
629
|
+
"No dataset paths configured. Please go to the **Data** page and either:\n"
|
|
630
|
+
"- Enter train/val image paths manually, or\n"
|
|
631
|
+
"- Download a dataset first (`flashdet download --dataset sample`)"
|
|
632
|
+
)
|
|
633
|
+
st.session_state["training_active"] = False
|
|
634
|
+
st.session_state["training_status"] = "No dataset"
|
|
635
|
+
return
|
|
636
|
+
|
|
637
|
+
cmd = [
|
|
638
|
+
sys.executable, "-m", "flashdet.cli", "train",
|
|
639
|
+
"--model-size", size_map.get(model_arch, "n"),
|
|
640
|
+
"--epochs", str(st.session_state.get("epochs", 100)),
|
|
641
|
+
"--batch-size", str(st.session_state.get("batch_size", 16)),
|
|
642
|
+
"--lr", str(st.session_state.get("lr", 1e-3)),
|
|
643
|
+
"--save-dir", save_dir,
|
|
644
|
+
"--device", device,
|
|
645
|
+
"--train-images", train_images,
|
|
646
|
+
"--val-images", val_images,
|
|
647
|
+
]
|
|
648
|
+
|
|
649
|
+
if st.session_state.get("amp", True) and device == "cuda":
|
|
650
|
+
cmd.append("--amp")
|
|
651
|
+
if st.session_state.get("aug_mosaic", True):
|
|
652
|
+
cmd.append("--mosaic")
|
|
653
|
+
if st.session_state.get("aug_mixup", False):
|
|
654
|
+
cmd.append("--mixup")
|
|
655
|
+
|
|
656
|
+
optimizer = st.session_state.get("optimizer", "AdamW").lower()
|
|
657
|
+
if optimizer in ("adamw", "sgd", "musgd"):
|
|
658
|
+
cmd.extend(["--optimizer", optimizer])
|
|
659
|
+
|
|
660
|
+
workers = st.session_state.get("num_workers", 2)
|
|
661
|
+
cmd.extend(["--workers", str(workers)])
|
|
662
|
+
|
|
663
|
+
# Run as non-blocking subprocess so Streamlit UI stays responsive
|
|
664
|
+
process = subprocess.Popen(
|
|
665
|
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1,
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
st.session_state["training_pid"] = process.pid
|
|
669
|
+
|
|
670
|
+
st.info(f"Training started (PID {process.pid}): `{' '.join(cmd)}`")
|
|
671
|
+
st.caption(f"Output will be saved to: `{save_dir}`")
|
|
672
|
+
st.caption("Switch to 'Monitor Run' tab to see live progress.")
|
|
File without changes
|