flashstudio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flashstudio/__init__.py +5 -0
- flashstudio/app.py +64 -0
- flashstudio/cli.py +18 -0
- flashstudio/components/__init__.py +0 -0
- flashstudio/components/sidebar.py +45 -0
- flashstudio/components/styles.py +273 -0
- flashstudio/components/wizard.py +46 -0
- flashstudio/launcher.py +73 -0
- flashstudio/pages/__init__.py +0 -0
- flashstudio/pages/dashboard.py +168 -0
- flashstudio/pages/data.py +272 -0
- flashstudio/pages/export.py +212 -0
- flashstudio/pages/inference.py +1112 -0
- flashstudio/pages/model.py +370 -0
- flashstudio/pages/training.py +672 -0
- flashstudio/utils/__init__.py +0 -0
- flashstudio/utils/device.py +58 -0
- flashstudio-0.1.0.dist-info/METADATA +133 -0
- flashstudio-0.1.0.dist-info/RECORD +22 -0
- flashstudio-0.1.0.dist-info/WHEEL +5 -0
- flashstudio-0.1.0.dist-info/entry_points.txt +2 -0
- flashstudio-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1112 @@
|
|
|
1
|
+
"""FlashStudio — Inference Pipeline Dashboard."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import io
|
|
5
|
+
import json
|
|
6
|
+
import tempfile
|
|
7
|
+
import importlib.util
|
|
8
|
+
import streamlit as st
|
|
9
|
+
import numpy as np
|
|
10
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
11
|
+
from streamlit_image_coordinates import streamlit_image_coordinates
|
|
12
|
+
from flashstudio.components.styles import render_page_header, render_info_bar
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
COCO_CLASSES = [
|
|
16
|
+
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
|
|
17
|
+
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
|
|
18
|
+
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
|
|
19
|
+
"giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
|
|
20
|
+
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
|
|
21
|
+
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
|
|
22
|
+
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
|
|
23
|
+
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
|
|
24
|
+
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
|
|
25
|
+
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
|
|
26
|
+
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
|
|
27
|
+
"hair drier", "toothbrush",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
COLORS = [
|
|
31
|
+
"#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7",
|
|
32
|
+
"#DDA0DD", "#98D8C8", "#F7DC6F", "#BB8FCE", "#85C1E9",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
SOLUTIONS = {
|
|
36
|
+
"None (Detection Only)": {"description": "Standard object detection — bounding boxes only", "icon": "🔍", "needs_zone": False},
|
|
37
|
+
"Object Counter (Line)": {"description": "Count objects crossing a line", "icon": "🔢", "needs_zone": True, "zone_type": "line"},
|
|
38
|
+
"Region Counter (Polygon)": {"description": "Count objects inside polygon zones", "icon": "📐", "needs_zone": True, "zone_type": "polygon"},
|
|
39
|
+
"Speed Estimator": {"description": "Estimate speed of moving objects", "icon": "⚡", "needs_zone": True, "zone_type": "line"},
|
|
40
|
+
"Heatmap": {"description": "Generate heatmap of object activity", "icon": "🔥", "needs_zone": False},
|
|
41
|
+
"Security Alarm (Zone)": {"description": "Alert when objects enter restricted zone", "icon": "🚨", "needs_zone": True, "zone_type": "polygon"},
|
|
42
|
+
"Trajectory Visualizer": {"description": "Draw motion trails for tracked objects", "icon": "🎯", "needs_zone": False},
|
|
43
|
+
"Object Blurrer": {"description": "Blur detected objects for privacy", "icon": "🙈", "needs_zone": False},
|
|
44
|
+
"Queue Manager": {"description": "Monitor queue lengths in defined zones", "icon": "🧍", "needs_zone": True, "zone_type": "polygon"},
|
|
45
|
+
"Crowd Density": {"description": "Grid-based crowd density estimation", "icon": "👥", "needs_zone": False},
|
|
46
|
+
"Parking Manager": {"description": "Track parking space occupancy", "icon": "🅿️", "needs_zone": True, "zone_type": "polygon"},
|
|
47
|
+
"Traffic Flow": {"description": "Direction-aware traffic analysis", "icon": "🚗", "needs_zone": True, "zone_type": "line"},
|
|
48
|
+
"Dwell Time Analyzer": {"description": "Measure time objects spend in zones", "icon": "⏱️", "needs_zone": True, "zone_type": "polygon"},
|
|
49
|
+
"Distance Calculator": {"description": "Compute pairwise distances between objects", "icon": "📏", "needs_zone": False},
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_device_options() -> list:
|
|
54
|
+
"""Get available compute devices."""
|
|
55
|
+
devices = ["cpu"]
|
|
56
|
+
try:
|
|
57
|
+
import torch
|
|
58
|
+
if torch.cuda.is_available():
|
|
59
|
+
for i in range(torch.cuda.device_count()):
|
|
60
|
+
name = torch.cuda.get_device_name(i)
|
|
61
|
+
devices.append(f"cuda:{i} ({name})")
|
|
62
|
+
except ImportError:
|
|
63
|
+
pass
|
|
64
|
+
return devices
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def render_inference_page():
|
|
68
|
+
"""Render inference pipeline using tabs."""
|
|
69
|
+
render_page_header("🔍", "FlashStudio Inference",
|
|
70
|
+
"Full inference pipeline — select model, upload data, configure solution, run detection.")
|
|
71
|
+
|
|
72
|
+
# Pipeline readiness indicators
|
|
73
|
+
_render_readiness_bar()
|
|
74
|
+
|
|
75
|
+
tab_model, tab_data, tab_solution, tab_run = st.tabs([
|
|
76
|
+
"🧠 Model & Weights", "📁 Upload Data", "📐 Solution & Zone", "🚀 Run & Results"
|
|
77
|
+
])
|
|
78
|
+
|
|
79
|
+
with tab_model:
|
|
80
|
+
_tab_model()
|
|
81
|
+
|
|
82
|
+
with tab_data:
|
|
83
|
+
_tab_data()
|
|
84
|
+
|
|
85
|
+
with tab_solution:
|
|
86
|
+
_tab_solution()
|
|
87
|
+
|
|
88
|
+
with tab_run:
|
|
89
|
+
_tab_run()
|
|
90
|
+
|
|
91
|
+
model_name = st.session_state.get("infer_model_arch", "FlashDet-Nano")
|
|
92
|
+
solution = st.session_state.get("selected_solution", "Detection Only")
|
|
93
|
+
device = st.session_state.get("infer_device", "cpu")
|
|
94
|
+
render_info_bar({"Model": model_name, "Solution": solution, "Device": device.split(" ")[0]})
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _render_readiness_bar():
|
|
98
|
+
"""Show pipeline readiness as colored indicators."""
|
|
99
|
+
has_model = bool(st.session_state.get("infer_weights_file") or st.session_state.get("infer_weights_path"))
|
|
100
|
+
has_data = bool(
|
|
101
|
+
st.session_state.get("infer_images") or
|
|
102
|
+
st.session_state.get("infer_video") or
|
|
103
|
+
st.session_state.get("infer_stream_url")
|
|
104
|
+
)
|
|
105
|
+
solution = st.session_state.get("selected_solution", "None (Detection Only)")
|
|
106
|
+
sol_info = SOLUTIONS.get(solution, {})
|
|
107
|
+
needs_zone = sol_info.get("needs_zone", False)
|
|
108
|
+
has_zone = bool(
|
|
109
|
+
st.session_state.get("zone_line_points") or
|
|
110
|
+
st.session_state.get("zone_polygons") or
|
|
111
|
+
not needs_zone
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def _dot(ok):
|
|
115
|
+
return "🟢" if ok else "🔴"
|
|
116
|
+
|
|
117
|
+
st.markdown(
|
|
118
|
+
f'<div style="display:flex;gap:1.5rem;padding:0.5rem 1rem;background:#F9FAFB;'
|
|
119
|
+
f'border-radius:8px;border:1px solid #E8E8EF;margin-bottom:1rem;font-size:0.82rem;">'
|
|
120
|
+
f'<span>{_dot(has_model)} Model Weights</span>'
|
|
121
|
+
f'<span>{_dot(has_data)} Input Data</span>'
|
|
122
|
+
f'<span>{_dot(has_zone)} Zone Config</span>'
|
|
123
|
+
f'<span style="margin-left:auto;color:#6B7280;">Ready: {sum([has_model, has_data, has_zone])}/3</span>'
|
|
124
|
+
f'</div>',
|
|
125
|
+
unsafe_allow_html=True,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# ════════════════════════════════════════════════════════════════
|
|
130
|
+
# TAB 1: MODEL & WEIGHTS
|
|
131
|
+
# ════════════════════════════════════════════════════════════════
|
|
132
|
+
|
|
133
|
+
def _tab_model():
|
|
134
|
+
"""Select model architecture, device, and upload weights."""
|
|
135
|
+
col_left, col_right = st.columns(2)
|
|
136
|
+
|
|
137
|
+
with col_left:
|
|
138
|
+
st.markdown("#### Model Architecture")
|
|
139
|
+
st.selectbox(
|
|
140
|
+
"Architecture",
|
|
141
|
+
["FlashDet-Pico", "FlashDet-Nano", "FlashDet-Small", "FlashDet-Medium", "FlashDet-Large"],
|
|
142
|
+
key="infer_model_arch",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
st.markdown("#### Inference Parameters")
|
|
146
|
+
st.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.05, key="infer_conf")
|
|
147
|
+
st.slider("NMS IoU Threshold", 0.0, 1.0, 0.45, 0.05, key="infer_nms")
|
|
148
|
+
st.number_input("Image Size", 320, 1920, 640, 32, key="infer_img_size")
|
|
149
|
+
|
|
150
|
+
st.markdown("#### Device")
|
|
151
|
+
devices = _get_device_options()
|
|
152
|
+
st.selectbox("Compute Device", devices, key="infer_device")
|
|
153
|
+
|
|
154
|
+
with col_right:
|
|
155
|
+
st.markdown("#### Model Weights")
|
|
156
|
+
weight_source = st.radio(
|
|
157
|
+
"Source",
|
|
158
|
+
["Upload file", "Enter path", "Use training output"],
|
|
159
|
+
key="weight_source",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if weight_source == "Upload file":
|
|
163
|
+
uploaded_w = st.file_uploader(
|
|
164
|
+
"Upload .pt / .pth / .onnx / .engine",
|
|
165
|
+
type=["pt", "pth", "onnx", "engine"],
|
|
166
|
+
key="infer_weights_file",
|
|
167
|
+
)
|
|
168
|
+
if uploaded_w:
|
|
169
|
+
st.success(f"Loaded: {uploaded_w.name} ({uploaded_w.size / 1e6:.1f} MB)")
|
|
170
|
+
|
|
171
|
+
elif weight_source == "Enter path":
|
|
172
|
+
st.text_input("Weights path", placeholder="/path/to/best.pt", key="infer_weights_path")
|
|
173
|
+
|
|
174
|
+
else:
|
|
175
|
+
save_dir = st.session_state.get("save_dir", "/content/flashstudio_runs")
|
|
176
|
+
path = f"{save_dir}/model_best_inference.pth"
|
|
177
|
+
st.session_state["infer_weights_path"] = path
|
|
178
|
+
st.info(f"Using: `{path}`")
|
|
179
|
+
|
|
180
|
+
st.markdown("#### Classes")
|
|
181
|
+
st.number_input("Number of Classes", 1, 1000, 80, key="infer_num_classes")
|
|
182
|
+
|
|
183
|
+
st.markdown("#### Class Filter")
|
|
184
|
+
st.multiselect(
|
|
185
|
+
"Detect only (leave empty for all)",
|
|
186
|
+
COCO_CLASSES,
|
|
187
|
+
default=[],
|
|
188
|
+
key="infer_class_filter",
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# ════════════════════════════════════════════════════════════════
|
|
193
|
+
# TAB 2: UPLOAD DATA
|
|
194
|
+
# ════════════════════════════════════════════════════════════════
|
|
195
|
+
|
|
196
|
+
def _tab_data():
|
|
197
|
+
"""Upload images or video with preview and metadata."""
|
|
198
|
+
input_type = st.radio(
|
|
199
|
+
"Input Type",
|
|
200
|
+
["Images", "Video", "RTSP Stream"],
|
|
201
|
+
horizontal=True,
|
|
202
|
+
key="infer_input_type",
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if input_type == "Images":
|
|
206
|
+
_data_images()
|
|
207
|
+
elif input_type == "Video":
|
|
208
|
+
_data_video()
|
|
209
|
+
else:
|
|
210
|
+
_data_stream()
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _data_images():
|
|
214
|
+
"""Image upload with grid preview."""
|
|
215
|
+
st.markdown("#### Upload Images")
|
|
216
|
+
uploaded = st.file_uploader(
|
|
217
|
+
"Select images (JPG, PNG, WebP)",
|
|
218
|
+
type=["jpg", "jpeg", "png", "webp", "bmp"],
|
|
219
|
+
accept_multiple_files=True,
|
|
220
|
+
key="infer_images",
|
|
221
|
+
)
|
|
222
|
+
if uploaded:
|
|
223
|
+
st.success(f"{len(uploaded)} image(s) uploaded")
|
|
224
|
+
cols = st.columns(min(len(uploaded), 4))
|
|
225
|
+
for i, f in enumerate(uploaded[:4]):
|
|
226
|
+
with cols[i]:
|
|
227
|
+
st.image(f, caption=f.name, width=None)
|
|
228
|
+
if len(uploaded) > 4:
|
|
229
|
+
st.caption(f"...and {len(uploaded) - 4} more")
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _data_video():
|
|
233
|
+
"""Video upload with metadata and first-frame preview."""
|
|
234
|
+
col_vid, col_opts = st.columns(2)
|
|
235
|
+
with col_vid:
|
|
236
|
+
st.markdown("#### Upload Video")
|
|
237
|
+
st.file_uploader(
|
|
238
|
+
"Select video (MP4, AVI, MOV)",
|
|
239
|
+
type=["mp4", "avi", "mov", "mkv", "webm"],
|
|
240
|
+
key="infer_video",
|
|
241
|
+
)
|
|
242
|
+
if st.session_state.get("infer_video"):
|
|
243
|
+
vid = st.session_state["infer_video"]
|
|
244
|
+
st.success(f"{vid.name} ({vid.size / 1e6:.1f} MB)")
|
|
245
|
+
|
|
246
|
+
# Show video metadata + first frame
|
|
247
|
+
meta = _get_video_metadata(vid)
|
|
248
|
+
if meta:
|
|
249
|
+
mc1, mc2, mc3, mc4 = st.columns(4)
|
|
250
|
+
with mc1:
|
|
251
|
+
st.metric("Resolution", f"{meta['width']}x{meta['height']}")
|
|
252
|
+
with mc2:
|
|
253
|
+
st.metric("FPS", f"{meta['fps']:.1f}")
|
|
254
|
+
with mc3:
|
|
255
|
+
st.metric("Frames", meta['total_frames'])
|
|
256
|
+
with mc4:
|
|
257
|
+
duration = meta['total_frames'] / meta['fps'] if meta['fps'] > 0 else 0
|
|
258
|
+
st.metric("Duration", f"{duration:.1f}s")
|
|
259
|
+
|
|
260
|
+
first_frame = _get_first_frame()
|
|
261
|
+
if first_frame:
|
|
262
|
+
st.image(first_frame, caption="First Frame Preview", width=None)
|
|
263
|
+
|
|
264
|
+
with col_opts:
|
|
265
|
+
st.markdown("#### Processing Options")
|
|
266
|
+
st.number_input("Max Frames (0=all)", 0, 100000, 300, key="infer_max_frames")
|
|
267
|
+
st.select_slider("Skip every N frames", [1, 2, 3, 5, 10], value=1, key="infer_frame_skip")
|
|
268
|
+
st.checkbox("Save output video", value=True, key="infer_save_video")
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _data_stream():
|
|
272
|
+
"""RTSP / webcam stream configuration."""
|
|
273
|
+
st.markdown("#### RTSP / Webcam Stream")
|
|
274
|
+
col_url, col_opts = st.columns([3, 1])
|
|
275
|
+
with col_url:
|
|
276
|
+
st.text_input("Stream URL", placeholder="rtsp://192.168.1.100:554/stream", key="infer_stream_url")
|
|
277
|
+
with col_opts:
|
|
278
|
+
st.number_input("Duration (s)", 0, 3600, 60, key="infer_stream_duration")
|
|
279
|
+
|
|
280
|
+
if st.session_state.get("infer_stream_url"):
|
|
281
|
+
if st.button("Test Connection", key="test_stream"):
|
|
282
|
+
with st.spinner("Testing stream..."):
|
|
283
|
+
try:
|
|
284
|
+
import cv2
|
|
285
|
+
cap = cv2.VideoCapture(st.session_state["infer_stream_url"])
|
|
286
|
+
if cap.isOpened():
|
|
287
|
+
ret, frame = cap.read()
|
|
288
|
+
cap.release()
|
|
289
|
+
if ret:
|
|
290
|
+
st.success("Stream connected successfully!")
|
|
291
|
+
st.image(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)),
|
|
292
|
+
caption="Stream Preview", width=None)
|
|
293
|
+
else:
|
|
294
|
+
st.error("Connected but failed to read frame.")
|
|
295
|
+
else:
|
|
296
|
+
st.error("Failed to connect to stream.")
|
|
297
|
+
except Exception as e:
|
|
298
|
+
st.error(f"Connection error: {e}")
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _get_video_metadata(video_file) -> dict | None:
|
|
302
|
+
"""Extract video metadata using OpenCV."""
|
|
303
|
+
try:
|
|
304
|
+
import cv2
|
|
305
|
+
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
|
306
|
+
tmp.write(video_file.read())
|
|
307
|
+
tmp.flush()
|
|
308
|
+
video_file.seek(0)
|
|
309
|
+
cap = cv2.VideoCapture(tmp.name)
|
|
310
|
+
if cap.isOpened():
|
|
311
|
+
meta = {
|
|
312
|
+
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
|
313
|
+
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
|
314
|
+
"fps": cap.get(cv2.CAP_PROP_FPS) or 30,
|
|
315
|
+
"total_frames": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
|
|
316
|
+
}
|
|
317
|
+
cap.release()
|
|
318
|
+
return meta
|
|
319
|
+
cap.release()
|
|
320
|
+
except Exception:
|
|
321
|
+
pass
|
|
322
|
+
return None
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
# ════════════════════════════════════════════════════════════════
|
|
326
|
+
# TAB 3: SOLUTION & ZONE (Interactive Click-to-Draw)
|
|
327
|
+
# ════════════════════════════════════════════════════════════════
|
|
328
|
+
|
|
329
|
+
def _get_first_frame() -> Image.Image | None:
|
|
330
|
+
"""Extract the first frame from uploaded video or first uploaded image."""
|
|
331
|
+
video = st.session_state.get("infer_video")
|
|
332
|
+
if video:
|
|
333
|
+
try:
|
|
334
|
+
import cv2
|
|
335
|
+
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
|
336
|
+
tmp.write(video.read())
|
|
337
|
+
tmp.flush()
|
|
338
|
+
video.seek(0)
|
|
339
|
+
cap = cv2.VideoCapture(tmp.name)
|
|
340
|
+
ret, frame = cap.read()
|
|
341
|
+
cap.release()
|
|
342
|
+
if ret:
|
|
343
|
+
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
344
|
+
except Exception:
|
|
345
|
+
pass
|
|
346
|
+
|
|
347
|
+
images = st.session_state.get("infer_images")
|
|
348
|
+
if images and len(images) > 0:
|
|
349
|
+
try:
|
|
350
|
+
img = Image.open(images[0]).convert("RGB")
|
|
351
|
+
images[0].seek(0)
|
|
352
|
+
return img
|
|
353
|
+
except Exception:
|
|
354
|
+
pass
|
|
355
|
+
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _draw_overlay(img: Image.Image, points: list, zone_type: str, closed: bool = False) -> Image.Image:
|
|
360
|
+
"""Draw zone points and lines on the frame image."""
|
|
361
|
+
overlay = img.copy()
|
|
362
|
+
draw = ImageDraw.Draw(overlay)
|
|
363
|
+
|
|
364
|
+
if not points:
|
|
365
|
+
return overlay
|
|
366
|
+
|
|
367
|
+
color = "#7C3AED"
|
|
368
|
+
point_radius = 6
|
|
369
|
+
|
|
370
|
+
try:
|
|
371
|
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
|
|
372
|
+
except (OSError, IOError):
|
|
373
|
+
font = ImageFont.load_default()
|
|
374
|
+
|
|
375
|
+
for i, (x, y) in enumerate(points):
|
|
376
|
+
draw.ellipse([x - point_radius, y - point_radius, x + point_radius, y + point_radius],
|
|
377
|
+
fill=color, outline="white", width=2)
|
|
378
|
+
draw.text((x + 10, y - 10), str(i + 1), fill="white", font=font)
|
|
379
|
+
|
|
380
|
+
if zone_type == "line" and len(points) == 2:
|
|
381
|
+
draw.line([tuple(points[0]), tuple(points[1])], fill=color, width=3)
|
|
382
|
+
# Draw direction arrow
|
|
383
|
+
mx = (points[0][0] + points[1][0]) // 2
|
|
384
|
+
my = (points[0][1] + points[1][1]) // 2
|
|
385
|
+
draw.ellipse([mx - 4, my - 4, mx + 4, my + 4], fill="#10B981")
|
|
386
|
+
elif zone_type == "polygon":
|
|
387
|
+
if len(points) >= 2:
|
|
388
|
+
for i in range(len(points) - 1):
|
|
389
|
+
draw.line([tuple(points[i]), tuple(points[i + 1])], fill=color, width=3)
|
|
390
|
+
if closed and len(points) >= 3:
|
|
391
|
+
draw.line([tuple(points[-1]), tuple(points[0])], fill=color, width=3)
|
|
392
|
+
poly_pts = [tuple(p) for p in points]
|
|
393
|
+
overlay_alpha = Image.new("RGBA", overlay.size, (0, 0, 0, 0))
|
|
394
|
+
draw_alpha = ImageDraw.Draw(overlay_alpha)
|
|
395
|
+
draw_alpha.polygon(poly_pts, fill=(124, 58, 237, 40))
|
|
396
|
+
overlay = Image.alpha_composite(overlay.convert("RGBA"), overlay_alpha).convert("RGB")
|
|
397
|
+
elif zone_type == "rect" and len(points) == 2:
|
|
398
|
+
x1, y1 = points[0]
|
|
399
|
+
x2, y2 = points[1]
|
|
400
|
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
|
|
401
|
+
overlay_alpha = Image.new("RGBA", overlay.size, (0, 0, 0, 0))
|
|
402
|
+
draw_alpha = ImageDraw.Draw(overlay_alpha)
|
|
403
|
+
draw_alpha.rectangle([x1, y1, x2, y2], fill=(124, 58, 237, 30))
|
|
404
|
+
overlay = Image.alpha_composite(overlay.convert("RGBA"), overlay_alpha).convert("RGB")
|
|
405
|
+
|
|
406
|
+
return overlay
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _tab_solution():
|
|
410
|
+
"""Select FlashDet solution and configure zones with interactive click drawing."""
|
|
411
|
+
col_sol, col_info = st.columns([2, 3])
|
|
412
|
+
|
|
413
|
+
with col_sol:
|
|
414
|
+
st.markdown("#### Select Solution")
|
|
415
|
+
selected = st.selectbox(
|
|
416
|
+
"FlashDet Solution",
|
|
417
|
+
list(SOLUTIONS.keys()),
|
|
418
|
+
key="selected_solution",
|
|
419
|
+
format_func=lambda x: f"{SOLUTIONS[x]['icon']} {x}",
|
|
420
|
+
)
|
|
421
|
+
sol = SOLUTIONS[selected]
|
|
422
|
+
st.caption(sol["description"])
|
|
423
|
+
|
|
424
|
+
with col_info:
|
|
425
|
+
if sol["needs_zone"]:
|
|
426
|
+
zone_type = sol.get("zone_type", "polygon")
|
|
427
|
+
if zone_type == "line":
|
|
428
|
+
st.info("🖊️ **Click 2 points** on the frame to define the counting line (start → end).")
|
|
429
|
+
else:
|
|
430
|
+
st.info("🖊️ **Click points** on the frame to define the zone. Click 'Close Polygon' when done.")
|
|
431
|
+
else:
|
|
432
|
+
st.success("✅ This solution does not require zone configuration.")
|
|
433
|
+
|
|
434
|
+
st.divider()
|
|
435
|
+
|
|
436
|
+
if not sol["needs_zone"]:
|
|
437
|
+
st.markdown("#### Available Solutions Reference")
|
|
438
|
+
cols = st.columns(3)
|
|
439
|
+
for i, (name, info) in enumerate(SOLUTIONS.items()):
|
|
440
|
+
if name == "None (Detection Only)":
|
|
441
|
+
continue
|
|
442
|
+
with cols[i % 3]:
|
|
443
|
+
zone = "Line" if info.get("zone_type") == "line" else ("Polygon" if info["needs_zone"] else "—")
|
|
444
|
+
st.markdown(f"{info['icon']} **{name}** \n<small style='color:#6B7280'>Zone: {zone}</small>",
|
|
445
|
+
unsafe_allow_html=True)
|
|
446
|
+
return
|
|
447
|
+
|
|
448
|
+
# Interactive zone drawing
|
|
449
|
+
_zone_draw_ui(sol)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def _zone_draw_ui(sol: dict):
|
|
453
|
+
"""Interactive zone drawing interface."""
|
|
454
|
+
zone_type = sol.get("zone_type", "polygon")
|
|
455
|
+
|
|
456
|
+
draw_mode = zone_type
|
|
457
|
+
if zone_type == "polygon":
|
|
458
|
+
draw_mode = st.radio(
|
|
459
|
+
"Drawing Mode",
|
|
460
|
+
["polygon", "rect"],
|
|
461
|
+
format_func=lambda x: "🔷 Polygon (click vertices)" if x == "polygon" else "⬜ Rectangle (2 corners)",
|
|
462
|
+
horizontal=True,
|
|
463
|
+
key="zone_draw_mode",
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Initialize state
|
|
467
|
+
zone_key = "zone_draw_points"
|
|
468
|
+
if zone_key not in st.session_state:
|
|
469
|
+
st.session_state[zone_key] = []
|
|
470
|
+
if "zone_closed" not in st.session_state:
|
|
471
|
+
st.session_state["zone_closed"] = False
|
|
472
|
+
|
|
473
|
+
# Frame image
|
|
474
|
+
frame_img = _get_first_frame()
|
|
475
|
+
if frame_img is None:
|
|
476
|
+
st.warning("⚠️ Upload a video or image in the **Upload Data** tab first, then return here to draw zones.")
|
|
477
|
+
frame_img = Image.new("RGB", (640, 480), color=(30, 30, 40))
|
|
478
|
+
draw = ImageDraw.Draw(frame_img)
|
|
479
|
+
draw.text((180, 220), "No frame available", fill=(180, 180, 190))
|
|
480
|
+
draw.text((160, 245), "Upload video/image in Tab 2", fill=(120, 120, 130))
|
|
481
|
+
|
|
482
|
+
# Scale for display
|
|
483
|
+
display_w = min(700, frame_img.width)
|
|
484
|
+
aspect = frame_img.height / frame_img.width
|
|
485
|
+
display_h = int(display_w * aspect)
|
|
486
|
+
scale_x = frame_img.width / display_w
|
|
487
|
+
scale_y = frame_img.height / display_h
|
|
488
|
+
|
|
489
|
+
# Draw overlay
|
|
490
|
+
display_img = frame_img.resize((display_w, display_h), Image.LANCZOS)
|
|
491
|
+
display_points = st.session_state[zone_key]
|
|
492
|
+
annotated = _draw_overlay(display_img, display_points, draw_mode, st.session_state["zone_closed"])
|
|
493
|
+
|
|
494
|
+
# Action bar
|
|
495
|
+
col_b1, col_b2, col_b3, col_b4 = st.columns(4)
|
|
496
|
+
with col_b1:
|
|
497
|
+
if st.button("🗑️ Clear", key="zone_clear", use_container_width=True):
|
|
498
|
+
st.session_state[zone_key] = []
|
|
499
|
+
st.session_state["zone_closed"] = False
|
|
500
|
+
st.rerun()
|
|
501
|
+
with col_b2:
|
|
502
|
+
if st.button("↩️ Undo", key="zone_undo", use_container_width=True):
|
|
503
|
+
if st.session_state[zone_key]:
|
|
504
|
+
st.session_state[zone_key].pop()
|
|
505
|
+
st.session_state["zone_closed"] = False
|
|
506
|
+
st.rerun()
|
|
507
|
+
with col_b3:
|
|
508
|
+
if draw_mode == "polygon" and len(st.session_state[zone_key]) >= 3:
|
|
509
|
+
if st.button("✅ Close Polygon", key="zone_close", use_container_width=True, type="primary"):
|
|
510
|
+
st.session_state["zone_closed"] = True
|
|
511
|
+
st.rerun()
|
|
512
|
+
with col_b4:
|
|
513
|
+
n_pts = len(display_points)
|
|
514
|
+
if draw_mode == "line":
|
|
515
|
+
st.caption(f"📍 Points: {n_pts}/2")
|
|
516
|
+
elif draw_mode == "rect":
|
|
517
|
+
st.caption(f"📍 Corners: {n_pts}/2")
|
|
518
|
+
else:
|
|
519
|
+
status = "Closed ✅" if st.session_state["zone_closed"] else f"{n_pts} vertices"
|
|
520
|
+
st.caption(f"📍 Polygon: {status}")
|
|
521
|
+
|
|
522
|
+
# Clickable image
|
|
523
|
+
coords = streamlit_image_coordinates(annotated, key="zone_click_img")
|
|
524
|
+
|
|
525
|
+
# Handle click
|
|
526
|
+
if coords is not None:
|
|
527
|
+
click_x = coords["x"]
|
|
528
|
+
click_y = coords["y"]
|
|
529
|
+
new_point = [click_x, click_y]
|
|
530
|
+
|
|
531
|
+
if not display_points or display_points[-1] != new_point:
|
|
532
|
+
max_points = 2 if draw_mode in ("line", "rect") else 50
|
|
533
|
+
if len(display_points) < max_points and not st.session_state["zone_closed"]:
|
|
534
|
+
st.session_state[zone_key].append(new_point)
|
|
535
|
+
if draw_mode in ("line", "rect") and len(st.session_state[zone_key]) == 2:
|
|
536
|
+
st.session_state["zone_closed"] = True
|
|
537
|
+
st.rerun()
|
|
538
|
+
|
|
539
|
+
# Store and display coordinates
|
|
540
|
+
st.divider()
|
|
541
|
+
_store_zone_coords(display_points, draw_mode, scale_x, scale_y)
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def _store_zone_coords(display_points: list, draw_mode: str, scale_x: float, scale_y: float):
|
|
545
|
+
"""Convert display coordinates to original frame coordinates and store."""
|
|
546
|
+
if not display_points:
|
|
547
|
+
st.caption("👆 Click on the frame above to start drawing the zone.")
|
|
548
|
+
return
|
|
549
|
+
|
|
550
|
+
orig_points = [(int(x * scale_x), int(y * scale_y)) for x, y in display_points]
|
|
551
|
+
|
|
552
|
+
st.markdown("#### Zone Coordinates")
|
|
553
|
+
|
|
554
|
+
if draw_mode == "line" and len(orig_points) >= 2:
|
|
555
|
+
st.session_state["zone_line_points"] = [orig_points[0], orig_points[1]]
|
|
556
|
+
st.session_state["line_x1"] = orig_points[0][0]
|
|
557
|
+
st.session_state["line_y1"] = orig_points[0][1]
|
|
558
|
+
st.session_state["line_x2"] = orig_points[1][0]
|
|
559
|
+
st.session_state["line_y2"] = orig_points[1][1]
|
|
560
|
+
st.session_state["auto_line"] = False
|
|
561
|
+
st.success(f"✅ Line: ({orig_points[0][0]}, {orig_points[0][1]}) → ({orig_points[1][0]}, {orig_points[1][1]})")
|
|
562
|
+
|
|
563
|
+
elif draw_mode == "rect" and len(orig_points) >= 2:
|
|
564
|
+
x1, y1 = orig_points[0]
|
|
565
|
+
x2, y2 = orig_points[1]
|
|
566
|
+
rect_poly = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
|
567
|
+
st.session_state["zone_polygons"] = [rect_poly]
|
|
568
|
+
pts_str = "\n".join(f"{x},{y}" for x, y in rect_poly)
|
|
569
|
+
st.session_state["polygon_points"] = pts_str
|
|
570
|
+
st.success(f"✅ Rectangle: ({x1},{y1}) → ({x2},{y2})")
|
|
571
|
+
|
|
572
|
+
elif draw_mode == "polygon" and len(orig_points) >= 3:
|
|
573
|
+
st.session_state["zone_polygons"] = [orig_points]
|
|
574
|
+
pts_str = "\n".join(f"{x},{y}" for x, y in orig_points)
|
|
575
|
+
st.session_state["polygon_points"] = pts_str
|
|
576
|
+
closed = st.session_state.get("zone_closed", False)
|
|
577
|
+
pts_display = " → ".join(f"({x},{y})" for x, y in orig_points)
|
|
578
|
+
if closed:
|
|
579
|
+
st.success(f"✅ Polygon ({len(orig_points)} pts): {pts_display}")
|
|
580
|
+
else:
|
|
581
|
+
st.info(f"⏳ {len(orig_points)} points placed — click more or Close Polygon")
|
|
582
|
+
else:
|
|
583
|
+
needed = 2 if draw_mode in ("line", "rect") else 3
|
|
584
|
+
st.info(f"⏳ {len(orig_points)}/{needed} points — keep clicking to add more.")
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
# ════════════════════════════════════════════════════════════════
|
|
588
|
+
# TAB 4: RUN & RESULTS
|
|
589
|
+
# ════════════════════════════════════════════════════════════════
|
|
590
|
+
|
|
591
|
+
def _tab_run():
|
|
592
|
+
"""Run inference with configuration summary and results."""
|
|
593
|
+
# Configuration summary
|
|
594
|
+
_render_config_summary()
|
|
595
|
+
|
|
596
|
+
st.divider()
|
|
597
|
+
|
|
598
|
+
# Readiness check
|
|
599
|
+
input_type = st.session_state.get("infer_input_type", "Images")
|
|
600
|
+
has_data = (
|
|
601
|
+
st.session_state.get("infer_images") or
|
|
602
|
+
st.session_state.get("infer_video") or
|
|
603
|
+
st.session_state.get("infer_stream_url")
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
if not has_data:
|
|
607
|
+
st.warning("⚠️ No data uploaded. Go to the **Upload Data** tab first.")
|
|
608
|
+
return
|
|
609
|
+
|
|
610
|
+
# Run controls
|
|
611
|
+
col_run, col_opts = st.columns([1, 2])
|
|
612
|
+
with col_run:
|
|
613
|
+
run_clicked = st.button("🚀 Run Inference", type="primary", key="run_infer_btn", use_container_width=True)
|
|
614
|
+
with col_opts:
|
|
615
|
+
st.checkbox("Show per-frame detections", value=False, key="show_per_frame")
|
|
616
|
+
|
|
617
|
+
if run_clicked:
|
|
618
|
+
if input_type == "Images":
|
|
619
|
+
_run_images()
|
|
620
|
+
elif input_type == "Video":
|
|
621
|
+
_run_video()
|
|
622
|
+
else:
|
|
623
|
+
st.info("Stream inference — coming soon.")
|
|
624
|
+
|
|
625
|
+
st.divider()
|
|
626
|
+
|
|
627
|
+
if input_type == "Images":
|
|
628
|
+
_show_image_results()
|
|
629
|
+
else:
|
|
630
|
+
_show_video_results()
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def _render_config_summary():
|
|
634
|
+
"""Show a compact configuration summary before running."""
|
|
635
|
+
c1, c2, c3, c4 = st.columns(4)
|
|
636
|
+
with c1:
|
|
637
|
+
st.metric("Model", st.session_state.get("infer_model_arch", "FlashDet-Nano"))
|
|
638
|
+
with c2:
|
|
639
|
+
st.metric("Input", st.session_state.get("infer_input_type", "Images"))
|
|
640
|
+
with c3:
|
|
641
|
+
st.metric("Solution", st.session_state.get("selected_solution", "Detection Only")[:18])
|
|
642
|
+
with c4:
|
|
643
|
+
device = st.session_state.get("infer_device", "cpu")
|
|
644
|
+
st.metric("Device", device.split(" ")[0])
|
|
645
|
+
|
|
646
|
+
# Expandable detailed config
|
|
647
|
+
with st.expander("View Full Configuration", expanded=False):
|
|
648
|
+
col_a, col_b = st.columns(2)
|
|
649
|
+
with col_a:
|
|
650
|
+
st.markdown("**Model Settings**")
|
|
651
|
+
st.text(f" Architecture: {st.session_state.get('infer_model_arch', 'FlashDet-Nano')}")
|
|
652
|
+
st.text(f" Confidence: {st.session_state.get('infer_conf', 0.25)}")
|
|
653
|
+
st.text(f" NMS IoU: {st.session_state.get('infer_nms', 0.45)}")
|
|
654
|
+
st.text(f" Image Size: {st.session_state.get('infer_img_size', 640)}")
|
|
655
|
+
class_filter = st.session_state.get("infer_class_filter", [])
|
|
656
|
+
st.text(f" Class Filter: {', '.join(class_filter) if class_filter else 'All classes'}")
|
|
657
|
+
with col_b:
|
|
658
|
+
st.markdown("**Zone Configuration**")
|
|
659
|
+
solution = st.session_state.get("selected_solution", "None (Detection Only)")
|
|
660
|
+
st.text(f" Solution: {solution}")
|
|
661
|
+
line_pts = st.session_state.get("zone_line_points")
|
|
662
|
+
if line_pts:
|
|
663
|
+
st.text(f" Line: {line_pts[0]} → {line_pts[1]}")
|
|
664
|
+
polys = st.session_state.get("zone_polygons")
|
|
665
|
+
if polys:
|
|
666
|
+
st.text(f" Polygon: {len(polys[0])} vertices")
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
# ════════════════════════════════════════════════════════════════
|
|
670
|
+
# IMAGE INFERENCE
|
|
671
|
+
# ════════════════════════════════════════════════════════════════
|
|
672
|
+
|
|
673
|
+
def _run_images():
|
|
674
|
+
"""Run inference on uploaded images."""
|
|
675
|
+
images = st.session_state.get("infer_images", [])
|
|
676
|
+
if not images:
|
|
677
|
+
return
|
|
678
|
+
|
|
679
|
+
results = []
|
|
680
|
+
total_time = 0
|
|
681
|
+
progress = st.progress(0, text="Running inference...")
|
|
682
|
+
|
|
683
|
+
for i, f in enumerate(images):
|
|
684
|
+
img = Image.open(f).convert("RGB")
|
|
685
|
+
t0 = time.perf_counter()
|
|
686
|
+
dets = _detect(img)
|
|
687
|
+
elapsed = time.perf_counter() - t0
|
|
688
|
+
total_time += elapsed
|
|
689
|
+
annotated = _draw_boxes(img.copy(), dets)
|
|
690
|
+
results.append({
|
|
691
|
+
"name": f.name,
|
|
692
|
+
"dets": dets,
|
|
693
|
+
"annotated": annotated,
|
|
694
|
+
"inference_time_ms": elapsed * 1000,
|
|
695
|
+
"width": img.width,
|
|
696
|
+
"height": img.height,
|
|
697
|
+
})
|
|
698
|
+
progress.progress((i + 1) / len(images), text=f"Processing {i+1}/{len(images)}...")
|
|
699
|
+
|
|
700
|
+
progress.empty()
|
|
701
|
+
st.session_state["infer_img_results"] = results
|
|
702
|
+
st.session_state["infer_total_time"] = total_time
|
|
703
|
+
st.rerun()
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def _show_image_results():
|
|
707
|
+
"""Display image results with real metrics."""
|
|
708
|
+
if "infer_img_results" not in st.session_state:
|
|
709
|
+
st.info("Click **Run Inference** to see results.")
|
|
710
|
+
return
|
|
711
|
+
|
|
712
|
+
results = st.session_state["infer_img_results"]
|
|
713
|
+
total_dets = sum(len(r["dets"]) for r in results)
|
|
714
|
+
total_time = st.session_state.get("infer_total_time", 1)
|
|
715
|
+
avg_time_ms = (total_time / len(results)) * 1000 if results else 0
|
|
716
|
+
|
|
717
|
+
# Real computed metrics
|
|
718
|
+
c1, c2, c3, c4 = st.columns(4)
|
|
719
|
+
with c1:
|
|
720
|
+
fps = len(results) / total_time if total_time > 0 else 0
|
|
721
|
+
st.metric("Throughput", f"{fps:.1f} img/s")
|
|
722
|
+
with c2:
|
|
723
|
+
st.metric("Avg Latency", f"{avg_time_ms:.1f} ms")
|
|
724
|
+
with c3:
|
|
725
|
+
st.metric("Total Detections", total_dets)
|
|
726
|
+
with c4:
|
|
727
|
+
st.metric("Images Processed", len(results))
|
|
728
|
+
|
|
729
|
+
# Per-image results
|
|
730
|
+
for r in results:
|
|
731
|
+
st.divider()
|
|
732
|
+
col_img, col_tbl = st.columns(2)
|
|
733
|
+
with col_img:
|
|
734
|
+
st.image(r["annotated"], caption=f"{r['name']} ({r['inference_time_ms']:.1f}ms)", width=None)
|
|
735
|
+
with col_tbl:
|
|
736
|
+
if r["dets"]:
|
|
737
|
+
import pandas as pd
|
|
738
|
+
df = pd.DataFrame(r["dets"], columns=["Class", "Confidence", "BBox"])
|
|
739
|
+
df.index += 1
|
|
740
|
+
st.dataframe(df, use_container_width=True, height=250)
|
|
741
|
+
else:
|
|
742
|
+
st.warning("No detections.")
|
|
743
|
+
|
|
744
|
+
# Export options
|
|
745
|
+
st.divider()
|
|
746
|
+
st.markdown("#### Export Results")
|
|
747
|
+
col_d1, col_d2, col_d3 = st.columns(3)
|
|
748
|
+
|
|
749
|
+
with col_d1:
|
|
750
|
+
import pandas as pd
|
|
751
|
+
all_d = []
|
|
752
|
+
for r in results:
|
|
753
|
+
for d in r["dets"]:
|
|
754
|
+
all_d.append([r["name"]] + d)
|
|
755
|
+
if all_d:
|
|
756
|
+
df = pd.DataFrame(all_d, columns=["File", "Class", "Confidence", "BBox"])
|
|
757
|
+
st.download_button("📥 CSV Detections", df.to_csv(index=False),
|
|
758
|
+
"detections.csv", "text/csv", type="primary", use_container_width=True)
|
|
759
|
+
|
|
760
|
+
with col_d2:
|
|
761
|
+
coco_json = _to_coco_format(results)
|
|
762
|
+
st.download_button("📥 COCO JSON", json.dumps(coco_json, indent=2),
|
|
763
|
+
"detections_coco.json", "application/json", use_container_width=True)
|
|
764
|
+
|
|
765
|
+
with col_d3:
|
|
766
|
+
if results:
|
|
767
|
+
buf = io.BytesIO()
|
|
768
|
+
results[0]["annotated"].save(buf, format="PNG")
|
|
769
|
+
st.download_button("🖼️ Annotated Image", buf.getvalue(),
|
|
770
|
+
"annotated.png", "image/png", use_container_width=True)
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
def _to_coco_format(results: list) -> dict:
|
|
774
|
+
"""Convert detection results to COCO JSON format."""
|
|
775
|
+
coco = {
|
|
776
|
+
"images": [],
|
|
777
|
+
"annotations": [],
|
|
778
|
+
"categories": [{"id": i, "name": name} for i, name in enumerate(COCO_CLASSES)],
|
|
779
|
+
}
|
|
780
|
+
ann_id = 1
|
|
781
|
+
for img_id, r in enumerate(results, 1):
|
|
782
|
+
coco["images"].append({
|
|
783
|
+
"id": img_id,
|
|
784
|
+
"file_name": r["name"],
|
|
785
|
+
"width": r.get("width", 640),
|
|
786
|
+
"height": r.get("height", 480),
|
|
787
|
+
})
|
|
788
|
+
for cls_name, conf, bbox_str in r["dets"]:
|
|
789
|
+
bbox = json.loads(bbox_str.replace("'", '"'))
|
|
790
|
+
x1, y1, x2, y2 = bbox
|
|
791
|
+
coco["annotations"].append({
|
|
792
|
+
"id": ann_id,
|
|
793
|
+
"image_id": img_id,
|
|
794
|
+
"category_id": COCO_CLASSES.index(cls_name) if cls_name in COCO_CLASSES else 0,
|
|
795
|
+
"bbox": [x1, y1, x2 - x1, y2 - y1],
|
|
796
|
+
"area": (x2 - x1) * (y2 - y1),
|
|
797
|
+
"score": float(conf),
|
|
798
|
+
"iscrowd": 0,
|
|
799
|
+
})
|
|
800
|
+
ann_id += 1
|
|
801
|
+
return coco
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
# ════════════════════════════════════════════════════════════════
|
|
805
|
+
# VIDEO INFERENCE
|
|
806
|
+
# ════════════════════════════════════════════════════════════════
|
|
807
|
+
|
|
808
|
+
def _run_video():
|
|
809
|
+
"""Run video inference with zone integration."""
|
|
810
|
+
video = st.session_state.get("infer_video")
|
|
811
|
+
solution = st.session_state.get("selected_solution", "None (Detection Only)")
|
|
812
|
+
|
|
813
|
+
if importlib.util.find_spec("flashdet"):
|
|
814
|
+
try:
|
|
815
|
+
_run_flashdet_video(video, solution)
|
|
816
|
+
return
|
|
817
|
+
except Exception as e:
|
|
818
|
+
st.warning(f"FlashDet unavailable: {e}. Running demo mode.")
|
|
819
|
+
|
|
820
|
+
_run_demo_video(solution)
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
def _show_video_results():
|
|
824
|
+
"""Display video inference results."""
|
|
825
|
+
if "video_results" not in st.session_state:
|
|
826
|
+
st.info("Click **Run Inference** to see results.")
|
|
827
|
+
return
|
|
828
|
+
|
|
829
|
+
res = st.session_state["video_results"]
|
|
830
|
+
|
|
831
|
+
# Metrics
|
|
832
|
+
c1, c2, c3, c4 = st.columns(4)
|
|
833
|
+
with c1:
|
|
834
|
+
st.metric("Frames Processed", res.get("frames", 0))
|
|
835
|
+
with c2:
|
|
836
|
+
st.metric("Total Detections", res.get("detections", 0))
|
|
837
|
+
with c3:
|
|
838
|
+
st.metric("Processing FPS", f"{res.get('processing_fps', 0):.1f}")
|
|
839
|
+
with c4:
|
|
840
|
+
elapsed = res.get("elapsed_time", 0)
|
|
841
|
+
st.metric("Total Time", f"{elapsed:.1f}s")
|
|
842
|
+
|
|
843
|
+
# Solution output
|
|
844
|
+
sol = res.get("solution_output", {})
|
|
845
|
+
if sol:
|
|
846
|
+
st.divider()
|
|
847
|
+
st.markdown("#### Solution Results")
|
|
848
|
+
|
|
849
|
+
if "in_count" in sol:
|
|
850
|
+
c1, c2, c3 = st.columns(3)
|
|
851
|
+
with c1:
|
|
852
|
+
st.metric("In Count", sol["in_count"])
|
|
853
|
+
with c2:
|
|
854
|
+
st.metric("Out Count", sol["out_count"])
|
|
855
|
+
with c3:
|
|
856
|
+
st.metric("Net Flow", sol["in_count"] - sol["out_count"])
|
|
857
|
+
|
|
858
|
+
if "region_counts" in sol:
|
|
859
|
+
cols = st.columns(len(sol["region_counts"]))
|
|
860
|
+
for i, (region, count) in enumerate(sol["region_counts"].items()):
|
|
861
|
+
with cols[i]:
|
|
862
|
+
st.metric(f"Zone: {region}", count)
|
|
863
|
+
|
|
864
|
+
if "speeds" in sol:
|
|
865
|
+
import pandas as pd
|
|
866
|
+
df = pd.DataFrame(sol["speeds"], columns=["Track ID", "Speed (km/h)"])
|
|
867
|
+
st.dataframe(df, use_container_width=True, hide_index=True)
|
|
868
|
+
|
|
869
|
+
# Sample frames
|
|
870
|
+
if res.get("frames_preview"):
|
|
871
|
+
st.divider()
|
|
872
|
+
st.markdown("#### Sample Frames")
|
|
873
|
+
n_frames = len(res["frames_preview"])
|
|
874
|
+
cols = st.columns(min(n_frames, 4))
|
|
875
|
+
for i, frame in enumerate(res["frames_preview"][:4]):
|
|
876
|
+
with cols[i]:
|
|
877
|
+
st.image(frame, caption=f"Frame {i+1}", width=None)
|
|
878
|
+
|
|
879
|
+
if res.get("output_path"):
|
|
880
|
+
st.success(f"Output saved: `{res['output_path']}`")
|
|
881
|
+
|
|
882
|
+
# Export
|
|
883
|
+
st.divider()
|
|
884
|
+
st.markdown("#### Export")
|
|
885
|
+
col_e1, col_e2 = st.columns(2)
|
|
886
|
+
with col_e1:
|
|
887
|
+
summary = {
|
|
888
|
+
"solution": st.session_state.get("selected_solution"),
|
|
889
|
+
"frames_processed": res.get("frames"),
|
|
890
|
+
"total_detections": res.get("detections"),
|
|
891
|
+
"processing_fps": res.get("processing_fps"),
|
|
892
|
+
"solution_output": sol,
|
|
893
|
+
}
|
|
894
|
+
st.download_button("📥 Results JSON", json.dumps(summary, indent=2),
|
|
895
|
+
"video_results.json", "application/json", use_container_width=True)
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
def _run_demo_video(solution_name: str):
|
|
899
|
+
"""Demo video processing with realistic simulation."""
|
|
900
|
+
t0 = time.perf_counter()
|
|
901
|
+
progress = st.progress(0, text="Processing video...")
|
|
902
|
+
frames_preview = []
|
|
903
|
+
|
|
904
|
+
for i in range(4):
|
|
905
|
+
time.sleep(0.3)
|
|
906
|
+
progress.progress((i + 1) / 4, text=f"Frame {(i+1)*75}/300")
|
|
907
|
+
img = Image.new("RGB", (640, 480), color=(248, 248, 252))
|
|
908
|
+
draw = ImageDraw.Draw(img)
|
|
909
|
+
draw.rectangle([80 + i * 40, 80 + i * 30, 250 + i * 40, 230 + i * 30],
|
|
910
|
+
outline="#7C3AED", width=3)
|
|
911
|
+
draw.text((90 + i * 40, 65 + i * 30), f"car 0.{85 + i}", fill="#1A1A2E")
|
|
912
|
+
|
|
913
|
+
# Draw zone overlay on preview frames
|
|
914
|
+
line_pts = st.session_state.get("zone_line_points")
|
|
915
|
+
if line_pts:
|
|
916
|
+
sx = 640 / 700
|
|
917
|
+
draw.line([(int(line_pts[0][0]), int(line_pts[0][1])),
|
|
918
|
+
(int(line_pts[1][0]), int(line_pts[1][1]))],
|
|
919
|
+
fill="#10B981", width=2)
|
|
920
|
+
|
|
921
|
+
frames_preview.append(img)
|
|
922
|
+
|
|
923
|
+
progress.empty()
|
|
924
|
+
elapsed = time.perf_counter() - t0
|
|
925
|
+
|
|
926
|
+
sol_output = {}
|
|
927
|
+
if "Counter" in solution_name:
|
|
928
|
+
sol_output = {"in_count": 12, "out_count": 8}
|
|
929
|
+
elif "Region" in solution_name:
|
|
930
|
+
sol_output = {"region_counts": {"Zone A": 5, "Zone B": 3}}
|
|
931
|
+
elif "Speed" in solution_name:
|
|
932
|
+
sol_output = {"speeds": [[1, 45.2], [2, 38.7], [3, 52.1]]}
|
|
933
|
+
elif "Security" in solution_name:
|
|
934
|
+
sol_output = {"alerts": 2, "alert_frames": [45, 178]}
|
|
935
|
+
elif "Queue" in solution_name:
|
|
936
|
+
sol_output = {"region_counts": {"Queue 1": 7, "Queue 2": 4}}
|
|
937
|
+
|
|
938
|
+
st.session_state["video_results"] = {
|
|
939
|
+
"frames": 300,
|
|
940
|
+
"detections": 1247,
|
|
941
|
+
"processing_fps": 300 / elapsed if elapsed > 0 else 0,
|
|
942
|
+
"elapsed_time": elapsed,
|
|
943
|
+
"solution_output": sol_output,
|
|
944
|
+
"frames_preview": frames_preview,
|
|
945
|
+
"output_path": "/content/output.mp4" if st.session_state.get("infer_save_video") else None,
|
|
946
|
+
}
|
|
947
|
+
st.rerun()
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
def _run_flashdet_video(video_file, solution_name: str):
|
|
951
|
+
"""Run real FlashDet video processing with zone integration."""
|
|
952
|
+
import cv2
|
|
953
|
+
from flashdet import FlashDetPredictor
|
|
954
|
+
|
|
955
|
+
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
|
956
|
+
tmp.write(video_file.read())
|
|
957
|
+
tmp.flush()
|
|
958
|
+
|
|
959
|
+
cap = cv2.VideoCapture(tmp.name)
|
|
960
|
+
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
961
|
+
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
|
962
|
+
|
|
963
|
+
predictor = FlashDetPredictor(
|
|
964
|
+
conf_threshold=st.session_state.get("infer_conf", 0.25),
|
|
965
|
+
nms_threshold=st.session_state.get("infer_nms", 0.45),
|
|
966
|
+
)
|
|
967
|
+
|
|
968
|
+
max_frames = st.session_state.get("infer_max_frames", 300) or total
|
|
969
|
+
skip = st.session_state.get("infer_frame_skip", 1)
|
|
970
|
+
class_filter = st.session_state.get("infer_class_filter", [])
|
|
971
|
+
|
|
972
|
+
frames_preview = []
|
|
973
|
+
processed = 0
|
|
974
|
+
dets = 0
|
|
975
|
+
t0 = time.perf_counter()
|
|
976
|
+
progress = st.progress(0, text="Processing video...")
|
|
977
|
+
|
|
978
|
+
for idx in range(min(total, max_frames)):
|
|
979
|
+
ret, frame = cap.read()
|
|
980
|
+
if not ret:
|
|
981
|
+
break
|
|
982
|
+
if idx % skip != 0:
|
|
983
|
+
continue
|
|
984
|
+
results = predictor.predict(frame)
|
|
985
|
+
if class_filter:
|
|
986
|
+
results = [r for r in results if r.get("class_name") in class_filter]
|
|
987
|
+
dets += len(results)
|
|
988
|
+
processed += 1
|
|
989
|
+
if idx % max(total // 4, 1) == 0:
|
|
990
|
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
991
|
+
frames_preview.append(Image.fromarray(rgb))
|
|
992
|
+
progress.progress(idx / min(total, max_frames), text=f"Frame {idx}/{min(total, max_frames)}")
|
|
993
|
+
|
|
994
|
+
cap.release()
|
|
995
|
+
progress.empty()
|
|
996
|
+
elapsed = time.perf_counter() - t0
|
|
997
|
+
|
|
998
|
+
st.session_state["video_results"] = {
|
|
999
|
+
"frames": processed,
|
|
1000
|
+
"detections": dets,
|
|
1001
|
+
"processing_fps": processed / elapsed if elapsed > 0 else 0,
|
|
1002
|
+
"elapsed_time": elapsed,
|
|
1003
|
+
"solution_output": {},
|
|
1004
|
+
"frames_preview": frames_preview,
|
|
1005
|
+
"output_path": "/content/output.mp4" if st.session_state.get("infer_save_video") else None,
|
|
1006
|
+
}
|
|
1007
|
+
st.rerun()
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
# ════════════════════════════════════════════════════════════════
|
|
1011
|
+
# DETECTION UTILITIES
|
|
1012
|
+
# ════════════════════════════════════════════════════════════════
|
|
1013
|
+
|
|
1014
|
+
def _detect(image: Image.Image) -> list:
|
|
1015
|
+
"""Run detection (real or demo)."""
|
|
1016
|
+
if importlib.util.find_spec("flashdet"):
|
|
1017
|
+
try:
|
|
1018
|
+
return _detect_real(image)
|
|
1019
|
+
except Exception:
|
|
1020
|
+
pass
|
|
1021
|
+
return _detect_demo(image)
|
|
1022
|
+
|
|
1023
|
+
|
|
1024
|
+
def _detect_real(image: Image.Image) -> list:
|
|
1025
|
+
"""Real FlashDet detection."""
|
|
1026
|
+
from flashdet import FlashDetPredictor
|
|
1027
|
+
weights = st.session_state.get("infer_weights_path", "")
|
|
1028
|
+
arch_map = {
|
|
1029
|
+
"FlashDet-Pico": "flashdet_pico", "FlashDet-Nano": "flashdet_nano",
|
|
1030
|
+
"FlashDet-Small": "flashdet_small", "FlashDet-Medium": "flashdet_medium",
|
|
1031
|
+
"FlashDet-Large": "flashdet_large",
|
|
1032
|
+
}
|
|
1033
|
+
model = st.session_state.get("infer_model_arch", "FlashDet-Nano")
|
|
1034
|
+
predictor = FlashDetPredictor(
|
|
1035
|
+
model_name=arch_map.get(model, "flashdet_nano"),
|
|
1036
|
+
weights=weights,
|
|
1037
|
+
conf_threshold=st.session_state.get("infer_conf", 0.25),
|
|
1038
|
+
nms_threshold=st.session_state.get("infer_nms", 0.45),
|
|
1039
|
+
)
|
|
1040
|
+
results = predictor.predict(np.array(image))
|
|
1041
|
+
class_filter = st.session_state.get("infer_class_filter", [])
|
|
1042
|
+
dets = []
|
|
1043
|
+
for d in results:
|
|
1044
|
+
cid = int(d["class_id"])
|
|
1045
|
+
name = COCO_CLASSES[cid] if cid < len(COCO_CLASSES) else f"class_{cid}"
|
|
1046
|
+
if class_filter and name not in class_filter:
|
|
1047
|
+
continue
|
|
1048
|
+
dets.append([name, f"{float(d['confidence']):.2f}", str([int(x) for x in d["bbox"]])])
|
|
1049
|
+
return dets
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def _detect_demo(image: Image.Image) -> list:
|
|
1053
|
+
"""Demo detection with class filtering."""
|
|
1054
|
+
w, h = image.size
|
|
1055
|
+
rng = np.random.default_rng(hash(image.tobytes()[:100]) % (2**31))
|
|
1056
|
+
conf_thr = st.session_state.get("infer_conf", 0.25)
|
|
1057
|
+
class_filter = st.session_state.get("infer_class_filter", [])
|
|
1058
|
+
dets = []
|
|
1059
|
+
for _ in range(rng.integers(3, 8)):
|
|
1060
|
+
cid = rng.integers(0, 10)
|
|
1061
|
+
name = COCO_CLASSES[cid]
|
|
1062
|
+
if class_filter and name not in class_filter:
|
|
1063
|
+
continue
|
|
1064
|
+
conf = rng.uniform(0.3, 0.98)
|
|
1065
|
+
if conf < conf_thr:
|
|
1066
|
+
continue
|
|
1067
|
+
x1 = int(rng.integers(0, max(w - 100, 1)))
|
|
1068
|
+
y1 = int(rng.integers(0, max(h - 100, 1)))
|
|
1069
|
+
x2 = min(x1 + int(rng.integers(60, 200)), w)
|
|
1070
|
+
y2 = min(y1 + int(rng.integers(60, 200)), h)
|
|
1071
|
+
dets.append([name, f"{conf:.2f}", str([x1, y1, x2, y2])])
|
|
1072
|
+
return sorted(dets, key=lambda x: float(x[1]), reverse=True)
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def _draw_boxes(image: Image.Image, dets: list) -> Image.Image:
|
|
1076
|
+
"""Draw bounding boxes with labels."""
|
|
1077
|
+
draw = ImageDraw.Draw(image)
|
|
1078
|
+
try:
|
|
1079
|
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
|
|
1080
|
+
except (OSError, IOError):
|
|
1081
|
+
font = ImageFont.load_default()
|
|
1082
|
+
|
|
1083
|
+
for i, (cls, conf, bbox_str) in enumerate(dets):
|
|
1084
|
+
bbox = json.loads(bbox_str.replace("'", '"'))
|
|
1085
|
+
color = COLORS[i % len(COLORS)]
|
|
1086
|
+
draw.rectangle(bbox, outline=color, width=3)
|
|
1087
|
+
label = f"{cls} {conf}"
|
|
1088
|
+
tb = draw.textbbox((0, 0), label, font=font)
|
|
1089
|
+
tw, th = tb[2] - tb[0], tb[3] - tb[1]
|
|
1090
|
+
ly = max(bbox[1] - th - 6, 0)
|
|
1091
|
+
draw.rectangle([bbox[0], ly, bbox[0] + tw + 8, ly + th + 6], fill=color)
|
|
1092
|
+
draw.text((bbox[0] + 4, ly + 2), label, fill="white", font=font)
|
|
1093
|
+
return image
|
|
1094
|
+
|
|
1095
|
+
|
|
1096
|
+
def _parse_polygon() -> list:
|
|
1097
|
+
"""Parse polygon points from session state."""
|
|
1098
|
+
polygons = st.session_state.get("zone_polygons")
|
|
1099
|
+
if polygons and len(polygons) > 0:
|
|
1100
|
+
return polygons[-1]
|
|
1101
|
+
|
|
1102
|
+
text = st.session_state.get("polygon_points", "100,100\n500,100\n500,400\n100,400")
|
|
1103
|
+
points = []
|
|
1104
|
+
for line in text.strip().split("\n"):
|
|
1105
|
+
line = line.strip()
|
|
1106
|
+
if line and "," in line:
|
|
1107
|
+
parts = line.split(",")
|
|
1108
|
+
try:
|
|
1109
|
+
points.append((int(parts[0].strip()), int(parts[1].strip())))
|
|
1110
|
+
except (ValueError, IndexError):
|
|
1111
|
+
continue
|
|
1112
|
+
return points if len(points) >= 3 else [(100, 100), (500, 100), (500, 400), (100, 400)]
|