simba-uw-tf-dev 4.5.8__py3-none-any.whl → 4.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. simba/SimBA.py +2 -2
  2. simba/assets/.recent_projects.txt +1 -0
  3. simba/assets/icons/frames_2.png +0 -0
  4. simba/assets/lookups/tooptips.json +15 -1
  5. simba/data_processors/agg_clf_counter_mp.py +52 -53
  6. simba/data_processors/blob_location_computer.py +1 -1
  7. simba/data_processors/circling_detector.py +30 -13
  8. simba/data_processors/cuda/geometry.py +45 -27
  9. simba/data_processors/cuda/image.py +1648 -1598
  10. simba/data_processors/cuda/statistics.py +72 -26
  11. simba/data_processors/cuda/timeseries.py +1 -1
  12. simba/data_processors/cue_light_analyzer.py +5 -9
  13. simba/data_processors/egocentric_aligner.py +25 -7
  14. simba/data_processors/freezing_detector.py +55 -47
  15. simba/data_processors/kleinberg_calculator.py +61 -29
  16. simba/feature_extractors/feature_subsets.py +14 -7
  17. simba/feature_extractors/mitra_feature_extractor.py +2 -2
  18. simba/feature_extractors/straub_tail_analyzer.py +4 -6
  19. simba/labelling/standard_labeller.py +1 -1
  20. simba/mixins/config_reader.py +5 -2
  21. simba/mixins/geometry_mixin.py +22 -36
  22. simba/mixins/image_mixin.py +24 -28
  23. simba/mixins/plotting_mixin.py +28 -10
  24. simba/mixins/statistics_mixin.py +48 -11
  25. simba/mixins/timeseries_features_mixin.py +1 -1
  26. simba/mixins/train_model_mixin.py +67 -29
  27. simba/model/inference_batch.py +1 -1
  28. simba/model/yolo_seg_inference.py +3 -3
  29. simba/outlier_tools/skip_outlier_correction.py +1 -1
  30. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  31. simba/plotting/clf_validator_mp.py +4 -5
  32. simba/plotting/cue_light_visualizer.py +6 -7
  33. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  34. simba/plotting/distance_plotter_mp.py +378 -378
  35. simba/plotting/frame_mergerer_ffmpeg.py +137 -196
  36. simba/plotting/gantt_creator.py +29 -10
  37. simba/plotting/gantt_creator_mp.py +96 -33
  38. simba/plotting/geometry_plotter.py +270 -272
  39. simba/plotting/heat_mapper_clf_mp.py +4 -6
  40. simba/plotting/heat_mapper_location_mp.py +2 -2
  41. simba/plotting/light_dark_box_plotter.py +2 -2
  42. simba/plotting/path_plotter_mp.py +26 -29
  43. simba/plotting/plot_clf_results_mp.py +455 -454
  44. simba/plotting/pose_plotter_mp.py +28 -29
  45. simba/plotting/probability_plot_creator_mp.py +288 -288
  46. simba/plotting/roi_plotter_mp.py +31 -31
  47. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  48. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  49. simba/plotting/yolo_pose_track_visualizer.py +32 -27
  50. simba/plotting/yolo_pose_visualizer.py +35 -36
  51. simba/plotting/yolo_seg_visualizer.py +2 -3
  52. simba/pose_importers/simba_blob_importer.py +3 -3
  53. simba/roi_tools/roi_aggregate_stats_mp.py +5 -4
  54. simba/roi_tools/roi_clf_calculator_mp.py +4 -4
  55. simba/sandbox/analyze_runtimes.py +30 -0
  56. simba/sandbox/cuda/egocentric_rotator.py +374 -0
  57. simba/sandbox/get_cpu_pool.py +5 -0
  58. simba/sandbox/proboscis_to_tip.py +28 -0
  59. simba/sandbox/test_directionality.py +47 -0
  60. simba/sandbox/test_nonstatic_directionality.py +27 -0
  61. simba/sandbox/test_pycharm_cuda.py +51 -0
  62. simba/sandbox/test_simba_install.py +41 -0
  63. simba/sandbox/test_static_directionality.py +26 -0
  64. simba/sandbox/test_static_directionality_2d.py +26 -0
  65. simba/sandbox/verify_env.py +42 -0
  66. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
  67. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
  68. simba/ui/pop_ups/clf_add_remove_print_pop_up.py +37 -30
  69. simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
  70. simba/ui/pop_ups/egocentric_alignment_pop_up.py +20 -21
  71. simba/ui/pop_ups/fsttc_pop_up.py +27 -25
  72. simba/ui/pop_ups/gantt_pop_up.py +31 -6
  73. simba/ui/pop_ups/interpolate_pop_up.py +2 -4
  74. simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
  75. simba/ui/pop_ups/multiple_videos_to_frames_popup.py +10 -11
  76. simba/ui/pop_ups/single_video_to_frames_popup.py +10 -10
  77. simba/ui/pop_ups/video_processing_pop_up.py +186 -174
  78. simba/ui/tkinter_functions.py +10 -1
  79. simba/utils/custom_feature_extractor.py +1 -1
  80. simba/utils/data.py +90 -14
  81. simba/utils/enums.py +1 -0
  82. simba/utils/errors.py +441 -440
  83. simba/utils/lookups.py +1203 -1203
  84. simba/utils/printing.py +124 -124
  85. simba/utils/read_write.py +3769 -3721
  86. simba/utils/yolo.py +10 -1
  87. simba/video_processors/blob_tracking_executor.py +2 -2
  88. simba/video_processors/clahe_ui.py +66 -23
  89. simba/video_processors/egocentric_video_rotator.py +46 -44
  90. simba/video_processors/multi_cropper.py +1 -1
  91. simba/video_processors/video_processing.py +5264 -5300
  92. simba/video_processors/videos_to_frames.py +43 -32
  93. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/METADATA +4 -3
  94. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/RECORD +98 -86
  95. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/LICENSE +0 -0
  96. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/WHEEL +0 -0
  97. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/entry_points.txt +0 -0
  98. {simba_uw_tf_dev-4.5.8.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/top_level.txt +0 -0
simba/SimBA.py CHANGED
@@ -966,8 +966,8 @@ class App(object):
966
966
  video_process_menu.add_cascade(label="Drop body-parts from tracking data", compound="left", image=self.menu_icons["trash"]["img"], command=DropTrackingDataPopUp, font=Formats.FONT_REGULAR.value)
967
967
  extract_frames_menu = Menu(video_process_menu, font=Formats.FONT_REGULAR.value)
968
968
  extract_frames_menu.add_command(label="Extract defined frames", command=ExtractSpecificFramesPopUp, font=Formats.FONT_REGULAR.value, image=self.menu_icons["frames"]["img"], compound="left")
969
- extract_frames_menu.add_command(label="Extract frames from single video", command=SingleVideo2FramesPopUp, font=Formats.FONT_REGULAR.value, image=self.menu_icons["video"]["img"], compound="left")
970
- extract_frames_menu.add_command(label="Extract frames from multiple video", command=MultipleVideos2FramesPopUp, font=Formats.FONT_REGULAR.value, image=self.menu_icons["stack"]["img"], compound="left")
969
+ extract_frames_menu.add_command(label="Extract frames from single video", command=SingleVideo2FramesPopUp, font=Formats.FONT_REGULAR.value, image=self.menu_icons["frames_2"]["img"], compound="left")
970
+ extract_frames_menu.add_command(label="Extract frames from multiple videos", command=MultipleVideos2FramesPopUp, font=Formats.FONT_REGULAR.value, image=self.menu_icons["stack"]["img"], compound="left")
971
971
  extract_frames_menu.add_command(label="Extract frames from seq files", command=ExtractSEQFramesPopUp, font=Formats.FONT_REGULAR.value, image=self.menu_icons["fire"]["img"], compound="left")
972
972
  video_process_menu.add_cascade(label="Extract frames...", compound="left", image=self.menu_icons["frames"]["img"], menu=extract_frames_menu, font=Formats.FONT_REGULAR.value)
973
973
 
@@ -1,2 +1,3 @@
1
+ E:/troubleshooting/mitra_emergence/project_folder/project_config.ini
1
2
  C:/troubleshooting/meberled/project_folder/project_config.ini
2
3
  C:/troubleshooting/mitra/project_folder/project_config.ini
Binary file
@@ -29,5 +29,19 @@
29
29
  "LOCATION_FRAME_COUNT": "The location in the video where the \n frame count is positioned.",
30
30
  "ROTATE_FILL_COLOR": "When video is rotated, there may be empty space not \n covered by the video. What color should this space have?",
31
31
  "VIDEO_DIR": "Directory containing videos.",
32
- "SAVE_DIR": "Directory where to save results."
32
+ "SAVE_DIR": "Directory where to save results.",
33
+ "CONCAT_HEIGHT": "If join involves aligning videos horizontally, this values \nwill be used to ensure videos have the same height. ",
34
+ "CONCAT_WIDTH": "If join involves aligning videos vertically, this values \nwill be used to ensure videos have the same width. ",
35
+ "CONCAT_RES_HEADER": "When stacking videos horizontally and/or vertically, \n the videos need to be same height and/or width. Here, \n select what resolution to use.",
36
+ "EGOCENTRIC_DATA_DIR": "Folder containing pose-estimation CSV data.\n Can be sub-directory in 'project_folder/csv' folder.\n Should contain same file names as the VIDEO files",
37
+ "EGOCENTRIC_VIDEO_DIR": "Folder containing videos.\n Should contain same file names as the DATA files.",
38
+ "EGOCENTRIC_ANCHOR": "This body-part will be placed in the center of the video",
39
+ "EGOCENTRIC_DIRECTION_ANCHOR": "This body-part will be placed at N degrees relative to the anchor",
40
+ "EGOCENTRIC_DIRECTION": "The anchor body-part will always be placed at these degrees relative to the center anchor",
41
+ "CORE_COUNT": "Higher core counts speeds up processing but may require more RAM memory",
42
+ "KLEINBERG_SIGMA": "Higher values (e.g., 2-3) produce fewer but longer bursts; lower values (e.g., 1.1-1.5) detect more frequent, shorter bursts. Must be > 1.01",
43
+ "KLEINBERG_GAMMA": "Higher values (e.g., 0.5-1.0) reduce total burst count by making downward transitions costly; lower values (e.g., 0.1-0.3) allow more flexible state changes",
44
+ "KLEINBERG_HIERARCHY": "Hierarchy level to extract bursts from (0=lowest, higher=more selective).\n Level 0 captures all bursts; level 1-2 typically filters noise; level 3+ selects only the most prominent, sustained bursts.\nHigher levels yield fewer but more confident detections",
45
+ "KLEINBERG_HIERARCHY_SEARCH": "If True, searches for target hierarchy level within detected burst periods,\n falling back to lower levels if target not found. If False, extracts only bursts at the exact specified hierarchy level.\n Recommended when target hierarchy may be sparse.",
46
+ "KLEINBERG_SAVE_ORIGINALS": "If True, saves the original data in a new sub-directory of \nthe project_folder/csv/machine_results directory"
33
47
  }
