flashstudio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flashstudio/__init__.py +5 -0
- flashstudio/app.py +64 -0
- flashstudio/cli.py +18 -0
- flashstudio/components/__init__.py +0 -0
- flashstudio/components/sidebar.py +45 -0
- flashstudio/components/styles.py +273 -0
- flashstudio/components/wizard.py +46 -0
- flashstudio/launcher.py +73 -0
- flashstudio/pages/__init__.py +0 -0
- flashstudio/pages/dashboard.py +168 -0
- flashstudio/pages/data.py +272 -0
- flashstudio/pages/export.py +212 -0
- flashstudio/pages/inference.py +1112 -0
- flashstudio/pages/model.py +370 -0
- flashstudio/pages/training.py +672 -0
- flashstudio/utils/__init__.py +0 -0
- flashstudio/utils/device.py +58 -0
- flashstudio-0.1.0.dist-info/METADATA +133 -0
- flashstudio-0.1.0.dist-info/RECORD +22 -0
- flashstudio-0.1.0.dist-info/WHEEL +5 -0
- flashstudio-0.1.0.dist-info/entry_points.txt +2 -0
- flashstudio-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""FlashStudio — Dashboard Page."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import json
|
|
5
|
+
import glob as glob_module
|
|
6
|
+
|
|
7
|
+
import streamlit as st
|
|
8
|
+
from flashstudio.components.styles import render_page_header
|
|
9
|
+
from flashstudio.utils.device import get_gpu_info, get_colab_runtime_type
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def render_dashboard():
|
|
13
|
+
"""Render the main dashboard with overview and quick-start."""
|
|
14
|
+
render_page_header("⚡", "FlashStudio Dashboard",
|
|
15
|
+
"End-to-end object detection: Data → Model → Train → Export → Inference")
|
|
16
|
+
|
|
17
|
+
# Top metrics row
|
|
18
|
+
col1, col2, col3, col4 = st.columns(4)
|
|
19
|
+
gpu = get_gpu_info()
|
|
20
|
+
with col1:
|
|
21
|
+
st.metric("🖥️ GPU", gpu["name"] if gpu["available"] else "CPU")
|
|
22
|
+
with col2:
|
|
23
|
+
st.metric("🌐 Runtime", get_colab_runtime_type())
|
|
24
|
+
with col3:
|
|
25
|
+
mem = f"{gpu['memory_total']:.1f} GB" if gpu["available"] else "—"
|
|
26
|
+
st.metric("💾 VRAM", mem)
|
|
27
|
+
with col4:
|
|
28
|
+
st.metric("📦 FlashDet", _get_flashdet_version())
|
|
29
|
+
|
|
30
|
+
st.markdown("<div style='height:1rem;'></div>", unsafe_allow_html=True)
|
|
31
|
+
|
|
32
|
+
# Main content
|
|
33
|
+
col_left, col_right = st.columns([2, 1])
|
|
34
|
+
|
|
35
|
+
with col_left:
|
|
36
|
+
with st.container(border=True):
|
|
37
|
+
st.markdown("### 🚀 Quick Start Guide")
|
|
38
|
+
st.markdown("""
|
|
39
|
+
| Step | Action | Description |
|
|
40
|
+
|:----:|--------|-------------|
|
|
41
|
+
| 1 | **📦 Data** | Upload your dataset or download an open-source one |
|
|
42
|
+
| 2 | **🧠 Model** | Choose architecture & configure training parameters |
|
|
43
|
+
| 3 | **🏋️ Train** | Start training and monitor real-time progress |
|
|
44
|
+
| 4 | **📤 Export** | Convert to ONNX for deployment |
|
|
45
|
+
| 5 | **🔍 Inference** | Test your model on images, video, or RTSP |
|
|
46
|
+
""")
|
|
47
|
+
st.caption("Click **Next** below or use the sidebar to jump to any step.")
|
|
48
|
+
|
|
49
|
+
with col_right:
|
|
50
|
+
with st.container(border=True):
|
|
51
|
+
st.markdown("### 📊 Session Info")
|
|
52
|
+
dataset = st.session_state.get("dataset_name", "Not selected")
|
|
53
|
+
model = st.session_state.get("model_arch", "FlashDet-Pico")
|
|
54
|
+
status = st.session_state.get("training_status", "Not started")
|
|
55
|
+
|
|
56
|
+
st.markdown(f"**Dataset:** {dataset}")
|
|
57
|
+
st.markdown(f"**Model:** {model}")
|
|
58
|
+
st.markdown(f"**Training:** {status}")
|
|
59
|
+
|
|
60
|
+
if st.session_state.get("best_map"):
|
|
61
|
+
st.metric("Best mAP", f"{st.session_state['best_map']:.4f}")
|
|
62
|
+
if st.session_state.get("exported_model"):
|
|
63
|
+
st.markdown(f"**Exported:** {st.session_state['exported_model']}")
|
|
64
|
+
|
|
65
|
+
st.markdown("<div style='height:1rem;'></div>", unsafe_allow_html=True)
|
|
66
|
+
|
|
67
|
+
# Recent training runs
|
|
68
|
+
_render_recent_runs()
|
|
69
|
+
|
|
70
|
+
# Supported models table (correct data)
|
|
71
|
+
with st.container(border=True):
|
|
72
|
+
st.markdown("### 🔗 Supported Architectures")
|
|
73
|
+
models_data = [
|
|
74
|
+
["FlashDet-Pico", "LiteBackbone (0.5x)", "~298K", "Ultra-fast", "Edge / MCU"],
|
|
75
|
+
["FlashDet-Nano", "FlashBackbone (stem=32)", "~790K", "Very fast", "Embedded / IoT"],
|
|
76
|
+
["FlashDet-Small", "FlashBackbone (stem=48)", "~1.8M", "Fast", "General purpose"],
|
|
77
|
+
["FlashDet-Medium", "FlashBackbone (stem=64)", "~3.6M", "Balanced", "High accuracy"],
|
|
78
|
+
["FlashDet-Large", "FlashBackbone (stem=80)", "~5.8M", "Accurate", "High accuracy"],
|
|
79
|
+
["FlashDet-X", "FlashBackbone (stem=96)", "~9.0M", "Max accuracy", "Server"],
|
|
80
|
+
["YOLOv8", "YOLOv8Backbone", "Varies", "Fast", "General YOLO"],
|
|
81
|
+
["YOLOv9", "YOLOv9Backbone", "Varies", "Fast", "PGI-based"],
|
|
82
|
+
["YOLOv10", "YOLOv10Backbone", "Varies", "Fast", "PSA-enhanced"],
|
|
83
|
+
["YOLOv11", "YOLOv11Backbone", "Varies", "Fast", "C2PSA-based"],
|
|
84
|
+
["YOLOX", "YOLOXBackbone", "Varies", "Fast", "Anchor-free"],
|
|
85
|
+
]
|
|
86
|
+
st.dataframe(
|
|
87
|
+
{
|
|
88
|
+
"Model": [m[0] for m in models_data],
|
|
89
|
+
"Backbone": [m[1] for m in models_data],
|
|
90
|
+
"Params": [m[2] for m in models_data],
|
|
91
|
+
"Speed": [m[3] for m in models_data],
|
|
92
|
+
"Best For": [m[4] for m in models_data],
|
|
93
|
+
},
|
|
94
|
+
use_container_width=True,
|
|
95
|
+
hide_index=True,
|
|
96
|
+
)
|
|
97
|
+
st.caption("All architectures use fixed backbone/neck/head — only FlashDet-Pico allows backbone choice (lite vs pico_v2).")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _render_recent_runs():
|
|
101
|
+
"""Show recent training runs from workspace."""
|
|
102
|
+
workspace_candidates = [
|
|
103
|
+
os.path.join(os.getcwd(), "workspace"),
|
|
104
|
+
os.path.join(os.getcwd(), "..", "FlashDet", "workspace"),
|
|
105
|
+
os.path.join(os.path.dirname(__file__), "..", "..", "..", "FlashDet", "workspace"),
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
workspace = None
|
|
109
|
+
for c in workspace_candidates:
|
|
110
|
+
c = os.path.abspath(c)
|
|
111
|
+
if os.path.isdir(c):
|
|
112
|
+
workspace = c
|
|
113
|
+
break
|
|
114
|
+
|
|
115
|
+
if not workspace:
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
runs = sorted(
|
|
119
|
+
[d for d in os.listdir(workspace) if os.path.isdir(os.path.join(workspace, d))],
|
|
120
|
+
key=lambda d: os.path.getmtime(os.path.join(workspace, d)),
|
|
121
|
+
reverse=True,
|
|
122
|
+
)[:5]
|
|
123
|
+
|
|
124
|
+
if not runs:
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
with st.container(border=True):
|
|
128
|
+
st.markdown("### 🕐 Recent Training Runs")
|
|
129
|
+
|
|
130
|
+
for run_name in runs:
|
|
131
|
+
run_dir = os.path.join(workspace, run_name)
|
|
132
|
+
# Try to get results summary
|
|
133
|
+
results_path = os.path.join(run_dir, "results.json")
|
|
134
|
+
log_files = glob_module.glob(os.path.join(run_dir, "train_*.log"))
|
|
135
|
+
|
|
136
|
+
col_name, col_info, col_status = st.columns([2, 3, 1])
|
|
137
|
+
with col_name:
|
|
138
|
+
st.markdown(f"**{run_name}**")
|
|
139
|
+
with col_info:
|
|
140
|
+
if os.path.isfile(results_path):
|
|
141
|
+
with open(results_path) as f:
|
|
142
|
+
results = json.load(f)
|
|
143
|
+
mAP = results.get("best_mAP50", 0)
|
|
144
|
+
epochs = results.get("epochs_trained", "?")
|
|
145
|
+
st.caption(f"mAP: {mAP:.4f} · {epochs} epochs")
|
|
146
|
+
elif log_files:
|
|
147
|
+
st.caption("Training in progress / incomplete")
|
|
148
|
+
else:
|
|
149
|
+
st.caption("No data")
|
|
150
|
+
with col_status:
|
|
151
|
+
has_best = os.path.isfile(os.path.join(run_dir, "checkpoint_best.pth"))
|
|
152
|
+
if has_best:
|
|
153
|
+
st.caption("✅ Done")
|
|
154
|
+
elif log_files:
|
|
155
|
+
st.caption("🔄 Active")
|
|
156
|
+
else:
|
|
157
|
+
st.caption("—")
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _get_flashdet_version() -> str:
|
|
161
|
+
try:
|
|
162
|
+
import importlib.util
|
|
163
|
+
if importlib.util.find_spec("flashdet"):
|
|
164
|
+
import flashdet
|
|
165
|
+
return getattr(flashdet, "__version__", "installed")
|
|
166
|
+
except Exception:
|
|
167
|
+
pass
|
|
168
|
+
return "not installed"
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""FlashStudio — Data Upload / Download Page."""
|
|
2
|
+
|
|
3
|
+
import streamlit as st
|
|
4
|
+
|
|
5
|
+
FLASHDET_NATIVE_DATASETS = [
|
|
6
|
+
{
|
|
7
|
+
"name": "Sample (demo)",
|
|
8
|
+
"description": "Built-in demo dataset — fastest way to test training pipeline",
|
|
9
|
+
"images": 80,
|
|
10
|
+
"classes": 10,
|
|
11
|
+
"size": "~5 MB",
|
|
12
|
+
"format": "COCO JSON",
|
|
13
|
+
"command": "flashdet download sample",
|
|
14
|
+
"native": True,
|
|
15
|
+
},
|
|
16
|
+
{
|
|
17
|
+
"name": "COCO 2017",
|
|
18
|
+
"description": "Full COCO 2017 — 118K train + 5K val images, 80 object classes",
|
|
19
|
+
"images": 123287,
|
|
20
|
+
"classes": 80,
|
|
21
|
+
"size": "~20 GB",
|
|
22
|
+
"format": "COCO JSON",
|
|
23
|
+
"command": "flashdet download coco2017",
|
|
24
|
+
"native": True,
|
|
25
|
+
},
|
|
26
|
+
{
|
|
27
|
+
"name": "COCO 2017 (val only)",
|
|
28
|
+
"description": "COCO 2017 validation set only — for quick evaluation",
|
|
29
|
+
"images": 5000,
|
|
30
|
+
"classes": 80,
|
|
31
|
+
"size": "~1 GB",
|
|
32
|
+
"format": "COCO JSON",
|
|
33
|
+
"command": "flashdet download coco2017-val",
|
|
34
|
+
"native": True,
|
|
35
|
+
},
|
|
36
|
+
{
|
|
37
|
+
"name": "VOC 2007",
|
|
38
|
+
"description": "Pascal VOC 2007 — classic detection benchmark, 20 classes",
|
|
39
|
+
"images": 9963,
|
|
40
|
+
"classes": 20,
|
|
41
|
+
"size": "~450 MB",
|
|
42
|
+
"format": "VOC XML → COCO JSON",
|
|
43
|
+
"command": "flashdet download voc2007",
|
|
44
|
+
"native": True,
|
|
45
|
+
},
|
|
46
|
+
{
|
|
47
|
+
"name": "VOC 2012",
|
|
48
|
+
"description": "Pascal VOC 2012 — extended VOC with more images",
|
|
49
|
+
"images": 17125,
|
|
50
|
+
"classes": 20,
|
|
51
|
+
"size": "~2 GB",
|
|
52
|
+
"format": "VOC XML → COCO JSON",
|
|
53
|
+
"command": "flashdet download voc2012",
|
|
54
|
+
"native": True,
|
|
55
|
+
},
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
EXTERNAL_DATASETS = [
|
|
59
|
+
{
|
|
60
|
+
"name": "COCO128 (tiny subset)",
|
|
61
|
+
"description": "128 COCO images — quick pipeline testing",
|
|
62
|
+
"images": 128,
|
|
63
|
+
"classes": 80,
|
|
64
|
+
"size": "7 MB",
|
|
65
|
+
"format": "COCO JSON",
|
|
66
|
+
"command": "wget -q https://github.com/ultralytics/assets/releases/download/v0.0.0/coco128.zip && unzip -q coco128.zip -d datasets/",
|
|
67
|
+
"native": False,
|
|
68
|
+
},
|
|
69
|
+
{
|
|
70
|
+
"name": "VisDrone (drone aerial)",
|
|
71
|
+
"description": "Drone aerial images — pedestrians, vehicles from above",
|
|
72
|
+
"images": 6471,
|
|
73
|
+
"classes": 10,
|
|
74
|
+
"size": "~2 GB",
|
|
75
|
+
"format": "Custom → COCO JSON",
|
|
76
|
+
"command": "# See: https://github.com/VisDrone/VisDrone-Dataset",
|
|
77
|
+
"native": False,
|
|
78
|
+
},
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def render_data_page():
|
|
83
|
+
"""Render data upload and open-source dataset download page."""
|
|
84
|
+
from flashstudio.components.styles import render_page_header
|
|
85
|
+
render_page_header("📦", "Data Setup",
|
|
86
|
+
"Upload your own dataset or download a ready-to-use open-source dataset.")
|
|
87
|
+
|
|
88
|
+
data_mode = st.radio(
|
|
89
|
+
"Choose data source",
|
|
90
|
+
["📤 Upload Your Dataset", "📥 Download Open-Source Dataset"],
|
|
91
|
+
horizontal=True,
|
|
92
|
+
key="data_mode",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if data_mode == "📤 Upload Your Dataset":
|
|
96
|
+
_render_upload_tab()
|
|
97
|
+
else:
|
|
98
|
+
_render_download_tab()
|
|
99
|
+
|
|
100
|
+
st.divider()
|
|
101
|
+
_render_dataset_status()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _render_upload_tab():
|
|
105
|
+
"""Upload custom dataset."""
|
|
106
|
+
st.markdown("### Upload Your Dataset")
|
|
107
|
+
st.info("Supported formats: **COCO JSON**, **Pascal VOC XML**, **YOLO TXT**. Upload as ZIP/TAR.")
|
|
108
|
+
|
|
109
|
+
col1, col2 = st.columns(2)
|
|
110
|
+
|
|
111
|
+
with col1:
|
|
112
|
+
with st.container(border=True):
|
|
113
|
+
st.markdown("**Training Data**")
|
|
114
|
+
train_upload = st.file_uploader(
|
|
115
|
+
"Upload training dataset (zip/tar)",
|
|
116
|
+
type=["zip", "tar", "gz"],
|
|
117
|
+
key="train_data_upload",
|
|
118
|
+
)
|
|
119
|
+
st.text_input(
|
|
120
|
+
"Or enter path to annotations",
|
|
121
|
+
placeholder="/content/datasets/train/annotations.json",
|
|
122
|
+
key="train_ann_path",
|
|
123
|
+
)
|
|
124
|
+
st.text_input(
|
|
125
|
+
"Images directory",
|
|
126
|
+
placeholder="/content/datasets/train/images/",
|
|
127
|
+
key="train_img_path",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
with col2:
|
|
131
|
+
with st.container(border=True):
|
|
132
|
+
st.markdown("**Validation Data**")
|
|
133
|
+
val_upload = st.file_uploader(
|
|
134
|
+
"Upload validation dataset (zip/tar)",
|
|
135
|
+
type=["zip", "tar", "gz"],
|
|
136
|
+
key="val_data_upload",
|
|
137
|
+
)
|
|
138
|
+
st.text_input(
|
|
139
|
+
"Or enter path to annotations",
|
|
140
|
+
placeholder="/content/datasets/val/annotations.json",
|
|
141
|
+
key="val_ann_path",
|
|
142
|
+
)
|
|
143
|
+
st.text_input(
|
|
144
|
+
"Images directory",
|
|
145
|
+
placeholder="/content/datasets/val/images/",
|
|
146
|
+
key="val_img_path",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
with st.container(border=True):
|
|
150
|
+
st.markdown("**Dataset Format**")
|
|
151
|
+
col_fmt, col_cls = st.columns(2)
|
|
152
|
+
with col_fmt:
|
|
153
|
+
st.selectbox(
|
|
154
|
+
"Annotation Format",
|
|
155
|
+
["COCO JSON (recommended)", "Pascal VOC XML", "YOLO TXT"],
|
|
156
|
+
key="ann_format",
|
|
157
|
+
)
|
|
158
|
+
with col_cls:
|
|
159
|
+
st.number_input("Number of Classes", min_value=1, max_value=1000, value=80, key="upload_num_classes")
|
|
160
|
+
st.text_area(
|
|
161
|
+
"Class Names (one per line, optional)",
|
|
162
|
+
placeholder="person\ncar\nbicycle\n...",
|
|
163
|
+
height=100,
|
|
164
|
+
key="class_names",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if train_upload:
|
|
168
|
+
st.session_state["dataset_name"] = train_upload.name
|
|
169
|
+
st.success(f"✅ Training data uploaded: {train_upload.name} ({train_upload.size / 1e6:.1f} MB)")
|
|
170
|
+
|
|
171
|
+
if val_upload:
|
|
172
|
+
st.success(f"✅ Validation data uploaded: {val_upload.name} ({val_upload.size / 1e6:.1f} MB)")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _render_download_tab():
|
|
176
|
+
"""Download datasets — FlashDet native first, then external."""
|
|
177
|
+
st.markdown("### FlashDet Native Datasets")
|
|
178
|
+
st.caption("These use `flashdet download` — auto-handled format conversion and split.")
|
|
179
|
+
|
|
180
|
+
for i, ds in enumerate(FLASHDET_NATIVE_DATASETS):
|
|
181
|
+
with st.container(border=True):
|
|
182
|
+
col_info, col_stats, col_action = st.columns([3, 2, 1])
|
|
183
|
+
|
|
184
|
+
with col_info:
|
|
185
|
+
st.markdown(f"**{ds['name']}**")
|
|
186
|
+
st.caption(ds["description"])
|
|
187
|
+
|
|
188
|
+
with col_stats:
|
|
189
|
+
st.markdown(
|
|
190
|
+
f"📸 {ds['images']} images · 🏷️ {ds['classes']} classes · "
|
|
191
|
+
f"💾 {ds['size']} · 📋 {ds['format']}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
with col_action:
|
|
195
|
+
if st.button("⬇️ Download", key=f"native_dl_{i}", use_container_width=True):
|
|
196
|
+
st.session_state["dataset_name"] = ds["name"]
|
|
197
|
+
st.session_state["dataset_command"] = ds["command"]
|
|
198
|
+
st.session_state["dataset_classes"] = ds["classes"]
|
|
199
|
+
st.rerun()
|
|
200
|
+
|
|
201
|
+
st.divider()
|
|
202
|
+
st.markdown("### External Datasets")
|
|
203
|
+
st.caption("Third-party datasets — may need manual format conversion.")
|
|
204
|
+
|
|
205
|
+
for i, ds in enumerate(EXTERNAL_DATASETS):
|
|
206
|
+
with st.container(border=True):
|
|
207
|
+
col_info, col_stats, col_action = st.columns([3, 2, 1])
|
|
208
|
+
|
|
209
|
+
with col_info:
|
|
210
|
+
st.markdown(f"**{ds['name']}**")
|
|
211
|
+
st.caption(ds["description"])
|
|
212
|
+
|
|
213
|
+
with col_stats:
|
|
214
|
+
st.markdown(
|
|
215
|
+
f"📸 {ds['images']} images · 🏷️ {ds['classes']} classes · "
|
|
216
|
+
f"💾 {ds['size']} · 📋 {ds['format']}"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
with col_action:
|
|
220
|
+
if st.button("⬇️ Download", key=f"ext_dl_{i}", use_container_width=True):
|
|
221
|
+
st.session_state["dataset_name"] = ds["name"]
|
|
222
|
+
st.session_state["dataset_command"] = ds["command"]
|
|
223
|
+
st.session_state["dataset_classes"] = ds["classes"]
|
|
224
|
+
st.rerun()
|
|
225
|
+
|
|
226
|
+
if "dataset_command" in st.session_state and st.session_state.get("dataset_command"):
|
|
227
|
+
st.divider()
|
|
228
|
+
st.markdown("### 📋 Download Command")
|
|
229
|
+
st.code(st.session_state["dataset_command"], language="bash")
|
|
230
|
+
st.caption("Run this in a terminal or Colab cell.")
|
|
231
|
+
|
|
232
|
+
if st.button("🚀 Run Download", type="primary"):
|
|
233
|
+
_auto_download(st.session_state["dataset_command"])
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _render_dataset_status():
|
|
237
|
+
"""Show current dataset selection status."""
|
|
238
|
+
st.markdown("### ✅ Current Dataset")
|
|
239
|
+
dataset = st.session_state.get("dataset_name", None)
|
|
240
|
+
|
|
241
|
+
if dataset:
|
|
242
|
+
col1, col2, col3 = st.columns(3)
|
|
243
|
+
with col1:
|
|
244
|
+
st.metric("Dataset", dataset)
|
|
245
|
+
with col2:
|
|
246
|
+
num_cls = st.session_state.get("dataset_classes", st.session_state.get("upload_num_classes", 80))
|
|
247
|
+
st.metric("Classes", num_cls)
|
|
248
|
+
with col3:
|
|
249
|
+
fmt = st.session_state.get("ann_format", "COCO JSON")
|
|
250
|
+
st.metric("Format", fmt.split("(")[0].strip() if "(" in fmt else fmt)
|
|
251
|
+
st.success("Dataset ready! Click **Next** to choose your model.")
|
|
252
|
+
else:
|
|
253
|
+
st.warning("No dataset selected yet. Upload or download a dataset above.")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _auto_download(command: str):
|
|
257
|
+
"""Execute download command (works in Colab)."""
|
|
258
|
+
import subprocess
|
|
259
|
+
with st.spinner("Downloading dataset..."):
|
|
260
|
+
try:
|
|
261
|
+
result = subprocess.run(
|
|
262
|
+
command, shell=True, capture_output=True, text=True, timeout=300
|
|
263
|
+
)
|
|
264
|
+
if result.returncode == 0:
|
|
265
|
+
st.success("✅ Dataset downloaded successfully!")
|
|
266
|
+
st.session_state["train_ann_path"] = "/content/datasets/"
|
|
267
|
+
else:
|
|
268
|
+
st.error(f"Download failed: {result.stderr[:200]}")
|
|
269
|
+
except subprocess.TimeoutExpired:
|
|
270
|
+
st.error("Download timed out (5 min limit). Try manually.")
|
|
271
|
+
except Exception as e:
|
|
272
|
+
st.error(f"Error: {e}")
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""FlashStudio — Model Export Page (matches actual FlashDet capabilities)."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import streamlit as st
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def render_export_page():
|
|
8
|
+
"""Render model export/conversion page."""
|
|
9
|
+
from flashstudio.components.styles import render_page_header
|
|
10
|
+
render_page_header("📤", "Export Model",
|
|
11
|
+
"Convert your trained model for production deployment.")
|
|
12
|
+
|
|
13
|
+
_render_training_summary()
|
|
14
|
+
st.divider()
|
|
15
|
+
|
|
16
|
+
col_config, col_output = st.columns([1, 1])
|
|
17
|
+
|
|
18
|
+
with col_config:
|
|
19
|
+
_render_export_config()
|
|
20
|
+
|
|
21
|
+
with col_output:
|
|
22
|
+
_render_export_output()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _render_training_summary():
|
|
26
|
+
"""Show what weights are already available from training."""
|
|
27
|
+
save_dir = st.session_state.get("save_dir", "")
|
|
28
|
+
|
|
29
|
+
if save_dir and os.path.isdir(save_dir):
|
|
30
|
+
st.markdown("#### Available Weights from Training")
|
|
31
|
+
weights_files = [f for f in os.listdir(save_dir) if f.endswith(".pth")]
|
|
32
|
+
if weights_files:
|
|
33
|
+
cols = st.columns(min(len(weights_files), 4))
|
|
34
|
+
for i, wf in enumerate(sorted(weights_files)[:4]):
|
|
35
|
+
fpath = os.path.join(save_dir, wf)
|
|
36
|
+
size_mb = os.path.getsize(fpath) / (1024 * 1024)
|
|
37
|
+
with cols[i]:
|
|
38
|
+
label = "Best" if "best" in wf else "Last" if "last" in wf else "Other"
|
|
39
|
+
precision = "FP16" if "fp16" in wf else "FP32"
|
|
40
|
+
st.metric(f"{label} ({precision})", f"{size_mb:.1f} MB")
|
|
41
|
+
st.caption(f"Directory: `{save_dir}`")
|
|
42
|
+
else:
|
|
43
|
+
st.info("No weights found in save directory.")
|
|
44
|
+
else:
|
|
45
|
+
st.info("Complete training first, or provide a weights path below.")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _render_export_config():
|
|
49
|
+
"""Export configuration — only options FlashDet actually supports."""
|
|
50
|
+
st.markdown("### ⚙️ Export Settings")
|
|
51
|
+
|
|
52
|
+
with st.container(border=True):
|
|
53
|
+
st.markdown("**Model Weights**")
|
|
54
|
+
weights_source = st.radio(
|
|
55
|
+
"Source",
|
|
56
|
+
["Best model (inference weights)", "Best model (FP16)",
|
|
57
|
+
"Last model (inference weights)", "Custom path"],
|
|
58
|
+
key="weights_source",
|
|
59
|
+
)
|
|
60
|
+
if weights_source == "Custom path":
|
|
61
|
+
st.text_input("Weights path (.pth)", placeholder="/path/to/model.pth",
|
|
62
|
+
key="export_weights_path")
|
|
63
|
+
|
|
64
|
+
save_dir = st.session_state.get("save_dir", "/content/flashstudio_runs")
|
|
65
|
+
weights_map = {
|
|
66
|
+
"Best model (inference weights)": "model_best_inference.pth",
|
|
67
|
+
"Best model (FP16)": "model_best_fp16.pth",
|
|
68
|
+
"Last model (inference weights)": "model_last_inference.pth",
|
|
69
|
+
}
|
|
70
|
+
if weights_source in weights_map:
|
|
71
|
+
weights_path = os.path.join(save_dir, weights_map[weights_source])
|
|
72
|
+
else:
|
|
73
|
+
weights_path = st.session_state.get("export_weights_path", "")
|
|
74
|
+
|
|
75
|
+
st.caption(f"📁 `{weights_path}`")
|
|
76
|
+
|
|
77
|
+
with st.container(border=True):
|
|
78
|
+
st.markdown("**Export Format**")
|
|
79
|
+
st.markdown("FlashDet supports **ONNX** export with dynamic batch size (opset 13).")
|
|
80
|
+
|
|
81
|
+
st.checkbox("ONNX (recommended — universal format)", value=True, key="export_onnx",
|
|
82
|
+
disabled=True)
|
|
83
|
+
st.caption("Works with: ONNX Runtime, OpenVINO, TensorRT (via onnx2trt), TFLite (via onnx2tf)")
|
|
84
|
+
|
|
85
|
+
st.divider()
|
|
86
|
+
st.caption("**Not natively supported** (convert from ONNX using external tools):")
|
|
87
|
+
st.markdown("""
|
|
88
|
+
- TensorRT: `trtexec --onnx=model.onnx --saveEngine=model.engine --fp16`
|
|
89
|
+
- OpenVINO: `mo --input_model model.onnx --output_dir openvino/`
|
|
90
|
+
- CoreML: `python -m coremltools.converters.onnx model.onnx`
|
|
91
|
+
- NCNN: `onnx2ncnn model.onnx model.param model.bin`
|
|
92
|
+
""")
|
|
93
|
+
|
|
94
|
+
with st.container(border=True):
|
|
95
|
+
st.markdown("**ONNX Export Options**")
|
|
96
|
+
img_size = st.select_slider(
|
|
97
|
+
"Input Size", [320, 416, 640], value=320, key="export_img_size",
|
|
98
|
+
help="Should match your training input size"
|
|
99
|
+
)
|
|
100
|
+
st.checkbox("Dynamic batch size", value=True, key="export_dynamic",
|
|
101
|
+
help="Allows variable batch size at inference (default in FlashDet)")
|
|
102
|
+
st.number_input("Opset Version", 11, 18, 13, key="export_opset",
|
|
103
|
+
help="FlashDet uses opset 13 by default")
|
|
104
|
+
|
|
105
|
+
if st.button("🚀 Export to ONNX", use_container_width=True, type="primary"):
|
|
106
|
+
_run_export(weights_path, img_size)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _render_export_output():
|
|
110
|
+
"""Show export results and commands."""
|
|
111
|
+
st.markdown("### 📦 Export Output")
|
|
112
|
+
|
|
113
|
+
exported = st.session_state.get("exported_files", [])
|
|
114
|
+
|
|
115
|
+
if not exported:
|
|
116
|
+
st.markdown("#### Pre-Generated Weights")
|
|
117
|
+
st.info(
|
|
118
|
+
"FlashDet **automatically saves** FP16 inference weights during training:\n\n"
|
|
119
|
+
"- `model_best_inference.pth` — best mAP, FP32, o2m head stripped\n"
|
|
120
|
+
"- `model_best_fp16.pth` — best mAP, FP16 (half memory)\n"
|
|
121
|
+
"- `model_last_inference.pth` — final epoch, FP32\n"
|
|
122
|
+
"- `model_last_fp16.pth` — final epoch, FP16\n\n"
|
|
123
|
+
"These are ready to use with `FlashDetPredictor` — no export step needed for PyTorch inference."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
st.divider()
|
|
127
|
+
st.markdown("#### ONNX Export Command (Manual)")
|
|
128
|
+
weights = st.session_state.get("export_weights_path", "model_best_inference.pth")
|
|
129
|
+
img_size = st.session_state.get("export_img_size", 320)
|
|
130
|
+
st.code(
|
|
131
|
+
f"flashdet export --model {weights} --output model.onnx\n"
|
|
132
|
+
f"# Output: model.onnx (opset 13, dynamic batch)",
|
|
133
|
+
language="bash"
|
|
134
|
+
)
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
for exp in exported:
|
|
138
|
+
with st.container(border=True):
|
|
139
|
+
col_info, col_action = st.columns([3, 1])
|
|
140
|
+
with col_info:
|
|
141
|
+
st.markdown(f"**{exp['format']}** — `{exp['path']}`")
|
|
142
|
+
if exp.get("size"):
|
|
143
|
+
st.caption(f"Size: {exp['size']}")
|
|
144
|
+
if exp.get("success"):
|
|
145
|
+
st.success("Export successful")
|
|
146
|
+
else:
|
|
147
|
+
st.error("Export failed")
|
|
148
|
+
with col_action:
|
|
149
|
+
if exp.get("success") and os.path.isfile(exp.get("path", "")):
|
|
150
|
+
with open(exp["path"], "rb") as f:
|
|
151
|
+
st.download_button("📥 Download", f, file_name=os.path.basename(exp["path"]),
|
|
152
|
+
key=f"dl_{exp['format']}")
|
|
153
|
+
|
|
154
|
+
st.divider()
|
|
155
|
+
st.success("✅ Export complete! Click **Next** to test inference.")
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _run_export(weights_path: str, img_size: int):
|
|
159
|
+
"""Run ONNX export."""
|
|
160
|
+
if not weights_path:
|
|
161
|
+
st.error("Please select or provide weights path.")
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
if not os.path.isfile(weights_path):
|
|
165
|
+
st.warning(f"Weights file not found: `{weights_path}`. Attempting export anyway...")
|
|
166
|
+
|
|
167
|
+
import importlib.util
|
|
168
|
+
if importlib.util.find_spec("flashdet") is not None:
|
|
169
|
+
_flashdet_export(weights_path, img_size)
|
|
170
|
+
else:
|
|
171
|
+
st.warning("FlashDet not installed. Run export manually:")
|
|
172
|
+
st.code(f"flashdet export --model {weights_path} --output model.onnx", language="bash")
|
|
173
|
+
st.session_state["exported_files"] = [{
|
|
174
|
+
"format": "ONNX", "path": weights_path.replace(".pth", ".onnx"),
|
|
175
|
+
"size": "—", "success": False
|
|
176
|
+
}]
|
|
177
|
+
st.rerun()
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _flashdet_export(weights_path: str, img_size: int):
|
|
181
|
+
"""Actual FlashDet ONNX export."""
|
|
182
|
+
import subprocess
|
|
183
|
+
import sys
|
|
184
|
+
|
|
185
|
+
output_path = weights_path.replace(".pth", ".onnx")
|
|
186
|
+
|
|
187
|
+
cmd = [
|
|
188
|
+
sys.executable, "-m", "flashdet.cli", "export",
|
|
189
|
+
"--model", weights_path,
|
|
190
|
+
"--output", output_path,
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
with st.spinner("Exporting to ONNX..."):
|
|
194
|
+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
|
|
195
|
+
|
|
196
|
+
if result.returncode == 0:
|
|
197
|
+
size_str = "—"
|
|
198
|
+
if os.path.isfile(output_path):
|
|
199
|
+
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
|
200
|
+
size_str = f"{size_mb:.1f} MB"
|
|
201
|
+
|
|
202
|
+
st.session_state["exported_files"] = [{
|
|
203
|
+
"format": "ONNX", "path": output_path, "size": size_str, "success": True
|
|
204
|
+
}]
|
|
205
|
+
st.toast("ONNX export complete!", icon="✅")
|
|
206
|
+
else:
|
|
207
|
+
st.error(f"Export failed:\n{result.stderr[:500]}")
|
|
208
|
+
st.session_state["exported_files"] = [{
|
|
209
|
+
"format": "ONNX", "path": output_path, "size": "—", "success": False
|
|
210
|
+
}]
|
|
211
|
+
|
|
212
|
+
st.rerun()
|