simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of simba-uw-tf-dev might be problematic. Click here for more details.

Files changed (46) hide show
  1. simba/assets/lookups/tooptips.json +6 -1
  2. simba/data_processors/agg_clf_counter_mp.py +52 -53
  3. simba/data_processors/cuda/image.py +3 -1
  4. simba/data_processors/cue_light_analyzer.py +5 -9
  5. simba/data_processors/kleinberg_calculator.py +57 -29
  6. simba/mixins/geometry_mixin.py +14 -28
  7. simba/mixins/image_mixin.py +10 -14
  8. simba/mixins/train_model_mixin.py +2 -2
  9. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  10. simba/plotting/clf_validator_mp.py +4 -5
  11. simba/plotting/cue_light_visualizer.py +6 -7
  12. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  13. simba/plotting/distance_plotter_mp.py +378 -378
  14. simba/plotting/gantt_creator_mp.py +61 -31
  15. simba/plotting/geometry_plotter.py +270 -272
  16. simba/plotting/heat_mapper_clf_mp.py +2 -4
  17. simba/plotting/heat_mapper_location_mp.py +2 -2
  18. simba/plotting/light_dark_box_plotter.py +2 -2
  19. simba/plotting/path_plotter_mp.py +26 -29
  20. simba/plotting/plot_clf_results_mp.py +455 -454
  21. simba/plotting/pose_plotter_mp.py +28 -29
  22. simba/plotting/probability_plot_creator_mp.py +288 -288
  23. simba/plotting/roi_plotter_mp.py +31 -31
  24. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  25. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  26. simba/plotting/yolo_pose_track_visualizer.py +32 -27
  27. simba/plotting/yolo_pose_visualizer.py +35 -36
  28. simba/plotting/yolo_seg_visualizer.py +2 -3
  29. simba/roi_tools/roi_aggregate_stats_mp.py +4 -3
  30. simba/roi_tools/roi_clf_calculator_mp.py +3 -3
  31. simba/sandbox/get_cpu_pool.py +5 -0
  32. simba/ui/pop_ups/kleinberg_pop_up.py +39 -41
  33. simba/ui/tkinter_functions.py +3 -0
  34. simba/utils/data.py +89 -12
  35. simba/utils/enums.py +1 -0
  36. simba/utils/printing.py +124 -124
  37. simba/utils/read_write.py +3730 -3721
  38. simba/video_processors/egocentric_video_rotator.py +2 -4
  39. simba/video_processors/video_processing.py +19 -8
  40. simba/video_processors/videos_to_frames.py +1 -1
  41. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/METADATA +1 -1
  42. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/RECORD +46 -45
  43. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/LICENSE +0 -0
  44. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/WHEEL +0 -0
  45. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/entry_points.txt +0 -0
  46. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/top_level.txt +0 -0
@@ -38,5 +38,10 @@
38
38
  "EGOCENTRIC_ANCHOR": "This body-part will be placed in the center of the video",
39
39
  "EGOCENTRIC_DIRECTION_ANCHOR": "This body-part will be placed at N degrees relative to the anchor",
40
40
  "EGOCENTRIC_DIRECTION": "The anchor body-part will always be placed at these degrees relative to the center anchor",
41
- "CORE_COUNT": "Higher core counts speeds up processing but may require more RAM memory"
41
+ "CORE_COUNT": "Higher core counts speeds up processing but may require more RAM memory",
42
+ "KLEINBERG_SIGMA": "Higher values (e.g., 2-3) produce fewer but longer bursts; lower values (e.g., 1.1-1.5) detect more frequent, shorter bursts. Must be > 1.01",
43
+ "KLEINBERG_GAMMA": "Higher values (e.g., 0.5-1.0) reduce total burst count by making downward transitions costly; lower values (e.g., 0.1-0.3) allow more flexible state changes",
44
+ "KLEINBERG_HIERARCHY": "Hierarchy level to extract bursts from (0=lowest, higher=more selective).\n Level 0 captures all bursts; level 1-2 typically filters noise; level 3+ selects only the most prominent, sustained bursts.\nHigher levels yield fewer but more confident detections",
45
+ "KLEINBERG_HIERARCHY_SEARCH": "If True, searches for target hierarchy level within detected burst periods,\n falling back to lower levels if target not found. If False, extracts only bursts at the exact specified hierarchy level.\n Recommended when target hierarchy may be sparse.",
46
+ "KLEINBERG_SAVE_ORIGINALS": "If True, saves the original data in a new sub-directory of \nthe project_folder/csv/machine_results directory"
42
47
  }
@@ -20,7 +20,7 @@ from simba.utils.checks import (
20
20
  check_all_file_names_are_represented_in_video_log,
21
21
  check_file_exist_and_readable, check_if_dir_exists, check_int,
22
22
  check_valid_boolean, check_valid_dataframe, check_valid_lst)
23
- from simba.utils.data import detect_bouts
23
+ from simba.utils.data import detect_bouts, terminate_cpu_pool
24
24
  from simba.utils.enums import TagNames
25
25
  from simba.utils.errors import NoChoosenMeasurementError
26
26
  from simba.utils.printing import SimbaTimer, log_event, stdout_success
@@ -210,8 +210,7 @@ class AggregateClfCalculatorMultiprocess(ConfigReader):
210
210
  self.bouts_df_lst.append(batch_bouts_df_lst)
211
211
  print(f"Data batch core {batch_id+1} / {self.core_cnt} complete...")