@@ -20,7 +20,7 @@ from simba.utils.checks import (
20
20
  check_all_file_names_are_represented_in_video_log,
21
21
  check_file_exist_and_readable, check_if_dir_exists, check_int,
22
22
  check_valid_boolean, check_valid_dataframe, check_valid_lst)
23
- from simba.utils.data import detect_bouts
23
+ from simba.utils.data import detect_bouts, terminate_cpu_pool
24
24
  from simba.utils.enums import TagNames
25
25
  from simba.utils.errors import NoChoosenMeasurementError
26
26
  from simba.utils.printing import SimbaTimer, log_event, stdout_success
@@ -210,8 +210,7 @@ class AggregateClfCalculatorMultiprocess(ConfigReader):
210
210
  self.bouts_df_lst.append(batch_bouts_df_lst)
211
211
  print(f"Data batch core {batch_id+1} / {self.core_cnt} complete...")
212
212
  self.bouts_df_lst = [df for sub in self.bouts_df_lst for df in sub]
213
- pool.join()
214
- pool.terminate()
213
+ terminate_cpu_pool(pool=pool, force=False)
215
214
 
216
215
  def save(self) -> None:
217
216
  """
@@ -242,56 +241,56 @@ class AggregateClfCalculatorMultiprocess(ConfigReader):
242
241
  self.timer.stop_timer()
243
242
  stdout_success(msg=f"Data aggregate log saved at {self.save_path}", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
244
243
 
245
- if __name__ == "__main__" and not hasattr(sys, 'ps1'):
246
- parser = argparse.ArgumentParser(description='Compute aggregate descriptive statistics from classification data.')
247
- parser.add_argument('--config_path', type=str, required=True, help='Path to SimBA project config file')
248
- parser.add_argument('--classifiers', type=str, nargs='+', required=True, help='List of classifier names to analyze')
249
- parser.add_argument('--data_dir', type=str, default=None, help='Directory containing machine results CSV files (default: project machine_results directory)')
250
- parser.add_argument('--detailed_bout_data', action='store_true', help='Save detailed bout data for each bout')
251
- parser.add_argument('--transpose', action='store_true', help='Create output with one video per row')
252
- parser.add_argument('--no_first_occurrence', action='store_true', help='Disable first occurrence calculation')
253
- parser.add_argument('--no_event_count', action='store_true', help='Disable event count calculation')
254
- parser.add_argument('--no_total_event_duration', action='store_true', help='Disable total event duration calculation')
255
- parser.add_argument('--no_mean_event_duration', action='store_true', help='Disable mean event duration calculation')
256
- parser.add_argument('--no_median_event_duration', action='store_true', help='Disable median event duration calculation')
257
- parser.add_argument('--no_mean_interval_duration', action='store_true', help='Disable mean interval duration calculation')
258
- parser.add_argument('--no_median_interval_duration', action='store_true', help='Disable median interval duration calculation')
259
- parser.add_argument('--frame_count', action='store_true', help='Include frame count in output')
260
- parser.add_argument('--video_length', action='store_true', help='Include video length in output')
261
-
262
- args = parser.parse_args()
263
-
264
- clf_calculator = AggregateClfCalculatorMultiprocess(
265
- config_path=args.config_path,
266
- classifiers=args.classifiers,
267
- data_dir=args.data_dir,
268
- detailed_bout_data=args.detailed_bout_data,
269
- transpose=args.transpose,
270
- first_occurrence=not args.no_first_occurrence,
271
- event_count=not args.no_event_count,
272
- total_event_duration=not args.no_total_event_duration,
273
- mean_event_duration=not args.no_mean_event_duration,
274
- median_event_duration=not args.no_median_event_duration,
275
- mean_interval_duration=not args.no_mean_interval_duration,
276
- median_interval_duration=not args.no_median_interval_duration,
277
- frame_count=args.frame_count,
278
- video_length=args.video_length
279
- )
280
- clf_calculator.run()
281
- clf_calculator.save()
282
-
283
- # if __name__ == "__main__":
284
- # test = AggregateClfCalculatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
285
- # classifiers=['attack'],
286
- # transpose=True,
287
- # mean_event_duration = True,
288
- # median_event_duration = True,
289
- # mean_interval_duration = True,
290
- # median_interval_duration = True,
291
- # detailed_bout_data=True,
292
- # core_cnt=12)
293
- # test.run()
294
- # test.save()
244
+ # if __name__ == "__main__" and not hasattr(sys, 'ps1'):
245
+ # parser = argparse.ArgumentParser(description='Compute aggregate descriptive statistics from classification data.')
246
+ # parser.add_argument('--config_path', type=str, required=True, help='Path to SimBA project config file')
247
+ # parser.add_argument('--classifiers', type=str, nargs='+', required=True, help='List of classifier names to analyze')
248
+ # parser.add_argument('--data_dir', type=str, default=None, help='Directory containing machine results CSV files (default: project machine_results directory)')
249
+ # parser.add_argument('--detailed_bout_data', action='store_true', help='Save detailed bout data for each bout')
250
+ # parser.add_argument('--transpose', action='store_true', help='Create output with one video per row')
251
+ # parser.add_argument('--no_first_occurrence', action='store_true', help='Disable first occurrence calculation')
252
+ # parser.add_argument('--no_event_count', action='store_true', help='Disable event count calculation')
253
+ # parser.add_argument('--no_total_event_duration', action='store_true', help='Disable total event duration calculation')
254
+ # parser.add_argument('--no_mean_event_duration', action='store_true', help='Disable mean event duration calculation')
255
+ # parser.add_argument('--no_median_event_duration', action='store_true', help='Disable median event duration calculation')
256
+ # parser.add_argument('--no_mean_interval_duration', action='store_true', help='Disable mean interval duration calculation')
257
+ # parser.add_argument('--no_median_interval_duration', action='store_true', help='Disable median interval duration calculation')
258
+ # parser.add_argument('--frame_count', action='store_true', help='Include frame count in output')
259
+ # parser.add_argument('--video_length', action='store_true', help='Include video length in output')
260
+ #
261
+ # args = parser.parse_args()
262
+ #
263
+ # clf_calculator = AggregateClfCalculatorMultiprocess(
264
+ # config_path=args.config_path,
265
+ # classifiers=args.classifiers,
266
+ # data_dir=args.data_dir,
267
+ # detailed_bout_data=args.detailed_bout_data,
268
+ # transpose=args.transpose,
269
+ # first_occurrence=not args.no_first_occurrence,
270
+ # event_count=not args.no_event_count,
271
+ # total_event_duration=not args.no_total_event_duration,
272
+ # mean_event_duration=not args.no_mean_event_duration,
273
+ # median_event_duration=not args.no_median_event_duration,
274
+ # mean_interval_duration=not args.no_mean_interval_duration,
275
+ # median_interval_duration=not args.no_median_interval_duration,
276
+ # frame_count=args.frame_count,
277
+ # video_length=args.video_length
278
+ # )
279
+ # clf_calculator.run()
280
+ # clf_calculator.save()
281
+
282
+ if __name__ == "__main__":
283
+ test = AggregateClfCalculatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
284
+ classifiers=['attack'],
285
+ transpose=True,
286
+ mean_event_duration = True,
287
+ median_event_duration = True,
288
+ mean_interval_duration = True,
289
+ median_interval_duration = True,
290
+ detailed_bout_data=True,
291
+ core_cnt=12)
292
+ test.run()
293
+ test.save()
295
294
 
296
295
 
297
296
  # test = AggregateClfCalculator(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
@@ -51,7 +51,7 @@ class BlobLocationComputer(object):
51
51
  :param Optional[bool] multiprocessing: If True, video background subtraction will be done using multiprocessing. Default is False.
52
52
 
53
53
  :example:
54
- >>> x = BlobLocationComputer(data_path=r"C:\troubleshooting\RAT_NOR\project_folder\videos\2022-06-20_NOB_DOT_4_downsampled_bg_subtracted.mp4", multiprocessing=True, gpu=True, batch_size=2000, save_dir=r"C:\blob_positions")
54
+ >>> x = BlobLocationComputer(data_path=r"C:/troubleshooting/RAT_NOR/project_folder/videos/2022-06-20_NOB_DOT_4_downsampled_bg_subtracted.mp4", multiprocessing=True, gpu=True, batch_size=2000, save_dir=r"C:/blob_positions")
55
55
  >>> x.run()
56
56
  """
