simba-uw-tf-dev 4.6.1__py3-none-any.whl → 4.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. simba/SimBA.py +2 -2
  2. simba/assets/icons/frames_2.png +0 -0
  3. simba/data_processors/agg_clf_counter_mp.py +52 -53
  4. simba/data_processors/cuda/image.py +3 -1
  5. simba/data_processors/cue_light_analyzer.py +5 -9
  6. simba/mixins/geometry_mixin.py +14 -28
  7. simba/mixins/image_mixin.py +10 -14
  8. simba/mixins/train_model_mixin.py +2 -2
  9. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  10. simba/plotting/clf_validator_mp.py +4 -5
  11. simba/plotting/cue_light_visualizer.py +6 -7
  12. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  13. simba/plotting/distance_plotter_mp.py +378 -378
  14. simba/plotting/frame_mergerer_ffmpeg.py +137 -137
  15. simba/plotting/gantt_creator_mp.py +59 -31
  16. simba/plotting/geometry_plotter.py +270 -272
  17. simba/plotting/heat_mapper_clf_mp.py +2 -4
  18. simba/plotting/heat_mapper_location_mp.py +2 -2
  19. simba/plotting/light_dark_box_plotter.py +2 -2
  20. simba/plotting/path_plotter_mp.py +26 -29
  21. simba/plotting/plot_clf_results_mp.py +455 -454
  22. simba/plotting/pose_plotter_mp.py +27 -32
  23. simba/plotting/probability_plot_creator_mp.py +288 -288
  24. simba/plotting/roi_plotter_mp.py +29 -30
  25. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  26. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  27. simba/plotting/yolo_pose_track_visualizer.py +31 -27
  28. simba/plotting/yolo_pose_visualizer.py +32 -34
  29. simba/plotting/yolo_seg_visualizer.py +2 -3
  30. simba/roi_tools/roi_aggregate_stats_mp.py +4 -3
  31. simba/roi_tools/roi_clf_calculator_mp.py +3 -3
  32. simba/sandbox/cuda/egocentric_rotator.py +374 -0
  33. simba/sandbox/get_cpu_pool.py +5 -0
  34. simba/ui/pop_ups/clf_add_remove_print_pop_up.py +3 -1
  35. simba/ui/pop_ups/egocentric_alignment_pop_up.py +6 -3
  36. simba/ui/pop_ups/multiple_videos_to_frames_popup.py +10 -11
  37. simba/ui/pop_ups/single_video_to_frames_popup.py +10 -10
  38. simba/ui/pop_ups/video_processing_pop_up.py +63 -63
  39. simba/ui/tkinter_functions.py +7 -1
  40. simba/utils/data.py +89 -12
  41. simba/utils/enums.py +1 -0
  42. simba/utils/printing.py +9 -8
  43. simba/utils/read_write.py +3726 -3721
  44. simba/video_processors/clahe_ui.py +65 -22
  45. simba/video_processors/egocentric_video_rotator.py +6 -9
  46. simba/video_processors/video_processing.py +21 -10
  47. simba/video_processors/videos_to_frames.py +3 -2
  48. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/METADATA +1 -1
  49. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/RECORD +53 -50
  50. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/LICENSE +0 -0
  51. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/WHEEL +0 -0
  52. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/entry_points.txt +0 -0
  53. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/top_level.txt +0 -0
