simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. simba/assets/.recent_projects.txt +1 -0
  2. simba/assets/lookups/tooptips.json +6 -1
  3. simba/data_processors/agg_clf_counter_mp.py +52 -53
  4. simba/data_processors/blob_location_computer.py +1 -1
  5. simba/data_processors/circling_detector.py +30 -13
  6. simba/data_processors/cuda/geometry.py +45 -27
  7. simba/data_processors/cuda/image.py +1648 -1598
  8. simba/data_processors/cuda/statistics.py +72 -26
  9. simba/data_processors/cuda/timeseries.py +1 -1
  10. simba/data_processors/cue_light_analyzer.py +5 -9
  11. simba/data_processors/egocentric_aligner.py +25 -7
  12. simba/data_processors/freezing_detector.py +55 -47
  13. simba/data_processors/kleinberg_calculator.py +61 -29
  14. simba/feature_extractors/feature_subsets.py +14 -7
  15. simba/feature_extractors/mitra_feature_extractor.py +2 -2
  16. simba/feature_extractors/straub_tail_analyzer.py +4 -6
  17. simba/labelling/standard_labeller.py +1 -1
  18. simba/mixins/config_reader.py +5 -2
  19. simba/mixins/geometry_mixin.py +22 -36
  20. simba/mixins/image_mixin.py +24 -28
  21. simba/mixins/plotting_mixin.py +28 -10
  22. simba/mixins/statistics_mixin.py +48 -11
  23. simba/mixins/timeseries_features_mixin.py +1 -1
  24. simba/mixins/train_model_mixin.py +67 -29
  25. simba/model/inference_batch.py +1 -1
  26. simba/model/yolo_seg_inference.py +3 -3
  27. simba/outlier_tools/skip_outlier_correction.py +1 -1
  28. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  29. simba/plotting/clf_validator_mp.py +4 -5
  30. simba/plotting/cue_light_visualizer.py +6 -7
  31. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  32. simba/plotting/distance_plotter_mp.py +378 -378
  33. simba/plotting/gantt_creator.py +29 -10
  34. simba/plotting/gantt_creator_mp.py +96 -33
  35. simba/plotting/geometry_plotter.py +270 -272
  36. simba/plotting/heat_mapper_clf_mp.py +4 -6
  37. simba/plotting/heat_mapper_location_mp.py +2 -2
  38. simba/plotting/light_dark_box_plotter.py +2 -2
  39. simba/plotting/path_plotter_mp.py +26 -29
  40. simba/plotting/plot_clf_results_mp.py +455 -454
  41. simba/plotting/pose_plotter_mp.py +28 -29
  42. simba/plotting/probability_plot_creator_mp.py +288 -288
  43. simba/plotting/roi_plotter_mp.py +31 -31
  44. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  45. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  46. simba/plotting/yolo_pose_track_visualizer.py +32 -27
  47. simba/plotting/yolo_pose_visualizer.py +35 -36
  48. simba/plotting/yolo_seg_visualizer.py +2 -3
  49. simba/pose_importers/simba_blob_importer.py +3 -3
  50. simba/roi_tools/roi_aggregate_stats_mp.py +5 -4
  51. simba/roi_tools/roi_clf_calculator_mp.py +4 -4
  52. simba/sandbox/analyze_runtimes.py +30 -0
  53. simba/sandbox/cuda/egocentric_rotator.py +374 -374
  54. simba/sandbox/get_cpu_pool.py +5 -0
  55. simba/sandbox/proboscis_to_tip.py +28 -0
  56. simba/sandbox/test_directionality.py +47 -0
  57. simba/sandbox/test_nonstatic_directionality.py +27 -0
  58. simba/sandbox/test_pycharm_cuda.py +51 -0
  59. simba/sandbox/test_simba_install.py +41 -0
  60. simba/sandbox/test_static_directionality.py +26 -0
  61. simba/sandbox/test_static_directionality_2d.py +26 -0
  62. simba/sandbox/verify_env.py +42 -0
  63. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
  64. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
  65. simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
  66. simba/ui/pop_ups/fsttc_pop_up.py +27 -25
  67. simba/ui/pop_ups/gantt_pop_up.py +31 -6
  68. simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
  69. simba/ui/pop_ups/video_processing_pop_up.py +37 -29
  70. simba/ui/tkinter_functions.py +3 -0
  71. simba/utils/custom_feature_extractor.py +1 -1
  72. simba/utils/data.py +90 -14
  73. simba/utils/enums.py +1 -0
  74. simba/utils/errors.py +441 -440
  75. simba/utils/lookups.py +1203 -1203
  76. simba/utils/printing.py +124 -124
  77. simba/utils/read_write.py +3769 -3721
  78. simba/utils/yolo.py +10 -1
  79. simba/video_processors/blob_tracking_executor.py +2 -2
  80. simba/video_processors/clahe_ui.py +1 -1
  81. simba/video_processors/egocentric_video_rotator.py +44 -41
  82. simba/video_processors/multi_cropper.py +1 -1
  83. simba/video_processors/video_processing.py +5264 -5222
  84. simba/video_processors/videos_to_frames.py +43 -33
  85. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/METADATA +4 -3
  86. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/RECORD +90 -80
  87. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/LICENSE +0 -0
  88. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/WHEEL +0 -0
  89. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/entry_points.txt +0 -0
  90. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/top_level.txt +0 -0
