simba-uw-tf-dev 4.6.3__py3-none-any.whl → 4.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. simba/assets/lookups/tooptips.json +6 -1
  2. simba/data_processors/cuda/geometry.py +45 -27
  3. simba/data_processors/cuda/image.py +1620 -1600
  4. simba/data_processors/cuda/statistics.py +17 -9
  5. simba/data_processors/egocentric_aligner.py +24 -6
  6. simba/data_processors/kleinberg_calculator.py +61 -29
  7. simba/feature_extractors/feature_subsets.py +12 -5
  8. simba/feature_extractors/straub_tail_analyzer.py +0 -2
  9. simba/mixins/statistics_mixin.py +9 -2
  10. simba/plotting/gantt_creator_mp.py +7 -5
  11. simba/plotting/pose_plotter_mp.py +7 -3
  12. simba/plotting/roi_plotter_mp.py +4 -3
  13. simba/plotting/yolo_pose_track_visualizer.py +3 -2
  14. simba/plotting/yolo_pose_visualizer.py +5 -4
  15. simba/sandbox/analyze_runtimes.py +30 -0
  16. simba/sandbox/cuda/egocentric_rotator.py +374 -374
  17. simba/sandbox/proboscis_to_tip.py +28 -0
  18. simba/sandbox/test_directionality.py +47 -0
  19. simba/sandbox/test_nonstatic_directionality.py +27 -0
  20. simba/sandbox/test_pycharm_cuda.py +51 -0
  21. simba/sandbox/test_simba_install.py +41 -0
  22. simba/sandbox/test_static_directionality.py +26 -0
  23. simba/sandbox/test_static_directionality_2d.py +26 -0
  24. simba/sandbox/verify_env.py +42 -0
  25. simba/ui/pop_ups/fsttc_pop_up.py +27 -25
  26. simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
  27. simba/ui/tkinter_functions.py +3 -0
  28. simba/utils/data.py +0 -1
  29. simba/utils/errors.py +441 -440
  30. simba/utils/lookups.py +1203 -1203
  31. simba/utils/printing.py +124 -125
  32. simba/utils/read_write.py +43 -14
  33. simba/video_processors/egocentric_video_rotator.py +41 -36
  34. simba/video_processors/video_processing.py +5247 -5233
  35. simba/video_processors/videos_to_frames.py +41 -31
  36. {simba_uw_tf_dev-4.6.3.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/METADATA +2 -2
  37. {simba_uw_tf_dev-4.6.3.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/RECORD +41 -32
  38. {simba_uw_tf_dev-4.6.3.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/LICENSE +0 -0
  39. {simba_uw_tf_dev-4.6.3.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/WHEEL +0 -0
  40. {simba_uw_tf_dev-4.6.3.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/entry_points.txt +0 -0
  41. {simba_uw_tf_dev-4.6.3.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/top_level.txt +0 -0
simba/utils/printing.py CHANGED
@@ -1,125 +1,124 @@
1
- __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
-
3
- try:
4
- from typing import Literal
5
- except:
6
- from typing_extensions import Literal
7
-
8
-
9
- import logging
10
- import time
11
- from typing import Optional
12
- from datetime import datetime
13
-
14
- from simba.utils.enums import Defaults, TagNames
15
-
16
-
17
- def stdout_success(msg: str, source: Optional[str] = "", elapsed_time: Optional[str] = None) -> None:
18
- """
19
- Helper to parse msg of completed operation to SimBA main interface.
20
-
21
- :param str msg: Message to be parsed.
22
- :param Optional[str] source: Optional string indicating the source method or function of the msg for logging.
23
- :param Optional[str] elapsed_time: Optional string indicating the runtime of the completed operation.
24
- :return None:
25
- """
26
-
27
- log_event(logger_name=f"{source}.{stdout_success.__name__}", log_type=TagNames.COMPLETE.value, msg=msg)
28
- if elapsed_time:
29
- print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.COMPLETE.value}")
30
- else:
31
- print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.COMPLETE.value}")
32
-
33
-
34
- def stdout_warning(msg: str, elapsed_time: Optional[str] = None) -> None:
35
- """
36
- Helper to parse warning msg to SimBA main interface.
37
-
38
- :param str msg: Message to be parsed.
39
- :param Optional[str] source: Optional string indicating the source method or function of the msg for logging.
40
- :param elapsed_time: Optional string indicating the runtime.
41
- :return None:
42
- """
43
-
44
- if elapsed_time:
45
- print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA WARNING: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.WARNING.value}")
46
- else:
47
- print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA WARNING: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.WARNING.value}")
48
-
49
-
50
- def stdout_trash(msg: str, source: Optional[str] = "", elapsed_time: Optional[str] = None) -> None:
51
- """
52
- Helper to parse msg of delete operation to SimBA main interface.
53
-
54
- :param str msg: Message to be parsed.
55
- :param Optional[str] source: Optional string indicating the source method or function of the operation for logging.
56
- :param elapsed_time: Optional string indicating the runtime.
57
- :return None:
58
- """
59
-
60
- log_event(logger_name=f"{source}.{stdout_trash.__name__}", log_type=TagNames.TRASH.value, msg=msg)
61
- if elapsed_time:
62
- print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.TRASH.value}")
63
- else:
64
- print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.TRASH.value}")
65
-
66
-
67
- def stdout_information(msg: str, source: Optional[str] = "", elapsed_time: Optional[str] = None) -> None:
68
- """
69
- Helper to parse information msg to SimBA main interface. E.g., how many monitors and their resolutions which is available.
70
-
71
- :param str msg: Message to be parsed.
72
- :param Optional[str] source: Optional string indicating the source method or function of the operation for logging.
73
- :param elapsed_time: Optional string indicating the runtime.
74
- :return None:
75
- """
76
-
77
- log_event(logger_name=f"{source}.{stdout_trash.__name__}", log_type=TagNames.INFORMATION.value, msg=msg)
78
- if elapsed_time:
79
- print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.INFORMATION.value}")
80
- else:
81
- print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.INFORMATION.value}")
82
-
83
-
84
- class SimbaTimer(object):
85
- """Timer class for keeping track of start and end-times of calls"""
86
-
87
- def __init__(self, start: bool = False):
88
- if start:
89
- self.start_timer()
90
-
91
- def start_timer(self):
92
- self.timer = time.time()
93
-
94
- def stop_timer(self):
95
- if not hasattr(self, "timer"):
96
- self.elapsed_time = -1
97
- self.elapsed_time_str = "-1"
98
- else:
99
- self.elapsed_time = round(time.time() - self.timer, 4)
100
- self.elapsed_time_str = str(self.elapsed_time)
101
-
102
-
103
- def log_event(logger_name: str, log_type: Literal["CLASS_INIT", "error", "warning"], msg: str):
104
- logger = logging.getLogger(str(logger_name))
105
- if log_type == TagNames.CLASS_INIT.value:
106
- logger.info(f"{TagNames.CLASS_INIT.value}||{msg}")
107
- elif log_type == TagNames.ERROR.value:
108
- logger.error(f"{TagNames.ERROR.value}||{msg}")
109
- elif log_type == TagNames.WARNING.value:
110
- logger.warning(f"{TagNames.WARNING.value}||{msg}")
111
- elif log_type == TagNames.TRASH.value:
112
- logger.info(f"{TagNames.TRASH.value}||{msg}")
113
- elif log_type == TagNames.COMPLETE.value:
114
- logger.info(f"{TagNames.COMPLETE.value}||{msg}")
115
-
116
-
117
- def perform_timing(func):
118
- def decorator(*args, **kwargs):
119
- timer = SimbaTimer(start=True)
120
- results = func(*args, **kwargs, _timer=timer)
121
- timer.stop_timer()
122
- results["timer"] = timer.elapsed_time_str
123
- return results
124
-
125
- return decorator
1
+ __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
+
3
+ try:
4
+ from typing import Literal
5
+ except:
6
+ from typing_extensions import Literal
7
+
8
+ import logging
9
+ import time
10
+ from datetime import datetime
11
+ from typing import Optional
12
+
13
+ from simba.utils.enums import Defaults, TagNames
14
+
15
+
16
+ def stdout_success(msg: str, source: Optional[str] = "", elapsed_time: Optional[str] = None) -> None:
17
+ """
18
+ Helper to parse msg of completed operation to SimBA main interface.
19
+
20
+ :param str msg: Message to be parsed.
21
+ :param Optional[str] source: Optional string indicating the source method or function of the msg for logging.
22
+ :param Optional[str] elapsed_time: Optional string indicating the runtime of the completed operation.
23
+ :return None:
24
+ """
25
+
26
+ log_event(logger_name=f"{source}.{stdout_success.__name__}", log_type=TagNames.COMPLETE.value, msg=msg)
27
+ if elapsed_time:
28
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.COMPLETE.value}")
29
+ else:
30
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.COMPLETE.value}")
31
+
32
+
33
+ def stdout_warning(msg: str, elapsed_time: Optional[str] = None) -> None:
34
+ """
35
+ Helper to parse warning msg to SimBA main interface.
36
+
37
+ :param str msg: Message to be parsed.
38
+ :param Optional[str] source: Optional string indicating the source method or function of the msg for logging.
39
+ :param elapsed_time: Optional string indicating the runtime.
40
+ :return None:
41
+ """
42
+
43
+ if elapsed_time:
44
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA WARNING: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.WARNING.value}")
45
+ else:
46
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA WARNING: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.WARNING.value}")
47
+
48
+
49
+ def stdout_trash(msg: str, source: Optional[str] = "", elapsed_time: Optional[str] = None) -> None:
50
+ """
51
+ Helper to parse msg of delete operation to SimBA main interface.
52
+
53
+ :param str msg: Message to be parsed.
54
+ :param Optional[str] source: Optional string indicating the source method or function of the operation for logging.
55
+ :param elapsed_time: Optional string indicating the runtime.
56
+ :return None:
57
+ """
58
+
59
+ log_event(logger_name=f"{source}.{stdout_trash.__name__}", log_type=TagNames.TRASH.value, msg=msg)
60
+ if elapsed_time:
61
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.TRASH.value}")
62
+ else:
63
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.TRASH.value}")
64
+
65
+
66
+ def stdout_information(msg: str, source: Optional[str] = "", elapsed_time: Optional[str] = None) -> None:
67
+ """
68
+ Helper to parse information msg to SimBA main interface. E.g., how many monitors and their resolutions which is available.
69
+
70
+ :param str msg: Message to be parsed.
71
+ :param Optional[str] source: Optional string indicating the source method or function of the operation for logging.
72
+ :param elapsed_time: Optional string indicating the runtime.
73
+ :return None:
74
+ """
75
+
76
+ log_event(logger_name=f"{source}.{stdout_trash.__name__}", log_type=TagNames.INFORMATION.value, msg=msg)
77
+ if elapsed_time:
78
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] SIMBA COMPLETE: {msg} (elapsed time: {elapsed_time}s) {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.INFORMATION.value}")
79
+ else:
80
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg} {Defaults.STR_SPLIT_DELIMITER.value}{TagNames.INFORMATION.value}")
81
+
82
+
83
+ class SimbaTimer(object):
84
+ """Timer class for keeping track of start and end-times of calls"""
85
+
86
+ def __init__(self, start: bool = False):
87
+ if start:
88
+ self.start_timer()
89
+
90
+ def start_timer(self):
91
+ self.timer = time.time()
92
+
93
+ def stop_timer(self):
94
+ if not hasattr(self, "timer"):
95
+ self.elapsed_time = -1
96
+ self.elapsed_time_str = "-1"
97
+ else:
98
+ self.elapsed_time = round(time.time() - self.timer, 4)
99
+ self.elapsed_time_str = str(self.elapsed_time)
100
+
101
+
102
+ def log_event(logger_name: str, log_type: Literal["CLASS_INIT", "error", "warning"], msg: str):
103
+ logger = logging.getLogger(str(logger_name))
104
+ if log_type == TagNames.CLASS_INIT.value:
105
+ logger.info(f"{TagNames.CLASS_INIT.value}||{msg}")
106
+ elif log_type == TagNames.ERROR.value:
107
+ logger.error(f"{TagNames.ERROR.value}||{msg}")
108
+ elif log_type == TagNames.WARNING.value:
109
+ logger.warning(f"{TagNames.WARNING.value}||{msg}")
110
+ elif log_type == TagNames.TRASH.value:
111
+ logger.info(f"{TagNames.TRASH.value}||{msg}")
112
+ elif log_type == TagNames.COMPLETE.value:
113
+ logger.info(f"{TagNames.COMPLETE.value}||{msg}")
114
+
115
+
116
+ def perform_timing(func):
117
+ def decorator(*args, **kwargs):
118
+ timer = SimbaTimer(start=True)
119
+ results = func(*args, **kwargs, _timer=timer)
120
+ timer.stop_timer()
121
+ results["timer"] = timer.elapsed_time_str
122
+ return results
123
+
124
+ return decorator
simba/utils/read_write.py CHANGED
@@ -560,7 +560,11 @@ def get_video_info_ffmpeg(video_path: Union[str, os.PathLike]) -> Dict[str, Any]
560
560
 
