simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. simba/assets/.recent_projects.txt +1 -0
  2. simba/assets/lookups/tooptips.json +6 -1
  3. simba/assets/lookups/yolo_schematics/yolo_mitra.csv +9 -0
  4. simba/data_processors/agg_clf_counter_mp.py +52 -53
  5. simba/data_processors/blob_location_computer.py +1 -1
  6. simba/data_processors/circling_detector.py +30 -13
  7. simba/data_processors/cuda/geometry.py +45 -27
  8. simba/data_processors/cuda/image.py +1648 -1598
  9. simba/data_processors/cuda/statistics.py +72 -26
  10. simba/data_processors/cuda/timeseries.py +1 -1
  11. simba/data_processors/cue_light_analyzer.py +5 -9
  12. simba/data_processors/egocentric_aligner.py +25 -7
  13. simba/data_processors/freezing_detector.py +55 -47
  14. simba/data_processors/kleinberg_calculator.py +61 -29
  15. simba/feature_extractors/feature_subsets.py +14 -7
  16. simba/feature_extractors/mitra_feature_extractor.py +2 -2
  17. simba/feature_extractors/straub_tail_analyzer.py +4 -6
  18. simba/labelling/standard_labeller.py +1 -1
  19. simba/mixins/config_reader.py +5 -2
  20. simba/mixins/geometry_mixin.py +22 -36
  21. simba/mixins/image_mixin.py +24 -28
  22. simba/mixins/plotting_mixin.py +28 -10
  23. simba/mixins/statistics_mixin.py +48 -11
  24. simba/mixins/timeseries_features_mixin.py +1 -1
  25. simba/mixins/train_model_mixin.py +68 -33
  26. simba/model/inference_batch.py +2 -2
  27. simba/model/yolo_seg_inference.py +3 -3
  28. simba/outlier_tools/skip_outlier_correction.py +1 -1
  29. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  30. simba/plotting/clf_validator_mp.py +4 -5
  31. simba/plotting/cue_light_visualizer.py +6 -7
  32. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  33. simba/plotting/distance_plotter_mp.py +378 -378
  34. simba/plotting/gantt_creator.py +29 -10
  35. simba/plotting/gantt_creator_mp.py +96 -33
  36. simba/plotting/geometry_plotter.py +270 -272
  37. simba/plotting/heat_mapper_clf_mp.py +4 -6
  38. simba/plotting/heat_mapper_location_mp.py +2 -2
  39. simba/plotting/light_dark_box_plotter.py +2 -2
  40. simba/plotting/path_plotter_mp.py +26 -29
  41. simba/plotting/plot_clf_results_mp.py +455 -454
  42. simba/plotting/pose_plotter_mp.py +28 -29
  43. simba/plotting/probability_plot_creator_mp.py +288 -288
  44. simba/plotting/roi_plotter_mp.py +31 -31
  45. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  46. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  47. simba/plotting/yolo_pose_track_visualizer.py +32 -27
  48. simba/plotting/yolo_pose_visualizer.py +35 -36
  49. simba/plotting/yolo_seg_visualizer.py +2 -3
  50. simba/pose_importers/simba_blob_importer.py +3 -3
  51. simba/roi_tools/roi_aggregate_stats_mp.py +5 -4
  52. simba/roi_tools/roi_clf_calculator_mp.py +4 -4
  53. simba/sandbox/analyze_runtimes.py +30 -0
  54. simba/sandbox/cuda/egocentric_rotator.py +374 -374
  55. simba/sandbox/get_cpu_pool.py +5 -0
  56. simba/sandbox/proboscis_to_tip.py +28 -0
  57. simba/sandbox/test_directionality.py +47 -0
  58. simba/sandbox/test_nonstatic_directionality.py +27 -0
  59. simba/sandbox/test_pycharm_cuda.py +51 -0
  60. simba/sandbox/test_simba_install.py +41 -0
  61. simba/sandbox/test_static_directionality.py +26 -0
  62. simba/sandbox/test_static_directionality_2d.py +26 -0
  63. simba/sandbox/verify_env.py +42 -0
  64. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
  65. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
  66. simba/third_party_label_appenders/transform/simba_to_yolo.py +8 -5
  67. simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
  68. simba/ui/pop_ups/fsttc_pop_up.py +27 -25
  69. simba/ui/pop_ups/gantt_pop_up.py +31 -6
  70. simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
  71. simba/ui/pop_ups/run_machine_models_popup.py +21 -21
  72. simba/ui/pop_ups/simba_to_yolo_keypoints_popup.py +2 -2
  73. simba/ui/pop_ups/video_processing_pop_up.py +37 -29
  74. simba/ui/pop_ups/yolo_inference_popup.py +1 -1
  75. simba/ui/pop_ups/yolo_pose_train_popup.py +1 -1
  76. simba/ui/tkinter_functions.py +3 -0
  77. simba/utils/custom_feature_extractor.py +1 -1
  78. simba/utils/data.py +90 -14
  79. simba/utils/enums.py +1 -0
  80. simba/utils/errors.py +441 -440
  81. simba/utils/lookups.py +1203 -1203
  82. simba/utils/printing.py +124 -124
  83. simba/utils/read_write.py +3769 -3721
  84. simba/utils/yolo.py +10 -1
  85. simba/video_processors/blob_tracking_executor.py +2 -2
  86. simba/video_processors/clahe_ui.py +1 -1
  87. simba/video_processors/egocentric_video_rotator.py +44 -41
  88. simba/video_processors/multi_cropper.py +1 -1
  89. simba/video_processors/video_processing.py +75 -33
  90. simba/video_processors/videos_to_frames.py +43 -33
  91. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/METADATA +4 -3
  92. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/RECORD +96 -85
  93. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/LICENSE +0 -0
  94. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/WHEEL +0 -0
  95. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/entry_points.txt +0 -0
  96. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/top_level.txt +0 -0
