simba-uw-tf-dev 4.5.8__py3-none-any.whl → 4.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. simba/SimBA.py +2 -2
  2. simba/assets/.recent_projects.txt +1 -0
  3. simba/assets/icons/frames_2.png +0 -0
  4. simba/assets/lookups/tooptips.json +15 -1
  5. simba/data_processors/agg_clf_counter_mp.py +52 -53
  6. simba/data_processors/blob_location_computer.py +1 -1
  7. simba/data_processors/circling_detector.py +30 -13
  8. simba/data_processors/cuda/geometry.py +45 -27
  9. simba/data_processors/cuda/image.py +1648 -1598
  10. simba/data_processors/cuda/statistics.py +72 -26
  11. simba/data_processors/cuda/timeseries.py +1 -1
  12. simba/data_processors/cue_light_analyzer.py +5 -9
  13. simba/data_processors/egocentric_aligner.py +25 -7
  14. simba/data_processors/freezing_detector.py +55 -47
  15. simba/data_processors/kleinberg_calculator.py +61 -29
  16. simba/feature_extractors/feature_subsets.py +14 -7
  17. simba/feature_extractors/mitra_feature_extractor.py +2 -2
  18. simba/feature_extractors/straub_tail_analyzer.py +4 -6
  19. simba/labelling/standard_labeller.py +1 -1
  20. simba/mixins/config_reader.py +5 -2
  21. simba/mixins/geometry_mixin.py +22 -36
  22. simba/mixins/image_mixin.py +24 -28
  23. simba/mixins/plotting_mixin.py +28 -10
  24. simba/mixins/statistics_mixin.py +48 -11
  25. simba/mixins/timeseries_features_mixin.py +1 -1
  26. simba/mixins/train_model_mixin.py +67 -29
  27. simba/model/inference_batch.py +1 -1
  28. simba/model/yolo_seg_inference.py +3 -3
  29. simba/outlier_tools/skip_outlier_correction.py +1 -1
  30. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  31. simba/plotting/clf_validator_mp.py +4 -5
  32. simba/plotting/cue_light_visualizer.py +6 -7
  33. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  34. simba/plotting/distance_plotter_mp.py +378 -378
  35. simba/plotting/frame_mergerer_ffmpeg.py +137 -196
  36. simba/plotting/gantt_creator.py +29 -10
  37. simba/plotting/gantt_creator_mp.py +96 -33
  38. simba/plotting/geometry_plotter.py +270 -272
  39. simba/plotting/heat_mapper_clf_mp.py +4 -6
  40. simba/plotting/heat_mapper_location_mp.py +2 -2
  41. simba/plotting/light_dark_box_plotter.py +2 -2
  42. simba/plotting/path_plotter_mp.py +26 -29
  43. simba/plotting/plot_clf_results_mp.py +455 -454
  44. simba/plotting/pose_plotter_mp.py +28 -29
  45. simba/plotting/probability_plot_creator_mp.py +288 -288
  46. simba/plotting/roi_plotter_mp.py +31 -31
  47. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  48. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  49. simba/plotting/yolo_pose_track_visualizer.py +32 -27
  50. simba/plotting/yolo_pose_visualizer.py +35 -36
  51. simba/plotting/yolo_seg_visualizer.py +2 -3
  52. simba/pose_importers/simba_blob_importer.py +3 -3
  53. simba/roi_tools/roi_aggregate_stats_mp.py +5 -4
  54. simba/roi_tools/roi_clf_calculator_mp.py +4 -4
  55. simba/sandbox/analyze_runtimes.py +30 -0
  56. simba/sandbox/cuda/egocentric_rotator.py +374 -0
  57. simba/sandbox/get_cpu_pool.py +5 -0
  58. simba/sandbox/proboscis_to_tip.py +28 -0
  59. simba/sandbox/test_directionality.py +47 -0
  60. simba/sandbox/test_nonstatic_directionality.py +27 -0
  61. simba/sandbox/test_pycharm_cuda.py +51 -0
  62. simba/sandbox/test_simba_install.py +41 -0
  63. simba/sandbox/test_static_directionality.py +26 -0
  64. simba/sandbox/test_static_directionality_2d.py +26 -0
  65. simba/sandbox/verify_env.py +42 -0
  66. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
  67. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
  68. simba/ui/pop_ups/clf_add_remove_print_pop_up.py +37 -30
  69. simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
  70. simba/ui/pop_ups/egocentric_alignment_pop_up.py +20 -21
  71. simba/ui/pop_ups/fsttc_pop_up.py +27 -25
  72. simba/ui/pop_ups/gantt_pop_up.py +31 -6
  73. simba/ui/pop_ups/interpolate_pop_up.py +2 -4
  74. simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
  75. simba/ui/pop_ups/multiple_videos_to_frames_popup.py +10 -11
  76. simba/ui/pop_ups/single_video_to_frames_popup.py +10 -10
  77. simba/ui/pop_ups/video_processing_pop_up.py +186 -174
  78. simba/ui/tkinter_functions.py +10 -1
  79. simba/utils/custom_feature_extractor.py +1 -1
  80. simba/utils/data.py +90 -14
  81. simba/utils/enums.py +1 -0
  82. simba/utils/errors.py +441 -440
  83. simba/utils/lookups.py +1203 -1203
  84. simba/utils/printing.py +124 -124
  85. simba/utils/read_write.py +3769 -3721
  86. simba/utils/yolo.py +10 -1
  87. simba/video_processors/blob_tracking_executor.py +2 -2
  88. simba/video_processors/clahe_ui.py +66 -23
  89. simba/video_processors/egocentric_video_rotator.py +46 -44
  90. simba/video_processors/multi_cropper.py +1 -1
  91. simba/video_processors/video_processing.py +5264 -5300
  92. simba/video_processors/videos_to_frames.py +43 -32
  93. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/METADATA +4 -3
  94. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/RECORD +98 -86
  95. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/LICENSE +0 -0
  96. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/WHEEL +0 -0
  97. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/entry_points.txt +0 -0
  98. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,41 @@