57
57
  def __init__(self,
@@ -11,12 +11,13 @@ from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
11
11
  from simba.mixins.timeseries_features_mixin import TimeseriesFeatureMixin
12
12
  from simba.utils.checks import (
13
13
  check_all_file_names_are_represented_in_video_log, check_if_dir_exists,
14
- check_int, check_str, check_valid_dataframe)
14
+ check_str, check_valid_dataframe)
15
15
  from simba.utils.data import detect_bouts, plug_holes_shortest_bout
16
16
  from simba.utils.enums import Formats
17
17
  from simba.utils.printing import stdout_success
18
18
  from simba.utils.read_write import (find_files_of_filetypes_in_directory,
19
- get_fn_ext, read_df, read_video_info)
19
+ get_current_time, get_fn_ext, read_df,
20
+ read_video_info)
20
21
 
21
22
  CIRCLING = 'CIRCLING'
22
23
 
@@ -58,30 +59,34 @@ class CirclingDetector(ConfigReader):
58
59
  """
59
60
 
60
61
  def __init__(self,
61
- data_dir: Union[str, os.PathLike],
62
62
  config_path: Union[str, os.PathLike],
63
63
  nose_name: Optional[str] = 'nose',
64
+ data_dir: Optional[Union[str, os.PathLike]] = None,
64
65
  left_ear_name: Optional[str] = 'left_ear',
65
66
  right_ear_name: Optional[str] = 'right_ear',
66
67
  tail_base_name: Optional[str] = 'tail_base',
67
68
  center_name: Optional[str] = 'center',
68
- time_threshold: Optional[int] = 10,
69
- circular_range_threshold: Optional[int] = 320,
69
+ time_threshold: Optional[int] = 7,
70
+ circular_range_threshold: Optional[int] = 350,
71
+ shortest_bout: int = 100,
70
72
  movement_threshold: Optional[int] = 60,
71
73
  save_dir: Optional[Union[str, os.PathLike]] = None):
72
74
 
73
- check_if_dir_exists(in_dir=data_dir)
74
75
  for bp_name in [nose_name, left_ear_name, right_ear_name, tail_base_name]: check_str(name='body part name', value=bp_name, allow_blank=False)
75
- self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
76
76
  ConfigReader.__init__(self, config_path=config_path, read_video_info=True, create_logger=False)
77
+ if data_dir is not None:
78
+ check_if_dir_exists(in_dir=data_dir)
79
+ else:
80
+ data_dir = self.outlier_corrected_dir
81
+ self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
77
82
  self.nose_heads = [f'{nose_name}_x'.lower(), f'{nose_name}_y'.lower()]
78
83
  self.left_ear_heads = [f'{left_ear_name}_x'.lower(), f'{left_ear_name}_y'.lower()]
79
84
  self.right_ear_heads = [f'{right_ear_name}_x'.lower(), f'{right_ear_name}_y'.lower()]
80
85
  self.center_heads = [f'{center_name}_x'.lower(), f'{center_name}_y'.lower()]
81
86
  self.required_field = self.nose_heads + self.left_ear_heads + self.right_ear_heads
82
- self.save_dir = save_dir
87
+ self.save_dir, self.shortest_bout = save_dir, shortest_bout
83
88
  if self.save_dir is None:
84
- self.save_dir = os.path.join(self.logs_path, f'circling_data_{self.datetime}')
89
+ self.save_dir = os.path.join(self.logs_path, f'circling_data_{time_threshold}s_{circular_range_threshold}d_{movement_threshold}mm_{self.datetime}')
85
90
  os.makedirs(self.save_dir)
86
91
  else:
87
92
  check_if_dir_exists(in_dir=self.save_dir)
@@ -93,7 +98,7 @@ class CirclingDetector(ConfigReader):
93
98
  check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
94
99
  for file_cnt, file_path in enumerate(self.data_paths):
95
100
  video_name = get_fn_ext(filepath=file_path)[1]
96
- print(f'Analyzing {video_name} ({file_cnt+1}/{len(self.data_paths)})...')
101
+ print(f'[{get_current_time()}] Analyzing circling {video_name}... (video {file_cnt+1}/{len(self.data_paths)})')
97
102
  save_file_path = os.path.join(self.save_dir, f'{video_name}.csv')
98
103
  df = read_df(file_path=file_path, file_type='csv').reset_index(drop=True)
99
104
  _, px_per_mm, fps = read_video_info(video_info_df=self.video_info_df, video_name=video_name)
@@ -115,11 +120,24 @@ class CirclingDetector(ConfigReader):
115
120
  circling_idx = np.argwhere(sliding_circular_range >= self.circular_range_threshold).astype(np.int32).flatten()
116
121
  movement_idx = np.argwhere(movement_sum >= self.movement_threshold).astype(np.int32).flatten()
117
122
  circling_idx = [x for x in movement_idx if x in circling_idx]
123
+ df[f'Probability_{CIRCLING}'] = 0
118
124
  df[CIRCLING] = 0
119
125
  df.loc[circling_idx, CIRCLING] = 1
126
+ df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
127
+ df = plug_holes_shortest_bout(data_df=df, clf_name=CIRCLING, fps=fps, shortest_bout=self.shortest_bout)
120
128
  bouts = detect_bouts(data_df=df, target_lst=[CIRCLING], fps=fps)
121
- df = plug_holes_shortest_bout(data_df=df, clf_name=CIRCLING, fps=fps, shortest_bout=100)
129
+ if len(bouts) > 0:
130
+ df[CIRCLING] = 0
131
+ circling_idx = list(bouts.apply(lambda x: list(range(int(x["Start_frame"]), int(x["End_frame"]) + 1)), 1))
132
+ circling_idx = [x for xs in circling_idx for x in xs]
133
+ df.loc[circling_idx, CIRCLING] = 1
134
+ df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
135
+ else:
136
+ df[CIRCLING] = 0
137
+ circling_idx = []
138
+
122
139
  df.to_csv(save_file_path)
140
+ #print(video_name, len(circling_idx), round(len(circling_idx) / fps, 4), df[CIRCLING].sum())
123
141
  agg_results.loc[len(agg_results)] = [video_name, len(circling_idx), round(len(circling_idx) / fps, 4), len(bouts), round((len(circling_idx) / len(df)) * 100, 4), len(df), round(len(df)/fps, 2) ]
124
142
 
125
143
  agg_results.to_csv(agg_results_path)
@@ -127,7 +145,6 @@ class CirclingDetector(ConfigReader):
127
145
 
128
146
  #
129
147
  #
130
- # detector = CirclingDetector(data_dir=r'C:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location',
131
- # config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
148
+ # detector = CirclingDetector(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
132
149
  # detector.run()
133
150
 
@@ -8,6 +8,7 @@ from numba import cuda, njit
8
8
 
9
9
  from simba.utils.checks import check_float, check_valid_array
10
10
  from simba.utils.enums import Formats
11
+ from simba.utils.printing import SimbaTimer
11
12
 
12
13
  try:
13
14
  import cupy as cp
@@ -401,20 +402,21 @@ def find_midpoints(x: np.ndarray,
401
402
 
402
403
 
403
404
  @cuda.jit()
404
- def _directionality_to_static_targets_kernel(left_ear, right_ear, nose, target, results):
405
+ def _directionality_to_static_targets_kernel(left_ear, right_ear, nose, target_x, target_y, results):
405
406
  i = cuda.grid(1)
406
- if i > left_ear.shape[0]:
407
+ if i >= left_ear.shape[0]:
407
408
  return
408
409
  else:
409
- LE, RE = left_ear[i], right_ear[i]
410
- N, Tx, Ty = nose[i], target[0], target[1]
411
-
412
- Px = abs(LE[0] - Tx)
413
- Py = abs(LE[1] - Ty)
414
- Qx = abs(RE[0] - Tx)
415
- Qy = abs(RE[1] - Ty)
416
- Nx = abs(N[0] - Tx)
417
- Ny = abs(N[1] - Ty)
410
+ LE = left_ear[i]
411
+ RE = right_ear[i]
412
+ N = nose[i]
413
+
414
+ Px = abs(LE[0] - target_x)
415
+ Py = abs(LE[1] - target_y)
416
+ Qx = abs(RE[0] - target_x)
417
+ Qy = abs(RE[1] - target_y)
418
+ Nx = abs(N[0] - target_x)
419
+ Ny = abs(N[1] - target_y)
418
420
  Ph = math.sqrt(Px * Px + Py * Py)
419
421
  Qh = math.sqrt(Qx * Qx + Qy * Qy)
420
422
  Nh = math.sqrt(Nx * Nx + Ny * Ny)
@@ -438,7 +440,8 @@ def _directionality_to_static_targets_kernel(left_ear, right_ear, nose, target,
438
440
  def directionality_to_static_targets(left_ear: np.ndarray,
439
441
  right_ear: np.ndarray,
440
442
  nose: np.ndarray,
441
- target: np.ndarray) -> np.ndarray:
443
+ target: np.ndarray,
444
+ verbose: bool = False) -> np.ndarray:
442
445
  """