@@ -1,2 +1,3 @@
1
+ E:/troubleshooting/mitra_emergence/project_folder/project_config.ini
1
2
  C:/troubleshooting/meberled/project_folder/project_config.ini
2
3
  C:/troubleshooting/mitra/project_folder/project_config.ini
@@ -38,5 +38,10 @@
38
38
  "EGOCENTRIC_ANCHOR": "This body-part will be placed in the center of the video",
39
39
  "EGOCENTRIC_DIRECTION_ANCHOR": "This body-part will be placed at N degrees relative to the anchor",
40
40
  "EGOCENTRIC_DIRECTION": "The anchor body-part will always be placed at these degrees relative to the center anchor",
41
- "CORE_COUNT": "Higher core counts speeds up processing but may require more RAM memory"
41
+ "CORE_COUNT": "Higher core counts speeds up processing but may require more RAM memory",
42
+ "KLEINBERG_SIGMA": "Higher values (e.g., 2-3) produce fewer but longer bursts; lower values (e.g., 1.1-1.5) detect more frequent, shorter bursts. Must be > 1.01",
43
+ "KLEINBERG_GAMMA": "Higher values (e.g., 0.5-1.0) reduce total burst count by making downward transitions costly; lower values (e.g., 0.1-0.3) allow more flexible state changes",
44
+ "KLEINBERG_HIERARCHY": "Hierarchy level to extract bursts from (0=lowest, higher=more selective).\n Level 0 captures all bursts; level 1-2 typically filters noise; level 3+ selects only the most prominent, sustained bursts.\nHigher levels yield fewer but more confident detections",
45
+ "KLEINBERG_HIERARCHY_SEARCH": "If True, searches for target hierarchy level within detected burst periods,\n falling back to lower levels if target not found. If False, extracts only bursts at the exact specified hierarchy level.\n Recommended when target hierarchy may be sparse.",
46
+ "KLEINBERG_SAVE_ORIGINALS": "If True, saves the original data in a new sub-directory of \nthe project_folder/csv/machine_results directory"
42
47
  }
