simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. simba/assets/.recent_projects.txt +1 -0
  2. simba/assets/lookups/tooptips.json +6 -1
  3. simba/assets/lookups/yolo_schematics/yolo_mitra.csv +9 -0
  4. simba/data_processors/agg_clf_counter_mp.py +52 -53
  5. simba/data_processors/blob_location_computer.py +1 -1
  6. simba/data_processors/circling_detector.py +30 -13
  7. simba/data_processors/cuda/geometry.py +45 -27
  8. simba/data_processors/cuda/image.py +1648 -1598
  9. simba/data_processors/cuda/statistics.py +72 -26
  10. simba/data_processors/cuda/timeseries.py +1 -1
  11. simba/data_processors/cue_light_analyzer.py +5 -9
  12. simba/data_processors/egocentric_aligner.py +25 -7
  13. simba/data_processors/freezing_detector.py +55 -47
  14. simba/data_processors/kleinberg_calculator.py +61 -29
  15. simba/feature_extractors/feature_subsets.py +14 -7
  16. simba/feature_extractors/mitra_feature_extractor.py +2 -2
  17. simba/feature_extractors/straub_tail_analyzer.py +4 -6
  18. simba/labelling/standard_labeller.py +1 -1
  19. simba/mixins/config_reader.py +5 -2
  20. simba/mixins/geometry_mixin.py +22 -36
  21. simba/mixins/image_mixin.py +24 -28
  22. simba/mixins/plotting_mixin.py +28 -10
  23. simba/mixins/statistics_mixin.py +48 -11
  24. simba/mixins/timeseries_features_mixin.py +1 -1
  25. simba/mixins/train_model_mixin.py +68 -33
  26. simba/model/inference_batch.py +2 -2
  27. simba/model/yolo_seg_inference.py +3 -3
  28. simba/outlier_tools/skip_outlier_correction.py +1 -1
  29. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  30. simba/plotting/clf_validator_mp.py +4 -5
  31. simba/plotting/cue_light_visualizer.py +6 -7
  32. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  33. simba/plotting/distance_plotter_mp.py +378 -378
  34. simba/plotting/gantt_creator.py +29 -10
  35. simba/plotting/gantt_creator_mp.py +96 -33
  36. simba/plotting/geometry_plotter.py +270 -272
  37. simba/plotting/heat_mapper_clf_mp.py +4 -6
  38. simba/plotting/heat_mapper_location_mp.py +2 -2
  39. simba/plotting/light_dark_box_plotter.py +2 -2
  40. simba/plotting/path_plotter_mp.py +26 -29
  41. simba/plotting/plot_clf_results_mp.py +455 -454
  42. simba/plotting/pose_plotter_mp.py +28 -29
  43. simba/plotting/probability_plot_creator_mp.py +288 -288
  44. simba/plotting/roi_plotter_mp.py +31 -31
  45. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  46. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  47. simba/plotting/yolo_pose_track_visualizer.py +32 -27
  48. simba/plotting/yolo_pose_visualizer.py +35 -36
  49. simba/plotting/yolo_seg_visualizer.py +2 -3
  50. simba/pose_importers/simba_blob_importer.py +3 -3
  51. simba/roi_tools/roi_aggregate_stats_mp.py +5 -4
  52. simba/roi_tools/roi_clf_calculator_mp.py +4 -4
  53. simba/sandbox/analyze_runtimes.py +30 -0
  54. simba/sandbox/cuda/egocentric_rotator.py +374 -374
  55. simba/sandbox/get_cpu_pool.py +5 -0
  56. simba/sandbox/proboscis_to_tip.py +28 -0
  57. simba/sandbox/test_directionality.py +47 -0
  58. simba/sandbox/test_nonstatic_directionality.py +27 -0
  59. simba/sandbox/test_pycharm_cuda.py +51 -0
  60. simba/sandbox/test_simba_install.py +41 -0
  61. simba/sandbox/test_static_directionality.py +26 -0
  62. simba/sandbox/test_static_directionality_2d.py +26 -0
  63. simba/sandbox/verify_env.py +42 -0
  64. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
  65. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
  66. simba/third_party_label_appenders/transform/simba_to_yolo.py +8 -5
  67. simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
  68. simba/ui/pop_ups/fsttc_pop_up.py +27 -25
  69. simba/ui/pop_ups/gantt_pop_up.py +31 -6
  70. simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
  71. simba/ui/pop_ups/run_machine_models_popup.py +21 -21
  72. simba/ui/pop_ups/simba_to_yolo_keypoints_popup.py +2 -2
  73. simba/ui/pop_ups/video_processing_pop_up.py +37 -29
  74. simba/ui/pop_ups/yolo_inference_popup.py +1 -1
  75. simba/ui/pop_ups/yolo_pose_train_popup.py +1 -1
  76. simba/ui/tkinter_functions.py +3 -0
  77. simba/utils/custom_feature_extractor.py +1 -1
  78. simba/utils/data.py +90 -14
  79. simba/utils/enums.py +1 -0
  80. simba/utils/errors.py +441 -440
  81. simba/utils/lookups.py +1203 -1203
  82. simba/utils/printing.py +124 -124
  83. simba/utils/read_write.py +3769 -3721
  84. simba/utils/yolo.py +10 -1
  85. simba/video_processors/blob_tracking_executor.py +2 -2
  86. simba/video_processors/clahe_ui.py +1 -1
  87. simba/video_processors/egocentric_video_rotator.py +44 -41
  88. simba/video_processors/multi_cropper.py +1 -1
  89. simba/video_processors/video_processing.py +75 -33
  90. simba/video_processors/videos_to_frames.py +43 -33
  91. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/METADATA +4 -3
  92. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/RECORD +96 -85
  93. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/LICENSE +0 -0
  94. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/WHEEL +0 -0
  95. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/entry_points.txt +0 -0
  96. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/top_level.txt +0 -0
