xttmp 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xttmp/__init__.py +1 -0
- xttmp/api/__init__.py +5 -0
- xttmp/api/evaluate.py +163 -0
- xttmp/api/get_visualize_handle.py +29 -0
- xttmp/api/instancing_model.py +35 -0
- xttmp/core/__init__.py +0 -0
- xttmp/core/apgstmd_core.py +188 -0
- xttmp/core/apgstmdv2_core.py +79 -0
- xttmp/core/base_core.py +36 -0
- xttmp/core/dstmd_core.py +213 -0
- xttmp/core/estmd_backbone.py +110 -0
- xttmp/core/estmd_core.py +356 -0
- xttmp/core/feedbackstmd_core.py +61 -0
- xttmp/core/fracstmd_core.py +98 -0
- xttmp/core/fstmd_core.py +15 -0
- xttmp/core/fstmdv2_core.py +42 -0
- xttmp/core/haarstmd_core.py +140 -0
- xttmp/core/math_operator.py +307 -0
- xttmp/core/stfeedbackstmd_core.py +233 -0
- xttmp/core/stmdplus_core.py +187 -0
- xttmp/core/stmdplusv2_core.py +82 -0
- xttmp/core/vstmd_core.py +420 -0
- xttmp/demo/evaluate_model.py +92 -0
- xttmp/demo/inference_gui.py +148 -0
- xttmp/demo/inference_gui_single_process.py +134 -0
- xttmp/demo/inference_image_stream.py +67 -0
- xttmp/demo/inference_video.py +66 -0
- xttmp/main.py +14 -0
- xttmp/model/__init__.py +13 -0
- xttmp/model/backbone.py +514 -0
- xttmp/model/facilitated_model.py +230 -0
- xttmp/model/feedback_model.py +271 -0
- xttmp/model/haarstmd.py +61 -0
- xttmp/model/vstmd.py +457 -0
- xttmp/util/__init__.py +0 -0
- xttmp/util/compute_module.py +402 -0
- xttmp/util/create_kernel.py +363 -0
- xttmp/util/evaluate_module.py +697 -0
- xttmp/util/iostream.py +660 -0
- xttmp-2.3.0.dist-info/METADATA +85 -0
- xttmp-2.3.0.dist-info/RECORD +45 -0
- xttmp-2.3.0.dist-info/WHEEL +5 -0
- xttmp-2.3.0.dist-info/entry_points.txt +2 -0
- xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
- xttmp-2.3.0.dist-info/top_level.txt +1 -0
xttmp/util/iostream.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional, List, Union, Tuple, Any
|
|
6
|
+
|
|
7
|
+
import cv2
|
|
8
|
+
import numpy as np
|
|
9
|
+
import tkinter as tk
|
|
10
|
+
from tkinter import ttk, filedialog, messagebox
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from .. import model
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Get the full path of this file
|
|
18
|
+
filePath = os.path.realpath(__file__)
|
|
19
|
+
gitCodePath = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(filePath))))
|
|
20
|
+
VID_DEFAULT_FOLDER = os.path.join(gitCodePath, 'demodata')
|
|
21
|
+
IMG_DEFAULT_FOLDER = os.path.join(VID_DEFAULT_FOLDER, 'imgstream')
|
|
22
|
+
# Add the path to the package containing the models
|
|
23
|
+
ALL_MODEL = model.__all__
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
class FrameIterator:
|
|
28
|
+
"""
|
|
29
|
+
A flexible iterator class that can retrieve images frame-by-frame from a video file
|
|
30
|
+
or from a sequence of numerically sorted image files.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, input_path: str, is_video: bool = True, is_silence: bool = True, device: str = 'cpu'):
|
|
34
|
+
"""
|
|
35
|
+
Initialize the iterator.
|
|
36
|
+
|
|
37
|
+
Parameters:
|
|
38
|
+
input_path (str):
|
|
39
|
+
- If is_video is True: full path to the video file.
|
|
40
|
+
- If is_video is False: path to folder containing image sequence.
|
|
41
|
+
is_video (bool): Specifies whether input is a video or image sequence.
|
|
42
|
+
is_silence (bool): If True, suppresses standard informational output.
|
|
43
|
+
device (str): Computation device for PyTorch tensors ('cpu', 'cuda', etc.).
|
|
44
|
+
"""
|
|
45
|
+
self.input_path = input_path
|
|
46
|
+
self.is_video = is_video
|
|
47
|
+
self.is_silence = is_silence
|
|
48
|
+
self.device = device # 将 device 提升为类属性
|
|
49
|
+
|
|
50
|
+
self.current_index = 0
|
|
51
|
+
self.total_frames = 0
|
|
52
|
+
self.is_open = False
|
|
53
|
+
|
|
54
|
+
self.img_height, self.img_width = None, None
|
|
55
|
+
self.cap = None
|
|
56
|
+
self.image_files: List[str] = []
|
|
57
|
+
|
|
58
|
+
if self.is_video:
|
|
59
|
+
self._init_video_source()
|
|
60
|
+
else:
|
|
61
|
+
self._init_image_sequence_source()
|
|
62
|
+
|
|
63
|
+
def _log(self, message: str, level: int = logging.INFO):
|
|
64
|
+
"""Helper to handle silenced logging"""
|
|
65
|
+
if not self.is_silence or level >= logging.WARNING:
|
|
66
|
+
logger.log(level, message)
|
|
67
|
+
|
|
68
|
+
def _setup(self, current_index: int):
|
|
69
|
+
"""Jump to a specific frame index."""
|
|
70
|
+
if current_index < 0 or (self.total_frames > 0 and current_index >= self.total_frames):
|
|
71
|
+
logger.warning(f"Index {current_index} is out of bounds (0 - {self.total_frames-1}).")
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
self.current_index = current_index
|
|
75
|
+
if self.is_video and self.cap:
|
|
76
|
+
success = self.cap.set(cv2.CAP_PROP_POS_FRAMES, current_index)
|
|
77
|
+
if not success:
|
|
78
|
+
logger.warning(f"Unable to set video frame position to {current_index}.")
|
|
79
|
+
|
|
80
|
+
# --- Video processing logic ---
|
|
81
|
+
def _init_video_source(self):
|
|
82
|
+
if not os.path.isfile(self.input_path):
|
|
83
|
+
logger.error(f"Video file not found: {self.input_path}")
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
self.cap = cv2.VideoCapture(self.input_path)
|
|
87
|
+
if not self.cap.isOpened():
|
|
88
|
+
logger.error(f"Unable to open video file: {self.input_path}")
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
92
|
+
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
|
93
|
+
self.img_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
94
|
+
self.img_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
95
|
+
self.is_open = True
|
|
96
|
+
|
|
97
|
+
self._log(f"Successfully opened video file. Total frames: {self.total_frames}")
|
|
98
|
+
|
|
99
|
+
def _get_next_frame_from_video(self) -> Optional[np.ndarray]:
|
|
100
|
+
if not self.is_open or not self.cap:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
ret, frame = self.cap.read()
|
|
104
|
+
if ret:
|
|
105
|
+
self.current_index += 1
|
|
106
|
+
return frame
|
|
107
|
+
|
|
108
|
+
self.release() # Video reading completed or error occurred
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
# --- Image sequence processing logic ---
|
|
112
|
+
def _init_image_sequence_source(self):
|
|
113
|
+
if not os.path.isdir(self.input_path):
|
|
114
|
+
logger.error(f"Folder not found: {self.input_path}")
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
self.image_files = self._get_sorted_image_files(self.input_path)
|
|
118
|
+
|
|
119
|
+
if not self.image_files:
|
|
120
|
+
logger.error(f"No image files found in folder: {self.input_path}")
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
self.total_frames = len(self.image_files)
|
|
124
|
+
self.is_open = True
|
|
125
|
+
self._log(f"Successfully loaded image sequence. Total images: {self.total_frames}")
|
|
126
|
+
|
|
127
|
+
# Read first image to get dimensions
|
|
128
|
+
first_image = cv2.imread(self.image_files[0], cv2.IMREAD_COLOR)
|
|
129
|
+
if first_image is not None:
|
|
130
|
+
self.img_height, self.img_width = first_image.shape[:2]
|
|
131
|
+
|
|
132
|
+
def _get_next_frame_from_sequence(self) -> Optional[np.ndarray]:
|
|
133
|
+
while self.is_open and self.current_index < self.total_frames:
|
|
134
|
+
file_path = self.image_files[self.current_index]
|
|
135
|
+
|
|
136
|
+
# 【重要修复】无论是否读取成功,都必须 +1,否则读到坏图会死循环
|
|
137
|
+
self.current_index += 1
|
|
138
|
+
|
|
139
|
+
frame = cv2.imread(file_path, cv2.IMREAD_COLOR)
|
|
140
|
+
if frame is not None:
|
|
141
|
+
return frame
|
|
142
|
+
else:
|
|
143
|
+
logger.warning(f"Unable to read or decode image file: {file_path}")
|
|
144
|
+
|
|
145
|
+
self.release()
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
# --- Core interfaces ---
|
|
149
|
+
def get_next_frame(self) -> Tuple[Optional[np.ndarray], Optional[torch.Tensor], bool]:
|
|
150
|
+
"""
|
|
151
|
+
[Public interface] Get next image frame and its grayscale PyTorch tensor.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Tuple[color_img, gray_tensor, is_valid]:
|
|
155
|
+
- color_img: BGR image (NumPy array) or None
|
|
156
|
+
- gray_tensor: Grayscale tensor shape (1, 1, H, W) or None
|
|
157
|
+
- is_valid: Boolean indicating if retrieval was successful
|
|
158
|
+
"""
|
|
159
|
+
color_img = self._get_next_frame_from_video() if self.is_video else self._get_next_frame_from_sequence()
|
|
160
|
+
|
|
161
|
+
if color_img is None:
|
|
162
|
+
return None, None, False
|
|
163
|
+
|
|
164
|
+
gray_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2GRAY)
|
|
165
|
+
# 移除了中间不必要的变量,直接构造 tensor
|
|
166
|
+
gray_tensor = torch.from_numpy(gray_img).to(device=self.device, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
|
|
167
|
+
|
|
168
|
+
return color_img, gray_tensor, True
|
|
169
|
+
|
|
170
|
+
# --- Iterator & Context Manager Protocols ---
|
|
171
|
+
def __iter__(self):
|
|
172
|
+
return self
|
|
173
|
+
|
|
174
|
+
def __next__(self) -> Tuple[np.ndarray, torch.Tensor]:
|
|
175
|
+
color_img, gray_tensor, is_valid = self.get_next_frame()
|
|
176
|
+
if not is_valid:
|
|
177
|
+
raise StopIteration
|
|
178
|
+
return color_img, gray_tensor
|
|
179
|
+
|
|
180
|
+
def __enter__(self):
|
|
181
|
+
"""Enable context manager: `with FrameIterator(...) as it:`"""
|
|
182
|
+
return self
|
|
183
|
+
|
|
184
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
185
|
+
self.release()
|
|
186
|
+
|
|
187
|
+
def release(self):
|
|
188
|
+
if self.is_open:
|
|
189
|
+
if self.is_video and self.cap is not None:
|
|
190
|
+
self.cap.release()
|
|
191
|
+
self.is_open = False
|
|
192
|
+
self._log("Resources released.")
|
|
193
|
+
|
|
194
|
+
def __del__(self):
|
|
195
|
+
self.release()
|
|
196
|
+
|
|
197
|
+
# --- Helpers ---
|
|
198
|
+
@staticmethod
|
|
199
|
+
def _natural_sort_key(s: str) -> List[Union[str, int]]:
|
|
200
|
+
return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]
|
|
201
|
+
|
|
202
|
+
def _get_sorted_image_files(self, folder_path: str) -> List[str]:
|
|
203
|
+
""" Uses pathlib for faster and cleaner directory iteration. """
|
|
204
|
+
valid_exts = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif'}
|
|
205
|
+
path = Path(folder_path)
|
|
206
|
+
|
|
207
|
+
# iterdir() 比多次调用 glob 性能好得多
|
|
208
|
+
files = [
|
|
209
|
+
str(f) for f in path.iterdir()
|
|
210
|
+
if f.is_file() and f.suffix.lower() in valid_exts
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
return sorted(files, key=lambda f: self._natural_sort_key(Path(f).name))
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class FrameVisualizer:
|
|
217
|
+
def __init__(self, window_name="Visualizer",
|
|
218
|
+
result_index_type="matrix",
|
|
219
|
+
win_width=None, win_height=None,
|
|
220
|
+
is_headless=False,
|
|
221
|
+
conf_threshold=0.8 # 阈值参数
|
|
222
|
+
):
|
|
223
|
+
"""
|
|
224
|
+
初始化可视化器
|
|
225
|
+
:param conf_threshold: 可视化过滤的相对阈值 (0.0 ~ 1.0)
|
|
226
|
+
"""
|
|
227
|
+
self.window_name = window_name
|
|
228
|
+
self.result_index_type = result_index_type # "matrix", "dots", "bbox"
|
|
229
|
+
self.win_width = win_width or 800
|
|
230
|
+
self.win_height = win_height or 600
|
|
231
|
+
self.is_headless = is_headless
|
|
232
|
+
self.conf_threshold = conf_threshold
|
|
233
|
+
|
|
234
|
+
self.save_output = False
|
|
235
|
+
self.video_writer = None # 显式初始化为 None
|
|
236
|
+
self.paused = False
|
|
237
|
+
|
|
238
|
+
self._setup_window()
|
|
239
|
+
|
|
240
|
+
def _setup_window(self):
|
|
241
|
+
"""初始化窗口"""
|
|
242
|
+
if self.is_headless:
|
|
243
|
+
return
|
|
244
|
+
cv2.namedWindow(self.window_name, cv2.WINDOW_GUI_NORMAL)
|
|
245
|
+
cv2.resizeWindow(self.window_name, self.win_width, self.win_height)
|
|
246
|
+
|
|
247
|
+
def setup_video_writer(self, output_path, fps=30, width=None, height=None):
|
|
248
|
+
"""初始化视频写入器 (建议外部显式调用)"""
|
|
249
|
+
# 确保目录存在
|
|
250
|
+
output_dir = os.path.dirname(output_path)
|
|
251
|
+
if output_dir and not os.path.exists(output_dir):
|
|
252
|
+
os.makedirs(output_dir)
|
|
253
|
+
|
|
254
|
+
width = width if width is not None else self.win_width
|
|
255
|
+
height = height if height is not None else self.win_height
|
|
256
|
+
|
|
257
|
+
# 常用 mp4v 兼容性较好
|
|
258
|
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
259
|
+
self.video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
|
260
|
+
self.save_output = True
|
|
261
|
+
print(f">>> Video writer initialized: {output_path}")
|
|
262
|
+
|
|
263
|
+
def update(self, frame, result=None, direction=None, annotation=None, process_time=None) -> bool:
|
|
264
|
+
if frame is None:
|
|
265
|
+
return False
|
|
266
|
+
|
|
267
|
+
# --- 绘制逻辑 ---
|
|
268
|
+
# 即使 result 是空的,只要不为 None 也可以处理
|
|
269
|
+
if result is not None:
|
|
270
|
+
if self.result_index_type == "matrix":
|
|
271
|
+
self._draw_matrix(frame, result, direction, self.conf_threshold)
|
|
272
|
+
elif self.result_index_type == "dots":
|
|
273
|
+
result = result.cpu().numpy() if isinstance(result, torch.Tensor) else result
|
|
274
|
+
self._draw_dots(frame, result, self.conf_threshold)
|
|
275
|
+
elif self.result_index_type == "bbox":
|
|
276
|
+
self._draw_bbox(frame, result, self.conf_threshold, annotation)
|
|
277
|
+
|
|
278
|
+
# --- 信息显示 ---
|
|
279
|
+
if process_time is not None:
|
|
280
|
+
cv2.putText(frame, f'Time: {process_time*1000:.1f} ms',
|
|
281
|
+
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8,
|
|
282
|
+
(0, 255, 0), 2, cv2.LINE_AA)
|
|
283
|
+
|
|
284
|
+
# --- 视频保存 (安全检查) ---
|
|
285
|
+
if self.save_output and self.video_writer is not None:
|
|
286
|
+
self.video_writer.write(frame)
|
|
287
|
+
|
|
288
|
+
# --- Headless 快速返回 ---
|
|
289
|
+
if self.is_headless:
|
|
290
|
+
return True
|
|
291
|
+
|
|
292
|
+
# --- 窗口显示与按键控制 ---
|
|
293
|
+
cv2.imshow(self.window_name, frame)
|
|
294
|
+
|
|
295
|
+
while True:
|
|
296
|
+
key = cv2.waitKey(1) & 0xFF
|
|
297
|
+
|
|
298
|
+
if key == 27 or key == ord('q'): # Esc
|
|
299
|
+
return False
|
|
300
|
+
|
|
301
|
+
if key == 32: # Space
|
|
302
|
+
self.paused = not self.paused
|
|
303
|
+
print(f">>> State: {'Paused' if self.paused else 'Running'}")
|
|
304
|
+
if not self.paused: break
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
if key == ord('n'): # Next
|
|
308
|
+
self.paused = True
|
|
309
|
+
break
|
|
310
|
+
|
|
311
|
+
if not self.paused:
|
|
312
|
+
break
|
|
313
|
+
|
|
314
|
+
return True
|
|
315
|
+
|
|
316
|
+
def close(self):
|
|
317
|
+
if not self.is_headless:
|
|
318
|
+
cv2.destroyWindow(self.window_name)
|
|
319
|
+
if self.video_writer is not None:
|
|
320
|
+
self.video_writer.release()
|
|
321
|
+
|
|
322
|
+
@staticmethod
|
|
323
|
+
def _draw_arrows(frame, x_coords, y_coords, directions, length=15):
|
|
324
|
+
"""
|
|
325
|
+
:param x_coords: X 坐标数组 (Cols)
|
|
326
|
+
:param y_coords: Y 坐标数组 (Rows)
|
|
327
|
+
:param directions: 方向角数组 (弧度)
|
|
328
|
+
"""
|
|
329
|
+
if len(x_coords) == 0: return
|
|
330
|
+
|
|
331
|
+
cos_d = np.cos(directions)
|
|
332
|
+
sin_d = np.sin(directions)
|
|
333
|
+
|
|
334
|
+
# Zip 里的顺序明确为: x, y, cos, sin
|
|
335
|
+
for x, y, c, s in zip(x_coords, y_coords, cos_d, sin_d):
|
|
336
|
+
# 必须转为 int,因为 cv2 坐标不支持 float
|
|
337
|
+
start_pt = (int(x), int(y))
|
|
338
|
+
end_pt = (int(x + length * c), int(y - length * s))
|
|
339
|
+
|
|
340
|
+
cv2.arrowedLine(frame, start_pt, end_pt,
|
|
341
|
+
color=(0, 0, 255), thickness=1,
|
|
342
|
+
tipLength=0.3, line_type=cv2.LINE_AA)
|
|
343
|
+
|
|
344
|
+
@staticmethod
|
|
345
|
+
def _draw_matrix(frame, matrix, direction_map, threshold):
|
|
346
|
+
"""处理 Matrix 格式 (Heatmap)"""
|
|
347
|
+
if np.max(matrix) <= 0: return
|
|
348
|
+
|
|
349
|
+
# np.where 返回 (rows, cols) 即 (y, x)
|
|
350
|
+
rows, cols = np.where(matrix > threshold)
|
|
351
|
+
|
|
352
|
+
# 画点
|
|
353
|
+
for r, c in zip(rows, cols):
|
|
354
|
+
# cv2 坐标是 (x, y) -> (col, row)
|
|
355
|
+
cv2.drawMarker(frame, (c, r), color=(0, 0, 255),
|
|
356
|
+
markerType=cv2.MARKER_STAR, markerSize=5, thickness=1)
|
|
357
|
+
|
|
358
|
+
# 画箭头
|
|
359
|
+
if direction_map is not None and len(rows) > 0:
|
|
360
|
+
# 确保 direction_map 维度匹配,这里假设是同样大小的矩阵
|
|
361
|
+
valid_dirs = direction_map[rows, cols]
|
|
362
|
+
|
|
363
|
+
# 过滤 NaN
|
|
364
|
+
valid_mask = ~np.isnan(valid_dirs)
|
|
365
|
+
if np.any(valid_mask):
|
|
366
|
+
# 传入 _draw_arrows 的必须是 (x, y) 对应 (cols, rows)
|
|
367
|
+
FrameVisualizer._draw_arrows(frame,
|
|
368
|
+
x_coords=cols[valid_mask],
|
|
369
|
+
y_coords=rows[valid_mask],
|
|
370
|
+
directions=valid_dirs[valid_mask])
|
|
371
|
+
|
|
372
|
+
@staticmethod
|
|
373
|
+
def _draw_dots(frame, response, threshold):
|
|
374
|
+
"""处理 Dots 格式: [[x, y, score, dir], ...]"""
|
|
375
|
+
if len(response) == 0: return
|
|
376
|
+
|
|
377
|
+
# 假设格式: Col 0=x, Col 1=y, Col 2=score
|
|
378
|
+
# 安全过滤
|
|
379
|
+
scores = response[:, 2]
|
|
380
|
+
|
|
381
|
+
mask = scores > threshold
|
|
382
|
+
filtered = response[mask]
|
|
383
|
+
|
|
384
|
+
if len(filtered) == 0: return
|
|
385
|
+
|
|
386
|
+
xs = filtered[:, 0]
|
|
387
|
+
ys = filtered[:, 1]
|
|
388
|
+
|
|
389
|
+
for x, y in zip(xs, ys):
|
|
390
|
+
cv2.drawMarker(frame, (int(x), int(y)), color=(0, 0, 255),
|
|
391
|
+
markerType=cv2.MARKER_STAR, markerSize=5, thickness=1)
|
|
392
|
+
|
|
393
|
+
# 处理方向 (假设 Col 3 是方向)
|
|
394
|
+
if response.shape[1] > 3:
|
|
395
|
+
dirs = filtered[:, 3]
|
|
396
|
+
valid_mask = ~np.isnan(dirs)
|
|
397
|
+
FrameVisualizer._draw_arrows(frame,
|
|
398
|
+
x_coords=xs[valid_mask],
|
|
399
|
+
y_coords=ys[valid_mask],
|
|
400
|
+
directions=dirs[valid_mask])
|
|
401
|
+
|
|
402
|
+
@staticmethod
|
|
403
|
+
def _draw_bbox(frame, response, threshold, annotation=None):
|
|
404
|
+
"""处理 BBox 格式: [[x1, y1, x2, y2, score, dir], ...]"""
|
|
405
|
+
if response.size == 0: return
|
|
406
|
+
|
|
407
|
+
# 1. 提前过滤:先做 Mask 过滤,减少后续转换的数据量
|
|
408
|
+
response = response.cpu().numpy() if isinstance(response, torch.Tensor) else response
|
|
409
|
+
mask = response[:, 4] > threshold
|
|
410
|
+
filtered_res = response[mask]
|
|
411
|
+
if filtered_res.size == 0: return
|
|
412
|
+
|
|
413
|
+
# 2. 批量转换类型
|
|
414
|
+
# 只转换坐标部分,避免对整个 response 进行转换
|
|
415
|
+
boxes = filtered_res[:, :4].astype(np.int32)
|
|
416
|
+
|
|
417
|
+
# 3. 提取方向(如果有)
|
|
418
|
+
# 优化点:直接从过滤后的结果拿第 5 列,避免多次索引 filtered
|
|
419
|
+
has_dir = filtered_res.shape[1] > 5
|
|
420
|
+
|
|
421
|
+
# 4. 优化循环逻辑:将判断移出循环
|
|
422
|
+
if annotation is not None:
|
|
423
|
+
filtered_anno = np.asanyarray(annotation)[mask]
|
|
424
|
+
for (x1, y1, x2, y2), anno in zip(boxes, filtered_anno):
|
|
425
|
+
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 1, cv2.LINE_AA)
|
|
426
|
+
cv2.putText(frame, str(anno), (x1, y1 - 5),
|
|
427
|
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
|
428
|
+
else:
|
|
429
|
+
for x1, y1, x2, y2 in boxes:
|
|
430
|
+
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 1, cv2.LINE_AA)
|
|
431
|
+
|
|
432
|
+
# 5. 绘制方向向量(矢量化计算中心点)
|
|
433
|
+
if has_dir:
|
|
434
|
+
dirs = filtered_res[:, 5]
|
|
435
|
+
v_mask = ~np.isnan(dirs)
|
|
436
|
+
if np.any(v_mask):
|
|
437
|
+
v_boxes = boxes[v_mask]
|
|
438
|
+
# 使用位移运算或更快的加法,并保持 float 计算中心点
|
|
439
|
+
# 这里的 // 2 直接得到整数坐标,方便画图
|
|
440
|
+
c_xs = (v_boxes[:, 0] + v_boxes[:, 2]) // 2
|
|
441
|
+
c_ys = (v_boxes[:, 1] + v_boxes[:, 3]) // 2
|
|
442
|
+
FrameVisualizer._draw_arrows(frame, c_xs, c_ys, dirs[v_mask])
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
class ModelSelectorGUI:
|
|
446
|
+
def __init__(self, root):
|
|
447
|
+
self.root = root
|
|
448
|
+
|
|
449
|
+
def create_gui(self, modelList):
|
|
450
|
+
self.modelLabel = ttk.Label(self.root, text="Select a model:", width = 15)
|
|
451
|
+
self.modelLabel.grid(row=0, column=0, padx=10, pady=10)
|
|
452
|
+
|
|
453
|
+
self.modelCombobox = ttk.Combobox(self.root, values=modelList, width = 30)
|
|
454
|
+
self.modelCombobox.current(11)
|
|
455
|
+
self.modelCombobox.grid(row=0, column=1, columnspan=2, padx=10, pady=10)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class InputSelectorGUI:
|
|
459
|
+
def __init__(self, root):
|
|
460
|
+
self.root = root
|
|
461
|
+
|
|
462
|
+
self.vidElement = {}
|
|
463
|
+
self.imgElement = {}
|
|
464
|
+
|
|
465
|
+
self.imgSelectFolder = None
|
|
466
|
+
|
|
467
|
+
self.inputType = None
|
|
468
|
+
self.startFrame = None
|
|
469
|
+
self.endFrame = None
|
|
470
|
+
self.video_file = None
|
|
471
|
+
self.check = False
|
|
472
|
+
self.vidName = None
|
|
473
|
+
self.startFolder = None
|
|
474
|
+
self.startImgName = None
|
|
475
|
+
self.endFolder = None
|
|
476
|
+
self.endImgName = None
|
|
477
|
+
|
|
478
|
+
def create_gui(self):
|
|
479
|
+
self.inputTypeLabel = ttk.Label(self.root, text="Select input from:", width = 15)
|
|
480
|
+
self.inputTypeLabel.grid(row=1, column=0, padx=10, pady=10)
|
|
481
|
+
|
|
482
|
+
self.selectedOption = tk.IntVar(value=0)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
self.vidLabel = ttk.Radiobutton(self.root,
|
|
486
|
+
text='Video stream',
|
|
487
|
+
variable=self.selectedOption,
|
|
488
|
+
value=1,
|
|
489
|
+
command=self.select_vidstream)
|
|
490
|
+
self.vidLabel.grid(row=1, column=2, padx=10, pady=10)
|
|
491
|
+
|
|
492
|
+
self.imgLabel = ttk.Radiobutton(self.root,
|
|
493
|
+
text='Image stream',
|
|
494
|
+
variable=self.selectedOption,
|
|
495
|
+
value=2,
|
|
496
|
+
command=self.select_imgstream)
|
|
497
|
+
self.imgLabel.grid(row=1, column=1, padx=10, pady=10)
|
|
498
|
+
|
|
499
|
+
def select_vidstream(self):
|
|
500
|
+
self.imgSelectFolder = None
|
|
501
|
+
self.startImgName = None
|
|
502
|
+
self.endImgName = None
|
|
503
|
+
for element in self.imgElement.values():
|
|
504
|
+
element.destroy()
|
|
505
|
+
|
|
506
|
+
self.vidElement['lblVidIndicate'] = ttk.Label(self.root, text= 'Video\'s path:', width = 15)
|
|
507
|
+
self.vidElement['lblVidIndicate'].grid(row=2, column=0, padx=10, pady=30)
|
|
508
|
+
self.vidElement['lblVidPath'] = ttk.Label(self.root,
|
|
509
|
+
text="Waiting for the selection",
|
|
510
|
+
wraplength=220
|
|
511
|
+
)
|
|
512
|
+
self.vidElement['lblVidPath'].grid(row=2, column=1, columnspan=2, padx=10, pady=10)
|
|
513
|
+
|
|
514
|
+
self.vidElement['btn'] = ttk.Button(self.root, text="Select a video", command=self._clicked_vid)
|
|
515
|
+
self.vidElement['btn'].grid(row=3, column=2, padx=10, pady=10)
|
|
516
|
+
|
|
517
|
+
def _clicked_vid(self):
|
|
518
|
+
self.vidName = filedialog.askopenfilenames(initialdir=VID_DEFAULT_FOLDER)
|
|
519
|
+
self.vidName = self.vidName[0]
|
|
520
|
+
self.vidElement['lblVidPath'].config(text=self.vidName)
|
|
521
|
+
|
|
522
|
+
def select_imgstream(self):
|
|
523
|
+
self.vidName = None
|
|
524
|
+
for element in self.vidElement.values():
|
|
525
|
+
element.destroy()
|
|
526
|
+
|
|
527
|
+
self.imgElement['lblFolder'] = ttk.Label(self.root, text="Image's folder: ", width = 15)
|
|
528
|
+
self.imgElement['lblFolder'].grid(row=2, column=0, padx=10, pady=10)
|
|
529
|
+
self.imgElement['lblFolderName'] = ttk.Label(self.root, text="Waiting for the selection", wraplength=220)
|
|
530
|
+
self.imgElement['lblFolderName'].grid(row=2, column=1, columnspan=2, padx=10, pady=30)
|
|
531
|
+
|
|
532
|
+
self.imgElement['btnStart'] = ttk.Button(self.root, text="Select start frame", command=self._clicked_start_img)
|
|
533
|
+
self.imgElement['btnStart'].grid(row=3, column=1, padx=10, pady=10)
|
|
534
|
+
self.imgElement['btnEnd'] = ttk.Button(self.root, text="Select end frame", command=self._clicked_end_img)
|
|
535
|
+
self.imgElement['btnEnd'].grid(row=4, column=1, padx=10, pady=10)
|
|
536
|
+
|
|
537
|
+
def _clicked_start_img(self):
|
|
538
|
+
startImgFullPath = filedialog.askopenfilenames(
|
|
539
|
+
initialdir=IMG_DEFAULT_FOLDER if self.imgSelectFolder is None else self.imgSelectFolder)
|
|
540
|
+
self.startFolder, self.startImgName = os.path.split(startImgFullPath[0])
|
|
541
|
+
if self.endFolder is not None:
|
|
542
|
+
if os.path.basename(self.startFolder) == os.path.basename(self.endFolder):
|
|
543
|
+
if self.endImgName is not None:
|
|
544
|
+
if check_same_ext_name(self.startImgName, self.endImgName):
|
|
545
|
+
self.check = True
|
|
546
|
+
else:
|
|
547
|
+
messagebox.showinfo("Message title", "Start image has a different extension than end image.")
|
|
548
|
+
else:
|
|
549
|
+
messagebox.showinfo("Message title", "The image stream must be in the same folder!")
|
|
550
|
+
|
|
551
|
+
self.imgSelectFolder = self.startFolder
|
|
552
|
+
self.imgElement['lblFolderName'].config(text=self.imgSelectFolder)
|
|
553
|
+
|
|
554
|
+
self.imgElement['lblStartImg'] = ttk.Label(self.root, text=self.startImgName)
|
|
555
|
+
self.imgElement['lblStartImg'].grid(row=3, column=2, padx=10, pady=10)
|
|
556
|
+
|
|
557
|
+
def _clicked_end_img(self):
|
|
558
|
+
endImgFullPath = filedialog.askopenfilenames(
|
|
559
|
+
initialdir=IMG_DEFAULT_FOLDER if self.imgSelectFolder is None else self.imgSelectFolder)
|
|
560
|
+
self.endFolder , self.endImgName = os.path.split(endImgFullPath[0])
|
|
561
|
+
|
|
562
|
+
if self.startFolder is not None:
|
|
563
|
+
if os.path.basename(self.endFolder) == os.path.basename(self.startFolder):
|
|
564
|
+
if self.startImgName is not None:
|
|
565
|
+
if check_same_ext_name(self.startImgName, self.endImgName):
|
|
566
|
+
self.check = True
|
|
567
|
+
else:
|
|
568
|
+
messagebox.showinfo("Message title", "Start image has a different extension than end image.")
|
|
569
|
+
else:
|
|
570
|
+
messagebox.showinfo("Message title", "The image stream must be in the same folder!")
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
self.imgSelectFolder = self.endFolder
|
|
574
|
+
self.imgElement['lblFolderName'].config(text=self.imgSelectFolder)
|
|
575
|
+
|
|
576
|
+
self.imgElement['lblEndImg'] = ttk.Label(self.root, text=self.endImgName)
|
|
577
|
+
self.imgElement['lblEndImg'].grid(row=4, column=2, padx=10, pady=10)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
class ModelAndInputSelectorGUI:
|
|
581
|
+
def __init__(self, root):
|
|
582
|
+
self.root = root
|
|
583
|
+
|
|
584
|
+
windowHeight = 350
|
|
585
|
+
windowWidth = 400
|
|
586
|
+
|
|
587
|
+
startHeight = (root.winfo_screenheight() - windowHeight) // 2
|
|
588
|
+
startWidth = (root.winfo_screenwidth() - windowWidth) // 2
|
|
589
|
+
|
|
590
|
+
self.root.geometry('{}x{}+{}+{}'.format(windowWidth, windowHeight, startWidth, startHeight))
|
|
591
|
+
self.root.title("Small target motion detector - Runner")
|
|
592
|
+
self.root.iconbitmap(os.path.join(os.path.dirname(filePath), 'stmd.ico'))
|
|
593
|
+
|
|
594
|
+
self.objModelSelector = ModelSelectorGUI(root)
|
|
595
|
+
self.objInputSelector = InputSelectorGUI(root)
|
|
596
|
+
|
|
597
|
+
self.btnRun = ttk.Button(self.root, text="Run", command=self._run)
|
|
598
|
+
self.btnRun.place(x = 20, y=300)
|
|
599
|
+
self.btnStepping = ttk.Button(self.root, text="Stepping", command=self._stepping)
|
|
600
|
+
self.btnStepping.place(x = 20, y=270)
|
|
601
|
+
self.isStepping = False
|
|
602
|
+
|
|
603
|
+
def create_gui(self):
|
|
604
|
+
self.objModelSelector.create_gui(ALL_MODEL)
|
|
605
|
+
self.objInputSelector.create_gui()
|
|
606
|
+
|
|
607
|
+
self.root.mainloop()
|
|
608
|
+
|
|
609
|
+
if self.objInputSelector.selectedOption.get() == 1:
|
|
610
|
+
return self.modelName, self.vidName, None, self.isStepping
|
|
611
|
+
elif self.objInputSelector.selectedOption.get() == 2:
|
|
612
|
+
return self.modelName, self.startImgName, self.endImgName, self.isStepping
|
|
613
|
+
|
|
614
|
+
def _run(self):
|
|
615
|
+
self.modelName = self.objModelSelector.modelCombobox.get()
|
|
616
|
+
if self.modelName not in ALL_MODEL:
|
|
617
|
+
messagebox.showinfo("Message title", "Please select a STMD-based model!")
|
|
618
|
+
return
|
|
619
|
+
|
|
620
|
+
if self.objInputSelector.selectedOption.get() == 1:
|
|
621
|
+
if self.objInputSelector.vidName is not None:
|
|
622
|
+
self.vidName = self.objInputSelector.vidName
|
|
623
|
+
self.root.destroy()
|
|
624
|
+
else:
|
|
625
|
+
messagebox.showinfo("Message title", "Please select a video")
|
|
626
|
+
elif self.objInputSelector.selectedOption.get() == 2:
|
|
627
|
+
if self.objInputSelector.startImgName is None:
|
|
628
|
+
messagebox.showinfo("Message title", "Please select start frame!")
|
|
629
|
+
return
|
|
630
|
+
|
|
631
|
+
if self.objInputSelector.endImgName is None:
|
|
632
|
+
messagebox.showinfo("Message title", "Please select end frame!")
|
|
633
|
+
return
|
|
634
|
+
|
|
635
|
+
if self.objInputSelector.check:
|
|
636
|
+
self.startImgName = os.path.join(self.objInputSelector.imgSelectFolder,
|
|
637
|
+
self.objInputSelector.startImgName)
|
|
638
|
+
self.endImgName = os.path.join(self.objInputSelector.imgSelectFolder,
|
|
639
|
+
self.objInputSelector.endImgName)
|
|
640
|
+
self.root.destroy()
|
|
641
|
+
else:
|
|
642
|
+
messagebox.showinfo("Message title", "The image stream must be in the same folder!")
|
|
643
|
+
else:
|
|
644
|
+
messagebox.showinfo("Message title", "Please select input")
|
|
645
|
+
|
|
646
|
+
def _stepping(self):
|
|
647
|
+
self.isStepping = True
|
|
648
|
+
self._run()
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def check_same_ext_name(startImgName, endImgName):
|
|
652
|
+
_, ext1 = os.path.splitext(startImgName)
|
|
653
|
+
_, ext2 = os.path.splitext(endImgName)
|
|
654
|
+
# Check if the extensions of the start and end images are the same
|
|
655
|
+
if ext1 != ext2:
|
|
656
|
+
return False
|
|
657
|
+
else:
|
|
658
|
+
return True
|
|
659
|
+
|
|
660
|
+
|