@@ -1,288 +1,288 @@
1
- __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
-
3
- import functools
4
- import multiprocessing
5
- import os
6
- import platform
7
- import shutil
8
- from copy import deepcopy
9
- from typing import List, Optional, Tuple, Union
10
-
11
- import cv2
12
- import numpy as np
13
-
14
- from simba.mixins.config_reader import ConfigReader
15
- from simba.mixins.plotting_mixin import PlottingMixin
16
- from simba.utils.checks import (
17
- check_all_file_names_are_represented_in_video_log,
18
- check_file_exist_and_readable, check_float, check_instance, check_int,
19
- check_str, check_that_column_exist, check_valid_boolean, check_valid_tuple)
20
- from simba.utils.enums import Formats
21
- from simba.utils.errors import NoSpecifiedOutputError
22
- from simba.utils.lookups import get_color_dict
23
- from simba.utils.printing import SimbaTimer, stdout_success
24
- from simba.utils.read_write import (concatenate_videos_in_folder,
25
- find_core_cnt, get_fn_ext, read_df)
26
-
27
- STYLE_WIDTH = 'width'
28
- STYLE_HEIGHT = 'height'
29
- STYLE_FONT_SIZE = 'font size'
30
- STYLE_LINE_WIDTH = 'line width'
31
- STYLE_YMAX = 'y_max'
32
- STYLE_COLOR = 'color'
33
- AUTO = 'AUTO'
34
- STYLE_OPACITY = 'opacity'
35
-
36
- VALID_COLORS = list(get_color_dict().keys())
37
- FOURCC = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
38
-
39
- STYLE_ATTR = [STYLE_WIDTH, STYLE_HEIGHT, STYLE_FONT_SIZE, STYLE_LINE_WIDTH, STYLE_COLOR, STYLE_YMAX, STYLE_OPACITY]
40
-
41
- def _probability_plot_mp(frm_range: Tuple[int, np.ndarray],
42
- clf_data: np.ndarray,
43
- clf_name: str,
44
- video_setting: bool,
45
- frame_setting: bool,
46
- video_dir: str,
47
- frame_dir: str,
48
- fps: int,
49
- video_name: str,
50
- y_max: Union[int, float],
51
- size: tuple,
52
- line_width: int,
53
- font_size: int,
54
- opacity: float,
55
- color: str,
56
- show_thresholds: bool):
57
-
58
-
59
-
60
- group, data = frm_range[0], frm_range[1]
61
- start_frm, end_frm, current_frm = data[0], data[-1], data[0]
62
-
63
- if video_setting:
64
- fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
65
- video_save_path = os.path.join(video_dir, f"{group}.mp4")
66
- video_writer = cv2.VideoWriter(video_save_path, fourcc, fps, size)
67
-
68
- while current_frm < end_frm:
69
- current_lst = [np.array(clf_data[0 : current_frm + 1])]
70
- current_frm += 1
71
- img = PlottingMixin.make_line_plot(data=current_lst,
72
- colors=[color],
73
- width=size[0],
74
- height=size[1],
75
- line_width=line_width,
76
- font_size=font_size,
77
- line_opacity=opacity,
78
- y_lbl=f"{clf_name} probability",
79
- title=f'{video_name} - {clf_name}',
80
- y_max=y_max,
81
- x_lbl='frame count',
82
- show_thresholds=show_thresholds)
83
-
84
- if video_setting:
85
- video_writer.write(img[:, :, :3])
86
- if frame_setting:
87
- frame_save_name = os.path.join(frame_dir, f"{current_frm}.png")
88
- cv2.imwrite(frame_save_name, img)
89
- current_frm += 1
90
- print(f"Probability frame created: {current_frm + 1}, Video: {video_name}, Processing core: {group}")
91
- return group
92
-
93
-
94
- class TresholdPlotCreatorMultiprocess(ConfigReader, PlottingMixin):
95
- """
96
- Class for line chart visualizations displaying the classification probabilities of a single classifier.
97
- Uses multiprocessing.
98
-
99
- :param str config_path: path to SimBA project config file in Configparser format
100
- :param str clf_name: Name of the classifier to create visualizations for
101
- :param bool frame_setting: When True, SimBA creates indidvidual frames in png format
102
- :param bool video_setting: When True, SimBA creates compressed video in mp4 format
103
- :param bool last_image: When True, creates image .png representing last frame of the video.
104
- :param dict style_attr: User-defined style attributes of the visualization (line size, color etc).
105
- :param List[str] files_found: Files to create threshold plots for.
106
- :param int cores: Number of cores to use.
107
-
108
- .. note::
109
- `Visualization tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#step-11-visualizations>`__.
110
-
111
- .. image:: _static/img/prob_plot.png
112
- :width: 300
113
- :align: center
114
-
115
- :example:
116
- >>> plot_creator = TresholdPlotCreatorMultiprocess(config_path='/Users/simon/Desktop/troubleshooting/train_model_project/project_folder/project_config.ini', frame_setting=True, video_setting=True, clf_name='Attack', style_attr={'width': 640, 'height': 480, 'font size': 10, 'line width': 6, 'color': 'magneta', 'circle size': 20}, cores=5)
117
- >>> plot_creator.run()
118
- """
119
-
120
- def __init__(self,
121
- config_path: Union[str, os.PathLike],
122
- data_path: Union[List[Union[str, os.PathLike]], str, os.PathLike],
123
- clf_name: str,
124
- frame_setting: Optional[bool] = False,
125
- video_setting: Optional[bool] = False,
126
- last_frame: Optional[bool] = True,
127
- size: Tuple[int, int] = (640, 480),
128
- font_size: int = 10,
129
- line_width: int = 2,
130
- y_max: Optional[int] = None,
131
- line_color: str = 'Red',
132
- line_opacity: float = 0.8,
133
- cores: Optional[int] = -1,
134
- show_thresholds: bool = True):
135
-
136
- if platform.system() == "Darwin":
137
- multiprocessing.set_start_method("spawn", force=True)
138
- if (not video_setting) and (not frame_setting) and (not last_frame):
139
- raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please choose to create video and/or frames data plots. SimBA found that you ticked neither video and/or frames")
140
- check_int(name=f"{self.__class__.__name__} core_cnt", value=cores, min_value=-1, max_value=find_core_cnt()[0], unaccepted_vals=[0])
141
- if cores == -1: cores = find_core_cnt()[0]
142
- check_valid_tuple(x=size, source=f'{self.__class__.__name__} size', accepted_lengths=(2,), valid_dtypes=Formats.INTEGER_DTYPES.value, min_integer=100)
143
- check_int(name=f'{self.__class__.__name__} font_size', value=font_size, min_value=1, raise_error=True)
144
- check_int(name=f'{self.__class__.__name__} line_width', value=line_width, min_value=1, raise_error=True)
145
- check_valid_boolean(value=show_thresholds, source=f'{self.__class__.__name__} show_thresholds')
146
- if y_max is not None:
147
- check_float(name=f'{self.__class__.__name__} y_max', value=y_max, min_value=0.00001, raise_error=True)
148
- check_str(name=f'{self.__class__.__name__} color', value=line_color, options=VALID_COLORS)
149
- check_float(name=f'{self.__class__.__name__} line_opacity', value=line_opacity, min_value=0.001, max_value=1.0, raise_error=True)
150
- ConfigReader.__init__(self, config_path=config_path)
151
- PlottingMixin.__init__(self)
152
- check_str(name=f"{self.__class__.__name__} clf_name", value=clf_name, options=(self.clf_names))
153
- self.frame_setting, self.video_setting, self.last_image = frame_setting, video_setting, last_frame
154
- self.line_opacity, self.line_clr, self.line_width = line_opacity, line_color, line_width
155
- self.font_size, self.img_size, self.y_max = font_size, size, y_max
156
- check_instance(source=f'{self.__class__.__name__} data_path' , instance=data_path, accepted_types=(str, list,), raise_error=True)
157
- if isinstance(data_path, str):
158
- data_path = [data_path]
159
- for path in data_path:
160
- check_file_exist_and_readable(file_path=path, raise_error=True)
161
- check_str(name=f"{self.__class__.__name__} clf_name", value=clf_name, options=(self.clf_names))
162
- self.show_thresholds = show_thresholds
163
- self.frame_setting, self.video_setting, self.cores, self.last_frame = (frame_setting, video_setting, cores, last_frame)
164
- self.clf_name, self.data_paths = clf_name, data_path
165
- self.probability_col, self.img_size = f"Probability_{self.clf_name}", size
166
- if not os.path.exists(self.probability_plot_dir): os.makedirs(self.probability_plot_dir)
167
- print(f"Processing {len(self.data_paths)} video(s)...")
168
-
169
- def run(self):
170
- check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
171
- for file_cnt, file_path in enumerate(self.data_paths):
172
- video_timer = SimbaTimer(start=True)
173
- _, self.video_name, _ = get_fn_ext(file_path)
174
- video_info, self.px_per_mm, self.fps = self.read_video_info(video_name=self.video_name)
175
- data_df = read_df(file_path, self.file_type)
176
- check_that_column_exist(df=data_df, column_name=[self.clf_name, self.probability_col], file_name=file_path)
177
- self.save_frame_folder_dir = os.path.join(self.probability_plot_dir, self.video_name + f"_{self.clf_name}")
178
- self.video_folder = os.path.join(self.probability_plot_dir, self.video_name + f"_{self.clf_name}")
179
- self.temp_folder = os.path.join(self.probability_plot_dir, f"{self.video_name}_{self.clf_name}", "temp")
180
- if self.frame_setting:
181
- if os.path.exists(self.save_frame_folder_dir):
182
- shutil.rmtree(self.save_frame_folder_dir)
183
- os.makedirs(self.save_frame_folder_dir)
184
- if self.video_setting:
185
- if os.path.exists(self.temp_folder):
186
- shutil.rmtree(self.temp_folder)
187
- shutil.rmtree(self.video_folder)
188
- os.makedirs(self.temp_folder)
189
- self.save_video_path = os.path.join(self.probability_plot_dir, f"{self.video_name}_{self.clf_name}.mp4")
190
-
191
- clf_data = data_df[self.probability_col].values
192
- y_max = deepcopy(self.y_max) if self.y_max is not None else float(np.max(clf_data))
193
-
194
- if self.last_frame:
195
- final_frm_save_path = os.path.join(self.probability_plot_dir, f'{self.video_name}_{self.clf_name}_final_frm_{self.datetime}.png')
196
- _ = PlottingMixin.make_line_plot(data=[clf_data],
197
- colors=[self.line_clr],
198
- width=self.img_size[0],
199
- height=self.img_size[1],
200
- line_width=self.line_width,
201
- font_size=self.font_size,
202
- y_lbl=f"{self.clf_name} probability",
203
- y_max=y_max,
204
- x_lbl='frame count',
205
- title=f'{self.video_name} - {self.clf_name}',
206
- save_path=final_frm_save_path,
207
- line_opacity=self.line_opacity,
208
- show_thresholds=self.show_thresholds)
209
-
210
- if self.video_setting or self.frame_setting:
211
- frm_nums = np.arange(0, len(data_df)+1)
212
- data_split = np.array_split(frm_nums, self.cores)
213
- frm_range = []
214
- for cnt, i in enumerate(data_split): frm_range.append((cnt, i))
215
- print(f"Creating probability images, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.cores})...")
216
- with multiprocessing.Pool(self.cores, maxtasksperchild=self.maxtasksperchild) as pool:
217
- constants = functools.partial(_probability_plot_mp,
218
- clf_name=self.clf_name,
219
- clf_data=clf_data,
220
- video_setting=self.video_setting,
221
- frame_setting=self.frame_setting,
222
- fps=self.fps,
223
- video_dir=self.temp_folder,
224
- frame_dir=self.save_frame_folder_dir,
225
- video_name=self.video_name,
226
- y_max=y_max,
227
- size=self.img_size,
228
- line_width=self.line_width,
229
- font_size=self.font_size,
230
- opacity=self.line_opacity,
231
- color=self.line_clr,
232
- show_thresholds=self.show_thresholds)
233
-
234
- for cnt, result in enumerate(pool.imap(constants, frm_range, chunksize=self.multiprocess_chunksize)):
235
- print(f"Core batch {result} complete...")
236
-
237
- pool.join()
238
- pool.terminate()
239
- if self.video_setting:
240
- print(f"Joining {self.video_name} multiprocessed video...")
241
- concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path)
242
-
243
- video_timer.stop_timer()
244
- print(f"Probability video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s) ...")
245
-
246
- self.timer.stop_timer()
247
- stdout_success(msg=f"Probability visualizations for {str(len(self.data_paths))} videos created in {self.probability_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str,)
248
-
249
-
250
- # test = TresholdPlotCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini',
251
- # frame_setting=False,
252
- # video_setting=True,
253
- # last_frame=True,
254
- # clf_name='Nose to Nose',
255
- # cores=-1,
256
- # files_found=['/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/csv/machine_results/Trial 10.csv'],
257
- # style_attr={'width': 640, 'height': 480, 'font size': 10, 'line width': 6, 'color': 'Red', 'circle size': 20, 'y_max': 'auto'})
258
- # #test = TresholdPlotCreatorMultiprocess(config_path='/Users/simon/Desktop/troubleshooting/train_model_project/project_folder/project_config.ini', frame_setting=False, video_setting=True, clf_name='Attack')
259
- # test.run()
260
-
261
-
262
- # test = TresholdPlotCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
263
- # frame_setting=False,
264
- # video_setting=True,
265
- # last_frame=True,
266
- # clf_name='Attack',
267
- # cores=5,
268
- # files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'],
269
- # style_attr={'width': 640, 'height': 480, 'font size': 10, 'line width': 3, 'color': 'blue', 'circle size': 20, 'y_max': 'auto'})
270
- # test.create_plots()
271
-
272
- # if __name__ == "__main__":
273
- # test = TresholdPlotCreatorMultiprocess(config_path=r"C:\troubleshooting\sleap_two_animals\project_folder\project_config.ini",
274
- # frame_setting=True,
275
- # video_setting=False,
276
- # last_frame=True,
277
- # clf_name='Attack',
278
- # data_path=[r"C:\troubleshooting\sleap_two_animals\project_folder\csv\machine_results\Together_1.csv"],
279
- # size = (640, 480),
280
- # font_size=10,
281
- # line_width=6,
282
- # line_color='Orange',
283
- # y_max=None,
284
- # line_opacity=0.8,
285
- # cores=4)
286
- # test.run()
287
- #
288
-
1
+ __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
+
3
+ import functools
4
+ import multiprocessing
5
+ import os
6
+ import platform
7
+ import shutil
8
+ from copy import deepcopy
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import cv2
12
+ import numpy as np
13
+
14
+ from simba.mixins.config_reader import ConfigReader
15
+ from simba.mixins.plotting_mixin import PlottingMixin
16
+ from simba.utils.checks import (
17
+ check_all_file_names_are_represented_in_video_log,
18
+ check_file_exist_and_readable, check_float, check_instance, check_int,
19
+ check_str, check_that_column_exist, check_valid_boolean, check_valid_tuple)
20
+ from simba.utils.data import terminate_cpu_pool
21
+ from simba.utils.enums import Formats
22
+ from simba.utils.errors import NoSpecifiedOutputError
23
+ from simba.utils.lookups import get_color_dict
24
+ from simba.utils.printing import SimbaTimer, stdout_success
25
+ from simba.utils.read_write import (concatenate_videos_in_folder,
26
+ find_core_cnt, get_fn_ext, read_df)
27
+
28
+ STYLE_WIDTH = 'width'
29
+ STYLE_HEIGHT = 'height'
30
+ STYLE_FONT_SIZE = 'font size'
31
+ STYLE_LINE_WIDTH = 'line width'
32
+ STYLE_YMAX = 'y_max'
33
+ STYLE_COLOR = 'color'
34
+ AUTO = 'AUTO'
35
+ STYLE_OPACITY = 'opacity'
36
+
37
+ VALID_COLORS = list(get_color_dict().keys())
38
+ FOURCC = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
39
+
40
+ STYLE_ATTR = [STYLE_WIDTH, STYLE_HEIGHT, STYLE_FONT_SIZE, STYLE_LINE_WIDTH, STYLE_COLOR, STYLE_YMAX, STYLE_OPACITY]
41
+
42
+ def _probability_plot_mp(frm_range: Tuple[int, np.ndarray],
43
+ clf_data: np.ndarray,
44
+ clf_name: str,
45
+ video_setting: bool,
46
+ frame_setting: bool,
47
+ video_dir: str,
48
+ frame_dir: str,
49
+ fps: int,
50
+ video_name: str,
51
+ y_max: Union[int, float],
52
+ size: tuple,
53
+ line_width: int,
54
+ font_size: int,
55
+ opacity: float,
56
+ color: str,
57
+ show_thresholds: bool):
58
+
59
+
60
+
61
+ group, data = frm_range[0], frm_range[1]
62
+ start_frm, end_frm, current_frm = data[0], data[-1], data[0]
63
+
64
+ if video_setting:
65
+ fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
66
+ video_save_path = os.path.join(video_dir, f"{group}.mp4")
67
+ video_writer = cv2.VideoWriter(video_save_path, fourcc, fps, size)
68
+
69
+ while current_frm < end_frm:
70
+ current_lst = [np.array(clf_data[0 : current_frm + 1])]
71
+ current_frm += 1
72
+ img = PlottingMixin.make_line_plot(data=current_lst,
73
+ colors=[color],
74
+ width=size[0],
75
+ height=size[1],
76
+ line_width=line_width,
77
+ font_size=font_size,
78
+ line_opacity=opacity,
79
+ y_lbl=f"{clf_name} probability",
80
+ title=f'{video_name} - {clf_name}',
81
+ y_max=y_max,
82
+ x_lbl='frame count',
83
+ show_thresholds=show_thresholds)
84
+
85
+ if video_setting:
86
+ video_writer.write(img[:, :, :3])
87
+ if frame_setting:
88
+ frame_save_name = os.path.join(frame_dir, f"{current_frm}.png")
89
+ cv2.imwrite(frame_save_name, img)
90
+ current_frm += 1
91
+ print(f"Probability frame created: {current_frm + 1}, Video: {video_name}, Processing core: {group}")
92
+ return group
93
+
94
+
95
+ class TresholdPlotCreatorMultiprocess(ConfigReader, PlottingMixin):
96
+ """
97
+ Class for line chart visualizations displaying the classification probabilities of a single classifier.
98
+ Uses multiprocessing.
99
+
100
+ :param str config_path: path to SimBA project config file in Configparser format
101
+ :param str clf_name: Name of the classifier to create visualizations for
102
+ :param bool frame_setting: When True, SimBA creates indidvidual frames in png format
103
+ :param bool video_setting: When True, SimBA creates compressed video in mp4 format
104
+ :param bool last_image: When True, creates image .png representing last frame of the video.
105
+ :param dict style_attr: User-defined style attributes of the visualization (line size, color etc).
106
+ :param List[str] files_found: Files to create threshold plots for.
107
+ :param int cores: Number of cores to use.
108
+
109
+ .. note::
110
+ `Visualization tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#step-11-visualizations>`__.
111
+
112
+ .. image:: _static/img/prob_plot.png
113
+ :width: 300
114
+ :align: center
115
+
116
+ :example:
117
+ >>> plot_creator = TresholdPlotCreatorMultiprocess(config_path='/Users/simon/Desktop/troubleshooting/train_model_project/project_folder/project_config.ini', frame_setting=True, video_setting=True, clf_name='Attack', style_attr={'width': 640, 'height': 480, 'font size': 10, 'line width': 6, 'color': 'magneta', 'circle size': 20}, cores=5)
118
+ >>> plot_creator.run()
119
+ """
120
+
121
+ def __init__(self,
122
+ config_path: Union[str, os.PathLike],
123
+ data_path: Union[List[Union[str, os.PathLike]], str, os.PathLike],
124
+ clf_name: str,
125
+ frame_setting: Optional[bool] = False,
126
+ video_setting: Optional[bool] = False,
127
+ last_frame: Optional[bool] = True,
128
+ size: Tuple[int, int] = (640, 480),
129
+ font_size: int = 10,
130
+ line_width: int = 2,
131
+ y_max: Optional[int] = None,
132
+ line_color: str = 'Red',
133
+ line_opacity: float = 0.8,
134
+ cores: Optional[int] = -1,
135
+ show_thresholds: bool = True):
136
+
137
+ if platform.system() == "Darwin":
138
+ multiprocessing.set_start_method("spawn", force=True)
139
+ if (not video_setting) and (not frame_setting) and (not last_frame):
140
+ raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please choose to create video and/or frames data plots. SimBA found that you ticked neither video and/or frames")
141
+ check_int(name=f"{self.__class__.__name__} core_cnt", value=cores, min_value=-1, max_value=find_core_cnt()[0], unaccepted_vals=[0])
142
+ if cores == -1: cores = find_core_cnt()[0]
143
+ check_valid_tuple(x=size, source=f'{self.__class__.__name__} size', accepted_lengths=(2,), valid_dtypes=Formats.INTEGER_DTYPES.value, min_integer=100)
144
+ check_int(name=f'{self.__class__.__name__} font_size', value=font_size, min_value=1, raise_error=True)
145
+ check_int(name=f'{self.__class__.__name__} line_width', value=line_width, min_value=1, raise_error=True)
146
+ check_valid_boolean(value=show_thresholds, source=f'{self.__class__.__name__} show_thresholds')
147
+ if y_max is not None:
148
+ check_float(name=f'{self.__class__.__name__} y_max', value=y_max, min_value=0.00001, raise_error=True)
149
+ check_str(name=f'{self.__class__.__name__} color', value=line_color, options=VALID_COLORS)
150
+ check_float(name=f'{self.__class__.__name__} line_opacity', value=line_opacity, min_value=0.001, max_value=1.0, raise_error=True)
151
+ ConfigReader.__init__(self, config_path=config_path)
152
+ PlottingMixin.__init__(self)
153
+ check_str(name=f"{self.__class__.__name__} clf_name", value=clf_name, options=(self.clf_names))
154
+ self.frame_setting, self.video_setting, self.last_image = frame_setting, video_setting, last_frame
155
+ self.line_opacity, self.line_clr, self.line_width = line_opacity, line_color, line_width
156
+ self.font_size, self.img_size, self.y_max = font_size, size, y_max
157
+ check_instance(source=f'{self.__class__.__name__} data_path' , instance=data_path, accepted_types=(str, list,), raise_error=True)
158
+ if isinstance(data_path, str):
159
+ data_path = [data_path]
160
+ for path in data_path:
161
+ check_file_exist_and_readable(file_path=path, raise_error=True)
162
+ check_str(name=f"{self.__class__.__name__} clf_name", value=clf_name, options=(self.clf_names))
163
+ self.show_thresholds = show_thresholds
164
+ self.frame_setting, self.video_setting, self.cores, self.last_frame = (frame_setting, video_setting, cores, last_frame)
165
+ self.clf_name, self.data_paths = clf_name, data_path
166
+ self.probability_col, self.img_size = f"Probability_{self.clf_name}", size
167
+ if not os.path.exists(self.probability_plot_dir): os.makedirs(self.probability_plot_dir)
168
+ print(f"Processing {len(self.data_paths)} video(s)...")
169
+
170
+ def run(self):
171
+ check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
172
+ for file_cnt, file_path in enumerate(self.data_paths):
173
+ video_timer = SimbaTimer(start=True)
174
+ _, self.video_name, _ = get_fn_ext(file_path)
175
+ video_info, self.px_per_mm, self.fps = self.read_video_info(video_name=self.video_name)
176
+ data_df = read_df(file_path, self.file_type)
177
+ check_that_column_exist(df=data_df, column_name=[self.clf_name, self.probability_col], file_name=file_path)
178
+ self.save_frame_folder_dir = os.path.join(self.probability_plot_dir, self.video_name + f"_{self.clf_name}")
179
+ self.video_folder = os.path.join(self.probability_plot_dir, self.video_name + f"_{self.clf_name}")
180
+ self.temp_folder = os.path.join(self.probability_plot_dir, f"{self.video_name}_{self.clf_name}", "temp")
181
+ if self.frame_setting:
182
+ if os.path.exists(self.save_frame_folder_dir):
183
+ shutil.rmtree(self.save_frame_folder_dir)
184
+ os.makedirs(self.save_frame_folder_dir)
185
+ if self.video_setting:
186
+ if os.path.exists(self.temp_folder):
187
+ shutil.rmtree(self.temp_folder)
188
+ shutil.rmtree(self.video_folder)
189
+ os.makedirs(self.temp_folder)
190
+ self.save_video_path = os.path.join(self.probability_plot_dir, f"{self.video_name}_{self.clf_name}.mp4")
191
+
192
+ clf_data = data_df[self.probability_col].values
193
+ y_max = deepcopy(self.y_max) if self.y_max is not None else float(np.max(clf_data))
194
+
195
+ if self.last_frame:
196
+ final_frm_save_path = os.path.join(self.probability_plot_dir, f'{self.video_name}_{self.clf_name}_final_frm_{self.datetime}.png')
197
+ _ = PlottingMixin.make_line_plot(data=[clf_data],
198
+ colors=[self.line_clr],
199
+ width=self.img_size[0],
200
+ height=self.img_size[1],
201
+ line_width=self.line_width,
202
+ font_size=self.font_size,
203
+ y_lbl=f"{self.clf_name} probability",
204
+ y_max=y_max,
205
+ x_lbl='frame count',
206
+ title=f'{self.video_name} - {self.clf_name}',
207
+ save_path=final_frm_save_path,
208
+ line_opacity=self.line_opacity,
209
+ show_thresholds=self.show_thresholds)
210
+
211
+ if self.video_setting or self.frame_setting:
212
+ frm_nums = np.arange(0, len(data_df)+1)
213
+ data_split = np.array_split(frm_nums, self.cores)
214
+ frm_range = []
215
+ for cnt, i in enumerate(data_split): frm_range.append((cnt, i))
216
+ print(f"Creating probability images, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.cores})...")
217
+ with multiprocessing.Pool(self.cores, maxtasksperchild=self.maxtasksperchild) as pool:
218
+ constants = functools.partial(_probability_plot_mp,
219
+ clf_name=self.clf_name,
220
+ clf_data=clf_data,
221
+ video_setting=self.video_setting,
222
+ frame_setting=self.frame_setting,
223
+ fps=self.fps,
224
+ video_dir=self.temp_folder,
225
+ frame_dir=self.save_frame_folder_dir,
226
+ video_name=self.video_name,
227
+ y_max=y_max,
228
+ size=self.img_size,
229
+ line_width=self.line_width,
230
+ font_size=self.font_size,
231
+ opacity=self.line_opacity,
232
+ color=self.line_clr,
233
+ show_thresholds=self.show_thresholds)
234
+
235
+ for cnt, result in enumerate(pool.imap(constants, frm_range, chunksize=self.multiprocess_chunksize)):
236
+ print(f"Core batch {result} complete...")
237
+
238
+ terminate_cpu_pool(pool=pool, force=False)
239
+ if self.video_setting:
240
+ print(f"Joining {self.video_name} multiprocessed video...")
241
+ concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path)
242
+
243
+ video_timer.stop_timer()
244
+ print(f"Probability video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s) ...")
245
+
246
+ self.timer.stop_timer()
247
+ stdout_success(msg=f"Probability visualizations for {str(len(self.data_paths))} videos created in {self.probability_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str,)
248
+
249
+
250
+ # test = TresholdPlotCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini',
251
+ # frame_setting=False,
252
+ # video_setting=True,
253
+ # last_frame=True,
254
+ # clf_name='Nose to Nose',
255
+ # cores=-1,
256
+ # files_found=['/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/csv/machine_results/Trial 10.csv'],
257
+ # style_attr={'width': 640, 'height': 480, 'font size': 10, 'line width': 6, 'color': 'Red', 'circle size': 20, 'y_max': 'auto'})
258
+ # #test = TresholdPlotCreatorMultiprocess(config_path='/Users/simon/Desktop/troubleshooting/train_model_project/project_folder/project_config.ini', frame_setting=False, video_setting=True, clf_name='Attack')
259
+ # test.run()
260
+
261
+
262
+ # test = TresholdPlotCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
263
+ # frame_setting=False,
264
+ # video_setting=True,
265
+ # last_frame=True,
266
+ # clf_name='Attack',
267
+ # cores=5,
268
+ # files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'],
269
+ # style_attr={'width': 640, 'height': 480, 'font size': 10, 'line width': 3, 'color': 'blue', 'circle size': 20, 'y_max': 'auto'})
270
+ # test.create_plots()
271
+
272
+ # if __name__ == "__main__":
273
+ # test = TresholdPlotCreatorMultiprocess(config_path=r"C:\troubleshooting\sleap_two_animals\project_folder\project_config.ini",
274
+ # frame_setting=True,
275
+ # video_setting=False,
276
+ # last_frame=True,
277
+ # clf_name='Attack',
278
+ # data_path=[r"C:\troubleshooting\sleap_two_animals\project_folder\csv\machine_results\Together_1.csv"],
279
+ # size = (640, 480),
280
+ # font_size=10,
281
+ # line_width=6,
282
+ # line_color='Orange',
283
+ # y_max=None,
284
+ # line_opacity=0.8,
285
+ # cores=4)
286
+ # test.run()
287
+ #
288
+