212
212
  self.bouts_df_lst = [df for sub in self.bouts_df_lst for df in sub]
213
- pool.join()
214
- pool.terminate()
213
+ terminate_cpu_pool(pool=pool, force=False)
215
214
 
216
215
  def save(self) -> None:
217
216
  """
@@ -242,56 +241,56 @@ class AggregateClfCalculatorMultiprocess(ConfigReader):
242
241
  self.timer.stop_timer()
243
242
  stdout_success(msg=f"Data aggregate log saved at {self.save_path}", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
244
243
 
245
- if __name__ == "__main__" and not hasattr(sys, 'ps1'):
246
- parser = argparse.ArgumentParser(description='Compute aggregate descriptive statistics from classification data.')
247
- parser.add_argument('--config_path', type=str, required=True, help='Path to SimBA project config file')
248
- parser.add_argument('--classifiers', type=str, nargs='+', required=True, help='List of classifier names to analyze')
249
- parser.add_argument('--data_dir', type=str, default=None, help='Directory containing machine results CSV files (default: project machine_results directory)')
250
- parser.add_argument('--detailed_bout_data', action='store_true', help='Save detailed bout data for each bout')
251
- parser.add_argument('--transpose', action='store_true', help='Create output with one video per row')
252
- parser.add_argument('--no_first_occurrence', action='store_true', help='Disable first occurrence calculation')
253
- parser.add_argument('--no_event_count', action='store_true', help='Disable event count calculation')
254
- parser.add_argument('--no_total_event_duration', action='store_true', help='Disable total event duration calculation')
255
- parser.add_argument('--no_mean_event_duration', action='store_true', help='Disable mean event duration calculation')
256
- parser.add_argument('--no_median_event_duration', action='store_true', help='Disable median event duration calculation')
257
- parser.add_argument('--no_mean_interval_duration', action='store_true', help='Disable mean interval duration calculation')
258
- parser.add_argument('--no_median_interval_duration', action='store_true', help='Disable median interval duration calculation')
259
- parser.add_argument('--frame_count', action='store_true', help='Include frame count in output')
260
- parser.add_argument('--video_length', action='store_true', help='Include video length in output')
261
-
262
- args = parser.parse_args()
263
-
264
- clf_calculator = AggregateClfCalculatorMultiprocess(
265
- config_path=args.config_path,
266
- classifiers=args.classifiers,
267
- data_dir=args.data_dir,
268
- detailed_bout_data=args.detailed_bout_data,
269
- transpose=args.transpose,
270
- first_occurrence=not args.no_first_occurrence,
271
- event_count=not args.no_event_count,
272
- total_event_duration=not args.no_total_event_duration,
273
- mean_event_duration=not args.no_mean_event_duration,
274
- median_event_duration=not args.no_median_event_duration,
275
- mean_interval_duration=not args.no_mean_interval_duration,
276
- median_interval_duration=not args.no_median_interval_duration,
277
- frame_count=args.frame_count,
278
- video_length=args.video_length
279
- )
280
- clf_calculator.run()
281
- clf_calculator.save()
282
-
283
- # if __name__ == "__main__":
284
- # test = AggregateClfCalculatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
285
- # classifiers=['attack'],
286
- # transpose=True,
287
- # mean_event_duration = True,
288
- # median_event_duration = True,
289
- # mean_interval_duration = True,
290
- # median_interval_duration = True,
291
- # detailed_bout_data=True,
292
- # core_cnt=12)
293
- # test.run()
294
- # test.save()
244
+ # if __name__ == "__main__" and not hasattr(sys, 'ps1'):
245
+ # parser = argparse.ArgumentParser(description='Compute aggregate descriptive statistics from classification data.')
246
+ # parser.add_argument('--config_path', type=str, required=True, help='Path to SimBA project config file')
247
+ # parser.add_argument('--classifiers', type=str, nargs='+', required=True, help='List of classifier names to analyze')
248
+ # parser.add_argument('--data_dir', type=str, default=None, help='Directory containing machine results CSV files (default: project machine_results directory)')
249
+ # parser.add_argument('--detailed_bout_data', action='store_true', help='Save detailed bout data for each bout')
250
+ # parser.add_argument('--transpose', action='store_true', help='Create output with one video per row')
251
+ # parser.add_argument('--no_first_occurrence', action='store_true', help='Disable first occurrence calculation')
252
+ # parser.add_argument('--no_event_count', action='store_true', help='Disable event count calculation')
253
+ # parser.add_argument('--no_total_event_duration', action='store_true', help='Disable total event duration calculation')
254
+ # parser.add_argument('--no_mean_event_duration', action='store_true', help='Disable mean event duration calculation')
255
+ # parser.add_argument('--no_median_event_duration', action='store_true', help='Disable median event duration calculation')
256
+ # parser.add_argument('--no_mean_interval_duration', action='store_true', help='Disable mean interval duration calculation')
257
+ # parser.add_argument('--no_median_interval_duration', action='store_true', help='Disable median interval duration calculation')
258
+ # parser.add_argument('--frame_count', action='store_true', help='Include frame count in output')
259
+ # parser.add_argument('--video_length', action='store_true', help='Include video length in output')
260
+ #
261
+ # args = parser.parse_args()
262
+ #
263
+ # clf_calculator = AggregateClfCalculatorMultiprocess(
264
+ # config_path=args.config_path,
265
+ # classifiers=args.classifiers,
266
+ # data_dir=args.data_dir,
267
+ # detailed_bout_data=args.detailed_bout_data,
268
+ # transpose=args.transpose,
269
+ # first_occurrence=not args.no_first_occurrence,
270
+ # event_count=not args.no_event_count,
271
+ # total_event_duration=not args.no_total_event_duration,
272
+ # mean_event_duration=not args.no_mean_event_duration,
273
+ # median_event_duration=not args.no_median_event_duration,
274
+ # mean_interval_duration=not args.no_mean_interval_duration,
275
+ # median_interval_duration=not args.no_median_interval_duration,
276
+ # frame_count=args.frame_count,
277
+ # video_length=args.video_length
278
+ # )
279
+ # clf_calculator.run()
280
+ # clf_calculator.save()
281
+
282
+ if __name__ == "__main__":
283
+ test = AggregateClfCalculatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
284
+ classifiers=['attack'],
285
+ transpose=True,
286
+ mean_event_duration = True,
287
+ median_event_duration = True,
288
+ mean_interval_duration = True,
289
+ median_interval_duration = True,
290
+ detailed_bout_data=True,
291
+ core_cnt=12)
292
+ test.run()
293
+ test.save()
295
294
 
296
295
 
297
296
  # test = AggregateClfCalculator(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
@@ -213,7 +213,6 @@ def _average_3d_stack_cuda(image_stack: np.ndarray) -> np.ndarray:
213
213
  return results
214
214
 
215
215
 
216
-
217
216
  def create_average_frm_cuda(video_path: Union[str, os.PathLike],
218
217
  start_frm: Optional[int] = None,
219
218
  end_frm: Optional[int] = None,
@@ -1512,6 +1511,9 @@ def pose_plotter(data: Union[str, os.PathLike, np.ndarray],
1512
1511
  stdout_success(msg=f'Pose-estimation video saved at {save_path}.', elapsed_time=total_timer.elapsed_time_str)
1513
1512
 
1514
1513
 
1514
+
1515
+ #x = create_average_frm_cuda(video_path=r"D:\troubleshooting\mitra\project_folder\videos\average_cpu_test\20min.mp4", verbose=True, batch_size=500, async_frame_read=False)
1516
+
1515
1517
  # VIDEO_PATH = "/mnt/d/troubleshooting/maplight_ri/project_folder/blob/videos/Trial_1_C24_D1_1.mp4"
1516
1518
  # #
1517
1519
  #
@@ -1,13 +1,9 @@
1
1
  __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
2
 
3
3
  import functools
4
- import glob
5
- import itertools
6
4
  import multiprocessing
7
5
  import os
8
- import platform
9
- import time
10
- from typing import Dict, List, Optional, Union
6
+ from typing import Dict, List, Union
11
7
 
12
8
  import cv2
13
9
  import numpy as np
@@ -17,9 +13,9 @@ from simba.mixins.config_reader import ConfigReader
17
13
  from simba.mixins.statistics_mixin import Statistics
18
14
  from simba.utils.checks import (
19
15
  check_all_file_names_are_represented_in_video_log, check_if_dir_exists,
20
- check_if_valid_img, check_int, check_nvidea_gpu_available,
21
- check_valid_boolean, check_valid_lst)
22
- from simba.utils.data import detect_bouts, slice_roi_dict_from_attribute
16
+ check_if_valid_img, check_int, check_valid_boolean, check_valid_lst)
17
+ from simba.utils.data import (detect_bouts, slice_roi_dict_from_attribute,
18
+ terminate_cpu_pool)
23
19
  from simba.utils.enums import Defaults, Keys
24
20
  from simba.utils.errors import NoROIDataError
25
21
  from simba.utils.printing import SimbaTimer, stdout_success
@@ -220,7 +216,7 @@ class CueLightAnalyzer(ConfigReader):
220
216
  else: self.intensities[key] = subdict
221
217
  if self.verbose:
222
218
  print(f'Batch {int(np.ceil(cnt + 1 / self.core_cnt))} complete...')
223
- pool.terminate(); pool.join()
219
+ terminate_cpu_pool(pool=pool, force=False)
224
220
  kmeans = self._get_kmeans(intensities=self.intensities)
225
221
  self.data_df = self._append_light_data(data_df=self.data_df, kmeans_data=kmeans)
226
222
  self.data_df = self._remove_outlier_events(data_df=self.data_df)
@@ -13,10 +13,10 @@ from simba.data_processors.pybursts_calculator import kleinberg_burst_detection
13
13
  from simba.mixins.config_reader import ConfigReader
14
14
  from simba.utils.checks import (check_float, check_if_dir_exists,
15
15
  check_if_filepath_list_is_empty, check_int,
16
- check_that_column_exist, check_valid_lst)
16
+ check_that_column_exist, check_valid_lst, check_valid_boolean)
17
17
  from simba.utils.enums import Paths, TagNames
18
18
  from simba.utils.printing import SimbaTimer, log_event, stdout_success
19
- from simba.utils.read_write import get_fn_ext, read_df, write_df
19
+ from simba.utils.read_write import get_fn_ext, read_df, write_df, get_current_time, find_files_of_filetypes_in_directory, remove_a_folder, copy_files_to_directory
20
20
  from simba.utils.warnings import KleinbergWarning
21
21
 
22
22
 
@@ -38,12 +38,13 @@ class KleinbergCalculator(ConfigReader):
38
38
 
39
39
  :param str config_path: path to SimBA project config file in Configparser format
40
40
  :param List[str] classifier_names: Classifier names to apply Kleinberg smoothing to.
41
- :param float sigma: Burst detection sigma value. Higher sigma values and fewer, longer, behavioural bursts will be recognised. Default: 2.
42
- :param float gamma: Burst detection gamma value. Higher gamma values and fewer behavioural bursts will be recognised. Default: 0.3.
43
- :param int hierarchy: Burst detection hierarchy level. Higher hierarchy values and fewer behavioural bursts will to be recognised. Default: 1.
44
- :param bool hierarchical_search: See `Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/kleinberg_filter.md#hierarchical-search-example>`_ Default: False.
41
+ :param float sigma: State transition cost for moving to higher burst levels. Higher values (e.g., 2-3) produce fewer but longer bursts; lower values (e.g., 1.1-1.5) detect more frequent, shorter bursts. Must be > 1.01. Default: 2.
42
+ :param float gamma: State transition cost for moving to lower burst levels. Higher values (e.g., 0.5-1.0) reduce total burst count by making downward transitions costly; lower values (e.g., 0.1-0.3) allow more flexible state changes. Must be >= 0. Default: 0.3.
43
+ :param int hierarchy: Hierarchy level to extract bursts from (0=lowest, higher=more selective). Level 0 captures all bursts; level 1-2 typically filters noise; level 3+ selects only the most prominent, sustained bursts. Higher levels yield fewer but more confident detections. Must be >= 0. Default: 1.
44
+ :param bool hierarchical_search: If True, searches for target hierarchy level within detected burst periods, falling back to lower levels if target not found. If False, extracts only bursts at the exact specified hierarchy level. Recommended when target hierarchy may be sparse. Default: False.
45
45
  :param Optional[Union[str, os.PathLike]] input_dir: The directory with files to perform kleinberg smoothing on. If None, defaults to `project_folder/csv/machine_results`
46
46
  :param Optional[Union[str, os.PathLike]] output_dir: Location to save smoothened data in. If None, defaults to `project_folder/csv/machine_results`
47
+ :param Optional[bool] save_originals: If True, saves the original data in sub-directory of the ouput directory.`
47
48
 
48
49
  :example I:
49
50
  >>> kleinberg_calculator = KleinbergCalculator(config_path='MySimBAConfigPath', classifier_names=['Attack'], sigma=2, gamma=0.3, hierarchy=2, hierarchical_search=False)
@@ -68,10 +69,12 @@ class KleinbergCalculator(ConfigReader):
68
69
 
69
70
  def __init__(self,
70
71
  config_path: Union[str, os.PathLike],
71
- classifier_names: List[str],
72
- sigma: Optional[int] = 2,
73
- gamma: Optional[float] = 0.3,
72
+ classifier_names: Optional[List[str]] = None,
73
+ sigma: float = 2,
74
+ gamma: float = 0.3,
74
75
  hierarchy: Optional[int] = 1,
76
+ verbose: bool = True,
77
+ save_originals: bool = True,
75
78
  hierarchical_search: Optional[bool] = False,
76
79
  input_dir: Optional[Union[str, os.PathLike]] = None,
77
80
  output_dir: Optional[Union[str, os.PathLike]] = None):
@@ -81,25 +84,31 @@ class KleinbergCalculator(ConfigReader):
81
84
  check_float(value=sigma, name=f'{self.__class__.__name__} sigma', min_value=1.01)
82
85
  check_float(value=gamma, name=f'{self.__class__.__name__} gamma', min_value=0)
83
86
  check_int(value=hierarchy, name=f'{self.__class__.__name__} hierarchy', min_value=0)
84
- check_valid_lst(data=classifier_names, source=f'{self.__class__.__name__} classifier_names', valid_dtypes=(str,), min_len=1)
87
+ if isinstance(classifier_names, list):
88
+ check_valid_lst(data=classifier_names, source=f'{self.__class__.__name__} classifier_names', valid_dtypes=(str,), min_len=1)
89
+ else:
90
+ classifier_names = deepcopy(self.clf_names)
91
+ check_valid_boolean(value=verbose, source=f'{self.__class__.__name__} verbose', raise_error=True)
92
+ check_valid_boolean(value=save_originals, source=f'{self.__class__.__name__} save_originals', raise_error=True)
85
93
  self.hierarchical_search, sigma, gamma, hierarchy, self.output_dir = (hierarchical_search, float(sigma), float(gamma), int(hierarchy), output_dir)
86
- self.sigma, self.gamma, self.hierarchy, self.clfs = ( float(sigma), float(gamma), float(hierarchy), classifier_names)
94
+ self.sigma, self.gamma, self.hierarchy, self.clfs = ( float(sigma), float(gamma), int(hierarchy), classifier_names)
95
+ self.verbose, self.save_originals = verbose, save_originals
87
96
  if input_dir is None:
88
- self.data_paths, self.output_dir = self.machine_results_paths, self.machine_results_dir
89
- check_if_filepath_list_is_empty(filepaths=self.machine_results_paths, error_msg=f"SIMBA ERROR: No data files found in {self.machine_results_dir}. Cannot perform Kleinberg smoothing")
90
- original_data_files_folder = os.path.join(self.project_path, Paths.MACHINE_RESULTS_DIR.value, f"Pre_Kleinberg_{self.datetime}")
91
- if not os.path.exists(original_data_files_folder):
92
- os.makedirs(original_data_files_folder)
93
- for file_path in self.machine_results_paths:
94
- _, file_name, ext = get_fn_ext(file_path)
95
- shutil.copyfile(file_path, os.path.join(original_data_files_folder, file_name + ext))
97
+ self.input_dir = os.path.join(self.project_path, Paths.MACHINE_RESULTS_DIR.value)
96
98
  else:
97
99
  check_if_dir_exists(in_dir=input_dir)
98
- self.data_paths = glob.glob(input_dir + f"/*.{self.file_type}")
99
- check_if_filepath_list_is_empty(filepaths=self.data_paths, error_msg=f"SIMBA ERROR: No data files found in {input_dir}. Cannot perform Kleinberg smoothing")
100
- if not os.path.isdir(output_dir):
101
- os.makedirs(output_dir)
102
- print(f"Processing Kleinberg burst detection for {len(self.data_paths)} file(s) and {len(classifier_names)} classifier(s)...")
100
+ self.input_dir = deepcopy(input_dir)
101
+ self.data_paths = find_files_of_filetypes_in_directory(directory=self.input_dir, extensions=[f'.{self.file_type}'], sort_alphabetically=True, raise_error=True)
102
+ if output_dir is None:
103
+ self.output_dir = deepcopy(self.input_dir)
104
+ else:
105
+ check_if_dir_exists(in_dir=output_dir)
106
+ self.output_dir = deepcopy(output_dir)
107
+ self.original_data_files_folder = os.path.join(self.output_dir, f"Pre_Kleinberg_{self.datetime}")
108
+ remove_a_folder(folder_dir=self.original_data_files_folder, ignore_errors=True)
109
+ os.makedirs(self.original_data_files_folder)
110
+ copy_files_to_directory(file_paths=self.data_paths, dir=self.original_data_files_folder, verbose=False, integer_save_names=False)
111
+ if self.verbose: print(f"Processing Kleinberg burst detection for {len(self.data_paths)} file(s) and {len(classifier_names)} classifier(s)...")
103
112
 
104
113
  def hierarchical_searcher(self):
105
114
  if (len(self.kleinberg_bouts["Hierarchy"]) == 1) and (int(self.kleinberg_bouts.at[0, "Hierarchy"]) == 0):
@@ -135,7 +144,7 @@ class KleinbergCalculator(ConfigReader):
135
144
  for file_cnt, file_path in enumerate(self.data_paths):
136
145
  _, video_name, _ = get_fn_ext(file_path)
137
146
  video_timer = SimbaTimer(start=True)
138
- print(f"Performing Kleinberg burst detection for video {video_name} (Video {file_cnt+1}/{len(self.data_paths)})...")
147
+ if self.verbose: print(f"[{get_current_time()}] Performing Kleinberg burst detection for video {video_name} (Video {file_cnt+1}/{len(self.data_paths)})...")
139
148
  data_df = read_df(file_path, self.file_type).reset_index(drop=True)
140
149
  video_out_df = deepcopy(data_df)
141
150
  check_that_column_exist(df=data_df, column_name=self.clfs, file_name=video_name)
@@ -150,7 +159,7 @@ class KleinbergCalculator(ConfigReader):
150
159
  self.kleinberg_bouts.insert(loc=0, column="Video", value=video_name)
151
160
  detailed_df_lst.append(self.kleinberg_bouts)
152
161
  if self.hierarchical_search:
153
- print(f"Applying hierarchical search for video {video_name}...")
162
+ if self.verbose: print(f"[{get_current_time()}] Applying hierarchical search for video {video_name}...")
154
163
  self.hierarchical_searcher()
155
164
  else:
156
165
  self.clf_bouts_in_hierarchy = self.kleinberg_bouts[self.kleinberg_bouts["Hierarchy"] == self.hierarchy]
@@ -160,19 +169,38 @@ class KleinbergCalculator(ConfigReader):
160
169
  video_out_df.loc[hierarchy_idx, clf] = 1
161
170
  write_df(video_out_df, self.file_type, save_path)
162
171
  video_timer.stop_timer()
163
- print(f'Kleinberg analysis complete for video {video_name} (saved at {save_path}), elapsed time: {video_timer.elapsed_time_str}s.')
172
+ if self.verbose: print(f'[{get_current_time()}] Kleinberg analysis complete for video {video_name} (saved at {save_path}), elapsed time: {video_timer.elapsed_time_str}s.')
164
173
 
165
174
  self.timer.stop_timer()
175
+ if not self.save_originals:
176
+ remove_a_folder(folder_dir=self.original_data_files_folder, ignore_errors=False)
177
+ else:
178
+ if self.verbose: stdout_success(msg=f"Original, un-smoothened data, saved in {self.original_data_files_folder} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
166
179
  if len(detailed_df_lst) > 0:
167
180
  self.detailed_df = pd.concat(detailed_df_lst, axis=0)
168
181
  detailed_save_path = os.path.join(self.logs_path, f"Kleinberg_detailed_log_{self.datetime}.csv")
169
182
  self.detailed_df.to_csv(detailed_save_path)
170
- stdout_success(msg=f"Kleinberg analysis complete. See {detailed_save_path} for details of detected bouts of all classifiers in all hierarchies", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
183
+ if self.verbose: stdout_success(msg=f"Kleinberg analysis complete for {len(self.data_paths)} files. Results stored in {self.output_dir} directory. See {detailed_save_path} for details of detected bouts of all classifiers in all hierarchies", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
171
184
  else:
172
- print("Kleinberg analysis complete.")
185
+ if self.verbose: print(f"[{get_current_time()}] Kleinberg analysis complete for {len(self.data_paths)} files. Results stored in {self.output_dir} directory.")
173
186
  KleinbergWarning(msg="All behavior bouts removed following kleinberg smoothing", source=self.__class__.__name__)
174
187
 
175
188
 
189
+
190
+
191
+ # test = KleinbergCalculator(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
192
+ # classifier_names=['straub_tail'],
193
+ # sigma=1.1,
194
+ # gamma=0.1,
195
+ # hierarchy=1,
196
+ # save_originals=False,
197
+ # hierarchical_search=False)
198
+ #
199
+ # test.run()
200
+ #
201
+
202
+
203
+
176
204
  # test = KleinbergCalculator(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/levi/project_folder/project_config.ini',
177
205
  # classifier_names=['No_Seizure_(0)'],
178
206
  # sigma=1.1,
@@ -1339,8 +1339,7 @@ class GeometryMixin(object):
1339
1339
  )
1340
1340
  for cnt, result in enumerate(pool.imap(constants, data, chunksize=1)):
1341
1341
  results.append(result)
1342
- pool.join()
1343
- pool.terminate()
1342
+ terminate_cpu_pool(pool=pool, force=False)
1344
1343
  if data_ndim == 2:
1345
1344
  return [i for s in results for i in s]
1346
1345
  else:
@@ -1370,8 +1369,7 @@ class GeometryMixin(object):
1370
1369
  cap_style=cap_style)
1371
1370
  for cnt, mp_return in enumerate(pool.imap(constants, geomety_lst, chunksize=1)):
1372
1371
  results.append(mp_return)
1373
- pool.join()
1374
- pool.terminate()
1372
+ terminate_cpu_pool(pool=pool, force=False)
1375
1373
  return [l for ll in results for l in ll]
1376
1374
 
1377
1375
  def multiframe_bodyparts_to_circle(self,
@@ -1524,8 +1522,7 @@ class GeometryMixin(object):
1524
1522
  )
1525
1523
  for cnt, result in enumerate(pool.imap(constants, data, chunksize=1)):
1526
1524
  results.append(result)
1527
- pool.join()
1528
- pool.terminate()
1525
+ terminate_cpu_pool(pool=pool, force=False)
1529
1526
  return results
1530
1527
 
1531
1528
  def multiframe_compute_pct_shape_overlap(self,
@@ -1798,8 +1795,7 @@ class GeometryMixin(object):
1798
1795
  timer.stop_timer()
1799
1796
  if verbose:
1800
1797
  stdout_success(msg="Rotated rectangles complete.", elapsed_time=timer.elapsed_time_str)
1801
- pool.join()
1802
- pool.terminate()
1798
+ terminate_cpu_pool(pool=pool, force=False)
1803
1799
  return results
1804
1800
 
1805
1801
  @staticmethod
@@ -2003,8 +1999,7 @@ class GeometryMixin(object):
2003
1999
  )
2004
2000
  for cnt, result in enumerate(pool.imap(constants, shapes, chunksize=1)):
2005
2001
  results.append(result)
2006
- pool.join()
2007
- pool.terminate()
2002
+ terminate_cpu_pool(pool=pool, force=False)
2008
2003
  return results
2009
2004
 
2010
2005
  def multiframe_union(self, shapes: Iterable[Union[LineString, MultiLineString, Polygon]], core_cnt: int = -1) -> \
@@ -2043,8 +2038,7 @@ class GeometryMixin(object):
2043
2038
  with multiprocessing.Pool(core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value) as pool:
2044
2039
  for cnt, result in enumerate(pool.imap(GeometryMixin().union, shapes, chunksize=1)):
2045
2040
  results.append(result)
2046
- pool.join()
2047
- pool.terminate()
2041
+ terminate_cpu_pool(pool=pool, force=False)
2048
2042
  return results
2049
2043
 
2050
2044
  def multiframe_symmetric_difference(self, shapes: Iterable[Union[LineString, MultiLineString, Polygon]],
@@ -2084,8 +2078,7 @@ class GeometryMixin(object):
2084
2078
  pool.imap(GeometryMixin().symmetric_difference, shapes, chunksize=1)
2085
2079
  ):
2086
2080
  results.append(result)
2087
- pool.join()
2088
- pool.terminate()
2081
+ terminate_cpu_pool(pool=pool, force=False)
2089
2082
  return results
2090
2083
 
2091
2084
  def multiframe_delaunay_triangulate_keypoints(self, data: np.ndarray, core_cnt: int = -1) -> List[List[Polygon]]:
@@ -2132,8 +2125,7 @@ class GeometryMixin(object):
2132
2125
  ):
2133
2126
  results.append(result)
2134
2127
 
2135
- pool.join()
2136
- pool.terminate()
2128
+ terminate_cpu_pool(pool=pool, force=False)
2137
2129
  return results
2138
2130
 
2139
2131
  def multiframe_difference(
@@ -2221,8 +2213,7 @@ class GeometryMixin(object):
2221
2213
  msg="Multi-frame difference compute complete",
2222
2214
  elapsed_time=timer.elapsed_time_str,
2223
2215
  )
2224
- pool.join()
2225
- pool.terminate()
2216
+ terminate_cpu_pool(pool=pool, force=False)
2226
2217
  return results
2227
2218
 
2228
2219
  def multiframe_area(self,
@@ -2276,8 +2267,7 @@ class GeometryMixin(object):
2276
2267
 
2277
2268
  timer.stop_timer()
2278
2269
  stdout_success(msg="Multi-frame area compute complete", elapsed_time=timer.elapsed_time_str)
2279
- pool.join()
2280
- pool.terminate()
2270
+ terminate_cpu_pool(pool=pool, force=False)
2281
2271
  return results
2282
2272
 
2283
2273
  def multiframe_bodyparts_to_multistring_skeleton(
@@ -2619,8 +2609,7 @@ class GeometryMixin(object):
2619
2609
  pool.imap(GeometryMixin.is_shape_covered, shapes, chunksize=1)
2620
2610
  ):
2621
2611
  results.append(mp_return)
2622
- pool.join()
2623
- pool.terminate()
2612
+ terminate_cpu_pool(pool=pool, force=False)
2624
2613
  return results
2625
2614
 
2626
2615
  @staticmethod
@@ -3321,8 +3310,7 @@ class GeometryMixin(object):
3321
3310
  for cnt, result in enumerate(pool.imap(constants, data, chunksize=1)):
3322
3311
  if result[1] != -1:
3323
3312
  img_arr[result[0], result[2] - 1, result[1] - 1] = 1
3324
- pool.join()
3325
- pool.terminate()
3313
+ terminate_cpu_pool(pool=pool, force=False)
3326
3314
  timer.stop_timer()
3327
3315
  stdout_success(
3328
3316
  msg="Cumulative coordinates in geometries complete",
@@ -3415,8 +3403,7 @@ class GeometryMixin(object):
3415
3403
  for cnt, result in enumerate(pool.imap(constants, data, chunksize=1)):
3416
3404
  if result[1] != -1:
3417
3405
  img_arr[result[0], result[2] - 1, result[1] - 1] = 1
3418
- pool.join()
3419
- pool.terminate()
3406
+ terminate_cpu_pool(pool=pool, force=False)
3420
3407
  if fps is None:
3421
3408
  return np.cumsum(img_arr, axis=0)
3422
3409
  else:
@@ -3559,8 +3546,7 @@ class GeometryMixin(object):
3559
3546
  constants = functools.partial(GeometryMixin._compute_framewise_geometry_idx, grid=grid, verbose=verbose)
3560
3547
  for cnt, result in enumerate(pool.imap(constants, data, chunksize=1)):
3561
3548
  results.append(result)
3562
- pool.join();
3563
- pool.terminate();
3549
+ terminate_cpu_pool(pool=pool, force=False)
3564
3550
  del data
3565
3551
 
3566
3552
  results = np.vstack(results)[:, 1:].astype(np.int32)
@@ -18,7 +18,7 @@ from collections import ChainMap
18
18
  import cv2
19
19
  import pandas as pd
20
20
  from numba import float64, int64, jit, njit, prange, uint8
21
- from shapely.geometry import MultiPolygon, Polygon
21
+ from shapely.geometry import Polygon
22
22
  from skimage.metrics import structural_similarity
23
23
 
24
24
  from simba.utils.checks import (check_file_exist_and_readable, check_float,
@@ -27,16 +27,14 @@ from simba.utils.checks import (check_file_exist_and_readable, check_float,
27
27
  check_int, check_str, check_valid_array,
28
28
  check_valid_boolean, check_valid_lst,
29
29
  check_valid_tuple, is_img_bw, is_img_greyscale)
30
+ from simba.utils.data import terminate_cpu_pool
30
31
  from simba.utils.enums import Defaults, Formats, GeometryEnum, Options
31
- from simba.utils.errors import (ArrayError, FFMPEGCodecGPUError,
32
- FrameRangeError, InvalidInputError,
33
- NotDirectoryError)
32
+ from simba.utils.errors import ArrayError, FrameRangeError, InvalidInputError
34
33
  from simba.utils.printing import SimbaTimer, stdout_success
35
34
  from simba.utils.read_write import (find_core_cnt,
36
35
  find_files_of_filetypes_in_directory,
37
36
  get_fn_ext, get_video_meta_data,
38
- read_frm_of_video,
39
- read_img_batch_from_video_gpu, write_df)
37
+ read_frm_of_video)
40
38
 
41
39
 
42
40
  class ImageMixin(object):
@@ -546,8 +544,8 @@ class ImageMixin(object):
546
544
  pool.imap(constants, split_frm_idx, chunksize=1)
547
545
  ):
548
546
  results.append(result)
549
- pool.terminate()
550
- pool.join()
547
+
548
+ terminate_cpu_pool(pool=pool, force=False)
551
549
  results = dict(ChainMap(*results))
552
550
 
553
551
  max_value, max_frm = -np.inf, None
@@ -876,8 +874,7 @@ class ImageMixin(object):
876
874
  pool.imap(ImageMixin()._image_reader_helper, file_paths, chunksize=1)
877
875
  ):
878
876
  imgs.update(result)
879
- pool.join()
880
- pool.terminate()
877
+ terminate_cpu_pool(pool=pool, force=False)
881
878
  return imgs
882
879
 
883
880
  @staticmethod
@@ -1027,8 +1024,7 @@ class ImageMixin(object):
1027
1024
  for cnt, result in enumerate(pool.imap(constants, frm_lst, chunksize=1)):
1028
1025
  results.update(result)
1029
1026
 
1030
- pool.join()
1031
- pool.terminate()
1027
+ terminate_cpu_pool(pool=pool, force=False)
1032
1028
  return results
1033
1029
 
1034
1030
  @staticmethod
@@ -1509,8 +1505,8 @@ class ImageMixin(object):
1509
1505
  for cnt, result in enumerate(pool.imap(constants, shapes, chunksize=1)):
1510
1506
  results.append(result)
1511
1507
  results = dict(ChainMap(*results))
1512
- pool.join()
1513
- pool.terminate()
1508
+
1509
+ terminate_cpu_pool(pool=pool, force=False)
1514
1510
  results = dict(sorted(results.items(), key=lambda item: int(item[0])))
1515
1511
  timer.stop_timer()
1516
1512
  stdout_success(msg="Geometry image slicing complete.", elapsed_time=timer.elapsed_time_str, source=self.__class__.__name__)
@@ -67,7 +67,7 @@ from simba.utils.checks import (check_all_dfs_in_list_has_same_cols,
67
67
  check_valid_boolean, check_valid_dataframe,
68
68
  check_valid_lst, is_lxc_container)
69
69
  from simba.utils.data import (detect_bouts, detect_bouts_multiclass,
70
- get_library_version)
70
+ get_library_version, terminate_cpu_pool)
71
71
  from simba.utils.enums import (OS, ConfigKey, Defaults, Dtypes, Formats, Links,
72
72
  Methods, MLParamKeys, Options)
73
73
  from simba.utils.errors import (ClassifierInferenceError, CorruptedFileError,
@@ -1859,7 +1859,7 @@ class TrainModelMixin(object):
1859
1859
  shap_raw.append(shap_data[result[1]][1].drop(clf_name, axis=1))
1860
1860
  if verbose: print(f"Completed SHAP care batch (Batch {result[1] + 1}/{len(shap_data)}).")
1861
1861
 
1862
- pool.terminate(); pool.join()
1862
+ terminate_cpu_pool(pool=pool, force=False)
1863
1863
  shap_df = pd.DataFrame(data=np.row_stack(shap_results), columns=list(x_names) + ["Expected_value", "Sum", "Prediction_probability", clf_name])
1864
1864
  raw_df = pd.DataFrame(data=np.row_stack(shap_raw), columns=list(x_names))
1865
1865
  out_shap_path, out_raw_path, img_save_path, df_save_paths, summary_dfs, img = None, None, None, None, None, None
@@ -24,10 +24,9 @@ from simba.utils.checks import (check_file_exist_and_readable,
24
24
  check_if_valid_rgb_tuple, check_int, check_str,
25
25
  check_valid_boolean, check_valid_lst,
26
26
  check_video_and_data_frm_count_align)
27
- from simba.utils.data import slice_roi_dict_for_video
27
+ from simba.utils.data import slice_roi_dict_for_video, terminate_cpu_pool
28
28
  from simba.utils.enums import Formats, TextOptions
29
- from simba.utils.errors import (BodypartColumnNotFoundError, NoFilesFoundError,
30
- ROICoordinatesNotFoundError)
29
+ from simba.utils.errors import BodypartColumnNotFoundError, NoFilesFoundError
31
30
  from simba.utils.printing import stdout_success
32
31
  from simba.utils.read_write import (concatenate_videos_in_folder,
33
32
  find_core_cnt, get_fn_ext,
@@ -315,8 +314,7 @@ class ROIfeatureVisualizerMultiprocess(ConfigReader):
315
314
  print(f"Joining {self.video_name} multi-processed video...")
316
315
  concatenate_videos_in_folder(in_folder=self.save_temp_dir, save_path=self.save_path, video_format="mp4", remove_splits=True, gpu=self.gpu)
317
316
  self.timer.stop_timer()
318
- pool.terminate()
319
- pool.join()
317
+ terminate_cpu_pool(pool=pool, force=False)
320
318
  stdout_success(msg=f"Video {self.video_name} complete. Video saved in directory {self.roi_features_save_dir}.", elapsed_time=self.timer.elapsed_time_str)
321
319
 
322
320
 
@@ -14,9 +14,9 @@ from simba.mixins.plotting_mixin import PlottingMixin
14
14
  from simba.utils.checks import (check_float, check_if_valid_rgb_tuple,
15
15
  check_int, check_str, check_that_column_exist,
16
16
  check_valid_lst)
17
- from simba.utils.data import detect_bouts
18
- from simba.utils.enums import Formats, TagNames, TextOptions
19
- from simba.utils.errors import NoFilesFoundError, NoSpecifiedOutputError
17
+ from simba.utils.data import detect_bouts, terminate_cpu_pool
18
+ from simba.utils.enums import Formats, TextOptions
19
+ from simba.utils.errors import NoSpecifiedOutputError
20
20
  from simba.utils.printing import SimbaTimer, log_event, stdout_success
21
21
  from simba.utils.read_write import (concatenate_videos_in_folder,
22
22
  find_core_cnt, get_fn_ext,
@@ -218,8 +218,7 @@ class ClassifierValidationClipsMultiprocess(ConfigReader):
218
218
  for cnt, result in enumerate(
219
219
  pool.imap(constants, clip_data, chunksize=self.multiprocess_chunksize)):
220
220
  print(f"Bout {cnt+1} complete...")
221
- pool.terminate()
222
- pool.join()
221
+ terminate_cpu_pool(pool=pool, force=False)
223
222
 
224
223
  if self.concat_video:
225
224
  print(f"Joining {file_name} multiprocessed video...")