1
1
  __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
2
 
3
- import time
4
3
  import warnings
5
4
 
6
5
  warnings.simplefilter(action="ignore", category=FutureWarning)
7
6
  import functools
7
+ import gc
8
8
  import multiprocessing
9
9
  import os
10
10
  import platform
11
+ import sys
12
+ from copy import deepcopy
11
13
  from typing import List, Optional, Union
12
14
 
13
15
  import cv2
16
+
17
+ is_pycharm_ipython = True
18
+ try:
19
+ module_names = list(sys.modules.keys())
20
+ if any('pydev' in str(mod).lower() for mod in module_names):
21
+ is_pycharm_ipython = True
22
+ elif 'IPython' in sys.modules or 'ipython' in sys.modules:
23
+ is_pycharm_ipython = True
24
+ else:
25
+ is_pycharm_ipython = False
26
+ except Exception:
27
+ is_pycharm_ipython = True
28
+
29
+ if not is_pycharm_ipython:
30
+ if 'MPLBACKEND' not in os.environ: os.environ['MPLBACKEND'] = 'Agg'
31
+ try:
32
+ import matplotlib
33
+ matplotlib.use('Agg', force=False)
34
+ except (RecursionError, RuntimeError, ValueError, SystemError):
35
+ import matplotlib
36
+ else:
37
+ import matplotlib
38
+
14
39
  import numpy as np
15
40
  import pandas as pd
16
41
 
@@ -20,13 +45,15 @@ from simba.utils.checks import (
20
45
  check_all_file_names_are_represented_in_video_log,
21
46
  check_file_exist_and_readable, check_int, check_str,
22
47
  check_that_column_exist, check_valid_boolean, check_valid_lst)
23
- from simba.utils.data import create_color_palette, detect_bouts
48
+ from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
49
+ terminate_cpu_pool)
24
50
  from simba.utils.enums import Formats, Options
25
51
  from simba.utils.errors import NoSpecifiedOutputError
52
+ from simba.utils.lookups import get_fonts
26
53
  from simba.utils.printing import SimbaTimer, stdout_success
27
54
  from simba.utils.read_write import (concatenate_videos_in_folder,
28
55
  create_directory, find_core_cnt,
29
- get_fn_ext, read_df)
56
+ get_current_time, get_fn_ext, read_df)
30
57
 
31
58
  HEIGHT = "height"
32
59
  WIDTH = "width"
@@ -46,6 +73,7 @@ def gantt_creator_mp(data: np.array,
46
73
  width: int,
47
74
  height: int,
48
75
  font_size: int,
76
+ font: str,
49
77
  font_rotation: int,
50
78
  palette: np.ndarray,
51
79
  hhmmss: bool):
@@ -67,6 +95,7 @@ def gantt_creator_mp(data: np.array,
67
95
  width=width,
68
96
  height=height,
69
97
  font_size=font_size,
98
+ font=font,
70
99
  font_rotation=font_rotation,
71
100
  video_name=video_name,
72
101
  save_path=None,
@@ -78,11 +107,17 @@ def gantt_creator_mp(data: np.array,
78
107
  cv2.imwrite(frame_save_path, plot)
79
108
  if video_setting:
80
109
  video_writer.write(plot)
81
- print(f"Gantt frame created: {current_frm + 1}, Video: {video_name}, Processing core: {batch_id + 1}")
110
+ # Clear memory after each frame
111
+ del plot
112
+ if current_frm % 100 == 0: # Periodic garbage collection to prevent memory buildup
113
+ gc.collect()
114
+ print(f"[{get_current_time()}] Gantt frame created: {current_frm + 1}, Video: {video_name}, Processing core: {batch_id + 1}")
82
115
 
83
116
  if video_setting:
84
117
  video_writer.release()
118
+ del video_writer
85
119
 
120
+ gc.collect()
86
121
  return batch_id
87
122
 
88
123
 
@@ -120,7 +155,7 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
120
155
 
121
156
  def __init__(self,
122
157
  config_path: Union[str, os.PathLike],
123
- data_paths: List[Union[str, os.PathLike]],
158
+ data_paths: Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] = None,
124
159
  frame_setting: Optional[bool] = False,
125
160
  video_setting: Optional[bool] = False,
126
161
  last_frm_setting: Optional[bool] = True,
@@ -128,15 +163,16 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
128
163
  height: int = 480,
129
164
  font_size: int = 8,
130
165
  font_rotation: int = 45,
166
+ font: Optional[str] = None,
131
167
  palette: str = 'Set1',
132
- core_cnt: Optional[int] = -1,
133
- hhmmss: bool = False):
168
+ core_cnt: int = -1,
169
+ hhmmss: bool = False,
170
+ clf_names: Optional[List[str]] = None):
134
171
 
135
172
  check_file_exist_and_readable(file_path=config_path)
136
173
  if (not frame_setting) and (not video_setting) and (not last_frm_setting):
137
174
  raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.", source=self.__class__.__name__)
138
175
  check_file_exist_and_readable(file_path=config_path)
139
- check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
140
176
  check_int(value=width, min_value=1, name=f'{self.__class__.__name__} width')
141
177
  check_int(value=height, min_value=1, name=f'{self.__class__.__name__} height')
142
178
  check_int(value=font_size, min_value=1, name=f'{self.__class__.__name__} font_size')
@@ -144,14 +180,26 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
144
180
  check_valid_boolean(value=hhmmss, source=f'{self.__class__.__name__} hhmmss', raise_error=False)
145
181
  palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
146
182
  check_str(name=f'{self.__class__.__name__} palette', value=palette, options=palettes)
147
- for file_path in data_paths: check_file_exist_and_readable(file_path=file_path)
148
183
  check_int(name=f"{self.__class__.__name__} core_cnt",value=core_cnt, min_value=-1, unaccepted_vals=[0], max_value=find_core_cnt()[0])
149
184
  self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
150
185
  self.width, self.height, self.font_size, self.font_rotation, self.hhmmss = width, height, font_size, font_rotation, hhmmss
186
+ if font is not None:
187
+ check_str(name=f'{self.__class__.__name__} font', value=font, options=list(get_fonts().keys()), raise_error=True)
151
188
  ConfigReader.__init__(self, config_path=config_path, create_logger=False)
189
+ if isinstance(data_paths, list):
190
+ check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
191
+ elif isinstance(data_paths, str):
192
+ check_file_exist_and_readable(file_path=data_paths)
193
+ data_paths = [data_paths]
194
+ else:
195
+ data_paths = deepcopy(self.machine_results_paths)
196
+ for file_path in data_paths: check_file_exist_and_readable(file_path=file_path)
197
+ if clf_names is not None:
198
+ check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), valid_values=self.clf_names, min_len=1, raise_error=True)
199
+ self.clf_names = clf_names
152
200
  PlottingMixin.__init__(self)
153
201
  self.clr_lst = create_color_palette(pallete_name=palette, increments=len(self.body_parts_lst) + 1, as_int=True, as_rgb_ratio=True)
154
- self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting = frame_setting, video_setting,data_paths, last_frm_setting
202
+ self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting, self.font = frame_setting, video_setting,data_paths, last_frm_setting, font
155
203
  if not os.path.exists(self.gantt_plot_dir): os.makedirs(self.gantt_plot_dir)
156
204
  if platform.system() == "Darwin":
157
205
  multiprocessing.set_start_method("spawn", force=True)
@@ -159,6 +207,10 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
159
207
 
