xttmp 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. xttmp/__init__.py +1 -0
  2. xttmp/api/__init__.py +5 -0
  3. xttmp/api/evaluate.py +163 -0
  4. xttmp/api/get_visualize_handle.py +29 -0
  5. xttmp/api/instancing_model.py +35 -0
  6. xttmp/core/__init__.py +0 -0
  7. xttmp/core/apgstmd_core.py +188 -0
  8. xttmp/core/apgstmdv2_core.py +79 -0
  9. xttmp/core/base_core.py +36 -0
  10. xttmp/core/dstmd_core.py +213 -0
  11. xttmp/core/estmd_backbone.py +110 -0
  12. xttmp/core/estmd_core.py +356 -0
  13. xttmp/core/feedbackstmd_core.py +61 -0
  14. xttmp/core/fracstmd_core.py +98 -0
  15. xttmp/core/fstmd_core.py +15 -0
  16. xttmp/core/fstmdv2_core.py +42 -0
  17. xttmp/core/haarstmd_core.py +140 -0
  18. xttmp/core/math_operator.py +307 -0
  19. xttmp/core/stfeedbackstmd_core.py +233 -0
  20. xttmp/core/stmdplus_core.py +187 -0
  21. xttmp/core/stmdplusv2_core.py +82 -0
  22. xttmp/core/vstmd_core.py +420 -0
  23. xttmp/demo/evaluate_model.py +92 -0
  24. xttmp/demo/inference_gui.py +148 -0
  25. xttmp/demo/inference_gui_single_process.py +134 -0
  26. xttmp/demo/inference_image_stream.py +67 -0
  27. xttmp/demo/inference_video.py +66 -0
  28. xttmp/main.py +14 -0
  29. xttmp/model/__init__.py +13 -0
  30. xttmp/model/backbone.py +514 -0
  31. xttmp/model/facilitated_model.py +230 -0
  32. xttmp/model/feedback_model.py +271 -0
  33. xttmp/model/haarstmd.py +61 -0
  34. xttmp/model/vstmd.py +457 -0
  35. xttmp/util/__init__.py +0 -0
  36. xttmp/util/compute_module.py +402 -0
  37. xttmp/util/create_kernel.py +363 -0
  38. xttmp/util/evaluate_module.py +697 -0
  39. xttmp/util/iostream.py +660 -0
  40. xttmp-2.3.0.dist-info/METADATA +85 -0
  41. xttmp-2.3.0.dist-info/RECORD +45 -0
  42. xttmp-2.3.0.dist-info/WHEEL +5 -0
  43. xttmp-2.3.0.dist-info/entry_points.txt +2 -0
  44. xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
  45. xttmp-2.3.0.dist-info/top_level.txt +1 -0
