simba-uw-tf-dev 4.7.1__py3-none-any.whl → 4.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simba/SimBA.py +13 -4
- simba/assets/icons/left_arrow_green.png +0 -0
- simba/assets/icons/left_arrow_red.png +0 -0
- simba/assets/icons/right_arrow_green.png +0 -0
- simba/assets/icons/right_arrow_red.png +0 -0
- simba/assets/lookups/yolo_schematics/yolo_mitra.csv +9 -0
- simba/mixins/geometry_mixin.py +357 -302
- simba/mixins/image_mixin.py +129 -4
- simba/mixins/train_model_mixin.py +1 -4
- simba/model/inference_batch.py +1 -1
- simba/model/yolo_fit.py +22 -15
- simba/model/yolo_pose_inference.py +7 -2
- simba/outlier_tools/skip_outlier_correction.py +2 -2
- simba/plotting/heat_mapper_clf_mp.py +45 -23
- simba/plotting/plot_clf_results.py +2 -1
- simba/plotting/plot_clf_results_mp.py +456 -455
- simba/roi_tools/roi_utils.py +2 -2
- simba/sandbox/convert_h264_to_mp4_lossless.py +129 -0
- simba/sandbox/extract_and_convert_videos.py +257 -0
- simba/sandbox/remove_end_of_video.py +80 -0
- simba/sandbox/video_timelaps.py +291 -0
- simba/third_party_label_appenders/transform/simba_to_yolo.py +8 -5
- simba/ui/import_pose_frame.py +13 -13
- simba/ui/pop_ups/clf_plot_pop_up.py +1 -1
- simba/ui/pop_ups/run_machine_models_popup.py +22 -22
- simba/ui/pop_ups/simba_to_yolo_keypoints_popup.py +2 -2
- simba/ui/pop_ups/video_processing_pop_up.py +3638 -3469
- simba/ui/pop_ups/yolo_inference_popup.py +1 -1
- simba/ui/pop_ups/yolo_pose_train_popup.py +1 -1
- simba/ui/tkinter_functions.py +3 -1
- simba/ui/video_timelaps.py +454 -0
- simba/utils/lookups.py +67 -1
- simba/utils/read_write.py +10 -3
- simba/video_processors/batch_process_create_ffmpeg_commands.py +0 -1
- simba/video_processors/video_processing.py +160 -39
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/METADATA +1 -1
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/RECORD +41 -31
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/top_level.txt +0 -0
simba/mixins/image_mixin.py
CHANGED
|
@@ -18,6 +18,7 @@ from collections import ChainMap
|
|
|
18
18
|
import cv2
|
|
19
19
|
import pandas as pd
|
|
20
20
|
from numba import float64, int64, jit, njit, prange, uint8
|
|
21
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
21
22
|
from shapely.geometry import Polygon
|
|
22
23
|
from skimage.metrics import structural_similarity
|
|
23
24
|
|
|
@@ -30,11 +31,12 @@ from simba.utils.checks import (check_file_exist_and_readable, check_float,
|
|
|
30
31
|
from simba.utils.data import terminate_cpu_pool
|
|
31
32
|
from simba.utils.enums import Defaults, Formats, GeometryEnum, Options
|
|
32
33
|
from simba.utils.errors import ArrayError, FrameRangeError, InvalidInputError
|
|
34
|
+
from simba.utils.lookups import get_fonts
|
|
33
35
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
34
36
|
from simba.utils.read_write import (find_core_cnt,
|
|
35
37
|
find_files_of_filetypes_in_directory,
|
|
36
38
|
get_fn_ext, get_video_meta_data,
|
|
37
|
-
read_frm_of_video)
|
|
39
|
+
read_frm_of_video, seconds_to_timestamp)
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
class ImageMixin(object):
|
|
@@ -2052,18 +2054,141 @@ class ImageMixin(object):
|
|
|
2052
2054
|
|
|
2053
2055
|
return denoised_img
|
|
2054
2056
|
|
|
2057
|
+
@staticmethod
|
|
2058
|
+
def get_timelapse_img(video_path: Union[str, os.PathLike],
|
|
2059
|
+
frame_cnt: int = 25,
|
|
2060
|
+
size: Optional[int] = None,
|
|
2061
|
+
crop_ratio: int = 50) -> np.ndarray:
|
|
2055
2062
|
|
|
2063
|
+
"""
|
|
2064
|
+
Creates timelapse image from video.
|
|
2056
2065
|
|
|
2066
|
+
.. image:: _static/img/get_timelapse_img.png
|
|
2067
|
+
:width: 600
|
|
2068
|
+
:align: center
|
|
2057
2069
|
|
|
2070
|
+
:param Union[str, os.PathLike] video_path: Path to the video to cerate the timelapse image from.
|
|
2071
|
+
:param int frame_cnt: Number of frames to grab from the video. There will be an even interval between each frame.
|
|
2072
|
+
:param Optional[int] size: The total width in pixels of the final timelapse image. If None, uses the video width (adjusted for crop_ratio).
|
|
2073
|
+
:param int crop_ratio: The percent of each original video (from the left) to show.
|
|
2074
|
+
:return np.ndarray: The timelapse image as a numpy array
|
|
2058
2075
|
|
|
2076
|
+
:example:
|
|
2077
|
+
>>> img = ImageMixin.get_timelapse_img(video_path=r"E:\troubleshooting\mitra_emergence\project_folder\clip_test\Box1_180mISOcontrol_Females_clipped_progress_bar.mp4", size=100)
|
|
2078
|
+
"""
|
|
2059
2079
|
|
|
2080
|
+
video_meta = get_video_meta_data(video_path=video_path, raise_error=True)
|
|
2081
|
+
frm_ids = [int(i * video_meta['frame_count'] / frame_cnt) for i in range(frame_cnt)]
|
|
2082
|
+
cap = cv2.VideoCapture(video_path)
|
|
2083
|
+
frms = [read_frm_of_video(video_path=cap, frame_index=x, use_ffmpeg=False) for x in frm_ids]
|
|
2084
|
+
|
|
2085
|
+
effective_video_width = int(video_meta['width'] * (crop_ratio / 100))
|
|
2086
|
+
if size is None:
|
|
2087
|
+
size = effective_video_width
|
|
2088
|
+
per_frame_width_after_crop = size / frame_cnt
|
|
2089
|
+
per_frame_width_before_crop = per_frame_width_after_crop / (crop_ratio / 100)
|
|
2090
|
+
scale_factor = per_frame_width_before_crop / frms[0].shape[1]
|
|
2091
|
+
scaled_frms = [cv2.resize(x, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR) for x in frms]
|
|
2092
|
+
if crop_ratio is not None:
|
|
2093
|
+
scaled_frms = [ImageMixin.segment_img_vertical(img=x, pct=crop_ratio, left=True) for x in scaled_frms]
|
|
2094
|
+
|
|
2095
|
+
return cv2.hconcat(scaled_frms)
|
|
2060
2096
|
|
|
2097
|
+
@staticmethod
|
|
2098
|
+
def create_time_ruler(width: int,
|
|
2099
|
+
video_path: Union[str, os.PathLike],
|
|
2100
|
+
height: int = 60,
|
|
2101
|
+
num_divisions: int = 6,
|
|
2102
|
+
font: str = 'Arial',
|
|
2103
|
+
bg_color: Tuple[int, int, int] = (255, 255, 255),
|
|
2104
|
+
line_color: Tuple[int, int, int] = (128, 128, 128),
|
|
2105
|
+
text_color: Tuple[int, int, int] = (0, 0, 0),
|
|
2106
|
+
padding: int = 60,
|
|
2107
|
+
show_time: bool = True) -> np.ndarray:
|
|
2108
|
+
"""
|
|
2109
|
+
Create a horizontal ruler/scale bar with tick marks and labels.
|
|
2110
|
+
|
|
2111
|
+
.. image:: _static/img/create_time_ruler.png
|
|
2112
|
+
:width: 600
|
|
2113
|
+
:align: center
|
|
2061
2114
|
|
|
2115
|
+
:param int width: Width of the ruler in pixels (should match timelapse image width if one is used)
|
|
2116
|
+
:param Union[str, os.PathLike] video_path: Path to video file to get metadata from
|
|
2117
|
+
:param int height: Height of the ruler in pixels. Default 60.
|
|
2118
|
+
:param int num_divisions: Number of major divisions on the ruler. Default 6.
|
|
2119
|
+
:param str font: Font name to use for labels. Default 'Algerian'.
|
|
2120
|
+
:param Tuple[int, int, int] bg_color: Background color (R, G, B). Default white.
|
|
2121
|
+
:param Tuple[int, int, int] line_color: Color for tick marks and lines (R, G, B). Default grey.
|
|
2122
|
+
:param Tuple[int, int, int] text_color: Color for text labels (R, G, B). Default black.
|
|
2123
|
+
:param bool show_time: If True, show time labels, else show frame numbers. Default True.
|
|
2124
|
+
:return: Ruler image as numpy array (BGR format for OpenCV compatibility)
|
|
2125
|
+
:rtype: np.ndarray
|
|
2062
2126
|
|
|
2127
|
+
:example:
|
|
2128
|
+
>>> ruler = ImageMixin.create_time_ruler(width=1920, video_path='path/to/video.mp4', height=60, num_divisions=6)
|
|
2129
|
+
"""
|
|
2063
2130
|
|
|
2064
|
-
|
|
2065
|
-
|
|
2066
|
-
|
|
2131
|
+
check_file_exist_and_readable(file_path=video_path)
|
|
2132
|
+
check_int(name='width', value=width, min_value=1, raise_error=True)
|
|
2133
|
+
check_int(name='height', value=height, min_value=1, raise_error=True)
|
|
2134
|
+
check_int(name='num_divisions', value=num_divisions, min_value=1, raise_error=True)
|
|
2135
|
+
check_int(name='padding', value=padding, min_value=0, raise_error=True)
|
|
2136
|
+
check_str(name='font', value=font, allow_blank=False, raise_error=True)
|
|
2137
|
+
check_if_valid_rgb_tuple(data=bg_color, raise_error=True, source=ImageMixin.create_time_ruler.__name__)
|
|
2138
|
+
check_if_valid_rgb_tuple(data=line_color, raise_error=True, source=ImageMixin.create_time_ruler.__name__)
|
|
2139
|
+
check_if_valid_rgb_tuple(data=text_color, raise_error=True, source=ImageMixin.create_time_ruler.__name__)
|
|
2140
|
+
|
|
2141
|
+
video_meta = get_video_meta_data(video_path=video_path, raise_error=True)
|
|
2142
|
+
total_width = width + (2 * padding)
|
|
2143
|
+
|
|
2144
|
+
img = Image.new('RGB', (total_width, height), color=bg_color)
|
|
2145
|
+
draw = ImageDraw.Draw(img)
|
|
2146
|
+
font_dict = get_fonts()
|
|
2147
|
+
try:
|
|
2148
|
+
font_path = font_dict[font]
|
|
2149
|
+
pil_font = ImageFont.truetype(font_path, size=12)
|
|
2150
|
+
except (KeyError, OSError):
|
|
2151
|
+
pil_font = ImageFont.load_default()
|
|
2152
|
+
major_tick_height, half_tick_height = height * 0.6, height * 0.4
|
|
2153
|
+
quarter_tick_height, eighth_tick_height = height * 0.25, height * 0.15
|
|
2154
|
+
|
|
2155
|
+
for i in range(num_divisions + 1):
|
|
2156
|
+
x = padding + int(i * width / num_divisions)
|
|
2157
|
+
draw.line([(x, 0), (x, major_tick_height)], fill=line_color, width=2)
|
|
2158
|
+
if show_time and video_meta['video_length_s'] is not None:
|
|
2159
|
+
seconds_at_division = i * video_meta['video_length_s'] / num_divisions
|
|
2160
|
+
label = seconds_to_timestamp(seconds=seconds_at_division)
|
|
2161
|
+
elif video_meta['frame_count'] is not None:
|
|
2162
|
+
label = str(int(i * video_meta['frame_count'] / num_divisions))
|
|
2163
|
+
else:
|
|
2164
|
+
label = str(i)
|
|
2165
|
+
bbox = draw.textbbox((0, 0), label, font=pil_font)
|
|
2166
|
+
text_width = bbox[2] - bbox[0]
|
|
2167
|
+
if i == 0:
|
|
2168
|
+
draw.text((x, major_tick_height + 5), label, fill=text_color, font=pil_font)
|
|
2169
|
+
elif i == num_divisions:
|
|
2170
|
+
draw.text((x - text_width, major_tick_height + 5), label, fill=text_color, font=pil_font)
|
|
2171
|
+
else:
|
|
2172
|
+
draw.text((x - text_width // 2, major_tick_height + 5), label, fill=text_color, font=pil_font)
|
|
2173
|
+
if i < num_divisions:
|
|
2174
|
+
x_half = padding + int((i + 0.5) * width / num_divisions)
|
|
2175
|
+
draw.line([(x_half, 0), (x_half, half_tick_height)], fill=line_color, width=1)
|
|
2176
|
+
for q in [0.25, 0.75]:
|
|
2177
|
+
x_quarter = padding + int((i + q) * width / num_divisions)
|
|
2178
|
+
draw.line([(x_quarter, 0), (x_quarter, quarter_tick_height)], fill=line_color, width=1)
|
|
2179
|
+
for e in [0.125, 0.375, 0.625, 0.875]:
|
|
2180
|
+
x_eighth = padding + int((i + e) * width / num_divisions)
|
|
2181
|
+
draw.line([(x_eighth, 0), (x_eighth, eighth_tick_height)], fill=line_color, width=1)
|
|
2182
|
+
|
|
2183
|
+
draw.line([(0, height - 1), (total_width, height - 1)], fill=line_color, width=1)
|
|
2184
|
+
img_bgr = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
2185
|
+
return img_bgr
|
|
2186
|
+
|
|
2187
|
+
|
|
2188
|
+
# img = ImageMixin.create_time_ruler(width=1920, video_path=r"E:\troubleshooting\mitra_emergence\project_folder\clip_test\Box1_180mISOcontrol_Females_clipped_progress_bar.mp4", height=60, num_divisions=6)
|
|
2189
|
+
#
|
|
2190
|
+
# cv2.imshow('sadasdas', img)
|
|
2191
|
+
# cv2.waitKey(40000)
|
|
2067
2192
|
|
|
2068
2193
|
#x = ImageMixin.get_blob_locations(video_path=r"C:\troubleshooting\RAT_NOR\project_folder\videos\2022-06-20_NOB_DOT_4_downsampled_bg_subtracted.mp4", gpu=True)
|
|
2069
2194
|
# imgs = ImageMixin().read_all_img_in_dir(dir='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/videos/examples')
|
|
@@ -1070,10 +1070,7 @@ class TrainModelMixin(object):
|
|
|
1070
1070
|
MissingUserInputWarning(msg=f'Skipping {str(config.get("SML settings", "target_name_" + str(n + 1)))} classifier analysis: missing information (e.g., no discrimination threshold and/or minimum bout set in the project_config.ini',source=self.__class__.__name__)
|
|
1071
1071
|
|
|
1072
1072
|
if len(model_dict.keys()) == 0:
|
|
1073
|
-
raise NoDataError(
|
|
1074
|
-
msg=f"There are no models with accurate data specified in the RUN MODELS menu. Specify the model information to SimBA RUN MODELS menu to use them to analyze videos",
|
|
1075
|
-
source=self.get_model_info.__name__,
|
|
1076
|
-
)
|
|
1073
|
+
raise NoDataError(msg=f"There are no models with accurate data specified in the RUN MODELS menu. Specify the model information to SimBA RUN MODELS menu to use them to analyze videos. PLease check the model paths, thresholds, and minimum bout lengths.", source=self.get_model_info.__name__)
|
|
1077
1074
|
else:
|
|
1078
1075
|
return model_dict
|
|
1079
1076
|
|
simba/model/inference_batch.py
CHANGED
|
@@ -101,7 +101,7 @@ class InferenceBatch(TrainModelMixin, ConfigReader):
|
|
|
101
101
|
video_timer.stop_timer()
|
|
102
102
|
print(f"Predictions created for {file_name} (frame count: {len(in_df)}, elapsed time: {video_timer.elapsed_time_str}) ...")
|
|
103
103
|
self.timer.stop_timer()
|
|
104
|
-
stdout_success(msg=f"Machine predictions complete. Files saved in {self.save_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
|
|
104
|
+
stdout_success(msg=f"Machine predictions complete for {len(self.feature_file_paths)} file(s). Files saved in {self.save_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
|
|
105
105
|
|
|
106
106
|
if __name__ == "__main__" and not hasattr(sys, 'ps1'):
|
|
107
107
|
parser = argparse.ArgumentParser(description="Perform classifications according to rules defined in SImAB project_config.ini.")
|
simba/model/yolo_fit.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import sys
|
|
3
|
+
from contextlib import redirect_stderr, redirect_stdout
|
|
3
4
|
|
|
4
5
|
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
|
5
6
|
import argparse
|
|
@@ -21,7 +22,8 @@ from simba.utils.checks import (check_file_exist_and_readable,
|
|
|
21
22
|
check_valid_boolean, check_valid_device)
|
|
22
23
|
from simba.utils.enums import Options
|
|
23
24
|
from simba.utils.errors import SimBAGPUError, SimBAPAckageVersionError
|
|
24
|
-
from simba.utils.
|
|
25
|
+
from simba.utils.printing import stdout_information
|
|
26
|
+
from simba.utils.read_write import find_core_cnt, get_current_time
|
|
25
27
|
from simba.utils.yolo import load_yolo_model
|
|
26
28
|
|
|
27
29
|
|
|
@@ -108,20 +110,25 @@ class FitYolo():
|
|
|
108
110
|
|
|
109
111
|
|
|
110
112
|
def run(self):
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
113
|
+
# Temporarily redirect stdout/stderr to terminal to ensure ultralytics output goes to terminal
|
|
114
|
+
# sys.__stdout__ and sys.__stderr__ are the original terminal streams
|
|
115
|
+
stdout_information(msg=f'[{get_current_time()}] Please follow the YOLO pose model training in the terminal from where SimBA was launched ...', source=self.__class__.__name__)
|
|
116
|
+
stdout_information(msg=f'[{get_current_time()}] Results will be stored in the {self.save_path} directory ..', source=self.__class__.__name__)
|
|
117
|
+
with redirect_stdout(sys.__stdout__), redirect_stderr(sys.__stderr__):
|
|
118
|
+
model = load_yolo_model(weights_path=self.weights_path,
|
|
119
|
+
verbose=self.verbose,
|
|
120
|
+
format=self.format,
|
|
121
|
+
device=self.device)
|
|
122
|
+
|
|
123
|
+
model.train(data=self.model_yaml,
|
|
124
|
+
epochs=self.epochs,
|
|
125
|
+
project=self.save_path,
|
|
126
|
+
batch=self.batch,
|
|
127
|
+
plots=self.plots,
|
|
128
|
+
imgsz=self.imgsz,
|
|
129
|
+
workers=self.workers,
|
|
130
|
+
device=self.device,
|
|
131
|
+
patience=self.patience)
|
|
125
132
|
|
|
126
133
|
|
|
127
134
|
if __name__ == "__main__" and not hasattr(sys, 'ps1'):
|
|
@@ -34,7 +34,7 @@ from simba.utils.errors import (CountError, InvalidFilepathError,
|
|
|
34
34
|
InvalidFileTypeError, SimBAGPUError,
|
|
35
35
|
SimBAPAckageVersionError)
|
|
36
36
|
from simba.utils.lookups import get_current_time
|
|
37
|
-
from simba.utils.printing import SimbaTimer, stdout_success
|
|
37
|
+
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
|
|
38
38
|
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
|
|
39
39
|
get_video_meta_data, recursive_file_search)
|
|
40
40
|
from simba.utils.warnings import FileExistWarning, NoDataFoundWarning
|
|
@@ -182,7 +182,12 @@ class YOLOPoseInference():
|
|
|
182
182
|
results = {}
|
|
183
183
|
class_dict = self.model.names
|
|
184
184
|
timer = SimbaTimer(start=True)
|
|
185
|
-
|
|
185
|
+
if self.save_dir is not None:
|
|
186
|
+
msg = f'[{get_current_time()}] Starting tracking inference for {len(self.video_path)} video(s). Results will be saved in {self.save_dir} ... '
|
|
187
|
+
else:
|
|
188
|
+
msg = f'[{get_current_time()}] Starting tracking inference for {len(self.video_path)} video(s) ... '
|
|
189
|
+
stdout_information(msg=msg, source=self.__class__.__name__)
|
|
190
|
+
stdout_information(msg='Follow progress in OS terminal window ...', source=self.__class__.__name__)
|
|
186
191
|
for video_cnt, path in enumerate(self.video_path):
|
|
187
192
|
video_timer = SimbaTimer(start=True)
|
|
188
193
|
_, video_name, _ = get_fn_ext(filepath=path)
|
|
@@ -5,7 +5,7 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
from simba.mixins.config_reader import ConfigReader
|
|
7
7
|
from simba.utils.checks import check_if_filepath_list_is_empty
|
|
8
|
-
from simba.utils.printing import SimbaTimer, stdout_success
|
|
8
|
+
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
|
|
9
9
|
from simba.utils.read_write import get_fn_ext, read_df, write_df
|
|
10
10
|
|
|
11
11
|
|
|
@@ -43,7 +43,7 @@ class OutlierCorrectionSkipper(ConfigReader):
|
|
|
43
43
|
save_path = os.path.join(self.outlier_corrected_dir, f"{video_name}.{self.file_type}")
|
|
44
44
|
write_df(df=data_df, file_type=self.file_type, save_path=save_path)
|
|
45
45
|
video_timer.stop_timer()
|
|
46
|
-
|
|
46
|
+
stdout_information(msg=f"Skipped outlier correction for video {video_name} (Video {file_cnt+1}/{len(self.input_csv_paths)})", elapsed_time=video_timer.elapsed_time_str)
|
|
47
47
|
self.timer.stop_timer()
|
|
48
48
|
stdout_success(msg=f"Skipped outlier correction for {len(self.input_csv_paths)} files", elapsed_time=self.timer.elapsed_time_str)
|
|
49
49
|
|
|
@@ -17,7 +17,7 @@ from simba.utils.checks import (
|
|
|
17
17
|
check_all_file_names_are_represented_in_video_log,
|
|
18
18
|
check_filepaths_in_iterable_exist, check_int, check_str,
|
|
19
19
|
check_valid_boolean, check_valid_dataframe, check_valid_dict)
|
|
20
|
-
from simba.utils.data import terminate_cpu_pool
|
|
20
|
+
from simba.utils.data import get_cpu_pool, terminate_cpu_pool
|
|
21
21
|
from simba.utils.enums import Formats
|
|
22
22
|
from simba.utils.errors import InvalidInputError, NoSpecifiedOutputError
|
|
23
23
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
@@ -149,6 +149,7 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
|
|
|
149
149
|
def run(self):
|
|
150
150
|
print(f"Processing {len(self.data_paths)} video(s)...")
|
|
151
151
|
check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
|
|
152
|
+
pool = get_cpu_pool(core_cnt=self.core_cnt, source=self.__class__.__name__)
|
|
152
153
|
for file_cnt, file_path in enumerate(self.data_paths):
|
|
153
154
|
video_timer = SimbaTimer(start=True)
|
|
154
155
|
_, self.video_name, _ = get_fn_ext(file_path)
|
|
@@ -173,7 +174,8 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
|
|
|
173
174
|
if len(np.unique(clf_data)) == 1:
|
|
174
175
|
raise InvalidInputError(msg=f'Cannot plot heatmap for behavior {self.clf_name} in video {self.video_name}. The behavior is classified as {np.unique(clf_data)} in every single frame.')
|
|
175
176
|
grid, aspect_ratio = GeometryMixin.bucket_img_into_grid_square(img_size=(self.width, self.height), bucket_grid_size_mm=self.bin_size, px_per_mm=self.px_per_mm, add_correction=False, verbose=False)
|
|
176
|
-
|
|
177
|
+
|
|
178
|
+
clf_data = GeometryMixin().cumsum_bool_geometries(data=bp_data, geometries=grid, bool_data=clf_data, fps=self.fps, verbose=False, core_cnt=self.core_cnt, pool=pool)
|
|
177
179
|
if self.max_scale == "auto":
|
|
178
180
|
self.max_scale = max(1, self.__calculate_max_scale(clf_array=clf_data))
|
|
179
181
|
if self.final_img_setting:
|
|
@@ -197,32 +199,29 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
|
|
|
197
199
|
frm_per_core_w_batch.append((batch_cnt, frm_range, frame_arrays[batch_cnt]))
|
|
198
200
|
del frame_arrays
|
|
199
201
|
print(f"Creating heatmaps, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})...")
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
terminate_cpu_pool(pool=pool, force=False)
|
|
218
|
-
|
|
202
|
+
constants = functools.partial(_heatmap_multiprocessor,
|
|
203
|
+
video_setting=self.video_setting,
|
|
204
|
+
frame_setting=self.frame_setting,
|
|
205
|
+
style_attr=self.style_attr,
|
|
206
|
+
fps=self.fps,
|
|
207
|
+
video_temp_dir=self.temp_folder,
|
|
208
|
+
frame_dir=self.frames_save_dir,
|
|
209
|
+
max_scale=self.max_scale,
|
|
210
|
+
aspect_ratio=aspect_ratio,
|
|
211
|
+
clf_name=self.clf_name,
|
|
212
|
+
size=(self.width, self.height),
|
|
213
|
+
video_name=self.video_name,
|
|
214
|
+
make_clf_heatmap_plot=self.make_clf_heatmap_plot)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
for cnt, batch in enumerate(pool.imap(constants, frm_per_core_w_batch, chunksize=self.multiprocess_chunksize)):
|
|
218
|
+
print(f'Batch core {batch+1}/{self.core_cnt} complete (Video {self.video_name})... ')
|
|
219
219
|
if self.video_setting:
|
|
220
220
|
print(f"Joining {self.video_name} multiprocessed video...")
|
|
221
221
|
concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path)
|
|
222
|
-
|
|
223
222
|
video_timer.stop_timer()
|
|
224
223
|
print(f"Heatmap video {self.video_name} complete, (elapsed time: {video_timer.elapsed_time_str}s) ...")
|
|
225
|
-
|
|
224
|
+
terminate_cpu_pool(pool=pool, force=False, source=self.__class__.__name__)
|
|
226
225
|
self.timer.stop_timer()
|
|
227
226
|
stdout_success(msg=f"Heatmap visualizations for {len(self.data_paths)} video(s) created in {self.heatmap_clf_location_dir} directory", elapsed_time=self.timer.elapsed_time_str)
|
|
228
227
|
|
|
@@ -261,3 +260,26 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
|
|
|
261
260
|
# core_cnt=5,
|
|
262
261
|
# files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'])
|
|
263
262
|
# test.create_heatmaps()
|
|
263
|
+
# if __name__ == "__main__":
|
|
264
|
+
# x = HeatMapperClfMultiprocess(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini",
|
|
265
|
+
# bodypart='nose',
|
|
266
|
+
# clf_name='GROOMING',
|
|
267
|
+
# style_attr={'palette': 'jet', 'shading': 'gouraud', 'bin_size': 25, 'max_scale': 'auto'},
|
|
268
|
+
# final_img_setting=True,
|
|
269
|
+
# video_setting=False,
|
|
270
|
+
# frame_setting=False,
|
|
271
|
+
# core_cnt=12,
|
|
272
|
+
# data_paths=[r"E:\troubleshooting\mitra_emergence\project_folder\csv\machine_results\Box1_180mISOcontrol_Females.csv"])
|
|
273
|
+
#
|
|
274
|
+
# x.run()
|
|
275
|
+
|
|
276
|
+
# def __init__(self,
|
|
277
|
+
# config_path: Union[str, os.PathLike],
|
|
278
|
+
# bodypart: str,
|
|
279
|
+
# clf_name: str,
|
|
280
|
+
# data_paths: List[str],
|
|
281
|
+
# style_attr: dict,
|
|
282
|
+
# final_img_setting: bool = True,
|
|
283
|
+
# video_setting: bool = False,
|
|
284
|
+
# frame_setting: bool = False,
|
|
285
|
+
# core_cnt: int = -1):
|
|
@@ -236,7 +236,8 @@ class PlotSklearnResultsSingleCore(ConfigReader, TrainModelMixin, PlottingMixin)
|
|
|
236
236
|
self.add_spacer += 1
|
|
237
237
|
if self.show_confidence:
|
|
238
238
|
col_name = f'Probability_{clf_name}'
|
|
239
|
-
|
|
239
|
+
conf = round(self.data_df.loc[frm_idx, col_name], 4)
|
|
240
|
+
conf_txt = f'{clf_name} CONFIDENCE {conf:.4f}'
|
|
240
241
|
self.frame = PlottingMixin().put_text(img=self.frame, text=conf_txt, pos=(TextOptions.BORDER_BUFFER_Y.value, ((self.video_meta_data["height"] - self.video_meta_data["height"]) + self.video_space_size * self.add_spacer)), font_size=self.video_font_size, font_thickness=self.video_text_thickness, font=self.font, text_bg_alpha=self.video_text_opacity, text_color_bg=self.text_bg_color, text_color=self.text_color)
|
|
241
242
|
self.add_spacer += 1
|
|
242
243
|
self.frame = PlottingMixin().put_text(img=self.frame, text="ENSEMBLE PREDICTION:", pos=(TextOptions.BORDER_BUFFER_Y.value, ((self.video_meta_data["height"] - self.video_meta_data["height"]) + self.video_space_size * self.add_spacer)), font_size=self.video_font_size, font_thickness=self.video_text_thickness, font=self.font, text_bg_alpha=self.video_text_opacity, text_color_bg=self.text_bg_color, text_color=self.text_color)
|