xttmp 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. xttmp/__init__.py +1 -0
  2. xttmp/api/__init__.py +5 -0
  3. xttmp/api/evaluate.py +163 -0
  4. xttmp/api/get_visualize_handle.py +29 -0
  5. xttmp/api/instancing_model.py +35 -0
  6. xttmp/core/__init__.py +0 -0
  7. xttmp/core/apgstmd_core.py +188 -0
  8. xttmp/core/apgstmdv2_core.py +79 -0
  9. xttmp/core/base_core.py +36 -0
  10. xttmp/core/dstmd_core.py +213 -0
  11. xttmp/core/estmd_backbone.py +110 -0
  12. xttmp/core/estmd_core.py +356 -0
  13. xttmp/core/feedbackstmd_core.py +61 -0
  14. xttmp/core/fracstmd_core.py +98 -0
  15. xttmp/core/fstmd_core.py +15 -0
  16. xttmp/core/fstmdv2_core.py +42 -0
  17. xttmp/core/haarstmd_core.py +140 -0
  18. xttmp/core/math_operator.py +307 -0
  19. xttmp/core/stfeedbackstmd_core.py +233 -0
  20. xttmp/core/stmdplus_core.py +187 -0
  21. xttmp/core/stmdplusv2_core.py +82 -0
  22. xttmp/core/vstmd_core.py +420 -0
  23. xttmp/demo/evaluate_model.py +92 -0
  24. xttmp/demo/inference_gui.py +148 -0
  25. xttmp/demo/inference_gui_single_process.py +134 -0
  26. xttmp/demo/inference_image_stream.py +67 -0
  27. xttmp/demo/inference_video.py +66 -0
  28. xttmp/main.py +14 -0
  29. xttmp/model/__init__.py +13 -0
  30. xttmp/model/backbone.py +514 -0
  31. xttmp/model/facilitated_model.py +230 -0
  32. xttmp/model/feedback_model.py +271 -0
  33. xttmp/model/haarstmd.py +61 -0
  34. xttmp/model/vstmd.py +457 -0
  35. xttmp/util/__init__.py +0 -0
  36. xttmp/util/compute_module.py +402 -0
  37. xttmp/util/create_kernel.py +363 -0
  38. xttmp/util/evaluate_module.py +697 -0
  39. xttmp/util/iostream.py +660 -0
  40. xttmp-2.3.0.dist-info/METADATA +85 -0
  41. xttmp-2.3.0.dist-info/RECORD +45 -0
  42. xttmp-2.3.0.dist-info/WHEEL +5 -0
  43. xttmp-2.3.0.dist-info/entry_points.txt +2 -0
  44. xttmp-2.3.0.dist-info/licenses/LICENSE +201 -0
  45. xttmp-2.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,134 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ from typing import Optional