@@ -1,427 +1,427 @@
1
- __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
-
3
- import warnings
4
-
5
- warnings.filterwarnings("ignore", category=FutureWarning)
6
- warnings.filterwarnings("ignore", category=DeprecationWarning)
7
- import functools
8
- import multiprocessing
9
- import os
10
- import platform
11
- from copy import deepcopy
12
- from typing import List, Optional, Tuple, Union
13
-
14
- import cv2
15
- import imutils
16
- import pandas as pd
17
-
18
- try:
19
- from typing import Literal
20
- except:
21
- from typing_extensions import Literal
22
-
23
- import matplotlib
24
- import matplotlib.pyplot as plt
25
- import numpy as np
26
- from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
27
-
28
- from simba.mixins.config_reader import ConfigReader
29
- from simba.mixins.geometry_mixin import GeometryMixin
30
- from simba.mixins.plotting_mixin import PlottingMixin
31
- from simba.mixins.train_model_mixin import TrainModelMixin
32
- from simba.utils.checks import (check_file_exist_and_readable, check_float,
33
- check_int, check_str, check_valid_boolean,
34
- check_video_and_data_frm_count_align)
35
- from simba.utils.data import create_color_palette, plug_holes_shortest_bout
36
- from simba.utils.enums import Options, TextOptions
37
- from simba.utils.printing import SimbaTimer, stdout_success
38
- from simba.utils.read_write import (concatenate_videos_in_folder,
39
- create_directory, find_core_cnt,
40
- get_fn_ext, get_video_meta_data, read_df,
41
- read_pickle, write_df)
42
- from simba.utils.warnings import FrameRangeWarning, NoDataFoundWarning
43
-
44
-
45
- def _validation_video_mp(data: pd.DataFrame,
46
- bp_dict: dict,
47
- video_save_dir: str,
48
- video_path: str,
49
- text_thickness: int,
50
- text_opacity: float,
51
- font_size: int,
52
- text_spacing: int,
53
- circle_size: int,
54
- show_pose: bool,
55
- show_animal_bounding_boxes: bool,
56
- show_animal_names: bool,
57
- gantt_setting: Union[int, None],
58
- final_gantt: Optional[np.ndarray],
59
- clf_data: np.ndarray,
60
- clrs: List[List],
61
- clf_name: str,
62
- bouts_df: pd.DataFrame,
63
- conf_data: np.ndarray):
64
-
65
- def _put_text(img: np.ndarray,
66
- text: str,
67
- pos: Tuple[int, int],
68
- font_size: int,
69
- font_thickness: Optional[int] = 2,
70
- font: Optional[int] = cv2.FONT_HERSHEY_DUPLEX,
71
- text_color: Optional[Tuple[int, int, int]] = (255, 255, 255),
72
- text_color_bg: Optional[Tuple[int, int, int]] = (0, 0, 0),
73
- text_bg_alpha: float = 0.8):
74
-
75
- x, y = pos
76
- text_size, px_buffer = cv2.getTextSize(text, font, font_size, font_thickness)
77
- w, h = text_size
78
- overlay, output = img.copy(), img.copy()
79
- cv2.rectangle(overlay, (x, y-h), (x + w, y + px_buffer), text_color_bg, -1)
80
- cv2.addWeighted(overlay, text_bg_alpha, output, 1 - text_bg_alpha, 0, output)
81
- cv2.putText(output, text, (x, y), font, font_size, text_color, font_thickness)
82
- return output
83
-
84
-
85
- def _create_gantt(bouts_df: pd.DataFrame,
86
- clf_name: str,
87
- image_index: int,
88
- fps: int,
89
- header_font_size: int = 24,
90
- label_font_size: int = 12):
91
-
92
- fig, ax = plt.subplots(figsize=(final_gantt.shape[1] / dpi, final_gantt.shape[0] / dpi))
93
- matplotlib.font_manager._get_font.cache_clear()
94
- relRows = bouts_df.loc[bouts_df["End_frame"] <= image_index]
95
- for i, event in enumerate(relRows.groupby("Event")):
96
- data_event = event[1][["Start_time", "Bout_time"]]
97
- ax.broken_barh(data_event.values, (4, 4), facecolors="red")
98
- xLength = (round(image_index / fps)) + 1
99
- if xLength < 10:
100
- xLength = 10
101
-
102
- ax.set_xlim(0, xLength)
103
- ax.set_ylim([0, 12])
104
- ax.set_xlabel("Session (s)", fontsize=label_font_size)
105
- ax.set_ylabel(clf_name, fontsize=label_font_size)
106
- ax.set_title(f"{clf_name} GANTT CHART", fontsize=header_font_size)
107
- ax.set_yticks([])
108
- ax.yaxis.set_ticklabels([])
109
- ax.yaxis.grid(True)
110
- canvas = FigureCanvas(fig)
111
- canvas.draw()
112
- img = np.array(np.uint8(np.array(canvas.renderer._renderer)))[:, :, :3]
113
- plt.close(fig)
114
- return img
115
-
116
- dpi = plt.rcParams["figure.dpi"]
117
- fourcc, font = cv2.VideoWriter_fourcc(*"mp4v"), cv2.FONT_HERSHEY_DUPLEX
118
- cap = cv2.VideoCapture(video_path)
119
- video_meta_data = get_video_meta_data(video_path=video_path, fps_as_int=False)
120
- batch_id, batch_data = data[0], data[1]
121
- start_frm, current_frm, end_frm = batch_data.index[0], batch_data.index[0], batch_data.index[-1]
122
- video_save_path = os.path.join(video_save_dir, f"{batch_id}.mp4")
123
- if gantt_setting is not None:
124
- video_size = (int(video_meta_data["width"] + final_gantt.shape[1]), int(video_meta_data["height"]))
125
- writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], video_size)
126
- else:
127
- video_size = (int(video_meta_data["width"]), int(video_meta_data["height"]))
128
- writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], video_size)
129
- cap.set(1, start_frm)
130
- while (current_frm <= end_frm) & (current_frm <= video_meta_data["frame_count"]):
131
- clf_frm_cnt = np.sum(clf_data[0:current_frm])
132
- ret, img = cap.read()
133
- if ret:
134
- frm_timer = SimbaTimer(start=True)
135
- if show_pose:
136
- for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
137
- for bp_cnt, bp in enumerate(range(len(animal_data["X_bps"]))):
138
- x_header, y_header = (animal_data["X_bps"][bp], animal_data["Y_bps"][bp])
139
- animal_cords = tuple(batch_data.loc[current_frm, [x_header, y_header]])
140
- cv2.circle(img, (int(animal_cords[0]), int(animal_cords[1])), circle_size, clrs[animal_cnt][bp_cnt], -1)
141
- if show_animal_names:
142
- for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
143
- x_header, y_header = (animal_data["X_bps"][0], animal_data["Y_bps"][0],)
144
- animal_cords = tuple(batch_data.loc[current_frm, [x_header, y_header]])
145
- cv2.putText(img, animal_name, (int(animal_cords[0]), int(animal_cords[1])), font, font_size, clrs[animal_cnt][0], text_thickness)
146
- if show_animal_bounding_boxes:
147
- for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
148
- animal_headers = [val for pair in zip(animal_data["X_bps"], animal_data["Y_bps"]) for val in pair]
149
- animal_cords = batch_data.loc[current_frm, animal_headers].values.reshape(-1, 2).astype(np.int32)
150
- try:
151
- bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
152
- cv2.polylines(img, [bbox], True, clrs[animal_cnt][0], thickness=text_thickness, lineType=-1)
153
- except:
154
- pass
155
- target_timer = round((1 / video_meta_data["fps"]) * clf_frm_cnt, 2)
156
- img = _put_text(img=img, text="BEHAVIOR TIMER:", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value)
157
- addSpacer = 2
158
- img = _put_text(img=img, text=f"{clf_name} {target_timer}s", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_bg_alpha=text_opacity)
159
- addSpacer += 1
160
- if conf_data is not None:
161
- img = _put_text(img=img, text=f"{clf_name} PROBABILITY: {round(conf_data[current_frm], 4)}", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_bg_alpha=text_opacity)
162
- addSpacer += 1
163
- img = _put_text(img=img, text="ENSEMBLE PREDICTION:", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_bg_alpha=text_opacity)
164
- addSpacer += 1
165
- if clf_data[current_frm] == 1:
166
- img = _put_text(img=img, text=clf_name, pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_color=TextOptions.COLOR.value, text_bg_alpha=text_opacity)
167
- addSpacer += 1
168
- if gantt_setting == 1:
169
- img = np.concatenate((img, final_gantt), axis=1)
170
- elif gantt_setting == 2:
171
- gantt_img = _create_gantt(bouts_df, clf_name, current_frm, video_meta_data["fps"], header_font_size=9, label_font_size=12)
172
- gantt_img = imutils.resize(gantt_img, height=video_meta_data["height"])
173
- img = np.concatenate((img, gantt_img), axis=1)
174
- img = cv2.resize(img, video_size, interpolation=cv2.INTER_LINEAR)
175
- writer.write(np.uint8(img))
176
- current_frm += 1
177
- frm_timer.stop_timer()
178
- print(f"Multi-processing video frame {current_frm} on core {batch_id}...(elapsed time: {frm_timer.elapsed_time_str}s)")
179
- else:
180
- FrameRangeWarning(msg=f'Frame {current_frm} could not be read in video {video_path}. The video contains {video_meta_data["frame_count"]} frames while the data file contains data for {len(batch_data)} frames. Consider re-encoding the video, or make sure the pose-estimation data and associated video contains the same number of frames. ', source=_validation_video_mp.__name__)
181
- break
182
-
183
- cap.release()
184
- writer.release()
185
- return batch_id
186
-
187
-
188
- class ValidateModelOneVideoMultiprocess(ConfigReader, PlottingMixin, TrainModelMixin):
189
- """
190
- Create classifier validation video for a single input video using multiprocessing for improved performance.
191
-
192
- This class generates validation videos that overlay classifier predictions, pose estimations, and
193
- optional Gantt charts onto the original video using multiple CPU cores for faster processing.
194
- Results are stored in the `project_folder/frames/output/validation` directory.
195
-
196
- .. note::
197
- This multiprocess version provides significant speed improvements over the single-core
198
- :class:`simba.plotting.single_run_model_validation_video.ValidateModelOneVideo` class.
199
-
200
- :param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format.
201
- :param Union[str, os.PathLike] feature_path: Path to SimBA file (parquet or CSV) containing pose-estimation and feature data.
202
- :param Union[str, os.PathLike] model_path: Path to pickled classifier object (.sav file).
203
- :param bool show_pose: If True, overlay pose estimation keypoints on the video. Default: True.
204
- :param bool show_animal_names: If True, display animal names near the first body part. Default: False.
205
- :param Optional[int] font_size: Font size for text overlays. If None, automatically calculated based on video dimensions.
206
- :param Optional[str] bp_palette: Optional name of the palette to use to color the animal body-parts (e.g., Pastel1). If None, ``spring`` is used.
207
-
208
-
209
- :param Optional[int] circle_size: Size of pose estimation circles. If None, automatically calculated based on video dimensions.
210
- :param Optional[int] text_spacing: Spacing between text lines. If None, automatically calculated.
211
- :param Optional[int] text_thickness: Thickness of text overlay. If None, uses default value.
212
- :param Optional[float] text_opacity: Opacity of text overlays (0.1-1.0). If None, defaults to 0.8.
213
- :param float discrimination_threshold: Classification probability threshold (0.0-1.0). Default: 0.0.
214
- :param int shortest_bout: Minimum classified bout length in milliseconds. Bouts shorter than this will be reclassified as absent. Default: 0.
215
- :param int core_cnt: Number of CPU cores to use for processing. If -1, uses all available cores. Default: -1.
216
- :param Optional[Union[None, int]] create_gantt: Gantt chart creation option:
217
-
218
- - None: No Gantt chart
219
- - 1: Static Gantt chart (final frame only, faster)
220
- - 2: Dynamic Gantt chart (updated per frame)
221
-
222
-
223
- .. youtube:: UOLSj7DGKRo
224
- :width: 640
225
- :height: 480
226
- :align: center
227
-
228
- .. video:: _static/img/T1.webm
229
- :width: 1000
230
- :autoplay:
231
- :loop:
232
-
233
- :example:
234
- >>> # Create multiprocess validation video with dynamic Gantt chart
235
- >>> validator = ValidateModelOneVideoMultiprocess(
236
- ... config_path=r'/path/to/project_config.ini',
237
- ... feature_path=r'/path/to/features.csv',
238
- ... model_path=r'/path/to/classifier.sav',
239
- ... show_pose=True,
240
- ... show_animal_names=True,
241
- ... discrimination_threshold=0.6,
242
- ... shortest_bout=500,
243
- ... core_cnt=4,
244
- ... create_gantt=2
245
- ... )
246
- >>> validator.run()
247
- """
248
-
249
- def __init__(self,
250
- config_path: Union[str, os.PathLike],
251
- feature_path: Union[str, os.PathLike],
252
- model_path: Union[str, os.PathLike],
253
- show_pose: bool = True,
254
- show_animal_names: bool = False,
255
- show_animal_bounding_boxes: bool = False,
256
- show_clf_confidence: bool = False,
257
- font_size: Optional[bool] = None,
258
- circle_size: Optional[int] = None,
259
- text_spacing: Optional[int] = None,
260
- text_thickness: Optional[int] = None,
261
- text_opacity: Optional[float] = None,
262
- bp_palette: Optional[str] = None,
263
- discrimination_threshold: float = 0.0,
264
- shortest_bout: int = 0.0,
265
- core_cnt: int = -1,
266
- create_gantt: Optional[Union[None, int]] = None):
267
-
268
-
269
- ConfigReader.__init__(self, config_path=config_path)
270
- PlottingMixin.__init__(self)
271
- TrainModelMixin.__init__(self)
272
- check_file_exist_and_readable(file_path=config_path)
273
- check_file_exist_and_readable(file_path=feature_path)
274
- check_file_exist_and_readable(file_path=model_path)
275
- check_valid_boolean(value=[show_pose], source=f'{self.__class__.__name__} show_pose', raise_error=True)
276
- check_valid_boolean(value=[show_animal_names], source=f'{self.__class__.__name__} show_animal_names', raise_error=True)
277
- check_valid_boolean(value=[show_animal_bounding_boxes], source=f'{self.__class__.__name__} show_animal_bounding_boxes', raise_error=True)
278
- check_valid_boolean(value=[show_clf_confidence], source=f'{self.__class__.__name__} show_clf_confidence', raise_error=True)
279
- check_int(name=f"{self.__class__.__name__} core_cnt", value=core_cnt, min_value=-1, unaccepted_vals=[0])
280
- if font_size is not None: check_int(name=f'{self.__class__.__name__} font_size', value=font_size)
281
- if circle_size is not None: check_int(name=f'{self.__class__.__name__} circle_size', value=circle_size)
282
- if text_spacing is not None: check_int(name=f'{self.__class__.__name__} text_spacing', value=text_spacing)
283
- if text_opacity is not None: check_float(name=f'{self.__class__.__name__} text_opacity', value=text_opacity, min_value=0.1)
284
- if text_thickness is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=text_thickness, min_value=0.1)
285
- check_float(name=f"{self.__class__.__name__} discrimination_threshold", value=discrimination_threshold, min_value=0, max_value=1.0)
286
- check_int(name=f"{self.__class__.__name__} shortest_bout", value=shortest_bout, min_value=0)
287
- if create_gantt is not None:
288
- check_int(name=f"{self.__class__.__name__} create gantt", value=create_gantt, max_value=2, min_value=1)
289
- if not os.path.exists(self.single_validation_video_save_dir):
290
- os.makedirs(self.single_validation_video_save_dir)
291
- if bp_palette is not None:
292
- self.bp_palette = []
293
- check_str(name=f'{self.__class__.__name__} bp_palette', value=bp_palette, options=(Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value))
294
- for animal in range(self.animal_cnt):
295
- self.bp_palette.append(create_color_palette(pallete_name=bp_palette, increments=(int(len(self.body_parts_lst)/self.animal_cnt) +1), as_int=True))
296
- else:
297
- self.bp_palette = deepcopy(self.clr_lst)
298
- _, self.feature_filename, ext = get_fn_ext(feature_path)
299
- self.video_path = self.find_video_of_file(self.video_dir, self.feature_filename)
300
- self.video_meta_data = get_video_meta_data(video_path=self.video_path, fps_as_int=False)
301
- self.clf_name, self.feature_file_path = (os.path.basename(model_path).replace(".sav", ""), feature_path)
302
- self.vid_output_path = os.path.join(self.single_validation_video_save_dir, f"{self.feature_filename} {self.clf_name}.mp4")
303
- self.clf_data_save_path = os.path.join(self.clf_data_validation_dir, f"{self.feature_filename }.csv")
304
- self.show_pose, self.show_animal_names = show_pose, show_animal_names
305
- self.font_size, self.circle_size, self.text_spacing, self.show_clf_confidence = font_size, circle_size, text_spacing, show_clf_confidence
306
- self.text_opacity, self.text_thickness, self.show_animal_bounding_boxes = text_opacity, text_thickness, show_animal_bounding_boxes
307
- self.clf = read_pickle(data_path=model_path, verbose=True)
308
- self.data_df = read_df(feature_path, self.file_type)
309
- self.x_df = self.drop_bp_cords(df=self.data_df)
310
- self.discrimination_threshold, self.shortest_bout, self.create_gantt = float(discrimination_threshold), shortest_bout, create_gantt
311
- check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.feature_filename, raise_error=False)
312
- self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
313
- self.temp_dir = os.path.join(self.single_validation_video_save_dir, "temp")
314
- self.video_save_path = os.path.join(self.single_validation_video_save_dir, f"{self.feature_filename}.mp4")
315
- create_directory(paths=self.temp_dir, overwrite=True)
316
- if platform.system() == "Darwin":
317
- multiprocessing.set_start_method("spawn", force=True)
318
-
319
- def _get_styles(self):
320
- self.video_text_thickness = TextOptions.TEXT_THICKNESS.value if self.text_thickness is None else int(max(self.text_thickness, 1))
321
- longest_str = str(max(['TIMERS:', 'ENSEMBLE PREDICTION:'] + self.clf_names, key=len))
322
- optimal_font_size, _, optimal_spacing_scale = self.get_optimal_font_scales(text=longest_str, accepted_px_width=int(self.video_meta_data["width"] / 3), accepted_px_height=int(self.video_meta_data["height"] / 10), text_thickness=self.video_text_thickness)
323
- optimal_circle_size = self.get_optimal_circle_size(frame_size=(self.video_meta_data["width"], self.video_meta_data["height"]), circle_frame_ratio=100)
324
- self.video_circle_size = optimal_circle_size if self.circle_size is None else int(self.circle_size)
325
- self.video_font_size = optimal_font_size if self.font_size is None else self.font_size
326
- self.video_space_size = optimal_spacing_scale if self.text_spacing is None else int(max(self.text_spacing, 1))
327
- self.video_text_opacity = 0.8 if self.text_opacity is None else float(self.text_opacity)
328
-
329
- def run(self):
330
- self.prob_col_name = f"Probability_{self.clf_name}"
331
- self.data_df[self.prob_col_name] = self.clf_predict_proba(clf=self.clf, x_df=self.x_df, model_name=self.clf_name, data_path=self.feature_file_path)
332
- self.data_df[self.clf_name] = np.where(self.data_df[self.prob_col_name] > self.discrimination_threshold, 1, 0)
333
- if self.shortest_bout > 1:
334
- self.data_df = plug_holes_shortest_bout(data_df=self.data_df, clf_name=self.clf_name, fps=self.video_meta_data['fps'], shortest_bout=self.shortest_bout)
335
- _ = write_df(df=self.data_df, file_type=self.file_type, save_path=self.clf_data_save_path)
336
- print(f"Predictions created for video {self.feature_filename} (creating video, follow progressin OS terminal)...")
337
- self._get_styles()
338
- if self.create_gantt is not None:
339
- self.bouts_df = self.get_bouts_for_gantt(data_df=self.data_df, clf_name=self.clf_name, fps=self.video_meta_data['fps'])
340
- self.final_gantt_img = self.create_gantt_img(self.bouts_df ,self.clf_name,len(self.data_df), self.video_meta_data['fps'],f"Behavior gantt chart (entire session, length (s): {self.video_meta_data['video_length_s']}, frames: {self.video_meta_data['frame_count']})", header_font_size=9, label_font_size=12)
341
- self.final_gantt_img = self.resize_gantt(self.final_gantt_img, self.video_meta_data["height"])
342
- else:
343
- self.bouts_df, self.final_gantt_img = None, None
344
- conf_data = self.data_df[self.prob_col_name].values if self.show_clf_confidence else None
345
-
346
- self.data_df = self.data_df.head(min(len(self.data_df), self.video_meta_data["frame_count"]))
347
- data = np.array_split(self.data_df, self.core_cnt)
348
- data = [(i, j) for i, j in enumerate(data)]
349
-
350
- with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool:
351
- constants = functools.partial(_validation_video_mp,
352
- bp_dict=self.animal_bp_dict,
353
- video_save_dir=self.temp_dir,
354
- text_thickness=self.video_text_thickness,
355
- text_opacity=self.video_text_opacity,
356
- font_size=self.video_font_size,
357
- text_spacing=self.video_space_size,
358
- circle_size=self.video_circle_size,
359
- video_path=self.video_path,
360
- show_pose=self.show_pose,
361
- show_animal_names=self.show_animal_names,
362
- show_animal_bounding_boxes=self.show_animal_bounding_boxes,
363
- gantt_setting=self.create_gantt,
364
- final_gantt=self.final_gantt_img,
365
- clf_data=self.data_df[self.clf_name].values,
366
- clrs=self.bp_palette,
367
- clf_name=self.clf_name,
368
- bouts_df=self.bouts_df,
369
- conf_data=conf_data)
370
-
371
- for cnt, result in enumerate(pool.imap(constants, data, chunksize=self.multiprocess_chunksize)):
372
- print(f"Image batch {result} complete, Video {self.feature_filename}...")
373
- pool.terminate()
374
- pool.join()
375
- concatenate_videos_in_folder(in_folder=self.temp_dir, save_path=self.video_save_path)
376
- self.timer.stop_timer()
377
- stdout_success(msg=f"Video complete, saved at {self.video_save_path}", elapsed_time=self.timer.elapsed_time_str)
378
-
379
- #
380
- # if __name__ == "__main__":
381
- # test = ValidateModelOneVideoMultiprocess(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini",
382
- # feature_path=r"D:\troubleshooting\mitra\project_folder\csv\features_extracted\592_MA147_CNO1_0515.csv",
383
- # model_path=r"C:\troubleshooting\mitra\models\validations\rearing_5\rearing.sav",
384
- # create_gantt=2,
385
- # show_pose=True,
386
- # show_animal_names=True,
387
- # core_cnt=13,
388
- # show_clf_confidence=True,
389
- # discrimination_threshold=0.20)
390
- # test.run()
391
-
392
-
393
- #
394
- # if __name__ == "__main__":
395
- # test = ValidateModelOneVideoMultiprocess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
396
- # feature_file_path=r"C:\troubleshooting\mitra\project_folder\csv\features_extracted\844_MA131_gq_CNO_0624.csv",
397
- # model_path=r"C:\troubleshooting\mitra\models\validations\lay-on-belly_1\lay-on-belly.sav",
398
- # discrimination_threshold=0.35,
399
- # shortest_bout=200,
400
- # cores=-1,
401
- # settings={'pose': True, 'animal_names': False, 'styles': None},
402
- # create_gantt=2)
403
- # test.run()
404
-
405
-
406
-
407
-
408
-
409
- # test = ValidateModelOneVideoMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini',
410
- # feature_file_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/csv/features_extracted/SI_DAY3_308_CD1_PRESENT.csv',
411
- # model_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/models/generated_models/Running.sav',
412
- # discrimination_threshold=0.6,
413
- # shortest_bout=50,
414
- # cores=6,
415
- # settings={'pose': True, 'animal_names': True, 'styles': None},
416
- # create_gantt=None)
417
- # test.run()
418
-
419
- # test = ValidateModelOneVideoMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
420
- # feature_file_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/features_extracted/Together_1.csv',
421
- # model_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/models/generated_models/Attack.sav',
422
- # discrimination_threshold=0.6,
423
- # shortest_bout=50,
424
- # cores=6,
425
- # settings={'pose': True, 'animal_names': True, 'styles': None},
426
- # create_gantt=None)
427
- # test.run()
1
+ __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
+
3
+ import warnings
4
+
5
+ warnings.filterwarnings("ignore", category=FutureWarning)
6
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
7
+ import functools
8
+ import multiprocessing
9
+ import os
10
+ import platform
11
+ from copy import deepcopy
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ import cv2
15
+ import imutils
16
+ import pandas as pd
17
+
18
+ try:
19
+ from typing import Literal
20
+ except:
21
+ from typing_extensions import Literal
22
+
23
+ import matplotlib
24
+ import matplotlib.pyplot as plt
25
+ import numpy as np
26
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
27
+
28
+ from simba.mixins.config_reader import ConfigReader
29
+ from simba.mixins.geometry_mixin import GeometryMixin
30
+ from simba.mixins.plotting_mixin import PlottingMixin
31
+ from simba.mixins.train_model_mixin import TrainModelMixin
32
+ from simba.utils.checks import (check_file_exist_and_readable, check_float,
33
+ check_int, check_str, check_valid_boolean,
34
+ check_video_and_data_frm_count_align)
35
+ from simba.utils.data import (create_color_palette, plug_holes_shortest_bout,
36
+ terminate_cpu_pool)
37
+ from simba.utils.enums import Options, TextOptions
38
+ from simba.utils.printing import SimbaTimer, stdout_success
39
+ from simba.utils.read_write import (concatenate_videos_in_folder,
40
+ create_directory, find_core_cnt,
41
+ get_fn_ext, get_video_meta_data, read_df,
42
+ read_pickle, write_df)
43
+ from simba.utils.warnings import FrameRangeWarning
44
+
45
+
46
+ def _validation_video_mp(data: pd.DataFrame,
47
+ bp_dict: dict,
48
+ video_save_dir: str,
49
+ video_path: str,
50
+ text_thickness: int,
51
+ text_opacity: float,
52
+ font_size: int,
53
+ text_spacing: int,
54
+ circle_size: int,
55
+ show_pose: bool,
56
+ show_animal_bounding_boxes: bool,
57
+ show_animal_names: bool,
58
+ gantt_setting: Union[int, None],
59
+ final_gantt: Optional[np.ndarray],
60
+ clf_data: np.ndarray,
61
+ clrs: List[List],
62
+ clf_name: str,
63
+ bouts_df: pd.DataFrame,
64
+ conf_data: np.ndarray):
65
+
66
+ def _put_text(img: np.ndarray,
67
+ text: str,
68
+ pos: Tuple[int, int],
69
+ font_size: int,
70
+ font_thickness: Optional[int] = 2,
71
+ font: Optional[int] = cv2.FONT_HERSHEY_DUPLEX,
72
+ text_color: Optional[Tuple[int, int, int]] = (255, 255, 255),
73
+ text_color_bg: Optional[Tuple[int, int, int]] = (0, 0, 0),
74
+ text_bg_alpha: float = 0.8):
75
+
76
+ x, y = pos
77
+ text_size, px_buffer = cv2.getTextSize(text, font, font_size, font_thickness)
78
+ w, h = text_size
79
+ overlay, output = img.copy(), img.copy()
80
+ cv2.rectangle(overlay, (x, y-h), (x + w, y + px_buffer), text_color_bg, -1)
81
+ cv2.addWeighted(overlay, text_bg_alpha, output, 1 - text_bg_alpha, 0, output)
82
+ cv2.putText(output, text, (x, y), font, font_size, text_color, font_thickness)
83
+ return output
84
+
85
+
86
+ def _create_gantt(bouts_df: pd.DataFrame,
87
+ clf_name: str,
88
+ image_index: int,
89
+ fps: int,
90
+ header_font_size: int = 24,
91
+ label_font_size: int = 12):
92
+
93
+ fig, ax = plt.subplots(figsize=(final_gantt.shape[1] / dpi, final_gantt.shape[0] / dpi))
94
+ matplotlib.font_manager._get_font.cache_clear()
95
+ relRows = bouts_df.loc[bouts_df["End_frame"] <= image_index]
96
+ for i, event in enumerate(relRows.groupby("Event")):
97
+ data_event = event[1][["Start_time", "Bout_time"]]
98
+ ax.broken_barh(data_event.values, (4, 4), facecolors="red")
99
+ xLength = (round(image_index / fps)) + 1
100
+ if xLength < 10:
101
+ xLength = 10
102
+
103
+ ax.set_xlim(0, xLength)
104
+ ax.set_ylim([0, 12])
105
+ ax.set_xlabel("Session (s)", fontsize=label_font_size)
106
+ ax.set_ylabel(clf_name, fontsize=label_font_size)
107
+ ax.set_title(f"{clf_name} GANTT CHART", fontsize=header_font_size)
108
+ ax.set_yticks([])
109
+ ax.yaxis.set_ticklabels([])
110
+ ax.yaxis.grid(True)
111
+ canvas = FigureCanvas(fig)
112
+ canvas.draw()
113
+ img = np.array(np.uint8(np.array(canvas.renderer._renderer)))[:, :, :3]
114
+ plt.close(fig)
115
+ return img
116
+
117
+ dpi = plt.rcParams["figure.dpi"]
118
+ fourcc, font = cv2.VideoWriter_fourcc(*"mp4v"), cv2.FONT_HERSHEY_DUPLEX
119
+ cap = cv2.VideoCapture(video_path)
120
+ video_meta_data = get_video_meta_data(video_path=video_path, fps_as_int=False)
121
+ batch_id, batch_data = data[0], data[1]
122
+ start_frm, current_frm, end_frm = batch_data.index[0], batch_data.index[0], batch_data.index[-1]
123
+ video_save_path = os.path.join(video_save_dir, f"{batch_id}.mp4")
124
+ if gantt_setting is not None:
125
+ video_size = (int(video_meta_data["width"] + final_gantt.shape[1]), int(video_meta_data["height"]))
126
+ writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], video_size)
127
+ else:
128
+ video_size = (int(video_meta_data["width"]), int(video_meta_data["height"]))
129
+ writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], video_size)
130
+ cap.set(1, start_frm)
131
+ while (current_frm <= end_frm) & (current_frm <= video_meta_data["frame_count"]):
132
+ clf_frm_cnt = np.sum(clf_data[0:current_frm])
133
+ ret, img = cap.read()
134
+ if ret:
135
+ frm_timer = SimbaTimer(start=True)
136
+ if show_pose:
137
+ for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
138
+ for bp_cnt, bp in enumerate(range(len(animal_data["X_bps"]))):
139
+ x_header, y_header = (animal_data["X_bps"][bp], animal_data["Y_bps"][bp])
140
+ animal_cords = tuple(batch_data.loc[current_frm, [x_header, y_header]])
141
+ cv2.circle(img, (int(animal_cords[0]), int(animal_cords[1])), circle_size, clrs[animal_cnt][bp_cnt], -1)
142
+ if show_animal_names:
143
+ for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
144
+ x_header, y_header = (animal_data["X_bps"][0], animal_data["Y_bps"][0],)
145
+ animal_cords = tuple(batch_data.loc[current_frm, [x_header, y_header]])
146
+ cv2.putText(img, animal_name, (int(animal_cords[0]), int(animal_cords[1])), font, font_size, clrs[animal_cnt][0], text_thickness)
147
+ if show_animal_bounding_boxes:
148
+ for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
149
+ animal_headers = [val for pair in zip(animal_data["X_bps"], animal_data["Y_bps"]) for val in pair]
150
+ animal_cords = batch_data.loc[current_frm, animal_headers].values.reshape(-1, 2).astype(np.int32)
151
+ try:
152
+ bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
153
+ cv2.polylines(img, [bbox], True, clrs[animal_cnt][0], thickness=text_thickness, lineType=-1)
154
+ except:
155
+ pass
156
+ target_timer = round((1 / video_meta_data["fps"]) * clf_frm_cnt, 2)
157
+ img = _put_text(img=img, text="BEHAVIOR TIMER:", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value)
158
+ addSpacer = 2
159
+ img = _put_text(img=img, text=f"{clf_name} {target_timer}s", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_bg_alpha=text_opacity)
160
+ addSpacer += 1
161
+ if conf_data is not None:
162
+ img = _put_text(img=img, text=f"{clf_name} PROBABILITY: {round(conf_data[current_frm], 4)}", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_bg_alpha=text_opacity)
163
+ addSpacer += 1
164
+ img = _put_text(img=img, text="ENSEMBLE PREDICTION:", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_bg_alpha=text_opacity)
165
+ addSpacer += 1
166
+ if clf_data[current_frm] == 1:
167
+ img = _put_text(img=img, text=clf_name, pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_color=TextOptions.COLOR.value, text_bg_alpha=text_opacity)
168
+ addSpacer += 1
169
+ if gantt_setting == 1:
170
+ img = np.concatenate((img, final_gantt), axis=1)
171
+ elif gantt_setting == 2:
172
+ gantt_img = _create_gantt(bouts_df, clf_name, current_frm, video_meta_data["fps"], header_font_size=9, label_font_size=12)
173
+ gantt_img = imutils.resize(gantt_img, height=video_meta_data["height"])
174
+ img = np.concatenate((img, gantt_img), axis=1)
175
+ img = cv2.resize(img, video_size, interpolation=cv2.INTER_LINEAR)
176
+ writer.write(np.uint8(img))
177
+ current_frm += 1
178
+ frm_timer.stop_timer()
179
+ print(f"Multi-processing video frame {current_frm} on core {batch_id}...(elapsed time: {frm_timer.elapsed_time_str}s)")
180
+ else:
181
+ FrameRangeWarning(msg=f'Frame {current_frm} could not be read in video {video_path}. The video contains {video_meta_data["frame_count"]} frames while the data file contains data for {len(batch_data)} frames. Consider re-encoding the video, or make sure the pose-estimation data and associated video contains the same number of frames. ', source=_validation_video_mp.__name__)
182
+ break
183
+
184
+ cap.release()
185
+ writer.release()
186
+ return batch_id
187
+
188
+
189
+ class ValidateModelOneVideoMultiprocess(ConfigReader, PlottingMixin, TrainModelMixin):
190
+ """
191
+ Create classifier validation video for a single input video using multiprocessing for improved performance.
192
+
193
+ This class generates validation videos that overlay classifier predictions, pose estimations, and
194
+ optional Gantt charts onto the original video using multiple CPU cores for faster processing.
195
+ Results are stored in the `project_folder/frames/output/validation` directory.
196
+
197
+ .. note::
198
+ This multiprocess version provides significant speed improvements over the single-core
199
+ :class:`simba.plotting.single_run_model_validation_video.ValidateModelOneVideo` class.
200
+
201
+ :param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format.
202
+ :param Union[str, os.PathLike] feature_path: Path to SimBA file (parquet or CSV) containing pose-estimation and feature data.
203
+ :param Union[str, os.PathLike] model_path: Path to pickled classifier object (.sav file).
204
+ :param bool show_pose: If True, overlay pose estimation keypoints on the video. Default: True.
205
+ :param bool show_animal_names: If True, display animal names near the first body part. Default: False.
206
+ :param Optional[int] font_size: Font size for text overlays. If None, automatically calculated based on video dimensions.
207
+ :param Optional[str] bp_palette: Optional name of the palette to use to color the animal body-parts (e.g., Pastel1). If None, ``spring`` is used.
208
+
209
+
210
+ :param Optional[int] circle_size: Size of pose estimation circles. If None, automatically calculated based on video dimensions.
211
+ :param Optional[int] text_spacing: Spacing between text lines. If None, automatically calculated.
212
+ :param Optional[int] text_thickness: Thickness of text overlay. If None, uses default value.
213
+ :param Optional[float] text_opacity: Opacity of text overlays (0.1-1.0). If None, defaults to 0.8.
214
+ :param float discrimination_threshold: Classification probability threshold (0.0-1.0). Default: 0.0.
215
+ :param int shortest_bout: Minimum classified bout length in milliseconds. Bouts shorter than this will be reclassified as absent. Default: 0.
216
+ :param int core_cnt: Number of CPU cores to use for processing. If -1, uses all available cores. Default: -1.
217
+ :param Optional[Union[None, int]] create_gantt: Gantt chart creation option:
218
+
219
+ - None: No Gantt chart
220
+ - 1: Static Gantt chart (final frame only, faster)
221
+ - 2: Dynamic Gantt chart (updated per frame)
222
+
223
+
224
+ .. youtube:: UOLSj7DGKRo
225
+ :width: 640
226
+ :height: 480
227
+ :align: center
228
+
229
+ .. video:: _static/img/T1.webm
230
+ :width: 1000
231
+ :autoplay:
232
+ :loop:
233
+
234
+ :example:
235
+ >>> # Create multiprocess validation video with dynamic Gantt chart
236
+ >>> validator = ValidateModelOneVideoMultiprocess(
237
+ ... config_path=r'/path/to/project_config.ini',
238
+ ... feature_path=r'/path/to/features.csv',
239
+ ... model_path=r'/path/to/classifier.sav',
240
+ ... show_pose=True,
241
+ ... show_animal_names=True,
242
+ ... discrimination_threshold=0.6,
243
+ ... shortest_bout=500,
244
+ ... core_cnt=4,
245
+ ... create_gantt=2
246
+ ... )
247
+ >>> validator.run()
248
+ """
249
+
250
+ def __init__(self,
251
+ config_path: Union[str, os.PathLike],
252
+ feature_path: Union[str, os.PathLike],
253
+ model_path: Union[str, os.PathLike],
254
+ show_pose: bool = True,
255
+ show_animal_names: bool = False,
256
+ show_animal_bounding_boxes: bool = False,
257
+ show_clf_confidence: bool = False,
258
+ font_size: Optional[bool] = None,
259
+ circle_size: Optional[int] = None,
260
+ text_spacing: Optional[int] = None,
261
+ text_thickness: Optional[int] = None,
262
+ text_opacity: Optional[float] = None,
263
+ bp_palette: Optional[str] = None,
264
+ discrimination_threshold: float = 0.0,
265
+ shortest_bout: int = 0.0,
266
+ core_cnt: int = -1,
267
+ create_gantt: Optional[Union[None, int]] = None):
268
+
269
+
270
+ ConfigReader.__init__(self, config_path=config_path)
271
+ PlottingMixin.__init__(self)
272
+ TrainModelMixin.__init__(self)
273
+ check_file_exist_and_readable(file_path=config_path)
274
+ check_file_exist_and_readable(file_path=feature_path)
275
+ check_file_exist_and_readable(file_path=model_path)
276
+ check_valid_boolean(value=[show_pose], source=f'{self.__class__.__name__} show_pose', raise_error=True)
277
+ check_valid_boolean(value=[show_animal_names], source=f'{self.__class__.__name__} show_animal_names', raise_error=True)
278
+ check_valid_boolean(value=[show_animal_bounding_boxes], source=f'{self.__class__.__name__} show_animal_bounding_boxes', raise_error=True)
279
+ check_valid_boolean(value=[show_clf_confidence], source=f'{self.__class__.__name__} show_clf_confidence', raise_error=True)
280
+ check_int(name=f"{self.__class__.__name__} core_cnt", value=core_cnt, min_value=-1, unaccepted_vals=[0])
281
+ if font_size is not None: check_int(name=f'{self.__class__.__name__} font_size', value=font_size)
282
+ if circle_size is not None: check_int(name=f'{self.__class__.__name__} circle_size', value=circle_size)
283
+ if text_spacing is not None: check_int(name=f'{self.__class__.__name__} text_spacing', value=text_spacing)
284
+ if text_opacity is not None: check_float(name=f'{self.__class__.__name__} text_opacity', value=text_opacity, min_value=0.1)
285
+ if text_thickness is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=text_thickness, min_value=0.1)
286
+ check_float(name=f"{self.__class__.__name__} discrimination_threshold", value=discrimination_threshold, min_value=0, max_value=1.0)
287
+ check_int(name=f"{self.__class__.__name__} shortest_bout", value=shortest_bout, min_value=0)
288
+ if create_gantt is not None:
289
+ check_int(name=f"{self.__class__.__name__} create gantt", value=create_gantt, max_value=2, min_value=1)
290
+ if not os.path.exists(self.single_validation_video_save_dir):
291
+ os.makedirs(self.single_validation_video_save_dir)
292
+ if bp_palette is not None:
293
+ self.bp_palette = []
294
+ check_str(name=f'{self.__class__.__name__} bp_palette', value=bp_palette, options=(Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value))
295
+ for animal in range(self.animal_cnt):
296
+ self.bp_palette.append(create_color_palette(pallete_name=bp_palette, increments=(int(len(self.body_parts_lst)/self.animal_cnt) +1), as_int=True))
297
+ else:
298
+ self.bp_palette = deepcopy(self.clr_lst)
299
+ _, self.feature_filename, ext = get_fn_ext(feature_path)
300
+ self.video_path = self.find_video_of_file(self.video_dir, self.feature_filename)
301
+ self.video_meta_data = get_video_meta_data(video_path=self.video_path, fps_as_int=False)
302
+ self.clf_name, self.feature_file_path = (os.path.basename(model_path).replace(".sav", ""), feature_path)
303
+ self.vid_output_path = os.path.join(self.single_validation_video_save_dir, f"{self.feature_filename} {self.clf_name}.mp4")
304
+ self.clf_data_save_path = os.path.join(self.clf_data_validation_dir, f"{self.feature_filename }.csv")
305
+ self.show_pose, self.show_animal_names = show_pose, show_animal_names
306
+ self.font_size, self.circle_size, self.text_spacing, self.show_clf_confidence = font_size, circle_size, text_spacing, show_clf_confidence
307
+ self.text_opacity, self.text_thickness, self.show_animal_bounding_boxes = text_opacity, text_thickness, show_animal_bounding_boxes
308
+ self.clf = read_pickle(data_path=model_path, verbose=True)
309
+ self.data_df = read_df(feature_path, self.file_type)
310
+ self.x_df = self.drop_bp_cords(df=self.data_df)
311
+ self.discrimination_threshold, self.shortest_bout, self.create_gantt = float(discrimination_threshold), shortest_bout, create_gantt
312
+ check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.feature_filename, raise_error=False)
313
+ self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
314
+ self.temp_dir = os.path.join(self.single_validation_video_save_dir, "temp")
315
+ self.video_save_path = os.path.join(self.single_validation_video_save_dir, f"{self.feature_filename}.mp4")
316
+ create_directory(paths=self.temp_dir, overwrite=True)
317
+ if platform.system() == "Darwin":
318
+ multiprocessing.set_start_method("spawn", force=True)
319
+
320
+ def _get_styles(self):
321
+ self.video_text_thickness = TextOptions.TEXT_THICKNESS.value if self.text_thickness is None else int(max(self.text_thickness, 1))
322
+ longest_str = str(max(['TIMERS:', 'ENSEMBLE PREDICTION:'] + self.clf_names, key=len))
323
+ optimal_font_size, _, optimal_spacing_scale = self.get_optimal_font_scales(text=longest_str, accepted_px_width=int(self.video_meta_data["width"] / 3), accepted_px_height=int(self.video_meta_data["height"] / 10), text_thickness=self.video_text_thickness)
324
+ optimal_circle_size = self.get_optimal_circle_size(frame_size=(self.video_meta_data["width"], self.video_meta_data["height"]), circle_frame_ratio=100)
325
+ self.video_circle_size = optimal_circle_size if self.circle_size is None else int(self.circle_size)
326
+ self.video_font_size = optimal_font_size if self.font_size is None else self.font_size
327
+ self.video_space_size = optimal_spacing_scale if self.text_spacing is None else int(max(self.text_spacing, 1))
328
+ self.video_text_opacity = 0.8 if self.text_opacity is None else float(self.text_opacity)
329
+
330
+ def run(self):
331
+ self.prob_col_name = f"Probability_{self.clf_name}"
332
+ self.data_df[self.prob_col_name] = self.clf_predict_proba(clf=self.clf, x_df=self.x_df, model_name=self.clf_name, data_path=self.feature_file_path)
333
+ self.data_df[self.clf_name] = np.where(self.data_df[self.prob_col_name] > self.discrimination_threshold, 1, 0)
334
+ if self.shortest_bout > 1:
335
+ self.data_df = plug_holes_shortest_bout(data_df=self.data_df, clf_name=self.clf_name, fps=self.video_meta_data['fps'], shortest_bout=self.shortest_bout)
336
+ _ = write_df(df=self.data_df, file_type=self.file_type, save_path=self.clf_data_save_path)
337
+ print(f"Predictions created for video {self.feature_filename} (creating video, follow progressin OS terminal)...")
338
+ self._get_styles()
339
+ if self.create_gantt is not None:
340
+ self.bouts_df = self.get_bouts_for_gantt(data_df=self.data_df, clf_name=self.clf_name, fps=self.video_meta_data['fps'])
341
+ self.final_gantt_img = self.create_gantt_img(self.bouts_df ,self.clf_name,len(self.data_df), self.video_meta_data['fps'],f"Behavior gantt chart (entire session, length (s): {self.video_meta_data['video_length_s']}, frames: {self.video_meta_data['frame_count']})", header_font_size=9, label_font_size=12)
342
+ self.final_gantt_img = self.resize_gantt(self.final_gantt_img, self.video_meta_data["height"])
343
+ else:
344
+ self.bouts_df, self.final_gantt_img = None, None
345
+ conf_data = self.data_df[self.prob_col_name].values if self.show_clf_confidence else None
346
+
347
+ self.data_df = self.data_df.head(min(len(self.data_df), self.video_meta_data["frame_count"]))
348
+ data = np.array_split(self.data_df, self.core_cnt)
349
+ data = [(i, j) for i, j in enumerate(data)]
350
+
351
+ with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool:
352
+ constants = functools.partial(_validation_video_mp,
353
+ bp_dict=self.animal_bp_dict,
354
+ video_save_dir=self.temp_dir,
355
+ text_thickness=self.video_text_thickness,
356
+ text_opacity=self.video_text_opacity,
357
+ font_size=self.video_font_size,
358
+ text_spacing=self.video_space_size,
359
+ circle_size=self.video_circle_size,
360
+ video_path=self.video_path,
361
+ show_pose=self.show_pose,
362
+ show_animal_names=self.show_animal_names,
363
+ show_animal_bounding_boxes=self.show_animal_bounding_boxes,
364
+ gantt_setting=self.create_gantt,
365
+ final_gantt=self.final_gantt_img,
366
+ clf_data=self.data_df[self.clf_name].values,
367
+ clrs=self.bp_palette,
368
+ clf_name=self.clf_name,
369
+ bouts_df=self.bouts_df,
370
+ conf_data=conf_data)
371
+
372
+ for cnt, result in enumerate(pool.imap(constants, data, chunksize=self.multiprocess_chunksize)):
373
+ print(f"Image batch {result} complete, Video {self.feature_filename}...")
374
+ terminate_cpu_pool(pool=pool, force=False)
375
+ concatenate_videos_in_folder(in_folder=self.temp_dir, save_path=self.video_save_path)
376
+ self.timer.stop_timer()
377
+ stdout_success(msg=f"Video complete, saved at {self.video_save_path}", elapsed_time=self.timer.elapsed_time_str)
378
+
379
+ #
380
+ # if __name__ == "__main__":
381
+ # test = ValidateModelOneVideoMultiprocess(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini",
382
+ # feature_path=r"D:\troubleshooting\mitra\project_folder\csv\features_extracted\592_MA147_CNO1_0515.csv",
383
+ # model_path=r"C:\troubleshooting\mitra\models\validations\rearing_5\rearing.sav",
384
+ # create_gantt=2,
385
+ # show_pose=True,
386
+ # show_animal_names=True,
387
+ # core_cnt=13,
388
+ # show_clf_confidence=True,
389
+ # discrimination_threshold=0.20)
390
+ # test.run()
391
+
392
+
393
+ #
394
+ # if __name__ == "__main__":
395
+ # test = ValidateModelOneVideoMultiprocess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
396
+ # feature_file_path=r"C:\troubleshooting\mitra\project_folder\csv\features_extracted\844_MA131_gq_CNO_0624.csv",
397
+ # model_path=r"C:\troubleshooting\mitra\models\validations\lay-on-belly_1\lay-on-belly.sav",
398
+ # discrimination_threshold=0.35,
399
+ # shortest_bout=200,
400
+ # cores=-1,
401
+ # settings={'pose': True, 'animal_names': False, 'styles': None},
402
+ # create_gantt=2)
403
+ # test.run()
404
+
405
+
406
+
407
+
408
+
409
+ # test = ValidateModelOneVideoMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini',
410
+ # feature_file_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/csv/features_extracted/SI_DAY3_308_CD1_PRESENT.csv',
411
+ # model_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/models/generated_models/Running.sav',
412
+ # discrimination_threshold=0.6,
413
+ # shortest_bout=50,
414
+ # cores=6,
415
+ # settings={'pose': True, 'animal_names': True, 'styles': None},
416
+ # create_gantt=None)
417
+ # test.run()
418
+
419
+ # test = ValidateModelOneVideoMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
420
+ # feature_file_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/features_extracted/Together_1.csv',
421
+ # model_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/models/generated_models/Attack.sav',
422
+ # discrimination_threshold=0.6,
423
+ # shortest_bout=50,
424
+ # cores=6,
425
+ # settings={'pose': True, 'animal_names': True, 'styles': None},
426
+ # create_gantt=None)
427
+ # test.run()