xttmp/util/iostream.py ADDED
@@ -0,0 +1,660 @@
1
+ import os
2
+ import re
3
+ from pathlib import Path
4
+ import logging
5
+ from typing import Optional, List, Union, Tuple, Any
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import tkinter as tk
10
+ from tkinter import ttk, filedialog, messagebox
11
+ import torch
12
+
13
+
14
+ from .. import model
15
+
16
+
17
+ # Get the full path of this file
18
+ filePath = os.path.realpath(__file__)
19
+ gitCodePath = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(filePath))))
20
+ VID_DEFAULT_FOLDER = os.path.join(gitCodePath, 'demodata')
21
+ IMG_DEFAULT_FOLDER = os.path.join(VID_DEFAULT_FOLDER, 'imgstream')
22
+ # Add the path to the package containing the models
23
+ ALL_MODEL = model.__all__
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ class FrameIterator:
28
+ """
29
+ A flexible iterator class that can retrieve images frame-by-frame from a video file
30
+ or from a sequence of numerically sorted image files.
31
+ """
32
+
33
+ def __init__(self, input_path: str, is_video: bool = True, is_silence: bool = True, device: str = 'cpu'):
34
+ """
35
+ Initialize the iterator.
36
+
37
+ Parameters:
38
+ input_path (str):
39
+ - If is_video is True: full path to the video file.
40
+ - If is_video is False: path to folder containing image sequence.
41
+ is_video (bool): Specifies whether input is a video or image sequence.
42
+ is_silence (bool): If True, suppresses standard informational output.
43
+ device (str): Computation device for PyTorch tensors ('cpu', 'cuda', etc.).
44
+ """
45
+ self.input_path = input_path
46
+ self.is_video = is_video
47
+ self.is_silence = is_silence
48
+ self.device = device # 将 device 提升为类属性
49
+
50
+ self.current_index = 0
51
+ self.total_frames = 0
52
+ self.is_open = False
53
+
54
+ self.img_height, self.img_width = None, None
55
+ self.cap = None
56
+ self.image_files: List[str] = []
57
+
58
+ if self.is_video:
59
+ self._init_video_source()
60
+ else:
61
+ self._init_image_sequence_source()
62
+
63
+ def _log(self, message: str, level: int = logging.INFO):
64
+ """Helper to handle silenced logging"""
65
+ if not self.is_silence or level >= logging.WARNING:
66
+ logger.log(level, message)
67
+
68
+ def _setup(self, current_index: int):
69
+ """Jump to a specific frame index."""
70
+ if current_index < 0 or (self.total_frames > 0 and current_index >= self.total_frames):
71
+ logger.warning(f"Index {current_index} is out of bounds (0 - {self.total_frames-1}).")
72
+ return
73
+
74
+ self.current_index = current_index
75
+ if self.is_video and self.cap:
76
+ success = self.cap.set(cv2.CAP_PROP_POS_FRAMES, current_index)
77
+ if not success:
78
+ logger.warning(f"Unable to set video frame position to {current_index}.")
79
+
80
+ # --- Video processing logic ---
81
+ def _init_video_source(self):
82
+ if not os.path.isfile(self.input_path):
83
+ logger.error(f"Video file not found: {self.input_path}")
84
+ return
85
+
86
+ self.cap = cv2.VideoCapture(self.input_path)
87
+ if not self.cap.isOpened():
88
+ logger.error(f"Unable to open video file: {self.input_path}")
89
+ return
90
+
91
+ self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
92
+ self.fps = self.cap.get(cv2.CAP_PROP_FPS)
93
+ self.img_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
94
+ self.img_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
95
+ self.is_open = True
96
+
97
+ self._log(f"Successfully opened video file. Total frames: {self.total_frames}")
98
+
99
+ def _get_next_frame_from_video(self) -> Optional[np.ndarray]:
100
+ if not self.is_open or not self.cap:
101
+ return None
102
+
103
+ ret, frame = self.cap.read()
104
+ if ret:
105
+ self.current_index += 1
106
+ return frame
107
+
108
+ self.release() # Video reading completed or error occurred
109
+ return None
110
+
111
+ # --- Image sequence processing logic ---
112
+ def _init_image_sequence_source(self):
113
+ if not os.path.isdir(self.input_path):
114
+ logger.error(f"Folder not found: {self.input_path}")
115
+ return
116
+
117
+ self.image_files = self._get_sorted_image_files(self.input_path)
118
+
119
+ if not self.image_files:
120
+ logger.error(f"No image files found in folder: {self.input_path}")
121
+ return
122
+
123
+ self.total_frames = len(self.image_files)
124
+ self.is_open = True
125
+ self._log(f"Successfully loaded image sequence. Total images: {self.total_frames}")
126
+
127
+ # Read first image to get dimensions
128
+ first_image = cv2.imread(self.image_files[0], cv2.IMREAD_COLOR)
129
+ if first_image is not None:
130
+ self.img_height, self.img_width = first_image.shape[:2]
131
+
132
+ def _get_next_frame_from_sequence(self) -> Optional[np.ndarray]:
133
+ while self.is_open and self.current_index < self.total_frames:
134
+ file_path = self.image_files[self.current_index]
135
+
136
+ # 【重要修复】无论是否读取成功,都必须 +1,否则读到坏图会死循环
137
+ self.current_index += 1
138
+
139
+ frame = cv2.imread(file_path, cv2.IMREAD_COLOR)
140
+ if frame is not None:
141
+ return frame
142
+ else:
143
+ logger.warning(f"Unable to read or decode image file: {file_path}")
144
+
145
+ self.release()
146
+ return None
147
+
148
+ # --- Core interfaces ---
149
+ def get_next_frame(self) -> Tuple[Optional[np.ndarray], Optional[torch.Tensor], bool]:
150
+ """
151
+ [Public interface] Get next image frame and its grayscale PyTorch tensor.
152
+
153
+ Returns:
154
+ Tuple[color_img, gray_tensor, is_valid]:
155
+ - color_img: BGR image (NumPy array) or None
156
+ - gray_tensor: Grayscale tensor shape (1, 1, H, W) or None
157
+ - is_valid: Boolean indicating if retrieval was successful
158
+ """
159
+ color_img = self._get_next_frame_from_video() if self.is_video else self._get_next_frame_from_sequence()
160
+
161
+ if color_img is None:
162
+ return None, None, False
163
+
164
+ gray_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2GRAY)
165
+ # 移除了中间不必要的变量,直接构造 tensor
166
+ gray_tensor = torch.from_numpy(gray_img).to(device=self.device, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
167
+
168
+ return color_img, gray_tensor, True
169
+
170
+ # --- Iterator & Context Manager Protocols ---
171
+ def __iter__(self):
172
+ return self
173
+
174
+ def __next__(self) -> Tuple[np.ndarray, torch.Tensor]:
175
+ color_img, gray_tensor, is_valid = self.get_next_frame()
176
+ if not is_valid:
177
+ raise StopIteration
178
+ return color_img, gray_tensor
179
+
180
+ def __enter__(self):
181
+ """Enable context manager: `with FrameIterator(...) as it:`"""
182
+ return self
183
+
184
+ def __exit__(self, exc_type, exc_val, exc_tb):
185
+ self.release()
186
+
187
+ def release(self):
188
+ if self.is_open:
189
+ if self.is_video and self.cap is not None:
190
+ self.cap.release()
191
+ self.is_open = False
192
+ self._log("Resources released.")
193
+
194
+ def __del__(self):
195
+ self.release()
196
+
197
+ # --- Helpers ---
198
+ @staticmethod
199
+ def _natural_sort_key(s: str) -> List[Union[str, int]]:
200
+ return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]
201
+
202
+ def _get_sorted_image_files(self, folder_path: str) -> List[str]:
203
+ """ Uses pathlib for faster and cleaner directory iteration. """
204
+ valid_exts = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif'}
205
+ path = Path(folder_path)
206
+
207
+ # iterdir() 比多次调用 glob 性能好得多
208
+ files = [
209
+ str(f) for f in path.iterdir()
210
+ if f.is_file() and f.suffix.lower() in valid_exts
211
+ ]
212
+
213
+ return sorted(files, key=lambda f: self._natural_sort_key(Path(f).name))
214
+
215
+
216
+ class FrameVisualizer:
217
+ def __init__(self, window_name="Visualizer",
218
+ result_index_type="matrix",
219
+ win_width=None, win_height=None,
220
+ is_headless=False,
221
+ conf_threshold=0.8 # 阈值参数
222
+ ):
223
+ """
224
+ 初始化可视化器
225
+ :param conf_threshold: 可视化过滤的相对阈值 (0.0 ~ 1.0)
226
+ """
227
+ self.window_name = window_name
228
+ self.result_index_type = result_index_type # "matrix", "dots", "bbox"
229
+ self.win_width = win_width or 800
230
+ self.win_height = win_height or 600
231
+ self.is_headless = is_headless
232
+ self.conf_threshold = conf_threshold
233
+
234
+ self.save_output = False
235
+ self.video_writer = None # 显式初始化为 None
236
+ self.paused = False
237
+
238
+ self._setup_window()
239
+
240
+ def _setup_window(self):
241
+ """初始化窗口"""
242
+ if self.is_headless:
243
+ return
244
+ cv2.namedWindow(self.window_name, cv2.WINDOW_GUI_NORMAL)
245
+ cv2.resizeWindow(self.window_name, self.win_width, self.win_height)
246
+
247
+ def setup_video_writer(self, output_path, fps=30, width=None, height=None):
248
+ """初始化视频写入器 (建议外部显式调用)"""
249
+ # 确保目录存在
250
+ output_dir = os.path.dirname(output_path)
251
+ if output_dir and not os.path.exists(output_dir):
252
+ os.makedirs(output_dir)
253
+
254
+ width = width if width is not None else self.win_width
255
+ height = height if height is not None else self.win_height
256
+
257
+ # 常用 mp4v 兼容性较好
258
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
259
+ self.video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
260
+ self.save_output = True
261
+ print(f">>> Video writer initialized: {output_path}")
262
+
263
+ def update(self, frame, result=None, direction=None, annotation=None, process_time=None) -> bool:
264
+ if frame is None:
265
+ return False
266
+
267
+ # --- 绘制逻辑 ---
268
+ # 即使 result 是空的,只要不为 None 也可以处理
269
+ if result is not None:
270
+ if self.result_index_type == "matrix":
271
+ self._draw_matrix(frame, result, direction, self.conf_threshold)
272
+ elif self.result_index_type == "dots":
273
+ result = result.cpu().numpy() if isinstance(result, torch.Tensor) else result
274
+ self._draw_dots(frame, result, self.conf_threshold)
275
+ elif self.result_index_type == "bbox":
276
+ self._draw_bbox(frame, result, self.conf_threshold, annotation)
277
+
278
+ # --- 信息显示 ---
279
+ if process_time is not None:
280
+ cv2.putText(frame, f'Time: {process_time*1000:.1f} ms',
281
+ (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8,
282
+ (0, 255, 0), 2, cv2.LINE_AA)
283
+
284
+ # --- 视频保存 (安全检查) ---
285
+ if self.save_output and self.video_writer is not None:
286
+ self.video_writer.write(frame)
287
+
288
+ # --- Headless 快速返回 ---
289
+ if self.is_headless:
290
+ return True
291
+
292
+ # --- 窗口显示与按键控制 ---
293
+ cv2.imshow(self.window_name, frame)
294
+
295
+ while True:
296
+ key = cv2.waitKey(1) & 0xFF
297
+
298
+ if key == 27 or key == ord('q'): # Esc
299
+ return False
300
+
301
+ if key == 32: # Space
302
+ self.paused = not self.paused
303
+ print(f">>> State: {'Paused' if self.paused else 'Running'}")
304
+ if not self.paused: break
305
+ continue
306
+
307
+ if key == ord('n'): # Next
308
+ self.paused = True
309
+ break
310
+
311
+ if not self.paused:
312
+ break
313
+
314
+ return True
315
+
316
+ def close(self):
317
+ if not self.is_headless:
318
+ cv2.destroyWindow(self.window_name)
319
+ if self.video_writer is not None:
320
+ self.video_writer.release()
321
+
322
+ @staticmethod
323
+ def _draw_arrows(frame, x_coords, y_coords, directions, length=15):
324
+ """
325
+ :param x_coords: X 坐标数组 (Cols)
326
+ :param y_coords: Y 坐标数组 (Rows)
327
+ :param directions: 方向角数组 (弧度)
328
+ """
329
+ if len(x_coords) == 0: return
330
+
331
+ cos_d = np.cos(directions)
332
+ sin_d = np.sin(directions)
333
+
334
+ # Zip 里的顺序明确为: x, y, cos, sin
335
+ for x, y, c, s in zip(x_coords, y_coords, cos_d, sin_d):
336
+ # 必须转为 int,因为 cv2 坐标不支持 float
337
+ start_pt = (int(x), int(y))
338
+ end_pt = (int(x + length * c), int(y - length * s))
339
+
340
+ cv2.arrowedLine(frame, start_pt, end_pt,
341
+ color=(0, 0, 255), thickness=1,
342
+ tipLength=0.3, line_type=cv2.LINE_AA)
343
+
344
+ @staticmethod
345
+ def _draw_matrix(frame, matrix, direction_map, threshold):
346
+ """处理 Matrix 格式 (Heatmap)"""
347
+ if np.max(matrix) <= 0: return
348
+
349
+ # np.where 返回 (rows, cols) 即 (y, x)
350
+ rows, cols = np.where(matrix > threshold)
351
+
352
+ # 画点
353
+ for r, c in zip(rows, cols):
354
+ # cv2 坐标是 (x, y) -> (col, row)
355
+ cv2.drawMarker(frame, (c, r), color=(0, 0, 255),
356
+ markerType=cv2.MARKER_STAR, markerSize=5, thickness=1)
357
+
358
+ # 画箭头
359
+ if direction_map is not None and len(rows) > 0:
360
+ # 确保 direction_map 维度匹配,这里假设是同样大小的矩阵
361
+ valid_dirs = direction_map[rows, cols]
362
+
363
+ # 过滤 NaN
364
+ valid_mask = ~np.isnan(valid_dirs)
365
+ if np.any(valid_mask):
366
+ # 传入 _draw_arrows 的必须是 (x, y) 对应 (cols, rows)
367
+ FrameVisualizer._draw_arrows(frame,
368
+ x_coords=cols[valid_mask],
369
+ y_coords=rows[valid_mask],
370
+ directions=valid_dirs[valid_mask])
371
+
372
+ @staticmethod
373
+ def _draw_dots(frame, response, threshold):
374
+ """处理 Dots 格式: [[x, y, score, dir], ...]"""
375
+ if len(response) == 0: return
376
+
377
+ # 假设格式: Col 0=x, Col 1=y, Col 2=score
378
+ # 安全过滤
379
+ scores = response[:, 2]
380
+
381
+ mask = scores > threshold
382
+ filtered = response[mask]
383
+
384
+ if len(filtered) == 0: return
385
+
386
+ xs = filtered[:, 0]
387
+ ys = filtered[:, 1]
388
+
389
+ for x, y in zip(xs, ys):
390
+ cv2.drawMarker(frame, (int(x), int(y)), color=(0, 0, 255),
391
+ markerType=cv2.MARKER_STAR, markerSize=5, thickness=1)
392
+
393
+ # 处理方向 (假设 Col 3 是方向)
394
+ if response.shape[1] > 3:
395
+ dirs = filtered[:, 3]
396
+ valid_mask = ~np.isnan(dirs)
397
+ FrameVisualizer._draw_arrows(frame,
398
+ x_coords=xs[valid_mask],
399
+ y_coords=ys[valid_mask],
400
+ directions=dirs[valid_mask])
401
+
402
+ @staticmethod
403
+ def _draw_bbox(frame, response, threshold, annotation=None):
404
+ """处理 BBox 格式: [[x1, y1, x2, y2, score, dir], ...]"""
405
+ if response.size == 0: return
406
+
407
+ # 1. 提前过滤:先做 Mask 过滤,减少后续转换的数据量
408
+ response = response.cpu().numpy() if isinstance(response, torch.Tensor) else response
409
+ mask = response[:, 4] > threshold
410
+ filtered_res = response[mask]
411
+ if filtered_res.size == 0: return
412
+
413
+ # 2. 批量转换类型
414
+ # 只转换坐标部分,避免对整个 response 进行转换
415
+ boxes = filtered_res[:, :4].astype(np.int32)
416
+
417
+ # 3. 提取方向(如果有)
418
+ # 优化点:直接从过滤后的结果拿第 5 列,避免多次索引 filtered
419
+ has_dir = filtered_res.shape[1] > 5
420
+
421
+ # 4. 优化循环逻辑:将判断移出循环
422
+ if annotation is not None:
423
+ filtered_anno = np.asanyarray(annotation)[mask]
424
+ for (x1, y1, x2, y2), anno in zip(boxes, filtered_anno):
425
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 1, cv2.LINE_AA)
426
+ cv2.putText(frame, str(anno), (x1, y1 - 5),
427
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
428
+ else:
429
+ for x1, y1, x2, y2 in boxes:
430
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 1, cv2.LINE_AA)
431
+
432
+ # 5. 绘制方向向量(矢量化计算中心点)
433
+ if has_dir:
434
+ dirs = filtered_res[:, 5]
435
+ v_mask = ~np.isnan(dirs)
436
+ if np.any(v_mask):
437
+ v_boxes = boxes[v_mask]
438
+ # 使用位移运算或更快的加法,并保持 float 计算中心点
439
+ # 这里的 // 2 直接得到整数坐标,方便画图
440
+ c_xs = (v_boxes[:, 0] + v_boxes[:, 2]) // 2
441
+ c_ys = (v_boxes[:, 1] + v_boxes[:, 3]) // 2
442
+ FrameVisualizer._draw_arrows(frame, c_xs, c_ys, dirs[v_mask])
443
+
444
+
445
+ class ModelSelectorGUI:
446
+ def __init__(self, root):
447
+ self.root = root
448
+
449
+ def create_gui(self, modelList):
450
+ self.modelLabel = ttk.Label(self.root, text="Select a model:", width = 15)
451
+ self.modelLabel.grid(row=0, column=0, padx=10, pady=10)
452
+
453
+ self.modelCombobox = ttk.Combobox(self.root, values=modelList, width = 30)
454
+ self.modelCombobox.current(11)
455
+ self.modelCombobox.grid(row=0, column=1, columnspan=2, padx=10, pady=10)
456
+
457
+
458
+ class InputSelectorGUI:
459
+ def __init__(self, root):
460
+ self.root = root
461
+
462
+ self.vidElement = {}
463
+ self.imgElement = {}
464
+
465
+ self.imgSelectFolder = None
466
+
467
+ self.inputType = None
468
+ self.startFrame = None
469
+ self.endFrame = None
470
+ self.video_file = None
471
+ self.check = False
472
+ self.vidName = None
473
+ self.startFolder = None
474
+ self.startImgName = None
475
+ self.endFolder = None
476
+ self.endImgName = None
477
+
478
+ def create_gui(self):
479
+ self.inputTypeLabel = ttk.Label(self.root, text="Select input from:", width = 15)
480
+ self.inputTypeLabel.grid(row=1, column=0, padx=10, pady=10)
481
+
482
+ self.selectedOption = tk.IntVar(value=0)
483
+
484
+
485
+ self.vidLabel = ttk.Radiobutton(self.root,
486
+ text='Video stream',
487
+ variable=self.selectedOption,
488
+ value=1,
489
+ command=self.select_vidstream)
490
+ self.vidLabel.grid(row=1, column=2, padx=10, pady=10)
491
+
492
+ self.imgLabel = ttk.Radiobutton(self.root,
493
+ text='Image stream',
494
+ variable=self.selectedOption,
495
+ value=2,
496
+ command=self.select_imgstream)
497
+ self.imgLabel.grid(row=1, column=1, padx=10, pady=10)
498
+
499
+ def select_vidstream(self):
500
+ self.imgSelectFolder = None
501
+ self.startImgName = None
502
+ self.endImgName = None
503
+ for element in self.imgElement.values():
504
+ element.destroy()
505
+
506
+ self.vidElement['lblVidIndicate'] = ttk.Label(self.root, text= 'Video\'s path:', width = 15)
507
+ self.vidElement['lblVidIndicate'].grid(row=2, column=0, padx=10, pady=30)
508
+ self.vidElement['lblVidPath'] = ttk.Label(self.root,
509
+ text="Waiting for the selection",
510
+ wraplength=220
511
+ )
512
+ self.vidElement['lblVidPath'].grid(row=2, column=1, columnspan=2, padx=10, pady=10)
513
+
514
+ self.vidElement['btn'] = ttk.Button(self.root, text="Select a video", command=self._clicked_vid)
515
+ self.vidElement['btn'].grid(row=3, column=2, padx=10, pady=10)
516
+
517
+ def _clicked_vid(self):
518
+ self.vidName = filedialog.askopenfilenames(initialdir=VID_DEFAULT_FOLDER)
519
+ self.vidName = self.vidName[0]
520
+ self.vidElement['lblVidPath'].config(text=self.vidName)
521
+
522
+ def select_imgstream(self):
523
+ self.vidName = None
524
+ for element in self.vidElement.values():
525
+ element.destroy()
526
+
527
+ self.imgElement['lblFolder'] = ttk.Label(self.root, text="Image's folder: ", width = 15)
528
+ self.imgElement['lblFolder'].grid(row=2, column=0, padx=10, pady=10)
529
+ self.imgElement['lblFolderName'] = ttk.Label(self.root, text="Waiting for the selection", wraplength=220)
530
+ self.imgElement['lblFolderName'].grid(row=2, column=1, columnspan=2, padx=10, pady=30)
531
+
532
+ self.imgElement['btnStart'] = ttk.Button(self.root, text="Select start frame", command=self._clicked_start_img)
533
+ self.imgElement['btnStart'].grid(row=3, column=1, padx=10, pady=10)
534
+ self.imgElement['btnEnd'] = ttk.Button(self.root, text="Select end frame", command=self._clicked_end_img)
535
+ self.imgElement['btnEnd'].grid(row=4, column=1, padx=10, pady=10)
536
+
537
+ def _clicked_start_img(self):
538
+ startImgFullPath = filedialog.askopenfilenames(
539
+ initialdir=IMG_DEFAULT_FOLDER if self.imgSelectFolder is None else self.imgSelectFolder)
540
+ self.startFolder, self.startImgName = os.path.split(startImgFullPath[0])
541
+ if self.endFolder is not None:
542
+ if os.path.basename(self.startFolder) == os.path.basename(self.endFolder):
543
+ if self.endImgName is not None:
544
+ if check_same_ext_name(self.startImgName, self.endImgName):
545
+ self.check = True
546
+ else:
547
+ messagebox.showinfo("Message title", "Start image has a different extension than end image.")
548
+ else:
549
+ messagebox.showinfo("Message title", "The image stream must be in the same folder!")
550
+
551
+ self.imgSelectFolder = self.startFolder
552
+ self.imgElement['lblFolderName'].config(text=self.imgSelectFolder)
553
+
554
+ self.imgElement['lblStartImg'] = ttk.Label(self.root, text=self.startImgName)
555
+ self.imgElement['lblStartImg'].grid(row=3, column=2, padx=10, pady=10)
556
+
557
+ def _clicked_end_img(self):
558
+ endImgFullPath = filedialog.askopenfilenames(
559
+ initialdir=IMG_DEFAULT_FOLDER if self.imgSelectFolder is None else self.imgSelectFolder)
560
+ self.endFolder , self.endImgName = os.path.split(endImgFullPath[0])
561
+
562
+ if self.startFolder is not None:
563
+ if os.path.basename(self.endFolder) == os.path.basename(self.startFolder):
564
+ if self.startImgName is not None:
565
+ if check_same_ext_name(self.startImgName, self.endImgName):
566
+ self.check = True
567
+ else:
568
+ messagebox.showinfo("Message title", "Start image has a different extension than end image.")
569
+ else:
570
+ messagebox.showinfo("Message title", "The image stream must be in the same folder!")
571
+
572
+
573
+ self.imgSelectFolder = self.endFolder
574
+ self.imgElement['lblFolderName'].config(text=self.imgSelectFolder)
575
+
576
+ self.imgElement['lblEndImg'] = ttk.Label(self.root, text=self.endImgName)
577
+ self.imgElement['lblEndImg'].grid(row=4, column=2, padx=10, pady=10)
578
+
579
+
580
+ class ModelAndInputSelectorGUI:
581
+ def __init__(self, root):
582
+ self.root = root
583
+
584
+ windowHeight = 350
585
+ windowWidth = 400
586
+
587
+ startHeight = (root.winfo_screenheight() - windowHeight) // 2
588
+ startWidth = (root.winfo_screenwidth() - windowWidth) // 2
589
+
590
+ self.root.geometry('{}x{}+{}+{}'.format(windowWidth, windowHeight, startWidth, startHeight))
591
+ self.root.title("Small target motion detector - Runner")
592
+ self.root.iconbitmap(os.path.join(os.path.dirname(filePath), 'stmd.ico'))
593
+
594
+ self.objModelSelector = ModelSelectorGUI(root)
595
+ self.objInputSelector = InputSelectorGUI(root)
596
+
597
+ self.btnRun = ttk.Button(self.root, text="Run", command=self._run)
598
+ self.btnRun.place(x = 20, y=300)
599
+ self.btnStepping = ttk.Button(self.root, text="Stepping", command=self._stepping)
600
+ self.btnStepping.place(x = 20, y=270)
601
+ self.isStepping = False
602
+
603
+ def create_gui(self):
604
+ self.objModelSelector.create_gui(ALL_MODEL)
605
+ self.objInputSelector.create_gui()
606
+
607
+ self.root.mainloop()
608
+
609
+ if self.objInputSelector.selectedOption.get() == 1:
610
+ return self.modelName, self.vidName, None, self.isStepping
611
+ elif self.objInputSelector.selectedOption.get() == 2:
612
+ return self.modelName, self.startImgName, self.endImgName, self.isStepping
613
+
614
+ def _run(self):
615
+ self.modelName = self.objModelSelector.modelCombobox.get()
616
+ if self.modelName not in ALL_MODEL:
617
+ messagebox.showinfo("Message title", "Please select a STMD-based model!")
618
+ return
619
+
620
+ if self.objInputSelector.selectedOption.get() == 1:
621
+ if self.objInputSelector.vidName is not None:
622
+ self.vidName = self.objInputSelector.vidName
623
+ self.root.destroy()
624
+ else:
625
+ messagebox.showinfo("Message title", "Please select a video")
626
+ elif self.objInputSelector.selectedOption.get() == 2:
627
+ if self.objInputSelector.startImgName is None:
628
+ messagebox.showinfo("Message title", "Please select start frame!")
629
+ return
630
+
631
+ if self.objInputSelector.endImgName is None:
632
+ messagebox.showinfo("Message title", "Please select end frame!")
633
+ return
634
+
635
+ if self.objInputSelector.check:
636
+ self.startImgName = os.path.join(self.objInputSelector.imgSelectFolder,
637
+ self.objInputSelector.startImgName)
638
+ self.endImgName = os.path.join(self.objInputSelector.imgSelectFolder,
639
+ self.objInputSelector.endImgName)
640
+ self.root.destroy()
641
+ else:
642
+ messagebox.showinfo("Message title", "The image stream must be in the same folder!")
643
+ else:
644
+ messagebox.showinfo("Message title", "Please select input")
645
+
646
+ def _stepping(self):
647
+ self.isStepping = True
648
+ self._run()
649
+
650
+
651
+ def check_same_ext_name(startImgName, endImgName):
652
+ _, ext1 = os.path.splitext(startImgName)
653
+ _, ext2 = os.path.splitext(endImgName)
654
+ # Check if the extensions of the start and end images are the same
655
+ if ext1 != ext2:
656
+ return False
657
+ else:
658
+ return True
659
+
660
+