simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of simba-uw-tf-dev might be problematic. Click here for more details.
- simba/data_processors/agg_clf_counter_mp.py +52 -53
- simba/data_processors/cuda/image.py +3 -1
- simba/data_processors/cue_light_analyzer.py +5 -9
- simba/mixins/geometry_mixin.py +14 -28
- simba/mixins/image_mixin.py +10 -14
- simba/mixins/train_model_mixin.py +2 -2
- simba/plotting/ROI_feature_visualizer_mp.py +3 -5
- simba/plotting/clf_validator_mp.py +4 -5
- simba/plotting/cue_light_visualizer.py +6 -7
- simba/plotting/directing_animals_visualizer_mp.py +2 -3
- simba/plotting/distance_plotter_mp.py +378 -378
- simba/plotting/gantt_creator_mp.py +59 -31
- simba/plotting/geometry_plotter.py +270 -272
- simba/plotting/heat_mapper_clf_mp.py +2 -4
- simba/plotting/heat_mapper_location_mp.py +2 -2
- simba/plotting/light_dark_box_plotter.py +2 -2
- simba/plotting/path_plotter_mp.py +26 -29
- simba/plotting/plot_clf_results_mp.py +455 -454
- simba/plotting/pose_plotter_mp.py +27 -32
- simba/plotting/probability_plot_creator_mp.py +288 -288
- simba/plotting/roi_plotter_mp.py +29 -30
- simba/plotting/single_run_model_validation_video_mp.py +427 -427
- simba/plotting/spontaneous_alternation_plotter.py +2 -3
- simba/plotting/yolo_pose_track_visualizer.py +31 -27
- simba/plotting/yolo_pose_visualizer.py +32 -34
- simba/plotting/yolo_seg_visualizer.py +2 -3
- simba/roi_tools/roi_aggregate_stats_mp.py +4 -3
- simba/roi_tools/roi_clf_calculator_mp.py +3 -3
- simba/sandbox/get_cpu_pool.py +5 -0
- simba/utils/data.py +89 -12
- simba/utils/enums.py +1 -0
- simba/utils/printing.py +9 -8
- simba/utils/read_write.py +3726 -3721
- simba/video_processors/egocentric_video_rotator.py +2 -4
- simba/video_processors/video_processing.py +19 -8
- simba/video_processors/videos_to_frames.py +1 -1
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/METADATA +1 -1
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/RECORD +42 -41
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.3.dist-info}/top_level.txt +0 -0
|
@@ -14,7 +14,7 @@ from simba.data_processors.spontaneous_alternation_calculator import \
|
|
|
14
14
|
from simba.mixins.config_reader import ConfigReader
|
|
15
15
|
from simba.utils.checks import (check_file_exist_and_readable, check_float,
|
|
16
16
|
check_int, check_str, check_valid_lst)
|
|
17
|
-
from simba.utils.data import detect_bouts
|
|
17
|
+
from simba.utils.data import detect_bouts, terminate_cpu_pool
|
|
18
18
|
from simba.utils.enums import Formats, Paths, TextOptions
|
|
19
19
|
from simba.utils.errors import AnimalNumberError, InvalidInputError
|
|
20
20
|
from simba.utils.printing import stdout_success
|
|
@@ -296,8 +296,7 @@ class SpontaneousAlternationsPlotter(ConfigReader):
|
|
|
296
296
|
pool.imap(constants, frm_index, chunksize=self.multiprocess_chunksize)
|
|
297
297
|
):
|
|
298
298
|
print(f"Section {cnt} complete...")
|
|
299
|
-
pool
|
|
300
|
-
pool.join()
|
|
299
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
301
300
|
print(f"Joining {sa_computer.video_name} multiprocessed video...")
|
|
302
301
|
concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=save_path)
|
|
303
302
|
self.timer.stop_timer()
|
|
@@ -13,9 +13,10 @@ from simba.mixins.plotting_mixin import PlottingMixin
|
|
|
13
13
|
from simba.utils.checks import (check_file_exist_and_readable, check_float,
|
|
14
14
|
check_if_dir_exists, check_int,
|
|
15
15
|
check_valid_boolean, check_valid_dataframe)
|
|
16
|
+
from simba.utils.data import terminate_cpu_pool, get_cpu_pool
|
|
16
17
|
from simba.utils.enums import Defaults, Formats
|
|
17
18
|
from simba.utils.errors import InvalidFilepathError, NoFilesFoundError
|
|
18
|
-
from simba.utils.lookups import get_random_color_palette, intermittent_palette
|
|
19
|
+
from simba.utils.lookups import get_random_color_palette, intermittent_palette, get_current_time
|
|
19
20
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
20
21
|
from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
21
22
|
create_directory,
|
|
@@ -53,13 +54,12 @@ def _yolo_keypoint_track_visualizer(frm_ids: np.ndarray,
|
|
|
53
54
|
video_save_path = os.path.join(save_dir, f'{batch_id}.mp4')
|
|
54
55
|
video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"], video_meta_data["height"]))
|
|
55
56
|
while current_frm <= end_frm:
|
|
56
|
-
print(f'Processing frame {current_frm}/{video_meta_data["frame_count"]} (batch: {batch_id})...')
|
|
57
|
+
print(f'[{get_current_time()}] Processing frame {current_frm}/{video_meta_data["frame_count"]} (batch: {batch_id}, video name: {video_meta_data["video_name"]})...')
|
|
57
58
|
img = read_frm_of_video(video_path=video_path, frame_index=current_frm, raise_error=False)
|
|
58
59
|
if img is not None:
|
|
59
60
|
frm_data = data.loc[data[FRAME] == current_frm]
|
|
60
61
|
frm_data = frm_data[frm_data[CONFIDENCE] > threshold]
|
|
61
62
|
for cnt, (row, row_data) in enumerate(frm_data.iterrows()):
|
|
62
|
-
|
|
63
63
|
clrs = np.array(palettes[int(row_data[TRACK])]).astype(np.int32)
|
|
64
64
|
bbox_cords = row_data[BOX_CORD_FIELDS].values.astype(np.int32).reshape(-1, 2)
|
|
65
65
|
kp_coords = row_data.drop(EXPECTED_COLS).values.astype(np.int32).reshape(-1, 3)[:, :-1]
|
|
@@ -159,6 +159,8 @@ class YOLOPoseTrackVisualizer():
|
|
|
159
159
|
self.threshold, self.circle_size, self.thickness, self.show_bbox, self.overwrite = threshold, circle_size, thickness, bbox, overwrite
|
|
160
160
|
|
|
161
161
|
def run(self):
|
|
162
|
+
self.pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=Defaults.MAXIMUM_MAX_TASK_PER_CHILD.value, source=self.__class__.__name__)
|
|
163
|
+
self.timer = SimbaTimer(start=True)
|
|
162
164
|
for video_cnt, (video_name, data_path) in enumerate(self.data_paths.items()):
|
|
163
165
|
print(f'Visualizing YOLO pose tracks in video {video_name} ({video_cnt+1}/{len(self.data_paths.keys())}) ...')
|
|
164
166
|
video_timer = SimbaTimer(start=True)
|
|
@@ -189,23 +191,25 @@ class YOLOPoseTrackVisualizer():
|
|
|
189
191
|
|
|
190
192
|
frm_batches = np.array_split(np.array(list(range(0, df_frm_cnt))), self.core_cnt)
|
|
191
193
|
frm_batches = [(i, j) for i, j in enumerate(frm_batches)]
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
print(f'Video batch {result+1}/{self.core_cnt} complete...')
|
|
204
|
-
pool.terminate()
|
|
205
|
-
pool.join()
|
|
194
|
+
constants = functools.partial(_yolo_keypoint_track_visualizer,
|
|
195
|
+
data=self.data_df,
|
|
196
|
+
threshold=self.threshold,
|
|
197
|
+
video_path=self.video_paths[video_name],
|
|
198
|
+
save_dir=video_temp_dir,
|
|
199
|
+
circle_size=video_circle_size,
|
|
200
|
+
thickness=video_thickness,
|
|
201
|
+
palettes=video_palettes,
|
|
202
|
+
show_bbox=self.show_bbox)
|
|
203
|
+
for cnt, result in enumerate(self.pool.imap(constants, frm_batches, chunksize=1)):
|
|
204
|
+
print(f'[{get_current_time()}] Video batch {result+1}/{self.core_cnt} complete...')
|
|
206
205
|
video_timer.stop_timer()
|
|
207
206
|
concatenate_videos_in_folder(in_folder=video_temp_dir, save_path=save_path, gpu=True)
|
|
208
207
|
stdout_success(msg=f'YOLO track pose video saved at {save_path}', source=self.__class__.__name__, elapsed_time=video_timer.elapsed_time_str)
|
|
208
|
+
terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__)
|
|
209
|
+
self.timer.stop_timer()
|
|
210
|
+
stdout_success(msg=f'YOLO track pose video data for {len(self.data_paths.keys())} videos saved in {self.save_dir}', source=self.__class__.__name__, elapsed_time=self.timer.elapsed_time_str)
|
|
211
|
+
|
|
212
|
+
|
|
209
213
|
#
|
|
210
214
|
# if __name__ == "__main__" and not hasattr(sys, 'ps1'):
|
|
211
215
|
# parser = argparse.ArgumentParser(description="Visualize YOLO pose tracking CSV outputs on their source videos.")
|
|
@@ -247,13 +251,13 @@ class YOLOPoseTrackVisualizer():
|
|
|
247
251
|
# #kp_vis.run()
|
|
248
252
|
|
|
249
253
|
|
|
250
|
-
if __name__ == "__main__":
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
254
|
+
# if __name__ == "__main__":
|
|
255
|
+
# VIDEO_PATH = r"E:\netholabs_videos\primeintellect_100_videos"
|
|
256
|
+
# DATA_PATH = r"E:\netholabs_videos\primeintellect_100_largest"
|
|
257
|
+
# SAVE_DIR = r"E:\netholabs_videos\primeintellect_100_videos\out"
|
|
258
|
+
# kp_vis = YOLOPoseTrackVisualizer(data_path=DATA_PATH,
|
|
259
|
+
# video_path=VIDEO_PATH,
|
|
260
|
+
# save_dir=SAVE_DIR,
|
|
261
|
+
# core_cnt=8,
|
|
262
|
+
# bbox=True)
|
|
263
|
+
# kp_vis.run()
|
|
@@ -14,7 +14,7 @@ from simba.utils.checks import (check_file_exist_and_readable, check_float,
|
|
|
14
14
|
check_if_dir_exists, check_int,
|
|
15
15
|
check_valid_boolean, check_valid_dataframe,
|
|
16
16
|
check_valid_lst, check_valid_tuple)
|
|
17
|
-
from simba.utils.data import create_color_palette
|
|
17
|
+
from simba.utils.data import create_color_palette, terminate_cpu_pool, get_cpu_pool
|
|
18
18
|
from simba.utils.enums import Defaults, Options
|
|
19
19
|
from simba.utils.errors import (CountError, DataHeaderError, FrameRangeError,
|
|
20
20
|
InvalidInputError, NoDataError)
|
|
@@ -24,8 +24,7 @@ from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
|
24
24
|
find_files_of_filetypes_in_directory,
|
|
25
25
|
get_fn_ext, get_video_meta_data,
|
|
26
26
|
read_frm_of_video, recursive_file_search,
|
|
27
|
-
remove_a_folder)
|
|
28
|
-
from simba.utils.warnings import InvalidValueWarning
|
|
27
|
+
remove_a_folder, get_current_time)
|
|
29
28
|
|
|
30
29
|
FRAME = 'FRAME'
|
|
31
30
|
CLASS_ID = 'CLASS_ID'
|
|
@@ -58,7 +57,7 @@ def _yolo_keypoint_visualizer(frm_ids: np.ndarray,
|
|
|
58
57
|
if TRACK in data.columns:
|
|
59
58
|
data = data.drop([TRACK], axis=1)
|
|
60
59
|
while current_frm <= end_frm:
|
|
61
|
-
print(f'Processing frame {current_frm}/{video_meta_data["frame_count"]} (batch: {batch_id}, video: {video_meta_data["video_name"]})...')
|
|
60
|
+
print(f'[{get_current_time()}] Processing frame {current_frm}/{video_meta_data["frame_count"]} (batch: {batch_id}, video: {video_meta_data["video_name"]})...')
|
|
62
61
|
img = read_frm_of_video(video_path=video_path, frame_index=current_frm)
|
|
63
62
|
frm_data = data.loc[data[FRAME] == current_frm]
|
|
64
63
|
frm_data = frm_data[frm_data[CONFIDENCE] > threshold]
|
|
@@ -206,6 +205,7 @@ class YOLOPoseVisualizer():
|
|
|
206
205
|
self.timer = SimbaTimer(start=True)
|
|
207
206
|
|
|
208
207
|
def run(self):
|
|
208
|
+
self.pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=Defaults.MAXIMUM_MAX_TASK_PER_CHILD.value, verbose=True, source=self.__class__.__name__)
|
|
209
209
|
for video_cnt, (video_name, data_path) in enumerate(self.data_paths.items()):
|
|
210
210
|
video_timer = SimbaTimer(start=True)
|
|
211
211
|
self.video_temp_dir = os.path.join(self.save_dir, video_name, "temp")
|
|
@@ -248,26 +248,24 @@ class YOLOPoseVisualizer():
|
|
|
248
248
|
thickness = deepcopy(self.thickness)
|
|
249
249
|
frm_batches = np.array_split(np.array(list(range(0, self.df_frm_cnt))), self.core_cnt)
|
|
250
250
|
frm_batches = [(i, j) for i, j in enumerate(frm_batches)]
|
|
251
|
-
if self.verbose: print(f'Visualizing video {self.video_meta_data["video_name"]} (frame count: {self.video_meta_data["frame_count"]})...')
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
print(f'Video batch {result+1}/{self.core_cnt} complete...')
|
|
265
|
-
pool.terminate()
|
|
266
|
-
pool.join()
|
|
251
|
+
if self.verbose: print(f'[{get_current_time()}] Visualizing video {self.video_meta_data["video_name"]} (frame count: {self.video_meta_data["frame_count"]})...')
|
|
252
|
+
constants = functools.partial(_yolo_keypoint_visualizer,
|
|
253
|
+
data=self.data_df,
|
|
254
|
+
threshold=self.threshold,
|
|
255
|
+
video_path=self.video_paths[video_name],
|
|
256
|
+
save_dir=self.video_temp_dir,
|
|
257
|
+
circle_size=circle_size,
|
|
258
|
+
thickness=thickness,
|
|
259
|
+
palettes=self.clrs,
|
|
260
|
+
bbox=self.bbox,
|
|
261
|
+
skeleton=self.skeleton)
|
|
262
|
+
for cnt, result in enumerate(self.pool.imap(constants, frm_batches, chunksize=1)):
|
|
263
|
+
print(f'[{get_current_time()}] Video batch {result+1}/{self.core_cnt} complete...')
|
|
267
264
|
video_timer.stop_timer()
|
|
268
265
|
concatenate_videos_in_folder(in_folder=self.video_temp_dir, save_path=self.save_path, gpu=True)
|
|
269
266
|
stdout_success(msg=f'YOLO pose video saved at {self.save_path} (Video {video_cnt+1}/{len(list(self.data_paths.keys()))})', source=self.__class__.__name__, elapsed_time=video_timer.elapsed_time_str)
|
|
270
267
|
|
|
268
|
+
terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__)
|
|
271
269
|
self.timer.stop_timer()
|
|
272
270
|
stdout_success(msg=f'{len(list(self.data_paths.keys()))} YOLO pose video saved in directory {self.save_dir}', source=self.__class__.__name__, elapsed_time=self.timer.elapsed_time_str)
|
|
273
271
|
|
|
@@ -413,18 +411,18 @@ class YOLOPoseVisualizer():
|
|
|
413
411
|
# kp_vis.run()
|
|
414
412
|
|
|
415
413
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
414
|
+
if __name__ == "__main__":
|
|
415
|
+
video_path = r"E:\netholabs_videos\primeintellect_100_videos\cage_1_date_2025_08_28_hour_20_minute_21.avi"
|
|
416
|
+
data_path = r"E:\netholabs_videos\primeintellect_100_largest\cage_1_date_2025_08_28_hour_20_minute_21.csv"
|
|
417
|
+
save_dir = r'E:\netholabs_videos\test_order'
|
|
418
|
+
kp_vis = YOLOPoseVisualizer(data_path=data_path,
|
|
419
|
+
video_path=video_path,
|
|
420
|
+
save_dir=save_dir,
|
|
421
|
+
core_cnt=14,
|
|
422
|
+
palettes=('tab20',),
|
|
423
|
+
recursive=True,
|
|
424
|
+
sample_n=None)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
kp_vis.run()
|
|
430
428
|
|
|
@@ -12,7 +12,7 @@ from simba.utils.checks import (check_file_exist_and_readable, check_float,
|
|
|
12
12
|
check_int, check_valid_boolean,
|
|
13
13
|
check_valid_dataframe, check_valid_lst,
|
|
14
14
|
check_valid_tuple)
|
|
15
|
-
from simba.utils.data import create_color_palette
|
|
15
|
+
from simba.utils.data import create_color_palette, terminate_cpu_pool
|
|
16
16
|
from simba.utils.enums import Defaults, Options
|
|
17
17
|
from simba.utils.errors import CountError, DataHeaderError, FrameRangeError
|
|
18
18
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
@@ -140,8 +140,7 @@ class YOLOSegmentationVisualizer():
|
|
|
140
140
|
shape_opacity=self.shape_opacity)
|
|
141
141
|
for cnt, result in enumerate(pool.imap(constants, frm_batches, chunksize=1)):
|
|
142
142
|
print(f'Video batch {result+1}/{self.core_cnt} complete...')
|
|
143
|
-
pool
|
|
144
|
-
pool.join()
|
|
143
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
145
144
|
video_timer.stop_timer()
|
|
146
145
|
concatenate_videos_in_folder(in_folder=self.video_temp_dir, save_path=self.save_path, gpu=True)
|
|
147
146
|
stdout_success(msg=f'YOLO pose video saved at {self.save_path}', source=self.__class__.__name__, elapsed_time=video_timer.elapsed_time_str)
|
|
@@ -17,7 +17,8 @@ from simba.utils.checks import (
|
|
|
17
17
|
check_all_file_names_are_represented_in_video_log,
|
|
18
18
|
check_file_exist_and_readable, check_float, check_if_dir_exists, check_int,
|
|
19
19
|
check_that_column_exist, check_valid_boolean, check_valid_lst)
|
|
20
|
-
from simba.utils.data import detect_bouts, slice_roi_dict_for_video
|
|
20
|
+
from simba.utils.data import (detect_bouts, slice_roi_dict_for_video,
|
|
21
|
+
terminate_cpu_pool)
|
|
21
22
|
from simba.utils.enums import ROI_SETTINGS, Formats, Keys
|
|
22
23
|
from simba.utils.errors import CountError, ROICoordinatesNotFoundError
|
|
23
24
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
@@ -297,8 +298,8 @@ class ROIAggregateStatisticsAnalyzerMultiprocess(ConfigReader, FeatureExtraction
|
|
|
297
298
|
self.results.append(result); self.detailed_dfs.append(detailed_dfs)
|
|
298
299
|
print(f"Data batch core {batch_id} / {self.core_cnt} complete...")
|
|
299
300
|
self.results = pd.concat(self.results, axis=0).reset_index(drop=True)
|
|
300
|
-
pool
|
|
301
|
-
|
|
301
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
302
|
+
|
|
302
303
|
|
|
303
304
|
def save(self):
|
|
304
305
|
self.__clean_results()
|
|
@@ -16,7 +16,8 @@ from simba.utils.checks import (
|
|
|
16
16
|
check_all_file_names_are_represented_in_video_log,
|
|
17
17
|
check_file_exist_and_readable, check_if_dir_exists, check_int,
|
|
18
18
|
check_valid_boolean, check_valid_dataframe, check_valid_lst)
|
|
19
|
-
from simba.utils.data import detect_bouts, slice_roi_dict_for_video
|
|
19
|
+
from simba.utils.data import (detect_bouts, slice_roi_dict_for_video,
|
|
20
|
+
terminate_cpu_pool)
|
|
20
21
|
from simba.utils.enums import ROI_SETTINGS, Keys
|
|
21
22
|
from simba.utils.errors import InvalidInputError, NoROIDataError
|
|
22
23
|
from simba.utils.lookups import get_current_time
|
|
@@ -237,8 +238,7 @@ class ROIClfCalculatorMultiprocess(ConfigReader):
|
|
|
237
238
|
self.bouts_results.append(batch_bout_results)
|
|
238
239
|
print(f"Data batch core {batch_id + 1} / {self.core_cnt} complete...")
|
|
239
240
|
self.bouts_results = pd.concat(self.bouts_results, axis=0).reset_index(drop=True) if len(self.bouts_results) > 0 else None
|
|
240
|
-
pool
|
|
241
|
-
pool.terminate()
|
|
241
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
242
242
|
|
|
243
243
|
def save(self):
|
|
244
244
|
self.timer.stop_timer()
|
simba/utils/data.py
CHANGED
|
@@ -5,6 +5,7 @@ import configparser
|
|
|
5
5
|
import gc
|
|
6
6
|
import io
|
|
7
7
|
import os
|
|
8
|
+
import platform
|
|
8
9
|
import subprocess
|
|
9
10
|
from copy import deepcopy
|
|
10
11
|
from datetime import datetime
|
|
@@ -38,15 +39,19 @@ from simba.utils.checks import (check_file_exist_and_readable, check_float,
|
|
|
38
39
|
check_if_valid_rgb_tuple, check_instance,
|
|
39
40
|
check_int, check_str, check_that_column_exist,
|
|
40
41
|
check_that_hhmmss_start_is_before_end,
|
|
41
|
-
check_valid_array,
|
|
42
|
-
check_valid_dataframe,
|
|
43
|
-
|
|
42
|
+
check_valid_array, check_valid_boolean,
|
|
43
|
+
check_valid_cpu_pool, check_valid_dataframe,
|
|
44
|
+
check_valid_lst)
|
|
45
|
+
from simba.utils.enums import (OS, ConfigKey, Defaults, Dtypes, Formats, Keys,
|
|
46
|
+
Options)
|
|
44
47
|
from simba.utils.errors import (BodypartColumnNotFoundError, CountError,
|
|
45
48
|
InvalidFileTypeError, InvalidInputError,
|
|
46
49
|
NoFilesFoundError, NoROIDataError,
|
|
47
50
|
SimBAModuleNotFoundError)
|
|
51
|
+
from simba.utils.lookups import get_current_time
|
|
48
52
|
from simba.utils.printing import stdout_success, stdout_warning
|
|
49
|
-
from simba.utils.read_write import (
|
|
53
|
+
from simba.utils.read_write import (find_core_cnt, find_video_of_file,
|
|
54
|
+
get_current_time, get_fn_ext,
|
|
50
55
|
get_video_meta_data, read_config_entry,
|
|
51
56
|
read_config_file, read_df,
|
|
52
57
|
read_project_path_and_file_type,
|
|
@@ -1813,33 +1818,105 @@ def fft_lowpass_filter(data: np.ndarray, cut_off: float = 0.1) -> np.ndarray:
|
|
|
1813
1818
|
return results.astype(data.dtype)
|
|
1814
1819
|
|
|
1815
1820
|
|
|
1816
|
-
def terminate_cpu_pool(pool:
|
|
1817
|
-
force: bool = False
|
|
1821
|
+
def terminate_cpu_pool(pool: multiprocessing.pool.Pool,
|
|
1822
|
+
force: bool = False,
|
|
1823
|
+
verbose: bool = True,
|
|
1824
|
+
source: Optional[str] = None) -> None:
|
|
1818
1825
|
"""
|
|
1819
|
-
Safely terminates a multiprocessing.Pool instance.
|
|
1826
|
+
Safely terminates a multiprocessing.Pool instance with optional graceful shutdown.
|
|
1820
1827
|
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1828
|
+
.. note::
|
|
1829
|
+
If pool is None or invalid, function returns without action. Exceptions during termination are silently caught.
|
|
1830
|
+
|
|
1831
|
+
:param multiprocessing.pool.Pool pool: The multiprocessing pool to terminate. If None, function returns without action.
|
|
1832
|
+
:param bool force: If True, skips graceful shutdown (close/join) and immediately terminates. Default: False.
|
|
1833
|
+
:param bool verbose: If True, prints termination message with timestamp. Default: True.
|
|
1834
|
+
:param Optional[str] source: Optional identifier string for logging purposes (e.g., 'VideoProcessor'). Default: None.
|
|
1824
1835
|
|
|
1825
1836
|
:example:
|
|
1826
1837
|
>>> import multiprocessing
|
|
1827
1838
|
>>> pool = multiprocessing.Pool(4)
|
|
1828
|
-
>>> terminate_cpu_pool(pool)
|
|
1839
|
+
>>> terminate_cpu_pool(pool=pool, force=False, verbose=True, source='FeatureExtractor')
|
|
1829
1840
|
"""
|
|
1830
1841
|
if pool is None:
|
|
1831
1842
|
return
|
|
1832
|
-
check_valid_cpu_pool(value=pool, source=terminate_cpu_pool.__name__, raise_error=
|
|
1843
|
+
if not check_valid_cpu_pool(value=pool, source=terminate_cpu_pool.__name__, raise_error=False):
|
|
1844
|
+
return
|
|
1833
1845
|
try:
|
|
1846
|
+
core_cnt = pool._processes if hasattr(pool, '_processes') else None
|
|
1834
1847
|
if not force:
|
|
1835
1848
|
pool.close()
|
|
1836
1849
|
pool.join()
|
|
1837
1850
|
pool.terminate()
|
|
1851
|
+
if verbose: print(f'[{get_current_time()}] {"" if source is None else f"{core_cnt} core"} SimBA CPU pool {"" if source is None else source} terminated.')
|
|
1838
1852
|
except (ValueError, AssertionError, AttributeError):
|
|
1839
1853
|
pass
|
|
1840
1854
|
gc.collect()
|
|
1841
1855
|
|
|
1842
1856
|
|
|
1857
|
+
|
|
1858
|
+
def get_cpu_pool(core_cnt: int = -1,
|
|
1859
|
+
maxtasksperchild: int = Defaults.MAXIMUM_MAX_TASK_PER_CHILD.value,
|
|
1860
|
+
context: Literal['fork', 'spawn', 'forkserver'] = None,
|
|
1861
|
+
verbose: bool = True,
|
|
1862
|
+
source: Optional[str] = None) -> multiprocessing.Pool:
|
|
1863
|
+
"""
|
|
1864
|
+
Creates and returns a multiprocessing.Pool instance with platform-appropriate defaults and validation.
|
|
1865
|
+
|
|
1866
|
+
:param int core_cnt: Number of worker processes. -1 uses all available cores. Default: -1.
|
|
1867
|
+
:param int maxtasksperchild: Maximum number of tasks a worker process can complete before being replaced. Default: From Defaults.MAXIMUM_MAX_TASK_PER_CHILD.
|
|
1868
|
+
:param Optional[Literal['fork', 'spawn', 'forkserver']] context: Multiprocessing start method. None uses platform default. Default: None.
|
|
1869
|
+
:param bool verbose: If True, prints pool creation message with timestamp. Default: True.
|
|
1870
|
+
:param Optional[str] source: Optional identifier string for logging purposes (e.g., 'VideoProcessor'). Default: None.
|
|
1871
|
+
:return: Configured multiprocessing.Pool instance.
|
|
1872
|
+
:rtype: multiprocessing.Pool
|
|
1873
|
+
|
|
1874
|
+
:example:
|
|
1875
|
+
>>> pool = get_cpu_pool(core_cnt=4, source='FeatureExtractor')
|
|
1876
|
+
>>> pool = get_cpu_pool(core_cnt=-1, context='spawn', verbose=True)
|
|
1877
|
+
>>> pool = get_cpu_pool(core_cnt=8, maxtasksperchild=100, source='VideoProcessor')
|
|
1878
|
+
"""
|
|
1879
|
+
|
|
1880
|
+
check_int(name=f'{get_cpu_pool.__name__} core_cnt', min_value=-1, unaccepted_vals=[0], value=core_cnt, raise_error=True)
|
|
1881
|
+
check_int(name=f'{get_cpu_pool.__name__} maxtasksperchild', min_value=1, value=maxtasksperchild, raise_error=True)
|
|
1882
|
+
check_valid_boolean(value=verbose, source=f'{get_cpu_pool.__name__} verbose', raise_error=True)
|
|
1883
|
+
if source is not None: check_str(name=f'{get_cpu_pool.__name__} source', value=source, raise_error=True, allow_blank=True)
|
|
1884
|
+
current_process = multiprocessing.current_process()
|
|
1885
|
+
if current_process.name != 'MainProcess': core_cnt = 1
|
|
1886
|
+
core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
|
|
1887
|
+
if verbose: print(f'[{get_current_time()}] {core_cnt} core SimBA CPU pool {"" if source is None else source} started.')
|
|
1888
|
+
if context is not None:
|
|
1889
|
+
check_str(name=f'{get_cpu_pool.__name__} context', value=context, options=('fork', 'spawn', 'forkserver'), raise_error=True)
|
|
1890
|
+
else:
|
|
1891
|
+
existing_method = multiprocessing.get_start_method(allow_none=True)
|
|
1892
|
+
if existing_method is not None:
|
|
1893
|
+
context = existing_method
|
|
1894
|
+
else:
|
|
1895
|
+
system = platform.system()
|
|
1896
|
+
if system == OS.WINDOWS.value: context = OS.SPAWN.value
|
|
1897
|
+
elif system == OS.MAC.value: context = OS.SPAWN.value
|
|
1898
|
+
else: context = OS.FORK.value
|
|
1899
|
+
|
|
1900
|
+
if context is not None:
|
|
1901
|
+
try:
|
|
1902
|
+
ctx = multiprocessing.get_context(context)
|
|
1903
|
+
except ValueError:
|
|
1904
|
+
system = platform.system()
|
|
1905
|
+
if system == OS.WINDOWS.value: fallback_context = OS.SPAWN.value
|
|
1906
|
+
elif system == OS.MAC.value: fallback_context = OS.SPAWN.value
|
|
1907
|
+
else: fallback_context = OS.FORK.value
|
|
1908
|
+
try:
|
|
1909
|
+
ctx = multiprocessing.get_context(fallback_context)
|
|
1910
|
+
except ValueError:
|
|
1911
|
+
pool = multiprocessing.Pool(processes=core_cnt, maxtasksperchild=maxtasksperchild)
|
|
1912
|
+
return pool
|
|
1913
|
+
pool = ctx.Pool(processes=core_cnt, maxtasksperchild=maxtasksperchild)
|
|
1914
|
+
else:
|
|
1915
|
+
pool = multiprocessing.Pool(processes=core_cnt, maxtasksperchild=maxtasksperchild)
|
|
1916
|
+
return pool
|
|
1917
|
+
|
|
1918
|
+
|
|
1919
|
+
#get_cpu_pool()
|
|
1843
1920
|
# run_user_defined_feature_extraction_class(config_path='/Users/simon/Desktop/envs/troubleshooting/circular_features_zebrafish/project_folder/project_config.ini', file_path='/Users/simon/Desktop/fish_feature_extractor_2023_version_5.py')
|
|
1844
1921
|
|
|
1845
1922
|
|
simba/utils/enums.py
CHANGED
|
@@ -127,6 +127,7 @@ class OS(Enum):
|
|
|
127
127
|
LINUX = "Linux"
|
|
128
128
|
MAC = "Darwin"
|
|
129
129
|
SPAWN = 'spawn'
|
|
130
|
+
FORK = 'fork'
|
|
130
131
|
PYTHON_VER = str(f"{sys.version_info.major}.{sys.version_info.minor}")
|
|
131
132
|
try:
|
|
132
133
|
SIMBA_VERSION = pkg_resources.get_distribution("simba-uw-tf-dev").version
|
simba/utils/printing.py
CHANGED
|
@@ -9,6 +9,7 @@ except:
|
|
|
9
9
|
import logging
|
|
10
10
|
import time
|
|
11
11
|
from typing import Optional
|
|
12
|
+
from datetime import datetime
|
|
12
13
|
|
|
13
14
|
from simba.utils.enums import Defaults, TagNames
|
|
14
15
|
|
|
@@ -25,9 +26,9 @@ def stdout_success(msg: str, source: Optional[str] = "", elapsed_time: Optional[
|
|
|
25
26
|
|
|
26
27
|
log_event(logger_name=f"{source}.{stdout_success.__name__}", log_type=TagNames.COMPLETE.value, msg=msg)
|
|
27
28
|
if elapsed_time:
|
|
28
|
-
print(f"SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.COMPLETE.value}")
|
|
29
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.COMPLETE.value}")
|
|
29
30
|
else:
|
|
30
|
-
print(f"SIMBA COMPLETE: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.COMPLETE.value}")
|
|
31
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.COMPLETE.value}")
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def stdout_warning(msg: str, elapsed_time: Optional[str] = None) -> None:
|
|
@@ -41,9 +42,9 @@ def stdout_warning(msg: str, elapsed_time: Optional[str] = None) -> None:
|
|
|
41
42
|
"""
|
|
42
43
|
|
|
43
44
|
if elapsed_time:
|
|
44
|
-
print(f"SIMBA WARNING: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.WARNING.value}")
|
|
45
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA WARNING: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.WARNING.value}")
|
|
45
46
|
else:
|
|
46
|
-
print(f"SIMBA WARNING: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.WARNING.value}")
|
|
47
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA WARNING: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.WARNING.value}")
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
def stdout_trash(msg: str, source: Optional[str] = "", elapsed_time: Optional[str] = None) -> None:
|
|
@@ -58,9 +59,9 @@ def stdout_trash(msg: str, source: Optional[str] = "", elapsed_time: Optional[st
|
|
|
58
59
|
|
|
59
60
|
log_event(logger_name=f"{source}.{stdout_trash.__name__}", log_type=TagNames.TRASH.value, msg=msg)
|
|
60
61
|
if elapsed_time:
|
|
61
|
-
print(f"SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.TRASH.value}")
|
|
62
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.TRASH.value}")
|
|
62
63
|
else:
|
|
63
|
-
print(f"SIMBA COMPLETE: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.TRASH.value}")
|
|
64
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.TRASH.value}")
|
|
64
65
|
|
|
65
66
|
|
|
66
67
|
def stdout_information(msg: str, source: Optional[str] = "", elapsed_time: Optional[str] = None) -> None:
|
|
@@ -75,9 +76,9 @@ def stdout_information(msg: str, source: Optional[str] = "", elapsed_time: Optio
|
|
|
75
76
|
|
|
76
77
|
log_event(logger_name=f"{source}.{stdout_trash.__name__}", log_type=TagNames.INFORMATION.value, msg=msg)
|
|
77
78
|
if elapsed_time:
|
|
78
|
-
print(f"SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.INFORMATION.value}")
|
|
79
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.INFORMATION.value}")
|
|
79
80
|
else:
|
|
80
|
-
print(f"{msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.INFORMATION.value}")
|
|
81
|
+
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.INFORMATION.value}")
|
|
81
82
|
|
|
82
83
|
|
|
83
84
|
class SimbaTimer(object):
|