@@ -1,378 +1,378 @@
1
- __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
- import functools
3
- import multiprocessing
4
- import os
5
- import platform
6
- from typing import Dict, List, Optional, Union
7
-
8
- import cv2
9
- import numpy as np
10
- from numba import jit
11
-
12
- from simba.mixins.config_reader import ConfigReader
13
- from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
14
- from simba.mixins.plotting_mixin import PlottingMixin
15
- from simba.utils.checks import (
16
- check_all_file_names_are_represented_in_video_log,
17
- check_file_exist_and_readable, check_instance, check_int, check_valid_lst)
18
- from simba.utils.errors import (CountError, InvalidInputError,
19
- NoSpecifiedOutputError)
20
- from simba.utils.lookups import get_color_dict
21
- from simba.utils.printing import SimbaTimer, stdout_success
22
- from simba.utils.read_write import (concatenate_videos_in_folder,
23
- find_core_cnt, get_fn_ext, read_df)
24
-
25
-
26
- def distance_plotter_mp(
27
- frm_cnts: np.array,
28
- distances: np.ndarray,
29
- colors: List[str],
30
- video_setting: bool,
31
- frame_setting: bool,
32
- video_name: str,
33
- video_save_dir: str,
34
- frame_folder_dir: str,
35
- style_attr: dict,
36
- fps: int,
37
- ):
38
-
39
- group = int(distances[frm_cnts[0], 0])
40
- video_writer = None
41
- if video_setting:
42
- fourcc = cv2.VideoWriter_fourcc(*"DIVX")
43
- temp_video_save_path = os.path.join(video_save_dir, f"{group}.avi")
44
- video_writer = cv2.VideoWriter(
45
- temp_video_save_path,
46
- fourcc,
47
- fps,
48
- (style_attr["width"], style_attr["height"]),
49
- )
50
-
51
- for frm_cnt in frm_cnts:
52
- line_data = distances[:frm_cnt, 1:]
53
- line_data = np.hsplit(line_data, line_data.shape[1])
54
-
55
- img = PlottingMixin.make_line_plot_plotly(
56
- data=line_data,
57
- colors=colors,
58
- width=style_attr["width"],
59
- height=style_attr["height"],
60
- line_width=style_attr["line width"],
61
- font_size=style_attr["font size"],
62
- title="Animal distances",
63
- y_lbl="distance (cm)",
64
- x_lbl="frame count",
65
- x_lbl_divisor=fps,
66
- y_max=style_attr["y_max"],
67
- line_opacity=style_attr["opacity"],
68
- save_path=None,
69
- ).astype(np.uint8)
70
- if video_setting:
71
- video_writer.write(img[:, :, :3])
72
- if frame_setting:
73
- frm_name = os.path.join(frame_folder_dir, f"{frm_cnt}.png")
74
- cv2.imwrite(frm_name, np.uint8(img))
75
- print(
76
- f"Distance frame created: {frm_cnt} (of {distances.shape[0]}), Video: {video_name}, Processing core: {group}"
77
- )
78
- if video_setting:
79
- video_writer.release()
80
- return group
81
-
82
-
83
- class DistancePlotterMultiCore(ConfigReader, PlottingMixin):
84
- """
85
- Visualize the distances between pose-estimated body-parts (e.g., two animals) through line
86
- charts. Results are saved as individual line charts, and/or a video of line charts.
87
- Uses multiprocessing.
88
-
89
- :param str config_path: path to SimBA project config file in Configparser format
90
- :param bool frame_setting: If True, creates individual frames.
91
- :param bool video_setting: If True, creates videos.
92
- :param bool final_img: If True, creates a single .png representing the entire video.
93
- :param dict style_attr: Video style attributes (font sizes, line opacity etc.)
94
- :param List[Union[str, os.PathLike]] data_paths: Files to visualize.
95
- :param dict line_attr: Representing the body-parts to visualize the distance between and their colors.
96
-
97
- .. note::
98
- `GitHub tutorial/documentation <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#step-11-visualizations>`__.
99
-
100
- .. image:: _static/img/DistancePlotterMultiCore.png
101
- :width: 600
102
- :align: center
103
-
104
- .. image:: _static/img/DistancePlotterMultiCore_1.gif
105
- :width: 600
106
- :align: center
107
-
108
- :example:
109
- >>> style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'opacity': 0.5}
110
- >>> line_attr = {0: ['Center_1', 'Center_2', 'Green'], 1: ['Ear_left_2', 'Ear_left_1', 'Red']}
111
- >>> distance_plotter = DistancePlotterMultiCore(config_path=r'/tests_/project_folder/project_config.ini', frame_setting=False, video_setting=True, final_img=True, style_attr=style_attr, line_attr=line_attr, files_found=['/test_/project_folder/csv/machine_results/Together_1.csv'], core_cnt=5)
112
- >>> distance_plotter.run()
113
- """
114
-
115
- def __init__(
116
- self,
117
- config_path: Union[str, os.PathLike],
118
- data_paths: List[Union[str, os.PathLike]],
119
- frame_setting: bool,
120
- video_setting: bool,
121
- final_img: bool,
122
- style_attr: Dict[str, int],
123
- line_attr: List[List[str]],
124
- core_cnt: Optional[int] = -1,
125
- ):
126
-
127
- if (not frame_setting) and (not video_setting) and (not final_img):
128
- raise NoSpecifiedOutputError(
129
- msg="Please choice to create frames and/or video distance plots",
130
- source=self.__class__.__name__,
131
- )
132
- check_int(
133
- name=f"{self.__class__.__name__} core_cnt",
134
- value=core_cnt,
135
- min_value=-1,
136
- max_value=find_core_cnt()[0],
137
- )
138
- if core_cnt == -1:
139
- core_cnt = find_core_cnt()[0]
140
- ConfigReader.__init__(self, config_path=config_path)
141
- PlottingMixin.__init__(self)
142
- check_instance(
143
- source=f"{self.__class__.__name__} line_attr",
144
- instance=line_attr,
145
- accepted_types=(list,),
146
- )
147
- for cnt, i in enumerate(line_attr):
148
- check_valid_lst(
149
- source=f"{self.__class__.__name__} line_attr {cnt}",
150
- data=i,
151
- valid_dtypes=(str,),
152
- exact_len=3,
153
- )
154
- check_valid_lst(data=data_paths, valid_dtypes=(str,), min_len=1)
155
- _ = [check_file_exist_and_readable(i) for i in data_paths]
156
- (
157
- self.video_setting,
158
- self.frame_setting,
159
- self.data_paths,
160
- self.style_attr,
161
- self.line_attr,
162
- self.final_img,
163
- self.core_cnt,
164
- ) = (
165
- video_setting,
166
- frame_setting,
167
- data_paths,
168
- style_attr,
169
- line_attr,
170
- final_img,
171
- core_cnt,
172
- )
173
- if not os.path.exists(self.line_plot_dir):
174
- os.makedirs(self.line_plot_dir)
175
- self.color_names = get_color_dict()
176
- if platform.system() == "Darwin":
177
- multiprocessing.set_start_method("spawn", force=True)
178
-
179
- @staticmethod
180
- @jit(nopython=True)
181
- def __insert_group_idx_column(data: np.array, group: int):
182
- group_col = np.full((data.shape[0], 1), group)
183
- return np.hstack((group_col, data))
184
-
185
- def run(self):
186
- print(f"Processing {len(self.data_paths)} video(s)...")
187
- check_all_file_names_are_represented_in_video_log(
188
- video_info_df=self.video_info_df, data_paths=self.data_paths
189
- )
190
- for file_cnt, file_path in enumerate(self.data_paths):
191
- video_timer = SimbaTimer(start=True)
192
- _, video_name, _ = get_fn_ext(file_path)
193
- data_df = read_df(file_path, self.file_type)
194
- try:
195
- data_df.columns = self.bp_headers
196
- except ValueError:
197
- raise CountError(
198
- msg=f"SimBA expects {self.bp_headers} columns but found {len(data_df)} columns in {file_path}",
199
- source=self.__class__.__name__,
200
- )
201
- self.video_info, px_per_mm, fps = self.read_video_info(
202
- video_name=video_name
203
- )
204
- self.save_video_folder = os.path.join(self.line_plot_dir, video_name)
205
- self.temp_folder = os.path.join(self.line_plot_dir, video_name, "temp")
206
- self.save_frame_folder_dir = os.path.join(self.line_plot_dir, video_name)
207
- distances = []
208
- colors = []
209
- for cnt, i in enumerate(self.line_attr):
210
- if i[2] not in list(self.color_names.keys()):
211
- raise InvalidInputError(
212
- msg=f"{i[2]} is not a valid color. Options: {list(self.color_names.keys())}.",
213
- source=self.__class__.__name__,
214
- )
215
- colors.append(i[2])
216
- bp_1, bp_2 = [f"{i[0]}_x", f"{i[0]}_y"], [f"{i[1]}_x", f"{i[1]}_y"]
217
- if len(list(set(bp_1) - set(data_df.columns))) > 0:
218
- raise InvalidInputError(
219
- msg=f"Could not find fields {bp_1} in {file_path}",
220
- source=self.__class__.__name__,
221
- )
222
- if len(list(set(bp_2) - set(data_df.columns))) > 0:
223
- raise InvalidInputError(
224
- msg=f"Could not find fields {bp_2} in {file_path}",
225
- source=self.__class__.__name__,
226
- )
227
- distances.append(
228
- FeatureExtractionMixin.framewise_euclidean_distance(
229
- location_1=data_df[bp_1].values.astype(np.float64),
230
- location_2=data_df[bp_2].values.astype(np.float64),
231
- px_per_mm=np.float64(px_per_mm),
232
- centimeter=True,
233
- )
234
- )
235
- if self.frame_setting:
236
- if os.path.exists(self.save_frame_folder_dir):
237
- self.remove_a_folder(self.save_frame_folder_dir)
238
- os.makedirs(self.save_frame_folder_dir)
239
- if self.video_setting:
240
- self.video_folder = os.path.join(self.line_plot_dir, video_name)
241
- if os.path.exists(self.temp_folder):
242
- self.remove_a_folder(self.temp_folder)
243
- os.makedirs(self.temp_folder)
244
- self.save_video_path = os.path.join(
245
- self.line_plot_dir, f"{video_name}.mp4"
246
- )
247
-
248
- if self.final_img:
249
- _ = PlottingMixin.make_line_plot(
250
- data=distances,
251
- colors=colors,
252
- width=self.style_attr["width"],
253
- height=self.style_attr["height"],
254
- line_width=self.style_attr["line width"],
255
- font_size=self.style_attr["font size"],
256
- title="Animal distances",
257
- y_lbl="distance (cm)",
258
- x_lbl="time (s)",
259
- x_lbl_divisor=fps,
260
- y_max=self.style_attr["y_max"],
261
- line_opacity=self.style_attr["opacity"],
262
- save_path=os.path.join(
263
- self.line_plot_dir, f"{video_name}_final_distances.png"
264
- ),
265
- )
266
-
267
- if self.video_setting or self.frame_setting:
268
- if self.style_attr["y_max"] == -1:
269
- self.style_attr["y_max"] = max([np.max(x) for x in distances])
270
- distances = np.stack(distances, axis=1)
271
- frm_range = np.arange(0, distances.shape[0])
272
- frm_range = np.array_split(frm_range, self.core_cnt)
273
-
274
- distances = np.array_split(distances, self.core_cnt)
275
- distances = [
276
- self.__insert_group_idx_column(data=i, group=cnt)
277
- for cnt, i in enumerate(distances)
278
- ]
279
- distances = np.concatenate(distances, axis=0)
280
- print(
281
- f"Creating distance plots, multiprocessing, follow progress in terminal (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})"
282
- )
283
- with multiprocessing.Pool(
284
- self.core_cnt, maxtasksperchild=self.maxtasksperchild
285
- ) as pool:
286
- constants = functools.partial(
287
- distance_plotter_mp,
288
- distances=distances,
289
- video_setting=self.video_setting,
290
- frame_setting=self.frame_setting,
291
- video_name=video_name,
292
- video_save_dir=self.temp_folder,
293
- frame_folder_dir=self.save_frame_folder_dir,
294
- style_attr=self.style_attr,
295
- colors=colors,
296
- fps=fps,
297
- )
298
- for cnt, result in enumerate(
299
- pool.map(
300
- constants, frm_range, chunksize=self.multiprocess_chunksize
301
- )
302
- ):
303
- print(f"Frame batch core {result} complete...")
304
- pass
305
- pool.join()
306
- pool.terminate()
307
- if self.video_setting:
308
- concatenate_videos_in_folder(
309
- in_folder=self.temp_folder,
310
- save_path=self.save_video_path,
311
- video_format="avi",
312
- )
313
- video_timer.stop_timer()
314
- stdout_success(
315
- msg=f"Distance visualizations created for {video_name} saved at {self.line_plot_dir}",
316
- elapsed_time=video_timer.elapsed_time_str,
317
- )
318
- self.timer.stop_timer()
319
- stdout_success(
320
- msg=f"Distance visualizations complete for {len(self.data_paths)} video(s)",
321
- elapsed_time=self.timer.elapsed_time_str,
322
- )
323
-
324
-
325
- # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 12, 'y_max': -1, 'opacity': 0.5}
326
- # line_attr = [['Center_1', 'Center_2', 'Green'], ['Ear_left_2', 'Ear_right_2', 'Red']]
327
- # test = DistancePlotterMultiCore(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
328
- # frame_setting=True,
329
- # video_setting=True,
330
- # style_attr=style_attr,
331
- # final_img=True,
332
- # data_paths=['/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_1.csv'],
333
- # line_attr=line_attr,
334
- # core_cnt=-1)
335
- # test.run()
336
-
337
-
338
- # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'y_max': 'auto', 'opacity': 0.9}
339
- # line_attr = {0: ['Center_1', 'Center_2', 'Green'], 1: ['Ear_left_2', 'Ear_left_1', 'Red']}
340
- #
341
- # test = DistancePlotterMultiCore(config_path=r'/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
342
- # frame_setting=False,
343
- # video_setting=True,
344
- # style_attr=style_attr,
345
- # final_img=True,
346
- # files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'],
347
- # line_attr=line_attr,
348
- # core_cnt=3)
349
- # test.create_distance_plot()
350
- # #
351
- # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8}
352
- # line_attr = {0: ['Termite_1_Head_1', 'Termite_1_Thorax_1', 'Dark-red']}
353
- #
354
-
355
-
356
- # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'opacity': 0.5, 'y_max': 'auto'}
357
- # line_attr = {0: ['Center_1', 'Center_2', 'Green'], 1: ['Ear_left_2', 'Ear_left_1', 'Red']}
358
- #
359
- # test = DistancePlotterMultiCore(config_path=r'/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
360
- # frame_setting=False,
361
- # video_setting=True,
362
- # style_attr=style_attr,
363
- # final_img=False,
364
- # files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'],
365
- # line_attr=line_attr,
366
- # core_cnt=5)
367
- # test.create_distance_plot()
368
-
369
- # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8}
370
- # line_attr = {0: ['Termite_1_Head_1', 'Termite_1_Thorax_1', 'Dark-red']}
371
-
372
- # test = DistancePlotterSingleCore(config_path=r'/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/project_config.ini',
373
- # frame_setting=False,
374
- # video_setting=True,
375
- # style_attr=style_attr,
376
- # files_found=['/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/csv/outlier_corrected_movement_location/termites_1.csv'],
377
- # line_attr=line_attr)
378
- # test.create_distance_plot()
1
+ __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
+ import functools
3
+ import multiprocessing
4
+ import os
5
+ import platform
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from numba import jit
11
+
12
+ from simba.mixins.config_reader import ConfigReader
13
+ from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
14
+ from simba.mixins.plotting_mixin import PlottingMixin
15
+ from simba.utils.checks import (
16
+ check_all_file_names_are_represented_in_video_log,
17
+ check_file_exist_and_readable, check_instance, check_int, check_valid_lst)
18
+ from simba.utils.data import terminate_cpu_pool
19
+ from simba.utils.errors import (CountError, InvalidInputError,
20
+ NoSpecifiedOutputError)
21
+ from simba.utils.lookups import get_color_dict
22
+ from simba.utils.printing import SimbaTimer, stdout_success
23
+ from simba.utils.read_write import (concatenate_videos_in_folder,
24
+ find_core_cnt, get_fn_ext, read_df)
25
+
26
+
27
+ def distance_plotter_mp(
28
+ frm_cnts: np.array,
29
+ distances: np.ndarray,
30
+ colors: List[str],
31
+ video_setting: bool,
32
+ frame_setting: bool,
33
+ video_name: str,
34
+ video_save_dir: str,
35
+ frame_folder_dir: str,
36
+ style_attr: dict,
37
+ fps: int,
38
+ ):
39
+
40
+ group = int(distances[frm_cnts[0], 0])
41
+ video_writer = None
42
+ if video_setting:
43
+ fourcc = cv2.VideoWriter_fourcc(*"DIVX")
44
+ temp_video_save_path = os.path.join(video_save_dir, f"{group}.avi")
45
+ video_writer = cv2.VideoWriter(
46
+ temp_video_save_path,
47
+ fourcc,
48
+ fps,
49
+ (style_attr["width"], style_attr["height"]),
50
+ )
51
+
52
+ for frm_cnt in frm_cnts:
53
+ line_data = distances[:frm_cnt, 1:]
54
+ line_data = np.hsplit(line_data, line_data.shape[1])
55
+
56
+ img = PlottingMixin.make_line_plot_plotly(
57
+ data=line_data,
58
+ colors=colors,
59
+ width=style_attr["width"],
60
+ height=style_attr["height"],
61
+ line_width=style_attr["line width"],
62
+ font_size=style_attr["font size"],
63
+ title="Animal distances",
64
+ y_lbl="distance (cm)",
65
+ x_lbl="frame count",
66
+ x_lbl_divisor=fps,
67
+ y_max=style_attr["y_max"],
68
+ line_opacity=style_attr["opacity"],
69
+ save_path=None,
70
+ ).astype(np.uint8)
71
+ if video_setting:
72
+ video_writer.write(img[:, :, :3])
73
+ if frame_setting:
74
+ frm_name = os.path.join(frame_folder_dir, f"{frm_cnt}.png")
75
+ cv2.imwrite(frm_name, np.uint8(img))
76
+ print(
77
+ f"Distance frame created: {frm_cnt} (of {distances.shape[0]}), Video: {video_name}, Processing core: {group}"
78
+ )
79
+ if video_setting:
80
+ video_writer.release()
81
+ return group
82
+
83
+
84
+ class DistancePlotterMultiCore(ConfigReader, PlottingMixin):
85
+ """
86
+ Visualize the distances between pose-estimated body-parts (e.g., two animals) through line
87
+ charts. Results are saved as individual line charts, and/or a video of line charts.
88
+ Uses multiprocessing.
89
+
90
+ :param str config_path: path to SimBA project config file in Configparser format
91
+ :param bool frame_setting: If True, creates individual frames.
92
+ :param bool video_setting: If True, creates videos.
93
+ :param bool final_img: If True, creates a single .png representing the entire video.
94
+ :param dict style_attr: Video style attributes (font sizes, line opacity etc.)
95
+ :param List[Union[str, os.PathLike]] data_paths: Files to visualize.
96
+ :param dict line_attr: Representing the body-parts to visualize the distance between and their colors.
97
+
98
+ .. note::
99
+ `GitHub tutorial/documentation <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#step-11-visualizations>`__.
100
+
101
+ .. image:: _static/img/DistancePlotterMultiCore.png
102
+ :width: 600
103
+ :align: center
104
+
105
+ .. image:: _static/img/DistancePlotterMultiCore_1.gif
106
+ :width: 600
107
+ :align: center
108
+
109
+ :example:
110
+ >>> style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'opacity': 0.5}
111
+ >>> line_attr = {0: ['Center_1', 'Center_2', 'Green'], 1: ['Ear_left_2', 'Ear_left_1', 'Red']}
112
+ >>> distance_plotter = DistancePlotterMultiCore(config_path=r'/tests_/project_folder/project_config.ini', frame_setting=False, video_setting=True, final_img=True, style_attr=style_attr, line_attr=line_attr, files_found=['/test_/project_folder/csv/machine_results/Together_1.csv'], core_cnt=5)
113
+ >>> distance_plotter.run()
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ config_path: Union[str, os.PathLike],
119
+ data_paths: List[Union[str, os.PathLike]],
120
+ frame_setting: bool,
121
+ video_setting: bool,
122
+ final_img: bool,
123
+ style_attr: Dict[str, int],
124
+ line_attr: List[List[str]],
125
+ core_cnt: Optional[int] = -1,
126
+ ):
127
+
128
+ if (not frame_setting) and (not video_setting) and (not final_img):
129
+ raise NoSpecifiedOutputError(
130
+ msg="Please choice to create frames and/or video distance plots",
131
+ source=self.__class__.__name__,
132
+ )
133
+ check_int(
134
+ name=f"{self.__class__.__name__} core_cnt",
135
+ value=core_cnt,
136
+ min_value=-1,
137
+ max_value=find_core_cnt()[0],
138
+ )
139
+ if core_cnt == -1:
140
+ core_cnt = find_core_cnt()[0]
141
+ ConfigReader.__init__(self, config_path=config_path)
142
+ PlottingMixin.__init__(self)
143
+ check_instance(
144
+ source=f"{self.__class__.__name__} line_attr",
145
+ instance=line_attr,
146
+ accepted_types=(list,),
147
+ )
148
+ for cnt, i in enumerate(line_attr):
149
+ check_valid_lst(
150
+ source=f"{self.__class__.__name__} line_attr {cnt}",
151
+ data=i,
152
+ valid_dtypes=(str,),
153
+ exact_len=3,
154
+ )
155
+ check_valid_lst(data=data_paths, valid_dtypes=(str,), min_len=1)
156
+ _ = [check_file_exist_and_readable(i) for i in data_paths]
157
+ (
158
+ self.video_setting,
159
+ self.frame_setting,
160
+ self.data_paths,
161
+ self.style_attr,
162
+ self.line_attr,
163
+ self.final_img,
164
+ self.core_cnt,
165
+ ) = (
166
+ video_setting,
167
+ frame_setting,
168
+ data_paths,
169
+ style_attr,
170
+ line_attr,
171
+ final_img,
172
+ core_cnt,
173
+ )
174
+ if not os.path.exists(self.line_plot_dir):
175
+ os.makedirs(self.line_plot_dir)
176
+ self.color_names = get_color_dict()
177
+ if platform.system() == "Darwin":
178
+ multiprocessing.set_start_method("spawn", force=True)
179
+
180
+ @staticmethod
181
+ @jit(nopython=True)
182
+ def __insert_group_idx_column(data: np.array, group: int):
183
+ group_col = np.full((data.shape[0], 1), group)
184
+ return np.hstack((group_col, data))
185
+
186
+ def run(self):
187
+ print(f"Processing {len(self.data_paths)} video(s)...")
188
+ check_all_file_names_are_represented_in_video_log(
189
+ video_info_df=self.video_info_df, data_paths=self.data_paths
190
+ )
191
+ for file_cnt, file_path in enumerate(self.data_paths):
192
+ video_timer = SimbaTimer(start=True)
193
+ _, video_name, _ = get_fn_ext(file_path)
194
+ data_df = read_df(file_path, self.file_type)
195
+ try:
196
+ data_df.columns = self.bp_headers
197
+ except ValueError:
198
+ raise CountError(
199
+ msg=f"SimBA expects {self.bp_headers} columns but found {len(data_df)} columns in {file_path}",
200
+ source=self.__class__.__name__,
201
+ )
202
+ self.video_info, px_per_mm, fps = self.read_video_info(
203
+ video_name=video_name
204
+ )
205
+ self.save_video_folder = os.path.join(self.line_plot_dir, video_name)
206
+ self.temp_folder = os.path.join(self.line_plot_dir, video_name, "temp")
207
+ self.save_frame_folder_dir = os.path.join(self.line_plot_dir, video_name)
208
+ distances = []
209
+ colors = []
210
+ for cnt, i in enumerate(self.line_attr):
211
+ if i[2] not in list(self.color_names.keys()):
212
+ raise InvalidInputError(
213
+ msg=f"{i[2]} is not a valid color. Options: {list(self.color_names.keys())}.",
214
+ source=self.__class__.__name__,
215
+ )
216
+ colors.append(i[2])
217
+ bp_1, bp_2 = [f"{i[0]}_x", f"{i[0]}_y"], [f"{i[1]}_x", f"{i[1]}_y"]
218
+ if len(list(set(bp_1) - set(data_df.columns))) > 0:
219
+ raise InvalidInputError(
220
+ msg=f"Could not find fields {bp_1} in {file_path}",
221
+ source=self.__class__.__name__,
222
+ )
223
+ if len(list(set(bp_2) - set(data_df.columns))) > 0:
224
+ raise InvalidInputError(
225
+ msg=f"Could not find fields {bp_2} in {file_path}",
226
+ source=self.__class__.__name__,
227
+ )
228
+ distances.append(
229
+ FeatureExtractionMixin.framewise_euclidean_distance(
230
+ location_1=data_df[bp_1].values.astype(np.float64),
231
+ location_2=data_df[bp_2].values.astype(np.float64),
232
+ px_per_mm=np.float64(px_per_mm),
233
+ centimeter=True,
234
+ )
235
+ )
236
+ if self.frame_setting:
237
+ if os.path.exists(self.save_frame_folder_dir):
238
+ self.remove_a_folder(self.save_frame_folder_dir)
239
+ os.makedirs(self.save_frame_folder_dir)
240
+ if self.video_setting:
241
+ self.video_folder = os.path.join(self.line_plot_dir, video_name)
242
+ if os.path.exists(self.temp_folder):
243
+ self.remove_a_folder(self.temp_folder)
244
+ os.makedirs(self.temp_folder)
245
+ self.save_video_path = os.path.join(
246
+ self.line_plot_dir, f"{video_name}.mp4"
247
+ )
248
+
249
+ if self.final_img:
250
+ _ = PlottingMixin.make_line_plot(
251
+ data=distances,
252
+ colors=colors,
253
+ width=self.style_attr["width"],
254
+ height=self.style_attr["height"],
255
+ line_width=self.style_attr["line width"],
256
+ font_size=self.style_attr["font size"],
257
+ title="Animal distances",
258
+ y_lbl="distance (cm)",
259
+ x_lbl="time (s)",
260
+ x_lbl_divisor=fps,
261
+ y_max=self.style_attr["y_max"],
262
+ line_opacity=self.style_attr["opacity"],
263
+ save_path=os.path.join(
264
+ self.line_plot_dir, f"{video_name}_final_distances.png"
265
+ ),
266
+ )
267
+
268
+ if self.video_setting or self.frame_setting:
269
+ if self.style_attr["y_max"] == -1:
270
+ self.style_attr["y_max"] = max([np.max(x) for x in distances])
271
+ distances = np.stack(distances, axis=1)
272
+ frm_range = np.arange(0, distances.shape[0])
273
+ frm_range = np.array_split(frm_range, self.core_cnt)
274
+
275
+ distances = np.array_split(distances, self.core_cnt)
276
+ distances = [
277
+ self.__insert_group_idx_column(data=i, group=cnt)
278
+ for cnt, i in enumerate(distances)
279
+ ]
280
+ distances = np.concatenate(distances, axis=0)
281
+ print(
282
+ f"Creating distance plots, multiprocessing, follow progress in terminal (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})"
283
+ )
284
+ with multiprocessing.Pool(
285
+ self.core_cnt, maxtasksperchild=self.maxtasksperchild
286
+ ) as pool:
287
+ constants = functools.partial(
288
+ distance_plotter_mp,
289
+ distances=distances,
290
+ video_setting=self.video_setting,
291
+ frame_setting=self.frame_setting,
292
+ video_name=video_name,
293
+ video_save_dir=self.temp_folder,
294
+ frame_folder_dir=self.save_frame_folder_dir,
295
+ style_attr=self.style_attr,
296
+ colors=colors,
297
+ fps=fps,
298
+ )
299
+ for cnt, result in enumerate(
300
+ pool.map(
301
+ constants, frm_range, chunksize=self.multiprocess_chunksize
302
+ )
303
+ ):
304
+ print(f"Frame batch core {result} complete...")
305
+ pass
306
+ terminate_cpu_pool(pool=pool, force=False)
307
+ if self.video_setting:
308
+ concatenate_videos_in_folder(
309
+ in_folder=self.temp_folder,
310
+ save_path=self.save_video_path,
311
+ video_format="avi",
312
+ )
313
+ video_timer.stop_timer()
314
+ stdout_success(
315
+ msg=f"Distance visualizations created for {video_name} saved at {self.line_plot_dir}",
316
+ elapsed_time=video_timer.elapsed_time_str,
317
+ )
318
+ self.timer.stop_timer()
319
+ stdout_success(
320
+ msg=f"Distance visualizations complete for {len(self.data_paths)} video(s)",
321
+ elapsed_time=self.timer.elapsed_time_str,
322
+ )
323
+
324
+
325
+ # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 12, 'y_max': -1, 'opacity': 0.5}
326
+ # line_attr = [['Center_1', 'Center_2', 'Green'], ['Ear_left_2', 'Ear_right_2', 'Red']]
327
+ # test = DistancePlotterMultiCore(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
328
+ # frame_setting=True,
329
+ # video_setting=True,
330
+ # style_attr=style_attr,
331
+ # final_img=True,
332
+ # data_paths=['/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_1.csv'],
333
+ # line_attr=line_attr,
334
+ # core_cnt=-1)
335
+ # test.run()
336
+
337
+
338
+ # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'y_max': 'auto', 'opacity': 0.9}
339
+ # line_attr = {0: ['Center_1', 'Center_2', 'Green'], 1: ['Ear_left_2', 'Ear_left_1', 'Red']}
340
+ #
341
+ # test = DistancePlotterMultiCore(config_path=r'/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
342
+ # frame_setting=False,
343
+ # video_setting=True,
344
+ # style_attr=style_attr,
345
+ # final_img=True,
346
+ # files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'],
347
+ # line_attr=line_attr,
348
+ # core_cnt=3)
349
+ # test.create_distance_plot()
350
+ # #
351
+ # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8}
352
+ # line_attr = {0: ['Termite_1_Head_1', 'Termite_1_Thorax_1', 'Dark-red']}
353
+ #
354
+
355
+
356
+ # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'opacity': 0.5, 'y_max': 'auto'}
357
+ # line_attr = {0: ['Center_1', 'Center_2', 'Green'], 1: ['Ear_left_2', 'Ear_left_1', 'Red']}
358
+ #
359
+ # test = DistancePlotterMultiCore(config_path=r'/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
360
+ # frame_setting=False,
361
+ # video_setting=True,
362
+ # style_attr=style_attr,
363
+ # final_img=False,
364
+ # files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'],
365
+ # line_attr=line_attr,
366
+ # core_cnt=5)
367
+ # test.create_distance_plot()
368
+
369
+ # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8}
370
+ # line_attr = {0: ['Termite_1_Head_1', 'Termite_1_Thorax_1', 'Dark-red']}
371
+
372
+ # test = DistancePlotterSingleCore(config_path=r'/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/project_config.ini',
373
+ # frame_setting=False,
374
+ # video_setting=True,
375
+ # style_attr=style_attr,
376
+ # files_found=['/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/csv/outlier_corrected_movement_location/termites_1.csv'],
377
+ # line_attr=line_attr)
378
+ # test.create_distance_plot()