simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of simba-uw-tf-dev might be problematic. Click here for more details.

Files changed (46) hide show
  1. simba/assets/lookups/tooptips.json +6 -1
  2. simba/data_processors/agg_clf_counter_mp.py +52 -53
  3. simba/data_processors/cuda/image.py +3 -1
  4. simba/data_processors/cue_light_analyzer.py +5 -9
  5. simba/data_processors/kleinberg_calculator.py +57 -29
  6. simba/mixins/geometry_mixin.py +14 -28
  7. simba/mixins/image_mixin.py +10 -14
  8. simba/mixins/train_model_mixin.py +2 -2
  9. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  10. simba/plotting/clf_validator_mp.py +4 -5
  11. simba/plotting/cue_light_visualizer.py +6 -7
  12. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  13. simba/plotting/distance_plotter_mp.py +378 -378
  14. simba/plotting/gantt_creator_mp.py +61 -31
  15. simba/plotting/geometry_plotter.py +270 -272
  16. simba/plotting/heat_mapper_clf_mp.py +2 -4
  17. simba/plotting/heat_mapper_location_mp.py +2 -2
  18. simba/plotting/light_dark_box_plotter.py +2 -2
  19. simba/plotting/path_plotter_mp.py +26 -29
  20. simba/plotting/plot_clf_results_mp.py +455 -454
  21. simba/plotting/pose_plotter_mp.py +28 -29
  22. simba/plotting/probability_plot_creator_mp.py +288 -288
  23. simba/plotting/roi_plotter_mp.py +31 -31
  24. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  25. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  26. simba/plotting/yolo_pose_track_visualizer.py +32 -27
  27. simba/plotting/yolo_pose_visualizer.py +35 -36
  28. simba/plotting/yolo_seg_visualizer.py +2 -3
  29. simba/roi_tools/roi_aggregate_stats_mp.py +4 -3
  30. simba/roi_tools/roi_clf_calculator_mp.py +3 -3
  31. simba/sandbox/get_cpu_pool.py +5 -0
  32. simba/ui/pop_ups/kleinberg_pop_up.py +39 -41
  33. simba/ui/tkinter_functions.py +3 -0
  34. simba/utils/data.py +89 -12
  35. simba/utils/enums.py +1 -0
  36. simba/utils/printing.py +124 -124
  37. simba/utils/read_write.py +3730 -3721
  38. simba/video_processors/egocentric_video_rotator.py +2 -4
  39. simba/video_processors/video_processing.py +19 -8
  40. simba/video_processors/videos_to_frames.py +1 -1
  41. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/METADATA +1 -1
  42. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/RECORD +46 -45
  43. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/LICENSE +0 -0
  44. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/WHEEL +0 -0
  45. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/entry_points.txt +0 -0
  46. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/top_level.txt +0 -0
@@ -1,32 +1,37 @@
1
1
  __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
2
 
3
- import time
4
3
  import warnings
5
4
 
6
5
  warnings.simplefilter(action="ignore", category=FutureWarning)
7
6
  import functools
7
+ import gc
8
8
  import multiprocessing
9
9
  import os
10
10
  import platform
11
+ from copy import deepcopy
11
12
  from typing import List, Optional, Union
12
13
 
13
14
  import cv2
15
+ import matplotlib
14
16
  import numpy as np
15
17
  import pandas as pd
16
18
 
19
+ matplotlib.use('Agg')
20
+
17
21
  from simba.mixins.config_reader import ConfigReader
18
22
  from simba.mixins.plotting_mixin import PlottingMixin
19
23
  from simba.utils.checks import (
20
24
  check_all_file_names_are_represented_in_video_log,
21
25
  check_file_exist_and_readable, check_int, check_str,
22
26
  check_that_column_exist, check_valid_boolean, check_valid_lst)
23
- from simba.utils.data import create_color_palette, detect_bouts
27
+ from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
28
+ terminate_cpu_pool)
24
29
  from simba.utils.enums import Formats, Options
25
30
  from simba.utils.errors import NoSpecifiedOutputError
26
31
  from simba.utils.printing import SimbaTimer, stdout_success