@@ -0,0 +1,9 @@
1
+ left_ear
2
+ right_ear
3
+ nose
4
+ left_side
5
+ right_side
6
+ tail_base
7
+ center
8
+ tail_center
9
+ tail_tip
@@ -20,7 +20,7 @@ from simba.utils.checks import (
20
20
  check_all_file_names_are_represented_in_video_log,
21
21
  check_file_exist_and_readable, check_if_dir_exists, check_int,
22
22
  check_valid_boolean, check_valid_dataframe, check_valid_lst)
23
- from simba.utils.data import detect_bouts
23
+ from simba.utils.data import detect_bouts, terminate_cpu_pool
24
24
  from simba.utils.enums import TagNames
25
25
  from simba.utils.errors import NoChoosenMeasurementError
26
26
  from simba.utils.printing import SimbaTimer, log_event, stdout_success
@@ -210,8 +210,7 @@ class AggregateClfCalculatorMultiprocess(ConfigReader):
210
210
  self.bouts_df_lst.append(batch_bouts_df_lst)
211
211
  print(f"Data batch core {batch_id+1} / {self.core_cnt} complete...")
212
212
  self.bouts_df_lst = [df for sub in self.bouts_df_lst for df in sub]
213
- pool.join()
214
- pool.terminate()
213
+ terminate_cpu_pool(pool=pool, force=False)
215
214
 
216
215
  def save(self) -> None:
