flashstudio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,672 @@
1
+ """FlashStudio โ€” Training Dashboard Page (reads real FlashDet workspace output)."""
2
+
3
+ import os
4
+ import re
5
+ import json
6
+ import glob as glob_module
7
+ import time
8
+ import importlib.util
9
+
10
+ import streamlit as st
11
+ import plotly.graph_objects as go
12
+ from plotly.subplots import make_subplots
13
+ import numpy as np
14
+ from PIL import Image
15
+
16
+
17
+ def _get_default_workspace():
18
+ """Auto-detect workspace: use save_dir from config or scan common locations."""
19
+ save_dir = st.session_state.get("save_dir", "")
20
+ if save_dir and os.path.isdir(save_dir):
21
+ return save_dir
22
+
23
+ # Try to find FlashDet workspace relative to this package
24
+ candidates = [
25
+ os.path.join(os.getcwd(), "workspace"),
26
+ os.path.join(os.getcwd(), "..", "FlashDet", "workspace"),
27
+ os.path.join(os.path.dirname(__file__), "..", "..", "..", "FlashDet", "workspace"),
28
+ ]
29
+ for c in candidates:
30
+ c = os.path.abspath(c)
31
+ if os.path.isdir(c):
32
+ return c
33
+
34
+ return os.getcwd()
35
+
36
+
37
+ def render_training_page():
38
+ """Render training dashboard โ€” monitor real FlashDet runs or start new training."""
39
+ from flashstudio.components.styles import render_page_header
40
+ render_page_header("๐Ÿ‹๏ธ", "Training Dashboard",
41
+ "Monitor real FlashDet training runs โ€” logs, curves, visualizations.")
42
+
43
+ if "training_active" not in st.session_state:
44
+ st.session_state["training_active"] = False
45
+ st.session_state["training_status"] = "Not started"
46
+
47
+ tab_start, tab_monitor = st.tabs(["โ–ถ๏ธ Start Training", "๐Ÿ“Š Monitor Run"])
48
+
49
+ with tab_start:
50
+ _render_start_tab()
51
+
52
+ with tab_monitor:
53
+ _render_monitor_tab()
54
+
55
+
56
+ def _render_monitor_tab():
57
+ """Tab: monitor an existing training run from workspace."""
58
+ workspace = st.text_input(
59
+ "Workspace path",
60
+ value=_get_default_workspace(),
61
+ key="workspace_path",
62
+ help="Path containing training run folders (auto-detected from save_dir config or FlashDet workspace)",
63
+ )
64
+
65
+ if not os.path.isdir(workspace):
66
+ st.warning(f"Workspace not found: `{workspace}`")
67
+ return
68
+
69
+ runs = sorted(
70
+ [d for d in os.listdir(workspace) if os.path.isdir(os.path.join(workspace, d))],
71
+ key=lambda d: os.path.getmtime(os.path.join(workspace, d)),
72
+ reverse=True,
73
+ )
74
+
75
+ if not runs:
76
+ st.info("No training runs found in workspace.")
77
+ return
78
+
79
+ selected_run = st.selectbox("Select Training Run", runs, key="selected_run")
80
+ run_dir = os.path.join(workspace, selected_run)
81
+
82
+ st.divider()
83
+ _render_run_dashboard(run_dir)
84
+
85
+
86
+ def _render_run_dashboard(run_dir: str):
87
+ """Render full dashboard for a single training run."""
88
+ log_file = _find_log_file(run_dir)
89
+ history = _parse_training_log(log_file) if log_file else None
90
+
91
+ # Top metrics row
92
+ _render_metrics_from_history(history, run_dir)
93
+
94
+ # Main content in tabs
95
+ tab_curves, tab_viz, tab_gt, tab_log, tab_files = st.tabs([
96
+ "๐Ÿ“ˆ Training Curves", "๐Ÿ–ผ๏ธ Visualizations", "โœ… GT Verification",
97
+ "๐Ÿ“œ Full Log", "๐Ÿ“ Checkpoints"
98
+ ])
99
+
100
+ with tab_curves:
101
+ _render_curves(history, run_dir)
102
+
103
+ with tab_viz:
104
+ _render_visualizations(run_dir)
105
+
106
+ with tab_gt:
107
+ _render_gt_verification(run_dir)
108
+
109
+ with tab_log:
110
+ _render_full_log(log_file)
111
+
112
+ with tab_files:
113
+ _render_checkpoints(run_dir)
114
+
115
+
116
+ def _find_log_file(run_dir: str):
117
+ """Find the training log file in a run directory."""
118
+ logs = glob_module.glob(os.path.join(run_dir, "train_*.log"))
119
+ if logs:
120
+ return max(logs, key=os.path.getmtime)
121
+ return None
122
+
123
+
124
+ def _parse_training_log(log_path: str):
125
+ """Parse FlashDet training log and extract metrics per epoch."""
126
+ if not log_path or not os.path.isfile(log_path):
127
+ return None
128
+
129
+ history = {
130
+ "epochs": [], "lr": [], "train_loss": [],
131
+ "val_loss": [], "mAP50": [],
132
+ "o2m_cls": [], "o2m_box": [], "o2o_cls": [], "o2o_box": [],
133
+ "ema_decay": [], "epoch_time": [],
134
+ "model_info": "", "device": "", "classes": [],
135
+ "total_epochs": 0, "batch_size": 0,
136
+ }
137
+
138
+ with open(log_path, "r") as f:
139
+ lines = f.readlines()
140
+
141
+ current_epoch = 0
142
+ for line in lines:
143
+ # Parse header info
144
+ if "Model Size:" in line:
145
+ history["model_info"] = line.split("Model Size:")[-1].strip()
146
+ if "Device:" in line:
147
+ history["device"] = line.split("Device:")[-1].strip()
148
+ if "Classes" in line and ":" in line:
149
+ m = re.search(r"Classes \((\d+)\): \[(.+)\]", line)
150
+ if m:
151
+ history["classes"] = [c.strip().strip("'") for c in m.group(2).split(",")]
152
+ if "Epochs:" in line and "Batch" not in line:
153
+ m = re.search(r"Epochs: (\d+)", line)
154
+ if m:
155
+ history["total_epochs"] = int(m.group(1))
156
+ if "Batch Size:" in line:
157
+ m = re.search(r"Batch Size: (\d+)", line)
158
+ if m:
159
+ history["batch_size"] = int(m.group(1))
160
+
161
+ # Parse epoch header: Epoch 1/10 (lr=0.000010, ema_decay=0.000500)
162
+ epoch_m = re.search(r"Epoch (\d+)/(\d+) \(lr=([\d.e+-]+),\s*ema_decay=([\d.e+-]+)\)", line)
163
+ if epoch_m:
164
+ current_epoch = int(epoch_m.group(1))
165
+ if history["total_epochs"] == 0:
166
+ history["total_epochs"] = int(epoch_m.group(2))
167
+ history["lr"].append(float(epoch_m.group(3)))
168
+ history["ema_decay"].append(float(epoch_m.group(4)))
169
+
170
+ # Parse batch loss: Epoch [1] Batch [10/16] Loss: 1156.5105 (...)
171
+ batch_m = re.search(
172
+ r"Epoch \[(\d+)\] Batch \[\d+/\d+\] Loss: ([\d.]+) "
173
+ r"\(.*?o2m_cls: ([\d.]+), o2m_box: ([\d.]+).*?o2o_cls: ([\d.]+), o2o_box: ([\d.]+)",
174
+ line
175
+ )
176
+ if batch_m:
177
+ epoch_num = int(batch_m.group(1))
178
+ total_loss = float(batch_m.group(2))
179
+ # Keep only the last batch of each epoch as the "epoch loss"
180
+ while len(history["train_loss"]) < epoch_num:
181
+ history["train_loss"].append(None)
182
+ history["train_loss"][epoch_num - 1] = total_loss
183
+
184
+ while len(history["o2m_cls"]) < epoch_num:
185
+ history["o2m_cls"].append(None)
186
+ history["o2m_box"].append(None)
187
+ history["o2o_cls"].append(None)
188
+ history["o2o_box"].append(None)
189
+ history["o2m_cls"][epoch_num - 1] = float(batch_m.group(3))
190
+ history["o2m_box"][epoch_num - 1] = float(batch_m.group(4))
191
+ history["o2o_cls"][epoch_num - 1] = float(batch_m.group(5))
192
+ history["o2o_box"][epoch_num - 1] = float(batch_m.group(6))
193
+
194
+ # Parse epoch time: Epoch time: 366.7s
195
+ time_m = re.search(r"Epoch time: ([\d.]+)s", line)
196
+ if time_m:
197
+ history["epoch_time"].append(float(time_m.group(1)))
198
+
199
+ # Parse validation: Validation - Loss: X.XXXX | mAP@0.5: X.XXXX
200
+ val_m = re.search(r"Validation.*Loss: ([\d.]+).*mAP@0.5: ([\d.]+)", line)
201
+ if val_m:
202
+ history["val_loss"].append(float(val_m.group(1)))
203
+ history["mAP50"].append(float(val_m.group(2)))
204
+
205
+ # Fill in epochs list
206
+ n = max(len(history["train_loss"]), len(history["lr"]))
207
+ history["epochs"] = list(range(1, n + 1))
208
+
209
+ return history
210
+
211
+
212
+ def _render_metrics_from_history(history, run_dir):
213
+ """Render top metric cards from parsed history."""
214
+ cols = st.columns(6)
215
+
216
+ if history:
217
+ n_epochs = len(history["epochs"])
218
+ total = history["total_epochs"] or n_epochs
219
+
220
+ with cols[0]:
221
+ st.metric("Epoch", f"{n_epochs}/{total}")
222
+ with cols[1]:
223
+ losses = [x for x in history["train_loss"] if x is not None]
224
+ val = f"{losses[-1]:.1f}" if losses else "โ€”"
225
+ st.metric("Train Loss", val)
226
+ with cols[2]:
227
+ val = f"{history['val_loss'][-1]:.4f}" if history["val_loss"] else "โ€”"
228
+ st.metric("Val Loss", val)
229
+ with cols[3]:
230
+ val = f"{history['mAP50'][-1]:.4f}" if history["mAP50"] else "โ€”"
231
+ st.metric("mAP@0.5", val)
232
+ with cols[4]:
233
+ best = f"{max(history['mAP50']):.4f}" if history["mAP50"] else "โ€”"
234
+ st.metric("Best mAP", best)
235
+ with cols[5]:
236
+ lr_val = f"{history['lr'][-1]:.2e}" if history["lr"] else "โ€”"
237
+ st.metric("Current LR", lr_val)
238
+
239
+ if n_epochs < total:
240
+ st.progress(n_epochs / max(total, 1))
241
+ else:
242
+ for col in cols:
243
+ with col:
244
+ st.metric("โ€”", "No data")
245
+
246
+
247
+ def _render_curves(history, run_dir):
248
+ """Render training curves โ€” from parsed log or saved plots."""
249
+ # Check for pre-generated plot images
250
+ plots_dir = os.path.join(run_dir, "plots")
251
+ training_curves_img = os.path.join(plots_dir, "training_curves.png")
252
+ map_curve_img = os.path.join(plots_dir, "mAP_curve.png")
253
+
254
+ if os.path.isfile(training_curves_img):
255
+ st.markdown("#### Saved Training Curves (matplotlib)")
256
+ st.image(training_curves_img, caption="training_curves.png")
257
+ if os.path.isfile(map_curve_img):
258
+ st.image(map_curve_img, caption="mAP_curve.png")
259
+ st.divider()
260
+
261
+ # Interactive plotly charts from parsed log
262
+ if not history or not history["train_loss"]:
263
+ st.info("No training data parsed yet. Training may still be starting.")
264
+ return
265
+
266
+ st.markdown("#### Interactive Charts (parsed from log)")
267
+
268
+ losses = [x for x in history["train_loss"] if x is not None]
269
+ epochs_for_loss = [i + 1 for i, x in enumerate(history["train_loss"]) if x is not None]
270
+
271
+ fig = make_subplots(
272
+ rows=2, cols=2,
273
+ subplot_titles=("Total Loss", "mAP@0.5", "Sub-Losses (o2m/o2o)", "Learning Rate"),
274
+ )
275
+
276
+ # Total loss
277
+ fig.add_trace(
278
+ go.Scatter(x=epochs_for_loss, y=losses, mode="lines+markers",
279
+ name="Train Loss", line=dict(color="#7C3AED", width=2), marker=dict(size=4)),
280
+ row=1, col=1,
281
+ )
282
+ if history["val_loss"]:
283
+ val_epochs = list(range(1, len(history["val_loss"]) + 1))
284
+ fig.add_trace(
285
+ go.Scatter(x=val_epochs, y=history["val_loss"], mode="lines+markers",
286
+ name="Val Loss", line=dict(color="#F59E0B", width=2), marker=dict(size=4)),
287
+ row=1, col=1,
288
+ )
289
+
290
+ # mAP
291
+ if history["mAP50"]:
292
+ map_epochs = list(range(1, len(history["mAP50"]) + 1))
293
+ fig.add_trace(
294
+ go.Scatter(x=map_epochs, y=history["mAP50"], mode="lines+markers",
295
+ name="mAP@0.5", line=dict(color="#10B981", width=2), marker=dict(size=5)),
296
+ row=1, col=2,
297
+ )
298
+
299
+ # Sub-losses
300
+ o2m_cls = [x for x in history["o2m_cls"] if x is not None]
301
+ o2o_cls = [x for x in history["o2o_cls"] if x is not None]
302
+ o2m_box = [x for x in history["o2m_box"] if x is not None]
303
+ if o2m_cls:
304
+ ep = list(range(1, len(o2m_cls) + 1))
305
+ fig.add_trace(
306
+ go.Scatter(x=ep, y=o2m_cls, mode="lines", name="o2m_cls",
307
+ line=dict(color="#EF4444", width=1.5)),
308
+ row=2, col=1,
309
+ )
310
+ if o2o_cls:
311
+ ep = list(range(1, len(o2o_cls) + 1))
312
+ fig.add_trace(
313
+ go.Scatter(x=ep, y=o2o_cls, mode="lines", name="o2o_cls",
314
+ line=dict(color="#3B82F6", width=1.5)),
315
+ row=2, col=1,
316
+ )
317
+ if o2m_box:
318
+ ep = list(range(1, len(o2m_box) + 1))
319
+ fig.add_trace(
320
+ go.Scatter(x=ep, y=o2m_box, mode="lines", name="o2m_box",
321
+ line=dict(color="#F97316", width=1.5, dash="dash")),
322
+ row=2, col=1,
323
+ )
324
+
325
+ # LR
326
+ if history["lr"]:
327
+ lr_epochs = list(range(1, len(history["lr"]) + 1))
328
+ fig.add_trace(
329
+ go.Scatter(x=lr_epochs, y=history["lr"], mode="lines",
330
+ name="LR", line=dict(color="#6366F1", width=2)),
331
+ row=2, col=2,
332
+ )
333
+
334
+ fig.update_layout(
335
+ template="plotly_white",
336
+ height=500,
337
+ margin=dict(l=40, r=20, t=40, b=40),
338
+ showlegend=True,
339
+ legend=dict(orientation="h", yanchor="bottom", y=1.05),
340
+ )
341
+ st.plotly_chart(fig, use_container_width=True)
342
+
343
+
344
+ def _render_visualizations(run_dir):
345
+ """Show per-epoch GT vs Prediction visualizations."""
346
+ vis_dir = os.path.join(run_dir, "visualizations")
347
+
348
+ if not os.path.isdir(vis_dir):
349
+ st.info("No visualizations directory found. Training may not have generated any yet.")
350
+ return
351
+
352
+ images = sorted([
353
+ f for f in os.listdir(vis_dir)
354
+ if f.endswith(".jpg") and f != "latest_visualization.jpg"
355
+ ])
356
+
357
+ if not images:
358
+ st.info("No visualization images found yet.")
359
+ return
360
+
361
+ st.markdown(f"**{len(images)} epoch visualizations** (GT vs Predictions)")
362
+
363
+ # Show latest first
364
+ latest = os.path.join(vis_dir, "latest_visualization.jpg")
365
+ if os.path.isfile(latest):
366
+ st.markdown("#### Latest Visualization")
367
+ st.image(latest, caption="Latest (GT left | Predictions right)")
368
+
369
+ st.divider()
370
+ st.markdown("#### All Epoch Visualizations")
371
+
372
+ # Display in a grid (2 per row)
373
+ for i in range(0, len(images), 2):
374
+ cols = st.columns(2)
375
+ for j, col in enumerate(cols):
376
+ idx = i + j
377
+ if idx < len(images):
378
+ img_path = os.path.join(vis_dir, images[idx])
379
+ with col:
380
+ st.image(img_path, caption=images[idx], use_container_width=True)
381
+
382
+
383
+ def _render_gt_verification(run_dir):
384
+ """Show GT verification images and report."""
385
+ gt_dir = os.path.join(run_dir, "gt_verification")
386
+
387
+ if not os.path.isdir(gt_dir):
388
+ st.info("No GT verification data found.")
389
+ return
390
+
391
+ # Summary
392
+ summary_path = os.path.join(gt_dir, "verification_summary.txt")
393
+ if os.path.isfile(summary_path):
394
+ with open(summary_path) as f:
395
+ st.code(f.read(), language="text")
396
+
397
+ # Report JSON highlights
398
+ report_path = os.path.join(gt_dir, "verification_report.json")
399
+ if os.path.isfile(report_path):
400
+ with open(report_path) as f:
401
+ report = json.load(f)
402
+
403
+ passed = report.get("passed", False)
404
+ if passed:
405
+ st.success("Annotation verification: PASSED")
406
+ else:
407
+ st.error("Annotation verification: FAILED")
408
+
409
+ cols = st.columns(4)
410
+ train_coco = report.get("splits", {}).get("train", {}).get("coco", {})
411
+ val_coco = report.get("splits", {}).get("val", {}).get("coco", {})
412
+ with cols[0]:
413
+ st.metric("Train Images", train_coco.get("num_images", 0))
414
+ with cols[1]:
415
+ st.metric("Train Annotations", train_coco.get("num_annotations", 0))
416
+ with cols[2]:
417
+ st.metric("Val Images", val_coco.get("num_images", 0))
418
+ with cols[3]:
419
+ st.metric("Val Annotations", val_coco.get("num_annotations", 0))
420
+
421
+ # GT Images
422
+ st.divider()
423
+ raw_dir = os.path.join(gt_dir, "images", "raw")
424
+ dl_dir = os.path.join(gt_dir, "images", "dataloader")
425
+
426
+ gt_tab_raw, gt_tab_dl = st.tabs(["Raw GT Images", "Dataloader GT Images"])
427
+
428
+ with gt_tab_raw:
429
+ _render_image_grid(raw_dir, "Raw ground truth with bounding boxes")
430
+
431
+ with gt_tab_dl:
432
+ _render_image_grid(dl_dir, "After dataloader transforms (letterbox, normalize)")
433
+
434
+
435
+ def _render_image_grid(img_dir, description):
436
+ """Render images from a directory in a grid."""
437
+ if not os.path.isdir(img_dir):
438
+ st.info(f"Directory not found: {img_dir}")
439
+ return
440
+
441
+ images = sorted([f for f in os.listdir(img_dir) if f.endswith((".jpg", ".png"))])
442
+ if not images:
443
+ st.info("No images found.")
444
+ return
445
+
446
+ st.caption(f"{description} โ€” {len(images)} images")
447
+
448
+ for i in range(0, len(images), 4):
449
+ cols = st.columns(4)
450
+ for j, col in enumerate(cols):
451
+ idx = i + j
452
+ if idx < len(images):
453
+ img_path = os.path.join(img_dir, images[idx])
454
+ with col:
455
+ st.image(img_path, caption=images[idx][:30], use_container_width=True)
456
+
457
+
458
+ def _render_full_log(log_path):
459
+ """Show full training log file."""
460
+ if not log_path or not os.path.isfile(log_path):
461
+ st.info("No training log found.")
462
+ return
463
+
464
+ with open(log_path) as f:
465
+ content = f.read()
466
+
467
+ st.markdown(f"**Log file:** `{os.path.basename(log_path)}`")
468
+
469
+ # Show last 100 lines by default, with option to expand
470
+ lines = content.strip().split("\n")
471
+ show_all = st.checkbox("Show full log", value=False, key="show_full_log")
472
+
473
+ if show_all:
474
+ st.code(content, language="bash")
475
+ else:
476
+ st.code("\n".join(lines[-50:]), language="bash")
477
+ if len(lines) > 50:
478
+ st.caption(f"Showing last 50 of {len(lines)} lines. Check 'Show full log' to see all.")
479
+
480
+
481
+ def _render_checkpoints(run_dir):
482
+ """Show checkpoint files and their sizes."""
483
+ st.markdown("#### Saved Checkpoints & Weights")
484
+
485
+ files_info = []
486
+ for f in sorted(os.listdir(run_dir)):
487
+ fpath = os.path.join(run_dir, f)
488
+ if os.path.isfile(fpath) and f.endswith((".pth", ".json", ".csv", ".log")):
489
+ size = os.path.getsize(fpath)
490
+ if size > 1024 * 1024:
491
+ size_str = f"{size / (1024*1024):.1f} MB"
492
+ elif size > 1024:
493
+ size_str = f"{size / 1024:.1f} KB"
494
+ else:
495
+ size_str = f"{size} B"
496
+ files_info.append({"File": f, "Size": size_str, "Type": _file_type(f)})
497
+
498
+ if files_info:
499
+ st.dataframe(files_info, use_container_width=True, hide_index=True)
500
+ else:
501
+ st.info("No checkpoint files found.")
502
+
503
+ # results.json
504
+ results_path = os.path.join(run_dir, "results.json")
505
+ if os.path.isfile(results_path):
506
+ st.divider()
507
+ st.markdown("#### Training Results Summary")
508
+ with open(results_path) as f:
509
+ results = json.load(f)
510
+
511
+ cols = st.columns(4)
512
+ with cols[0]:
513
+ st.metric("Epochs Trained", results.get("epochs_trained", "?"))
514
+ with cols[1]:
515
+ st.metric("Best mAP@0.5", f"{results.get('best_mAP50', 0):.4f}")
516
+ with cols[2]:
517
+ st.metric("Best Val Loss", f"{results.get('best_val_loss', 0):.4f}")
518
+ with cols[3]:
519
+ st.metric("Model Params", f"{results.get('model_params_M', 0):.2f}M")
520
+
521
+
522
+ def _file_type(filename):
523
+ """Categorize a file by its name."""
524
+ if "best" in filename:
525
+ return "Best checkpoint"
526
+ if "last" in filename:
527
+ return "Latest checkpoint"
528
+ if "inference" in filename:
529
+ return "Inference weights"
530
+ if "fp16" in filename:
531
+ return "FP16 weights"
532
+ if filename.endswith(".json"):
533
+ return "Results/Report"
534
+ if filename.endswith(".csv"):
535
+ return "Training log CSV"
536
+ if filename.endswith(".log"):
537
+ return "Training log"
538
+ return "Other"
539
+
540
+
541
+ # ---- Start New Training Tab ----
542
+
543
+ def _render_start_tab():
544
+ """Tab: start a new training run."""
545
+ st.markdown("### Start New Training")
546
+
547
+ col_info, col_actions = st.columns([3, 1])
548
+
549
+ with col_info:
550
+ arch = st.session_state.get("model_arch", "FlashDet-Pico")
551
+ dataset = st.session_state.get("dataset_name", "Not selected")
552
+ epochs = st.session_state.get("epochs", 100)
553
+ bs = st.session_state.get("batch_size", 16)
554
+ lr = st.session_state.get("lr", 0.001)
555
+
556
+ st.markdown(
557
+ f"**Config:** `{arch}` ยท Dataset: `{dataset}` ยท "
558
+ f"Epochs: `{epochs}` ยท BS: `{bs}` ยท LR: `{lr}`"
559
+ )
560
+
561
+ with col_actions:
562
+ if not st.session_state.get("training_active", False):
563
+ if st.button("โ–ถ๏ธ Start Training", use_container_width=True, type="primary",
564
+ key="btn_start_training"):
565
+ st.session_state["training_active"] = True
566
+ st.session_state["training_status"] = "Running"
567
+ _start_training()
568
+ else:
569
+ if st.button("โน๏ธ Stop Training", use_container_width=True, type="secondary",
570
+ key="btn_stop_training"):
571
+ st.session_state["training_active"] = False
572
+ st.session_state["training_status"] = "Stopped"
573
+ st.rerun()
574
+
575
+ if st.session_state.get("training_active"):
576
+ st.info("Training is running... Switch to 'Monitor Run' tab and select latest run to see progress.")
577
+
578
+
579
+ def _start_training():
580
+ """Start training โ€” real FlashDet or demo mode."""
581
+ if importlib.util.find_spec("flashdet") is not None:
582
+ _run_flashdet_training()
583
+ else:
584
+ st.warning("FlashDet not installed. Install with `pip install flashdet` to run real training.")
585
+ st.session_state["training_active"] = False
586
+ st.session_state["training_status"] = "FlashDet not found"
587
+
588
+
589
+ def _run_flashdet_training():
590
+ """Run actual FlashDet training via subprocess (non-blocking)."""
591
+ import subprocess
592
+ import sys
593
+
594
+ size_map = {
595
+ "FlashDet-Pico": "p", "FlashDet-Nano": "n", "FlashDet-Small": "s",
596
+ "FlashDet-Medium": "m", "FlashDet-Large": "l", "FlashDet-X": "x",
597
+ }
598
+
599
+ model_arch = st.session_state.get("model_arch", "FlashDet-Pico")
600
+ save_dir = st.session_state.get("save_dir", os.path.join(_get_default_workspace(), "flashstudio_run"))
601
+
602
+ # Auto-detect device
603
+ device = "cpu"
604
+ try:
605
+ import torch
606
+ if torch.cuda.is_available():
607
+ device = "cuda"
608
+ except ImportError:
609
+ pass
610
+
611
+ # Resolve dataset paths from session state
612
+ train_images = st.session_state.get("train_img_path", "")
613
+ val_images = st.session_state.get("val_img_path", "")
614
+
615
+ # Try to auto-detect from dataset download command
616
+ if not train_images or not os.path.isdir(train_images):
617
+ dataset_name = st.session_state.get("dataset_name", "")
618
+ dataset_id = dataset_name.lower().replace(" ", "").replace("(demo)", "sample")
619
+ for candidate_id in ["sample", "coco2017", "coco2017-val", "voc2007", "voc2012"]:
620
+ if candidate_id in dataset_id or dataset_id in candidate_id:
621
+ candidate_dir = os.path.join("data", candidate_id)
622
+ if os.path.isdir(os.path.join(candidate_dir, "train")):
623
+ train_images = os.path.join(candidate_dir, "train")
624
+ val_images = os.path.join(candidate_dir, "valid")
625
+ break
626
+
627
+ if not train_images or not val_images:
628
+ st.error(
629
+ "No dataset paths configured. Please go to the **Data** page and either:\n"
630
+ "- Enter train/val image paths manually, or\n"
631
+ "- Download a dataset first (`flashdet download --dataset sample`)"
632
+ )
633
+ st.session_state["training_active"] = False
634
+ st.session_state["training_status"] = "No dataset"
635
+ return
636
+
637
+ cmd = [
638
+ sys.executable, "-m", "flashdet.cli", "train",
639
+ "--model-size", size_map.get(model_arch, "n"),
640
+ "--epochs", str(st.session_state.get("epochs", 100)),
641
+ "--batch-size", str(st.session_state.get("batch_size", 16)),
642
+ "--lr", str(st.session_state.get("lr", 1e-3)),
643
+ "--save-dir", save_dir,
644
+ "--device", device,
645
+ "--train-images", train_images,
646
+ "--val-images", val_images,
647
+ ]
648
+
649
+ if st.session_state.get("amp", True) and device == "cuda":
650
+ cmd.append("--amp")
651
+ if st.session_state.get("aug_mosaic", True):
652
+ cmd.append("--mosaic")
653
+ if st.session_state.get("aug_mixup", False):
654
+ cmd.append("--mixup")
655
+
656
+ optimizer = st.session_state.get("optimizer", "AdamW").lower()
657
+ if optimizer in ("adamw", "sgd", "musgd"):
658
+ cmd.extend(["--optimizer", optimizer])
659
+
660
+ workers = st.session_state.get("num_workers", 2)
661
+ cmd.extend(["--workers", str(workers)])
662
+
663
+ # Run as non-blocking subprocess so Streamlit UI stays responsive
664
+ process = subprocess.Popen(
665
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1,
666
+ )
667
+
668
+ st.session_state["training_pid"] = process.pid
669
+
670
+ st.info(f"Training started (PID {process.pid}): `{' '.join(cmd)}`")
671
+ st.caption(f"Output will be saved to: `{save_dir}`")
672
+ st.caption("Switch to 'Monitor Run' tab to see live progress.")
File without changes