simba-uw-tf-dev 4.7.1__py3-none-any.whl → 4.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. simba/SimBA.py +13 -4
  2. simba/assets/icons/left_arrow_green.png +0 -0
  3. simba/assets/icons/left_arrow_red.png +0 -0
  4. simba/assets/icons/right_arrow_green.png +0 -0
  5. simba/assets/icons/right_arrow_red.png +0 -0
  6. simba/assets/lookups/yolo_schematics/yolo_mitra.csv +9 -0
  7. simba/mixins/geometry_mixin.py +357 -302
  8. simba/mixins/image_mixin.py +129 -4
  9. simba/mixins/train_model_mixin.py +1 -4
  10. simba/model/inference_batch.py +1 -1
  11. simba/model/yolo_fit.py +22 -15
  12. simba/model/yolo_pose_inference.py +7 -2
  13. simba/outlier_tools/skip_outlier_correction.py +2 -2
  14. simba/plotting/heat_mapper_clf_mp.py +45 -23
  15. simba/plotting/plot_clf_results.py +2 -1
  16. simba/plotting/plot_clf_results_mp.py +456 -455
  17. simba/roi_tools/roi_utils.py +2 -2
  18. simba/sandbox/convert_h264_to_mp4_lossless.py +129 -0
  19. simba/sandbox/extract_and_convert_videos.py +257 -0
  20. simba/sandbox/remove_end_of_video.py +80 -0
  21. simba/sandbox/video_timelaps.py +291 -0
  22. simba/third_party_label_appenders/transform/simba_to_yolo.py +8 -5
  23. simba/ui/import_pose_frame.py +13 -13
  24. simba/ui/pop_ups/clf_plot_pop_up.py +1 -1
  25. simba/ui/pop_ups/run_machine_models_popup.py +22 -22
  26. simba/ui/pop_ups/simba_to_yolo_keypoints_popup.py +2 -2
  27. simba/ui/pop_ups/video_processing_pop_up.py +3638 -3469
  28. simba/ui/pop_ups/yolo_inference_popup.py +1 -1
  29. simba/ui/pop_ups/yolo_pose_train_popup.py +1 -1
  30. simba/ui/tkinter_functions.py +3 -1
  31. simba/ui/video_timelaps.py +454 -0
  32. simba/utils/lookups.py +67 -1
  33. simba/utils/read_write.py +10 -3
  34. simba/video_processors/batch_process_create_ffmpeg_commands.py +0 -1
  35. simba/video_processors/video_processing.py +160 -39
  36. {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/METADATA +1 -1
  37. {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/RECORD +41 -31
  38. {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/LICENSE +0 -0
  39. {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/WHEEL +0 -0
  40. {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/entry_points.txt +0 -0
  41. {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ from collections import ChainMap
18
18
  import cv2
19
19
  import pandas as pd
20
20
  from numba import float64, int64, jit, njit, prange, uint8
21
+ from PIL import Image, ImageDraw, ImageFont
21
22
  from shapely.geometry import Polygon
22
23
  from skimage.metrics import structural_similarity
23
24
 
@@ -30,11 +31,12 @@ from simba.utils.checks import (check_file_exist_and_readable, check_float,
30
31
  from simba.utils.data import terminate_cpu_pool
31
32
  from simba.utils.enums import Defaults, Formats, GeometryEnum, Options
32
33
  from simba.utils.errors import ArrayError, FrameRangeError, InvalidInputError
34
+ from simba.utils.lookups import get_fonts
33
35
  from simba.utils.printing import SimbaTimer, stdout_success
34
36
  from simba.utils.read_write import (find_core_cnt,
35
37
  find_files_of_filetypes_in_directory,
36
38
  get_fn_ext, get_video_meta_data,
37
- read_frm_of_video)
39
+ read_frm_of_video, seconds_to_timestamp)
38
40
 
39
41
 
40
42
  class ImageMixin(object):
@@ -2052,18 +2054,141 @@ class ImageMixin(object):
2052
2054
 
2053
2055
  return denoised_img
2054
2056
 
2057
+ @staticmethod
2058
+ def get_timelapse_img(video_path: Union[str, os.PathLike],
2059
+ frame_cnt: int = 25,
2060
+ size: Optional[int] = None,
2061
+ crop_ratio: int = 50) -> np.ndarray:
2055
2062
 
2063
+ """
2064
+ Creates timelapse image from video.
2056
2065
 
2066
+ .. image:: _static/img/get_timelapse_img.png
2067
+ :width: 600
2068
+ :align: center
2057
2069
 
2070
+ :param Union[str, os.PathLike] video_path: Path to the video to cerate the timelapse image from.
2071
+ :param int frame_cnt: Number of frames to grab from the video. There will be an even interval between each frame.
2072
+ :param Optional[int] size: The total width in pixels of the final timelapse image. If None, uses the video width (adjusted for crop_ratio).
2073
+ :param int crop_ratio: The percent of each original video (from the left) to show.
2074
+ :return np.ndarray: The timelapse image as a numpy array
2058
2075
 
2076
+ :example:
2077
+ >>> img = ImageMixin.get_timelapse_img(video_path=r"E:\troubleshooting\mitra_emergence\project_folder\clip_test\Box1_180mISOcontrol_Females_clipped_progress_bar.mp4", size=100)
2078
+ """
2059
2079
 
2080
+ video_meta = get_video_meta_data(video_path=video_path, raise_error=True)
2081
+ frm_ids = [int(i * video_meta['frame_count'] / frame_cnt) for i in range(frame_cnt)]
2082
+ cap = cv2.VideoCapture(video_path)
2083
+ frms = [read_frm_of_video(video_path=cap, frame_index=x, use_ffmpeg=False) for x in frm_ids]
2084
+
2085
+ effective_video_width = int(video_meta['width'] * (crop_ratio / 100))
2086
+ if size is None:
2087
+ size = effective_video_width
2088
+ per_frame_width_after_crop = size / frame_cnt
2089
+ per_frame_width_before_crop = per_frame_width_after_crop / (crop_ratio / 100)
2090
+ scale_factor = per_frame_width_before_crop / frms[0].shape[1]
2091
+ scaled_frms = [cv2.resize(x, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR) for x in frms]
2092
+ if crop_ratio is not None:
2093
+ scaled_frms = [ImageMixin.segment_img_vertical(img=x, pct=crop_ratio, left=True) for x in scaled_frms]
2094
+
2095
+ return cv2.hconcat(scaled_frms)
2060
2096
 
2097
+ @staticmethod
2098
+ def create_time_ruler(width: int,
2099
+ video_path: Union[str, os.PathLike],
2100
+ height: int = 60,
2101
+ num_divisions: int = 6,
2102
+ font: str = 'Arial',
2103
+ bg_color: Tuple[int, int, int] = (255, 255, 255),
2104
+ line_color: Tuple[int, int, int] = (128, 128, 128),
2105
+ text_color: Tuple[int, int, int] = (0, 0, 0),
2106
+ padding: int = 60,
2107
+ show_time: bool = True) -> np.ndarray:
2108
+ """
2109
+ Create a horizontal ruler/scale bar with tick marks and labels.
2110
+
2111
+ .. image:: _static/img/create_time_ruler.png
2112
+ :width: 600
2113
+ :align: center
2061
2114
 
2115
+ :param int width: Width of the ruler in pixels (should match timelapse image width if one is used)
2116
+ :param Union[str, os.PathLike] video_path: Path to video file to get metadata from
2117
+ :param int height: Height of the ruler in pixels. Default 60.
2118
+ :param int num_divisions: Number of major divisions on the ruler. Default 6.
2119
+ :param str font: Font name to use for labels. Default 'Algerian'.
2120
+ :param Tuple[int, int, int] bg_color: Background color (R, G, B). Default white.
2121
+ :param Tuple[int, int, int] line_color: Color for tick marks and lines (R, G, B). Default grey.
2122
+ :param Tuple[int, int, int] text_color: Color for text labels (R, G, B). Default black.
2123
+ :param bool show_time: If True, show time labels, else show frame numbers. Default True.
2124
+ :return: Ruler image as numpy array (BGR format for OpenCV compatibility)
2125
+ :rtype: np.ndarray
2062
2126
 
2127
+ :example:
2128
+ >>> ruler = ImageMixin.create_time_ruler(width=1920, video_path='path/to/video.mp4', height=60, num_divisions=6)
2129
+ """
2063
2130
 
2064
-
2065
-
2066
-
2131
+ check_file_exist_and_readable(file_path=video_path)
2132
+ check_int(name='width', value=width, min_value=1, raise_error=True)
2133
+ check_int(name='height', value=height, min_value=1, raise_error=True)
2134
+ check_int(name='num_divisions', value=num_divisions, min_value=1, raise_error=True)
2135
+ check_int(name='padding', value=padding, min_value=0, raise_error=True)
2136
+ check_str(name='font', value=font, allow_blank=False, raise_error=True)
2137
+ check_if_valid_rgb_tuple(data=bg_color, raise_error=True, source=ImageMixin.create_time_ruler.__name__)
2138
+ check_if_valid_rgb_tuple(data=line_color, raise_error=True, source=ImageMixin.create_time_ruler.__name__)
2139
+ check_if_valid_rgb_tuple(data=text_color, raise_error=True, source=ImageMixin.create_time_ruler.__name__)
2140
+
2141
+ video_meta = get_video_meta_data(video_path=video_path, raise_error=True)
2142
+ total_width = width + (2 * padding)
2143
+
2144
+ img = Image.new('RGB', (total_width, height), color=bg_color)
2145
+ draw = ImageDraw.Draw(img)
2146
+ font_dict = get_fonts()
2147
+ try:
2148
+ font_path = font_dict[font]
2149
+ pil_font = ImageFont.truetype(font_path, size=12)
2150
+ except (KeyError, OSError):
2151
+ pil_font = ImageFont.load_default()
2152
+ major_tick_height, half_tick_height = height * 0.6, height * 0.4
2153
+ quarter_tick_height, eighth_tick_height = height * 0.25, height * 0.15
2154
+
2155
+ for i in range(num_divisions + 1):
2156
+ x = padding + int(i * width / num_divisions)
2157
+ draw.line([(x, 0), (x, major_tick_height)], fill=line_color, width=2)
2158
+ if show_time and video_meta['video_length_s'] is not None:
2159
+ seconds_at_division = i * video_meta['video_length_s'] / num_divisions
2160
+ label = seconds_to_timestamp(seconds=seconds_at_division)
2161
+ elif video_meta['frame_count'] is not None:
2162
+ label = str(int(i * video_meta['frame_count'] / num_divisions))
2163
+ else:
2164
+ label = str(i)
2165
+ bbox = draw.textbbox((0, 0), label, font=pil_font)
2166
+ text_width = bbox[2] - bbox[0]
2167
+ if i == 0:
2168
+ draw.text((x, major_tick_height + 5), label, fill=text_color, font=pil_font)
2169
+ elif i == num_divisions:
2170
+ draw.text((x - text_width, major_tick_height + 5), label, fill=text_color, font=pil_font)
2171
+ else:
2172
+ draw.text((x - text_width // 2, major_tick_height + 5), label, fill=text_color, font=pil_font)
2173
+ if i < num_divisions:
2174
+ x_half = padding + int((i + 0.5) * width / num_divisions)
2175
+ draw.line([(x_half, 0), (x_half, half_tick_height)], fill=line_color, width=1)
2176
+ for q in [0.25, 0.75]:
2177
+ x_quarter = padding + int((i + q) * width / num_divisions)
2178
+ draw.line([(x_quarter, 0), (x_quarter, quarter_tick_height)], fill=line_color, width=1)
2179
+ for e in [0.125, 0.375, 0.625, 0.875]:
2180
+ x_eighth = padding + int((i + e) * width / num_divisions)
2181
+ draw.line([(x_eighth, 0), (x_eighth, eighth_tick_height)], fill=line_color, width=1)
2182
+
2183
+ draw.line([(0, height - 1), (total_width, height - 1)], fill=line_color, width=1)
2184
+ img_bgr = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
2185
+ return img_bgr
2186
+
2187
+
2188
+ # img = ImageMixin.create_time_ruler(width=1920, video_path=r"E:\troubleshooting\mitra_emergence\project_folder\clip_test\Box1_180mISOcontrol_Females_clipped_progress_bar.mp4", height=60, num_divisions=6)
2189
+ #
2190
+ # cv2.imshow('sadasdas', img)
2191
+ # cv2.waitKey(40000)
2067
2192
 
2068
2193
  #x = ImageMixin.get_blob_locations(video_path=r"C:\troubleshooting\RAT_NOR\project_folder\videos\2022-06-20_NOB_DOT_4_downsampled_bg_subtracted.mp4", gpu=True)
2069
2194
  # imgs = ImageMixin().read_all_img_in_dir(dir='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/videos/examples')
@@ -1070,10 +1070,7 @@ class TrainModelMixin(object):
1070
1070
  MissingUserInputWarning(msg=f'Skipping {str(config.get("SML settings", "target_name_" + str(n + 1)))} classifier analysis: missing information (e.g., no discrimination threshold and/or minimum bout set in the project_config.ini',source=self.__class__.__name__)
1071
1071
 
1072
1072
  if len(model_dict.keys()) == 0:
1073
- raise NoDataError(
1074
- msg=f"There are no models with accurate data specified in the RUN MODELS menu. Specify the model information to SimBA RUN MODELS menu to use them to analyze videos",
1075
- source=self.get_model_info.__name__,
1076
- )
1073
+ raise NoDataError(msg=f"There are no models with accurate data specified in the RUN MODELS menu. Specify the model information to SimBA RUN MODELS menu to use them to analyze videos. PLease check the model paths, thresholds, and minimum bout lengths.", source=self.get_model_info.__name__)
1077
1074
  else:
1078
1075
  return model_dict
1079
1076
 
@@ -101,7 +101,7 @@ class InferenceBatch(TrainModelMixin, ConfigReader):
101
101
  video_timer.stop_timer()
102
102
  print(f"Predictions created for {file_name} (frame count: {len(in_df)}, elapsed time: {video_timer.elapsed_time_str}) ...")
103
103
  self.timer.stop_timer()
104
- stdout_success(msg=f"Machine predictions complete. Files saved in {self.save_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
104
+ stdout_success(msg=f"Machine predictions complete for {len(self.feature_file_paths)} file(s). Files saved in {self.save_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
105
105
 
106
106
  if __name__ == "__main__" and not hasattr(sys, 'ps1'):
107
107
  parser = argparse.ArgumentParser(description="Perform classifications according to rules defined in SImAB project_config.ini.")
simba/model/yolo_fit.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import os
2
2
  import sys
3
+ from contextlib import redirect_stderr, redirect_stdout
3
4
 
4
5
  os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
5
6
  import argparse
@@ -21,7 +22,8 @@ from simba.utils.checks import (check_file_exist_and_readable,
21
22
  check_valid_boolean, check_valid_device)
22
23
  from simba.utils.enums import Options
23
24
  from simba.utils.errors import SimBAGPUError, SimBAPAckageVersionError
24
- from simba.utils.read_write import find_core_cnt
25
+ from simba.utils.printing import stdout_information
26
+ from simba.utils.read_write import find_core_cnt, get_current_time
25
27
  from simba.utils.yolo import load_yolo_model
26
28
 
27
29
 
@@ -108,20 +110,25 @@ class FitYolo():
108
110
 
109
111
 
110
112
  def run(self):
111
- model = load_yolo_model(weights_path=self.weights_path,
112
- verbose=self.verbose,
113
- format=self.format,
114
- device=self.device)
115
-
116
- model.train(data=self.model_yaml,
117
- epochs=self.epochs,
118
- project=self.save_path,
119
- batch=self.batch,
120
- plots=self.plots,
121
- imgsz=self.imgsz,
122
- workers=self.workers,
123
- device=self.device,
124
- patience=self.patience)
113
+ # Temporarily redirect stdout/stderr to terminal to ensure ultralytics output goes to terminal
114
+ # sys.__stdout__ and sys.__stderr__ are the original terminal streams
115
+ stdout_information(msg=f'[{get_current_time()}] Please follow the YOLO pose model training in the terminal from where SimBA was launched ...', source=self.__class__.__name__)
116
+ stdout_information(msg=f'[{get_current_time()}] Results will be stored in the {self.save_path} directory ..', source=self.__class__.__name__)
117
+ with redirect_stdout(sys.__stdout__), redirect_stderr(sys.__stderr__):
118
+ model = load_yolo_model(weights_path=self.weights_path,
119
+ verbose=self.verbose,
120
+ format=self.format,
121
+ device=self.device)
122
+
123
+ model.train(data=self.model_yaml,
124
+ epochs=self.epochs,
125
+ project=self.save_path,
126
+ batch=self.batch,
127
+ plots=self.plots,
128
+ imgsz=self.imgsz,
129
+ workers=self.workers,
130
+ device=self.device,
131
+ patience=self.patience)
125
132
 
126
133
 
127
134
  if __name__ == "__main__" and not hasattr(sys, 'ps1'):
@@ -34,7 +34,7 @@ from simba.utils.errors import (CountError, InvalidFilepathError,
34
34
  InvalidFileTypeError, SimBAGPUError,
35
35
  SimBAPAckageVersionError)
36
36
  from simba.utils.lookups import get_current_time
37
- from simba.utils.printing import SimbaTimer, stdout_success
37
+ from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
38
38
  from simba.utils.read_write import (find_files_of_filetypes_in_directory,
39
39
  get_video_meta_data, recursive_file_search)
40
40
  from simba.utils.warnings import FileExistWarning, NoDataFoundWarning
@@ -182,7 +182,12 @@ class YOLOPoseInference():
182
182
  results = {}
183
183
  class_dict = self.model.names
184
184
  timer = SimbaTimer(start=True)
185
- print(f'Starting tracking inference for {len(self.video_path)} video(s) ({get_current_time()})... ')
185
+ if self.save_dir is not None:
186
+ msg = f'[{get_current_time()}] Starting tracking inference for {len(self.video_path)} video(s). Results will be saved in {self.save_dir} ... '
187
+ else:
188
+ msg = f'[{get_current_time()}] Starting tracking inference for {len(self.video_path)} video(s) ... '
189
+ stdout_information(msg=msg, source=self.__class__.__name__)
190
+ stdout_information(msg='Follow progress in OS terminal window ...', source=self.__class__.__name__)
186
191
  for video_cnt, path in enumerate(self.video_path):
187
192
  video_timer = SimbaTimer(start=True)
188
193
  _, video_name, _ = get_fn_ext(filepath=path)
@@ -5,7 +5,7 @@ from typing import Union
5
5
 
6
6
  from simba.mixins.config_reader import ConfigReader
7
7
  from simba.utils.checks import check_if_filepath_list_is_empty
8
- from simba.utils.printing import SimbaTimer, stdout_success
8
+ from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
9
9
  from simba.utils.read_write import get_fn_ext, read_df, write_df
10
10
 
11
11
 
@@ -43,7 +43,7 @@ class OutlierCorrectionSkipper(ConfigReader):
43
43
  save_path = os.path.join(self.outlier_corrected_dir, f"{video_name}.{self.file_type}")
44
44
  write_df(df=data_df, file_type=self.file_type, save_path=save_path)
45
45
  video_timer.stop_timer()
46
- print(f"Skipped outlier correction for video {video_name} (elapsed time {video_timer.elapsed_time_str}s)...")
46
+ stdout_information(msg=f"Skipped outlier correction for video {video_name} (Video {file_cnt+1}/{len(self.input_csv_paths)})", elapsed_time=video_timer.elapsed_time_str)
47
47
  self.timer.stop_timer()
48
48
  stdout_success(msg=f"Skipped outlier correction for {len(self.input_csv_paths)} files", elapsed_time=self.timer.elapsed_time_str)
49
49
 
@@ -17,7 +17,7 @@ from simba.utils.checks import (
17
17
  check_all_file_names_are_represented_in_video_log,
18
18
  check_filepaths_in_iterable_exist, check_int, check_str,
19
19
  check_valid_boolean, check_valid_dataframe, check_valid_dict)
20
- from simba.utils.data import terminate_cpu_pool
20
+ from simba.utils.data import get_cpu_pool, terminate_cpu_pool
21
21
  from simba.utils.enums import Formats
22
22
  from simba.utils.errors import InvalidInputError, NoSpecifiedOutputError
23
23
  from simba.utils.printing import SimbaTimer, stdout_success
@@ -149,6 +149,7 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
149
149
  def run(self):
150
150
  print(f"Processing {len(self.data_paths)} video(s)...")
151
151
  check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
152
+ pool = get_cpu_pool(core_cnt=self.core_cnt, source=self.__class__.__name__)
152
153
  for file_cnt, file_path in enumerate(self.data_paths):
153
154
  video_timer = SimbaTimer(start=True)
154
155
  _, self.video_name, _ = get_fn_ext(file_path)
@@ -173,7 +174,8 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
173
174
  if len(np.unique(clf_data)) == 1:
174
175
  raise InvalidInputError(msg=f'Cannot plot heatmap for behavior {self.clf_name} in video {self.video_name}. The behavior is classified as {np.unique(clf_data)} in every single frame.')
175
176
  grid, aspect_ratio = GeometryMixin.bucket_img_into_grid_square(img_size=(self.width, self.height), bucket_grid_size_mm=self.bin_size, px_per_mm=self.px_per_mm, add_correction=False, verbose=False)
176
- clf_data = GeometryMixin().cumsum_bool_geometries(data=bp_data, geometries=grid, bool_data=clf_data, fps=self.fps, verbose=False)
177
+
178
+ clf_data = GeometryMixin().cumsum_bool_geometries(data=bp_data, geometries=grid, bool_data=clf_data, fps=self.fps, verbose=False, core_cnt=self.core_cnt, pool=pool)
177
179
  if self.max_scale == "auto":
178
180
  self.max_scale = max(1, self.__calculate_max_scale(clf_array=clf_data))
179
181
  if self.final_img_setting:
@@ -197,32 +199,29 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
197
199
  frm_per_core_w_batch.append((batch_cnt, frm_range, frame_arrays[batch_cnt]))
198
200
  del frame_arrays
199
201
  print(f"Creating heatmaps, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})...")
200
- with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool:
201
- constants = functools.partial(_heatmap_multiprocessor,
202
- video_setting=self.video_setting,
203
- frame_setting=self.frame_setting,
204
- style_attr=self.style_attr,
205
- fps=self.fps,
206
- video_temp_dir=self.temp_folder,
207
- frame_dir=self.frames_save_dir,
208
- max_scale=self.max_scale,
209
- aspect_ratio=aspect_ratio,
210
- clf_name=self.clf_name,
211
- size=(self.width, self.height),
212
- video_name=self.video_name,
213
- make_clf_heatmap_plot=self.make_clf_heatmap_plot)
214
-
215
- for cnt, batch in enumerate(pool.imap(constants, frm_per_core_w_batch, chunksize=self.multiprocess_chunksize)):
216
- print(f'Batch core {batch+1}/{self.core_cnt} complete (Video {self.video_name})... ')
217
- terminate_cpu_pool(pool=pool, force=False)
218
-
202
+ constants = functools.partial(_heatmap_multiprocessor,
203
+ video_setting=self.video_setting,
204
+ frame_setting=self.frame_setting,
205
+ style_attr=self.style_attr,
206
+ fps=self.fps,
207
+ video_temp_dir=self.temp_folder,
208
+ frame_dir=self.frames_save_dir,
209
+ max_scale=self.max_scale,
210
+ aspect_ratio=aspect_ratio,
211
+ clf_name=self.clf_name,
212
+ size=(self.width, self.height),
213
+ video_name=self.video_name,
214
+ make_clf_heatmap_plot=self.make_clf_heatmap_plot)
215
+
216
+
217
+ for cnt, batch in enumerate(pool.imap(constants, frm_per_core_w_batch, chunksize=self.multiprocess_chunksize)):
218
+ print(f'Batch core {batch+1}/{self.core_cnt} complete (Video {self.video_name})... ')
219
219
  if self.video_setting:
220
220
  print(f"Joining {self.video_name} multiprocessed video...")
221
221
  concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path)
222
-
223
222
  video_timer.stop_timer()
224
223
  print(f"Heatmap video {self.video_name} complete, (elapsed time: {video_timer.elapsed_time_str}s) ...")
225
-
224
+ terminate_cpu_pool(pool=pool, force=False, source=self.__class__.__name__)
226
225
  self.timer.stop_timer()
227
226
  stdout_success(msg=f"Heatmap visualizations for {len(self.data_paths)} video(s) created in {self.heatmap_clf_location_dir} directory", elapsed_time=self.timer.elapsed_time_str)
228
227
 
@@ -261,3 +260,26 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
261
260
  # core_cnt=5,
262
261
  # files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'])
263
262
  # test.create_heatmaps()
263
+ # if __name__ == "__main__":
264
+ # x = HeatMapperClfMultiprocess(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini",
265
+ # bodypart='nose',
266
+ # clf_name='GROOMING',
267
+ # style_attr={'palette': 'jet', 'shading': 'gouraud', 'bin_size': 25, 'max_scale': 'auto'},
268
+ # final_img_setting=True,
269
+ # video_setting=False,
270
+ # frame_setting=False,
271
+ # core_cnt=12,
272
+ # data_paths=[r"E:\troubleshooting\mitra_emergence\project_folder\csv\machine_results\Box1_180mISOcontrol_Females.csv"])
273
+ #
274
+ # x.run()
275
+
276
+ # def __init__(self,
277
+ # config_path: Union[str, os.PathLike],
278
+ # bodypart: str,
279
+ # clf_name: str,
280
+ # data_paths: List[str],
281
+ # style_attr: dict,
282
+ # final_img_setting: bool = True,
283
+ # video_setting: bool = False,
284
+ # frame_setting: bool = False,
285
+ # core_cnt: int = -1):
@@ -236,7 +236,8 @@ class PlotSklearnResultsSingleCore(ConfigReader, TrainModelMixin, PlottingMixin)
236
236
  self.add_spacer += 1
237
237
  if self.show_confidence:
238
238
  col_name = f'Probability_{clf_name}'
239
- conf_txt = f'{clf_name} CONFIDENCE {self.data_df.loc[frm_idx, col_name]}'
239
+ conf = round(self.data_df.loc[frm_idx, col_name], 4)
240
+ conf_txt = f'{clf_name} CONFIDENCE {conf:.4f}'
240
241
  self.frame = PlottingMixin().put_text(img=self.frame, text=conf_txt, pos=(TextOptions.BORDER_BUFFER_Y.value, ((self.video_meta_data["height"] - self.video_meta_data["height"]) + self.video_space_size * self.add_spacer)), font_size=self.video_font_size, font_thickness=self.video_text_thickness, font=self.font, text_bg_alpha=self.video_text_opacity, text_color_bg=self.text_bg_color, text_color=self.text_color)
241
242
  self.add_spacer += 1
242
243
  self.frame = PlottingMixin().put_text(img=self.frame, text="ENSEMBLE PREDICTION:", pos=(TextOptions.BORDER_BUFFER_Y.value, ((self.video_meta_data["height"] - self.video_meta_data["height"]) + self.video_space_size * self.add_spacer)), font_size=self.video_font_size, font_thickness=self.video_text_thickness, font=self.font, text_bg_alpha=self.video_text_opacity, text_color_bg=self.text_bg_color, text_color=self.text_color)