217
216
  """
@@ -242,56 +241,56 @@ class AggregateClfCalculatorMultiprocess(ConfigReader):
242
241
  self.timer.stop_timer()
243
242
  stdout_success(msg=f"Data aggregate log saved at {self.save_path}", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
244
243
 
245
- if __name__ == "__main__" and not hasattr(sys, 'ps1'):
246
- parser = argparse.ArgumentParser(description='Compute aggregate descriptive statistics from classification data.')
247
- parser.add_argument('--config_path', type=str, required=True, help='Path to SimBA project config file')
248
- parser.add_argument('--classifiers', type=str, nargs='+', required=True, help='List of classifier names to analyze')
249
- parser.add_argument('--data_dir', type=str, default=None, help='Directory containing machine results CSV files (default: project machine_results directory)')
250
- parser.add_argument('--detailed_bout_data', action='store_true', help='Save detailed bout data for each bout')
251
- parser.add_argument('--transpose', action='store_true', help='Create output with one video per row')
252
- parser.add_argument('--no_first_occurrence', action='store_true', help='Disable first occurrence calculation')
253
- parser.add_argument('--no_event_count', action='store_true', help='Disable event count calculation')
254
- parser.add_argument('--no_total_event_duration', action='store_true', help='Disable total event duration calculation')
255
- parser.add_argument('--no_mean_event_duration', action='store_true', help='Disable mean event duration calculation')
256
- parser.add_argument('--no_median_event_duration', action='store_true', help='Disable median event duration calculation')
257
- parser.add_argument('--no_mean_interval_duration', action='store_true', help='Disable mean interval duration calculation')
258
- parser.add_argument('--no_median_interval_duration', action='store_true', help='Disable median interval duration calculation')
259
- parser.add_argument('--frame_count', action='store_true', help='Include frame count in output')
260
- parser.add_argument('--video_length', action='store_true', help='Include video length in output')
261
-
262
- args = parser.parse_args()
263
-
264
- clf_calculator = AggregateClfCalculatorMultiprocess(
265
- config_path=args.config_path,
266
- classifiers=args.classifiers,
267
- data_dir=args.data_dir,
268
- detailed_bout_data=args.detailed_bout_data,
269
- transpose=args.transpose,
270
- first_occurrence=not args.no_first_occurrence,
271
- event_count=not args.no_event_count,
272
- total_event_duration=not args.no_total_event_duration,
273
- mean_event_duration=not args.no_mean_event_duration,
274
- median_event_duration=not args.no_median_event_duration,
275
- mean_interval_duration=not args.no_mean_interval_duration,
276
- median_interval_duration=not args.no_median_interval_duration,
277
- frame_count=args.frame_count,
278
- video_length=args.video_length
279
- )
280
- clf_calculator.run()
281
- clf_calculator.save()
282
-
283
- # if __name__ == "__main__":
284
- # test = AggregateClfCalculatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
285
- # classifiers=['attack'],
286
- # transpose=True,
287
- # mean_event_duration = True,
288
- # median_event_duration = True,
289
- # mean_interval_duration = True,
290
- # median_interval_duration = True,
291
- # detailed_bout_data=True,
292
- # core_cnt=12)
293
- # test.run()
294
- # test.save()
244
+ # if __name__ == "__main__" and not hasattr(sys, 'ps1'):
245
+ # parser = argparse.ArgumentParser(description='Compute aggregate descriptive statistics from classification data.')
246
+ # parser.add_argument('--config_path', type=str, required=True, help='Path to SimBA project config file')
247
+ # parser.add_argument('--classifiers', type=str, nargs='+', required=True, help='List of classifier names to analyze')
248
+ # parser.add_argument('--data_dir', type=str, default=None, help='Directory containing machine results CSV files (default: project machine_results directory)')
249
+ # parser.add_argument('--detailed_bout_data', action='store_true', help='Save detailed bout data for each bout')
250
+ # parser.add_argument('--transpose', action='store_true', help='Create output with one video per row')
251
+ # parser.add_argument('--no_first_occurrence', action='store_true', help='Disable first occurrence calculation')
252
+ # parser.add_argument('--no_event_count', action='store_true', help='Disable event count calculation')
253
+ # parser.add_argument('--no_total_event_duration', action='store_true', help='Disable total event duration calculation')
254
+ # parser.add_argument('--no_mean_event_duration', action='store_true', help='Disable mean event duration calculation')
255
+ # parser.add_argument('--no_median_event_duration', action='store_true', help='Disable median event duration calculation')
256
+ # parser.add_argument('--no_mean_interval_duration', action='store_true', help='Disable mean interval duration calculation')
257
+ # parser.add_argument('--no_median_interval_duration', action='store_true', help='Disable median interval duration calculation')
258
+ # parser.add_argument('--frame_count', action='store_true', help='Include frame count in output')
259
+ # parser.add_argument('--video_length', action='store_true', help='Include video length in output')
260
+ #
261
+ # args = parser.parse_args()
262
+ #
263
+ # clf_calculator = AggregateClfCalculatorMultiprocess(
264
+ # config_path=args.config_path,
265
+ # classifiers=args.classifiers,
266
+ # data_dir=args.data_dir,
267
+ # detailed_bout_data=args.detailed_bout_data,
268
+ # transpose=args.transpose,
269
+ # first_occurrence=not args.no_first_occurrence,
270
+ # event_count=not args.no_event_count,
271
+ # total_event_duration=not args.no_total_event_duration,
272
+ # mean_event_duration=not args.no_mean_event_duration,
273
+ # median_event_duration=not args.no_median_event_duration,
274
+ # mean_interval_duration=not args.no_mean_interval_duration,
275
+ # median_interval_duration=not args.no_median_interval_duration,
276
+ # frame_count=args.frame_count,
277
+ # video_length=args.video_length
278
+ # )
279
+ # clf_calculator.run()
280
+ # clf_calculator.save()
281
+
282
+ if __name__ == "__main__":
283
+ test = AggregateClfCalculatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
284
+ classifiers=['attack'],
285
+ transpose=True,
286
+ mean_event_duration = True,
287
+ median_event_duration = True,
288
+ mean_interval_duration = True,
289
+ median_interval_duration = True,
290
+ detailed_bout_data=True,
291
+ core_cnt=12)
292
+ test.run()
293
+ test.save()
295
294
 
296
295
 
297
296
  # test = AggregateClfCalculator(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
@@ -51,7 +51,7 @@ class BlobLocationComputer(object):
51
51
  :param Optional[bool] multiprocessing: If True, video background subtraction will be done using multiprocessing. Default is False.
52
52
 
53
53
  :example:
54
- >>> x = BlobLocationComputer(data_path=r"C:\troubleshooting\RAT_NOR\project_folder\videos\2022-06-20_NOB_DOT_4_downsampled_bg_subtracted.mp4", multiprocessing=True, gpu=True, batch_size=2000, save_dir=r"C:\blob_positions")
54
+ >>> x = BlobLocationComputer(data_path=r"C:/troubleshooting/RAT_NOR/project_folder/videos/2022-06-20_NOB_DOT_4_downsampled_bg_subtracted.mp4", multiprocessing=True, gpu=True, batch_size=2000, save_dir=r"C:/blob_positions")
55
55
  >>> x.run()
56
56
  """