443
446
  GPU helper to calculate if an animal is directing towards a static location (e.g., ROI centroid), given the target location and the left ear, right ear, and nose coordinates of the observer.
444
447
 
@@ -487,32 +490,38 @@ def directionality_to_static_targets(left_ear: np.ndarray,
487
490
  >>> directionality_to_static_targets(left_ear=left_ear, right_ear=right_ear, nose=nose, target=target)
488
491
 
489
492
  """
490
-
493
+ timer = SimbaTimer(start=True)
491
494
  left_ear = np.ascontiguousarray(left_ear).astype(np.int32)
492
495
  right_ear = np.ascontiguousarray(right_ear).astype(np.int32)
493
496
  nose = np.ascontiguousarray(nose).astype(np.int32)
494
497
  target = np.ascontiguousarray(target).astype(np.int32)
495
498
 
499
+ target_x = int(target[0])
500
+ target_y = int(target[1])
501
+
496
502
  left_ear_dev = cuda.to_device(left_ear)
497
503
  right_ear_dev = cuda.to_device(right_ear)
498
504
  nose_dev = cuda.to_device(nose)
499
- target_dev = cuda.to_device(target)
500
505
  results = cuda.device_array((left_ear.shape[0], 4), dtype=np.int32)
501
506
  bpg = (left_ear.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK
502
- _directionality_to_static_targets_kernel[bpg, THREADS_PER_BLOCK](left_ear_dev, right_ear_dev, nose_dev, target_dev, results)
507
+ _directionality_to_static_targets_kernel[bpg, THREADS_PER_BLOCK](left_ear_dev, right_ear_dev, nose_dev, target_x, target_y, results)
503
508
 
504
509
  results = results.copy_to_host()
510
+ timer.stop_timer()
511
+ if verbose: print(f'Directionality to static target computed in for {left_ear.shape[0]} observations (elapsed time: {timer.elapsed_time_str}s)')
505
512
  return results
506
513
 
507
514
 
508
515
  @cuda.jit()
509
516
  def _directionality_to_nonstatic_targets_kernel(left_ear, right_ear, nose, target, results):
510
517
  i = cuda.grid(1)
511
- if i > left_ear.shape[0]:
518
+ if i >= left_ear.shape[0]:
512
519
  return
513
520
  else:
514
- LE, RE = left_ear[i], right_ear[i]
515
- N, T = nose[i], target[i]
521
+ LE = left_ear[i]
522
+ RE = right_ear[i]
523
+ N = nose[i]
524
+ T = target[i]
516
525
 
517
526
  Px = abs(LE[0] - T[0])
518
527
  Py = abs(LE[1] - T[1])
@@ -543,11 +552,19 @@ def _directionality_to_nonstatic_targets_kernel(left_ear, right_ear, nose, targe
543
552
  def directionality_to_nonstatic_target(left_ear: np.ndarray,
544
553
  right_ear: np.ndarray,
545
554
  nose: np.ndarray,
546
- target: np.ndarray) -> np.ndarray:
555
+ target: np.ndarray,
556
+ verbose: bool = False) -> np.ndarray:
547
557
 
548
558
  """
549
559
  GPU method to calculate if an animal is directing towards a moving point location given the target location and the left ear, right ear, and nose coordinates of the observer.
550
560
 
561
+ .. csv-table::
562
+ :header: EXPECTED RUNTIMES
563
+ :file: ../../../docs/tables/directionality_to_nonstatic_target_cuda.csv
564
+ :widths: 30, 30, 20, 10, 10
565
+ :align: center
566
+ :class: simba-table
567
+ :header-rows: 1
551
568
 
552
569
  .. image:: _static/img/directing_moving_targets.png
553
570
  :width: 400
@@ -573,27 +590,28 @@ def directionality_to_nonstatic_target(left_ear: np.ndarray,
573
590
  >>> right_ear = np.random.randint(0, 500, (100, 2))
574
591
  >>> nose = np.random.randint(0, 500, (100, 2))
575
592
  >>> target = np.random.randint(0, 500, (100, 2))
576
- >>> directionality_to_static_targets(left_ear=left_ear, right_ear=right_ear, nose=nose, target=target)
593
+ >>> directionality_to_nonstatic_target(left_ear=left_ear, right_ear=right_ear, nose=nose, target=target)
577
594
  """
578
595
 
579
- left_ear = np.ascontiguousarray(left_ear).astype(np.int32)
580
- right_ear = np.ascontiguousarray(right_ear).astype(np.int32)
581
- nose = np.ascontiguousarray(nose).astype(np.int32)
582
- target = np.ascontiguousarray(target).astype(np.int32)
596
+ timer = SimbaTimer(start=True)
597
+ left_ear = np.ascontiguousarray(left_ear).astype(np.int64)
598
+ right_ear = np.ascontiguousarray(right_ear).astype(np.int64)
599
+ nose = np.ascontiguousarray(nose).astype(np.int64)
600
+ target = np.ascontiguousarray(target).astype(np.int64)
583
601
 
584
602
  left_ear_dev = cuda.to_device(left_ear)
585
603
  right_ear_dev = cuda.to_device(right_ear)
586
604
  nose_dev = cuda.to_device(nose)
587
605
  target_dev = cuda.to_device(target)
588
- results = cuda.device_array((left_ear.shape[0], 4), dtype=np.int32)
606
+ results = cuda.device_array((left_ear.shape[0], 4), dtype=np.int64)
589
607
  bpg = (left_ear.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK
590
608
  _directionality_to_nonstatic_targets_kernel[bpg, THREADS_PER_BLOCK](left_ear_dev, right_ear_dev, nose_dev, target_dev, results)
591
609
 
592
610
  results = results.copy_to_host()
611
+ timer.stop_timer()
612
+ if verbose: print(f'Directionality to moving target computed in for {left_ear.shape[0]} observations (elapsed time: {timer.elapsed_time_str}s)')
593
613
  return results
594
614
 
595
615
 
596
616
 
597
617
 
598
-
599
-