561
561
  def remove_a_folder(folder_dir: Union[str, os.PathLike], ignore_errors: Optional[bool] = True) -> None:
562
562
  """Helper to remove a directory"""
563
- check_if_dir_exists(in_dir=folder_dir, source=remove_a_folder.__name__)
563
+ valid_dir = check_if_dir_exists(in_dir=folder_dir, source=remove_a_folder.__name__, raise_error=False)
564
+ if not valid_dir and not ignore_errors:
565
+ raise NotDirectoryError(msg=f'Cannot delete directory {folder_dir}: The directory does not exist', source=remove_a_folder.__name__)
566
+ if not valid_dir and ignore_errors:
567
+ return
564
568
  try:
565
569
  shutil.rmtree(folder_dir, ignore_errors=ignore_errors)
566
570
  except Exception as e:
@@ -2456,6 +2460,13 @@ def read_img_batch_from_video_gpu(video_path: Union[str, os.PathLike],
2456
2460
  """
2457
2461
  Reads a batch of frames from a video file using GPU acceleration.
2458
2462
 
2463
+ .. csv-table::
2464
+ :header: EXPECTED RUNTIMES
2465
+ :file: ../../docs/tables/read_img_batch_from_video_gpu.csv
2466
+ :widths: 10, 45, 45
2467
+ :align: center
2468
+ :header-rows: 1
2469
+
2459
2470
  This function uses FFmpeg with CUDA acceleration to read frames from a specified range in a video file. It supports both RGB and greyscale video formats. Frames are returned as a dictionary where the keys are
2460
2471
  frame indices and the values are NumPy arrays representing the image data.
2461
2472
 
@@ -2464,7 +2475,7 @@ def read_img_batch_from_video_gpu(video_path: Union[str, os.PathLike],
2464
2475
  If you expect that the video you are reading in is black and white, set ``black_and_white`` to True to round any of these wonly value sto 0 and 255.
2465
2476
 
2466
2477
  .. seealso::
2467
- For CPU multicore acceleration, see :func:`simba.mixins.image_mixin.ImageMixin.read_img_batch_from_video`
2478
+ For CPU multicore acceleration, see :func:`simba.mixins.image_mixin.ImageMixin.read_img_batch_from_video` or :func:`simba.utils.read_write.read_img_batch_from_video`.
2468
2479
 
2469
2480
  :param video_path: Path to the video file. Can be a string or an os.PathLike object.
2470
2481
  :param start_frm: The starting frame index to read. If None, starts from the beginning of the video.
@@ -2475,6 +2486,7 @@ def read_img_batch_from_video_gpu(video_path: Union[str, os.PathLike],
2475
2486
  :return: A dictionary where keys are frame indices (integers) and values are NumPy arrays containing the image data of each frame.
2476
2487
  """
2477
2488
 
2489
+ timer = SimbaTimer(start=True)
2478
2490
  check_file_exist_and_readable(file_path=video_path)
2479
2491
  video_meta_data = get_video_meta_data(video_path=video_path, fps_as_int=False)
2480
2492
  if start_frm is not None:
@@ -2547,6 +2559,10 @@ def read_img_batch_from_video_gpu(video_path: Union[str, os.PathLike],
2547
2559
  binary_frms[frm_id] = np.where(frames[frm_id] > 127, 255, 0).astype(np.uint8)
2548
2560
  frames = binary_frms
2549
2561
 
2562
+ timer.stop_timer()
2563
+ if verbose:
2564
+ print(f'[{get_current_time()}] Read frames {start_frm}-{end_frm} (video: {video_name}, elapsed time: {timer.elapsed_time_str}s)')
2565
+
2550
2566
  return frames
2551
2567
 
2552
2568
 
@@ -3152,7 +3168,7 @@ def _read_img_batch_from_video_helper(frm_idx: np.ndarray, video_path: Union[str
3152
3168
  cap.set(1, current_frm)
3153
3169
  while current_frm < end_frm:
3154
3170
  if verbose:
3155
- print(f'Reading frame {current_frm}/{video_meta_data["frame_count"]} ({video_meta_data["video_name"]})...')
3171
+ print(f'[{get_current_time()}] Reading frame {current_frm} ({video_meta_data["video_name"]})...')
3156
3172
  img = cap.read()[1]
3157
3173
  if img is not None:
3158
3174
  if greyscale or black_and_white or clahe:
@@ -3184,6 +3200,14 @@ def read_img_batch_from_video(video_path: Union[str, os.PathLike],
3184
3200
  """
3185
3201
  Read a batch of frames from a video file. This method reads frames from a specified range of frames within a video file using multiprocessing.
3186
3202
 
3203
+ .. csv-table::
3204
+ :header: EXPECTED RUNTIMES
3205
+ :file: ../../docs/tables/read_img_batch_from_video.csv
3206
+ :widths: 10, 45, 45
3207
+ :align: center
3208
+ :header-rows: 1
3209
+
3210
+
3187
3211
  .. seealso::
3188
3212
  For GPU acceleration, see :func:`simba.utils.read_write.read_img_batch_from_video_gpu`
3189
3213
 
@@ -3205,6 +3229,8 @@ def read_img_batch_from_video(video_path: Union[str, os.PathLike],
3205
3229
  >>> read_img_batch_from_video(video_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/videos/Together_1.avi', start_frm=0, end_frm=50)
3206
3230
  """
3207
3231
 
3232
+
3233
+ timer = SimbaTimer(start=True)
3208
3234
  if platform.system() == "Darwin":
3209
3235
  if not multiprocessing.get_start_method(allow_none=True):
3210
3236
  multiprocessing.set_start_method("fork", force=True)
@@ -3226,19 +3252,22 @@ def read_img_batch_from_video(video_path: Union[str, os.PathLike],
3226
3252
  if end_frm <= start_frm:
3227
3253
  FrameRangeError(msg=f"Start frame ({start_frm}) has to be before end frame ({end_frm})", source=read_img_batch_from_video.__name__)
3228
3254
  frm_lst = np.array_split(np.arange(start_frm, end_frm + 1), core_cnt)
3255
+ pool = multiprocessing.Pool(core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value)
3229
3256
  results = {}
3230
- with multiprocessing.Pool(core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value) as pool:
3231
- constants = functools.partial(_read_img_batch_from_video_helper,
3232
- video_path=video_path,
3233
- greyscale=greyscale,
3234
- black_and_white=black_and_white,
3235
- clahe=clahe,
3236
- verbose=verbose)
3237
- for cnt, result in enumerate(pool.imap(constants, frm_lst, chunksize=1)):
3238
- results.update(result)
3239
- pool.join()
3257
+ constants = functools.partial(_read_img_batch_from_video_helper,
3258
+ video_path=video_path,
3259
+ greyscale=greyscale,
3260
+ black_and_white=black_and_white,
3261
+ clahe=clahe,
3262
+ verbose=verbose)
3263
+ for cnt, result in enumerate(pool.imap(constants, frm_lst, chunksize=1)):
3264
+ results.update(result)
3240
3265
  pool.close()
3241
- #terminate_cpu_pool(pool=pool, force=False)
3266
+ pool.join()
3267
+ pool.terminate()
3268
+ timer.stop_timer()
3269
+ if verbose:
3270
+ print(f'[{get_current_time()}] Read frames {start_frm}-{end_frm} (video: {video_meta_data["video_name"]}, elapsed time: {timer.elapsed_time_str}s)')
3242
3271
  return results
3243
3272
 
3244
3273
  def read_yolo_bp_names_file(file_path: Union[str, os.PathLike]) -> Tuple[str]:
@@ -9,10 +9,12 @@ import numpy as np
9
9
  from simba.utils.checks import (check_file_exist_and_readable,
10
10
  check_if_dir_exists, check_if_valid_rgb_tuple,
11
11
  check_int, check_valid_array,
12
- check_valid_boolean, check_valid_tuple)
12
+ check_valid_boolean, check_valid_cpu_pool,
13
+ check_valid_tuple)
13
14
  from simba.utils.data import (align_target_warpaffine_vectors,
14
15
  center_rotation_warpaffine_vectors,
15
- egocentrically_align_pose, terminate_cpu_pool)
16
+ egocentrically_align_pose, get_cpu_pool,
17
+ terminate_cpu_pool)
16
18
  from simba.utils.enums import Defaults, Formats
17
19
  from simba.utils.printing import SimbaTimer, stdout_success
18
20
  from simba.utils.read_write import (concatenate_videos_in_folder,
@@ -114,7 +116,8 @@ class EgocentricVideoRotator():
114
116
  fill_clr: Tuple[int, int, int] = (0, 0, 0),
115
117
  core_cnt: int = -1,
116
118
  save_path: Optional[Union[str, os.PathLike]] = None,
117
- gpu: Optional[bool] = True):
119
+ gpu: Optional[bool] = True,
120
+ pool: bool = None):
118
121
 
119
122
  check_file_exist_and_readable(file_path=video_path)
120
123
  self.video_meta_data = get_video_meta_data(video_path=video_path)
@@ -125,10 +128,14 @@ class EgocentricVideoRotator():
125
128
  check_valid_boolean(value=[verbose], source=f'{self.__class__.__name__} verbose')
126
129
  check_if_valid_rgb_tuple(data=fill_clr)
127
130
  check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0])
128
- if core_cnt > find_core_cnt()[0] or core_cnt == -1:
129
- self.core_cnt = find_core_cnt()[0]
131
+ if core_cnt > find_core_cnt()[0] or core_cnt == -1: self.core_cnt = find_core_cnt()[0]
132
+ else: self.core_cnt = core_cnt
133
+ if pool is not None:
134
+ check_valid_cpu_pool(value=pool, source=self.__class__.__name__, max_cores=find_core_cnt()[0], min_cores=2, raise_error=True)
135
+ self.pool_termination_flag = True
130
136
  else:
131
- self.core_cnt = core_cnt
137
+ self.pool_termination_flag = False
138
+ self.pool = get_cpu_pool(core_cnt=self.core_cnt, source=self.__class__.__name__) if pool is None else pool
132
139
  video_dir, self.video_name, _ = get_fn_ext(filepath=video_path)
133
140
  if save_path is not None:
134
141
  self.save_dir = os.path.dirname(save_path)
@@ -151,37 +158,35 @@ class EgocentricVideoRotator():
151
158
  frm_list = np.arange(0, self.video_meta_data['frame_count'])
152
159
  frm_list = np.array_split(frm_list, self.core_cnt)
153
160
  frm_list = [(cnt, x) for cnt, x in enumerate(frm_list)]
154
- if self.verbose:
155
- print(f"Creating rotated video {self.video_name}, multiprocessing (chunksize: {1}, cores: {self.core_cnt})...")
156
- with multiprocessing.Pool(self.core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value) as pool:
157
- constants = functools.partial(egocentric_video_aligner,
158
- temp_dir=temp_dir,
159
- video_name=self.video_name,
160
- video_path=self.video_path,
161
- centers=self.centers,
162
- rotation_vectors=self.rotation_vectors,
163
- target=self.anchor_loc,
164
- verbose=self.verbose,
165
- fill_clr=self.fill_clr,
166
- gpu=self.gpu)
167
- for cnt, result in enumerate(pool.imap(constants, frm_list, chunksize=1)):
168
- if self.verbose:
169
- print(f"Rotate batch {result}/{self.core_cnt} complete...")
170
- terminate_cpu_pool(pool=pool, force=False)
161
+ if self.verbose: print(f"Creating rotated video {self.video_name}, multiprocessing (chunksize: {1}, cores: {self.core_cnt})...")
162
+
163
+ constants = functools.partial(egocentric_video_aligner,
164
+ temp_dir=temp_dir,
165
+ video_name=self.video_name,
166
+ video_path=self.video_path,
167
+ centers=self.centers,
168
+ rotation_vectors=self.rotation_vectors,
169
+ target=self.anchor_loc,
170
+ verbose=self.verbose,
171
+ fill_clr=self.fill_clr,
172
+ gpu=self.gpu)
173
+ for cnt, result in enumerate(self.pool.imap(constants, frm_list, chunksize=1)):
174
+ if self.verbose: print(f"Rotate batch {result}/{self.core_cnt} complete...")
175
+ if self.pool_termination_flag: terminate_cpu_pool(pool=self.pool, force=False)
171
176
  concatenate_videos_in_folder(in_folder=temp_dir, save_path=self.save_path, remove_splits=True, gpu=self.gpu, verbose=self.verbose)
172
177
  video_timer.stop_timer()
173
178
  stdout_success(msg=f"Egocentric rotation video {self.save_path} complete", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__)
174
179
 
175
- if __name__ == "__main__":
176
- DATA_PATH = r"C:\Users\sroni\OneDrive\Desktop\desktop\rotate_ex\data\501_MA142_Gi_Saline_0513.csv"
177
- VIDEO_PATH = r"C:\Users\sroni\OneDrive\Desktop\desktop\rotate_ex\videos\501_MA142_Gi_Saline_0513.mp4"
178
- SAVE_PATH = r"C:\Users\sroni\OneDrive\Desktop\desktop\rotate_ex\videos\501_MA142_Gi_Saline_0513_rotated.mp4"
179
- ANCHOR_LOC = np.array([250, 250])
180
-
181
- df = read_df(file_path=DATA_PATH, file_type='csv')
182
- bp_cols = [x for x in df.columns if not x.endswith('_p')]
183
- data = df[bp_cols].values.reshape(len(df), int(len(bp_cols)/2), 2).astype(np.int32)
184
-
185
- _, centers, rotation_vectors = egocentrically_align_pose(data=data, anchor_1_idx=5, anchor_2_idx=2, anchor_location=ANCHOR_LOC, direction=0)
186
- rotater = EgocentricVideoRotator(video_path=VIDEO_PATH, centers=centers, rotation_vectors=rotation_vectors, anchor_location=(400, 100), save_path=SAVE_PATH, verbose=True, core_cnt=16)
187
- rotater.run()
180
+ # if __name__ == "__main__":
181
+ # DATA_PATH = r"C:\Users\sroni\OneDrive\Desktop\desktop\rotate_ex\data\501_MA142_Gi_Saline_0513.csv"
182
+ # VIDEO_PATH = r"C:\Users\sroni\OneDrive\Desktop\desktop\rotate_ex\videos\501_MA142_Gi_Saline_0513.mp4"
183
+ # SAVE_PATH = r"C:\Users\sroni\OneDrive\Desktop\desktop\rotate_ex\videos\501_MA142_Gi_Saline_0513_rotated.mp4"
184
+ # ANCHOR_LOC = np.array([250, 250])
185
+ #
186
+ # df = read_df(file_path=DATA_PATH, file_type='csv')
187
+ # bp_cols = [x for x in df.columns if not x.endswith('_p')]
188
+ # data = df[bp_cols].values.reshape(len(df), int(len(bp_cols)/2), 2).astype(np.int32)
189
+ #
190
+ # _, centers, rotation_vectors = egocentrically_align_pose(data=data, anchor_1_idx=5, anchor_2_idx=2, anchor_location=ANCHOR_LOC, direction=0)
191
+ # rotater = EgocentricVideoRotator(video_path=VIDEO_PATH, centers=centers, rotation_vectors=rotation_vectors, anchor_location=(400, 100), save_path=SAVE_PATH, verbose=True, core_cnt=16)
192
+ # rotater.run()