xttmp 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. xttmp/__init__.py +1 -0
  2. xttmp/api/__init__.py +5 -0
  3. xttmp/api/evaluate.py +163 -0
  4. xttmp/api/get_visualize_handle.py +29 -0
  5. xttmp/api/instancing_model.py +35 -0
  6. xttmp/core/__init__.py +0 -0
  7. xttmp/core/apgstmd_core.py +188 -0
  8. xttmp/core/apgstmdv2_core.py +79 -0
  9. xttmp/core/base_core.py +36 -0
  10. xttmp/core/dstmd_core.py +213 -0
  11. xttmp/core/estmd_backbone.py +110 -0
  12. xttmp/core/estmd_core.py +356 -0
  13. xttmp/core/feedbackstmd_core.py +61 -0
  14. xttmp/core/fracstmd_core.py +98 -0
  15. xttmp/core/fstmd_core.py +15 -0
  16. xttmp/core/fstmdv2_core.py +42 -0
  17. xttmp/core/haarstmd_core.py +140 -0
  18. xttmp/core/math_operator.py +307 -0
  19. xttmp/core/stfeedbackstmd_core.py +233 -0
  20. xttmp/core/stmdplus_core.py +187 -0
  21. xttmp/core/stmdplusv2_core.py +82 -0
  22. xttmp/core/vstmd_core.py +420 -0
  23. xttmp/demo/evaluate_model.py +92 -0
  24. xttmp/demo/inference_gui.py +148 -0
  25. xttmp/demo/inference_gui_single_process.py +134 -0
  26. xttmp/demo/inference_image_stream.py +67 -0
  27. xttmp/demo/inference_video.py +66 -0
  28. xttmp/main.py +14 -0
  29. xttmp/model/__init__.py +13 -0
  30. xttmp/model/backbone.py +514 -0
  31. xttmp/model/facilitated_model.py +230 -0
  32. xttmp/model/feedback_model.py +271 -0
  33. xttmp/model/haarstmd.py +61 -0
  34. xttmp/model/vstmd.py +457 -0
  35. xttmp/util/__init__.py +0 -0
  36. xttmp/util/compute_module.py +402 -0
  37. xttmp/util/create_kernel.py +363 -0
  38. xttmp/util/evaluate_module.py +697 -0
  39. xttmp/util/iostream.py +660 -0
  40. xttmp-2.3.0.dist-info/METADATA +85 -0
  41. xttmp-2.3.0.dist-info/RECORD +45 -0
  42. xttmp-2.3.0.dist-info/WHEEL +5 -0
  43. xttmp-2.3.0.dist-info/entry_points.txt +2 -0
  44. xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
  45. xttmp-2.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,420 @@
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from scipy.optimize import linear_sum_assignment
5
+
6
+ from .base_core import BaseCore
7
+ from .math_operator import SpatialInhibition
8
+ from . import fracstmd_core
9
+ from ..util.create_kernel import create_2d_gaussian_kernel
10
+
11
+
12
+ class Lamina(fracstmd_core.Lamina):
13
+ """ Lamina layer of the motion detection system."""
14
+ def forward(self, LaminaIpt):
15
+ temporal_diff_output = super().forward(LaminaIpt)
16
+
17
+ lamina_ON = torch.clamp(temporal_diff_output, min=0) # ON
18
+ lamina_OFF = torch.clamp(-temporal_diff_output, min=0) # OFF
19
+
20
+ self.output = (lamina_ON, lamina_OFF)
21
+ return self.output
22
+
23
+
24
+ class cIDP(BaseCore):
25
+ ''' cross-Inhibiton Dynamics Potentials (cIDP) '''
26
+
27
+ def __init__(self):
28
+ super().__init__()
29
+
30
+ self.g_leak = 0.5 # coefficient of decay
31
+ self.v_rest = 0; # passive/rest potentials;
32
+ self.v_exci = 1; # excitatory saturation potentials;
33
+ self.reset_buffer()
34
+
35
+ def reset_buffer(self):
36
+ self.post_MP = None
37
+
38
+ def forward(self, same_polarity, oppo_polarity):
39
+ if self.post_MP is None:
40
+ self.post_MP = torch.zeros_like(same_polarity)
41
+
42
+ # Decay
43
+ decay_term = self.g_leak * (self.v_rest - self.post_MP)
44
+
45
+ # Inhibition
46
+ inhi_gain = torch.exp(oppo_polarity)
47
+
48
+ # Excitation
49
+ exci_term = same_polarity * (self.v_exci - self.post_MP)
50
+
51
+ # Euler method for solving ordinary differential equation
52
+ self.post_MP += inhi_gain * decay_term + exci_term
53
+
54
+ return self.post_MP
55
+
56
+
57
+ class Medulla(BaseCore):
58
+ """
59
+ Medulla layer of the motion detection system.
60
+
61
+ Illustration:
62
+
63
+ """
64
+
65
+ def __init__(self):
66
+ super().__init__()
67
+ # Initialize components
68
+ self.on_pathway = cIDP()
69
+ self.off_pathway = cIDP()
70
+
71
+ def setup(self):
72
+ # Initialize configurations
73
+ self.on_pathway.setup()
74
+ self.off_pathway.setup()
75
+
76
+ def forward(self, lamina_ON, lamina_OFF):
77
+ """
78
+ Process the input through the Medulla layer.
79
+
80
+ Args:
81
+ - medullaIpt (array-like): Input to the Medulla layer.
82
+
83
+ Returns:
84
+ - Von (array-like): Output ON signal from Dual-Dynamic.
85
+ - Voff (array-like): Output OFF signal from Dual-Dynamic.
86
+ """
87
+
88
+ medulla_ON = self.on_pathway.forward(lamina_ON, lamina_OFF); # ON
89
+ medulla_OFF = self.off_pathway.forward(lamina_OFF, lamina_ON); # OFF
90
+
91
+ # Store the output signals
92
+ self.output = (medulla_ON, medulla_OFF)
93
+
94
+ return self.output
95
+
96
+
97
+ class CDGC(BaseCore):
98
+ """
99
+ Collaborative Directional Encoding-Decoding (CDGC)
100
+ Pure PyTorch Implementation.
101
+ """
102
+ def __init__(self, kernel_size=3):
103
+ super().__init__()
104
+ self.kernel_size = kernel_size
105
+
106
+ self.register_buffer('corr_kernel_cos', torch.empty(0))
107
+ self.register_buffer('corr_kernel_sin', torch.empty(0))
108
+
109
+ self.setup()
110
+
111
+ def setup(self):
112
+ # 1. 在初始化时预先生成空间方向卷积核
113
+ _cos_kernel, _sin_kernel = self._create_directional_kernels()
114
+
115
+ # 2. 使用 register_buffer 注册为模型状态
116
+ self.corr_kernel_cos.data = _cos_kernel
117
+ self.corr_kernel_sin.data = _sin_kernel
118
+
119
+ def _create_directional_kernels(self):
120
+ """利用 PyTorch 原生算子,向量化生成 Cosine 和 Sine 感受野核"""
121
+ # 生成一维坐标序列,例如核大小为3时,生成 [-1, 0, 1]
122
+ coords = torch.arange(self.kernel_size, dtype=torch.float32) - self.kernel_size // 2
123
+
124
+ # 使用 meshgrid 快速生成二维网格 ('ij' 模式确保 y 对应行,x 对应列)
125
+ y, x = torch.meshgrid(coords, coords, indexing='ij')
126
+
127
+ # 计算欧氏距离
128
+ r = torch.sqrt(x**2 + y**2)
129
+
130
+ # 临时将中心点的 r 设为 1.0 以避免除以 0 导致 NaN (后面会强制将中心值设回 0)
131
+ r[r == 0] = 1.0
132
+
133
+ # 计算方向权重
134
+ cos_k = x / r
135
+ sin_k = -y / r
136
+
137
+ # 强制将中心点 (x=0, y=0) 设为 0
138
+ center = self.kernel_size // 2
139
+ cos_k[center, center] = 0.0
140
+ sin_k[center, center] = 0.0
141
+
142
+ # 调整形状为 (out_channels=1, in_channels=1, H, W)
143
+ cos_k = cos_k.view(1, 1, self.kernel_size, self.kernel_size)
144
+ sin_k = sin_k.view(1, 1, self.kernel_size, self.kernel_size)
145
+
146
+ return cos_k, sin_k
147
+
148
+ def forward(self, medulla_on, medulla_off, lamina_on, lamina_off):
149
+ """
150
+ 前向传播
151
+ 所有输入均为形状为 (B, C, H, W) 的张量
152
+ """
153
+ C = medulla_on.shape[1]
154
+
155
+ # --- 1. 计算协作编码矩阵 (cdedMatrix) ---
156
+ direction_gradient = torch.zeros_like(medulla_on)
157
+
158
+ # 提取布尔掩码
159
+ mask_on = (lamina_on > 0) & (medulla_on > 0)
160
+ mask_off = (lamina_off > 0) & (medulla_off > 0)
161
+
162
+ # 通过掩码赋值,避免任何多余的全局除法运算
163
+ direction_gradient[mask_on] = medulla_off[mask_on] / medulla_on[mask_on]
164
+ direction_gradient[mask_off] = medulla_on[mask_off] / medulla_off[mask_off]
165
+
166
+ # --- 2. 空间方向卷积 ---
167
+ # 使用 .expand() 将单通道核扩展到匹配输入通道数 C,且不增加显存消耗
168
+ # 注意: 形状调整为 (C, 1, K, K) 以适配 Depthwise Convolution
169
+ weight_cos = self.corr_kernel_cos.expand(C, 1, self.kernel_size, self.kernel_size)
170
+ weight_sin = self.corr_kernel_sin.expand(C, 1, self.kernel_size, self.kernel_size)
171
+
172
+ # groups=C 表示进行深度可分离卷积,每个通道独立计算方向
173
+ direction_cos = F.conv2d(direction_gradient, weight_cos, padding='same', groups=C)
174
+ direction_sin = F.conv2d(direction_gradient, weight_sin, padding='same', groups=C)
175
+
176
+ # --- 3. 计算角度 ---
177
+ direction = torch.atan2(direction_sin, direction_cos)
178
+
179
+ # 将范围调整到 [0, 2*pi]
180
+ self.output = torch.where(direction < 0, direction + 2 * torch.pi, direction)
181
+
182
+ self.direction_gradient = direction_gradient
183
+
184
+ return self.output
185
+
186
+
187
+ class Lobula(BaseCore):
188
+ """
189
+ Lobula layer of the motion detection system.
190
+ """
191
+
192
+ def __init__(self):
193
+ super().__init__()
194
+ self.spatial_inhibition = SpatialInhibition(B=3, e=3, sigma1=5, sigma2=10)
195
+ self.cdgc = CDGC()
196
+
197
+ self.setup()
198
+
199
+ def setup(self):
200
+ """
201
+ Initialization method.
202
+ """
203
+ self.spatial_inhibition.setup()
204
+ self.cdgc.setup()
205
+
206
+ def forward(self, medulla_on, medulla_off, lamina_on, lamina_off):
207
+ """
208
+ Processing method.
209
+
210
+ Args:
211
+ - medulla_on (np.array): ON channel signal from medulla layer.
212
+ - medulla_off (np.array): OFF channel signal from medulla layer.
213
+ - lamina_on (np.array): ON channel signal from lamina layer.
214
+ - lamina_off (np.array): OFF channel signal from lamina layer.
215
+
216
+ Returns:
217
+ - lobulaoutput (np.array): output for location.
218
+ - direction (np.array): output for direction.
219
+ - correlationOutput (np.array): output without inhibition.
220
+ """
221
+
222
+ self.correlation_output = medulla_on * medulla_off
223
+ lobula_output = self.spatial_inhibition.forward(self.correlation_output)
224
+
225
+ direction = self.cdgc.forward(medulla_on, medulla_off, lamina_on, lamina_off)
226
+
227
+ self.output = (lobula_output, direction)
228
+
229
+ return self.output
230
+
231
+
232
+ class Lobula_with_Feedback(BaseCore):
233
+ """Lobula layer of the motion detection system."""
234
+
235
+ def __init__(self):
236
+ """Constructor method."""
237
+ # Initializes the Lobula object
238
+ super().__init__()
239
+ self.spatial_inhibition = SpatialInhibition() # SpatialInhibition component
240
+ self.cdgc = CDGC()
241
+
242
+ self.beta = 1 # Parameter beta
243
+ self.sigma = 1.5 # Parameters for Gaussian kernel
244
+
245
+ self.register_buffer('gaussian_kernel', torch.empty(0)) # Buffer for Gaussian blur kernel
246
+
247
+ self.setup()
248
+
249
+ def setup(self):
250
+ """Initialization method."""
251
+ # Initializes the Lobula layer component
252
+ super().setup()
253
+
254
+ self.gaussian_kernel.data = create_2d_gaussian_kernel(size=3,
255
+ sigma= self.sigma)
256
+
257
+ self.spatial_inhibition.setup()
258
+ self.cdgc.setup()
259
+
260
+ self.reset_buffer()
261
+
262
+ def reset_buffer(self):
263
+ """Reset the buffer for feedback signal."""
264
+ self.feedback_signal = None
265
+
266
+ def forward_localization(self, medulla_ON, medulla_OFF):
267
+
268
+ self.feedback_signal = torch.zeros_like(medulla_ON)
269
+
270
+ # Formula (8)
271
+ self.v_on = torch.clamp(medulla_ON - self.feedback_signal, min=0)
272
+ self.v_off = torch.clamp(medulla_OFF - self.feedback_signal, min=0)
273
+ correlationD = self.v_on * self.v_off
274
+
275
+ # Formula (10)
276
+ correlationE = F.conv2d(medulla_ON * medulla_OFF, self.gaussian_kernel, padding='same')
277
+
278
+ # Only record (correlationD + correlationE) for next delay in Formula (9)
279
+ self.feedback_signal = self.beta * (correlationD + correlationE)
280
+
281
+ # Formula (14)
282
+ response = self.spatial_inhibition.forward(correlationD)
283
+ return response
284
+
285
+ def forward(self, medulla_ON, medulla_OFF, lamina_ON, lamina_OFF):
286
+ response = self.forward_localization(medulla_ON, medulla_OFF)
287
+ direction = self.cdgc.forward(medulla_ON, medulla_OFF, lamina_ON, lamina_OFF)
288
+
289
+ self.output = (response, direction)
290
+ return self.output
291
+
292
+
293
+ class FastEuclideanTracker:
294
+ """
295
+ 一个基于欧氏距离和匈牙利算法的极速多目标追踪器 (PyTorch 版)。
296
+ """
297
+ def __init__(self, max_distance=5.0, max_unmatched=5, device='cpu'):
298
+ """
299
+ 初始化追踪器。
300
+
301
+ Args:
302
+ max_distance (float): 匹配的最大欧氏距离阈值。
303
+ max_unmatched (int): 轨迹在被删除前允许的最大未匹配帧数。
304
+ device (str): 运行设备,'cpu' 或 'cuda'。
305
+ """
306
+ self.next_track_id = 0
307
+ self.tracks = {} # {track_id: {'center': tensor(x, y), 'unmatched_count': int, 'direction': float}}
308
+ self.max_distance = max_distance
309
+ self.max_unmatched = max_unmatched
310
+
311
+ def update(self, response):
312
+ """
313
+ 更新轨迹。
314
+ Args:
315
+ response: torch.Tensor
316
+ Returns:
317
+ results: torch.Tensor, 形状为 (N, 3), 包含 (y, x, direction)
318
+ """
319
+
320
+ device = response.device
321
+
322
+ track_ids = list(self.tracks.keys())
323
+
324
+ # === 1. 取出现有轨迹的 centers ===
325
+ if len(track_ids) > 0:
326
+ # stack 将列表中的 1D tensor 堆叠为 2D tensor (T, 2)
327
+ track_centers = torch.stack([self.tracks[tid]['center'] for tid in track_ids])
328
+ else:
329
+ track_centers = torch.empty((0, 2), dtype=torch.float32, device=device)
330
+ # === 若没有轨迹,全部新建 ===
331
+ if len(track_centers) == 0:
332
+ for i in range(len(response)):
333
+ self.tracks[self.next_track_id] = {
334
+ 'center': response[i],
335
+ 'unmatched_count': 0
336
+ }
337
+ self.next_track_id += 1
338
+ return torch.empty((0, 3), device=device)
339
+
340
+ # === 2. 向量化构造代价矩阵 ===
341
+ # torch.cdist 高度优化了欧氏距离计算,比手动广播相减再求范数更快
342
+ cost_matrix = torch.cdist(track_centers[:, -2:].float(), response[:, 2:].float(), p=2) # 形状 (T, D)
343
+
344
+ # === 3. 匈牙利匹配 ===
345
+ # PyTorch 没有内置的线性指派求解器,必须在 CPU 上用 scipy 计算
346
+ cost_matrix_np = cost_matrix.cpu().numpy()
347
+ track_idx_arr, det_idx_arr = linear_sum_assignment(cost_matrix_np)
348
+
349
+ matched_pairs = []
350
+ for t_i, d_i in zip(track_idx_arr, det_idx_arr):
351
+ if cost_matrix_np[t_i, d_i] <= self.max_distance:
352
+ matched_pairs.append((t_i, d_i))
353
+
354
+ # 用于后续集合运算的快速查找
355
+ matched_tracks = {p[0] for p in matched_pairs}
356
+ matched_dets = {p[1] for p in matched_pairs}
357
+
358
+ # === 4. 更新匹配成功的轨迹 ===
359
+ if len(matched_pairs) > 0:
360
+ # 提取按配对顺序排列的索引
361
+ t_indices = [p[0] for p in matched_pairs]
362
+ d_indices = [p[1] for p in matched_pairs]
363
+
364
+ past_centers = track_centers[t_indices]
365
+ curr_centers = response[d_indices]
366
+
367
+ # === 批量计算方向 (Tensor 运算) ===
368
+ dy = curr_centers[:, -2] - past_centers[:, -2]
369
+ dx = curr_centers[:, -1] - past_centers[:, -1]
370
+
371
+ angles = torch.atan2(-dy, dx)
372
+ angles = torch.remainder(angles, 2 * torch.pi)
373
+
374
+ # 静止目标方向设为 NaN,因为 float Tensor 无法直接存入 None
375
+ static_mask = (dx == 0) & (dy == 0)
376
+ angles[static_mask] = float('nan')
377
+
378
+ # === 写回轨迹 ===
379
+ for idx, (t_i, d_i) in enumerate(matched_pairs):
380
+ tid = track_ids[t_i]
381
+ self.tracks[tid]['center'] = curr_centers[idx]
382
+ self.tracks[tid]['unmatched_count'] = 0
383
+
384
+ ang = angles[idx].item()
385
+ self.tracks[tid]['direction'] = None if math.isnan(ang) else ang
386
+
387
+ # === 5. 未匹配的轨迹 -> unmatched_count+=1 或删除 ===
388
+ all_tracks_set = set(range(len(track_centers)))
389
+ for t_i in all_tracks_set - matched_tracks:
390
+ tid = track_ids[t_i]
391
+ if self.tracks[tid]['unmatched_count'] >= self.max_unmatched:
392
+ del self.tracks[tid]
393
+ else:
394
+ self.tracks[tid]['unmatched_count'] += 1
395
+
396
+ # === 6. 未匹配检测 -> 新建轨迹 ===
397
+ all_dets_set = set(range(len(response)))
398
+ for d_i in all_dets_set - matched_dets:
399
+ self.tracks[self.next_track_id] = {
400
+ 'center': response[d_i],
401
+ 'unmatched_count': 0
402
+ }
403
+ self.next_track_id += 1
404
+
405
+ # === 7. 输出结果 ===
406
+ # 统一返回一个 PyTorch Tensor, 形状为 (N, 3),方便下游网络直接使用
407
+ results = []
408
+ for tid, info in self.tracks.items():
409
+ direction = info.get('direction', None)
410
+ if direction is not None:
411
+ c = info['center']
412
+ # 拼接 [y, x, direction]
413
+ res_tensor = torch.tensor([c[-2], c[-1], direction], dtype=torch.float32, device=device)
414
+ results.append(res_tensor)
415
+
416
+ if results:
417
+ return torch.stack(results)
418
+ return torch.empty((0, 3), device=device)
419
+
420
+
@@ -0,0 +1,92 @@
1
+ # demo_vidstream
2
+ import os
3
+ import sys
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import json
7
+
8
+ filePath = os.path.realpath(__file__)
9
+ pyPackagePath = os.path.dirname(os.path.dirname(os.path.dirname(filePath)))
10
+ gitCodePath = os.path.dirname(pyPackagePath)
11
+ sys.path.append(pyPackagePath)
12
+
13
+ from smalltargetmotiondetectors.api import inference_task, evaluate_task # type: ignore
14
+
15
+
16
+ def inference_and_evaluate_task(modelName,
17
+ inputpath,
18
+ inputType = 'ImgstreamReader',
19
+ groundTruth = None,
20
+ gTError = 1,
21
+ startFrame = 0,
22
+ endFrame = None,
23
+ savePath1 = None,
24
+ savePath2 = None,
25
+ **kwargs):
26
+
27
+ '''inference'''
28
+ modelOpt, modelDire = inference_task(modelName, inputpath, inputType, startFrame, endFrame, **kwargs)
29
+ # save
30
+ save_as_json(savePath1, modelOpt, modelDire)
31
+
32
+ '''evaluate'''
33
+ rocFig, AUC, mR, RPIList, FPPIList, thresholdList = evaluate_task(modelOpt, groundTruth, gTError, startFrame, endFrame)
34
+ # save
35
+ rocFig.savefig('roc_curve.png') # Save as PNG file
36
+ save_as_json(savePath2, AUC, mR, RPIList, FPPIList, thresholdList)
37
+
38
+ return rocFig, AUC, mR
39
+
40
+ def save_as_json(file_name='output.json', *args, ):
41
+ """
42
+ Save multiple arguments as a JSON file.
43
+
44
+ Parameters:
45
+ - file_name (str): The name of the JSON file to save the data. Defaults to 'output.json'.
46
+ - *args: The data to be saved. Can be multiple objects of any type.
47
+ """
48
+ # Create a dictionary to hold all data
49
+ data = {}
50
+
51
+ # Generate unique keys for each argument
52
+ for i, arg in enumerate(args):
53
+ key = f"data_{i+1}"
54
+ data[key] = arg
55
+
56
+ # Ensure the file extension is '.json'
57
+ if not file_name.endswith('.json'):
58
+ file_name += '.json'
59
+
60
+ # Save data to JSON file
61
+ with open(file_name, 'w') as f:
62
+ json.dump(data, f, indent=4)
63
+
64
+ if __name__ == '__main__':
65
+ with open(os.path.join('C:\\Users\\mings\\Desktop', 'temp_result', 'gt.json'), 'r') as file:
66
+ data = json.load(file)
67
+
68
+
69
+ modelName = 'ESTMD'
70
+ inputpath = os.path.join('D:\\STMD_Dataset\\PanoramaStimuli\\BV-250-Leftward',
71
+ 'SingleTarget-TW-5-TH-5-TV-300-TL-0-Rightward-Amp-15-Theta-0-TemFre-2-SamFre-1000',
72
+ 'PanoramaStimuli*.tif')
73
+ inputType = 'ImgstreamReader'
74
+ groundTruth = data['groundTruth']
75
+ gTError = 1
76
+ startFrame = 1
77
+ endFrame = 500
78
+ savePath1 = os.path.join('C:\\Users\\mings\\Desktop', 'temp_result', 'opt1.json')
79
+ savePath2 = os.path.join('C:\\Users\\mings\\Desktop', 'temp_result', 'opt2.json')
80
+
81
+ rocFig, AUC, mR = inference_and_evaluate_task(modelName,
82
+ inputpath,
83
+ inputType,
84
+ groundTruth,
85
+ gTError,
86
+ startFrame,
87
+ endFrame,
88
+ savePath1,
89
+ savePath2,
90
+ sigma1 = 1)
91
+
92
+ plt.show()
@@ -0,0 +1,148 @@
1
+ import os
2
+ import sys
3
+ import logging
4
+ import time
5
+ from typing import Optional
6
+
7
+
8
+ import tkinter as tk
9
+ import torch
10
+
11
+ file_path = os.path.realpath(__file__)
12
+ py_pkg_path = os.path.dirname(os.path.dirname(os.path.dirname(file_path)))
13
+ sys.path.append(py_pkg_path)
14
+
15
+ try:
16
+ from xttmp.util.iostream import ( # type: ignore
17
+ ModelAndInputSelectorGUI,
18
+ FrameIterator,
19
+ FrameVisualizer,
20
+ )
21
+ from xttmp.util.compute_module import PostProcessing # type: ignore
22
+ from xttmp.api import ( # type: ignore
23
+ instancing_model,
24
+ )
25
+ except ImportError as e:
26
+ raise ImportError("Failed to import required modules. "
27
+ "Ensure that the 'xttmp' package is correctly installed.") from e
28
+
29
+ # configure logging
30
+ logging.basicConfig(level=logging.INFO,
31
+ format='%(asctime)s - %(levelname)s - %(message)s')
32
+ logger = logging.getLogger(__name__)
33
+
34
+ class StmdGui:
35
+ def __init__(self, device='cpu', show_threshold: float = 0.8, get_top_num: int = 1):
36
+ """ Initialize STMD GUI """
37
+ self.device = device
38
+ self.show_threshold = show_threshold
39
+ self.get_top_num = get_top_num
40
+ self.ModelAndInputSelectorGUI = ModelAndInputSelectorGUI
41
+ self.FrameIterator = FrameIterator
42
+ self.FrameVisualizer = FrameVisualizer
43
+ self.PostProcessing = PostProcessing
44
+ self.instancing_model = instancing_model
45
+
46
+
47
+ def _get_user_input(self) -> tuple:
48
+ """ get user input """
49
+ root = tk.Tk()
50
+ try:
51
+ gui = self.ModelAndInputSelectorGUI(root)
52
+ return gui.create_gui()
53
+ finally:
54
+ # FIX: 安全销毁逻辑
55
+ # gui.create_gui() 可能已经销毁了窗口(例如用户点击了确认按钮后代码内部调用了 destroy)
56
+ # 所以这里包裹一个 try-except,如果窗口已不在,直接忽略错误。
57
+ try:
58
+ root.destroy()
59
+ except tk.TclError:
60
+ pass
61
+
62
+ def _create_frame_reader(self, opt1: str, opt2: Optional[str]):
63
+ """Create a frame reader for a video file or an image sequence."""
64
+ if opt2 is None:
65
+ return self.FrameIterator(opt1, is_video=True, device=self.device)
66
+
67
+ reader = self.FrameIterator(os.path.dirname(opt1), is_video=False, device=self.device)
68
+ start_name = os.path.basename(opt1)
69
+ end_name = os.path.basename(opt2)
70
+
71
+ start_index = next((i for i, path in enumerate(reader.image_files)
72
+ if os.path.basename(path) == start_name), None)
73
+ end_index = next((i for i, path in enumerate(reader.image_files)
74
+ if os.path.basename(path) == end_name), None)
75
+
76
+ if start_index is None or end_index is None:
77
+ raise ValueError("Selected image range could not be located in the folder.")
78
+
79
+ if start_index > end_index:
80
+ start_index, end_index = end_index, start_index
81
+
82
+ reader._setup(start_index)
83
+ reader.total_frames = end_index + 1
84
+ return reader
85
+
86
+ def run(self):
87
+ """ run video processor"""
88
+ reader = None
89
+ visualizer = None
90
+ try:
91
+ user_input = self._get_user_input()
92
+ if not user_input:
93
+ logger.info("User cancelled input.")
94
+ return
95
+
96
+ model_name, opt1, opt2, is_stepping = user_input
97
+ reader = self._create_frame_reader(opt1, opt2)
98
+ model = self.instancing_model(model_name, device=self.device)
99
+ post_processor = self.PostProcessing(
100
+ device=self.device,
101
+ nms_radio=8,
102
+ get_top_num=self.get_top_num,
103
+ )
104
+
105
+ visualizer = self.FrameVisualizer(
106
+ window_name=model_name,
107
+ result_index_type="dots",
108
+ win_width=reader.img_width,
109
+ win_height=reader.img_height,
110
+ conf_threshold=self.show_threshold,
111
+ )
112
+ if is_stepping:
113
+ visualizer.paused = True
114
+
115
+ while True:
116
+ color_img, gray_tensor, is_valid = reader.get_next_frame()
117
+ if not is_valid:
118
+ break
119
+
120
+ if self.device != 'cpu' and torch.cuda.is_available():
121
+ torch.cuda.synchronize()
122
+ time_start = time.perf_counter()
123
+ result = model(gray_tensor)
124
+ if self.device != 'cpu' and torch.cuda.is_available():
125
+ torch.cuda.synchronize()
126
+ run_time = time.perf_counter() - time_start
127
+
128
+ dots = post_processor(result['response'], result.get('direction'))
129
+ if not visualizer.update(color_img, result=dots, process_time=run_time):
130
+ break
131
+
132
+ except Exception as e:
133
+ logger.error(f"Main process error: {str(e)}")
134
+ finally:
135
+ logger.info("Cleaning up resources...")
136
+ if visualizer is not None:
137
+ visualizer.close()
138
+ if reader is not None:
139
+ reader.release()
140
+ logger.info("Shutdown completed")
141
+
142
+ def main(show_threshold: float = 0, get_top_num: int = 10):
143
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
144
+ obj = StmdGui(DEVICE, show_threshold = show_threshold, get_top_num = get_top_num)
145
+ obj.run()
146
+
147
+ if __name__ == "__main__":
148
+ main(get_top_num = 20)