simba-uw-tf-dev 4.6.1__py3-none-any.whl → 4.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. simba/SimBA.py +2 -2
  2. simba/assets/icons/frames_2.png +0 -0
  3. simba/data_processors/agg_clf_counter_mp.py +52 -53
  4. simba/data_processors/cuda/image.py +3 -1
  5. simba/data_processors/cue_light_analyzer.py +5 -9
  6. simba/mixins/geometry_mixin.py +14 -28
  7. simba/mixins/image_mixin.py +10 -14
  8. simba/mixins/train_model_mixin.py +2 -2
  9. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  10. simba/plotting/clf_validator_mp.py +4 -5
  11. simba/plotting/cue_light_visualizer.py +6 -7
  12. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  13. simba/plotting/distance_plotter_mp.py +378 -378
  14. simba/plotting/frame_mergerer_ffmpeg.py +137 -137
  15. simba/plotting/gantt_creator_mp.py +59 -31
  16. simba/plotting/geometry_plotter.py +270 -272
  17. simba/plotting/heat_mapper_clf_mp.py +2 -4
  18. simba/plotting/heat_mapper_location_mp.py +2 -2
  19. simba/plotting/light_dark_box_plotter.py +2 -2
  20. simba/plotting/path_plotter_mp.py +26 -29
  21. simba/plotting/plot_clf_results_mp.py +455 -454
  22. simba/plotting/pose_plotter_mp.py +27 -32
  23. simba/plotting/probability_plot_creator_mp.py +288 -288
  24. simba/plotting/roi_plotter_mp.py +29 -30
  25. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  26. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  27. simba/plotting/yolo_pose_track_visualizer.py +31 -27
  28. simba/plotting/yolo_pose_visualizer.py +32 -34
  29. simba/plotting/yolo_seg_visualizer.py +2 -3
  30. simba/roi_tools/roi_aggregate_stats_mp.py +4 -3
  31. simba/roi_tools/roi_clf_calculator_mp.py +3 -3
  32. simba/sandbox/cuda/egocentric_rotator.py +374 -0
  33. simba/sandbox/get_cpu_pool.py +5 -0
  34. simba/ui/pop_ups/clf_add_remove_print_pop_up.py +3 -1
  35. simba/ui/pop_ups/egocentric_alignment_pop_up.py +6 -3
  36. simba/ui/pop_ups/multiple_videos_to_frames_popup.py +10 -11
  37. simba/ui/pop_ups/single_video_to_frames_popup.py +10 -10
  38. simba/ui/pop_ups/video_processing_pop_up.py +63 -63
  39. simba/ui/tkinter_functions.py +7 -1
  40. simba/utils/data.py +89 -12
  41. simba/utils/enums.py +1 -0
  42. simba/utils/printing.py +9 -8
  43. simba/utils/read_write.py +3726 -3721
  44. simba/video_processors/clahe_ui.py +65 -22
  45. simba/video_processors/egocentric_video_rotator.py +6 -9
  46. simba/video_processors/video_processing.py +21 -10
  47. simba/video_processors/videos_to_frames.py +3 -2
  48. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/METADATA +1 -1
  49. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/RECORD +53 -50
  50. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/LICENSE +0 -0
  51. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/WHEEL +0 -0
  52. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/entry_points.txt +0 -0
  53. {simba_uw_tf_dev-4.6.1.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/top_level.txt +0 -0
@@ -1,455 +1,456 @@
1
- __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
-
3
- import functools
4
- import multiprocessing
5
- import os
6
- import platform
7
- from copy import deepcopy
8
- from typing import List, Optional, Tuple, Union
9
-
10
- import cv2
11
- import numpy as np
12
- import pandas as pd
13
-
14
- from simba.mixins.config_reader import ConfigReader
15
- from simba.mixins.geometry_mixin import GeometryMixin
16
- from simba.mixins.plotting_mixin import PlottingMixin
17
- from simba.mixins.train_model_mixin import TrainModelMixin
18
- from simba.utils.checks import (check_float, check_if_valid_rgb_tuple,
19
- check_int, check_nvidea_gpu_available,
20
- check_str, check_that_column_exist,
21
- check_valid_boolean,
22
- check_video_and_data_frm_count_align)
23
- from simba.utils.data import create_color_palette, detect_bouts
24
- from simba.utils.enums import ConfigKey, Dtypes, Options, TagNames, TextOptions
25
- from simba.utils.errors import (InvalidInputError, NoDataError,
26
- NoSpecifiedOutputError)
27
- from simba.utils.lookups import get_current_time
28
- from simba.utils.printing import SimbaTimer, log_event, stdout_success
29
- from simba.utils.read_write import (concatenate_videos_in_folder,
30
- create_directory,
31
- find_all_videos_in_project, find_core_cnt,
32
- get_fn_ext, get_video_meta_data,
33
- read_config_entry, read_df)
34
- from simba.utils.warnings import FrameRangeWarning
35
-
36
-
37
- def _multiprocess_sklearn_video(data: pd.DataFrame,
38
- bp_dict: dict,
39
- video_save_dir: str,
40
- frame_save_dir: str,
41
- clf_cumsum: dict,
42
- rotate: bool,
43
- video_path: str,
44
- print_timers: bool,
45
- video_setting: bool,
46
- frame_setting: bool,
47
- pose_threshold: float,
48
- clf_confidence: Union[dict, None],
49
- show_pose: bool,
50
- show_animal_names: bool,
51
- show_bbox: bool,
52
- circle_size: int,
53
- font_size: int,
54
- space_size: int,
55
- text_thickness: int,
56
- text_opacity: float,
57
- text_bg_clr: Tuple[int, int, int],
58
- text_color: Tuple[int, int, int],
59
- pose_clr_lst: List[Tuple[int, int, int]],
60
- show_gantt: Optional[int],
61
- bouts_df: Optional[pd.DataFrame],
62
- final_gantt: Optional[np.ndarray],
63
- gantt_clrs: List[Tuple[float, float, float]],
64
- clf_names: List[str],
65
- verbose:bool):
66
-
67
- fourcc, font = cv2.VideoWriter_fourcc(*"mp4v"), cv2.FONT_HERSHEY_DUPLEX
68
- video_meta_data = get_video_meta_data(video_path=video_path)
69
- if rotate:
70
- video_meta_data["height"], video_meta_data["width"] = (video_meta_data['width'], video_meta_data['height'])
71
- cap = cv2.VideoCapture(video_path)
72
- batch, data = data
73
- start_frm, current_frm, end_frm = (data["index"].iloc[0], data["index"].iloc[0], data["index"].iloc[-1])
74
- if video_setting:
75
- video_save_path = os.path.join(video_save_dir, f"{batch}.mp4")
76
- if show_gantt is None:
77
- video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"], video_meta_data["height"]))
78
- else:
79
- video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (int(video_meta_data["width"] + final_gantt.shape[1]), video_meta_data["height"]))
80
- cap.set(1, start_frm)
81
- while current_frm < end_frm:
82
- ret, img = cap.read()
83
- if ret:
84
- clr_cnt = 0
85
- for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
86
- if show_pose:
87
- for bp_no in range(len(animal_data["X_bps"])):
88
- x_bp, y_bp, p_bp = (animal_data["X_bps"][bp_no], animal_data["Y_bps"][bp_no], animal_data["P_bps"][bp_no])
89
- bp_cords = data.loc[current_frm, [x_bp, y_bp, p_bp]]
90
- if bp_cords[p_bp] >= pose_threshold:
91
- img = cv2.circle(img, (int(bp_cords[x_bp]), int(bp_cords[y_bp])), circle_size, pose_clr_lst[clr_cnt], -1)
92
- clr_cnt += 1
93
- if show_animal_names:
94
- x_bp, y_bp, p_bp = (animal_data["X_bps"][0], animal_data["Y_bps"][0], animal_data["P_bps"][0])
95
- bp_cords = data.loc[current_frm, [x_bp, y_bp, p_bp]]
96
- img = cv2.putText(img, animal_name, (int(bp_cords[x_bp]), int(bp_cords[y_bp])), font, font_size, pose_clr_lst[0], text_thickness)
97
- if show_bbox:
98
- animal_headers = [val for pair in zip(animal_data["X_bps"], animal_data["Y_bps"]) for val in pair]
99
- animal_cords = data.loc[current_frm, animal_headers].values.reshape(-1, 2).astype(np.int32)
100
- try:
101
- bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
102
- img = cv2.polylines(img, [bbox], True, pose_clr_lst[animal_cnt], thickness=circle_size, lineType=cv2.LINE_AA)
103
- except Exception as e:
104
- #print(e.args)
105
- pass
106
- if rotate:
107
- img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
108
- if show_gantt == 1:
109
- img = np.concatenate((img, final_gantt), axis=1)
110
- elif show_gantt == 2:
111
- bout_rows = bouts_df.loc[bouts_df["End_frame"] <= current_frm]
112
- gantt_plot = PlottingMixin().make_gantt_plot(x_length=current_frm + 1,
113
- bouts_df=bout_rows,
114
- clf_names=clf_names,
115
- fps=video_meta_data['fps'],
116
- width=video_meta_data['width'],
117
- height=video_meta_data['height'],
118
- font_size=12,
119
- font_rotation=90,
120
- video_name=video_meta_data['video_name'],
121
- save_path=None,
122
- palette=gantt_clrs)
123
- img = np.concatenate((img, gantt_plot), axis=1)
124
- if print_timers:
125
- img = PlottingMixin().put_text(img=img, text="TIMERS:", pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
126
- add_spacer = 2
127
- for clf_name, clf_time_df in clf_cumsum.items():
128
- frame_results = clf_time_df.loc[current_frm]
129
- clf_time = round(frame_results / video_meta_data['fps'], 2)
130
- if print_timers:
131
- img = PlottingMixin().put_text(img=img, text=f"{clf_name} {clf_time}",pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
132
- add_spacer += 1
133
- if clf_confidence is not None:
134
- frm_clf_conf_txt = f'{clf_name} CONFIDENCE {clf_confidence[clf_name][current_frm]}'
135
- img = PlottingMixin().put_text(img=img, text=frm_clf_conf_txt,pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
136
- add_spacer += 1
137
-
138
- img = PlottingMixin().put_text(img=img, text="ENSEMBLE PREDICTION:", pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
139
- add_spacer += 1
140
- for clf_name in clf_cumsum.keys():
141
- if data.loc[current_frm, clf_name] == 1:
142
- img = PlottingMixin().put_text(img=img, text=clf_name, pos=(TextOptions.BORDER_BUFFER_Y.value, (video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer), font_size=font_size, font_thickness=text_thickness, font=font, text_color=TextOptions.COLOR.value, text_bg_alpha=text_opacity)
143
- add_spacer += 1
144
- if video_setting:
145
- video_writer.write(img.astype(np.uint8))
146
- if frame_setting:
147
- frame_save_name = os.path.join(frame_save_dir, f"{current_frm}.png")
148
- cv2.imwrite(frame_save_name, img)
149
- current_frm += 1
150
- if verbose: print(f"Multi-processing video frame {current_frm} on core {batch}...")
151
- else:
152
- FrameRangeWarning(msg=f'Could not read frame {current_frm} in video {video_path}. Stopping video creation.')
153
- break
154
-
155
- cap.release()
156
- if video_setting:
157
- video_writer.release()
158
- return batch
159
-
160
-
161
- class PlotSklearnResultsMultiProcess(ConfigReader, TrainModelMixin, PlottingMixin):
162
- """
163
- Plot classification results on videos using multiprocessing. Results are stored in the
164
- `project_folder/frames/output/sklearn_results` directory of the SimBA project.
165
-
166
- This class creates annotated videos/frames showing classifier predictions overlaid on pose-estimation data,
167
- with optional Gantt charts, timers, and bounding boxes. Processing is parallelized across multiple CPU cores
168
- for faster rendering of large video datasets.
169
-
170
- .. seealso::
171
- `Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#step-10-sklearn-visualization>`__.
172
- For single-core processing, see :meth:`simba.plotting.plot_clf_results.PlotSklearnResultsSingleCore`.
173
-
174
- .. image:: _static/img/sklearn_visualization.gif
175
- :width: 600
176
- :align: center
177
-
178
- .. video:: _static/img/T1.webm
179
- :width: 1000
180
- :autoplay:
181
- :loop:
182
-
183
- .. youtube:: Frq6mMcaHBc
184
- :width: 640
185
- :height: 480
186
- :align: center
187
-
188
- :param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format.
189
- :param bool video_setting: If True, creates compressed MP4 videos. Default True.
190
- :param bool frame_setting: If True, saves individual annotated frames as PNG images. Default False.
191
- :param Optional[Union[List[Union[str, os.PathLike]], Union[str, os.PathLike]]] video_paths: Path(s) to video file(s) to process. If None, processes all videos found in the project's video directory. Default None.
192
- :param bool rotate: If True, rotates output videos 90 degrees clockwise. Default False.
193
- :param bool animal_names: If True, displays animal names on the video frames. Default False.
194
- :param bool show_pose: If True, overlays pose-estimation keypoints on the video. Default True.
195
- :param Optional[Union[int, float]] font_size: Font size for text overlays. If None, auto-computed based on video resolution. Default None.
196
- :param Optional[Union[int, float]] space_size: Vertical spacing between text lines. If None, auto-computed. Default None.
197
- :param Optional[Union[int, float]] text_thickness: Thickness of text characters. If None, uses default. Default None.
198
- :param Optional[Union[int, float]] text_opacity: Opacity of text background (0.0-1.0). If None, defaults to 0.8. Default None.
199
- :param Optional[Union[int, float]] circle_size: Radius of pose keypoint circles. If None, auto-computed based on video resolution. Default None.
200
- :param Optional[str] pose_palette: Name of color palette for pose keypoints. Must be from :class:`simba.utils.enums.Options.PALETTE_OPTIONS_CATEGORICAL` or :class:`simba.utils.enums.Options.PALETTE_OPTIONS`. Default 'Set1'.
201
- :param bool print_timers: If True, displays cumulative time for each classifier behavior on each frame. Default True.
202
- :param bool show_bbox: If True, draws axis-aligned bounding boxes around detected animals. Default False.
203
- :param Optional[int] show_gantt: If 1, appends static Gantt chart to video. If 2, appends dynamic Gantt chart that updates per frame. If None, no Gantt chart. Default None.
204
- :param Tuple[int, int, int] text_clr: RGB color tuple for text foreground. Default (255, 255, 255) (white).
205
- :param Tuple[int, int, int] text_bg_clr: RGB color tuple for text background. Default (0, 0, 0) (black).
206
- :param bool gpu: If True, uses GPU acceleration for video concatenation (requires CUDA-capable GPU). Default False.
207
- :param int core_cnt: Number of CPU cores to use for parallel processing. Pass -1 to use all available cores. Default -1.
208
-
209
- :example:
210
- >>> clf_plotter = PlotSklearnResultsMultiProcess(
211
- ... config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini',
212
- ... video_setting=True,
213
- ... frame_setting=False,
214
- ... video_paths='Trial_10.mp4',
215
- ... rotate=False,
216
- ... show_pose=True,
217
- ... show_bbox=True,
218
- ... print_timers=True,
219
- ... show_gantt=1,
220
- ... core_cnt=5
221
- ... )
222
- >>> clf_plotter.run()
223
- """
224
-
225
- def __init__(self,
226
- config_path: Union[str, os.PathLike],
227
- video_setting: bool = True,
228
- frame_setting: bool = False,
229
- video_paths: Optional[Union[List[Union[str, os.PathLike]], Union[str, os.PathLike]]] = None,
230
- rotate: bool = False,
231
- animal_names: bool = False,
232
- show_pose: bool = True,
233
- show_confidence: bool = False,
234
- font_size: Optional[Union[int, float]] = None,
235
- space_size: Optional[Union[int, float]] = None,
236
- text_thickness: Optional[Union[int, float]] = None,
237
- text_opacity: Optional[Union[int, float]] = None,
238
- circle_size: Optional[Union[int, float]] = None,
239
- pose_palette: Optional[str] = 'Set1',
240
- print_timers: bool = True,
241
- show_bbox: bool = False,
242
- show_gantt: Optional[int] = None,
243
- text_clr: Tuple[int, int, int] = (255, 255, 255),
244
- text_bg_clr: Tuple[int, int, int] = (0, 0, 0),
245
- gpu: bool = False,
246
- verbose: bool = True,
247
- core_cnt: int = -1):
248
-
249
-
250
- ConfigReader.__init__(self, config_path=config_path)
251
- TrainModelMixin.__init__(self)
252
- PlottingMixin.__init__(self)
253
- log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals()))
254
- for i in [video_setting, frame_setting, rotate, print_timers, animal_names, show_pose, gpu, show_bbox, show_confidence]:
255
- check_valid_boolean(value=i, source=self.__class__.__name__, raise_error=True)
256
- if (not video_setting) and (not frame_setting):
257
- raise NoSpecifiedOutputError(msg="Please choose to create a video and/or frames. SimBA found that you ticked neither video and/or frames", source=self.__class__.__name__)
258
- if font_size is not None: check_float(name=f'{self.__class__.__name__} font_size', value=font_size, min_value=0.1)
259
- if space_size is not None: check_float(name=f'{self.__class__.__name__} space_size', value=space_size, min_value=0.1)
260
- if text_thickness is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=text_thickness, min_value=0.1)
261
- if circle_size is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=circle_size, min_value=0.1)
262
- if circle_size is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=circle_size, min_value=0.1)
263
- if text_opacity is not None: check_float(name=f'{self.__class__.__name__} text_opacity', value=text_opacity, min_value=0.1)
264
- pose_palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
265
- check_str(name=f'{self.__class__.__name__} pose_palette', value=pose_palette, options=pose_palettes)
266
- self.clr_lst = create_color_palette(pallete_name=pose_palette, increments=len(self.body_parts_lst)+1)
267
- check_if_valid_rgb_tuple(data=text_clr, source=f'{self.__class__.__name__} text_clr')
268
- check_if_valid_rgb_tuple(data=text_bg_clr, source=f'{self.__class__.__name__} text_bg_clr')
269
- check_valid_boolean(value=verbose, source=f'{self.__class__.__name__} verbose', raise_error=True)
270
- if show_gantt is not None:
271
- check_int(name=f"{self.__class__.__name__} show_gantt", value=show_gantt, max_value=2, min_value=1)
272
- self.video_setting, self.frame_setting, self.rotate, self.print_timers = video_setting, frame_setting, rotate, print_timers
273
- self.circle_size, self.font_size, self.animal_names, self.text_opacity = circle_size, font_size, animal_names, text_opacity
274
- self.text_thickness, self.space_size, self.show_pose, self.pose_palette, self.verbose = text_thickness, space_size, show_pose, pose_palette, verbose
275
- self.text_color, self.text_bg_color, self.show_bbox, self.show_gantt, self.show_confidence = text_clr, text_bg_clr, show_bbox, show_gantt, show_confidence
276
- self.gpu = True if check_nvidea_gpu_available() and gpu else False
277
- self.pose_threshold = read_config_entry(self.config, ConfigKey.THRESHOLD_SETTINGS.value, ConfigKey.SKLEARN_BP_PROB_THRESH.value, Dtypes.FLOAT.value, 0.00)
278
- if not os.path.exists(self.sklearn_plot_dir):
279
- os.makedirs(self.sklearn_plot_dir)
280
- if isinstance(video_paths, str): self.video_paths = [video_paths]
281
- elif isinstance(video_paths, list): self.video_paths = video_paths
282
- elif video_paths is None:
283
- self.video_paths = find_all_videos_in_project(videos_dir=self.video_dir)
284
- if len(self.video_paths) == 0:
285
- raise NoDataError(msg=f'Cannot create classification videos. No videos exist in {self.video_dir} directory', source=self.__class__.__name__)
286
- else:
287
- raise InvalidInputError(msg=f'video_paths has to be a path of a list of paths. Got {type(video_paths)}', source=self.__class__.__name__)
288
-
289
- for video_path in self.video_paths:
290
- video_name = get_fn_ext(filepath=video_path)[1]
291
- data_path = os.path.join(self.machine_results_dir, f'{video_name}.{self.file_type}')
292
- if not os.path.isfile(data_path): raise NoDataError(msg=f'Cannot create classification videos for {video_name}. Expected classification data at location {data_path} but file does not exist', source=self.__class__.__name__)
293
- check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0])
294
- self.core_cnt = find_core_cnt()[0] if int(core_cnt) == -1 or int(core_cnt) > find_core_cnt()[0] else int(core_cnt)
295
- self.conf_cols = [f'Probability_{x}' for x in self.clf_names]
296
- if platform.system() == "Darwin":
297
- multiprocessing.set_start_method("spawn", force=True)
298
-
299
- def __get_print_settings(self):
300
- optimal_circle_size = self.get_optimal_circle_size(frame_size=(self.video_meta_data["width"], self.video_meta_data["height"]), circle_frame_ratio=100)
301
- longest_str = str(max(['TIMERS:', 'ENSEMBLE PREDICTION:'] + self.clf_names, key=len))
302
- self.video_text_thickness = TextOptions.TEXT_THICKNESS.value if self.text_thickness is None else int(max(self.text_thickness, 1))
303
- optimal_font_size, _, optimal_spacing_scale = self.get_optimal_font_scales(text=longest_str, accepted_px_width=int(self.video_meta_data["width"] / 3), accepted_px_height=int(self.video_meta_data["height"] / 10), text_thickness=self.video_text_thickness)
304
- self.video_circle_size = optimal_circle_size if self.circle_size is None else int(max(1, self.circle_size))
305
- self.video_font_size = optimal_font_size if self.font_size is None else self.font_size
306
- self.video_space_size = optimal_spacing_scale if self.space_size is None else int(max(self.space_size, 1))
307
- self.video_text_opacity = 0.8 if self.text_opacity is None else float(self.text_opacity)
308
-
309
- def run(self):
310
- if self.verbose: print(f'Creating {len(self.video_paths)} classification visualization(s) using {self.core_cnt} cores... ({get_current_time()})')
311
- for video_cnt, video_path in enumerate(self.video_paths):
312
- video_timer = SimbaTimer(start=True)
313
- _, self.video_name, _ = get_fn_ext(video_path)
314
- if self.verbose: print(f"Creating classification visualization for video {self.video_name}... ({get_current_time()})")
315
- self.data_path = os.path.join(self.machine_results_dir, f'{self.video_name}.{self.file_type}')
316
- self.data_df = read_df(self.data_path, self.file_type).reset_index(drop=True).fillna(0)
317
- if self.show_pose: check_that_column_exist(df=self.data_df, column_name=self.bp_col_names, file_name=self.data_path)
318
- if self.show_confidence: check_that_column_exist(df=self.data_df, column_name=self.conf_cols, file_name=self.data_path)
319
- self.video_meta_data = get_video_meta_data(video_path=video_path)
320
- height, width = deepcopy(self.video_meta_data["height"]), deepcopy(self.video_meta_data["width"])
321
- self.save_path = os.path.join(self.sklearn_plot_dir, f"{self.video_name}.mp4")
322
- self.video_frame_dir, self.video_temp_dir = None, None
323
- if self.video_setting:
324
- self.video_save_path = os.path.join(self.sklearn_plot_dir, f"{self.video_name}.mp4")
325
- self.video_temp_dir = os.path.join(self.sklearn_plot_dir, self.video_name, "temp")
326
- create_directory(paths=self.video_temp_dir, overwrite=True)
327
- if self.frame_setting:
328
- self.video_frame_dir = os.path.join(self.sklearn_plot_dir, self.video_name)
329
- create_directory(paths=self.video_temp_dir, overwrite=True)
330
- if self.rotate:
331
- self.video_meta_data["height"], self.video_meta_data["width"] = (width, height)
332
- check_video_and_data_frm_count_align(video=video_path, data=self.data_df, name=self.video_name, raise_error=False)
333
- check_that_column_exist(df=self.data_df, column_name=self.clf_names, file_name=self.data_path)
334
- self.__get_print_settings()
335
- if self.show_gantt is not None:
336
- self.gantt_clrs = create_color_palette(pallete_name=self.pose_palette, increments=len(self.clf_names) + 1, as_int=True, as_rgb_ratio=True)
337
- self.bouts_df = detect_bouts(data_df=self.data_df, target_lst=list(self.clf_names), fps=int(self.video_meta_data["fps"]))
338
- self.final_gantt_img = PlottingMixin().make_gantt_plot(x_length=len(self.data_df) + 1, bouts_df=self.bouts_df, clf_names=self.clf_names, fps=self.video_meta_data["fps"], width=self.video_meta_data["width"], height=self.video_meta_data["height"], font_size=12, font_rotation=90, video_name=self.video_meta_data["video_name"], save_path=None, palette=self.gantt_clrs)
339
- self.final_gantt_img = self.resize_gantt(self.final_gantt_img, self.video_meta_data["height"])
340
- else:
341
- self.bouts_df, self.final_gantt_img, self.gantt_clrs = None, None, None
342
-
343
-
344
- self.clf_cumsums, self.clf_p = {}, {} if self.show_confidence else None
345
- for clf_name in self.clf_names:
346
- self.clf_cumsums[clf_name] = self.data_df[clf_name].cumsum()
347
- if self.show_confidence: self.clf_p[clf_name] = np.round(self.data_df[f'Probability_{clf_name}'].values.reshape(-1), 4)
348
-
349
- self.data_df["index"] = self.data_df.index
350
- data = np.array_split(self.data_df, self.core_cnt)
351
- data = [(cnt, x) for (cnt, x) in enumerate(data)]
352
-
353
- with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool:
354
- constants = functools.partial(_multiprocess_sklearn_video,
355
- bp_dict=self.animal_bp_dict,
356
- video_save_dir=self.video_temp_dir,
357
- frame_save_dir=self.video_frame_dir,
358
- clf_cumsum=self.clf_cumsums,
359
- rotate=self.rotate,
360
- video_path=video_path,
361
- clf_confidence=self.clf_p,
362
- print_timers=self.print_timers,
363
- video_setting=self.video_setting,
364
- frame_setting=self.frame_setting,
365
- pose_threshold=self.pose_threshold,
366
- show_pose=self.show_pose,
367
- show_animal_names=self.animal_names,
368
- circle_size=self.video_circle_size,
369
- font_size=self.video_font_size,
370
- space_size=self.video_space_size,
371
- text_thickness=self.video_text_thickness,
372
- text_opacity=self.video_text_opacity,
373
- text_bg_clr=self.text_bg_color,
374
- text_color=self.text_color,
375
- pose_clr_lst=self.clr_lst,
376
- show_bbox=self.show_bbox,
377
- show_gantt=self.show_gantt,
378
- bouts_df=self.bouts_df,
379
- final_gantt=self.final_gantt_img,
380
- gantt_clrs=self.gantt_clrs,
381
- clf_names=self.clf_names,
382
- verbose=self.verbose)
383
-
384
- for cnt, result in enumerate(pool.imap(constants, data, chunksize=self.multiprocess_chunksize)):
385
- if self.verbose: print(f"Image batch {result} complete, Video {(video_cnt + 1)}/{len(self.video_paths)}...")
386
-
387
- if self.video_setting:
388
- if self.verbose: print(f"Joining {self.video_name} multiprocessed video...")
389
- concatenate_videos_in_folder(in_folder=self.video_temp_dir, save_path=self.video_save_path, gpu=self.gpu, verbose=self.verbose)
390
- video_timer.stop_timer()
391
- pool.terminate()
392
- pool.join()
393
- print(f"Video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s)...")
394
-
395
- self.timer.stop_timer()
396
- if self.video_setting:
397
- stdout_success(msg=f"{len(self.video_paths)} video(s) saved in {self.sklearn_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
398
- if self.frame_setting:
399
- stdout_success(f"Frames for {len(self.video_paths)} videos saved in sub-folders within {self.sklearn_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
400
-
401
-
402
-
403
- # if __name__ == "__main__":
404
- # clf_plotter = PlotSklearnResultsMultiProcess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
405
- # video_setting=True,
406
- # frame_setting=False,
407
- # video_paths=None,
408
- # print_timers=True,
409
- # rotate=False,
410
- # animal_names=False,
411
- # show_bbox=True,
412
- # show_gantt=None)
413
- # clf_plotter.run()
414
-
415
-
416
-
417
-
418
- # if __name__ == "__main__":
419
- # clf_plotter = PlotSklearnResultsMultiProcess(config_path=r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini",
420
- # video_setting=True,
421
- # frame_setting=False,
422
- # video_paths=r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/videos/Together_1.mp4",
423
- # print_timers=True,
424
- # rotate=False,
425
- # animal_names=False,
426
- # show_bbox=True,
427
- # show_gantt=None)
428
- # clf_plotter.run()
429
-
430
-
431
-
432
-
433
- #text_settings = {'circle_scale': 5, 'font_size': 0.528, 'spacing_scale': 28, 'text_thickness': 2}
434
- # clf_plotter = PlotSklearnResultsMultiProcess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini',
435
- # video_setting=True,
436
- # frame_setting=False,
437
- # rotate=False,
438
- # video_file_path='592_MA147_Gq_CNO_0515.mp4',
439
- # cores=-1,
440
- # text_settings=False)
441
- # clf_plotter.run()
442
- #
443
-
444
- # clf_plotter = PlotSklearnResultsMultiProcess(config_path='/Users/simon/Desktop/envs/troubleshooting/DLC_2_Black_animals/project_folder/project_config.ini', video_setting=True, frame_setting=False, rotate=False, video_file_path='Together_1.avi', cores=5)
445
- # clf_plotter.run()
446
-
447
- # if __name__ == "__main__":
448
- # clf_plotter = PlotSklearnResultsMultiProcess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
449
- # video_setting = True,
450
- # frame_setting = False,
451
- # rotate = False,
452
- # core_cnt = 6,
453
- # show_confidence=True,
454
- # video_paths=r"C:\troubleshooting\mitra\project_folder\videos\501_MA142_Gi_CNO_0521.mp4")
1
+ __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
+
3
+ import functools
4
+ import multiprocessing
5
+ import os
6
+ import platform
7
+ from copy import deepcopy
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ from simba.mixins.config_reader import ConfigReader
15
+ from simba.mixins.geometry_mixin import GeometryMixin
16
+ from simba.mixins.plotting_mixin import PlottingMixin
17
+ from simba.mixins.train_model_mixin import TrainModelMixin
18
+ from simba.utils.checks import (check_float, check_if_valid_rgb_tuple,
19
+ check_int, check_nvidea_gpu_available,
20
+ check_str, check_that_column_exist,
21
+ check_valid_boolean,
22
+ check_video_and_data_frm_count_align)
23
+ from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
24
+ terminate_cpu_pool)
25
+ from simba.utils.enums import ConfigKey, Dtypes, Options, TagNames, TextOptions
26
+ from simba.utils.errors import (InvalidInputError, NoDataError,
27
+ NoSpecifiedOutputError)
28
+ from simba.utils.lookups import get_current_time
29
+ from simba.utils.printing import SimbaTimer, log_event, stdout_success
30
+ from simba.utils.read_write import (concatenate_videos_in_folder,
31
+ create_directory,
32
+ find_all_videos_in_project, find_core_cnt,
33
+ get_fn_ext, get_video_meta_data,
34
+ read_config_entry, read_df)
35
+ from simba.utils.warnings import FrameRangeWarning
36
+
37
+
38
+ def _multiprocess_sklearn_video(data: pd.DataFrame,
39
+ bp_dict: dict,
40
+ video_save_dir: str,
41
+ frame_save_dir: str,
42
+ clf_cumsum: dict,
43
+ rotate: bool,
44
+ video_path: str,
45
+ print_timers: bool,
46
+ video_setting: bool,
47
+ frame_setting: bool,
48
+ pose_threshold: float,
49
+ clf_confidence: Union[dict, None],
50
+ show_pose: bool,
51
+ show_animal_names: bool,
52
+ show_bbox: bool,
53
+ circle_size: int,
54
+ font_size: int,
55
+ space_size: int,
56
+ text_thickness: int,
57
+ text_opacity: float,
58
+ text_bg_clr: Tuple[int, int, int],
59
+ text_color: Tuple[int, int, int],
60
+ pose_clr_lst: List[Tuple[int, int, int]],
61
+ show_gantt: Optional[int],
62
+ bouts_df: Optional[pd.DataFrame],
63
+ final_gantt: Optional[np.ndarray],
64
+ gantt_clrs: List[Tuple[float, float, float]],
65
+ clf_names: List[str],
66
+ verbose:bool):
67
+
68
+ fourcc, font = cv2.VideoWriter_fourcc(*"mp4v"), cv2.FONT_HERSHEY_DUPLEX
69
+ video_meta_data = get_video_meta_data(video_path=video_path)
70
+ if rotate:
71
+ video_meta_data["height"], video_meta_data["width"] = (video_meta_data['width'], video_meta_data['height'])
72
+ cap = cv2.VideoCapture(video_path)
73
+ batch, data = data
74
+ start_frm, current_frm, end_frm = (data["index"].iloc[0], data["index"].iloc[0], data["index"].iloc[-1])
75
+ if video_setting:
76
+ video_save_path = os.path.join(video_save_dir, f"{batch}.mp4")
77
+ if show_gantt is None:
78
+ video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"], video_meta_data["height"]))
79
+ else:
80
+ video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (int(video_meta_data["width"] + final_gantt.shape[1]), video_meta_data["height"]))
81
+ cap.set(1, start_frm)
82
+ while current_frm < end_frm:
83
+ ret, img = cap.read()
84
+ if ret:
85
+ clr_cnt = 0
86
+ for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
87
+ if show_pose:
88
+ for bp_no in range(len(animal_data["X_bps"])):
89
+ x_bp, y_bp, p_bp = (animal_data["X_bps"][bp_no], animal_data["Y_bps"][bp_no], animal_data["P_bps"][bp_no])
90
+ bp_cords = data.loc[current_frm, [x_bp, y_bp, p_bp]]
91
+ if bp_cords[p_bp] >= pose_threshold:
92
+ img = cv2.circle(img, (int(bp_cords[x_bp]), int(bp_cords[y_bp])), circle_size, pose_clr_lst[clr_cnt], -1)
93
+ clr_cnt += 1
94
+ if show_animal_names:
95
+ x_bp, y_bp, p_bp = (animal_data["X_bps"][0], animal_data["Y_bps"][0], animal_data["P_bps"][0])
96
+ bp_cords = data.loc[current_frm, [x_bp, y_bp, p_bp]]
97
+ img = cv2.putText(img, animal_name, (int(bp_cords[x_bp]), int(bp_cords[y_bp])), font, font_size, pose_clr_lst[0], text_thickness)
98
+ if show_bbox:
99
+ animal_headers = [val for pair in zip(animal_data["X_bps"], animal_data["Y_bps"]) for val in pair]
100
+ animal_cords = data.loc[current_frm, animal_headers].values.reshape(-1, 2).astype(np.int32)
101
+ try:
102
+ bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
103
+ img = cv2.polylines(img, [bbox], True, pose_clr_lst[animal_cnt], thickness=circle_size, lineType=cv2.LINE_AA)
104
+ except Exception as e:
105
+ #print(e.args)
106
+ pass
107
+ if rotate:
108
+ img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
109
+ if show_gantt == 1:
110
+ img = np.concatenate((img, final_gantt), axis=1)
111
+ elif show_gantt == 2:
112
+ bout_rows = bouts_df.loc[bouts_df["End_frame"] <= current_frm]
113
+ gantt_plot = PlottingMixin().make_gantt_plot(x_length=current_frm + 1,
114
+ bouts_df=bout_rows,
115
+ clf_names=clf_names,
116
+ fps=video_meta_data['fps'],
117
+ width=video_meta_data['width'],
118
+ height=video_meta_data['height'],
119
+ font_size=12,
120
+ font_rotation=90,
121
+ video_name=video_meta_data['video_name'],
122
+ save_path=None,
123
+ palette=gantt_clrs)
124
+ img = np.concatenate((img, gantt_plot), axis=1)
125
+ if print_timers:
126
+ img = PlottingMixin().put_text(img=img, text="TIMERS:", pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
127
+ add_spacer = 2
128
+ for clf_name, clf_time_df in clf_cumsum.items():
129
+ frame_results = clf_time_df.loc[current_frm]
130
+ clf_time = round(frame_results / video_meta_data['fps'], 2)
131
+ if print_timers:
132
+ img = PlottingMixin().put_text(img=img, text=f"{clf_name} {clf_time}",pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
133
+ add_spacer += 1
134
+ if clf_confidence is not None:
135
+ frm_clf_conf_txt = f'{clf_name} CONFIDENCE {clf_confidence[clf_name][current_frm]}'
136
+ img = PlottingMixin().put_text(img=img, text=frm_clf_conf_txt,pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
137
+ add_spacer += 1
138
+
139
+ img = PlottingMixin().put_text(img=img, text="ENSEMBLE PREDICTION:", pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
140
+ add_spacer += 1
141
+ for clf_name in clf_cumsum.keys():
142
+ if data.loc[current_frm, clf_name] == 1:
143
+ img = PlottingMixin().put_text(img=img, text=clf_name, pos=(TextOptions.BORDER_BUFFER_Y.value, (video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer), font_size=font_size, font_thickness=text_thickness, font=font, text_color=TextOptions.COLOR.value, text_bg_alpha=text_opacity)
144
+ add_spacer += 1
145
+ if video_setting:
146
+ video_writer.write(img.astype(np.uint8))
147
+ if frame_setting:
148
+ frame_save_name = os.path.join(frame_save_dir, f"{current_frm}.png")
149
+ cv2.imwrite(frame_save_name, img)
150
+ current_frm += 1
151
+ if verbose: print(f"[{get_current_time()}] Multi-processing video frame {current_frm}/{video_meta_data['frame_count']} (core batch: {batch}, video name: {video_meta_data['video_name']})...")
152
+ else:
153
+ FrameRangeWarning(msg=f'Could not read frame {current_frm} in video {video_path}. Stopping video creation.')
154
+ break
155
+
156
+ cap.release()
157
+ if video_setting:
158
+ video_writer.release()
159
+ return batch
160
+
161
+
162
+ class PlotSklearnResultsMultiProcess(ConfigReader, TrainModelMixin, PlottingMixin):
163
+ """
164
+ Plot classification results on videos using multiprocessing. Results are stored in the
165
+ `project_folder/frames/output/sklearn_results` directory of the SimBA project.
166
+
167
+ This class creates annotated videos/frames showing classifier predictions overlaid on pose-estimation data,
168
+ with optional Gantt charts, timers, and bounding boxes. Processing is parallelized across multiple CPU cores
169
+ for faster rendering of large video datasets.
170
+
171
+ .. seealso::
172
+ `Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#step-10-sklearn-visualization>`__.
173
+ For single-core processing, see :meth:`simba.plotting.plot_clf_results.PlotSklearnResultsSingleCore`.
174
+
175
+ .. image:: _static/img/sklearn_visualization.gif
176
+ :width: 600
177
+ :align: center
178
+
179
+ .. video:: _static/img/T1.webm
180
+ :width: 1000
181
+ :autoplay:
182
+ :loop:
183
+
184
+ .. youtube:: Frq6mMcaHBc
185
+ :width: 640
186
+ :height: 480
187
+ :align: center
188
+
189
+ :param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format.
190
+ :param bool video_setting: If True, creates compressed MP4 videos. Default True.
191
+ :param bool frame_setting: If True, saves individual annotated frames as PNG images. Default False.
192
+ :param Optional[Union[List[Union[str, os.PathLike]], Union[str, os.PathLike]]] video_paths: Path(s) to video file(s) to process. If None, processes all videos found in the project's video directory. Default None.
193
+ :param bool rotate: If True, rotates output videos 90 degrees clockwise. Default False.
194
+ :param bool animal_names: If True, displays animal names on the video frames. Default False.
195
+ :param bool show_pose: If True, overlays pose-estimation keypoints on the video. Default True.
196
+ :param Optional[Union[int, float]] font_size: Font size for text overlays. If None, auto-computed based on video resolution. Default None.
197
+ :param Optional[Union[int, float]] space_size: Vertical spacing between text lines. If None, auto-computed. Default None.
198
+ :param Optional[Union[int, float]] text_thickness: Thickness of text characters. If None, uses default. Default None.
199
+ :param Optional[Union[int, float]] text_opacity: Opacity of text background (0.0-1.0). If None, defaults to 0.8. Default None.
200
+ :param Optional[Union[int, float]] circle_size: Radius of pose keypoint circles. If None, auto-computed based on video resolution. Default None.
201
+ :param Optional[str] pose_palette: Name of color palette for pose keypoints. Must be from :class:`simba.utils.enums.Options.PALETTE_OPTIONS_CATEGORICAL` or :class:`simba.utils.enums.Options.PALETTE_OPTIONS`. Default 'Set1'.
202
+ :param bool print_timers: If True, displays cumulative time for each classifier behavior on each frame. Default True.
203
+ :param bool show_bbox: If True, draws axis-aligned bounding boxes around detected animals. Default False.
204
+ :param Optional[int] show_gantt: If 1, appends static Gantt chart to video. If 2, appends dynamic Gantt chart that updates per frame. If None, no Gantt chart. Default None.
205
+ :param Tuple[int, int, int] text_clr: RGB color tuple for text foreground. Default (255, 255, 255) (white).
206
+ :param Tuple[int, int, int] text_bg_clr: RGB color tuple for text background. Default (0, 0, 0) (black).
207
+ :param bool gpu: If True, uses GPU acceleration for video concatenation (requires CUDA-capable GPU). Default False.
208
+ :param int core_cnt: Number of CPU cores to use for parallel processing. Pass -1 to use all available cores. Default -1.
209
+
210
+ :example:
211
+ >>> clf_plotter = PlotSklearnResultsMultiProcess(
212
+ ... config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini',
213
+ ... video_setting=True,
214
+ ... frame_setting=False,
215
+ ... video_paths='Trial_10.mp4',
216
+ ... rotate=False,
217
+ ... show_pose=True,
218
+ ... show_bbox=True,
219
+ ... print_timers=True,
220
+ ... show_gantt=1,
221
+ ... core_cnt=5
222
+ ... )
223
+ >>> clf_plotter.run()
224
+ """
225
+
226
+ def __init__(self,
227
+ config_path: Union[str, os.PathLike],
228
+ video_setting: bool = True,
229
+ frame_setting: bool = False,
230
+ video_paths: Optional[Union[List[Union[str, os.PathLike]], Union[str, os.PathLike]]] = None,
231
+ rotate: bool = False,
232
+ animal_names: bool = False,
233
+ show_pose: bool = True,
234
+ show_confidence: bool = False,
235
+ font_size: Optional[Union[int, float]] = None,
236
+ space_size: Optional[Union[int, float]] = None,
237
+ text_thickness: Optional[Union[int, float]] = None,
238
+ text_opacity: Optional[Union[int, float]] = None,
239
+ circle_size: Optional[Union[int, float]] = None,
240
+ pose_palette: Optional[str] = 'Set1',
241
+ print_timers: bool = True,
242
+ show_bbox: bool = False,
243
+ show_gantt: Optional[int] = None,
244
+ text_clr: Tuple[int, int, int] = (255, 255, 255),
245
+ text_bg_clr: Tuple[int, int, int] = (0, 0, 0),
246
+ gpu: bool = False,
247
+ verbose: bool = True,
248
+ core_cnt: int = -1):
249
+
250
+
251
+ ConfigReader.__init__(self, config_path=config_path)
252
+ TrainModelMixin.__init__(self)
253
+ PlottingMixin.__init__(self)
254
+ log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals()))
255
+ for i in [video_setting, frame_setting, rotate, print_timers, animal_names, show_pose, gpu, show_bbox, show_confidence]:
256
+ check_valid_boolean(value=i, source=self.__class__.__name__, raise_error=True)
257
+ if (not video_setting) and (not frame_setting):
258
+ raise NoSpecifiedOutputError(msg="Please choose to create a video and/or frames. SimBA found that you ticked neither video and/or frames", source=self.__class__.__name__)
259
+ if font_size is not None: check_float(name=f'{self.__class__.__name__} font_size', value=font_size, min_value=0.1)
260
+ if space_size is not None: check_float(name=f'{self.__class__.__name__} space_size', value=space_size, min_value=0.1)
261
+ if text_thickness is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=text_thickness, min_value=0.1)
262
+ if circle_size is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=circle_size, min_value=0.1)
263
+ if circle_size is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=circle_size, min_value=0.1)
264
+ if text_opacity is not None: check_float(name=f'{self.__class__.__name__} text_opacity', value=text_opacity, min_value=0.1)
265
+ pose_palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
266
+ check_str(name=f'{self.__class__.__name__} pose_palette', value=pose_palette, options=pose_palettes)
267
+ self.clr_lst = create_color_palette(pallete_name=pose_palette, increments=len(self.body_parts_lst)+1)
268
+ check_if_valid_rgb_tuple(data=text_clr, source=f'{self.__class__.__name__} text_clr')
269
+ check_if_valid_rgb_tuple(data=text_bg_clr, source=f'{self.__class__.__name__} text_bg_clr')
270
+ check_valid_boolean(value=verbose, source=f'{self.__class__.__name__} verbose', raise_error=True)
271
+ if show_gantt is not None:
272
+ check_int(name=f"{self.__class__.__name__} show_gantt", value=show_gantt, max_value=2, min_value=1)
273
+ self.video_setting, self.frame_setting, self.rotate, self.print_timers = video_setting, frame_setting, rotate, print_timers
274
+ self.circle_size, self.font_size, self.animal_names, self.text_opacity = circle_size, font_size, animal_names, text_opacity
275
+ self.text_thickness, self.space_size, self.show_pose, self.pose_palette, self.verbose = text_thickness, space_size, show_pose, pose_palette, verbose
276
+ self.text_color, self.text_bg_color, self.show_bbox, self.show_gantt, self.show_confidence = text_clr, text_bg_clr, show_bbox, show_gantt, show_confidence
277
+ self.gpu = True if check_nvidea_gpu_available() and gpu else False
278
+ self.pose_threshold = read_config_entry(self.config, ConfigKey.THRESHOLD_SETTINGS.value, ConfigKey.SKLEARN_BP_PROB_THRESH.value, Dtypes.FLOAT.value, 0.00)
279
+ if not os.path.exists(self.sklearn_plot_dir):
280
+ os.makedirs(self.sklearn_plot_dir)
281
+ if isinstance(video_paths, str): self.video_paths = [video_paths]
282
+ elif isinstance(video_paths, list): self.video_paths = video_paths
283
+ elif video_paths is None:
284
+ self.video_paths = find_all_videos_in_project(videos_dir=self.video_dir)
285
+ if len(self.video_paths) == 0:
286
+ raise NoDataError(msg=f'Cannot create classification videos. No videos exist in {self.video_dir} directory', source=self.__class__.__name__)
287
+ else:
288
+ raise InvalidInputError(msg=f'video_paths has to be a path of a list of paths. Got {type(video_paths)}', source=self.__class__.__name__)
289
+
290
+ for video_path in self.video_paths:
291
+ video_name = get_fn_ext(filepath=video_path)[1]
292
+ data_path = os.path.join(self.machine_results_dir, f'{video_name}.{self.file_type}')
293
+ if not os.path.isfile(data_path): raise NoDataError(msg=f'Cannot create classification videos for {video_name}. Expected classification data at location {data_path} but file does not exist', source=self.__class__.__name__)
294
+ check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0])
295
+ self.core_cnt = find_core_cnt()[0] if int(core_cnt) == -1 or int(core_cnt) > find_core_cnt()[0] else int(core_cnt)
296
+ self.conf_cols = [f'Probability_{x}' for x in self.clf_names]
297
+ if platform.system() == "Darwin":
298
+ multiprocessing.set_start_method("spawn", force=True)
299
+
300
+ def __get_print_settings(self):
301
+ optimal_circle_size = self.get_optimal_circle_size(frame_size=(self.video_meta_data["width"], self.video_meta_data["height"]), circle_frame_ratio=100)
302
+ longest_str = str(max(['TIMERS:', 'ENSEMBLE PREDICTION:'] + self.clf_names, key=len))
303
+ self.video_text_thickness = TextOptions.TEXT_THICKNESS.value if self.text_thickness is None else int(max(self.text_thickness, 1))
304
+ optimal_font_size, _, optimal_spacing_scale = self.get_optimal_font_scales(text=longest_str, accepted_px_width=int(self.video_meta_data["width"] / 3), accepted_px_height=int(self.video_meta_data["height"] / 10), text_thickness=self.video_text_thickness)
305
+ self.video_circle_size = optimal_circle_size if self.circle_size is None else int(max(1, self.circle_size))
306
+ self.video_font_size = optimal_font_size if self.font_size is None else self.font_size
307
+ self.video_space_size = optimal_spacing_scale if self.space_size is None else int(max(self.space_size, 1))
308
+ self.video_text_opacity = 0.8 if self.text_opacity is None else float(self.text_opacity)
309
+
310
+ def run(self):
311
+ if self.verbose: print(f'Creating {len(self.video_paths)} classification visualization(s) using {self.core_cnt} cores... ({get_current_time()})')
312
+ self.pool = get_cpu_pool(core_cnt=self.core_cnt, source=self.__class__.__name__, )
313
+ for video_cnt, video_path in enumerate(self.video_paths):
314
+ video_timer = SimbaTimer(start=True)
315
+ _, self.video_name, _ = get_fn_ext(video_path)
316
+ if self.verbose: print(f"[{get_current_time()}] Creating classification visualization for video {self.video_name}...")
317
+ self.data_path = os.path.join(self.machine_results_dir, f'{self.video_name}.{self.file_type}')
318
+ self.data_df = read_df(self.data_path, self.file_type).reset_index(drop=True).fillna(0)
319
+ if self.show_pose: check_that_column_exist(df=self.data_df, column_name=self.bp_col_names, file_name=self.data_path)
320
+ if self.show_confidence: check_that_column_exist(df=self.data_df, column_name=self.conf_cols, file_name=self.data_path)
321
+ self.video_meta_data = get_video_meta_data(video_path=video_path)
322
+ height, width = deepcopy(self.video_meta_data["height"]), deepcopy(self.video_meta_data["width"])
323
+ self.save_path = os.path.join(self.sklearn_plot_dir, f"{self.video_name}.mp4")
324
+ self.video_frame_dir, self.video_temp_dir = None, None
325
+ if self.video_setting:
326
+ self.video_save_path = os.path.join(self.sklearn_plot_dir, f"{self.video_name}.mp4")
327
+ self.video_temp_dir = os.path.join(self.sklearn_plot_dir, self.video_name, "temp")
328
+ create_directory(paths=self.video_temp_dir, overwrite=True)
329
+ if self.frame_setting:
330
+ self.video_frame_dir = os.path.join(self.sklearn_plot_dir, self.video_name)
331
+ create_directory(paths=self.video_temp_dir, overwrite=True)
332
+ if self.rotate:
333
+ self.video_meta_data["height"], self.video_meta_data["width"] = (width, height)
334
+ check_video_and_data_frm_count_align(video=video_path, data=self.data_df, name=self.video_name, raise_error=False)
335
+ check_that_column_exist(df=self.data_df, column_name=self.clf_names, file_name=self.data_path)
336
+ self.__get_print_settings()
337
+ if self.show_gantt is not None:
338
+ self.gantt_clrs = create_color_palette(pallete_name=self.pose_palette, increments=len(self.clf_names) + 1, as_int=True, as_rgb_ratio=True)
339
+ self.bouts_df = detect_bouts(data_df=self.data_df, target_lst=list(self.clf_names), fps=int(self.video_meta_data["fps"]))
340
+ self.final_gantt_img = PlottingMixin().make_gantt_plot(x_length=len(self.data_df) + 1, bouts_df=self.bouts_df, clf_names=self.clf_names, fps=self.video_meta_data["fps"], width=self.video_meta_data["width"], height=self.video_meta_data["height"], font_size=12, font_rotation=90, video_name=self.video_meta_data["video_name"], save_path=None, palette=self.gantt_clrs)
341
+ self.final_gantt_img = self.resize_gantt(self.final_gantt_img, self.video_meta_data["height"])
342
+ else:
343
+ self.bouts_df, self.final_gantt_img, self.gantt_clrs = None, None, None
344
+
345
+
346
+ self.clf_cumsums, self.clf_p = {}, {} if self.show_confidence else None
347
+ for clf_name in self.clf_names:
348
+ self.clf_cumsums[clf_name] = self.data_df[clf_name].cumsum()
349
+ if self.show_confidence: self.clf_p[clf_name] = np.round(self.data_df[f'Probability_{clf_name}'].values.reshape(-1), 4)
350
+
351
+ self.data_df["index"] = self.data_df.index
352
+ data = np.array_split(self.data_df, self.core_cnt)
353
+ data = [(cnt, x) for (cnt, x) in enumerate(data)]
354
+
355
+ constants = functools.partial(_multiprocess_sklearn_video,
356
+ bp_dict=self.animal_bp_dict,
357
+ video_save_dir=self.video_temp_dir,
358
+ frame_save_dir=self.video_frame_dir,
359
+ clf_cumsum=self.clf_cumsums,
360
+ rotate=self.rotate,
361
+ video_path=video_path,
362
+ clf_confidence=self.clf_p,
363
+ print_timers=self.print_timers,
364
+ video_setting=self.video_setting,
365
+ frame_setting=self.frame_setting,
366
+ pose_threshold=self.pose_threshold,
367
+ show_pose=self.show_pose,
368
+ show_animal_names=self.animal_names,
369
+ circle_size=self.video_circle_size,
370
+ font_size=self.video_font_size,
371
+ space_size=self.video_space_size,
372
+ text_thickness=self.video_text_thickness,
373
+ text_opacity=self.video_text_opacity,
374
+ text_bg_clr=self.text_bg_color,
375
+ text_color=self.text_color,
376
+ pose_clr_lst=self.clr_lst,
377
+ show_bbox=self.show_bbox,
378
+ show_gantt=self.show_gantt,
379
+ bouts_df=self.bouts_df,
380
+ final_gantt=self.final_gantt_img,
381
+ gantt_clrs=self.gantt_clrs,
382
+ clf_names=self.clf_names,
383
+ verbose=self.verbose)
384
+
385
+ for cnt, result in enumerate(self.pool.imap(constants, data, chunksize=self.multiprocess_chunksize)):
386
+ if self.verbose: print(f"[{get_current_time()}] Image batch {result} complete, Video {(video_cnt + 1)}/{len(self.video_paths)}...")
387
+
388
+ if self.video_setting:
389
+ if self.verbose: print(f"Joining {self.video_name} multiprocessed video...")
390
+ concatenate_videos_in_folder(in_folder=self.video_temp_dir, save_path=self.video_save_path, gpu=self.gpu, verbose=self.verbose)
391
+ video_timer.stop_timer()
392
+ print(f"Video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s)...")
393
+
394
+ terminate_cpu_pool(pool=self.pool, force=False)
395
+ self.timer.stop_timer()
396
+ if self.video_setting:
397
+ stdout_success(msg=f"{len(self.video_paths)} video(s) saved in {self.sklearn_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
398
+ if self.frame_setting:
399
+ stdout_success(f"Frames for {len(self.video_paths)} videos saved in sub-folders within {self.sklearn_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
400
+
401
+
402
+
403
+ if __name__ == "__main__":
404
+ clf_plotter = PlotSklearnResultsMultiProcess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
405
+ video_setting=True,
406
+ frame_setting=False,
407
+ video_paths=None,
408
+ print_timers=True,
409
+ rotate=False,
410
+ core_cnt=21,
411
+ animal_names=False,
412
+ show_bbox=True,
413
+ show_gantt=None)
414
+ clf_plotter.run()
415
+
416
+
417
+
418
+
419
+ # if __name__ == "__main__":
420
+ # clf_plotter = PlotSklearnResultsMultiProcess(config_path=r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini",
421
+ # video_setting=True,
422
+ # frame_setting=False,
423
+ # video_paths=r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/videos/Together_1.mp4",
424
+ # print_timers=True,
425
+ # rotate=False,
426
+ # animal_names=False,
427
+ # show_bbox=True,
428
+ # show_gantt=None)
429
+ # clf_plotter.run()
430
+
431
+
432
+
433
+
434
+ #text_settings = {'circle_scale': 5, 'font_size': 0.528, 'spacing_scale': 28, 'text_thickness': 2}
435
+ # clf_plotter = PlotSklearnResultsMultiProcess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini',
436
+ # video_setting=True,
437
+ # frame_setting=False,
438
+ # rotate=False,
439
+ # video_file_path='592_MA147_Gq_CNO_0515.mp4',
440
+ # cores=-1,
441
+ # text_settings=False)
442
+ # clf_plotter.run()
443
+ #
444
+
445
+ # clf_plotter = PlotSklearnResultsMultiProcess(config_path='/Users/simon/Desktop/envs/troubleshooting/DLC_2_Black_animals/project_folder/project_config.ini', video_setting=True, frame_setting=False, rotate=False, video_file_path='Together_1.avi', cores=5)
446
+ # clf_plotter.run()
447
+
448
+ # if __name__ == "__main__":
449
+ # clf_plotter = PlotSklearnResultsMultiProcess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
450
+ # video_setting = True,
451
+ # frame_setting = False,
452
+ # rotate = False,
453
+ # core_cnt = 6,
454
+ # show_confidence=True,
455
+ # video_paths=r"C:\troubleshooting\mitra\project_folder\videos\501_MA142_Gi_CNO_0521.mp4")
455
456
  # clf_plotter.run()