simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of simba-uw-tf-dev might be problematic. Click here for more details.
- simba/assets/lookups/tooptips.json +6 -1
- simba/data_processors/agg_clf_counter_mp.py +52 -53
- simba/data_processors/cuda/image.py +3 -1
- simba/data_processors/cue_light_analyzer.py +5 -9
- simba/data_processors/kleinberg_calculator.py +57 -29
- simba/mixins/geometry_mixin.py +14 -28
- simba/mixins/image_mixin.py +10 -14
- simba/mixins/train_model_mixin.py +2 -2
- simba/plotting/ROI_feature_visualizer_mp.py +3 -5
- simba/plotting/clf_validator_mp.py +4 -5
- simba/plotting/cue_light_visualizer.py +6 -7
- simba/plotting/directing_animals_visualizer_mp.py +2 -3
- simba/plotting/distance_plotter_mp.py +378 -378
- simba/plotting/gantt_creator_mp.py +61 -31
- simba/plotting/geometry_plotter.py +270 -272
- simba/plotting/heat_mapper_clf_mp.py +2 -4
- simba/plotting/heat_mapper_location_mp.py +2 -2
- simba/plotting/light_dark_box_plotter.py +2 -2
- simba/plotting/path_plotter_mp.py +26 -29
- simba/plotting/plot_clf_results_mp.py +455 -454
- simba/plotting/pose_plotter_mp.py +28 -29
- simba/plotting/probability_plot_creator_mp.py +288 -288
- simba/plotting/roi_plotter_mp.py +31 -31
- simba/plotting/single_run_model_validation_video_mp.py +427 -427
- simba/plotting/spontaneous_alternation_plotter.py +2 -3
- simba/plotting/yolo_pose_track_visualizer.py +32 -27
- simba/plotting/yolo_pose_visualizer.py +35 -36
- simba/plotting/yolo_seg_visualizer.py +2 -3
- simba/roi_tools/roi_aggregate_stats_mp.py +4 -3
- simba/roi_tools/roi_clf_calculator_mp.py +3 -3
- simba/sandbox/get_cpu_pool.py +5 -0
- simba/ui/pop_ups/kleinberg_pop_up.py +39 -41
- simba/ui/tkinter_functions.py +3 -0
- simba/utils/data.py +89 -12
- simba/utils/enums.py +1 -0
- simba/utils/printing.py +124 -124
- simba/utils/read_write.py +3730 -3721
- simba/video_processors/egocentric_video_rotator.py +2 -4
- simba/video_processors/video_processing.py +19 -8
- simba/video_processors/videos_to_frames.py +1 -1
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/METADATA +1 -1
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/RECORD +46 -45
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.6.4.dist-info}/top_level.txt +0 -0
|
@@ -1,32 +1,37 @@
|
|
|
1
1
|
__author__ = "Simon Nilsson; sronilsson@gmail.com"
|
|
2
2
|
|
|
3
|
-
import time
|
|
4
3
|
import warnings
|
|
5
4
|
|
|
6
5
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
7
6
|
import functools
|
|
7
|
+
import gc
|
|
8
8
|
import multiprocessing
|
|
9
9
|
import os
|
|
10
10
|
import platform
|
|
11
|
+
from copy import deepcopy
|
|
11
12
|
from typing import List, Optional, Union
|
|
12
13
|
|
|
13
14
|
import cv2
|
|
15
|
+
import matplotlib
|
|
14
16
|
import numpy as np
|
|
15
17
|
import pandas as pd
|
|
16
18
|
|
|
19
|
+
matplotlib.use('Agg')
|
|
20
|
+
|
|
17
21
|
from simba.mixins.config_reader import ConfigReader
|
|
18
22
|
from simba.mixins.plotting_mixin import PlottingMixin
|
|
19
23
|
from simba.utils.checks import (
|
|
20
24
|
check_all_file_names_are_represented_in_video_log,
|
|
21
25
|
check_file_exist_and_readable, check_int, check_str,
|
|
22
26
|
check_that_column_exist, check_valid_boolean, check_valid_lst)
|
|
23
|
-
from simba.utils.data import create_color_palette, detect_bouts
|
|
27
|
+
from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
|
|
28
|
+
terminate_cpu_pool)
|
|
24
29
|
from simba.utils.enums import Formats, Options
|
|
25
30
|
from simba.utils.errors import NoSpecifiedOutputError
|
|
26
31
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
27
32
|
from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
28
33
|
create_directory, find_core_cnt,
|
|
29
|
-
get_fn_ext, read_df)
|
|
34
|
+
get_current_time, get_fn_ext, read_df)
|
|
30
35
|
|
|
31
36
|
HEIGHT = "height"
|
|
32
37
|
WIDTH = "width"
|
|
@@ -78,11 +83,17 @@ def gantt_creator_mp(data: np.array,
|
|
|
78
83
|
cv2.imwrite(frame_save_path, plot)
|
|
79
84
|
if video_setting:
|
|
80
85
|
video_writer.write(plot)
|
|
81
|
-
|
|
86
|
+
# Clear memory after each frame
|
|
87
|
+
del plot
|
|
88
|
+
if current_frm % 100 == 0: # Periodic garbage collection to prevent memory buildup
|
|
89
|
+
gc.collect()
|
|
90
|
+
print(f"[{get_current_time()}] Gantt frame created: {current_frm + 1}, Video: {video_name}, Processing core: {batch_id + 1}")
|
|
82
91
|
|
|
83
92
|
if video_setting:
|
|
84
93
|
video_writer.release()
|
|
94
|
+
del video_writer
|
|
85
95
|
|
|
96
|
+
gc.collect()
|
|
86
97
|
return batch_id
|
|
87
98
|
|
|
88
99
|
|
|
@@ -120,7 +131,7 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
120
131
|
|
|
121
132
|
def __init__(self,
|
|
122
133
|
config_path: Union[str, os.PathLike],
|
|
123
|
-
data_paths: List[Union[str, os.PathLike]],
|
|
134
|
+
data_paths: Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] = None,
|
|
124
135
|
frame_setting: Optional[bool] = False,
|
|
125
136
|
video_setting: Optional[bool] = False,
|
|
126
137
|
last_frm_setting: Optional[bool] = True,
|
|
@@ -129,14 +140,13 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
129
140
|
font_size: int = 8,
|
|
130
141
|
font_rotation: int = 45,
|
|
131
142
|
palette: str = 'Set1',
|
|
132
|
-
core_cnt:
|
|
143
|
+
core_cnt: int = -1,
|
|
133
144
|
hhmmss: bool = False):
|
|
134
145
|
|
|
135
146
|
check_file_exist_and_readable(file_path=config_path)
|
|
136
147
|
if (not frame_setting) and (not video_setting) and (not last_frm_setting):
|
|
137
148
|
raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.", source=self.__class__.__name__)
|
|
138
149
|
check_file_exist_and_readable(file_path=config_path)
|
|
139
|
-
check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
|
|
140
150
|
check_int(value=width, min_value=1, name=f'{self.__class__.__name__} width')
|
|
141
151
|
check_int(value=height, min_value=1, name=f'{self.__class__.__name__} height')
|
|
142
152
|
check_int(value=font_size, min_value=1, name=f'{self.__class__.__name__} font_size')
|
|
@@ -144,11 +154,18 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
144
154
|
check_valid_boolean(value=hhmmss, source=f'{self.__class__.__name__} hhmmss', raise_error=False)
|
|
145
155
|
palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
|
|
146
156
|
check_str(name=f'{self.__class__.__name__} palette', value=palette, options=palettes)
|
|
147
|
-
for file_path in data_paths: check_file_exist_and_readable(file_path=file_path)
|
|
148
157
|
check_int(name=f"{self.__class__.__name__} core_cnt",value=core_cnt, min_value=-1, unaccepted_vals=[0], max_value=find_core_cnt()[0])
|
|
149
158
|
self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
|
|
150
159
|
self.width, self.height, self.font_size, self.font_rotation, self.hhmmss = width, height, font_size, font_rotation, hhmmss
|
|
151
160
|
ConfigReader.__init__(self, config_path=config_path, create_logger=False)
|
|
161
|
+
if isinstance(data_paths, list):
|
|
162
|
+
check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
|
|
163
|
+
elif isinstance(data_paths, str):
|
|
164
|
+
check_file_exist_and_readable(file_path=data_paths)
|
|
165
|
+
data_paths = [data_paths]
|
|
166
|
+
else:
|
|
167
|
+
data_paths = deepcopy(self.machine_results_paths)
|
|
168
|
+
for file_path in data_paths: check_file_exist_and_readable(file_path=file_path)
|
|
152
169
|
PlottingMixin.__init__(self)
|
|
153
170
|
self.clr_lst = create_color_palette(pallete_name=palette, increments=len(self.body_parts_lst) + 1, as_int=True, as_rgb_ratio=True)
|
|
154
171
|
self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting = frame_setting, video_setting,data_paths, last_frm_setting
|
|
@@ -159,6 +176,10 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
159
176
|
|
|
160
177
|
def run(self):
|
|
161
178
|
check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
|
|
179
|
+
if self.video_setting or self.frame_setting:
|
|
180
|
+
self.pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=self.maxtasksperchild, verbose=True, source=self.__class__.__name__)
|
|
181
|
+
else:
|
|
182
|
+
self.pool = None
|
|
162
183
|
for file_cnt, file_path in enumerate(self.data_paths):
|
|
163
184
|
video_timer = SimbaTimer(start=True)
|
|
164
185
|
_, self.video_name, _ = get_fn_ext(file_path)
|
|
@@ -192,33 +213,31 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
192
213
|
if self.video_setting or self.frame_setting:
|
|
193
214
|
frame_data = np.array_split(list(range(0, len(self.data_df))), self.core_cnt)
|
|
194
215
|
frame_data = [(i, x) for i, x in enumerate(frame_data)]
|
|
195
|
-
print(f"Creating gantt, multiprocessing (chunksize: {(self.multiprocess_chunksize)}, cores: {self.core_cnt})...")
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
print(f'Batch {result+1/self.core_cnt} complete...')
|
|
214
|
-
pool.terminate()
|
|
215
|
-
pool.join()
|
|
216
|
+
print(f"[{get_current_time()}] Creating gantt, multiprocessing (chunksize: {(self.multiprocess_chunksize)}, cores: {self.core_cnt})...")
|
|
217
|
+
constants = functools.partial(gantt_creator_mp,
|
|
218
|
+
video_setting=self.video_setting,
|
|
219
|
+
frame_setting=self.frame_setting,
|
|
220
|
+
video_save_dir=self.temp_folder,
|
|
221
|
+
frame_folder_dir=self.save_frame_folder_dir,
|
|
222
|
+
bouts_df=self.bouts_df,
|
|
223
|
+
clf_names=self.clf_names,
|
|
224
|
+
fps=self.fps,
|
|
225
|
+
width=self.width,
|
|
226
|
+
height=self.height,
|
|
227
|
+
font_size=self.font_size,
|
|
228
|
+
font_rotation=self.font_rotation,
|
|
229
|
+
video_name=self.video_name,
|
|
230
|
+
palette=self.clr_lst,
|
|
231
|
+
hhmmss=self.hhmmss)
|
|
232
|
+
for cnt, result in enumerate(self.pool.imap(constants, frame_data, chunksize=self.multiprocess_chunksize)):
|
|
233
|
+
print(f'[{get_current_time()}] Batch {result+1}/{self.core_cnt} complete...')
|
|
216
234
|
if self.video_setting:
|
|
217
|
-
print(f"Joining {self.video_name} multiprocessed video...")
|
|
235
|
+
print(f"[{get_current_time()}] Joining {self.video_name} multiprocessed video...")
|
|
218
236
|
concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path)
|
|
219
237
|
video_timer.stop_timer()
|
|
220
238
|
print(f"Gantt video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s) ...")
|
|
221
239
|
|
|
240
|
+
terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__)
|
|
222
241
|
self.timer.stop_timer()
|
|
223
242
|
stdout_success(msg=f"Gantt visualizations for {len(self.data_paths)} videos created in {self.gantt_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str)
|
|
224
243
|
|
|
@@ -235,7 +254,18 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
235
254
|
# font_rotation= 45)
|
|
236
255
|
# test.run()
|
|
237
256
|
|
|
238
|
-
|
|
257
|
+
if __name__ == "__main__":
|
|
258
|
+
test = GanttCreatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
|
|
259
|
+
frame_setting=False,
|
|
260
|
+
video_setting=True,
|
|
261
|
+
data_paths=r"D:\troubleshooting\maplight_ri\project_folder\csv\machine_results\Trial_1_C24_D1_1.csv",
|
|
262
|
+
last_frm_setting=False,
|
|
263
|
+
width=640,
|
|
264
|
+
height= 480,
|
|
265
|
+
font_size=10,
|
|
266
|
+
font_rotation= 45,
|
|
267
|
+
core_cnt=16)
|
|
268
|
+
test.run()
|
|
239
269
|
|
|
240
270
|
|
|
241
271
|
# test = GanttCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini',
|