57
57
  def __init__(self,
@@ -11,12 +11,13 @@ from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
11
11
  from simba.mixins.timeseries_features_mixin import TimeseriesFeatureMixin
12
12
  from simba.utils.checks import (
13
13
  check_all_file_names_are_represented_in_video_log, check_if_dir_exists,
14
- check_int, check_str, check_valid_dataframe)
14
+ check_str, check_valid_dataframe)
15
15
  from simba.utils.data import detect_bouts, plug_holes_shortest_bout
16
16
  from simba.utils.enums import Formats
17
17
  from simba.utils.printing import stdout_success
18
18
  from simba.utils.read_write import (find_files_of_filetypes_in_directory,
19
- get_fn_ext, read_df, read_video_info)
19
+ get_current_time, get_fn_ext, read_df,
20
+ read_video_info)
20
21
 
21
22
  CIRCLING = 'CIRCLING'
22
23
 
@@ -58,30 +59,34 @@ class CirclingDetector(ConfigReader):
58
59
  """
59
60
 
60
61
  def __init__(self,
61
- data_dir: Union[str, os.PathLike],
62
62
  config_path: Union[str, os.PathLike],
63
63
  nose_name: Optional[str] = 'nose',
64
+ data_dir: Optional[Union[str, os.PathLike]] = None,
64
65
  left_ear_name: Optional[str] = 'left_ear',
65
66
  right_ear_name: Optional[str] = 'right_ear',
66
67
  tail_base_name: Optional[str] = 'tail_base',
67
68
  center_name: Optional[str] = 'center',
68
- time_threshold: Optional[int] = 10,
69
- circular_range_threshold: Optional[int] = 320,
69
+ time_threshold: Optional[int] = 7,
70
+ circular_range_threshold: Optional[int] = 350,
71
+ shortest_bout: int = 100,
70
72
  movement_threshold: Optional[int] = 60,
71
73
  save_dir: Optional[Union[str, os.PathLike]] = None):
72
74
 
73
- check_if_dir_exists(in_dir=data_dir)
74
75
  for bp_name in [nose_name, left_ear_name, right_ear_name, tail_base_name]: check_str(name='body part name', value=bp_name, allow_blank=False)
75
- self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
76
76
  ConfigReader.__init__(self, config_path=config_path, read_video_info=True, create_logger=False)
77
+ if data_dir is not None:
78
+ check_if_dir_exists(in_dir=data_dir)
79
+ else:
80
+ data_dir = self.outlier_corrected_dir
81
+ self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
77
82
  self.nose_heads = [f'{nose_name}_x'.lower(), f'{nose_name}_y'.lower()]
78
83
  self.left_ear_heads = [f'{left_ear_name}_x'.lower(), f'{left_ear_name}_y'.lower()]
79
84
  self.right_ear_heads = [f'{right_ear_name}_x'.lower(), f'{right_ear_name}_y'.lower()]
80
85
  self.center_heads = [f'{center_name}_x'.lower(), f'{center_name}_y'.lower()]
81
86
  self.required_field = self.nose_heads + self.left_ear_heads + self.right_ear_heads
82
- self.save_dir = save_dir
87
+ self.save_dir, self.shortest_bout = save_dir, shortest_bout
83
88
  if self.save_dir is None:
84
- self.save_dir = os.path.join(self.logs_path, f'circling_data_{self.datetime}')
89
+ self.save_dir = os.path.join(self.logs_path, f'circling_data_{time_threshold}s_{circular_range_threshold}d_{movement_threshold}mm_{self.datetime}')
85
90
  os.makedirs(self.save_dir)
86
91
  else:
87
92
  check_if_dir_exists(in_dir=self.save_dir)
@@ -93,7 +98,7 @@ class CirclingDetector(ConfigReader):
93
98
  check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
94
99
  for file_cnt, file_path in enumerate(self.data_paths):
95
100
  video_name = get_fn_ext(filepath=file_path)[1]
96
- print(f'Analyzing {video_name} ({file_cnt+1}/{len(self.data_paths)})...')
101
+ print(f'[{get_current_time()}] Analyzing circling {video_name}... (video {file_cnt+1}/{len(self.data_paths)})')
97
102
  save_file_path = os.path.join(self.save_dir, f'{video_name}.csv')
98
103
  df = read_df(file_path=file_path, file_type='csv').reset_index(drop=True)
99
104
  _, px_per_mm, fps = read_video_info(video_info_df=self.video_info_df, video_name=video_name)
@@ -115,11 +120,24 @@ class CirclingDetector(ConfigReader):
115
120
  circling_idx = np.argwhere(sliding_circular_range >= self.circular_range_threshold).astype(np.int32).flatten()
116
121
  movement_idx = np.argwhere(movement_sum >= self.movement_threshold).astype(np.int32).flatten()
117
122
  circling_idx = [x for x in movement_idx if x in circling_idx]
123
+ df[f'Probability_{CIRCLING}'] = 0
118
124
  df[CIRCLING] = 0
119
125
  df.loc[circling_idx, CIRCLING] = 1
126
+ df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
127
+ df = plug_holes_shortest_bout(data_df=df, clf_name=CIRCLING, fps=fps, shortest_bout=self.shortest_bout)
120
128
  bouts = detect_bouts(data_df=df, target_lst=[CIRCLING], fps=fps)
121
- df = plug_holes_shortest_bout(data_df=df, clf_name=CIRCLING, fps=fps, shortest_bout=100)
129
+ if len(bouts) > 0:
130
+ df[CIRCLING] = 0
131
+ circling_idx = list(bouts.apply(lambda x: list(range(int(x["Start_frame"]), int(x["End_frame"]) + 1)), 1))
132
+ circling_idx = [x for xs in circling_idx for x in xs]
133
+ df.loc[circling_idx, CIRCLING] = 1
134
+ df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
135
+ else:
136
+ df[CIRCLING] = 0
137
+ circling_idx = []
138
+
122
139
  df.to_csv(save_file_path)
140
+ #print(video_name, len(circling_idx), round(len(circling_idx) / fps, 4), df[CIRCLING].sum())
123
141
  agg_results.loc[len(agg_results)] = [video_name, len(circling_idx), round(len(circling_idx) / fps, 4), len(bouts), round((len(circling_idx) / len(df)) * 100, 4), len(df), round(len(df)/fps, 2) ]
124
142
 
125
143
  agg_results.to_csv(agg_results_path)
@@ -127,7 +145,6 @@ class CirclingDetector(ConfigReader):
127
145
 
128
146
  #
129
147
  #
130
- # detector = CirclingDetector(data_dir=r'C:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location',
131
- # config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
148
+ # detector = CirclingDetector(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
132
149
  # detector.run()
133
150
 
@@ -8,6 +8,7 @@ from numba import cuda, njit
8
8
 
9
9
  from simba.utils.checks import check_float, check_valid_array
10
10
  from simba.utils.enums import Formats
11
+ from simba.utils.printing import SimbaTimer
11
12
 
12
13
  try:
13
14
  import cupy as cp
@@ -401,20 +402,21 @@ def find_midpoints(x: np.ndarray,
401
402
 
402
403
 
403
404
  @cuda.jit()
404
- def _directionality_to_static_targets_kernel(left_ear, right_ear, nose, target, results):
405
+ def _directionality_to_static_targets_kernel(left_ear, right_ear, nose, target_x, target_y, results):
405
406
  i = cuda.grid(1)
406
- if i > left_ear.shape[0]:
407
+ if i >= left_ear.shape[0]:
407
408
  return
408
409
  else:
409
- LE, RE = left_ear[i], right_ear[i]
410
- N, Tx, Ty = nose[i], target[0], target[1]
411
-
412
- Px = abs(LE[0] - Tx)
413
- Py = abs(LE[1] - Ty)
414
- Qx = abs(RE[0] - Tx)
415
- Qy = abs(RE[1] - Ty)
416
- Nx = abs(N[0] - Tx)
417
- Ny = abs(N[1] - Ty)
410
+ LE = left_ear[i]
411
+ RE = right_ear[i]
412
+ N = nose[i]
413
+
414
+ Px = abs(LE[0] - target_x)
415
+ Py = abs(LE[1] - target_y)
416
+ Qx = abs(RE[0] - target_x)
417
+ Qy = abs(RE[1] - target_y)
418
+ Nx = abs(N[0] - target_x)
419
+ Ny = abs(N[1] - target_y)
418
420
  Ph = math.sqrt(Px * Px + Py * Py)
419
421
  Qh = math.sqrt(Qx * Qx + Qy * Qy)
420
422
  Nh = math.sqrt(Nx * Nx + Ny * Ny)
@@ -438,7 +440,8 @@ def _directionality_to_static_targets_kernel(left_ear, right_ear, nose, target,
438
440
  def directionality_to_static_targets(left_ear: np.ndarray,
439
441
  right_ear: np.ndarray,
440
442
  nose: np.ndarray,
441
- target: np.ndarray) -> np.ndarray:
443
+ target: np.ndarray,
444
+ verbose: bool = False) -> np.ndarray:
442
445
  """
