xttmp 2.3.0.1__tar.gz → 2.3.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {xttmp-2.3.0.1/src/xttmp.egg-info → xttmp-2.3.0.3}/PKG-INFO +16 -5
  2. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/README.md +11 -1
  3. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/pyproject.toml +8 -4
  4. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/demo/inference_gui.py +24 -38
  5. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/main.py +7 -0
  6. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/util/compute_module.py +145 -187
  7. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/util/iostream.py +172 -47
  8. {xttmp-2.3.0.1 → xttmp-2.3.0.3/src/xttmp.egg-info}/PKG-INFO +16 -5
  9. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp.egg-info/requires.txt +2 -0
  10. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/LICENSE +0 -0
  11. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/setup.cfg +0 -0
  12. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/__init__.py +0 -0
  13. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/api/__init__.py +0 -0
  14. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/api/evaluate.py +0 -0
  15. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/api/get_visualize_handle.py +0 -0
  16. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/api/instancing_model.py +0 -0
  17. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/__init__.py +0 -0
  18. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/apgstmd_core.py +0 -0
  19. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/apgstmdv2_core.py +0 -0
  20. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/base_core.py +0 -0
  21. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/dstmd_core.py +0 -0
  22. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/estmd_backbone.py +0 -0
  23. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/estmd_core.py +0 -0
  24. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/feedbackstmd_core.py +0 -0
  25. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/fracstmd_core.py +0 -0
  26. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/fstmd_core.py +0 -0
  27. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/fstmdv2_core.py +0 -0
  28. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/haarstmd_core.py +0 -0
  29. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/math_operator.py +0 -0
  30. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/stfeedbackstmd_core.py +0 -0
  31. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/stmdplus_core.py +0 -0
  32. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/stmdplusv2_core.py +0 -0
  33. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/core/vstmd_core.py +0 -0
  34. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/demo/evaluate_model.py +0 -0
  35. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/demo/inference_gui_single_process.py +0 -0
  36. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/demo/inference_image_stream.py +0 -0
  37. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/demo/inference_video.py +0 -0
  38. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/model/__init__.py +0 -0
  39. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/model/backbone.py +0 -0
  40. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/model/facilitated_model.py +0 -0
  41. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/model/feedback_model.py +0 -0
  42. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/model/haarstmd.py +0 -0
  43. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/model/vstmd.py +0 -0
  44. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/util/__init__.py +0 -0
  45. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/util/create_kernel.py +0 -0
  46. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/util/evaluate_module.py +0 -0
  47. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp/util/stmd.ico +0 -0
  48. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp.egg-info/SOURCES.txt +0 -0
  49. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp.egg-info/dependency_links.txt +0 -0
  50. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp.egg-info/entry_points.txt +0 -0
  51. {xttmp-2.3.0.1 → xttmp-2.3.0.3}/src/xttmp.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xttmp
3
- Version: 2.3.0.1
3
+ Version: 2.3.0.3
4
4
  Summary: eXtremely Tiny Target - Motion Perception
5
5
  Author-email: Shawn MX <mingshuoxu@hotmail.com>
6
6
  Project-URL: Homepage, https://github.com/MingshuoXu/Small-Target-Motion-Detectors
@@ -16,15 +16,16 @@ Classifier: Programming Language :: Python :: 3.9
16
16
  Classifier: Programming Language :: Python :: 3.10
17
17
  Classifier: Programming Language :: Python :: 3.11
18
18
  Classifier: Programming Language :: Python :: 3.12
19
- Classifier: Topic :: Scientific/Engineering :: Image Recognition
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.8
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
23
  Requires-Dist: matplotlib
24
24
  Requires-Dist: opencv-python
25
25
  Requires-Dist: scipy
26
- Requires-Dist: torch>=2.5.0
27
- Requires-Dist: torchvision>=0.20.0
26
+ Provides-Extra: torch
27
+ Requires-Dist: torch>=2.5.0; extra == "torch"
28
+ Requires-Dist: torchvision>=0.20.0; extra == "torch"
28
29
  Dynamic: license-file
29
30
 
30
31
  # Small Target Motion Detectors, Version 2.3 (XTT-MP: Extremely Tiny Target - Motion Perception)
@@ -66,8 +67,18 @@ Built with modularity and extensibility in mind, XTT-MP provides a robust suite
66
67
  - After `pip install xttmp`, use the installed code and bring your own input data, or run from a repository checkout to access the bundled examples.
67
68
 
68
69
  ### Via PyPI
70
+ #### CPU
71
+ ```bash
72
+ pip install xttmp[torch]
73
+ ```
74
+
75
+ #### NVIDIA GPU (CUDA 12.6)
76
+ ```bash
77
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
78
+ ```
79
+
80
+ ### Running the GUI Demo
69
81
  ```bash
70
- pip install xttmp
71
82
  xttmp_gui
72
83
  ```
73
84
 
@@ -37,8 +37,18 @@ Built with modularity and extensibility in mind, XTT-MP provides a robust suite
37
37
  - After `pip install xttmp`, use the installed code and bring your own input data, or run from a repository checkout to access the bundled examples.
38
38
 
39
39
  ### Via PyPI
40
+ #### CPU
41
+ ```bash
42
+ pip install xttmp[torch]
43
+ ```
44
+
45
+ #### NVIDIA GPU (CUDA 12.6)
46
+ ```bash
47
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
48
+ ```
49
+
50
+ ### Running the GUI Demo
40
51
  ```bash
41
- pip install xttmp
42
52
  xttmp_gui
43
53
  ```
44
54
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "xttmp"
7
- version = "2.3.0.1"
7
+ version = "2.3.0.3"
8
8
  description = "eXtremely Tiny Target - Motion Perception"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -20,14 +20,12 @@ classifiers = [
20
20
  "Programming Language :: Python :: 3.10",
21
21
  "Programming Language :: Python :: 3.11",
22
22
  "Programming Language :: Python :: 3.12",
23
- "Topic :: Scientific/Engineering :: Image Recognition",
23
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
24
24
  ]
25
25
  dependencies = [
26
26
  "matplotlib",
27
27
  "opencv-python",
28
28
  "scipy",
29
- "torch>=2.5.0",
30
- "torchvision>=0.20.0",
31
29
  ]