5
+
6
+ import tkinter as tk
7
+ import torch
8
+
9
+ file_path = os.path.realpath(__file__)
10
+ project_path = os.path.dirname(os.path.dirname(os.path.dirname(file_path)))
11
+ repo_root = os.path.dirname(project_path)
12
+ sys.path.append(project_path)
13
+
14
+ from xttmp.util.iostream import ModelAndInputSelectorGUI, FrameIterator, FrameVisualizer
15
+ from xttmp.api import instancing_model # type: ignore
16
+ from xttmp.util.compute_module import PostProcessing # type: ignore
17
+
18
+
19
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+
22
+ class StmdGuiSingleProcess:
23
+ def __init__(self, device: str = DEVICE, show_threshold: float = 0.2, get_top_num: int = 20):
24
+ self.device = device
25
+ self.show_threshold = show_threshold
26
+ self.get_top_num = get_top_num
27
+ self.ModelAndInputSelectorGUI = ModelAndInputSelectorGUI
28
+ self.FrameIterator = FrameIterator
29
+ self.FrameVisualizer = FrameVisualizer
30
+ self.PostProcessing = PostProcessing
31
+ self.instancing_model = instancing_model
32
+
33
+ def _get_user_input(self):
34
+ root = tk.Tk()
35
+ try:
36
+ gui = self.ModelAndInputSelectorGUI(root)
37
+ return gui.create_gui()
38
+ finally:
39
+ try:
40
+ root.destroy()
41
+ except tk.TclError:
42
+ pass
43
+
44
+ def _create_frame_reader(self, opt1: str, opt2: Optional[str]):
45
+ if opt2 is None:
46
+ return self.FrameIterator(opt1, is_video=True, device=self.device)
47
+
48
+ reader = self.FrameIterator(os.path.dirname(opt1), is_video=False, device=self.device)
49
+ start_name = os.path.basename(opt1)
50
+ end_name = os.path.basename(opt2)
51
+
52
+ start_index = next((i for i, path in enumerate(reader.image_files)
53
+ if os.path.basename(path) == start_name), None)
54
+ end_index = next((i for i, path in enumerate(reader.image_files)
55
+ if os.path.basename(path) == end_name), None)
56
+
57
+ if start_index is None or end_index is None:
58
+ raise ValueError("Selected image range could not be located in the folder.")
59
+
60
+ if start_index > end_index:
61
+ start_index, end_index = end_index, start_index
62
+
63
+ reader._setup(start_index)
64
+ reader.total_frames = end_index + 1
65
+ return reader
66
+
67
+ def run(self):
68
+ reader = None
69
+ visualizer = None
70
+ try:
71
+ user_input = self._get_user_input()
72
+ if not user_input:
73
+ return
74
+
75
+ model_name, opt1, opt2, is_stepping = user_input
76
+ reader = self._create_frame_reader(opt1, opt2)
77
+ model = self.instancing_model(model_name, device=self.device)
78
+ post_processor = self.PostProcessing(
79
+ device=self.device,
80
+ nms_radio=8,
81
+ get_top_num=self.get_top_num,
82
+ )
83
+
84
+ visualizer = self.FrameVisualizer(
85
+ window_name=model.__class__.__name__,
86
+ result_index_type='dots',
87
+ win_height=reader.img_height,
88
+ win_width=reader.img_width,
89
+ conf_threshold=self.show_threshold,
90
+ )
91
+ if is_stepping:
92
+ visualizer.paused = True
93
+
94
+ total_time = 0.0
95
+
96
+ while True:
97
+ color_img, gray_tensor, is_valid = reader.get_next_frame()
98
+ if not is_valid:
99
+ break
100
+
101
+ if self.device == 'cuda':
102
+ torch.cuda.synchronize()
103
+ start_time = time.time()
104
+
105
+ result = model(gray_tensor)
106
+
107
+ if self.device == 'cuda':
108
+ torch.cuda.synchronize()
109
+ run_time = time.time() - start_time
110
+
111
+ dot_res = post_processor(result['response'], result.get('direction'))
112
+ if not visualizer.update(color_img, dot_res, process_time=run_time):
113
+ break
114
+
115
+ total_time += run_time
116
+
117
+ if total_time > 0:
118
+ print(f"Total time: {total_time:.4f} seconds, "
119
+ f"FPS: {reader.current_index / total_time :.4f} frames/second")
120
+
121
+ finally:
122
+ if visualizer is not None:
123
+ visualizer.close()
124
+ if reader is not None:
125
+ reader.release()
126
+
127
+
128
+ def main(show_threshold: float = 0.2, get_top_num: int = 20):
129
+ app = StmdGuiSingleProcess(device=DEVICE, show_threshold=show_threshold, get_top_num=get_top_num)
130
+ app.run()
131
+
132
+
133
+ if __name__ == '__main__':
134
+ main()
@@ -0,0 +1,67 @@
1
+ # demo_imgstream
2
+ import os
3
+ import sys
4
+ import time
5
+
6
+ import torch
7
+
8
+ filePath = os.path.realpath(__file__)
9
+ project_path = os.path.dirname(os.path.dirname(os.path.dirname(filePath)))
10
+ gitCodePath = os.path.dirname(project_path)
11
+ sys.path.append(project_path)
12
+
13
+ from xttmp.util.iostream import FrameIterator, FrameVisualizer
14
+ from xttmp.api import instancing_model # type: ignore
15
+ from xttmp.util.compute_module import PostProcessing # type: ignore
16
+
17
+
18
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+ MODEL_NAME = 'vSTMD_F'
20
+ INPUT_PATH = os.path.join(gitCodePath, 'example-data', 'imgstream')
21
+ SHOW_THRESHOLD = 0.2
22
+ GET_TOP_NUM = 20
23
+
24
+
25
+ def main():
26
+ model = instancing_model(MODEL_NAME, device=DEVICE)
27
+
28
+ frame_reader = FrameIterator(INPUT_PATH, is_video=False, device=DEVICE)
29
+ visualizer = FrameVisualizer(
30
+ window_name=model.__class__.__name__,
31
+ result_index_type='dots',
32
+ win_height=frame_reader.img_height,
33
+ win_width=frame_reader.img_width,
34
+ conf_threshold=SHOW_THRESHOLD,
35
+ )
36
+ post_processor = PostProcessing(device=DEVICE, nms_radio=8, get_top_num=GET_TOP_NUM)
37
+
38
+ total_time = 0.0
39
+
40
+ try:
41
+ for color_img, gray_tensor in frame_reader:
42
+ if DEVICE == 'cuda':
43
+ torch.cuda.synchronize()
44
+ time_start = time.time()
45
+
46
+ result = model(gray_tensor)
47
+
48
+ if DEVICE == 'cuda':
49
+ torch.cuda.synchronize()
50
+ run_time = time.time() - time_start
51
+
52
+ dot_res = post_processor(result['response'], result.get('direction'))
53
+ if not visualizer.update(color_img, dot_res, process_time=run_time):
54
+ break
55
+
56
+ total_time += run_time
57
+
58
+ if total_time > 0:
59
+ print(f"Total time: {total_time:.4f} seconds, "
60
+ f"FPS: {frame_reader.current_index / total_time :.4f} frames/second")
61
+ finally:
62
+ visualizer.close()
63
+ frame_reader.release()
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
@@ -0,0 +1,66 @@
1
+ # demo_vidstream
2
+ import os
3
+ import sys
4
+ import time
5
+
6
+ import torch
7
+
8
+ filePath = os.path.realpath(__file__)
9
+ project_path = os.path.dirname(os.path.dirname(os.path.dirname(filePath)))
10
+ sys.path.append(project_path)
11
+
12
+ from xttmp.util.iostream import FrameIterator, FrameVisualizer
13
+ from xttmp.api import instancing_model # type: ignore
14
+ from xttmp.util.compute_module import PostProcessing # type: ignore
15
+
16
+
17
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ MODEL_NAME = 'vSTMD_F'
19
+ INPUT_PATH = os.path.join(os.path.dirname(project_path), 'example-data', 'RIST_GX010290_orignal_240Hz.mp4')
20
+ SHOW_THRESHOLD = 0.2
21
+ GET_TOP_NUM = 20
22
+
23
+
24
+ def main():
25
+ model = instancing_model(MODEL_NAME, device=DEVICE)
26
+
27
+ frame_reader = FrameIterator(INPUT_PATH, is_video=True, device=DEVICE)
28
+ visualizer = FrameVisualizer(
29
+ window_name=model.__class__.__name__,
30
+ result_index_type='dots',
31
+ win_height=frame_reader.img_height,
32
+ win_width=frame_reader.img_width,
33
+ conf_threshold=SHOW_THRESHOLD,
34
+ )
35
+ post_processor = PostProcessing(device=DEVICE, nms_radio=8, get_top_num=GET_TOP_NUM)
36
+
37
+ total_time = 0.0
38
+
39
+ try:
40
+ for color_img, gray_tensor in frame_reader:
41
+ if DEVICE == 'cuda':
42
+ torch.cuda.synchronize()
43
+ time_start = time.time()
44
+
45
+ result = model(gray_tensor)
46
+
47
+ if DEVICE == 'cuda':
48
+ torch.cuda.synchronize()
49
+ run_time = time.time() - time_start
50
+
51
+ dot_res = post_processor(result['response'], result.get('direction'))
52
+ if not visualizer.update(color_img, dot_res, process_time=run_time):
53
+ break
54
+
55
+ total_time += run_time
56
+
57
+ if total_time > 0:
58
+ print(f"Total time: {total_time:.4f} seconds, "
59
+ f"FPS: {frame_reader.current_index / total_time :.4f} frames/second")
60
+ finally:
61
+ visualizer.close()
62
+ frame_reader.release()
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
xttmp/main.py ADDED
@@ -0,0 +1,14 @@
1
+ from pathlib import Path
2
+ import subprocess
3
+ import sys
4
+
5
+
6
+ def main():
7
+ script_path = Path(__file__).resolve().parent / 'demo' / 'inference_gui.py'
8
+ result = subprocess.run([sys.executable, str(script_path)], check=False)
9
+ return result.returncode
10
+
11
+
12
+ if __name__ == '__main__':
13
+ raise SystemExit(main())
14
+
@@ -0,0 +1,13 @@
1
+ from .backbone import ESTMD, ESTMDBackbone, FracSTMD, DSTMD, DSTMDBackbone
2
+ from .feedback_model import FeedbackSTMD, FSTMD, FracSTMD_F
3
+ from .facilitated_model import STMDPlus, ApgSTMD
4
+ from .haarstmd import HaarSTMD
5
+ from .vstmd import vSTMD, vSTMD_F
6
+
7
+
8
+ __all__ = ['ESTMD', 'ESTMDBackbone', 'FracSTMD', 'DSTMD', 'DSTMDBackbone', # backbone with four basis layers
9
+ 'FeedbackSTMD', 'FSTMD', 'FracSTMD_F', # model with feedback pathway
10
+ 'STMDPlus', 'ApgSTMD', # facilitated model
11
+ 'HaarSTMD',
12
+ 'vSTMD', 'vSTMD_F',
13
+ ]