443
446
  GPU helper to calculate if an animal is directing towards a static location (e.g., ROI centroid), given the target location and the left ear, right ear, and nose coordinates of the observer.
444
447
 
@@ -487,32 +490,38 @@ def directionality_to_static_targets(left_ear: np.ndarray,
487
490
  >>> directionality_to_static_targets(left_ear=left_ear, right_ear=right_ear, nose=nose, target=target)
488
491
 
489
492
  """
490
-
493
+ timer = SimbaTimer(start=True)
491
494
  left_ear = np.ascontiguousarray(left_ear).astype(np.int32)
492
495
  right_ear = np.ascontiguousarray(right_ear).astype(np.int32)
493
496
  nose = np.ascontiguousarray(nose).astype(np.int32)
494
497
  target = np.ascontiguousarray(target).astype(np.int32)
495
498
 
499
+ target_x = int(target[0])
500
+ target_y = int(target[1])
501
+
496
502
  left_ear_dev = cuda.to_device(left_ear)
497
503
  right_ear_dev = cuda.to_device(right_ear)
498
504
  nose_dev = cuda.to_device(nose)
499
- target_dev = cuda.to_device(target)
500
505
  results = cuda.device_array((left_ear.shape[0], 4), dtype=np.int32)
501
506
  bpg = (left_ear.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK
502
- _directionality_to_static_targets_kernel[bpg, THREADS_PER_BLOCK](left_ear_dev, right_ear_dev, nose_dev, target_dev, results)
507
+ _directionality_to_static_targets_kernel[bpg, THREADS_PER_BLOCK](left_ear_dev, right_ear_dev, nose_dev, target_x, target_y, results)
503
508
 
504
509
  results = results.copy_to_host()
510
+ timer.stop_timer()
511
+ if verbose: print(f'Directionality to static target computed in for {left_ear.shape[0]} observations (elapsed time: {timer.elapsed_time_str}s)')
505
512
  return results
506
513
 
507
514
 
508
515
  @cuda.jit()
509
516
  def _directionality_to_nonstatic_targets_kernel(left_ear, right_ear, nose, target, results):
510
517
  i = cuda.grid(1)
511
- if i > left_ear.shape[0]:
518
+ if i >= left_ear.shape[0]:
512
519
  return
513
520
  else:
514
- LE, RE = left_ear[i], right_ear[i]
515
- N, T = nose[i], target[i]
521
+ LE = left_ear[i]
522
+ RE = right_ear[i]
523
+ N = nose[i]
524
+ T = target[i]
516
525
 
517
526
  Px = abs(LE[0] - T[0])
518
527
  Py = abs(LE[1] - T[1])
@@ -543,11 +552,19 @@ def _directionality_to_nonstatic_targets_kernel(left_ear, right_ear, nose, targe
543
552
  def directionality_to_nonstatic_target(left_ear: np.ndarray,
544
553
  right_ear: np.ndarray,
545
554
  nose: np.ndarray,
546
- target: np.ndarray) -> np.ndarray:
555
+ target: np.ndarray,
556
+ verbose: bool = False) -> np.ndarray:
547
557
 
548
558
  """
549
559
  GPU method to calculate if an animal is directing towards a moving point location given the target location and the left ear, right ear, and nose coordinates of the observer.
550
560
 
561
+ .. csv-table::
562
+ :header: EXPECTED RUNTIMES
563
+ :file: ../../../docs/tables/directionality_to_nonstatic_target_cuda.csv
564
+ :widths: 30, 30, 20, 10, 10
565
+ :align: center
566
+ :class: simba-table
567
+ :header-rows: 1
551
568
 
552
569
  .. image:: _static/img/directing_moving_targets.png
553
570
  :width: 400
@@ -573,27 +590,28 @@ def directionality_to_nonstatic_target(left_ear: np.ndarray,
573
590
  >>> right_ear = np.random.randint(0, 500, (100, 2))
574
591
  >>> nose = np.random.randint(0, 500, (100, 2))
575
592
  >>> target = np.random.randint(0, 500, (100, 2))
576
- >>> directionality_to_static_targets(left_ear=left_ear, right_ear=right_ear, nose=nose, target=target)
593
+ >>> directionality_to_nonstatic_target(left_ear=left_ear, right_ear=right_ear, nose=nose, target=target)
577
594
  """
578
595
 
579
- left_ear = np.ascontiguousarray(left_ear).astype(np.int32)
580
- right_ear = np.ascontiguousarray(right_ear).astype(np.int32)
581
- nose = np.ascontiguousarray(nose).astype(np.int32)
582
- target = np.ascontiguousarray(target).astype(np.int32)
596
+ timer = SimbaTimer(start=True)
597
+ left_ear = np.ascontiguousarray(left_ear).astype(np.int64)
598
+ right_ear = np.ascontiguousarray(right_ear).astype(np.int64)
599
+ nose = np.ascontiguousarray(nose).astype(np.int64)
600
+ target = np.ascontiguousarray(target).astype(np.int64)
583
601
 
584
602
  left_ear_dev = cuda.to_device(left_ear)
585
603
  right_ear_dev = cuda.to_device(right_ear)
586
604
  nose_dev = cuda.to_device(nose)
587
605
  target_dev = cuda.to_device(target)
588
- results = cuda.device_array((left_ear.shape[0], 4), dtype=np.int32)
606
+ results = cuda.device_array((left_ear.shape[0], 4), dtype=np.int64)
589
607
  bpg = (left_ear.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK
590
608
  _directionality_to_nonstatic_targets_kernel[bpg, THREADS_PER_BLOCK](left_ear_dev, right_ear_dev, nose_dev, target_dev, results)
591
609
 
592
610
  results = results.copy_to_host()
611
+ timer.stop_timer()
612
+ if verbose: print(f'Directionality to moving target computed in for {left_ear.shape[0]} observations (elapsed time: {timer.elapsed_time_str}s)')
593
613
  return results
594
614
 
595
615
 
596
616
 
597
617
 
598
-
599
-