27
32
  from simba.utils.read_write import (concatenate_videos_in_folder,
28
33
  create_directory, find_core_cnt,
29
- get_fn_ext, read_df)
34
+ get_current_time, get_fn_ext, read_df)
30
35
 
31
36
  HEIGHT = "height"
32
37
  WIDTH = "width"
@@ -78,11 +83,17 @@ def gantt_creator_mp(data: np.array,
78
83
  cv2.imwrite(frame_save_path, plot)
79
84
  if video_setting:
80
85
  video_writer.write(plot)
81
- print(f"Gantt frame created: {current_frm + 1}, Video: {video_name}, Processing core: {batch_id + 1}")
86
+ # Clear memory after each frame
87
+ del plot
88
+ if current_frm % 100 == 0: # Periodic garbage collection to prevent memory buildup
89
+ gc.collect()
90
+ print(f"[{get_current_time()}] Gantt frame created: {current_frm + 1}, Video: {video_name}, Processing core: {batch_id + 1}")
82
91
 
83
92
  if video_setting:
84
93
  video_writer.release()
94
+ del video_writer
85
95
 
96
+ gc.collect()
86
97
  return batch_id
87
98
 
88
99
 
@@ -120,7 +131,7 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
120
131
 
121
132
  def __init__(self,
122
133
  config_path: Union[str, os.PathLike],
123
- data_paths: List[Union[str, os.PathLike]],
134
+ data_paths: Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] = None,
124
135
  frame_setting: Optional[bool] = False,
125
136
  video_setting: Optional[bool] = False,
126
137
  last_frm_setting: Optional[bool] = True,
@@ -129,14 +140,13 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
129
140
  font_size: int = 8,
130
141
  font_rotation: int = 45,
131
142
  palette: str = 'Set1',
132
- core_cnt: Optional[int] = -1,
143
+ core_cnt: int = -1,
133
144
  hhmmss: bool = False):
134
145
 
135
146
  check_file_exist_and_readable(file_path=config_path)
136
147
  if (not frame_setting) and (not video_setting) and (not last_frm_setting):
137
148
  raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.", source=self.__class__.__name__)
138
149
  check_file_exist_and_readable(file_path=config_path)
139
- check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
140
150
  check_int(value=width, min_value=1, name=f'{self.__class__.__name__} width')
141
151
  check_int(value=height, min_value=1, name=f'{self.__class__.__name__} height')
142
152
  check_int(value=font_size, min_value=1, name=f'{self.__class__.__name__} font_size')
@@ -144,11 +154,18 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
144
154
  check_valid_boolean(value=hhmmss, source=f'{self.__class__.__name__} hhmmss', raise_error=False)
145
155
  palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
146
156
  check_str(name=f'{self.__class__.__name__} palette', value=palette, options=palettes)
147
- for file_path in data_paths: check_file_exist_and_readable(file_path=file_path)
148
157
  check_int(name=f"{self.__class__.__name__} core_cnt",value=core_cnt, min_value=-1, unaccepted_vals=[0], max_value=find_core_cnt()[0])
149
158
  self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
150
159
  self.width, self.height, self.font_size, self.font_rotation, self.hhmmss = width, height, font_size, font_rotation, hhmmss
151
160
  ConfigReader.__init__(self, config_path=config_path, create_logger=False)
161
+ if isinstance(data_paths, list):
162
+ check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
163
+ elif isinstance(data_paths, str):
164
+ check_file_exist_and_readable(file_path=data_paths)
165
+ data_paths = [data_paths]
166
+ else:
167
+ data_paths = deepcopy(self.machine_results_paths)
168
+ for file_path in data_paths: check_file_exist_and_readable(file_path=file_path)
152
169
  PlottingMixin.__init__(self)
153
170
  self.clr_lst = create_color_palette(pallete_name=palette, increments=len(self.body_parts_lst) + 1, as_int=True, as_rgb_ratio=True)
154
171
  self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting = frame_setting, video_setting,data_paths, last_frm_setting
@@ -159,6 +176,10 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
159
176
 