160
208
  def run(self):
161
209
  check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
210
+ if self.video_setting or self.frame_setting:
211
+ self.pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=self.maxtasksperchild, verbose=True, source=self.__class__.__name__)
212
+ else:
213
+ self.pool = None
162
214
  for file_cnt, file_path in enumerate(self.data_paths):
163
215
  video_timer = SimbaTimer(start=True)
164
216
  _, self.video_name, _ = get_fn_ext(file_path)
@@ -185,6 +237,7 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
185
237
  font_size=self.font_size,
186
238
  font_rotation=self.font_rotation,
187
239
  video_name=self.video_name,
240
+ font=self.font,
188
241
  save_path=os.path.join(self.gantt_plot_dir, f"{self.video_name}_final_image.png"),
189
242
  palette=self.clr_lst,
190
243
  hhmmss=self.hhmmss)
@@ -192,33 +245,32 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
192
245
  if self.video_setting or self.frame_setting:
193
246
  frame_data = np.array_split(list(range(0, len(self.data_df))), self.core_cnt)
194
247
  frame_data = [(i, x) for i, x in enumerate(frame_data)]
195
- print(f"Creating gantt, multiprocessing (chunksize: {(self.multiprocess_chunksize)}, cores: {self.core_cnt})...")
196
- with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool:
197
- constants = functools.partial(gantt_creator_mp,
198
- video_setting=self.video_setting,
199
- frame_setting=self.frame_setting,
200
- video_save_dir=self.temp_folder,
201
- frame_folder_dir=self.save_frame_folder_dir,
202
- bouts_df=self.bouts_df,
203
- clf_names=self.clf_names,
204
- fps=self.fps,
205
- width=self.width,
206
- height=self.height,
207
- font_size=self.font_size,
208
- font_rotation=self.font_rotation,
209
- video_name=self.video_name,
210
- palette=self.clr_lst,
211
- hhmmss=self.hhmmss)
212
- for cnt, result in enumerate(pool.imap(constants, frame_data, chunksize=self.multiprocess_chunksize)):
213
- print(f'Batch {result+1/self.core_cnt} complete...')
214
- pool.terminate()
215
- pool.join()
248
+ print(f"[{get_current_time()}] Creating gantt, multiprocessing (chunksize: {(self.multiprocess_chunksize)}, cores: {self.core_cnt})...")
249
+ constants = functools.partial(gantt_creator_mp,
250
+ video_setting=self.video_setting,
251
+ frame_setting=self.frame_setting,
252
+ video_save_dir=self.temp_folder,
253
+ frame_folder_dir=self.save_frame_folder_dir,
254
+ bouts_df=self.bouts_df,
255
+ clf_names=self.clf_names,
256
+ fps=self.fps,
257
+ width=self.width,
258
+ height=self.height,
259
+ font_size=self.font_size,
260
+ font=self.font,
261
+ font_rotation=self.font_rotation,
262
+ video_name=self.video_name,
263
+ palette=self.clr_lst,
264
+ hhmmss=self.hhmmss)
265
+ for cnt, result in enumerate(self.pool.imap(constants, frame_data, chunksize=self.multiprocess_chunksize)):
266
+ print(f'[{get_current_time()}] Batch {result+1}/{self.core_cnt} complete...')
216
267
  if self.video_setting:
217
- print(f"Joining {self.video_name} multiprocessed video...")
268
+ print(f"[{get_current_time()}] Joining {self.video_name} multiprocessed video...")
218
269
  concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path)
219
270
  video_timer.stop_timer()
220
271
  print(f"Gantt video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s) ...")
221
272
 
273
+ terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__)
222
274
  self.timer.stop_timer()
223
275
  stdout_success(msg=f"Gantt visualizations for {len(self.data_paths)} videos created in {self.gantt_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str)
224
276
 
@@ -235,7 +287,18 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
235
287
  # font_rotation= 45)
236
288
  # test.run()
237
289
 
238
-
290
+ # if __name__ == "__main__":
291
+ # test = GanttCreatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
292
+ # frame_setting=False,
293
+ # video_setting=True,
294
+ # data_paths=r"D:\troubleshooting\maplight_ri\project_folder\csv\machine_results\Trial_1_C24_D1_1.csv",
295
+ # last_frm_setting=False,
296
+ # width=640,
297
+ # height= 480,
298
+ # font_size=10,
299
+ # font_rotation= 45,
300
+ # core_cnt=16)
301
+ # test.run()
239
302
 
240
303
 
241
304
  # test = GanttCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini',