simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simba/assets/.recent_projects.txt +1 -0
- simba/assets/lookups/tooptips.json +6 -1
- simba/data_processors/agg_clf_counter_mp.py +52 -53
- simba/data_processors/blob_location_computer.py +1 -1
- simba/data_processors/circling_detector.py +30 -13
- simba/data_processors/cuda/geometry.py +45 -27
- simba/data_processors/cuda/image.py +1648 -1598
- simba/data_processors/cuda/statistics.py +72 -26
- simba/data_processors/cuda/timeseries.py +1 -1
- simba/data_processors/cue_light_analyzer.py +5 -9
- simba/data_processors/egocentric_aligner.py +25 -7
- simba/data_processors/freezing_detector.py +55 -47
- simba/data_processors/kleinberg_calculator.py +61 -29
- simba/feature_extractors/feature_subsets.py +14 -7
- simba/feature_extractors/mitra_feature_extractor.py +2 -2
- simba/feature_extractors/straub_tail_analyzer.py +4 -6
- simba/labelling/standard_labeller.py +1 -1
- simba/mixins/config_reader.py +5 -2
- simba/mixins/geometry_mixin.py +22 -36
- simba/mixins/image_mixin.py +24 -28
- simba/mixins/plotting_mixin.py +28 -10
- simba/mixins/statistics_mixin.py +48 -11
- simba/mixins/timeseries_features_mixin.py +1 -1
- simba/mixins/train_model_mixin.py +67 -29
- simba/model/inference_batch.py +1 -1
- simba/model/yolo_seg_inference.py +3 -3
- simba/outlier_tools/skip_outlier_correction.py +1 -1
- simba/plotting/ROI_feature_visualizer_mp.py +3 -5
- simba/plotting/clf_validator_mp.py +4 -5
- simba/plotting/cue_light_visualizer.py +6 -7
- simba/plotting/directing_animals_visualizer_mp.py +2 -3
- simba/plotting/distance_plotter_mp.py +378 -378
- simba/plotting/gantt_creator.py +29 -10
- simba/plotting/gantt_creator_mp.py +96 -33
- simba/plotting/geometry_plotter.py +270 -272
- simba/plotting/heat_mapper_clf_mp.py +4 -6
- simba/plotting/heat_mapper_location_mp.py +2 -2
- simba/plotting/light_dark_box_plotter.py +2 -2
- simba/plotting/path_plotter_mp.py +26 -29
- simba/plotting/plot_clf_results_mp.py +455 -454
- simba/plotting/pose_plotter_mp.py +28 -29
- simba/plotting/probability_plot_creator_mp.py +288 -288
- simba/plotting/roi_plotter_mp.py +31 -31
- simba/plotting/single_run_model_validation_video_mp.py +427 -427
- simba/plotting/spontaneous_alternation_plotter.py +2 -3
- simba/plotting/yolo_pose_track_visualizer.py +32 -27
- simba/plotting/yolo_pose_visualizer.py +35 -36
- simba/plotting/yolo_seg_visualizer.py +2 -3
- simba/pose_importers/simba_blob_importer.py +3 -3
- simba/roi_tools/roi_aggregate_stats_mp.py +5 -4
- simba/roi_tools/roi_clf_calculator_mp.py +4 -4
- simba/sandbox/analyze_runtimes.py +30 -0
- simba/sandbox/cuda/egocentric_rotator.py +374 -374
- simba/sandbox/get_cpu_pool.py +5 -0
- simba/sandbox/proboscis_to_tip.py +28 -0
- simba/sandbox/test_directionality.py +47 -0
- simba/sandbox/test_nonstatic_directionality.py +27 -0
- simba/sandbox/test_pycharm_cuda.py +51 -0
- simba/sandbox/test_simba_install.py +41 -0
- simba/sandbox/test_static_directionality.py +26 -0
- simba/sandbox/test_static_directionality_2d.py +26 -0
- simba/sandbox/verify_env.py +42 -0
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
- simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
- simba/ui/pop_ups/fsttc_pop_up.py +27 -25
- simba/ui/pop_ups/gantt_pop_up.py +31 -6
- simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
- simba/ui/pop_ups/video_processing_pop_up.py +37 -29
- simba/ui/tkinter_functions.py +3 -0
- simba/utils/custom_feature_extractor.py +1 -1
- simba/utils/data.py +90 -14
- simba/utils/enums.py +1 -0
- simba/utils/errors.py +441 -440
- simba/utils/lookups.py +1203 -1203
- simba/utils/printing.py +124 -124
- simba/utils/read_write.py +3769 -3721
- simba/utils/yolo.py +10 -1
- simba/video_processors/blob_tracking_executor.py +2 -2
- simba/video_processors/clahe_ui.py +1 -1
- simba/video_processors/egocentric_video_rotator.py +44 -41
- simba/video_processors/multi_cropper.py +1 -1
- simba/video_processors/video_processing.py +5264 -5222
- simba/video_processors/videos_to_frames.py +43 -33
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/METADATA +4 -3
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/RECORD +90 -80
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/top_level.txt +0 -0
|
@@ -38,5 +38,10 @@
|
|
|
38
38
|
"EGOCENTRIC_ANCHOR": "This body-part will be placed in the center of the video",
|
|
39
39
|
"EGOCENTRIC_DIRECTION_ANCHOR": "This body-part will be placed at N degrees relative to the anchor",
|
|
40
40
|
"EGOCENTRIC_DIRECTION": "The anchor body-part will always be placed at these degrees relative to the center anchor",
|
|
41
|
-
"CORE_COUNT": "Higher core counts speeds up processing but may require more RAM memory"
|
|
41
|
+
"CORE_COUNT": "Higher core counts speeds up processing but may require more RAM memory",
|
|
42
|
+
"KLEINBERG_SIGMA": "Higher values (e.g., 2-3) produce fewer but longer bursts; lower values (e.g., 1.1-1.5) detect more frequent, shorter bursts. Must be > 1.01",
|
|
43
|
+
"KLEINBERG_GAMMA": "Higher values (e.g., 0.5-1.0) reduce total burst count by making downward transitions costly; lower values (e.g., 0.1-0.3) allow more flexible state changes",
|
|
44
|
+
"KLEINBERG_HIERARCHY": "Hierarchy level to extract bursts from (0=lowest, higher=more selective).\n Level 0 captures all bursts; level 1-2 typically filters noise; level 3+ selects only the most prominent, sustained bursts.\nHigher levels yield fewer but more confident detections",
|
|
45
|
+
"KLEINBERG_HIERARCHY_SEARCH": "If True, searches for target hierarchy level within detected burst periods,\n falling back to lower levels if target not found. If False, extracts only bursts at the exact specified hierarchy level.\n Recommended when target hierarchy may be sparse.",
|
|
46
|
+
"KLEINBERG_SAVE_ORIGINALS": "If True, saves the original data in a new sub-directory of \nthe project_folder/csv/machine_results directory"
|
|
42
47
|
}
|
|
@@ -20,7 +20,7 @@ from simba.utils.checks import (
|
|
|
20
20
|
check_all_file_names_are_represented_in_video_log,
|
|
21
21
|
check_file_exist_and_readable, check_if_dir_exists, check_int,
|
|
22
22
|
check_valid_boolean, check_valid_dataframe, check_valid_lst)
|
|
23
|
-
from simba.utils.data import detect_bouts
|
|
23
|
+
from simba.utils.data import detect_bouts, terminate_cpu_pool
|
|
24
24
|
from simba.utils.enums import TagNames
|
|
25
25
|
from simba.utils.errors import NoChoosenMeasurementError
|
|
26
26
|
from simba.utils.printing import SimbaTimer, log_event, stdout_success
|
|
@@ -210,8 +210,7 @@ class AggregateClfCalculatorMultiprocess(ConfigReader):
|
|
|
210
210
|
self.bouts_df_lst.append(batch_bouts_df_lst)
|
|
211
211
|
print(f"Data batch core {batch_id+1} / {self.core_cnt} complete...")
|
|
212
212
|
self.bouts_df_lst = [df for sub in self.bouts_df_lst for df in sub]
|
|
213
|
-
pool
|
|
214
|
-
pool.terminate()
|
|
213
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
215
214
|
|
|
216
215
|
def save(self) -> None:
|
|
217
216
|
"""
|
|
@@ -242,56 +241,56 @@ class AggregateClfCalculatorMultiprocess(ConfigReader):
|
|
|
242
241
|
self.timer.stop_timer()
|
|
243
242
|
stdout_success(msg=f"Data aggregate log saved at {self.save_path}", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
|
|
244
243
|
|
|
245
|
-
if __name__ == "__main__" and not hasattr(sys, 'ps1'):
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
244
|
+
# if __name__ == "__main__" and not hasattr(sys, 'ps1'):
|
|
245
|
+
# parser = argparse.ArgumentParser(description='Compute aggregate descriptive statistics from classification data.')
|
|
246
|
+
# parser.add_argument('--config_path', type=str, required=True, help='Path to SimBA project config file')
|
|
247
|
+
# parser.add_argument('--classifiers', type=str, nargs='+', required=True, help='List of classifier names to analyze')
|
|
248
|
+
# parser.add_argument('--data_dir', type=str, default=None, help='Directory containing machine results CSV files (default: project machine_results directory)')
|
|
249
|
+
# parser.add_argument('--detailed_bout_data', action='store_true', help='Save detailed bout data for each bout')
|
|
250
|
+
# parser.add_argument('--transpose', action='store_true', help='Create output with one video per row')
|
|
251
|
+
# parser.add_argument('--no_first_occurrence', action='store_true', help='Disable first occurrence calculation')
|
|
252
|
+
# parser.add_argument('--no_event_count', action='store_true', help='Disable event count calculation')
|
|
253
|
+
# parser.add_argument('--no_total_event_duration', action='store_true', help='Disable total event duration calculation')
|
|
254
|
+
# parser.add_argument('--no_mean_event_duration', action='store_true', help='Disable mean event duration calculation')
|
|
255
|
+
# parser.add_argument('--no_median_event_duration', action='store_true', help='Disable median event duration calculation')
|
|
256
|
+
# parser.add_argument('--no_mean_interval_duration', action='store_true', help='Disable mean interval duration calculation')
|
|
257
|
+
# parser.add_argument('--no_median_interval_duration', action='store_true', help='Disable median interval duration calculation')
|
|
258
|
+
# parser.add_argument('--frame_count', action='store_true', help='Include frame count in output')
|
|
259
|
+
# parser.add_argument('--video_length', action='store_true', help='Include video length in output')
|
|
260
|
+
#
|
|
261
|
+
# args = parser.parse_args()
|
|
262
|
+
#
|
|
263
|
+
# clf_calculator = AggregateClfCalculatorMultiprocess(
|
|
264
|
+
# config_path=args.config_path,
|
|
265
|
+
# classifiers=args.classifiers,
|
|
266
|
+
# data_dir=args.data_dir,
|
|
267
|
+
# detailed_bout_data=args.detailed_bout_data,
|
|
268
|
+
# transpose=args.transpose,
|
|
269
|
+
# first_occurrence=not args.no_first_occurrence,
|
|
270
|
+
# event_count=not args.no_event_count,
|
|
271
|
+
# total_event_duration=not args.no_total_event_duration,
|
|
272
|
+
# mean_event_duration=not args.no_mean_event_duration,
|
|
273
|
+
# median_event_duration=not args.no_median_event_duration,
|
|
274
|
+
# mean_interval_duration=not args.no_mean_interval_duration,
|
|
275
|
+
# median_interval_duration=not args.no_median_interval_duration,
|
|
276
|
+
# frame_count=args.frame_count,
|
|
277
|
+
# video_length=args.video_length
|
|
278
|
+
# )
|
|
279
|
+
# clf_calculator.run()
|
|
280
|
+
# clf_calculator.save()
|
|
281
|
+
|
|
282
|
+
if __name__ == "__main__":
|
|
283
|
+
test = AggregateClfCalculatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
|
|
284
|
+
classifiers=['attack'],
|
|
285
|
+
transpose=True,
|
|
286
|
+
mean_event_duration = True,
|
|
287
|
+
median_event_duration = True,
|
|
288
|
+
mean_interval_duration = True,
|
|
289
|
+
median_interval_duration = True,
|
|
290
|
+
detailed_bout_data=True,
|
|
291
|
+
core_cnt=12)
|
|
292
|
+
test.run()
|
|
293
|
+
test.save()
|
|
295
294
|
|
|
296
295
|
|
|
297
296
|
# test = AggregateClfCalculator(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
|
|
@@ -51,7 +51,7 @@ class BlobLocationComputer(object):
|
|
|
51
51
|
:param Optional[bool] multiprocessing: If True, video background subtraction will be done using multiprocessing. Default is False.
|
|
52
52
|
|
|
53
53
|
:example:
|
|
54
|
-
>>> x = BlobLocationComputer(data_path=r"C
|
|
54
|
+
>>> x = BlobLocationComputer(data_path=r"C:/troubleshooting/RAT_NOR/project_folder/videos/2022-06-20_NOB_DOT_4_downsampled_bg_subtracted.mp4", multiprocessing=True, gpu=True, batch_size=2000, save_dir=r"C:/blob_positions")
|
|
55
55
|
>>> x.run()
|
|
56
56
|
"""
|
|
57
57
|
def __init__(self,
|
|
@@ -11,12 +11,13 @@ from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
|
|
|
11
11
|
from simba.mixins.timeseries_features_mixin import TimeseriesFeatureMixin
|
|
12
12
|
from simba.utils.checks import (
|
|
13
13
|
check_all_file_names_are_represented_in_video_log, check_if_dir_exists,
|
|
14
|
-
|
|
14
|
+
check_str, check_valid_dataframe)
|
|
15
15
|
from simba.utils.data import detect_bouts, plug_holes_shortest_bout
|
|
16
16
|
from simba.utils.enums import Formats
|
|
17
17
|
from simba.utils.printing import stdout_success
|
|
18
18
|
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
|
|
19
|
-
get_fn_ext, read_df,
|
|
19
|
+
get_current_time, get_fn_ext, read_df,
|
|
20
|
+
read_video_info)
|
|
20
21
|
|
|
21
22
|
CIRCLING = 'CIRCLING'
|
|
22
23
|
|
|
@@ -58,30 +59,34 @@ class CirclingDetector(ConfigReader):
|
|
|
58
59
|
"""
|
|
59
60
|
|
|
60
61
|
def __init__(self,
|
|
61
|
-
data_dir: Union[str, os.PathLike],
|
|
62
62
|
config_path: Union[str, os.PathLike],
|
|
63
63
|
nose_name: Optional[str] = 'nose',
|
|
64
|
+
data_dir: Optional[Union[str, os.PathLike]] = None,
|
|
64
65
|
left_ear_name: Optional[str] = 'left_ear',
|
|
65
66
|
right_ear_name: Optional[str] = 'right_ear',
|
|
66
67
|
tail_base_name: Optional[str] = 'tail_base',
|
|
67
68
|
center_name: Optional[str] = 'center',
|
|
68
|
-
time_threshold: Optional[int] =
|
|
69
|
-
circular_range_threshold: Optional[int] =
|
|
69
|
+
time_threshold: Optional[int] = 7,
|
|
70
|
+
circular_range_threshold: Optional[int] = 350,
|
|
71
|
+
shortest_bout: int = 100,
|
|
70
72
|
movement_threshold: Optional[int] = 60,
|
|
71
73
|
save_dir: Optional[Union[str, os.PathLike]] = None):
|
|
72
74
|
|
|
73
|
-
check_if_dir_exists(in_dir=data_dir)
|
|
74
75
|
for bp_name in [nose_name, left_ear_name, right_ear_name, tail_base_name]: check_str(name='body part name', value=bp_name, allow_blank=False)
|
|
75
|
-
self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
|
|
76
76
|
ConfigReader.__init__(self, config_path=config_path, read_video_info=True, create_logger=False)
|
|
77
|
+
if data_dir is not None:
|
|
78
|
+
check_if_dir_exists(in_dir=data_dir)
|
|
79
|
+
else:
|
|
80
|
+
data_dir = self.outlier_corrected_dir
|
|
81
|
+
self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
|
|
77
82
|
self.nose_heads = [f'{nose_name}_x'.lower(), f'{nose_name}_y'.lower()]
|
|
78
83
|
self.left_ear_heads = [f'{left_ear_name}_x'.lower(), f'{left_ear_name}_y'.lower()]
|
|
79
84
|
self.right_ear_heads = [f'{right_ear_name}_x'.lower(), f'{right_ear_name}_y'.lower()]
|
|
80
85
|
self.center_heads = [f'{center_name}_x'.lower(), f'{center_name}_y'.lower()]
|
|
81
86
|
self.required_field = self.nose_heads + self.left_ear_heads + self.right_ear_heads
|
|
82
|
-
self.save_dir = save_dir
|
|
87
|
+
self.save_dir, self.shortest_bout = save_dir, shortest_bout
|
|
83
88
|
if self.save_dir is None:
|
|
84
|
-
self.save_dir = os.path.join(self.logs_path, f'circling_data_{self.datetime}')
|
|
89
|
+
self.save_dir = os.path.join(self.logs_path, f'circling_data_{time_threshold}s_{circular_range_threshold}d_{movement_threshold}mm_{self.datetime}')
|
|
85
90
|
os.makedirs(self.save_dir)
|
|
86
91
|
else:
|
|
87
92
|
check_if_dir_exists(in_dir=self.save_dir)
|
|
@@ -93,7 +98,7 @@ class CirclingDetector(ConfigReader):
|
|
|
93
98
|
check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
|
|
94
99
|
for file_cnt, file_path in enumerate(self.data_paths):
|
|
95
100
|
video_name = get_fn_ext(filepath=file_path)[1]
|
|
96
|
-
print(f'Analyzing {video_name} ({file_cnt+1}/{len(self.data_paths)})
|
|
101
|
+
print(f'[{get_current_time()}] Analyzing circling {video_name}... (video {file_cnt+1}/{len(self.data_paths)})')
|
|
97
102
|
save_file_path = os.path.join(self.save_dir, f'{video_name}.csv')
|
|
98
103
|
df = read_df(file_path=file_path, file_type='csv').reset_index(drop=True)
|
|
99
104
|
_, px_per_mm, fps = read_video_info(video_info_df=self.video_info_df, video_name=video_name)
|
|
@@ -115,11 +120,24 @@ class CirclingDetector(ConfigReader):
|
|
|
115
120
|
circling_idx = np.argwhere(sliding_circular_range >= self.circular_range_threshold).astype(np.int32).flatten()
|
|
116
121
|
movement_idx = np.argwhere(movement_sum >= self.movement_threshold).astype(np.int32).flatten()
|
|
117
122
|
circling_idx = [x for x in movement_idx if x in circling_idx]
|
|
123
|
+
df[f'Probability_{CIRCLING}'] = 0
|
|
118
124
|
df[CIRCLING] = 0
|
|
119
125
|
df.loc[circling_idx, CIRCLING] = 1
|
|
126
|
+
df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
|
|
127
|
+
df = plug_holes_shortest_bout(data_df=df, clf_name=CIRCLING, fps=fps, shortest_bout=self.shortest_bout)
|
|
120
128
|
bouts = detect_bouts(data_df=df, target_lst=[CIRCLING], fps=fps)
|
|
121
|
-
|
|
129
|
+
if len(bouts) > 0:
|
|
130
|
+
df[CIRCLING] = 0
|
|
131
|
+
circling_idx = list(bouts.apply(lambda x: list(range(int(x["Start_frame"]), int(x["End_frame"]) + 1)), 1))
|
|
132
|
+
circling_idx = [x for xs in circling_idx for x in xs]
|
|
133
|
+
df.loc[circling_idx, CIRCLING] = 1
|
|
134
|
+
df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
|
|
135
|
+
else:
|
|
136
|
+
df[CIRCLING] = 0
|
|
137
|
+
circling_idx = []
|
|
138
|
+
|
|
122
139
|
df.to_csv(save_file_path)
|
|
140
|
+
#print(video_name, len(circling_idx), round(len(circling_idx) / fps, 4), df[CIRCLING].sum())
|
|
123
141
|
agg_results.loc[len(agg_results)] = [video_name, len(circling_idx), round(len(circling_idx) / fps, 4), len(bouts), round((len(circling_idx) / len(df)) * 100, 4), len(df), round(len(df)/fps, 2) ]
|
|
124
142
|
|
|
125
143
|
agg_results.to_csv(agg_results_path)
|
|
@@ -127,7 +145,6 @@ class CirclingDetector(ConfigReader):
|
|
|
127
145
|
|
|
128
146
|
#
|
|
129
147
|
#
|
|
130
|
-
# detector = CirclingDetector(
|
|
131
|
-
# config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
|
|
148
|
+
# detector = CirclingDetector(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
|
|
132
149
|
# detector.run()
|
|
133
150
|
|
|
@@ -8,6 +8,7 @@ from numba import cuda, njit
|
|
|
8
8
|
|
|
9
9
|
from simba.utils.checks import check_float, check_valid_array
|
|
10
10
|
from simba.utils.enums import Formats
|
|
11
|
+
from simba.utils.printing import SimbaTimer
|
|
11
12
|
|
|
12
13
|
try:
|
|
13
14
|
import cupy as cp
|
|
@@ -401,20 +402,21 @@ def find_midpoints(x: np.ndarray,
|
|
|
401
402
|
|
|
402
403
|
|
|
403
404
|
@cuda.jit()
|
|
404
|
-
def _directionality_to_static_targets_kernel(left_ear, right_ear, nose,
|
|
405
|
+
def _directionality_to_static_targets_kernel(left_ear, right_ear, nose, target_x, target_y, results):
|
|
405
406
|
i = cuda.grid(1)
|
|
406
|
-
if i
|
|
407
|
+
if i >= left_ear.shape[0]:
|
|
407
408
|
return
|
|
408
409
|
else:
|
|
409
|
-
LE
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
410
|
+
LE = left_ear[i]
|
|
411
|
+
RE = right_ear[i]
|
|
412
|
+
N = nose[i]
|
|
413
|
+
|
|
414
|
+
Px = abs(LE[0] - target_x)
|
|
415
|
+
Py = abs(LE[1] - target_y)
|
|
416
|
+
Qx = abs(RE[0] - target_x)
|
|
417
|
+
Qy = abs(RE[1] - target_y)
|
|
418
|
+
Nx = abs(N[0] - target_x)
|
|
419
|
+
Ny = abs(N[1] - target_y)
|
|
418
420
|
Ph = math.sqrt(Px * Px + Py * Py)
|
|
419
421
|
Qh = math.sqrt(Qx * Qx + Qy * Qy)
|
|
420
422
|
Nh = math.sqrt(Nx * Nx + Ny * Ny)
|
|
@@ -438,7 +440,8 @@ def _directionality_to_static_targets_kernel(left_ear, right_ear, nose, target,
|
|
|
438
440
|
def directionality_to_static_targets(left_ear: np.ndarray,
|
|
439
441
|
right_ear: np.ndarray,
|
|
440
442
|
nose: np.ndarray,
|
|
441
|
-
target: np.ndarray
|
|
443
|
+
target: np.ndarray,
|
|
444
|
+
verbose: bool = False) -> np.ndarray:
|
|
442
445
|
"""
|
|
443
446
|
GPU helper to calculate if an animal is directing towards a static location (e.g., ROI centroid), given the target location and the left ear, right ear, and nose coordinates of the observer.
|
|
444
447
|
|
|
@@ -487,32 +490,38 @@ def directionality_to_static_targets(left_ear: np.ndarray,
|
|
|
487
490
|
>>> directionality_to_static_targets(left_ear=left_ear, right_ear=right_ear, nose=nose, target=target)
|
|
488
491
|
|
|
489
492
|
"""
|
|
490
|
-
|
|
493
|
+
timer = SimbaTimer(start=True)
|
|
491
494
|
left_ear = np.ascontiguousarray(left_ear).astype(np.int32)
|
|
492
495
|
right_ear = np.ascontiguousarray(right_ear).astype(np.int32)
|
|
493
496
|
nose = np.ascontiguousarray(nose).astype(np.int32)
|
|
494
497
|
target = np.ascontiguousarray(target).astype(np.int32)
|
|
495
498
|
|
|
499
|
+
target_x = int(target[0])
|
|
500
|
+
target_y = int(target[1])
|
|
501
|
+
|
|
496
502
|
left_ear_dev = cuda.to_device(left_ear)
|
|
497
503
|
right_ear_dev = cuda.to_device(right_ear)
|
|
498
504
|
nose_dev = cuda.to_device(nose)
|
|
499
|
-
target_dev = cuda.to_device(target)
|
|
500
505
|
results = cuda.device_array((left_ear.shape[0], 4), dtype=np.int32)
|
|
501
506
|
bpg = (left_ear.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK
|
|
502
|
-
_directionality_to_static_targets_kernel[bpg, THREADS_PER_BLOCK](left_ear_dev, right_ear_dev, nose_dev,
|
|
507
|
+
_directionality_to_static_targets_kernel[bpg, THREADS_PER_BLOCK](left_ear_dev, right_ear_dev, nose_dev, target_x, target_y, results)
|
|
503
508
|
|
|
504
509
|
results = results.copy_to_host()
|
|
510
|
+
timer.stop_timer()
|
|
511
|
+
if verbose: print(f'Directionality to static target computed in for {left_ear.shape[0]} observations (elapsed time: {timer.elapsed_time_str}s)')
|
|
505
512
|
return results
|
|
506
513
|
|
|
507
514
|
|
|
508
515
|
@cuda.jit()
|
|
509
516
|
def _directionality_to_nonstatic_targets_kernel(left_ear, right_ear, nose, target, results):
|
|
510
517
|
i = cuda.grid(1)
|
|
511
|
-
if i
|
|
518
|
+
if i >= left_ear.shape[0]:
|
|
512
519
|
return
|
|
513
520
|
else:
|
|
514
|
-
LE
|
|
515
|
-
|
|
521
|
+
LE = left_ear[i]
|
|
522
|
+
RE = right_ear[i]
|
|
523
|
+
N = nose[i]
|
|
524
|
+
T = target[i]
|
|
516
525
|
|
|
517
526
|
Px = abs(LE[0] - T[0])
|
|
518
527
|
Py = abs(LE[1] - T[1])
|
|
@@ -543,11 +552,19 @@ def _directionality_to_nonstatic_targets_kernel(left_ear, right_ear, nose, targe
|
|
|
543
552
|
def directionality_to_nonstatic_target(left_ear: np.ndarray,
|
|
544
553
|
right_ear: np.ndarray,
|
|
545
554
|
nose: np.ndarray,
|
|
546
|
-
target: np.ndarray
|
|
555
|
+
target: np.ndarray,
|
|
556
|
+
verbose: bool = False) -> np.ndarray:
|
|
547
557
|
|
|
548
558
|
"""
|
|
549
559
|
GPU method to calculate if an animal is directing towards a moving point location given the target location and the left ear, right ear, and nose coordinates of the observer.
|
|
550
560
|
|
|
561
|
+
.. csv-table::
|
|
562
|
+
:header: EXPECTED RUNTIMES
|
|
563
|
+
:file: ../../../docs/tables/directionality_to_nonstatic_target_cuda.csv
|
|
564
|
+
:widths: 30, 30, 20, 10, 10
|
|
565
|
+
:align: center
|
|
566
|
+
:class: simba-table
|
|
567
|
+
:header-rows: 1
|
|
551
568
|
|
|
552
569
|
.. image:: _static/img/directing_moving_targets.png
|
|
553
570
|
:width: 400
|
|
@@ -573,27 +590,28 @@ def directionality_to_nonstatic_target(left_ear: np.ndarray,
|
|
|
573
590
|
>>> right_ear = np.random.randint(0, 500, (100, 2))
|
|
574
591
|
>>> nose = np.random.randint(0, 500, (100, 2))
|
|
575
592
|
>>> target = np.random.randint(0, 500, (100, 2))
|
|
576
|
-
>>>
|
|
593
|
+
>>> directionality_to_nonstatic_target(left_ear=left_ear, right_ear=right_ear, nose=nose, target=target)
|
|
577
594
|
"""
|
|
578
595
|
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
596
|
+
timer = SimbaTimer(start=True)
|
|
597
|
+
left_ear = np.ascontiguousarray(left_ear).astype(np.int64)
|
|
598
|
+
right_ear = np.ascontiguousarray(right_ear).astype(np.int64)
|
|
599
|
+
nose = np.ascontiguousarray(nose).astype(np.int64)
|
|
600
|
+
target = np.ascontiguousarray(target).astype(np.int64)
|
|
583
601
|
|
|
584
602
|
left_ear_dev = cuda.to_device(left_ear)
|
|
585
603
|
right_ear_dev = cuda.to_device(right_ear)
|
|
586
604
|
nose_dev = cuda.to_device(nose)
|
|
587
605
|
target_dev = cuda.to_device(target)
|
|
588
|
-
results = cuda.device_array((left_ear.shape[0], 4), dtype=np.
|
|
606
|
+
results = cuda.device_array((left_ear.shape[0], 4), dtype=np.int64)
|
|
589
607
|
bpg = (left_ear.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK
|
|
590
608
|
_directionality_to_nonstatic_targets_kernel[bpg, THREADS_PER_BLOCK](left_ear_dev, right_ear_dev, nose_dev, target_dev, results)
|
|
591
609
|
|
|
592
610
|
results = results.copy_to_host()
|
|
611
|
+
timer.stop_timer()
|
|
612
|
+
if verbose: print(f'Directionality to moving target computed in for {left_ear.shape[0]} observations (elapsed time: {timer.elapsed_time_str}s)')
|
|
593
613
|
return results
|
|
594
614
|
|
|
595
615
|
|
|
596
616
|
|
|
597
617
|
|
|
598
|
-
|
|
599
|
-
|