160
177
  def run(self):
161
178
  check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
179
+ if self.video_setting or self.frame_setting:
180
+ self.pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=self.maxtasksperchild, verbose=True, source=self.__class__.__name__)
181
+ else:
182
+ self.pool = None
162
183
  for file_cnt, file_path in enumerate(self.data_paths):
163
184
  video_timer = SimbaTimer(start=True)
164
185
  _, self.video_name, _ = get_fn_ext(file_path)
@@ -192,33 +213,31 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
192
213
  if self.video_setting or self.frame_setting:
193
214
  frame_data = np.array_split(list(range(0, len(self.data_df))), self.core_cnt)
194
215
  frame_data = [(i, x) for i, x in enumerate(frame_data)]
195
- print(f"Creating gantt, multiprocessing (chunksize: {(self.multiprocess_chunksize)}, cores: {self.core_cnt})...")
196
- with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool:
197
- constants = functools.partial(gantt_creator_mp,
198
- video_setting=self.video_setting,
199
- frame_setting=self.frame_setting,
200
- video_save_dir=self.temp_folder,
201
- frame_folder_dir=self.save_frame_folder_dir,
202
- bouts_df=self.bouts_df,
203
- clf_names=self.clf_names,
204
- fps=self.fps,
205
- width=self.width,
206
- height=self.height,
207
- font_size=self.font_size,
208
- font_rotation=self.font_rotation,
209
- video_name=self.video_name,
210
- palette=self.clr_lst,
211
- hhmmss=self.hhmmss)
212
- for cnt, result in enumerate(pool.imap(constants, frame_data, chunksize=self.multiprocess_chunksize)):
213
- print(f'Batch {result+1/self.core_cnt} complete...')
214
- pool.terminate()
215
- pool.join()
216
+ print(f"[{get_current_time()}] Creating gantt, multiprocessing (chunksize: {(self.multiprocess_chunksize)}, cores: {self.core_cnt})...")
217
+ constants = functools.partial(gantt_creator_mp,
218
+ video_setting=self.video_setting,
219
+ frame_setting=self.frame_setting,
220
+ video_save_dir=self.temp_folder,
221
+ frame_folder_dir=self.save_frame_folder_dir,
222
+ bouts_df=self.bouts_df,
223
+ clf_names=self.clf_names,
224
+ fps=self.fps,
225
+ width=self.width,
226
+ height=self.height,
227
+ font_size=self.font_size,
228
+ font_rotation=self.font_rotation,
229
+ video_name=self.video_name,
230
+ palette=self.clr_lst,
231
+ hhmmss=self.hhmmss)
232
+ for cnt, result in enumerate(self.pool.imap(constants, frame_data, chunksize=self.multiprocess_chunksize)):
233
+ print(f'[{get_current_time()}] Batch {result+1}/{self.core_cnt} complete...')
216
234
  if self.video_setting:
217
- print(f"Joining {self.video_name} multiprocessed video...")
235
+ print(f"[{get_current_time()}] Joining {self.video_name} multiprocessed video...")
218
236
  concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path)
219
237
  video_timer.stop_timer()
220
238
  print(f"Gantt video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s) ...")
221
239
 
240
+ terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__)
222
241
  self.timer.stop_timer()
223
242
  stdout_success(msg=f"Gantt visualizations for {len(self.data_paths)} videos created in {self.gantt_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str)
224
243
 
@@ -235,7 +254,18 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
235
254
  # font_rotation= 45)
236
255
  # test.run()
237
256
 
238
-
257
+ if __name__ == "__main__":
258
+ test = GanttCreatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
259
+ frame_setting=False,
260
+ video_setting=True,
261
+ data_paths=r"D:\troubleshooting\maplight_ri\project_folder\csv\machine_results\Trial_1_C24_D1_1.csv",
262
+ last_frm_setting=False,
263
+ width=640,
264
+ height= 480,
265
+ font_size=10,
266
+ font_rotation= 45,
267
+ core_cnt=16)
268
+ test.run()
239
269
 
240
270
 
241
271
  # test = GanttCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini',