xttmp 2.3.0.2__tar.gz → 2.3.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xttmp-2.3.0.2/src/xttmp.egg-info → xttmp-2.3.0.3}/PKG-INFO +1 -1
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/pyproject.toml +1 -1
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/demo/inference_gui.py +24 -38
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/util/compute_module.py +145 -187
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/util/iostream.py +172 -47
- {xttmp-2.3.0.2 → xttmp-2.3.0.3/src/xttmp.egg-info}/PKG-INFO +1 -1
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/LICENSE +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/README.md +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/setup.cfg +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/__init__.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/api/__init__.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/api/evaluate.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/api/get_visualize_handle.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/api/instancing_model.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/__init__.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/apgstmd_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/apgstmdv2_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/base_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/dstmd_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/estmd_backbone.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/estmd_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/feedbackstmd_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/fracstmd_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/fstmd_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/fstmdv2_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/haarstmd_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/math_operator.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/stfeedbackstmd_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/stmdplus_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/stmdplusv2_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/core/vstmd_core.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/demo/evaluate_model.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/demo/inference_gui_single_process.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/demo/inference_image_stream.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/demo/inference_video.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/main.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/model/__init__.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/model/backbone.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/model/facilitated_model.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/model/feedback_model.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/model/haarstmd.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/model/vstmd.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/util/__init__.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/util/create_kernel.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/util/evaluate_module.py +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp/util/stmd.ico +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp.egg-info/SOURCES.txt +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp.egg-info/dependency_links.txt +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp.egg-info/entry_points.txt +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp.egg-info/requires.txt +0 -0
- {xttmp-2.3.0.2 → xttmp-2.3.0.3}/src/xttmp.egg-info/top_level.txt +0 -0
|
@@ -12,19 +12,14 @@ file_path = os.path.realpath(__file__)
|
|
|
12
12
|
py_pkg_path = os.path.dirname(os.path.dirname(os.path.dirname(file_path)))
|
|
13
13
|
sys.path.append(py_pkg_path)
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
instancing_model,
|
|
24
|
-
)
|
|
25
|
-
except ImportError as e:
|
|
26
|
-
raise ImportError("Failed to import required modules. "
|
|
27
|
-
"Ensure that the 'xttmp' package is correctly installed.") from e
|
|
15
|
+
|
|
16
|
+
from xttmp.util.iostream import ( # type: ignore
|
|
17
|
+
XTTMP_GUI,
|
|
18
|
+
FrameIterator,
|
|
19
|
+
FrameVisualizer,
|
|
20
|
+
)
|
|
21
|
+
from xttmp.api import instancing_model
|
|
22
|
+
|
|
28
23
|
|
|
29
24
|
# configure logging
|
|
30
25
|
logging.basicConfig(level=logging.INFO,
|
|
@@ -32,18 +27,15 @@ logging.basicConfig(level=logging.INFO,
|
|
|
32
27
|
logger = logging.getLogger(__name__)
|
|
33
28
|
|
|
34
29
|
class StmdGui:
|
|
35
|
-
def __init__(self
|
|
30
|
+
def __init__(self):
|
|
36
31
|
""" Initialize STMD GUI """
|
|
37
|
-
self.device =
|
|
38
|
-
self.
|
|
39
|
-
self.get_top_num = get_top_num
|
|
40
|
-
self.ModelAndInputSelectorGUI = ModelAndInputSelectorGUI
|
|
32
|
+
self.device = None
|
|
33
|
+
self.ModelAndInputSelectorGUI = XTTMP_GUI
|
|
41
34
|
self.FrameIterator = FrameIterator
|
|
42
35
|
self.FrameVisualizer = FrameVisualizer
|
|
43
|
-
self.
|
|
36
|
+
self.post_processor = None
|
|
44
37
|
self.instancing_model = instancing_model
|
|
45
38
|
|
|
46
|
-
|
|
47
39
|
def _get_user_input(self) -> tuple:
|
|
48
40
|
""" get user input """
|
|
49
41
|
root = tk.Tk()
|
|
@@ -93,21 +85,17 @@ class StmdGui:
|
|
|
93
85
|
logger.info("User cancelled input.")
|
|
94
86
|
return
|
|
95
87
|
|
|
96
|
-
model_name, opt1, opt2, is_stepping = user_input
|
|
88
|
+
model_name, opt1, opt2, is_stepping, device, post_processor, show_threshold = user_input
|
|
89
|
+
self.post_processor = post_processor
|
|
90
|
+
self.device = device
|
|
97
91
|
reader = self._create_frame_reader(opt1, opt2)
|
|
98
|
-
model = self.instancing_model(model_name, device
|
|
99
|
-
post_processor = self.PostProcessing(
|
|
100
|
-
device=self.device,
|
|
101
|
-
nms_radio=8,
|
|
102
|
-
get_top_num=self.get_top_num,
|
|
103
|
-
)
|
|
92
|
+
model = self.instancing_model(model_name, device)
|
|
104
93
|
|
|
105
94
|
visualizer = self.FrameVisualizer(
|
|
106
95
|
window_name=model_name,
|
|
107
|
-
result_index_type="dots",
|
|
108
96
|
win_width=reader.img_width,
|
|
109
97
|
win_height=reader.img_height,
|
|
110
|
-
conf_threshold=
|
|
98
|
+
conf_threshold=show_threshold,
|
|
111
99
|
)
|
|
112
100
|
if is_stepping:
|
|
113
101
|
visualizer.paused = True
|
|
@@ -117,16 +105,17 @@ class StmdGui:
|
|
|
117
105
|
if not is_valid:
|
|
118
106
|
break
|
|
119
107
|
|
|
120
|
-
if self.device
|
|
108
|
+
if self.device == 'cuda':
|
|
121
109
|
torch.cuda.synchronize()
|
|
122
110
|
time_start = time.perf_counter()
|
|
123
111
|
result = model(gray_tensor)
|
|
124
|
-
if self.device
|
|
112
|
+
if self.device == 'cuda':
|
|
125
113
|
torch.cuda.synchronize()
|
|
126
114
|
run_time = time.perf_counter() - time_start
|
|
127
115
|
|
|
128
|
-
|
|
129
|
-
|
|
116
|
+
post_res = post_processor(result['response'], result.get('direction'))
|
|
117
|
+
show_str = f'{self.device} : {run_time*1000:.1f} ms'
|
|
118
|
+
if not visualizer.update(color_img, result=post_res, show_str=show_str):
|
|
130
119
|
break
|
|
131
120
|
|
|
132
121
|
except Exception as e:
|
|
@@ -139,10 +128,7 @@ class StmdGui:
|
|
|
139
128
|
reader.release()
|
|
140
129
|
logger.info("Shutdown completed")
|
|
141
130
|
|
|
142
|
-
def main(show_threshold: float = 0, get_top_num: int = 10):
|
|
143
|
-
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
144
|
-
obj = StmdGui(DEVICE, show_threshold = show_threshold, get_top_num = get_top_num)
|
|
145
|
-
obj.run()
|
|
146
131
|
|
|
147
132
|
if __name__ == "__main__":
|
|
148
|
-
|
|
133
|
+
obj = StmdGui()
|
|
134
|
+
obj.run()
|
|
@@ -6,67 +6,6 @@ import torch
|
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def compute_temporal_conv(iptCell, kernel, pointer=None):
|
|
10
|
-
"""
|
|
11
|
-
Computes temporal convolution.
|
|
12
|
-
|
|
13
|
-
Parameters:
|
|
14
|
-
- iptCell: A list of arrays where each element has the same dimension.
|
|
15
|
-
- kernel: A vector representing the convolution kernel.
|
|
16
|
-
- headPointer: Head pointer of the input cell array (optional).
|
|
17
|
-
|
|
18
|
-
Returns:
|
|
19
|
-
- optMatrix: The result of the temporal convolution.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
# Default value for headPointer
|
|
23
|
-
if pointer is None:
|
|
24
|
-
pointer = len(iptCell) - 1
|
|
25
|
-
|
|
26
|
-
# Initialize output matrix
|
|
27
|
-
if iptCell[pointer] is None:
|
|
28
|
-
return None
|
|
29
|
-
|
|
30
|
-
# Ensure kernel is a vector
|
|
31
|
-
kernel = np.squeeze(kernel)
|
|
32
|
-
if not np.ndim(kernel) == 1:
|
|
33
|
-
raise ValueError('The kernel must be a vector.')
|
|
34
|
-
|
|
35
|
-
# Determine the lengths of input cell array and kernel
|
|
36
|
-
k1 = len(iptCell)
|
|
37
|
-
k2 = len(kernel)
|
|
38
|
-
length = min(k1, k2)
|
|
39
|
-
|
|
40
|
-
if isinstance(iptCell[pointer], np.ndarray):
|
|
41
|
-
optMatrix = np.zeros_like(iptCell[pointer])
|
|
42
|
-
elif isinstance(iptCell[pointer], torch.Tensor):
|
|
43
|
-
optMatrix = torch.zeros_like(iptCell[pointer])
|
|
44
|
-
# Perform temporal convolution
|
|
45
|
-
for t in range(length):
|
|
46
|
-
j = (pointer - t) % k1
|
|
47
|
-
if abs(kernel[t]) > 1e-16 and iptCell[j] is not None:
|
|
48
|
-
optMatrix += iptCell[j] * kernel[t]
|
|
49
|
-
|
|
50
|
-
return optMatrix
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def compute_circularlist_conv(circularCell, temporalKernel):
|
|
54
|
-
"""
|
|
55
|
-
Compute the convolution of a circular cell with a temporal kernel.
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
- circularCell: The circular cell data.
|
|
59
|
-
- temporalKernel: The temporal kernel data.
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
- opt_matrix: The result of the convolution.
|
|
63
|
-
"""
|
|
64
|
-
optMatrix = compute_temporal_conv(circularCell,
|
|
65
|
-
temporalKernel,
|
|
66
|
-
circularCell.pointer )
|
|
67
|
-
return optMatrix
|
|
68
|
-
|
|
69
|
-
|
|
70
9
|
def compute_response(ipt):
|
|
71
10
|
"""
|
|
72
11
|
Computes the maximum response from multiple inputs.
|
|
@@ -233,139 +172,17 @@ class AreaNMS:
|
|
|
233
172
|
return matrix * (matrix == local_max)
|
|
234
173
|
|
|
235
174
|
|
|
236
|
-
def get_top_k_torch(response_tensor, direction_tensor, k=1000):
|
|
237
|
-
"""
|
|
238
|
-
输入:
|
|
239
|
-
response_tensor: (..., H, W) 任意维度的 Tensor
|
|
240
|
-
direction_tensor: (..., H, W) 形状需与 response 匹配 (可选)
|
|
241
|
-
输出:
|
|
242
|
-
torch.Tensor: shape=(M, 4), dtype=float32, 其中 M <= k
|
|
243
|
-
格式: [[x, y, response, direction], ...]
|
|
244
|
-
"""
|
|
245
|
-
# 1. 获取维度
|
|
246
|
-
H, W = response_tensor.shape[-2:]
|
|
247
|
-
k = min(k, H * W)
|
|
248
|
-
|
|
249
|
-
# 2. 展平 (Flatten)
|
|
250
|
-
# view(-1) 零拷贝,极快
|
|
251
|
-
flat_response = response_tensor.view(-1)
|
|
252
|
-
|
|
253
|
-
# 3. TopK (GPU 上极速排序)
|
|
254
|
-
top_vals, top_indices = torch.topk(flat_response, k=k)
|
|
255
|
-
|
|
256
|
-
# 4. 过滤掉 <= 0 的值 ---
|
|
257
|
-
# 创建掩码:只保留大于 0 的值
|
|
258
|
-
mask = top_vals > 0
|
|
259
|
-
|
|
260
|
-
# 如果全都是 0,直接返回空数组,避免后续报错
|
|
261
|
-
if not mask.any():
|
|
262
|
-
return torch.empty((0, 4))
|
|
263
|
-
|
|
264
|
-
# 应用掩码,缩减 tensor 长度
|
|
265
|
-
top_vals = top_vals[mask]
|
|
266
|
-
top_indices = top_indices[mask]
|
|
267
|
-
# ------------------------------------
|
|
268
|
-
|
|
269
|
-
# 5. 计算坐标 (x, y)
|
|
270
|
-
# 此时计算量已经减少,只计算非零点
|
|
271
|
-
top_y = top_indices.div(W, rounding_mode='floor').float()
|
|
272
|
-
top_x = (top_indices % W).float()
|
|
273
|
-
|
|
274
|
-
# 6. 获取 Direction
|
|
275
|
-
if direction_tensor is not None and direction_tensor.numel() > 0:
|
|
276
|
-
flat_direction = direction_tensor.view(-1)
|
|
277
|
-
# 注意:这里使用过滤后的 top_indices
|
|
278
|
-
top_dirs = flat_direction[top_indices]
|
|
279
|
-
else:
|
|
280
|
-
top_dirs = torch.empty_like(top_vals).fill_(float('nan'))
|
|
281
|
-
|
|
282
|
-
# 7. 堆叠 (Stack) -> (M, 4)
|
|
283
|
-
result_tensor = torch.stack([top_x, top_y, top_vals, top_dirs], dim=1)
|
|
284
|
-
|
|
285
|
-
return result_tensor
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
def get_top_k_numpy(response_array, direction_array=None, k=1000):
|
|
289
|
-
"""
|
|
290
|
-
输入:
|
|
291
|
-
response_array: (..., H, W) numpy.ndarray
|
|
292
|
-
direction_array: (..., H, W) (可选)
|
|
293
|
-
输出:
|
|
294
|
-
numpy.ndarray: shape=(M, 4), dtype=float32, 其中 M <= k
|
|
295
|
-
格式: [[x, y, response, direction], ...]
|
|
296
|
-
"""
|
|
297
|
-
# 1. 获取维度
|
|
298
|
-
shape = response_array.shape
|
|
299
|
-
H, W = shape[-2:]
|
|
300
|
-
|
|
301
|
-
# 零拷贝展平
|
|
302
|
-
flat_response = response_array.ravel()
|
|
303
|
-
k = min(k, flat_response.size)
|
|
304
|
-
|
|
305
|
-
# 2. TopK 核心优化 (O(N))
|
|
306
|
-
# argpartition 找出最大的 k 个 (无序)
|
|
307
|
-
unsorted_top_indices = np.argpartition(flat_response, -k)[-k:]
|
|
308
|
-
unsorted_top_vals = flat_response[unsorted_top_indices]
|
|
309
|
-
|
|
310
|
-
# 3. 局部排序 (O(k log k))
|
|
311
|
-
# argsort 默认升序,[::-1] 翻转为降序
|
|
312
|
-
sort_idx = np.argsort(unsorted_top_vals)[::-1]
|
|
313
|
-
|
|
314
|
-
# 获取排序后的 Top K 索引和值
|
|
315
|
-
top_indices = unsorted_top_indices[sort_idx]
|
|
316
|
-
top_vals = unsorted_top_vals[sort_idx]
|
|
317
|
-
|
|
318
|
-
# --- [关键修改] 4. 过滤掉 <= 0 的值 ---
|
|
319
|
-
# 创建掩码
|
|
320
|
-
mask = top_vals > 0
|
|
321
|
-
|
|
322
|
-
# 极速判断:如果没有有效值,直接返回空数组
|
|
323
|
-
# np.any() 很快
|
|
324
|
-
if not np.any(mask):
|
|
325
|
-
return np.empty((0, 4), dtype=np.float32)
|
|
326
|
-
|
|
327
|
-
# 应用掩码 (切片操作,只保留有效值)
|
|
328
|
-
# 因为 k 通常不大 (比如 1000),这里的拷贝开销可忽略不计
|
|
329
|
-
top_vals = top_vals[mask]
|
|
330
|
-
top_indices = top_indices[mask]
|
|
331
|
-
|
|
332
|
-
# 更新实际数量 M
|
|
333
|
-
M = top_vals.size
|
|
334
|
-
# ------------------------------------
|
|
335
|
-
|
|
336
|
-
# 5. 计算坐标 (x, y)
|
|
337
|
-
# 只对过滤后的索引计算,节省算力
|
|
338
|
-
top_y, top_x = np.unravel_index(top_indices, (H, W))
|
|
339
|
-
|
|
340
|
-
# 6. 获取 Direction
|
|
341
|
-
if direction_array is not None and direction_array.size > 0:
|
|
342
|
-
flat_direction = direction_array.ravel()
|
|
343
|
-
top_dirs = flat_direction[top_indices]
|
|
344
|
-
else:
|
|
345
|
-
top_dirs = np.full(M, np.nan, dtype=np.float32)
|
|
346
|
-
|
|
347
|
-
# 7. 堆叠结果
|
|
348
|
-
# 分配恰好大小为 M 的内存
|
|
349
|
-
result = np.empty((M, 4), dtype=np.float32)
|
|
350
|
-
result[:, 0] = top_x # x
|
|
351
|
-
result[:, 1] = top_y # y
|
|
352
|
-
result[:, 2] = top_vals # response
|
|
353
|
-
result[:, 3] = top_dirs # direction
|
|
354
|
-
|
|
355
|
-
return result
|
|
356
|
-
|
|
357
|
-
|
|
358
175
|
class PostProcessing:
|
|
359
176
|
"""
|
|
360
177
|
Post-processing class to apply AreaNMS, get top K, and return list format.
|
|
361
178
|
"""
|
|
362
179
|
|
|
363
|
-
def __init__(self,
|
|
180
|
+
def __init__(self, nms_radio = 8, get_top_num=1000):
|
|
364
181
|
"""
|
|
365
182
|
Args:
|
|
366
|
-
|
|
183
|
+
nms_radio (int): Radius for AreaNMS.
|
|
184
|
+
get_top_num (int): Number of top points to extract.
|
|
367
185
|
"""
|
|
368
|
-
self.device = device
|
|
369
186
|
self.area_nms = AreaNMS(radio=nms_radio)
|
|
370
187
|
self.get_top_num = get_top_num
|
|
371
188
|
|
|
@@ -389,7 +206,7 @@ class PostProcessing:
|
|
|
389
206
|
"""
|
|
390
207
|
nms_response = self.area_nms(response)
|
|
391
208
|
|
|
392
|
-
res = get_top_k_torch(nms_response,
|
|
209
|
+
res, _ = get_top_k_torch(nms_response,
|
|
393
210
|
direction,
|
|
394
211
|
k=self.get_top_num)
|
|
395
212
|
if res.shape[0] == 0:
|
|
@@ -400,3 +217,144 @@ class PostProcessing:
|
|
|
400
217
|
res[:, 2] /= max_score
|
|
401
218
|
|
|
402
219
|
return res
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@torch.no_grad()
|
|
223
|
+
def gen_bboxes_around_points(results, box_size=16, shift_ratio=0.3):
|
|
224
|
+
"""Generate initial bboxes around detected motion points.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
results: (N, 4) -> [x, y, response, direction]
|
|
228
|
+
box_size: int, box size
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
[[x1, y1, x2, y2], ... ] (N, 4) tensors
|
|
232
|
+
"""
|
|
233
|
+
N = results.shape[0]
|
|
234
|
+
|
|
235
|
+
if N == 0:
|
|
236
|
+
return torch.empty((0, 4), device=results.device, dtype=torch.int)
|
|
237
|
+
|
|
238
|
+
rear_x, rear_y, direction = results[:, 0], results[:, 1], results[:, 3]
|
|
239
|
+
|
|
240
|
+
radius = box_size * 0.5
|
|
241
|
+
shift_mag = box_size * shift_ratio
|
|
242
|
+
|
|
243
|
+
# 1. 计算偏移量 (dx, dy),根据方向和预设的 shift_mag
|
|
244
|
+
dx = torch.cos(direction) * shift_mag
|
|
245
|
+
dy = -torch.sin(direction) * shift_mag
|
|
246
|
+
|
|
247
|
+
# 2. 一次性将 NaN 偏移量替换为 0.0
|
|
248
|
+
dx = torch.nan_to_num(dx, nan=0.0)
|
|
249
|
+
dy = torch.nan_to_num(dy, nan=0.0)
|
|
250
|
+
|
|
251
|
+
# 3. 计算中心点
|
|
252
|
+
center_x = rear_x + dx
|
|
253
|
+
center_y = rear_y + dy
|
|
254
|
+
|
|
255
|
+
x1 = (center_x - radius)
|
|
256
|
+
y1 = (center_y - radius)
|
|
257
|
+
x2 = (center_x + radius)
|
|
258
|
+
y2 = (center_y + radius)
|
|
259
|
+
|
|
260
|
+
# Stack as (N, 4) -> [x1, y1, x2, y2]
|
|
261
|
+
return torch.stack([x1, y1, x2, y2], dim=1)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@torch.no_grad()
|
|
265
|
+
def get_top_k_torch(response_tensor, direction_tensor=None, k=100):
|
|
266
|
+
"""
|
|
267
|
+
Extract the top-k points with highest responses from feature maps, filtering out non-positive values.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
response_tensor (torch.Tensor): The response map tensor of shape (B, 1, H, W) or (B, H, W).
|
|
271
|
+
direction_tensor (torch.Tensor, optional): The corresponding direction map tensor of
|
|
272
|
+
shape (B, 1, H, W) or (B, H, W). Must match response_tensor's shape. Defaults to None.
|
|
273
|
+
k (int, optional): The maximum number of top points to extract per batch. Defaults to 100.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
Tuple[torch.Tensor, torch.Tensor]:
|
|
277
|
+
- results (torch.Tensor): A tensor of shape (M, 4) containing the valid extracted points
|
|
278
|
+
across the entire batch. M <= B * k. Each row is formatted as [x, y, response, direction].
|
|
279
|
+
- batch_ids (torch.Tensor): A 1D tensor of shape (M,) containing the corresponding
|
|
280
|
+
batch index (from 0 to B-1) for each point in `results`. dtype is torch.long.
|
|
281
|
+
"""
|
|
282
|
+
B, _, H, W = response_tensor.shape
|
|
283
|
+
k = min(k, H * W)
|
|
284
|
+
device = response_tensor.device
|
|
285
|
+
|
|
286
|
+
# 1. Flatten -> (B, H*W)
|
|
287
|
+
flat_response = response_tensor.reshape(B, -1)
|
|
288
|
+
|
|
289
|
+
# 2. TopK -> top_vals and top_indices are both (B, k)
|
|
290
|
+
top_vals, top_indices = torch.topk(flat_response, k=k, dim=-1)
|
|
291
|
+
|
|
292
|
+
# 3. Get Direction -> (B, k)
|
|
293
|
+
if direction_tensor is not None and direction_tensor.numel() > 0:
|
|
294
|
+
flat_direction = direction_tensor.reshape(B, -1)
|
|
295
|
+
top_dirs = torch.gather(flat_direction, dim=-1, index=top_indices)
|
|
296
|
+
else:
|
|
297
|
+
top_dirs = torch.full_like(top_vals, float('nan'))
|
|
298
|
+
|
|
299
|
+
# 4. Calculate coordinates (x, y) -> (B, k)
|
|
300
|
+
top_y = top_indices.div(W, rounding_mode='floor').float()
|
|
301
|
+
top_x = (top_indices % W).float()
|
|
302
|
+
|
|
303
|
+
# 5. Stack -> merge on the last dimension, shape becomes (B, k, 4)
|
|
304
|
+
stacked = torch.stack([top_x, top_y, top_vals, top_dirs], dim=-1)
|
|
305
|
+
|
|
306
|
+
# 6. Generate Mask -> (B, k)
|
|
307
|
+
mask = top_vals > 0
|
|
308
|
+
|
|
309
|
+
# 7. Split and filter by Batch
|
|
310
|
+
result_list = []
|
|
311
|
+
batch_id_list = []
|
|
312
|
+
|
|
313
|
+
for i in range(B):
|
|
314
|
+
batch_mask = mask[i] # Get the mask for the i-th batch
|
|
315
|
+
|
|
316
|
+
# Apply mask: [k, 4] -> [M_i, 4]
|
|
317
|
+
valid_stacked = stacked[i][batch_mask]
|
|
318
|
+
result_list.append(valid_stacked)
|
|
319
|
+
|
|
320
|
+
# Create a batch index tensor of shape (M_i,) filled with the current batch index 'i'
|
|
321
|
+
batch_id_list.append(torch.full((valid_stacked.shape[0],), i, device=device, dtype=torch.long))
|
|
322
|
+
|
|
323
|
+
# Concatenate all valid items into continuous tensors
|
|
324
|
+
return torch.cat(result_list, dim=0), torch.cat(batch_id_list, dim=0)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
@torch.no_grad()
|
|
328
|
+
def get_STMD_region_proposal(response_tensor, direction_tensor=None, top_k=1, box_size=16, spatial_scale=1.0, shift_ratio=0.3):
|
|
329
|
+
nms_win = int(box_size * spatial_scale) | 1 # 确保是奇数
|
|
330
|
+
score_mask = F.max_pool2d(response_tensor, kernel_size=nms_win, stride=1, padding=nms_win//2)
|
|
331
|
+
|
|
332
|
+
nms_response_tensor = torch.where(response_tensor == score_mask, response_tensor, 0.0)
|
|
333
|
+
|
|
334
|
+
vSTMD_res, batch_id = get_top_k_torch(nms_response_tensor, direction_tensor, k=top_k)
|
|
335
|
+
|
|
336
|
+
if spatial_scale > 1:
|
|
337
|
+
vSTMD_res[:, :2] *= spatial_scale # 将坐标放大回原图尺度
|
|
338
|
+
|
|
339
|
+
bboxes = gen_bboxes_around_points(vSTMD_res, box_size, shift_ratio)
|
|
340
|
+
|
|
341
|
+
return vSTMD_res, bboxes, batch_id
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
@torch.no_grad()
|
|
345
|
+
def bbox_post_processing(top_k=1, box_size=16, spatial_scale=1.0, shift_ratio=0.3):
|
|
346
|
+
|
|
347
|
+
def post_process_func(
|
|
348
|
+
response_tensor,
|
|
349
|
+
direction_tensor=None
|
|
350
|
+
):
|
|
351
|
+
vSTMD_res, bboxes, _ = get_STMD_region_proposal(response_tensor,
|
|
352
|
+
direction_tensor,
|
|
353
|
+
top_k=top_k,
|
|
354
|
+
box_size=box_size,
|
|
355
|
+
spatial_scale=spatial_scale,
|
|
356
|
+
shift_ratio=shift_ratio )
|
|
357
|
+
return torch.cat([bboxes, vSTMD_res[..., 2:3]], dim=1)
|
|
358
|
+
|
|
359
|
+
return post_process_func
|
|
360
|
+
|
|
@@ -3,6 +3,7 @@ import re
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
import logging
|
|
5
5
|
from typing import Optional, List, Union, Tuple, Any
|
|
6
|
+
from functools import partial
|
|
6
7
|
|
|
7
8
|
import cv2
|
|
8
9
|
import numpy as np
|
|
@@ -12,6 +13,7 @@ import torch
|
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
from .. import model
|
|
16
|
+
from .compute_module import PostProcessing, bbox_post_processing
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
# Get the full path of this file
|
|
@@ -215,7 +217,6 @@ class FrameIterator:
|
|
|
215
217
|
|
|
216
218
|
class FrameVisualizer:
|
|
217
219
|
def __init__(self, window_name="Visualizer",
|
|
218
|
-
result_index_type="matrix",
|
|
219
220
|
win_width=None, win_height=None,
|
|
220
221
|
is_headless=False,
|
|
221
222
|
conf_threshold=0.8 # 阈值参数
|
|
@@ -225,7 +226,6 @@ class FrameVisualizer:
|
|
|
225
226
|
:param conf_threshold: 可视化过滤的相对阈值 (0.0 ~ 1.0)
|
|
226
227
|
"""
|
|
227
228
|
self.window_name = window_name
|
|
228
|
-
self.result_index_type = result_index_type # "matrix", "dots", "bbox"
|
|
229
229
|
self.win_width = win_width or 800
|
|
230
230
|
self.win_height = win_height or 600
|
|
231
231
|
self.is_headless = is_headless
|
|
@@ -260,24 +260,26 @@ class FrameVisualizer:
|
|
|
260
260
|
self.save_output = True
|
|
261
261
|
print(f">>> Video writer initialized: {output_path}")
|
|
262
262
|
|
|
263
|
-
def update(self, frame, result=None, direction=None, annotation=None,
|
|
263
|
+
def update(self, frame, result=None, direction=None, annotation=None, show_str=None) -> bool:
|
|
264
264
|
if frame is None:
|
|
265
265
|
return False
|
|
266
266
|
|
|
267
267
|
# --- 绘制逻辑 ---
|
|
268
268
|
# 即使 result 是空的,只要不为 None 也可以处理
|
|
269
269
|
if result is not None:
|
|
270
|
-
if
|
|
270
|
+
if result.dim() == 4:
|
|
271
271
|
self._draw_matrix(frame, result, direction, self.conf_threshold)
|
|
272
|
-
elif
|
|
272
|
+
elif result.shape[1] == 4:
|
|
273
273
|
result = result.cpu().numpy() if isinstance(result, torch.Tensor) else result
|
|
274
274
|
self._draw_dots(frame, result, self.conf_threshold)
|
|
275
|
-
elif
|
|
275
|
+
elif result.shape[1] == 5:
|
|
276
276
|
self._draw_bbox(frame, result, self.conf_threshold, annotation)
|
|
277
|
-
|
|
277
|
+
device_str = f'{result.device}'
|
|
278
|
+
else:
|
|
279
|
+
device_str = 'Time'
|
|
278
280
|
# --- 信息显示 ---
|
|
279
|
-
if
|
|
280
|
-
cv2.putText(frame,
|
|
281
|
+
if show_str is not None and show_str != '':
|
|
282
|
+
cv2.putText(frame, str(show_str),
|
|
281
283
|
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8,
|
|
282
284
|
(0, 255, 0), 2, cv2.LINE_AA)
|
|
283
285
|
|
|
@@ -344,10 +346,12 @@ class FrameVisualizer:
|
|
|
344
346
|
@staticmethod
|
|
345
347
|
def _draw_matrix(frame, matrix, direction_map, threshold):
|
|
346
348
|
"""处理 Matrix 格式 (Heatmap)"""
|
|
347
|
-
if
|
|
349
|
+
if torch.max(matrix) <= 0: return
|
|
348
350
|
|
|
349
351
|
# np.where 返回 (rows, cols) 即 (y, x)
|
|
350
|
-
rows, cols =
|
|
352
|
+
_, _, rows, cols = torch.where(matrix > threshold)
|
|
353
|
+
rows = rows.cpu().numpy()
|
|
354
|
+
cols = cols.cpu().numpy()
|
|
351
355
|
|
|
352
356
|
# 画点
|
|
353
357
|
for r, c in zip(rows, cols):
|
|
@@ -358,7 +362,7 @@ class FrameVisualizer:
|
|
|
358
362
|
# 画箭头
|
|
359
363
|
if direction_map is not None and len(rows) > 0:
|
|
360
364
|
# 确保 direction_map 维度匹配,这里假设是同样大小的矩阵
|
|
361
|
-
valid_dirs = direction_map[rows, cols]
|
|
365
|
+
valid_dirs = direction_map[0, 0, rows, cols]
|
|
362
366
|
|
|
363
367
|
# 过滤 NaN
|
|
364
368
|
valid_mask = ~np.isnan(valid_dirs)
|
|
@@ -450,9 +454,9 @@ class ModelSelectorGUI:
|
|
|
450
454
|
self.modelLabel = ttk.Label(self.root, text="Select a model:", width = 15)
|
|
451
455
|
self.modelLabel.grid(row=0, column=0, padx=10, pady=10)
|
|
452
456
|
|
|
453
|
-
self.modelCombobox = ttk.Combobox(self.root, values=modelList, width =
|
|
457
|
+
self.modelCombobox = ttk.Combobox(self.root, values=modelList, width = 25)
|
|
454
458
|
self.modelCombobox.current(11)
|
|
455
|
-
self.modelCombobox.grid(row=0, column=1, columnspan=2,
|
|
459
|
+
self.modelCombobox.grid(row=0, column=1, columnspan=2, pady=10, sticky='w')
|
|
456
460
|
|
|
457
461
|
|
|
458
462
|
class InputSelectorGUI:
|
|
@@ -476,25 +480,24 @@ class InputSelectorGUI:
|
|
|
476
480
|
self.endImgName = None
|
|
477
481
|
|
|
478
482
|
def create_gui(self):
|
|
479
|
-
self.inputTypeLabel = ttk.Label(self.root, text="
|
|
480
|
-
self.inputTypeLabel.grid(row=1, column=0, padx=10, pady=10)
|
|
483
|
+
self.inputTypeLabel = ttk.Label(self.root, text="Input Type:", width = 15)
|
|
484
|
+
self.inputTypeLabel.grid(row=1, column=0, padx=10, pady=10, sticky='w')
|
|
481
485
|
|
|
482
486
|
self.selectedOption = tk.IntVar(value=0)
|
|
483
487
|
|
|
484
|
-
|
|
485
|
-
self.vidLabel = ttk.Radiobutton(self.root,
|
|
486
|
-
text='Video stream',
|
|
487
|
-
variable=self.selectedOption,
|
|
488
|
-
value=1,
|
|
489
|
-
command=self.select_vidstream)
|
|
490
|
-
self.vidLabel.grid(row=1, column=2, padx=10, pady=10)
|
|
491
|
-
|
|
492
488
|
self.imgLabel = ttk.Radiobutton(self.root,
|
|
493
|
-
text='Image
|
|
489
|
+
text='Image Sequence',
|
|
494
490
|
variable=self.selectedOption,
|
|
495
491
|
value=2,
|
|
496
492
|
command=self.select_imgstream)
|
|
497
|
-
self.imgLabel.grid(row=1, column=
|
|
493
|
+
self.imgLabel.grid(row=1, column=2, padx=10, pady=10, sticky="w")
|
|
494
|
+
|
|
495
|
+
self.vidLabel = ttk.Radiobutton(self.root,
|
|
496
|
+
text='Video',
|
|
497
|
+
variable=self.selectedOption,
|
|
498
|
+
value=1,
|
|
499
|
+
command=self.select_vidstream)
|
|
500
|
+
self.vidLabel.grid(row=1, column=1, padx=10, pady=10, sticky="w")
|
|
498
501
|
|
|
499
502
|
def select_vidstream(self):
|
|
500
503
|
self.imgSelectFolder = None
|
|
@@ -503,16 +506,16 @@ class InputSelectorGUI:
|
|
|
503
506
|
for element in self.imgElement.values():
|
|
504
507
|
element.destroy()
|
|
505
508
|
|
|
506
|
-
self.vidElement['lblVidIndicate'] = ttk.Label(self.root, text= 'Video\'s path:'
|
|
507
|
-
self.vidElement['lblVidIndicate'].grid(row=2, column=
|
|
509
|
+
self.vidElement['lblVidIndicate'] = ttk.Label(self.root, text= 'Video\'s path:')
|
|
510
|
+
self.vidElement['lblVidIndicate'].grid(row=2, column=1, padx=10, pady=30, sticky='w')
|
|
508
511
|
self.vidElement['lblVidPath'] = ttk.Label(self.root,
|
|
509
512
|
text="Waiting for the selection",
|
|
510
513
|
wraplength=220
|
|
511
514
|
)
|
|
512
|
-
self.vidElement['lblVidPath'].grid(row=2, column=
|
|
515
|
+
self.vidElement['lblVidPath'].grid(row=2, column=2, padx=10, pady=10, sticky='w')
|
|
513
516
|
|
|
514
517
|
self.vidElement['btn'] = ttk.Button(self.root, text="Select a video", command=self._clicked_vid)
|
|
515
|
-
self.vidElement['btn'].grid(row=3, column=2, padx=10, pady=10)
|
|
518
|
+
self.vidElement['btn'].grid(row=3, column=2, padx=10, pady=10, sticky='w')
|
|
516
519
|
|
|
517
520
|
def _clicked_vid(self):
|
|
518
521
|
self.vidName = filedialog.askopenfilenames(initialdir=VID_DEFAULT_FOLDER)
|
|
@@ -524,15 +527,20 @@ class InputSelectorGUI:
|
|
|
524
527
|
for element in self.vidElement.values():
|
|
525
528
|
element.destroy()
|
|
526
529
|
|
|
527
|
-
self.imgElement['lblFolder'] = ttk.Label(self.root, text="Image's folder: "
|
|
528
|
-
self.imgElement['lblFolder'].grid(row=2, column=
|
|
530
|
+
self.imgElement['lblFolder'] = ttk.Label(self.root, text="Image's folder: ")
|
|
531
|
+
self.imgElement['lblFolder'].grid(row=2, column=1, padx=10, pady=10, sticky='w')
|
|
529
532
|
self.imgElement['lblFolderName'] = ttk.Label(self.root, text="Waiting for the selection", wraplength=220)
|
|
530
|
-
self.imgElement['lblFolderName'].grid(row=2, column=
|
|
533
|
+
self.imgElement['lblFolderName'].grid(row=2, column=2, padx=10, pady=30, sticky='w')
|
|
531
534
|
|
|
532
535
|
self.imgElement['btnStart'] = ttk.Button(self.root, text="Select start frame", command=self._clicked_start_img)
|
|
533
|
-
self.imgElement['btnStart'].grid(row=3, column=1, padx=10, pady=10)
|
|
536
|
+
self.imgElement['btnStart'].grid(row=3, column=1, padx=10, pady=10, sticky='w')
|
|
537
|
+
self.imgElement['lblStartImg'] = ttk.Label(self.root, text=self.startImgName)
|
|
538
|
+
self.imgElement['lblStartImg'].grid(row=3, column=2, padx=10, pady=10, sticky='w')
|
|
539
|
+
|
|
534
540
|
self.imgElement['btnEnd'] = ttk.Button(self.root, text="Select end frame", command=self._clicked_end_img)
|
|
535
|
-
self.imgElement['btnEnd'].grid(row=4, column=1, padx=10, pady=10)
|
|
541
|
+
self.imgElement['btnEnd'].grid(row=4, column=1, padx=10, pady=10, sticky='w')
|
|
542
|
+
self.imgElement['lblEndImg'] = ttk.Label(self.root, text=self.endImgName)
|
|
543
|
+
self.imgElement['lblEndImg'].grid(row=4, column=2, padx=10, pady=10, sticky='w')
|
|
536
544
|
|
|
537
545
|
def _clicked_start_img(self):
|
|
538
546
|
startImgFullPath = filedialog.askopenfilenames(
|
|
@@ -551,8 +559,7 @@ class InputSelectorGUI:
|
|
|
551
559
|
self.imgSelectFolder = self.startFolder
|
|
552
560
|
self.imgElement['lblFolderName'].config(text=self.imgSelectFolder)
|
|
553
561
|
|
|
554
|
-
self.imgElement['lblStartImg']
|
|
555
|
-
self.imgElement['lblStartImg'].grid(row=3, column=2, padx=10, pady=10)
|
|
562
|
+
self.imgElement['lblStartImg'].config(text=self.startImgName)
|
|
556
563
|
|
|
557
564
|
def _clicked_end_img(self):
|
|
558
565
|
endImgFullPath = filedialog.askopenfilenames(
|
|
@@ -573,16 +580,127 @@ class InputSelectorGUI:
|
|
|
573
580
|
self.imgSelectFolder = self.endFolder
|
|
574
581
|
self.imgElement['lblFolderName'].config(text=self.imgSelectFolder)
|
|
575
582
|
|
|
576
|
-
self.imgElement['lblEndImg']
|
|
577
|
-
|
|
583
|
+
self.imgElement['lblEndImg'].config(text=self.endImgName)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
class PostProcessingSelectorGUI:
|
|
587
|
+
def __init__(self, root):
|
|
588
|
+
self.root = root
|
|
589
|
+
self.output_type = "dot"
|
|
590
|
+
self.show_threshold = 0.0
|
|
591
|
+
self.top_num = 1
|
|
592
|
+
|
|
593
|
+
self.outputTypeLabel = ttk.Label(self.root, text="Output Type:", width = 15)
|
|
594
|
+
self.outputTypeLabel.grid(row=5, column=0, padx=10, pady=10)
|
|
595
|
+
|
|
596
|
+
self.selectedOption = tk.IntVar(value=2)
|
|
597
|
+
|
|
598
|
+
self.dotLabel = ttk.Radiobutton(self.root,
|
|
599
|
+
text='dot output',
|
|
600
|
+
variable=self.selectedOption,
|
|
601
|
+
value=2,
|
|
602
|
+
command=self.select_dot)
|
|
603
|
+
self.dotLabel.grid(row=5, column=1, padx=10, pady=10, sticky="w")
|
|
604
|
+
|
|
605
|
+
self.bboxLabel = ttk.Radiobutton(self.root,
|
|
606
|
+
text='bbox output',
|
|
607
|
+
variable=self.selectedOption,
|
|
608
|
+
value=1,
|
|
609
|
+
command=self.select_bbox)
|
|
610
|
+
self.bboxLabel.grid(row=5, column=2, padx=10, pady=10, sticky="w")
|
|
611
|
+
|
|
612
|
+
self.showThresholdLabel = ttk.Label(self.root, text="Threshold:", width=10)
|
|
613
|
+
self.showThresholdLabel.grid(row=6, column=1, padx=10, pady=10, sticky='w')
|
|
578
614
|
|
|
615
|
+
self.showThresholdVar = tk.StringVar(value="0")
|
|
616
|
+
self.showThresholdVar.trace_add('write', self.update_show_threshold)
|
|
617
|
+
self.showThresholdEntry = ttk.Entry(self.root, textvariable=self.showThresholdVar, width=5)
|
|
618
|
+
self.showThresholdEntry.grid(row=6, column=2, padx=10, pady=10, sticky='w')
|
|
579
619
|
|
|
580
|
-
|
|
620
|
+
self.getTopNumLabel = ttk.Label(self.root, text="Top Num:", width=10)
|
|
621
|
+
self.getTopNumLabel.grid(row=7, column=1, padx=10, pady=10, sticky='w')
|
|
622
|
+
|
|
623
|
+
self.getTopNumVar = tk.StringVar(value="1")
|
|
624
|
+
self.getTopNumVar.trace_add('write', self.update_top_num)
|
|
625
|
+
self.getTopNumEntry = ttk.Entry(self.root, textvariable=self.getTopNumVar, width=5)
|
|
626
|
+
self.getTopNumEntry.grid(row=7, column=2, padx=10, pady=10, sticky='w')
|
|
627
|
+
|
|
628
|
+
self.select_dot()
|
|
629
|
+
|
|
630
|
+
def select_dot(self):
|
|
631
|
+
self.selectedOption.set(2)
|
|
632
|
+
self.output_type = "dot"
|
|
633
|
+
|
|
634
|
+
def select_bbox(self):
|
|
635
|
+
self.selectedOption.set(1)
|
|
636
|
+
self.output_type = "bbox"
|
|
637
|
+
|
|
638
|
+
def get_post_processing(self):
|
|
639
|
+
if self.output_type == "dot":
|
|
640
|
+
return PostProcessing(get_top_num = self.top_num)
|
|
641
|
+
elif self.output_type == "bbox":
|
|
642
|
+
return bbox_post_processing(self.top_num)
|
|
643
|
+
else:
|
|
644
|
+
raise ValueError(f"Unknown output type: {self.output_type}")
|
|
645
|
+
|
|
646
|
+
def update_show_threshold(self, *args):
|
|
647
|
+
value = float(self.showThresholdVar.get())
|
|
648
|
+
self.show_threshold = min(max(value, 0.0), 1.0) # 确保在 [0.0, 1.0] 范围内
|
|
649
|
+
|
|
650
|
+
def update_top_num(self, *args):
|
|
651
|
+
value = int(self.getTopNumVar.get())
|
|
652
|
+
self.top_num = max(value, 1) # 确保 top_num 至少为 1
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
class DeviceSelectorGUI:
|
|
581
658
|
def __init__(self, root):
|
|
582
659
|
self.root = root
|
|
660
|
+
self.device = "cpu"
|
|
583
661
|
|
|
584
|
-
|
|
585
|
-
|
|
662
|
+
self.deviceLabel = ttk.Label(self.root, text="Select Device:", width=15)
|
|
663
|
+
self.deviceLabel.grid(row=8, column=0, padx=10, pady=10)
|
|
664
|
+
|
|
665
|
+
self.selectedOption = tk.IntVar(value=1)
|
|
666
|
+
|
|
667
|
+
if torch.cuda.is_available():
|
|
668
|
+
self.selectedOption.set(2)
|
|
669
|
+
self.device = "cuda"
|
|
670
|
+
|
|
671
|
+
self.cpuLabel = ttk.Radiobutton(self.root,
|
|
672
|
+
text='CPU',
|
|
673
|
+
variable=self.selectedOption,
|
|
674
|
+
value=1,
|
|
675
|
+
command=self.select_cpu)
|
|
676
|
+
self.cpuLabel.grid(row=8, column=1, padx=10, pady=10, sticky="w")
|
|
677
|
+
|
|
678
|
+
self.gpuLabel = ttk.Radiobutton(self.root,
|
|
679
|
+
text='GPU',
|
|
680
|
+
variable=self.selectedOption,
|
|
681
|
+
value=2,
|
|
682
|
+
command=self.select_gpu)
|
|
683
|
+
self.gpuLabel.grid(row=8, column=2, padx=10, pady=10, sticky="w")
|
|
684
|
+
|
|
685
|
+
def select_cpu(self):
|
|
686
|
+
self.selectedOption.set(1)
|
|
687
|
+
self.device = "cpu"
|
|
688
|
+
|
|
689
|
+
def select_gpu(self):
|
|
690
|
+
if torch.cuda.is_available():
|
|
691
|
+
self.selectedOption.set(2)
|
|
692
|
+
self.device = "cuda"
|
|
693
|
+
else:
|
|
694
|
+
messagebox.showinfo("Message title", "CUDA is not available. Please select CPU.")
|
|
695
|
+
self.select_cpu()
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
class XTTMP_GUI:
|
|
699
|
+
def __init__(self, root):
|
|
700
|
+
self.root = root
|
|
701
|
+
|
|
702
|
+
windowHeight = 550
|
|
703
|
+
windowWidth = 510
|
|
586
704
|
|
|
587
705
|
startHeight = (root.winfo_screenheight() - windowHeight) // 2
|
|
588
706
|
startWidth = (root.winfo_screenwidth() - windowWidth) // 2
|
|
@@ -592,13 +710,16 @@ class ModelAndInputSelectorGUI:
|
|
|
592
710
|
self._set_window_icon()
|
|
593
711
|
|
|
594
712
|
self.objModelSelector = ModelSelectorGUI(root)
|
|
713
|
+
self.objPostProcessingSelector = PostProcessingSelectorGUI(root)
|
|
595
714
|
self.objInputSelector = InputSelectorGUI(root)
|
|
715
|
+
self.objDeviceSelector = DeviceSelectorGUI(root)
|
|
716
|
+
|
|
596
717
|
|
|
597
|
-
self.btnRun = ttk.Button(self.root, text="Run", command=self._run)
|
|
598
|
-
self.btnRun.place(x = 20, y=300)
|
|
599
718
|
self.btnStepping = ttk.Button(self.root, text="Stepping", command=self._stepping)
|
|
600
|
-
self.btnStepping.place(x = 20, y=270)
|
|
601
719
|
self.isStepping = False
|
|
720
|
+
self.btnStepping.grid(row=9, column=2, padx=10, pady=10, sticky='e')
|
|
721
|
+
self.btnRun = ttk.Button(self.root, text="Run", command=self._run)
|
|
722
|
+
self.btnRun.grid(row=10, column=2, padx=10, pady=10, sticky='e')
|
|
602
723
|
|
|
603
724
|
def create_gui(self):
|
|
604
725
|
self.objModelSelector.create_gui(ALL_MODEL)
|
|
@@ -607,9 +728,13 @@ class ModelAndInputSelectorGUI:
|
|
|
607
728
|
self.root.mainloop()
|
|
608
729
|
|
|
609
730
|
if self.objInputSelector.selectedOption.get() == 1:
|
|
610
|
-
return self.modelName, self.vidName, None, self.isStepping
|
|
731
|
+
return (self.modelName, self.vidName, None, self.isStepping,
|
|
732
|
+
self.objDeviceSelector.device, self.objPostProcessingSelector.get_post_processing(),
|
|
733
|
+
self.objPostProcessingSelector.show_threshold)
|
|
611
734
|
elif self.objInputSelector.selectedOption.get() == 2:
|
|
612
|
-
return self.modelName, self.startImgName, self.endImgName, self.isStepping
|
|
735
|
+
return (self.modelName, self.startImgName, self.endImgName, self.isStepping,
|
|
736
|
+
self.objDeviceSelector.device, self.objPostProcessingSelector.get_post_processing(),
|
|
737
|
+
self.objPostProcessingSelector.show_threshold)
|
|
613
738
|
|
|
614
739
|
def _run(self):
|
|
615
740
|
self.modelName = self.objModelSelector.modelCombobox.get()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|