32
30
 
33
31
  [project.urls]
@@ -47,6 +45,12 @@ where = ["src"]
47
45
  [tool.setuptools.package-data]
48
46
  "xttmp" = ["util/*.ico"]
49
47
 
48
+ [project.optional-dependencies]
49
+ torch = [
50
+ "torch>=2.5.0",
51
+ "torchvision>=0.20.0"
52
+ ]
53
+
50
54
  [dependency-groups]
51
55
  dev = []
52
56
 
@@ -12,19 +12,14 @@ file_path = os.path.realpath(__file__)
12
12
  py_pkg_path = os.path.dirname(os.path.dirname(os.path.dirname(file_path)))
13
13
  sys.path.append(py_pkg_path)
14
14
 
15
- try:
16
- from xttmp.util.iostream import ( # type: ignore
17
- ModelAndInputSelectorGUI,
18
- FrameIterator,
19
- FrameVisualizer,
20
- )
21
- from xttmp.util.compute_module import PostProcessing # type: ignore
22
- from xttmp.api import ( # type: ignore
23
- instancing_model,
24
- )
25
- except ImportError as e:
26
- raise ImportError("Failed to import required modules. "
27
- "Ensure that the 'xttmp' package is correctly installed.") from e
15
+
16
+ from xttmp.util.iostream import ( # type: ignore
17
+ XTTMP_GUI,
18
+ FrameIterator,
19
+ FrameVisualizer,
20
+ )
21
+ from xttmp.api import instancing_model
22
+
28
23
 
29
24
  # configure logging
30
25
  logging.basicConfig(level=logging.INFO,
@@ -32,18 +27,15 @@ logging.basicConfig(level=logging.INFO,
32
27
  logger = logging.getLogger(__name__)
33
28
 
34
29
  class StmdGui:
35
- def __init__(self, device='cpu', show_threshold: float = 0.8, get_top_num: int = 1):
30
+ def __init__(self):
36
31
  """ Initialize STMD GUI """
37
- self.device = device
38
- self.show_threshold = show_threshold
39
- self.get_top_num = get_top_num
40
- self.ModelAndInputSelectorGUI = ModelAndInputSelectorGUI
32
+ self.device = None
33
+ self.ModelAndInputSelectorGUI = XTTMP_GUI
41
34
  self.FrameIterator = FrameIterator
42
35
  self.FrameVisualizer = FrameVisualizer
43
- self.PostProcessing = PostProcessing
36
+ self.post_processor = None
44
37
  self.instancing_model = instancing_model
45
38
 
46
-
47
39
  def _get_user_input(self) -> tuple:
48
40
  """ get user input """
49
41
  root = tk.Tk()
@@ -93,21 +85,17 @@ class StmdGui:
93
85
  logger.info("User cancelled input.")
94
86
  return
95
87
 
96
- model_name, opt1, opt2, is_stepping = user_input
88
+ model_name, opt1, opt2, is_stepping, device, post_processor, show_threshold = user_input
89
+ self.post_processor = post_processor
90
+ self.device = device
97
91
  reader = self._create_frame_reader(opt1, opt2)
98
- model = self.instancing_model(model_name, device=self.device)
99
- post_processor = self.PostProcessing(
100
- device=self.device,
101
- nms_radio=8,
102
- get_top_num=self.get_top_num,
103
- )
92
+ model = self.instancing_model(model_name, device)
104
93
 
105
94
  visualizer = self.FrameVisualizer(
106
95
  window_name=model_name,
107
- result_index_type="dots",
108
96
  win_width=reader.img_width,
109
97
  win_height=reader.img_height,
110
- conf_threshold=self.show_threshold,
98
+ conf_threshold=show_threshold,
111
99
  )
112
100
  if is_stepping:
113
101
  visualizer.paused = True
@@ -117,16 +105,17 @@ class StmdGui:
117
105
  if not is_valid:
118
106
  break
119
107
 
120
- if self.device != 'cpu' and torch.cuda.is_available():
108
+ if self.device == 'cuda':
121
109
  torch.cuda.synchronize()
122
110
  time_start = time.perf_counter()
123
111
  result = model(gray_tensor)
124
- if self.device != 'cpu' and torch.cuda.is_available():
112
+ if self.device == 'cuda':
125
113
  torch.cuda.synchronize()
126
114
  run_time = time.perf_counter() - time_start
127
115
 
128
- dots = post_processor(result['response'], result.get('direction'))
129
- if not visualizer.update(color_img, result=dots, process_time=run_time):
116
+ post_res = post_processor(result['response'], result.get('direction'))
117
+ show_str = f'{self.device} : {run_time*1000:.1f} ms'
118
+ if not visualizer.update(color_img, result=post_res, show_str=show_str):
130
119
  break
131
120
 
132
121
  except Exception as e:
@@ -139,10 +128,7 @@ class StmdGui:
139
128
  reader.release()
140
129
  logger.info("Shutdown completed")
141
130
 
142
- def main(show_threshold: float = 0, get_top_num: int = 10):
143
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
144
- obj = StmdGui(DEVICE, show_threshold = show_threshold, get_top_num = get_top_num)
145
- obj.run()
146
131
 
147
132
  if __name__ == "__main__":
148
- main(get_top_num = 20)
133
+ obj = StmdGui()
134
+ obj.run()
@@ -2,6 +2,13 @@ from pathlib import Path
2
2
  import subprocess
3
3
  import sys
4
4
 
5
+ try:
6
+ import torch
7
+ except ImportError:
8
+ raise ImportError(
9
+ "Please install PyTorch first. "
10
+ "See https://pytorch.org/get-started/locally/"
11
+ )
5
12
 
6
13
  def main():
7
14
  script_path = Path(__file__).resolve().parent / 'demo' / 'inference_gui.py'
@@ -6,67 +6,6 @@ import torch
6
6
  import torch.nn.functional as F
7
7
 
8
8
 
9
- def compute_temporal_conv(iptCell, kernel, pointer=None):
10
- """
11
- Computes temporal convolution.
12
-
13
- Parameters:
14
- - iptCell: A list of arrays where each element has the same dimension.
15
- - kernel: A vector representing the convolution kernel.
16
- - headPointer: Head pointer of the input cell array (optional).
17
-
18
- Returns:
19
- - optMatrix: The result of the temporal convolution.
20
- """
21
-
22
- # Default value for headPointer
23
- if pointer is None:
24
- pointer = len(iptCell) - 1
25
-
26
- # Initialize output matrix
27
- if iptCell[pointer] is None:
28
- return None
29
-
30
- # Ensure kernel is a vector
31
- kernel = np.squeeze(kernel)
32
- if not np.ndim(kernel) == 1:
33
- raise ValueError('The kernel must be a vector.')
34
-
35
- # Determine the lengths of input cell array and kernel
36
- k1 = len(iptCell)
37
- k2 = len(kernel)
38
- length = min(k1, k2)
39
-
40
- if isinstance(iptCell[pointer], np.ndarray):
41
- optMatrix = np.zeros_like(iptCell[pointer])
42
- elif isinstance(iptCell[pointer], torch.Tensor):
43
- optMatrix = torch.zeros_like(iptCell[pointer])
44
- # Perform temporal convolution
45
- for t in range(length):
46
- j = (pointer - t) % k1
47
- if abs(kernel[t]) > 1e-16 and iptCell[j] is not None:
48
- optMatrix += iptCell[j] * kernel[t]
49
-
50
- return optMatrix
51
-
52
-
53
- def compute_circularlist_conv(circularCell, temporalKernel):
54
- """
55
- Compute the convolution of a circular cell with a temporal kernel.
56
-
57
- Args:
58
- - circularCell: The circular cell data.
59
- - temporalKernel: The temporal kernel data.
60
-
61
- Returns:
62
- - opt_matrix: The result of the convolution.
63
- """
64
- optMatrix = compute_temporal_conv(circularCell,
65
- temporalKernel,
66
- circularCell.pointer )
67
- return optMatrix
68
-
69
-
70
9
  def compute_response(ipt):
71
10
  """
72
11
  Computes the maximum response from multiple inputs.
@@ -233,139 +172,17 @@ class AreaNMS:
233
172
  return matrix * (matrix == local_max)
234
173
 
235
174
 
236
- def get_top_k_torch(response_tensor, direction_tensor, k=1000):
237
- """
238
- 输入:
239
- response_tensor: (..., H, W) 任意维度的 Tensor
240
- direction_tensor: (..., H, W) 形状需与 response 匹配 (可选)
241
- 输出:
242
- torch.Tensor: shape=(M, 4), dtype=float32, 其中 M <= k
243
- 格式: [[x, y, response, direction], ...]
244
- """
245
- # 1. 获取维度
246
- H, W = response_tensor.shape[-2:]
247
- k = min(k, H * W)
248
-
249
- # 2. 展平 (Flatten)
250
- # view(-1) 零拷贝,极快
251
- flat_response = response_tensor.view(-1)
252
-
253
- # 3. TopK (GPU 上极速排序)
254
- top_vals, top_indices = torch.topk(flat_response, k=k)
255
-
256
- # 4. 过滤掉 <= 0 的值 ---
257
- # 创建掩码:只保留大于 0 的值
258
- mask = top_vals > 0
259
-
260
- # 如果全都是 0,直接返回空数组,避免后续报错
261
- if not mask.any():
262
- return torch.empty((0, 4))
263
-
264
- # 应用掩码,缩减 tensor 长度
265
- top_vals = top_vals[mask]
266
- top_indices = top_indices[mask]
267
- # ------------------------------------
268
-
269
- # 5. 计算坐标 (x, y)
270
- # 此时计算量已经减少,只计算非零点
271
- top_y = top_indices.div(W, rounding_mode='floor').float()
272
- top_x = (top_indices % W).float()
273
-
274
- # 6. 获取 Direction
275
- if direction_tensor is not None and direction_tensor.numel() > 0:
276
- flat_direction = direction_tensor.view(-1)
277
- # 注意:这里使用过滤后的 top_indices
278
- top_dirs = flat_direction[top_indices]
279
- else:
280
- top_dirs = torch.empty_like(top_vals).fill_(float('nan'))
281
-
282
- # 7. 堆叠 (Stack) -> (M, 4)
283
- result_tensor = torch.stack([top_x, top_y, top_vals, top_dirs], dim=1)
284
-
285
- return result_tensor
286
-
287
-
288
- def get_top_k_numpy(response_array, direction_array=None, k=1000):
289
- """
290
- 输入:
291
- response_array: (..., H, W) numpy.ndarray
292
- direction_array: (..., H, W) (可选)
293
- 输出:
294
- numpy.ndarray: shape=(M, 4), dtype=float32, 其中 M <= k
295
- 格式: [[x, y, response, direction], ...]
296
- """
297
- # 1. 获取维度
298
- shape = response_array.shape
299
- H, W = shape[-2:]
300
-
301
- # 零拷贝展平
302
- flat_response = response_array.ravel()
303
- k = min(k, flat_response.size)
304
-
305
- # 2. TopK 核心优化 (O(N))
306
- # argpartition 找出最大的 k 个 (无序)
307
- unsorted_top_indices = np.argpartition(flat_response, -k)[-k:]
308
- unsorted_top_vals = flat_response[unsorted_top_indices]
309
-
310
- # 3. 局部排序 (O(k log k))
311
- # argsort 默认升序,[::-1] 翻转为降序
312
- sort_idx = np.argsort(unsorted_top_vals)[::-1]
313
-
314
- # 获取排序后的 Top K 索引和值
315
- top_indices = unsorted_top_indices[sort_idx]
316
- top_vals = unsorted_top_vals[sort_idx]
317
-
318
- # --- [关键修改] 4. 过滤掉 <= 0 的值 ---
319
- # 创建掩码
320
- mask = top_vals > 0
321
-
322
- # 极速判断:如果没有有效值,直接返回空数组
323
- # np.any() 很快
324
- if not np.any(mask):
325
- return np.empty((0, 4), dtype=np.float32)
326
-
327
- # 应用掩码 (切片操作,只保留有效值)
328
- # 因为 k 通常不大 (比如 1000),这里的拷贝开销可忽略不计
329
- top_vals = top_vals[mask]
330
- top_indices = top_indices[mask]
331
-
332
- # 更新实际数量 M
333
- M = top_vals.size
334
- # ------------------------------------
335
-
336
- # 5. 计算坐标 (x, y)
337
- # 只对过滤后的索引计算,节省算力
338
- top_y, top_x = np.unravel_index(top_indices, (H, W))
339
-
340
- # 6. 获取 Direction
341
- if direction_array is not None and direction_array.size > 0:
342
- flat_direction = direction_array.ravel()
343
- top_dirs = flat_direction[top_indices]
344
- else:
345
- top_dirs = np.full(M, np.nan, dtype=np.float32)
346
-
347
- # 7. 堆叠结果
348
- # 分配恰好大小为 M 的内存
349
- result = np.empty((M, 4), dtype=np.float32)
350
- result[:, 0] = top_x # x
351
- result[:, 1] = top_y # y
352
- result[:, 2] = top_vals # response
353
- result[:, 3] = top_dirs # direction
354
-
355
- return result
356
-
357
-
358
175
  class PostProcessing:
359
176
  """
360
177
  Post-processing class to apply AreaNMS, get top K, and return list format.
361
178
  """
362
179
 
363
- def __init__(self, device='cpu', nms_radio = 8, get_top_num=1000):
180
+ def __init__(self, nms_radio = 8, get_top_num=1000):
364
181
  """
365
182
  Args:
366
- device (str): Computing device ('cpu' or 'cuda').
183
+ nms_radio (int): Radius for AreaNMS.
184
+ get_top_num (int): Number of top points to extract.
367
185
  """
368
- self.device = device
369
186
  self.area_nms = AreaNMS(radio=nms_radio)
370
187
  self.get_top_num = get_top_num
371
188
 
@@ -389,7 +206,7 @@ class PostProcessing:
389
206
  """
390
207
  nms_response = self.area_nms(response)
391
208
 
392
- res = get_top_k_torch(nms_response,
209
+ res, _ = get_top_k_torch(nms_response,
393
210
  direction,
394
211
  k=self.get_top_num)
395
212
  if res.shape[0] == 0:
@@ -400,3 +217,144 @@ class PostProcessing:
400
217
  res[:, 2] /= max_score
401
218
 
402
219
  return res
220
+
221
+
222
+ @torch.no_grad()
223
+ def gen_bboxes_around_points(results, box_size=16, shift_ratio=0.3):
224
+ """Generate initial bboxes around detected motion points.
225
+
226
+ Args:
227
+ results: (N, 4) -> [x, y, response, direction]
228
+ box_size: int, box size
229
+
230
+ Returns:
231
+ [[x1, y1, x2, y2], ... ] (N, 4) tensors
232
+ """
233
+ N = results.shape[0]
234
+
235
+ if N == 0:
236
+ return torch.empty((0, 4), device=results.device, dtype=torch.int)
237
+
238
+ rear_x, rear_y, direction = results[:, 0], results[:, 1], results[:, 3]
239
+
240
+ radius = box_size * 0.5
241
+ shift_mag = box_size * shift_ratio
242
+
243
+ # 1. 计算偏移量 (dx, dy),根据方向和预设的 shift_mag
244
+ dx = torch.cos(direction) * shift_mag
245
+ dy = -torch.sin(direction) * shift_mag
246
+
247
+ # 2. 一次性将 NaN 偏移量替换为 0.0
248
+ dx = torch.nan_to_num(dx, nan=0.0)
249
+ dy = torch.nan_to_num(dy, nan=0.0)
250
+
251
+ # 3. 计算中心点
252
+ center_x = rear_x + dx
253
+ center_y = rear_y + dy
254
+
255
+ x1 = (center_x - radius)
256
+ y1 = (center_y - radius)
257
+ x2 = (center_x + radius)
258
+ y2 = (center_y + radius)
259
+
260
+ # Stack as (N, 4) -> [x1, y1, x2, y2]
261
+ return torch.stack([x1, y1, x2, y2], dim=1)
262
+
263
+
264
+ @torch.no_grad()
265
+ def get_top_k_torch(response_tensor, direction_tensor=None, k=100):
266
+ """
267
+ Extract the top-k points with highest responses from feature maps, filtering out non-positive values.
268
+
269
+ Args:
270
+ response_tensor (torch.Tensor): The response map tensor of shape (B, 1, H, W) or (B, H, W).
271
+ direction_tensor (torch.Tensor, optional): The corresponding direction map tensor of
272
+ shape (B, 1, H, W) or (B, H, W). Must match response_tensor's shape. Defaults to None.
273
+ k (int, optional): The maximum number of top points to extract per batch. Defaults to 100.
274
+
275
+ Returns:
276
+ Tuple[torch.Tensor, torch.Tensor]:
277
+ - results (torch.Tensor): A tensor of shape (M, 4) containing the valid extracted points
278
+ across the entire batch. M <= B * k. Each row is formatted as [x, y, response, direction].
279
+ - batch_ids (torch.Tensor): A 1D tensor of shape (M,) containing the corresponding
280
+ batch index (from 0 to B-1) for each point in `results`. dtype is torch.long.
281
+ """
282
+ B, _, H, W = response_tensor.shape
283
+ k = min(k, H * W)
284
+ device = response_tensor.device
285
+
286
+ # 1. Flatten -> (B, H*W)
287
+ flat_response = response_tensor.reshape(B, -1)
288
+
289
+ # 2. TopK -> top_vals and top_indices are both (B, k)
290
+ top_vals, top_indices = torch.topk(flat_response, k=k, dim=-1)
291
+
292
+ # 3. Get Direction -> (B, k)
293
+ if direction_tensor is not None and direction_tensor.numel() > 0:
294
+ flat_direction = direction_tensor.reshape(B, -1)
295
+ top_dirs = torch.gather(flat_direction, dim=-1, index=top_indices)
296
+ else:
297
+ top_dirs = torch.full_like(top_vals, float('nan'))
298
+
299
+ # 4. Calculate coordinates (x, y) -> (B, k)
300
+ top_y = top_indices.div(W, rounding_mode='floor').float()
301
+ top_x = (top_indices % W).float()
302
+
303
+ # 5. Stack -> merge on the last dimension, shape becomes (B, k, 4)
304
+ stacked = torch.stack([top_x, top_y, top_vals, top_dirs], dim=-1)
305
+
306
+ # 6. Generate Mask -> (B, k)
307
+ mask = top_vals > 0
308
+
309
+ # 7. Split and filter by Batch
310
+ result_list = []
311
+ batch_id_list = []
312
+
313
+ for i in range(B):
314
+ batch_mask = mask[i] # Get the mask for the i-th batch
315
+
316
+ # Apply mask: [k, 4] -> [M_i, 4]
317
+ valid_stacked = stacked[i][batch_mask]
318
+ result_list.append(valid_stacked)
319
+
320
+ # Create a batch index tensor of shape (M_i,) filled with the current batch index 'i'
321
+ batch_id_list.append(torch.full((valid_stacked.shape[0],), i, device=device, dtype=torch.long))
322
+
323
+ # Concatenate all valid items into continuous tensors
324
+ return torch.cat(result_list, dim=0), torch.cat(batch_id_list, dim=0)
325
+
326
+
327
+ @torch.no_grad()
328
+ def get_STMD_region_proposal(response_tensor, direction_tensor=None, top_k=1, box_size=16, spatial_scale=1.0, shift_ratio=0.3):
329
+ nms_win = int(box_size * spatial_scale) | 1 # 确保是奇数
330
+ score_mask = F.max_pool2d(response_tensor, kernel_size=nms_win, stride=1, padding=nms_win//2)
331
+
332
+ nms_response_tensor = torch.where(response_tensor == score_mask, response_tensor, 0.0)
333
+
334
+ vSTMD_res, batch_id = get_top_k_torch(nms_response_tensor, direction_tensor, k=top_k)
335
+
336
+ if spatial_scale > 1:
337
+ vSTMD_res[:, :2] *= spatial_scale # 将坐标放大回原图尺度
338
+
339
+ bboxes = gen_bboxes_around_points(vSTMD_res, box_size, shift_ratio)
340
+
341
+ return vSTMD_res, bboxes, batch_id
342
+
343
+
344
+ @torch.no_grad()
345
+ def bbox_post_processing(top_k=1, box_size=16, spatial_scale=1.0, shift_ratio=0.3):
346
+
347
+ def post_process_func(
348
+ response_tensor,
349
+ direction_tensor=None
350
+ ):
351
+ vSTMD_res, bboxes, _ = get_STMD_region_proposal(response_tensor,
352
+ direction_tensor,
353
+ top_k=top_k,
354
+ box_size=box_size,
355
+ spatial_scale=spatial_scale,
356
+ shift_ratio=shift_ratio )
357
+ return torch.cat([bboxes, vSTMD_res[..., 2:3]], dim=1)
358
+
359
+ return post_process_func
360
+
@@ -3,6 +3,7 @@ import re
3
3
  from pathlib import Path
4
4
  import logging
5
5
  from typing import Optional, List, Union, Tuple, Any
6
+ from functools import partial
6
7
 
7
8
  import cv2
8
9
  import numpy as np
@@ -12,6 +13,7 @@ import torch
12
13
 
13
14
 
14
15
  from .. import model
16
+ from .compute_module import PostProcessing, bbox_post_processing
15
17
 
16
18
 
17
19
  # Get the full path of this file
@@ -215,7 +217,6 @@ class FrameIterator:
215
217
 
216
218
  class FrameVisualizer:
217
219
  def __init__(self, window_name="Visualizer",
218
- result_index_type="matrix",
219
220
  win_width=None, win_height=None,
220
221
  is_headless=False,
221
222
  conf_threshold=0.8 # 阈值参数
@@ -225,7 +226,6 @@ class FrameVisualizer:
225
226
  :param conf_threshold: 可视化过滤的相对阈值 (0.0 ~ 1.0)
226
227
  """
227
228
  self.window_name = window_name
228
- self.result_index_type = result_index_type # "matrix", "dots", "bbox"
229
229
  self.win_width = win_width or 800
230
230
  self.win_height = win_height or 600
231
231
  self.is_headless = is_headless
@@ -260,24 +260,26 @@ class FrameVisualizer:
260
260
  self.save_output = True
261
261
  print(f">>> Video writer initialized: {output_path}")
262
262
 
263
- def update(self, frame, result=None, direction=None, annotation=None, process_time=None) -> bool:
263
+ def update(self, frame, result=None, direction=None, annotation=None, show_str=None) -> bool:
264
264
  if frame is None:
265
265
  return False
266
266
 
267
267
  # --- 绘制逻辑 ---
268
268
  # 即使 result 是空的,只要不为 None 也可以处理
269
269
  if result is not None:
270
- if self.result_index_type == "matrix":
270
+ if result.dim() == 4:
271
271
  self._draw_matrix(frame, result, direction, self.conf_threshold)
272
- elif self.result_index_type == "dots":
272
+ elif result.shape[1] == 4:
273
273
  result = result.cpu().numpy() if isinstance(result, torch.Tensor) else result
274
274
  self._draw_dots(frame, result, self.conf_threshold)
275
- elif self.result_index_type == "bbox":
275
+ elif result.shape[1] == 5:
276
276
  self._draw_bbox(frame, result, self.conf_threshold, annotation)
277
-
277
+ device_str = f'{result.device}'
278
+ else:
279
+ device_str = 'Time'
278
280
  # --- 信息显示 ---
279
- if process_time is not None:
280
- cv2.putText(frame, f'Time: {process_time*1000:.1f} ms',
281
+ if show_str is not None and show_str != '':
282
+ cv2.putText(frame, str(show_str),
281
283
  (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8,
282
284
  (0, 255, 0), 2, cv2.LINE_AA)
283
285
 
@@ -344,10 +346,12 @@ class FrameVisualizer:
344
346
  @staticmethod
345
347
  def _draw_matrix(frame, matrix, direction_map, threshold):
346
348
  """处理 Matrix 格式 (Heatmap)"""
347
- if np.max(matrix) <= 0: return
349
+ if torch.max(matrix) <= 0: return
348
350
 
349
351
  # np.where 返回 (rows, cols) 即 (y, x)
350
- rows, cols = np.where(matrix > threshold)
352
+ _, _, rows, cols = torch.where(matrix > threshold)
353
+ rows = rows.cpu().numpy()
354
+ cols = cols.cpu().numpy()
351
355
 
352
356
  # 画点
353
357
  for r, c in zip(rows, cols):
@@ -358,7 +362,7 @@ class FrameVisualizer:
358
362
  # 画箭头
359
363
  if direction_map is not None and len(rows) > 0:
360
364
  # 确保 direction_map 维度匹配,这里假设是同样大小的矩阵
361
- valid_dirs = direction_map[rows, cols]
365
+ valid_dirs = direction_map[0, 0, rows, cols]
362
366
 
363
367
  # 过滤 NaN
364
368
  valid_mask = ~np.isnan(valid_dirs)
@@ -450,9 +454,9 @@ class ModelSelectorGUI:
450
454
  self.modelLabel = ttk.Label(self.root, text="Select a model:", width = 15)
451
455
  self.modelLabel.grid(row=0, column=0, padx=10, pady=10)
452
456
 
453
- self.modelCombobox = ttk.Combobox(self.root, values=modelList, width = 30)
457
+ self.modelCombobox = ttk.Combobox(self.root, values=modelList, width = 25)
454
458
  self.modelCombobox.current(11)
455
- self.modelCombobox.grid(row=0, column=1, columnspan=2, padx=10, pady=10)
459
+ self.modelCombobox.grid(row=0, column=1, columnspan=2, pady=10, sticky='w')
456
460
 
457
461
 
458
462
  class InputSelectorGUI:
@@ -476,25 +480,24 @@ class InputSelectorGUI:
476
480
  self.endImgName = None
477
481
 
478
482
  def create_gui(self):
479
- self.inputTypeLabel = ttk.Label(self.root, text="Select input from:", width = 15)
480
- self.inputTypeLabel.grid(row=1, column=0, padx=10, pady=10)
483
+ self.inputTypeLabel = ttk.Label(self.root, text="Input Type:", width = 15)
484
+ self.inputTypeLabel.grid(row=1, column=0, padx=10, pady=10, sticky='w')
481
485
 
482
486
  self.selectedOption = tk.IntVar(value=0)
483
487
 
484
-
485
- self.vidLabel = ttk.Radiobutton(self.root,
486
- text='Video stream',
487
- variable=self.selectedOption,
488
- value=1,
489
- command=self.select_vidstream)
490
- self.vidLabel.grid(row=1, column=2, padx=10, pady=10)
491
-
492
488
  self.imgLabel = ttk.Radiobutton(self.root,
493
- text='Image stream',
489
+ text='Image Sequence',
494
490
  variable=self.selectedOption,
495
491
  value=2,
496
492
  command=self.select_imgstream)
497
- self.imgLabel.grid(row=1, column=1, padx=10, pady=10)
493
+ self.imgLabel.grid(row=1, column=2, padx=10, pady=10, sticky="w")
494
+
495
+ self.vidLabel = ttk.Radiobutton(self.root,
496
+ text='Video',
497
+ variable=self.selectedOption,
498
+ value=1,
499
+ command=self.select_vidstream)
500
+ self.vidLabel.grid(row=1, column=1, padx=10, pady=10, sticky="w")
498
501
 
499
502
  def select_vidstream(self):
500
503
  self.imgSelectFolder = None
@@ -503,16 +506,16 @@ class InputSelectorGUI:
503
506
  for element in self.imgElement.values():
504
507
  element.destroy()
505
508
 
506
- self.vidElement['lblVidIndicate'] = ttk.Label(self.root, text= 'Video\'s path:', width = 15)
507
- self.vidElement['lblVidIndicate'].grid(row=2, column=0, padx=10, pady=30)
509
+ self.vidElement['lblVidIndicate'] = ttk.Label(self.root, text= 'Video\'s path:')
510
+ self.vidElement['lblVidIndicate'].grid(row=2, column=1, padx=10, pady=30, sticky='w')
508
511
  self.vidElement['lblVidPath'] = ttk.Label(self.root,
509
512
  text="Waiting for the selection",
510
513
  wraplength=220
511
514
  )
512
- self.vidElement['lblVidPath'].grid(row=2, column=1, columnspan=2, padx=10, pady=10)
515
+ self.vidElement['lblVidPath'].grid(row=2, column=2, padx=10, pady=10, sticky='w')
513
516
 
514
517
  self.vidElement['btn'] = ttk.Button(self.root, text="Select a video", command=self._clicked_vid)
515
- self.vidElement['btn'].grid(row=3, column=2, padx=10, pady=10)
518
+ self.vidElement['btn'].grid(row=3, column=2, padx=10, pady=10, sticky='w')
516
519
 
517
520
  def _clicked_vid(self):
518
521
  self.vidName = filedialog.askopenfilenames(initialdir=VID_DEFAULT_FOLDER)
@@ -524,15 +527,20 @@ class InputSelectorGUI:
524
527
  for element in self.vidElement.values():
525
528
  element.destroy()
526
529
 
527
- self.imgElement['lblFolder'] = ttk.Label(self.root, text="Image's folder: ", width = 15)
528
- self.imgElement['lblFolder'].grid(row=2, column=0, padx=10, pady=10)
530
+ self.imgElement['lblFolder'] = ttk.Label(self.root, text="Image's folder: ")
531
+ self.imgElement['lblFolder'].grid(row=2, column=1, padx=10, pady=10, sticky='w')
529
532
  self.imgElement['lblFolderName'] = ttk.Label(self.root, text="Waiting for the selection", wraplength=220)
530
- self.imgElement['lblFolderName'].grid(row=2, column=1, columnspan=2, padx=10, pady=30)
533
+ self.imgElement['lblFolderName'].grid(row=2, column=2, padx=10, pady=30, sticky='w')
531
534
 
532
535
  self.imgElement['btnStart'] = ttk.Button(self.root, text="Select start frame", command=self._clicked_start_img)
533
- self.imgElement['btnStart'].grid(row=3, column=1, padx=10, pady=10)
536
+ self.imgElement['btnStart'].grid(row=3, column=1, padx=10, pady=10, sticky='w')
537
+ self.imgElement['lblStartImg'] = ttk.Label(self.root, text=self.startImgName)
538
+ self.imgElement['lblStartImg'].grid(row=3, column=2, padx=10, pady=10, sticky='w')
539
+
534
540
  self.imgElement['btnEnd'] = ttk.Button(self.root, text="Select end frame", command=self._clicked_end_img)
535
- self.imgElement['btnEnd'].grid(row=4, column=1, padx=10, pady=10)
541
+ self.imgElement['btnEnd'].grid(row=4, column=1, padx=10, pady=10, sticky='w')
542
+ self.imgElement['lblEndImg'] = ttk.Label(self.root, text=self.endImgName)
543
+ self.imgElement['lblEndImg'].grid(row=4, column=2, padx=10, pady=10, sticky='w')
536
544
 
537
545
  def _clicked_start_img(self):
538
546
  startImgFullPath = filedialog.askopenfilenames(
@@ -551,8 +559,7 @@ class InputSelectorGUI:
551
559
  self.imgSelectFolder = self.startFolder
552
560
  self.imgElement['lblFolderName'].config(text=self.imgSelectFolder)
553
561
 
554
- self.imgElement['lblStartImg'] = ttk.Label(self.root, text=self.startImgName)
555
- self.imgElement['lblStartImg'].grid(row=3, column=2, padx=10, pady=10)
562
+ self.imgElement['lblStartImg'].config(text=self.startImgName)
556
563
 
557
564
  def _clicked_end_img(self):
558
565
  endImgFullPath = filedialog.askopenfilenames(
@@ -573,16 +580,127 @@ class InputSelectorGUI:
573
580
  self.imgSelectFolder = self.endFolder
574
581
  self.imgElement['lblFolderName'].config(text=self.imgSelectFolder)
575
582
 
576
- self.imgElement['lblEndImg'] = ttk.Label(self.root, text=self.endImgName)
577
- self.imgElement['lblEndImg'].grid(row=4, column=2, padx=10, pady=10)
583
+ self.imgElement['lblEndImg'].config(text=self.endImgName)
584
+
585
+
586
+ class PostProcessingSelectorGUI:
587
+ def __init__(self, root):
588
+ self.root = root
589
+ self.output_type = "dot"
590
+ self.show_threshold = 0.0
591
+ self.top_num = 1
592
+
593
+ self.outputTypeLabel = ttk.Label(self.root, text="Output Type:", width = 15)
594
+ self.outputTypeLabel.grid(row=5, column=0, padx=10, pady=10)
595
+
596
+ self.selectedOption = tk.IntVar(value=2)
597
+
598
+ self.dotLabel = ttk.Radiobutton(self.root,
599
+ text='dot output',
600
+ variable=self.selectedOption,
601
+ value=2,
602
+ command=self.select_dot)
603
+ self.dotLabel.grid(row=5, column=1, padx=10, pady=10, sticky="w")
604
+
605
+ self.bboxLabel = ttk.Radiobutton(self.root,
606
+ text='bbox output',
607
+ variable=self.selectedOption,
608
+ value=1,
609
+ command=self.select_bbox)
610
+ self.bboxLabel.grid(row=5, column=2, padx=10, pady=10, sticky="w")
611
+
612
+ self.showThresholdLabel = ttk.Label(self.root, text="Threshold:", width=10)
613
+ self.showThresholdLabel.grid(row=6, column=1, padx=10, pady=10, sticky='w')
578
614
 
615
+ self.showThresholdVar = tk.StringVar(value="0")
616
+ self.showThresholdVar.trace_add('write', self.update_show_threshold)
617
+ self.showThresholdEntry = ttk.Entry(self.root, textvariable=self.showThresholdVar, width=5)
618
+ self.showThresholdEntry.grid(row=6, column=2, padx=10, pady=10, sticky='w')
579
619
 
580
- class ModelAndInputSelectorGUI:
620
+ self.getTopNumLabel = ttk.Label(self.root, text="Top Num:", width=10)
621
+ self.getTopNumLabel.grid(row=7, column=1, padx=10, pady=10, sticky='w')
622
+
623
+ self.getTopNumVar = tk.StringVar(value="1")
624
+ self.getTopNumVar.trace_add('write', self.update_top_num)
625
+ self.getTopNumEntry = ttk.Entry(self.root, textvariable=self.getTopNumVar, width=5)
626
+ self.getTopNumEntry.grid(row=7, column=2, padx=10, pady=10, sticky='w')
627
+
628
+ self.select_dot()
629
+
630
+ def select_dot(self):
631
+ self.selectedOption.set(2)
632
+ self.output_type = "dot"
633
+
634
+ def select_bbox(self):
635
+ self.selectedOption.set(1)
636
+ self.output_type = "bbox"
637
+
638
+ def get_post_processing(self):
639
+ if self.output_type == "dot":
640
+ return PostProcessing(get_top_num = self.top_num)
641
+ elif self.output_type == "bbox":
642
+ return bbox_post_processing(self.top_num)
643
+ else:
644
+ raise ValueError(f"Unknown output type: {self.output_type}")
645
+
646
+ def update_show_threshold(self, *args):
647
+ value = float(self.showThresholdVar.get())
648
+ self.show_threshold = min(max(value, 0.0), 1.0) # 确保在 [0.0, 1.0] 范围内
649
+
650
+ def update_top_num(self, *args):
651
+ value = int(self.getTopNumVar.get())
652
+ self.top_num = max(value, 1) # 确保 top_num 至少为 1
653
+
654
+
655
+
656
+
657
+ class DeviceSelectorGUI:
581
658
  def __init__(self, root):
582
659
  self.root = root
660
+ self.device = "cpu"
583
661
 
584
- windowHeight = 350
585
- windowWidth = 400
662
+ self.deviceLabel = ttk.Label(self.root, text="Select Device:", width=15)
663
+ self.deviceLabel.grid(row=8, column=0, padx=10, pady=10)
664
+
665
+ self.selectedOption = tk.IntVar(value=1)
666
+
667
+ if torch.cuda.is_available():
668
+ self.selectedOption.set(2)
669
+ self.device = "cuda"
670
+
671
+ self.cpuLabel = ttk.Radiobutton(self.root,
672
+ text='CPU',
673
+ variable=self.selectedOption,
674
+ value=1,
675
+ command=self.select_cpu)
676
+ self.cpuLabel.grid(row=8, column=1, padx=10, pady=10, sticky="w")
677
+
678
+ self.gpuLabel = ttk.Radiobutton(self.root,
679
+ text='GPU',
680
+ variable=self.selectedOption,
681
+ value=2,
682
+ command=self.select_gpu)
683
+ self.gpuLabel.grid(row=8, column=2, padx=10, pady=10, sticky="w")
684
+
685
+ def select_cpu(self):
686
+ self.selectedOption.set(1)
687
+ self.device = "cpu"
688
+
689
+ def select_gpu(self):
690
+ if torch.cuda.is_available():
691
+ self.selectedOption.set(2)
692
+ self.device = "cuda"
693
+ else:
694
+ messagebox.showinfo("Message title", "CUDA is not available. Please select CPU.")
695
+ self.select_cpu()
696
+
697
+
698
+ class XTTMP_GUI:
699
+ def __init__(self, root):
700
+ self.root = root
701
+
702
+ windowHeight = 550
703
+ windowWidth = 510
586
704
 
587
705
  startHeight = (root.winfo_screenheight() - windowHeight) // 2
588
706
  startWidth = (root.winfo_screenwidth() - windowWidth) // 2
@@ -592,13 +710,16 @@ class ModelAndInputSelectorGUI:
592
710
  self._set_window_icon()
593
711
 
594
712
  self.objModelSelector = ModelSelectorGUI(root)
713
+ self.objPostProcessingSelector = PostProcessingSelectorGUI(root)
595
714
  self.objInputSelector = InputSelectorGUI(root)
715
+ self.objDeviceSelector = DeviceSelectorGUI(root)
716
+
596
717
 
597
- self.btnRun = ttk.Button(self.root, text="Run", command=self._run)
598
- self.btnRun.place(x = 20, y=300)
599
718
  self.btnStepping = ttk.Button(self.root, text="Stepping", command=self._stepping)
600
- self.btnStepping.place(x = 20, y=270)
601
719
  self.isStepping = False
720
+ self.btnStepping.grid(row=9, column=2, padx=10, pady=10, sticky='e')
721
+ self.btnRun = ttk.Button(self.root, text="Run", command=self._run)
722
+ self.btnRun.grid(row=10, column=2, padx=10, pady=10, sticky='e')
602
723
 
603
724
  def create_gui(self):
604
725
  self.objModelSelector.create_gui(ALL_MODEL)
@@ -607,9 +728,13 @@ class ModelAndInputSelectorGUI:
607
728
  self.root.mainloop()
608
729
 
609
730
  if self.objInputSelector.selectedOption.get() == 1:
610
- return self.modelName, self.vidName, None, self.isStepping
731
+ return (self.modelName, self.vidName, None, self.isStepping,
732
+ self.objDeviceSelector.device, self.objPostProcessingSelector.get_post_processing(),
733
+ self.objPostProcessingSelector.show_threshold)
611
734
  elif self.objInputSelector.selectedOption.get() == 2:
612
- return self.modelName, self.startImgName, self.endImgName, self.isStepping
735
+ return (self.modelName, self.startImgName, self.endImgName, self.isStepping,
736
+ self.objDeviceSelector.device, self.objPostProcessingSelector.get_post_processing(),
737
+ self.objPostProcessingSelector.show_threshold)
613
738
 
614
739
  def _run(self):
615
740
  self.modelName = self.objModelSelector.modelCombobox.get()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xttmp
3
- Version: 2.3.0.1
3
+ Version: 2.3.0.3
4
4
  Summary: eXtremely Tiny Target - Motion Perception
5
5
  Author-email: Shawn MX <mingshuoxu@hotmail.com>
6
6
  Project-URL: Homepage, https://github.com/MingshuoXu/Small-Target-Motion-Detectors
@@ -16,15 +16,16 @@ Classifier: Programming Language :: Python :: 3.9
16
16
  Classifier: Programming Language :: Python :: 3.10
17
17
  Classifier: Programming Language :: Python :: 3.11
18
18
  Classifier: Programming Language :: Python :: 3.12
19
- Classifier: Topic :: Scientific/Engineering :: Image Recognition
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.8
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
23
  Requires-Dist: matplotlib
24
24
  Requires-Dist: opencv-python
25
25
  Requires-Dist: scipy
26
- Requires-Dist: torch>=2.5.0
27
- Requires-Dist: torchvision>=0.20.0
26
+ Provides-Extra: torch
27
+ Requires-Dist: torch>=2.5.0; extra == "torch"
28
+ Requires-Dist: torchvision>=0.20.0; extra == "torch"
28
29
  Dynamic: license-file
29
30
 
30
31
  # Small Target Motion Detectors, Version 2.3 (XTT-MP: Extremely Tiny Target - Motion Perception)
@@ -66,8 +67,18 @@ Built with modularity and extensibility in mind, XTT-MP provides a robust suite
66
67
  - After `pip install xttmp`, use the installed code and bring your own input data, or run from a repository checkout to access the bundled examples.
67
68
 
68
69
  ### Via PyPI
70
+ #### CPU
71
+ ```bash
72
+ pip install xttmp[torch]
73
+ ```
74
+
75
+ #### NVIDIA GPU (CUDA 12.6)
76
+ ```bash
77
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
78
+ ```
79
+
80
+ ### Running the GUI Demo
69
81
  ```bash
70
- pip install xttmp
71
82
  xttmp_gui
72
83
  ```
73
84
 
@@ -1,5 +1,7 @@
1
1
  matplotlib
2
2
  opencv-python
3
3
  scipy
4
+
5
+ [torch]
4
6
  torch>=2.5.0
5
7
  torchvision>=0.20.0
File without changes
File without changes
File without changes
File without changes