xttmp 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xttmp/__init__.py +1 -0
- xttmp/api/__init__.py +5 -0
- xttmp/api/evaluate.py +163 -0
- xttmp/api/get_visualize_handle.py +29 -0
- xttmp/api/instancing_model.py +35 -0
- xttmp/core/__init__.py +0 -0
- xttmp/core/apgstmd_core.py +188 -0
- xttmp/core/apgstmdv2_core.py +79 -0
- xttmp/core/base_core.py +36 -0
- xttmp/core/dstmd_core.py +213 -0
- xttmp/core/estmd_backbone.py +110 -0
- xttmp/core/estmd_core.py +356 -0
- xttmp/core/feedbackstmd_core.py +61 -0
- xttmp/core/fracstmd_core.py +98 -0
- xttmp/core/fstmd_core.py +15 -0
- xttmp/core/fstmdv2_core.py +42 -0
- xttmp/core/haarstmd_core.py +140 -0
- xttmp/core/math_operator.py +307 -0
- xttmp/core/stfeedbackstmd_core.py +233 -0
- xttmp/core/stmdplus_core.py +187 -0
- xttmp/core/stmdplusv2_core.py +82 -0
- xttmp/core/vstmd_core.py +420 -0
- xttmp/demo/evaluate_model.py +92 -0
- xttmp/demo/inference_gui.py +148 -0
- xttmp/demo/inference_gui_single_process.py +134 -0
- xttmp/demo/inference_image_stream.py +67 -0
- xttmp/demo/inference_video.py +66 -0
- xttmp/main.py +14 -0
- xttmp/model/__init__.py +13 -0
- xttmp/model/backbone.py +514 -0
- xttmp/model/facilitated_model.py +230 -0
- xttmp/model/feedback_model.py +271 -0
- xttmp/model/haarstmd.py +61 -0
- xttmp/model/vstmd.py +457 -0
- xttmp/util/__init__.py +0 -0
- xttmp/util/compute_module.py +402 -0
- xttmp/util/create_kernel.py +363 -0
- xttmp/util/evaluate_module.py +697 -0
- xttmp/util/iostream.py +660 -0
- xttmp-2.3.0.dist-info/METADATA +85 -0
- xttmp-2.3.0.dist-info/RECORD +45 -0
- xttmp-2.3.0.dist-info/WHEEL +5 -0
- xttmp-2.3.0.dist-info/entry_points.txt +2 -0
- xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
- xttmp-2.3.0.dist-info/top_level.txt +1 -0
xttmp/core/vstmd_core.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from scipy.optimize import linear_sum_assignment
|
|
5
|
+
|
|
6
|
+
from .base_core import BaseCore
|
|
7
|
+
from .math_operator import SpatialInhibition
|
|
8
|
+
from . import fracstmd_core
|
|
9
|
+
from ..util.create_kernel import create_2d_gaussian_kernel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Lamina(fracstmd_core.Lamina):
|
|
13
|
+
""" Lamina layer of the motion detection system."""
|
|
14
|
+
def forward(self, LaminaIpt):
|
|
15
|
+
temporal_diff_output = super().forward(LaminaIpt)
|
|
16
|
+
|
|
17
|
+
lamina_ON = torch.clamp(temporal_diff_output, min=0) # ON
|
|
18
|
+
lamina_OFF = torch.clamp(-temporal_diff_output, min=0) # OFF
|
|
19
|
+
|
|
20
|
+
self.output = (lamina_ON, lamina_OFF)
|
|
21
|
+
return self.output
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class cIDP(BaseCore):
|
|
25
|
+
''' cross-Inhibiton Dynamics Potentials (cIDP) '''
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
self.g_leak = 0.5 # coefficient of decay
|
|
31
|
+
self.v_rest = 0; # passive/rest potentials;
|
|
32
|
+
self.v_exci = 1; # excitatory saturation potentials;
|
|
33
|
+
self.reset_buffer()
|
|
34
|
+
|
|
35
|
+
def reset_buffer(self):
|
|
36
|
+
self.post_MP = None
|
|
37
|
+
|
|
38
|
+
def forward(self, same_polarity, oppo_polarity):
|
|
39
|
+
if self.post_MP is None:
|
|
40
|
+
self.post_MP = torch.zeros_like(same_polarity)
|
|
41
|
+
|
|
42
|
+
# Decay
|
|
43
|
+
decay_term = self.g_leak * (self.v_rest - self.post_MP)
|
|
44
|
+
|
|
45
|
+
# Inhibition
|
|
46
|
+
inhi_gain = torch.exp(oppo_polarity)
|
|
47
|
+
|
|
48
|
+
# Excitation
|
|
49
|
+
exci_term = same_polarity * (self.v_exci - self.post_MP)
|
|
50
|
+
|
|
51
|
+
# Euler method for solving ordinary differential equation
|
|
52
|
+
self.post_MP += inhi_gain * decay_term + exci_term
|
|
53
|
+
|
|
54
|
+
return self.post_MP
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Medulla(BaseCore):
|
|
58
|
+
"""
|
|
59
|
+
Medulla layer of the motion detection system.
|
|
60
|
+
|
|
61
|
+
Illustration:
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self):
|
|
66
|
+
super().__init__()
|
|
67
|
+
# Initialize components
|
|
68
|
+
self.on_pathway = cIDP()
|
|
69
|
+
self.off_pathway = cIDP()
|
|
70
|
+
|
|
71
|
+
def setup(self):
|
|
72
|
+
# Initialize configurations
|
|
73
|
+
self.on_pathway.setup()
|
|
74
|
+
self.off_pathway.setup()
|
|
75
|
+
|
|
76
|
+
def forward(self, lamina_ON, lamina_OFF):
|
|
77
|
+
"""
|
|
78
|
+
Process the input through the Medulla layer.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
- medullaIpt (array-like): Input to the Medulla layer.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
- Von (array-like): Output ON signal from Dual-Dynamic.
|
|
85
|
+
- Voff (array-like): Output OFF signal from Dual-Dynamic.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
medulla_ON = self.on_pathway.forward(lamina_ON, lamina_OFF); # ON
|
|
89
|
+
medulla_OFF = self.off_pathway.forward(lamina_OFF, lamina_ON); # OFF
|
|
90
|
+
|
|
91
|
+
# Store the output signals
|
|
92
|
+
self.output = (medulla_ON, medulla_OFF)
|
|
93
|
+
|
|
94
|
+
return self.output
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class CDGC(BaseCore):
|
|
98
|
+
"""
|
|
99
|
+
Collaborative Directional Encoding-Decoding (CDGC)
|
|
100
|
+
Pure PyTorch Implementation.
|
|
101
|
+
"""
|
|
102
|
+
def __init__(self, kernel_size=3):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.kernel_size = kernel_size
|
|
105
|
+
|
|
106
|
+
self.register_buffer('corr_kernel_cos', torch.empty(0))
|
|
107
|
+
self.register_buffer('corr_kernel_sin', torch.empty(0))
|
|
108
|
+
|
|
109
|
+
self.setup()
|
|
110
|
+
|
|
111
|
+
def setup(self):
|
|
112
|
+
# 1. 在初始化时预先生成空间方向卷积核
|
|
113
|
+
_cos_kernel, _sin_kernel = self._create_directional_kernels()
|
|
114
|
+
|
|
115
|
+
# 2. 使用 register_buffer 注册为模型状态
|
|
116
|
+
self.corr_kernel_cos.data = _cos_kernel
|
|
117
|
+
self.corr_kernel_sin.data = _sin_kernel
|
|
118
|
+
|
|
119
|
+
def _create_directional_kernels(self):
|
|
120
|
+
"""利用 PyTorch 原生算子,向量化生成 Cosine 和 Sine 感受野核"""
|
|
121
|
+
# 生成一维坐标序列,例如核大小为3时,生成 [-1, 0, 1]
|
|
122
|
+
coords = torch.arange(self.kernel_size, dtype=torch.float32) - self.kernel_size // 2
|
|
123
|
+
|
|
124
|
+
# 使用 meshgrid 快速生成二维网格 ('ij' 模式确保 y 对应行,x 对应列)
|
|
125
|
+
y, x = torch.meshgrid(coords, coords, indexing='ij')
|
|
126
|
+
|
|
127
|
+
# 计算欧氏距离
|
|
128
|
+
r = torch.sqrt(x**2 + y**2)
|
|
129
|
+
|
|
130
|
+
# 临时将中心点的 r 设为 1.0 以避免除以 0 导致 NaN (后面会强制将中心值设回 0)
|
|
131
|
+
r[r == 0] = 1.0
|
|
132
|
+
|
|
133
|
+
# 计算方向权重
|
|
134
|
+
cos_k = x / r
|
|
135
|
+
sin_k = -y / r
|
|
136
|
+
|
|
137
|
+
# 强制将中心点 (x=0, y=0) 设为 0
|
|
138
|
+
center = self.kernel_size // 2
|
|
139
|
+
cos_k[center, center] = 0.0
|
|
140
|
+
sin_k[center, center] = 0.0
|
|
141
|
+
|
|
142
|
+
# 调整形状为 (out_channels=1, in_channels=1, H, W)
|
|
143
|
+
cos_k = cos_k.view(1, 1, self.kernel_size, self.kernel_size)
|
|
144
|
+
sin_k = sin_k.view(1, 1, self.kernel_size, self.kernel_size)
|
|
145
|
+
|
|
146
|
+
return cos_k, sin_k
|
|
147
|
+
|
|
148
|
+
def forward(self, medulla_on, medulla_off, lamina_on, lamina_off):
|
|
149
|
+
"""
|
|
150
|
+
前向传播
|
|
151
|
+
所有输入均为形状为 (B, C, H, W) 的张量
|
|
152
|
+
"""
|
|
153
|
+
C = medulla_on.shape[1]
|
|
154
|
+
|
|
155
|
+
# --- 1. 计算协作编码矩阵 (cdedMatrix) ---
|
|
156
|
+
direction_gradient = torch.zeros_like(medulla_on)
|
|
157
|
+
|
|
158
|
+
# 提取布尔掩码
|
|
159
|
+
mask_on = (lamina_on > 0) & (medulla_on > 0)
|
|
160
|
+
mask_off = (lamina_off > 0) & (medulla_off > 0)
|
|
161
|
+
|
|
162
|
+
# 通过掩码赋值,避免任何多余的全局除法运算
|
|
163
|
+
direction_gradient[mask_on] = medulla_off[mask_on] / medulla_on[mask_on]
|
|
164
|
+
direction_gradient[mask_off] = medulla_on[mask_off] / medulla_off[mask_off]
|
|
165
|
+
|
|
166
|
+
# --- 2. 空间方向卷积 ---
|
|
167
|
+
# 使用 .expand() 将单通道核扩展到匹配输入通道数 C,且不增加显存消耗
|
|
168
|
+
# 注意: 形状调整为 (C, 1, K, K) 以适配 Depthwise Convolution
|
|
169
|
+
weight_cos = self.corr_kernel_cos.expand(C, 1, self.kernel_size, self.kernel_size)
|
|
170
|
+
weight_sin = self.corr_kernel_sin.expand(C, 1, self.kernel_size, self.kernel_size)
|
|
171
|
+
|
|
172
|
+
# groups=C 表示进行深度可分离卷积,每个通道独立计算方向
|
|
173
|
+
direction_cos = F.conv2d(direction_gradient, weight_cos, padding='same', groups=C)
|
|
174
|
+
direction_sin = F.conv2d(direction_gradient, weight_sin, padding='same', groups=C)
|
|
175
|
+
|
|
176
|
+
# --- 3. 计算角度 ---
|
|
177
|
+
direction = torch.atan2(direction_sin, direction_cos)
|
|
178
|
+
|
|
179
|
+
# 将范围调整到 [0, 2*pi]
|
|
180
|
+
self.output = torch.where(direction < 0, direction + 2 * torch.pi, direction)
|
|
181
|
+
|
|
182
|
+
self.direction_gradient = direction_gradient
|
|
183
|
+
|
|
184
|
+
return self.output
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class Lobula(BaseCore):
|
|
188
|
+
"""
|
|
189
|
+
Lobula layer of the motion detection system.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
def __init__(self):
|
|
193
|
+
super().__init__()
|
|
194
|
+
self.spatial_inhibition = SpatialInhibition(B=3, e=3, sigma1=5, sigma2=10)
|
|
195
|
+
self.cdgc = CDGC()
|
|
196
|
+
|
|
197
|
+
self.setup()
|
|
198
|
+
|
|
199
|
+
def setup(self):
|
|
200
|
+
"""
|
|
201
|
+
Initialization method.
|
|
202
|
+
"""
|
|
203
|
+
self.spatial_inhibition.setup()
|
|
204
|
+
self.cdgc.setup()
|
|
205
|
+
|
|
206
|
+
def forward(self, medulla_on, medulla_off, lamina_on, lamina_off):
|
|
207
|
+
"""
|
|
208
|
+
Processing method.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
- medulla_on (np.array): ON channel signal from medulla layer.
|
|
212
|
+
- medulla_off (np.array): OFF channel signal from medulla layer.
|
|
213
|
+
- lamina_on (np.array): ON channel signal from lamina layer.
|
|
214
|
+
- lamina_off (np.array): OFF channel signal from lamina layer.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
- lobulaoutput (np.array): output for location.
|
|
218
|
+
- direction (np.array): output for direction.
|
|
219
|
+
- correlationOutput (np.array): output without inhibition.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
self.correlation_output = medulla_on * medulla_off
|
|
223
|
+
lobula_output = self.spatial_inhibition.forward(self.correlation_output)
|
|
224
|
+
|
|
225
|
+
direction = self.cdgc.forward(medulla_on, medulla_off, lamina_on, lamina_off)
|
|
226
|
+
|
|
227
|
+
self.output = (lobula_output, direction)
|
|
228
|
+
|
|
229
|
+
return self.output
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class Lobula_with_Feedback(BaseCore):
|
|
233
|
+
"""Lobula layer of the motion detection system."""
|
|
234
|
+
|
|
235
|
+
def __init__(self):
|
|
236
|
+
"""Constructor method."""
|
|
237
|
+
# Initializes the Lobula object
|
|
238
|
+
super().__init__()
|
|
239
|
+
self.spatial_inhibition = SpatialInhibition() # SpatialInhibition component
|
|
240
|
+
self.cdgc = CDGC()
|
|
241
|
+
|
|
242
|
+
self.beta = 1 # Parameter beta
|
|
243
|
+
self.sigma = 1.5 # Parameters for Gaussian kernel
|
|
244
|
+
|
|
245
|
+
self.register_buffer('gaussian_kernel', torch.empty(0)) # Buffer for Gaussian blur kernel
|
|
246
|
+
|
|
247
|
+
self.setup()
|
|
248
|
+
|
|
249
|
+
def setup(self):
|
|
250
|
+
"""Initialization method."""
|
|
251
|
+
# Initializes the Lobula layer component
|
|
252
|
+
super().setup()
|
|
253
|
+
|
|
254
|
+
self.gaussian_kernel.data = create_2d_gaussian_kernel(size=3,
|
|
255
|
+
sigma= self.sigma)
|
|
256
|
+
|
|
257
|
+
self.spatial_inhibition.setup()
|
|
258
|
+
self.cdgc.setup()
|
|
259
|
+
|
|
260
|
+
self.reset_buffer()
|
|
261
|
+
|
|
262
|
+
def reset_buffer(self):
|
|
263
|
+
"""Reset the buffer for feedback signal."""
|
|
264
|
+
self.feedback_signal = None
|
|
265
|
+
|
|
266
|
+
def forward_localization(self, medulla_ON, medulla_OFF):
|
|
267
|
+
|
|
268
|
+
self.feedback_signal = torch.zeros_like(medulla_ON)
|
|
269
|
+
|
|
270
|
+
# Formula (8)
|
|
271
|
+
self.v_on = torch.clamp(medulla_ON - self.feedback_signal, min=0)
|
|
272
|
+
self.v_off = torch.clamp(medulla_OFF - self.feedback_signal, min=0)
|
|
273
|
+
correlationD = self.v_on * self.v_off
|
|
274
|
+
|
|
275
|
+
# Formula (10)
|
|
276
|
+
correlationE = F.conv2d(medulla_ON * medulla_OFF, self.gaussian_kernel, padding='same')
|
|
277
|
+
|
|
278
|
+
# Only record (correlationD + correlationE) for next delay in Formula (9)
|
|
279
|
+
self.feedback_signal = self.beta * (correlationD + correlationE)
|
|
280
|
+
|
|
281
|
+
# Formula (14)
|
|
282
|
+
response = self.spatial_inhibition.forward(correlationD)
|
|
283
|
+
return response
|
|
284
|
+
|
|
285
|
+
def forward(self, medulla_ON, medulla_OFF, lamina_ON, lamina_OFF):
|
|
286
|
+
response = self.forward_localization(medulla_ON, medulla_OFF)
|
|
287
|
+
direction = self.cdgc.forward(medulla_ON, medulla_OFF, lamina_ON, lamina_OFF)
|
|
288
|
+
|
|
289
|
+
self.output = (response, direction)
|
|
290
|
+
return self.output
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class FastEuclideanTracker:
|
|
294
|
+
"""
|
|
295
|
+
一个基于欧氏距离和匈牙利算法的极速多目标追踪器 (PyTorch 版)。
|
|
296
|
+
"""
|
|
297
|
+
def __init__(self, max_distance=5.0, max_unmatched=5, device='cpu'):
|
|
298
|
+
"""
|
|
299
|
+
初始化追踪器。
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
max_distance (float): 匹配的最大欧氏距离阈值。
|
|
303
|
+
max_unmatched (int): 轨迹在被删除前允许的最大未匹配帧数。
|
|
304
|
+
device (str): 运行设备,'cpu' 或 'cuda'。
|
|
305
|
+
"""
|
|
306
|
+
self.next_track_id = 0
|
|
307
|
+
self.tracks = {} # {track_id: {'center': tensor(x, y), 'unmatched_count': int, 'direction': float}}
|
|
308
|
+
self.max_distance = max_distance
|
|
309
|
+
self.max_unmatched = max_unmatched
|
|
310
|
+
|
|
311
|
+
def update(self, response):
|
|
312
|
+
"""
|
|
313
|
+
更新轨迹。
|
|
314
|
+
Args:
|
|
315
|
+
response: torch.Tensor
|
|
316
|
+
Returns:
|
|
317
|
+
results: torch.Tensor, 形状为 (N, 3), 包含 (y, x, direction)
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
device = response.device
|
|
321
|
+
|
|
322
|
+
track_ids = list(self.tracks.keys())
|
|
323
|
+
|
|
324
|
+
# === 1. 取出现有轨迹的 centers ===
|
|
325
|
+
if len(track_ids) > 0:
|
|
326
|
+
# stack 将列表中的 1D tensor 堆叠为 2D tensor (T, 2)
|
|
327
|
+
track_centers = torch.stack([self.tracks[tid]['center'] for tid in track_ids])
|
|
328
|
+
else:
|
|
329
|
+
track_centers = torch.empty((0, 2), dtype=torch.float32, device=device)
|
|
330
|
+
# === 若没有轨迹,全部新建 ===
|
|
331
|
+
if len(track_centers) == 0:
|
|
332
|
+
for i in range(len(response)):
|
|
333
|
+
self.tracks[self.next_track_id] = {
|
|
334
|
+
'center': response[i],
|
|
335
|
+
'unmatched_count': 0
|
|
336
|
+
}
|
|
337
|
+
self.next_track_id += 1
|
|
338
|
+
return torch.empty((0, 3), device=device)
|
|
339
|
+
|
|
340
|
+
# === 2. 向量化构造代价矩阵 ===
|
|
341
|
+
# torch.cdist 高度优化了欧氏距离计算,比手动广播相减再求范数更快
|
|
342
|
+
cost_matrix = torch.cdist(track_centers[:, -2:].float(), response[:, 2:].float(), p=2) # 形状 (T, D)
|
|
343
|
+
|
|
344
|
+
# === 3. 匈牙利匹配 ===
|
|
345
|
+
# PyTorch 没有内置的线性指派求解器,必须在 CPU 上用 scipy 计算
|
|
346
|
+
cost_matrix_np = cost_matrix.cpu().numpy()
|
|
347
|
+
track_idx_arr, det_idx_arr = linear_sum_assignment(cost_matrix_np)
|
|
348
|
+
|
|
349
|
+
matched_pairs = []
|
|
350
|
+
for t_i, d_i in zip(track_idx_arr, det_idx_arr):
|
|
351
|
+
if cost_matrix_np[t_i, d_i] <= self.max_distance:
|
|
352
|
+
matched_pairs.append((t_i, d_i))
|
|
353
|
+
|
|
354
|
+
# 用于后续集合运算的快速查找
|
|
355
|
+
matched_tracks = {p[0] for p in matched_pairs}
|
|
356
|
+
matched_dets = {p[1] for p in matched_pairs}
|
|
357
|
+
|
|
358
|
+
# === 4. 更新匹配成功的轨迹 ===
|
|
359
|
+
if len(matched_pairs) > 0:
|
|
360
|
+
# 提取按配对顺序排列的索引
|
|
361
|
+
t_indices = [p[0] for p in matched_pairs]
|
|
362
|
+
d_indices = [p[1] for p in matched_pairs]
|
|
363
|
+
|
|
364
|
+
past_centers = track_centers[t_indices]
|
|
365
|
+
curr_centers = response[d_indices]
|
|
366
|
+
|
|
367
|
+
# === 批量计算方向 (Tensor 运算) ===
|
|
368
|
+
dy = curr_centers[:, -2] - past_centers[:, -2]
|
|
369
|
+
dx = curr_centers[:, -1] - past_centers[:, -1]
|
|
370
|
+
|
|
371
|
+
angles = torch.atan2(-dy, dx)
|
|
372
|
+
angles = torch.remainder(angles, 2 * torch.pi)
|
|
373
|
+
|
|
374
|
+
# 静止目标方向设为 NaN,因为 float Tensor 无法直接存入 None
|
|
375
|
+
static_mask = (dx == 0) & (dy == 0)
|
|
376
|
+
angles[static_mask] = float('nan')
|
|
377
|
+
|
|
378
|
+
# === 写回轨迹 ===
|
|
379
|
+
for idx, (t_i, d_i) in enumerate(matched_pairs):
|
|
380
|
+
tid = track_ids[t_i]
|
|
381
|
+
self.tracks[tid]['center'] = curr_centers[idx]
|
|
382
|
+
self.tracks[tid]['unmatched_count'] = 0
|
|
383
|
+
|
|
384
|
+
ang = angles[idx].item()
|
|
385
|
+
self.tracks[tid]['direction'] = None if math.isnan(ang) else ang
|
|
386
|
+
|
|
387
|
+
# === 5. 未匹配的轨迹 -> unmatched_count+=1 或删除 ===
|
|
388
|
+
all_tracks_set = set(range(len(track_centers)))
|
|
389
|
+
for t_i in all_tracks_set - matched_tracks:
|
|
390
|
+
tid = track_ids[t_i]
|
|
391
|
+
if self.tracks[tid]['unmatched_count'] >= self.max_unmatched:
|
|
392
|
+
del self.tracks[tid]
|
|
393
|
+
else:
|
|
394
|
+
self.tracks[tid]['unmatched_count'] += 1
|
|
395
|
+
|
|
396
|
+
# === 6. 未匹配检测 -> 新建轨迹 ===
|
|
397
|
+
all_dets_set = set(range(len(response)))
|
|
398
|
+
for d_i in all_dets_set - matched_dets:
|
|
399
|
+
self.tracks[self.next_track_id] = {
|
|
400
|
+
'center': response[d_i],
|
|
401
|
+
'unmatched_count': 0
|
|
402
|
+
}
|
|
403
|
+
self.next_track_id += 1
|
|
404
|
+
|
|
405
|
+
# === 7. 输出结果 ===
|
|
406
|
+
# 统一返回一个 PyTorch Tensor, 形状为 (N, 3),方便下游网络直接使用
|
|
407
|
+
results = []
|
|
408
|
+
for tid, info in self.tracks.items():
|
|
409
|
+
direction = info.get('direction', None)
|
|
410
|
+
if direction is not None:
|
|
411
|
+
c = info['center']
|
|
412
|
+
# 拼接 [y, x, direction]
|
|
413
|
+
res_tensor = torch.tensor([c[-2], c[-1], direction], dtype=torch.float32, device=device)
|
|
414
|
+
results.append(res_tensor)
|
|
415
|
+
|
|
416
|
+
if results:
|
|
417
|
+
return torch.stack(results)
|
|
418
|
+
return torch.empty((0, 3), device=device)
|
|
419
|
+
|
|
420
|
+
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# demo_vidstream
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
6
|
+
import json
|
|
7
|
+
|
|
8
|
+
filePath = os.path.realpath(__file__)
|
|
9
|
+
pyPackagePath = os.path.dirname(os.path.dirname(os.path.dirname(filePath)))
|
|
10
|
+
gitCodePath = os.path.dirname(pyPackagePath)
|
|
11
|
+
sys.path.append(pyPackagePath)
|
|
12
|
+
|
|
13
|
+
from smalltargetmotiondetectors.api import inference_task, evaluate_task # type: ignore
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def inference_and_evaluate_task(modelName,
|
|
17
|
+
inputpath,
|
|
18
|
+
inputType = 'ImgstreamReader',
|
|
19
|
+
groundTruth = None,
|
|
20
|
+
gTError = 1,
|
|
21
|
+
startFrame = 0,
|
|
22
|
+
endFrame = None,
|
|
23
|
+
savePath1 = None,
|
|
24
|
+
savePath2 = None,
|
|
25
|
+
**kwargs):
|
|
26
|
+
|
|
27
|
+
'''inference'''
|
|
28
|
+
modelOpt, modelDire = inference_task(modelName, inputpath, inputType, startFrame, endFrame, **kwargs)
|
|
29
|
+
# save
|
|
30
|
+
save_as_json(savePath1, modelOpt, modelDire)
|
|
31
|
+
|
|
32
|
+
'''evaluate'''
|
|
33
|
+
rocFig, AUC, mR, RPIList, FPPIList, thresholdList = evaluate_task(modelOpt, groundTruth, gTError, startFrame, endFrame)
|
|
34
|
+
# save
|
|
35
|
+
rocFig.savefig('roc_curve.png') # Save as PNG file
|
|
36
|
+
save_as_json(savePath2, AUC, mR, RPIList, FPPIList, thresholdList)
|
|
37
|
+
|
|
38
|
+
return rocFig, AUC, mR
|
|
39
|
+
|
|
40
|
+
def save_as_json(file_name='output.json', *args, ):
|
|
41
|
+
"""
|
|
42
|
+
Save multiple arguments as a JSON file.
|
|
43
|
+
|
|
44
|
+
Parameters:
|
|
45
|
+
- file_name (str): The name of the JSON file to save the data. Defaults to 'output.json'.
|
|
46
|
+
- *args: The data to be saved. Can be multiple objects of any type.
|
|
47
|
+
"""
|
|
48
|
+
# Create a dictionary to hold all data
|
|
49
|
+
data = {}
|
|
50
|
+
|
|
51
|
+
# Generate unique keys for each argument
|
|
52
|
+
for i, arg in enumerate(args):
|
|
53
|
+
key = f"data_{i+1}"
|
|
54
|
+
data[key] = arg
|
|
55
|
+
|
|
56
|
+
# Ensure the file extension is '.json'
|
|
57
|
+
if not file_name.endswith('.json'):
|
|
58
|
+
file_name += '.json'
|
|
59
|
+
|
|
60
|
+
# Save data to JSON file
|
|
61
|
+
with open(file_name, 'w') as f:
|
|
62
|
+
json.dump(data, f, indent=4)
|
|
63
|
+
|
|
64
|
+
if __name__ == '__main__':
|
|
65
|
+
with open(os.path.join('C:\\Users\\mings\\Desktop', 'temp_result', 'gt.json'), 'r') as file:
|
|
66
|
+
data = json.load(file)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
modelName = 'ESTMD'
|
|
70
|
+
inputpath = os.path.join('D:\\STMD_Dataset\\PanoramaStimuli\\BV-250-Leftward',
|
|
71
|
+
'SingleTarget-TW-5-TH-5-TV-300-TL-0-Rightward-Amp-15-Theta-0-TemFre-2-SamFre-1000',
|
|
72
|
+
'PanoramaStimuli*.tif')
|
|
73
|
+
inputType = 'ImgstreamReader'
|
|
74
|
+
groundTruth = data['groundTruth']
|
|
75
|
+
gTError = 1
|
|
76
|
+
startFrame = 1
|
|
77
|
+
endFrame = 500
|
|
78
|
+
savePath1 = os.path.join('C:\\Users\\mings\\Desktop', 'temp_result', 'opt1.json')
|
|
79
|
+
savePath2 = os.path.join('C:\\Users\\mings\\Desktop', 'temp_result', 'opt2.json')
|
|
80
|
+
|
|
81
|
+
rocFig, AUC, mR = inference_and_evaluate_task(modelName,
|
|
82
|
+
inputpath,
|
|
83
|
+
inputType,
|
|
84
|
+
groundTruth,
|
|
85
|
+
gTError,
|
|
86
|
+
startFrame,
|
|
87
|
+
endFrame,
|
|
88
|
+
savePath1,
|
|
89
|
+
savePath2,
|
|
90
|
+
sigma1 = 1)
|
|
91
|
+
|
|
92
|
+
plt.show()
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import tkinter as tk
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
file_path = os.path.realpath(__file__)
|
|
12
|
+
py_pkg_path = os.path.dirname(os.path.dirname(os.path.dirname(file_path)))
|
|
13
|
+
sys.path.append(py_pkg_path)
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from xttmp.util.iostream import ( # type: ignore
|
|
17
|
+
ModelAndInputSelectorGUI,
|
|
18
|
+
FrameIterator,
|
|
19
|
+
FrameVisualizer,
|
|
20
|
+
)
|
|
21
|
+
from xttmp.util.compute_module import PostProcessing # type: ignore
|
|
22
|
+
from xttmp.api import ( # type: ignore
|
|
23
|
+
instancing_model,
|
|
24
|
+
)
|
|
25
|
+
except ImportError as e:
|
|
26
|
+
raise ImportError("Failed to import required modules. "
|
|
27
|
+
"Ensure that the 'xttmp' package is correctly installed.") from e
|
|
28
|
+
|
|
29
|
+
# configure logging
|
|
30
|
+
logging.basicConfig(level=logging.INFO,
|
|
31
|
+
format='%(asctime)s - %(levelname)s - %(message)s')
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
class StmdGui:
|
|
35
|
+
def __init__(self, device='cpu', show_threshold: float = 0.8, get_top_num: int = 1):
|
|
36
|
+
""" Initialize STMD GUI """
|
|
37
|
+
self.device = device
|
|
38
|
+
self.show_threshold = show_threshold
|
|
39
|
+
self.get_top_num = get_top_num
|
|
40
|
+
self.ModelAndInputSelectorGUI = ModelAndInputSelectorGUI
|
|
41
|
+
self.FrameIterator = FrameIterator
|
|
42
|
+
self.FrameVisualizer = FrameVisualizer
|
|
43
|
+
self.PostProcessing = PostProcessing
|
|
44
|
+
self.instancing_model = instancing_model
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _get_user_input(self) -> tuple:
|
|
48
|
+
""" get user input """
|
|
49
|
+
root = tk.Tk()
|
|
50
|
+
try:
|
|
51
|
+
gui = self.ModelAndInputSelectorGUI(root)
|
|
52
|
+
return gui.create_gui()
|
|
53
|
+
finally:
|
|
54
|
+
# FIX: 安全销毁逻辑
|
|
55
|
+
# gui.create_gui() 可能已经销毁了窗口(例如用户点击了确认按钮后代码内部调用了 destroy)
|
|
56
|
+
# 所以这里包裹一个 try-except,如果窗口已不在,直接忽略错误。
|
|
57
|
+
try:
|
|
58
|
+
root.destroy()
|
|
59
|
+
except tk.TclError:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
def _create_frame_reader(self, opt1: str, opt2: Optional[str]):
|
|
63
|
+
"""Create a frame reader for a video file or an image sequence."""
|
|
64
|
+
if opt2 is None:
|
|
65
|
+
return self.FrameIterator(opt1, is_video=True, device=self.device)
|
|
66
|
+
|
|
67
|
+
reader = self.FrameIterator(os.path.dirname(opt1), is_video=False, device=self.device)
|
|
68
|
+
start_name = os.path.basename(opt1)
|
|
69
|
+
end_name = os.path.basename(opt2)
|
|
70
|
+
|
|
71
|
+
start_index = next((i for i, path in enumerate(reader.image_files)
|
|
72
|
+
if os.path.basename(path) == start_name), None)
|
|
73
|
+
end_index = next((i for i, path in enumerate(reader.image_files)
|
|
74
|
+
if os.path.basename(path) == end_name), None)
|
|
75
|
+
|
|
76
|
+
if start_index is None or end_index is None:
|
|
77
|
+
raise ValueError("Selected image range could not be located in the folder.")
|
|
78
|
+
|
|
79
|
+
if start_index > end_index:
|
|
80
|
+
start_index, end_index = end_index, start_index
|
|
81
|
+
|
|
82
|
+
reader._setup(start_index)
|
|
83
|
+
reader.total_frames = end_index + 1
|
|
84
|
+
return reader
|
|
85
|
+
|
|
86
|
+
def run(self):
|
|
87
|
+
""" run video processor"""
|
|
88
|
+
reader = None
|
|
89
|
+
visualizer = None
|
|
90
|
+
try:
|
|
91
|
+
user_input = self._get_user_input()
|
|
92
|
+
if not user_input:
|
|
93
|
+
logger.info("User cancelled input.")
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
model_name, opt1, opt2, is_stepping = user_input
|
|
97
|
+
reader = self._create_frame_reader(opt1, opt2)
|
|
98
|
+
model = self.instancing_model(model_name, device=self.device)
|
|
99
|
+
post_processor = self.PostProcessing(
|
|
100
|
+
device=self.device,
|
|
101
|
+
nms_radio=8,
|
|
102
|
+
get_top_num=self.get_top_num,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
visualizer = self.FrameVisualizer(
|
|
106
|
+
window_name=model_name,
|
|
107
|
+
result_index_type="dots",
|
|
108
|
+
win_width=reader.img_width,
|
|
109
|
+
win_height=reader.img_height,
|
|
110
|
+
conf_threshold=self.show_threshold,
|
|
111
|
+
)
|
|
112
|
+
if is_stepping:
|
|
113
|
+
visualizer.paused = True
|
|
114
|
+
|
|
115
|
+
while True:
|
|
116
|
+
color_img, gray_tensor, is_valid = reader.get_next_frame()
|
|
117
|
+
if not is_valid:
|
|
118
|
+
break
|
|
119
|
+
|
|
120
|
+
if self.device != 'cpu' and torch.cuda.is_available():
|
|
121
|
+
torch.cuda.synchronize()
|
|
122
|
+
time_start = time.perf_counter()
|
|
123
|
+
result = model(gray_tensor)
|
|
124
|
+
if self.device != 'cpu' and torch.cuda.is_available():
|
|
125
|
+
torch.cuda.synchronize()
|
|
126
|
+
run_time = time.perf_counter() - time_start
|
|
127
|
+
|
|
128
|
+
dots = post_processor(result['response'], result.get('direction'))
|
|
129
|
+
if not visualizer.update(color_img, result=dots, process_time=run_time):
|
|
130
|
+
break
|
|
131
|
+
|
|
132
|
+
except Exception as e:
|
|
133
|
+
logger.error(f"Main process error: {str(e)}")
|
|
134
|
+
finally:
|
|
135
|
+
logger.info("Cleaning up resources...")
|
|
136
|
+
if visualizer is not None:
|
|
137
|
+
visualizer.close()
|
|
138
|
+
if reader is not None:
|
|
139
|
+
reader.release()
|
|
140
|
+
logger.info("Shutdown completed")
|
|
141
|
+
|
|
142
|
+
def main(show_threshold: float = 0, get_top_num: int = 10):
|
|
143
|
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
144
|
+
obj = StmdGui(DEVICE, show_threshold = show_threshold, get_top_num = get_top_num)
|
|
145
|
+
obj.run()
|
|
146
|
+
|
|
147
|
+
if __name__ == "__main__":
|
|
148
|
+
main(get_top_num = 20)
|