simba-uw-tf-dev 4.7.1__py3-none-any.whl → 4.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simba/SimBA.py +13 -4
- simba/assets/icons/left_arrow_green.png +0 -0
- simba/assets/icons/left_arrow_red.png +0 -0
- simba/assets/icons/right_arrow_green.png +0 -0
- simba/assets/icons/right_arrow_red.png +0 -0
- simba/assets/lookups/yolo_schematics/yolo_mitra.csv +9 -0
- simba/mixins/geometry_mixin.py +357 -302
- simba/mixins/image_mixin.py +129 -4
- simba/mixins/train_model_mixin.py +1 -4
- simba/model/inference_batch.py +1 -1
- simba/model/yolo_fit.py +22 -15
- simba/model/yolo_pose_inference.py +7 -2
- simba/outlier_tools/skip_outlier_correction.py +2 -2
- simba/plotting/heat_mapper_clf_mp.py +45 -23
- simba/plotting/plot_clf_results.py +2 -1
- simba/plotting/plot_clf_results_mp.py +456 -455
- simba/roi_tools/roi_utils.py +2 -2
- simba/sandbox/convert_h264_to_mp4_lossless.py +129 -0
- simba/sandbox/extract_and_convert_videos.py +257 -0
- simba/sandbox/remove_end_of_video.py +80 -0
- simba/sandbox/video_timelaps.py +291 -0
- simba/third_party_label_appenders/transform/simba_to_yolo.py +8 -5
- simba/ui/import_pose_frame.py +13 -13
- simba/ui/pop_ups/clf_plot_pop_up.py +1 -1
- simba/ui/pop_ups/run_machine_models_popup.py +22 -22
- simba/ui/pop_ups/simba_to_yolo_keypoints_popup.py +2 -2
- simba/ui/pop_ups/video_processing_pop_up.py +3638 -3469
- simba/ui/pop_ups/yolo_inference_popup.py +1 -1
- simba/ui/pop_ups/yolo_pose_train_popup.py +1 -1
- simba/ui/tkinter_functions.py +3 -1
- simba/ui/video_timelaps.py +454 -0
- simba/utils/lookups.py +67 -1
- simba/utils/read_write.py +10 -3
- simba/video_processors/batch_process_create_ffmpeg_commands.py +0 -1
- simba/video_processors/video_processing.py +160 -39
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/METADATA +1 -1
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/RECORD +41 -31
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.7.1.dist-info → simba_uw_tf_dev-4.7.5.dist-info}/top_level.txt +0 -0
|
@@ -1,456 +1,457 @@
|
|
|
1
|
-
__author__ = "Simon Nilsson; sronilsson@gmail.com"
|
|
2
|
-
|
|
3
|
-
import functools
|
|
4
|
-
import multiprocessing
|
|
5
|
-
import os
|
|
6
|
-
import platform
|
|
7
|
-
from copy import deepcopy
|
|
8
|
-
from typing import List, Optional, Tuple, Union
|
|
9
|
-
|
|
10
|
-
import cv2
|
|
11
|
-
import numpy as np
|
|
12
|
-
import pandas as pd
|
|
13
|
-
|
|
14
|
-
from simba.mixins.config_reader import ConfigReader
|
|
15
|
-
from simba.mixins.geometry_mixin import GeometryMixin
|
|
16
|
-
from simba.mixins.plotting_mixin import PlottingMixin
|
|
17
|
-
from simba.mixins.train_model_mixin import TrainModelMixin
|
|
18
|
-
from simba.utils.checks import (check_float, check_if_valid_rgb_tuple,
|
|
19
|
-
check_int, check_nvidea_gpu_available,
|
|
20
|
-
check_str, check_that_column_exist,
|
|
21
|
-
check_valid_boolean,
|
|
22
|
-
check_video_and_data_frm_count_align)
|
|
23
|
-
from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
|
|
24
|
-
terminate_cpu_pool)
|
|
25
|
-
from simba.utils.enums import ConfigKey, Dtypes, Options, TagNames, TextOptions
|
|
26
|
-
from simba.utils.errors import (InvalidInputError, NoDataError,
|
|
27
|
-
NoSpecifiedOutputError)
|
|
28
|
-
from simba.utils.lookups import get_current_time
|
|
29
|
-
from simba.utils.printing import SimbaTimer, log_event, stdout_success
|
|
30
|
-
from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
31
|
-
create_directory,
|
|
32
|
-
find_all_videos_in_project, find_core_cnt,
|
|
33
|
-
get_fn_ext, get_video_meta_data,
|
|
34
|
-
read_config_entry, read_df)
|
|
35
|
-
from simba.utils.warnings import FrameRangeWarning
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def _multiprocess_sklearn_video(data: pd.DataFrame,
|
|
39
|
-
bp_dict: dict,
|
|
40
|
-
video_save_dir: str,
|
|
41
|
-
frame_save_dir: str,
|
|
42
|
-
clf_cumsum: dict,
|
|
43
|
-
rotate: bool,
|
|
44
|
-
video_path: str,
|
|
45
|
-
print_timers: bool,
|
|
46
|
-
video_setting: bool,
|
|
47
|
-
frame_setting: bool,
|
|
48
|
-
pose_threshold: float,
|
|
49
|
-
clf_confidence: Union[dict, None],
|
|
50
|
-
show_pose: bool,
|
|
51
|
-
show_animal_names: bool,
|
|
52
|
-
show_bbox: bool,
|
|
53
|
-
circle_size: int,
|
|
54
|
-
font_size: int,
|
|
55
|
-
space_size: int,
|
|
56
|
-
text_thickness: int,
|
|
57
|
-
text_opacity: float,
|
|
58
|
-
text_bg_clr: Tuple[int, int, int],
|
|
59
|
-
text_color: Tuple[int, int, int],
|
|
60
|
-
pose_clr_lst: List[Tuple[int, int, int]],
|
|
61
|
-
show_gantt: Optional[int],
|
|
62
|
-
bouts_df: Optional[pd.DataFrame],
|
|
63
|
-
final_gantt: Optional[np.ndarray],
|
|
64
|
-
gantt_clrs: List[Tuple[float, float, float]],
|
|
65
|
-
clf_names: List[str],
|
|
66
|
-
verbose:bool):
|
|
67
|
-
|
|
68
|
-
fourcc, font = cv2.VideoWriter_fourcc(*"mp4v"), cv2.FONT_HERSHEY_DUPLEX
|
|
69
|
-
video_meta_data = get_video_meta_data(video_path=video_path)
|
|
70
|
-
if rotate:
|
|
71
|
-
video_meta_data["height"], video_meta_data["width"] = (video_meta_data['width'], video_meta_data['height'])
|
|
72
|
-
cap = cv2.VideoCapture(video_path)
|
|
73
|
-
batch, data = data
|
|
74
|
-
start_frm, current_frm, end_frm = (data["index"].iloc[0], data["index"].iloc[0], data["index"].iloc[-1])
|
|
75
|
-
if video_setting:
|
|
76
|
-
video_save_path = os.path.join(video_save_dir, f"{batch}.mp4")
|
|
77
|
-
if show_gantt is None:
|
|
78
|
-
video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"], video_meta_data["height"]))
|
|
79
|
-
else:
|
|
80
|
-
video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (int(video_meta_data["width"] + final_gantt.shape[1]), video_meta_data["height"]))
|
|
81
|
-
cap.set(1, start_frm)
|
|
82
|
-
while current_frm < end_frm:
|
|
83
|
-
ret, img = cap.read()
|
|
84
|
-
if ret:
|
|
85
|
-
clr_cnt = 0
|
|
86
|
-
for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
|
|
87
|
-
if show_pose:
|
|
88
|
-
for bp_no in range(len(animal_data["X_bps"])):
|
|
89
|
-
x_bp, y_bp, p_bp = (animal_data["X_bps"][bp_no], animal_data["Y_bps"][bp_no], animal_data["P_bps"][bp_no])
|
|
90
|
-
bp_cords = data.loc[current_frm, [x_bp, y_bp, p_bp]]
|
|
91
|
-
if bp_cords[p_bp] >= pose_threshold:
|
|
92
|
-
img = cv2.circle(img, (int(bp_cords[x_bp]), int(bp_cords[y_bp])), circle_size, pose_clr_lst[clr_cnt], -1)
|
|
93
|
-
clr_cnt += 1
|
|
94
|
-
if show_animal_names:
|
|
95
|
-
x_bp, y_bp, p_bp = (animal_data["X_bps"][0], animal_data["Y_bps"][0], animal_data["P_bps"][0])
|
|
96
|
-
bp_cords = data.loc[current_frm, [x_bp, y_bp, p_bp]]
|
|
97
|
-
img = cv2.putText(img, animal_name, (int(bp_cords[x_bp]), int(bp_cords[y_bp])), font, font_size, pose_clr_lst[0], text_thickness)
|
|
98
|
-
if show_bbox:
|
|
99
|
-
animal_headers = [val for pair in zip(animal_data["X_bps"], animal_data["Y_bps"]) for val in pair]
|
|
100
|
-
animal_cords = data.loc[current_frm, animal_headers].values.reshape(-1, 2).astype(np.int32)
|
|
101
|
-
try:
|
|
102
|
-
bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
|
|
103
|
-
img = cv2.polylines(img, [bbox], True, pose_clr_lst[animal_cnt], thickness=circle_size, lineType=cv2.LINE_AA)
|
|
104
|
-
except Exception as e:
|
|
105
|
-
#print(e.args)
|
|
106
|
-
pass
|
|
107
|
-
if rotate:
|
|
108
|
-
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
|
109
|
-
if show_gantt == 1:
|
|
110
|
-
img = np.concatenate((img, final_gantt), axis=1)
|
|
111
|
-
elif show_gantt == 2:
|
|
112
|
-
bout_rows = bouts_df.loc[bouts_df["End_frame"] <= current_frm]
|
|
113
|
-
gantt_plot = PlottingMixin().make_gantt_plot(x_length=current_frm + 1,
|
|
114
|
-
bouts_df=bout_rows,
|
|
115
|
-
clf_names=clf_names,
|
|
116
|
-
fps=video_meta_data['fps'],
|
|
117
|
-
width=video_meta_data['width'],
|
|
118
|
-
height=video_meta_data['height'],
|
|
119
|
-
font_size=12,
|
|
120
|
-
font_rotation=90,
|
|
121
|
-
video_name=video_meta_data['video_name'],
|
|
122
|
-
save_path=None,
|
|
123
|
-
palette=gantt_clrs)
|
|
124
|
-
img = np.concatenate((img, gantt_plot), axis=1)
|
|
125
|
-
if print_timers:
|
|
126
|
-
img = PlottingMixin().put_text(img=img, text="TIMERS:", pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
|
|
127
|
-
add_spacer = 2
|
|
128
|
-
for clf_name, clf_time_df in clf_cumsum.items():
|
|
129
|
-
frame_results = clf_time_df.loc[current_frm]
|
|
130
|
-
clf_time = round(frame_results / video_meta_data['fps'], 2)
|
|
131
|
-
if print_timers:
|
|
132
|
-
img = PlottingMixin().put_text(img=img, text=f"{clf_name} {clf_time}",pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
|
|
133
|
-
add_spacer += 1
|
|
134
|
-
if clf_confidence is not None:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
add_spacer
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
add_spacer
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
add_spacer
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
:
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
:
|
|
182
|
-
:
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
:
|
|
187
|
-
:
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
:param
|
|
191
|
-
:param bool
|
|
192
|
-
:param
|
|
193
|
-
:param
|
|
194
|
-
:param bool
|
|
195
|
-
:param bool
|
|
196
|
-
:param
|
|
197
|
-
:param Optional[Union[int, float]]
|
|
198
|
-
:param Optional[Union[int, float]]
|
|
199
|
-
:param Optional[Union[int, float]]
|
|
200
|
-
:param Optional[Union[int, float]]
|
|
201
|
-
:param Optional[
|
|
202
|
-
:param
|
|
203
|
-
:param bool
|
|
204
|
-
:param
|
|
205
|
-
:param
|
|
206
|
-
:param Tuple[int, int, int]
|
|
207
|
-
:param
|
|
208
|
-
:param
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
...
|
|
214
|
-
...
|
|
215
|
-
...
|
|
216
|
-
...
|
|
217
|
-
...
|
|
218
|
-
...
|
|
219
|
-
...
|
|
220
|
-
...
|
|
221
|
-
...
|
|
222
|
-
...
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
if
|
|
261
|
-
if
|
|
262
|
-
if
|
|
263
|
-
if circle_size is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=circle_size, min_value=0.1)
|
|
264
|
-
if
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
self.
|
|
268
|
-
|
|
269
|
-
check_if_valid_rgb_tuple(data=
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
self.
|
|
275
|
-
self.
|
|
276
|
-
self.
|
|
277
|
-
self.
|
|
278
|
-
self.
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
elif video_paths
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
self.core_cnt =
|
|
296
|
-
self.
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
self.
|
|
307
|
-
self.
|
|
308
|
-
self.
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
self.
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
self.
|
|
318
|
-
self.
|
|
319
|
-
|
|
320
|
-
if self.
|
|
321
|
-
self.
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
self.
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
self.
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
self.
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
self.
|
|
340
|
-
self.
|
|
341
|
-
self.final_gantt_img =
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
data =
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
self.
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
#
|
|
421
|
-
#
|
|
422
|
-
#
|
|
423
|
-
#
|
|
424
|
-
#
|
|
425
|
-
#
|
|
426
|
-
#
|
|
427
|
-
#
|
|
428
|
-
#
|
|
429
|
-
#
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
#
|
|
436
|
-
#
|
|
437
|
-
#
|
|
438
|
-
#
|
|
439
|
-
#
|
|
440
|
-
#
|
|
441
|
-
#
|
|
442
|
-
#
|
|
443
|
-
#
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
# clf_plotter.
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
#
|
|
450
|
-
#
|
|
451
|
-
#
|
|
452
|
-
#
|
|
453
|
-
#
|
|
454
|
-
#
|
|
455
|
-
#
|
|
1
|
+
__author__ = "Simon Nilsson; sronilsson@gmail.com"
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import multiprocessing
|
|
5
|
+
import os
|
|
6
|
+
import platform
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from typing import List, Optional, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import cv2
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from simba.mixins.config_reader import ConfigReader
|
|
15
|
+
from simba.mixins.geometry_mixin import GeometryMixin
|
|
16
|
+
from simba.mixins.plotting_mixin import PlottingMixin
|
|
17
|
+
from simba.mixins.train_model_mixin import TrainModelMixin
|
|
18
|
+
from simba.utils.checks import (check_float, check_if_valid_rgb_tuple,
|
|
19
|
+
check_int, check_nvidea_gpu_available,
|
|
20
|
+
check_str, check_that_column_exist,
|
|
21
|
+
check_valid_boolean,
|
|
22
|
+
check_video_and_data_frm_count_align)
|
|
23
|
+
from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
|
|
24
|
+
terminate_cpu_pool)
|
|
25
|
+
from simba.utils.enums import ConfigKey, Dtypes, Options, TagNames, TextOptions
|
|
26
|
+
from simba.utils.errors import (InvalidInputError, NoDataError,
|
|
27
|
+
NoSpecifiedOutputError)
|
|
28
|
+
from simba.utils.lookups import get_current_time
|
|
29
|
+
from simba.utils.printing import SimbaTimer, log_event, stdout_success
|
|
30
|
+
from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
31
|
+
create_directory,
|
|
32
|
+
find_all_videos_in_project, find_core_cnt,
|
|
33
|
+
get_fn_ext, get_video_meta_data,
|
|
34
|
+
read_config_entry, read_df)
|
|
35
|
+
from simba.utils.warnings import FrameRangeWarning
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _multiprocess_sklearn_video(data: pd.DataFrame,
|
|
39
|
+
bp_dict: dict,
|
|
40
|
+
video_save_dir: str,
|
|
41
|
+
frame_save_dir: str,
|
|
42
|
+
clf_cumsum: dict,
|
|
43
|
+
rotate: bool,
|
|
44
|
+
video_path: str,
|
|
45
|
+
print_timers: bool,
|
|
46
|
+
video_setting: bool,
|
|
47
|
+
frame_setting: bool,
|
|
48
|
+
pose_threshold: float,
|
|
49
|
+
clf_confidence: Union[dict, None],
|
|
50
|
+
show_pose: bool,
|
|
51
|
+
show_animal_names: bool,
|
|
52
|
+
show_bbox: bool,
|
|
53
|
+
circle_size: int,
|
|
54
|
+
font_size: int,
|
|
55
|
+
space_size: int,
|
|
56
|
+
text_thickness: int,
|
|
57
|
+
text_opacity: float,
|
|
58
|
+
text_bg_clr: Tuple[int, int, int],
|
|
59
|
+
text_color: Tuple[int, int, int],
|
|
60
|
+
pose_clr_lst: List[Tuple[int, int, int]],
|
|
61
|
+
show_gantt: Optional[int],
|
|
62
|
+
bouts_df: Optional[pd.DataFrame],
|
|
63
|
+
final_gantt: Optional[np.ndarray],
|
|
64
|
+
gantt_clrs: List[Tuple[float, float, float]],
|
|
65
|
+
clf_names: List[str],
|
|
66
|
+
verbose:bool):
|
|
67
|
+
|
|
68
|
+
fourcc, font = cv2.VideoWriter_fourcc(*"mp4v"), cv2.FONT_HERSHEY_DUPLEX
|
|
69
|
+
video_meta_data = get_video_meta_data(video_path=video_path)
|
|
70
|
+
if rotate:
|
|
71
|
+
video_meta_data["height"], video_meta_data["width"] = (video_meta_data['width'], video_meta_data['height'])
|
|
72
|
+
cap = cv2.VideoCapture(video_path)
|
|
73
|
+
batch, data = data
|
|
74
|
+
start_frm, current_frm, end_frm = (data["index"].iloc[0], data["index"].iloc[0], data["index"].iloc[-1])
|
|
75
|
+
if video_setting:
|
|
76
|
+
video_save_path = os.path.join(video_save_dir, f"{batch}.mp4")
|
|
77
|
+
if show_gantt is None:
|
|
78
|
+
video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"], video_meta_data["height"]))
|
|
79
|
+
else:
|
|
80
|
+
video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (int(video_meta_data["width"] + final_gantt.shape[1]), video_meta_data["height"]))
|
|
81
|
+
cap.set(1, start_frm)
|
|
82
|
+
while current_frm < end_frm:
|
|
83
|
+
ret, img = cap.read()
|
|
84
|
+
if ret:
|
|
85
|
+
clr_cnt = 0
|
|
86
|
+
for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
|
|
87
|
+
if show_pose:
|
|
88
|
+
for bp_no in range(len(animal_data["X_bps"])):
|
|
89
|
+
x_bp, y_bp, p_bp = (animal_data["X_bps"][bp_no], animal_data["Y_bps"][bp_no], animal_data["P_bps"][bp_no])
|
|
90
|
+
bp_cords = data.loc[current_frm, [x_bp, y_bp, p_bp]]
|
|
91
|
+
if bp_cords[p_bp] >= pose_threshold:
|
|
92
|
+
img = cv2.circle(img, (int(bp_cords[x_bp]), int(bp_cords[y_bp])), circle_size, pose_clr_lst[clr_cnt], -1)
|
|
93
|
+
clr_cnt += 1
|
|
94
|
+
if show_animal_names:
|
|
95
|
+
x_bp, y_bp, p_bp = (animal_data["X_bps"][0], animal_data["Y_bps"][0], animal_data["P_bps"][0])
|
|
96
|
+
bp_cords = data.loc[current_frm, [x_bp, y_bp, p_bp]]
|
|
97
|
+
img = cv2.putText(img, animal_name, (int(bp_cords[x_bp]), int(bp_cords[y_bp])), font, font_size, pose_clr_lst[0], text_thickness)
|
|
98
|
+
if show_bbox:
|
|
99
|
+
animal_headers = [val for pair in zip(animal_data["X_bps"], animal_data["Y_bps"]) for val in pair]
|
|
100
|
+
animal_cords = data.loc[current_frm, animal_headers].values.reshape(-1, 2).astype(np.int32)
|
|
101
|
+
try:
|
|
102
|
+
bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
|
|
103
|
+
img = cv2.polylines(img, [bbox], True, pose_clr_lst[animal_cnt], thickness=circle_size, lineType=cv2.LINE_AA)
|
|
104
|
+
except Exception as e:
|
|
105
|
+
#print(e.args)
|
|
106
|
+
pass
|
|
107
|
+
if rotate:
|
|
108
|
+
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
|
109
|
+
if show_gantt == 1:
|
|
110
|
+
img = np.concatenate((img, final_gantt), axis=1)
|
|
111
|
+
elif show_gantt == 2:
|
|
112
|
+
bout_rows = bouts_df.loc[bouts_df["End_frame"] <= current_frm]
|
|
113
|
+
gantt_plot = PlottingMixin().make_gantt_plot(x_length=current_frm + 1,
|
|
114
|
+
bouts_df=bout_rows,
|
|
115
|
+
clf_names=clf_names,
|
|
116
|
+
fps=video_meta_data['fps'],
|
|
117
|
+
width=video_meta_data['width'],
|
|
118
|
+
height=video_meta_data['height'],
|
|
119
|
+
font_size=12,
|
|
120
|
+
font_rotation=90,
|
|
121
|
+
video_name=video_meta_data['video_name'],
|
|
122
|
+
save_path=None,
|
|
123
|
+
palette=gantt_clrs)
|
|
124
|
+
img = np.concatenate((img, gantt_plot), axis=1)
|
|
125
|
+
if print_timers:
|
|
126
|
+
img = PlottingMixin().put_text(img=img, text="TIMERS:", pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
|
|
127
|
+
add_spacer = 2
|
|
128
|
+
for clf_name, clf_time_df in clf_cumsum.items():
|
|
129
|
+
frame_results = clf_time_df.loc[current_frm]
|
|
130
|
+
clf_time = round(frame_results / video_meta_data['fps'], 2)
|
|
131
|
+
if print_timers:
|
|
132
|
+
img = PlottingMixin().put_text(img=img, text=f"{clf_name} {clf_time}",pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
|
|
133
|
+
add_spacer += 1
|
|
134
|
+
if clf_confidence is not None:
|
|
135
|
+
conf = round(clf_confidence[clf_name][current_frm], 4)
|
|
136
|
+
frm_clf_conf_txt = f'{clf_name} CONFIDENCE: {conf:.4f}'
|
|
137
|
+
img = PlottingMixin().put_text(img=img, text=frm_clf_conf_txt,pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
|
|
138
|
+
add_spacer += 1
|
|
139
|
+
|
|
140
|
+
img = PlottingMixin().put_text(img=img, text="ENSEMBLE PREDICTION:", pos=(TextOptions.BORDER_BUFFER_Y.value, ((video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer)), font_size=font_size, font_thickness=text_thickness, font=font, text_bg_alpha=text_opacity, text_color_bg=text_bg_clr, text_color=text_color)
|
|
141
|
+
add_spacer += 1
|
|
142
|
+
for clf_name in clf_cumsum.keys():
|
|
143
|
+
if data.loc[current_frm, clf_name] == 1:
|
|
144
|
+
img = PlottingMixin().put_text(img=img, text=clf_name, pos=(TextOptions.BORDER_BUFFER_Y.value, (video_meta_data["height"] - video_meta_data["height"]) + space_size * add_spacer), font_size=font_size, font_thickness=text_thickness, font=font, text_color=TextOptions.COLOR.value, text_bg_alpha=text_opacity)
|
|
145
|
+
add_spacer += 1
|
|
146
|
+
if video_setting:
|
|
147
|
+
video_writer.write(img.astype(np.uint8))
|
|
148
|
+
if frame_setting:
|
|
149
|
+
frame_save_name = os.path.join(frame_save_dir, f"{current_frm}.png")
|
|
150
|
+
cv2.imwrite(frame_save_name, img)
|
|
151
|
+
current_frm += 1
|
|
152
|
+
if verbose: print(f"[{get_current_time()}] Multi-processing video frame {current_frm}/{video_meta_data['frame_count']} (core batch: {batch}, video name: {video_meta_data['video_name']})...")
|
|
153
|
+
else:
|
|
154
|
+
FrameRangeWarning(msg=f'Could not read frame {current_frm} in video {video_path}. Stopping video creation.')
|
|
155
|
+
break
|
|
156
|
+
|
|
157
|
+
cap.release()
|
|
158
|
+
if video_setting:
|
|
159
|
+
video_writer.release()
|
|
160
|
+
return batch
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class PlotSklearnResultsMultiProcess(ConfigReader, TrainModelMixin, PlottingMixin):
|
|
164
|
+
"""
|
|
165
|
+
Plot classification results on videos using multiprocessing. Results are stored in the
|
|
166
|
+
`project_folder/frames/output/sklearn_results` directory of the SimBA project.
|
|
167
|
+
|
|
168
|
+
This class creates annotated videos/frames showing classifier predictions overlaid on pose-estimation data,
|
|
169
|
+
with optional Gantt charts, timers, and bounding boxes. Processing is parallelized across multiple CPU cores
|
|
170
|
+
for faster rendering of large video datasets.
|
|
171
|
+
|
|
172
|
+
.. seealso::
|
|
173
|
+
`Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#step-10-sklearn-visualization>`__.
|
|
174
|
+
For single-core processing, see :meth:`simba.plotting.plot_clf_results.PlotSklearnResultsSingleCore`.
|
|
175
|
+
|
|
176
|
+
.. image:: _static/img/sklearn_visualization.gif
|
|
177
|
+
:width: 600
|
|
178
|
+
:align: center
|
|
179
|
+
|
|
180
|
+
.. video:: _static/img/T1.webm
|
|
181
|
+
:width: 1000
|
|
182
|
+
:autoplay:
|
|
183
|
+
:loop:
|
|
184
|
+
|
|
185
|
+
.. youtube:: Frq6mMcaHBc
|
|
186
|
+
:width: 640
|
|
187
|
+
:height: 480
|
|
188
|
+
:align: center
|
|
189
|
+
|
|
190
|
+
:param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format.
|
|
191
|
+
:param bool video_setting: If True, creates compressed MP4 videos. Default True.
|
|
192
|
+
:param bool frame_setting: If True, saves individual annotated frames as PNG images. Default False.
|
|
193
|
+
:param Optional[Union[List[Union[str, os.PathLike]], Union[str, os.PathLike]]] video_paths: Path(s) to video file(s) to process. If None, processes all videos found in the project's video directory. Default None.
|
|
194
|
+
:param bool rotate: If True, rotates output videos 90 degrees clockwise. Default False.
|
|
195
|
+
:param bool animal_names: If True, displays animal names on the video frames. Default False.
|
|
196
|
+
:param bool show_pose: If True, overlays pose-estimation keypoints on the video. Default True.
|
|
197
|
+
:param Optional[Union[int, float]] font_size: Font size for text overlays. If None, auto-computed based on video resolution. Default None.
|
|
198
|
+
:param Optional[Union[int, float]] space_size: Vertical spacing between text lines. If None, auto-computed. Default None.
|
|
199
|
+
:param Optional[Union[int, float]] text_thickness: Thickness of text characters. If None, uses default. Default None.
|
|
200
|
+
:param Optional[Union[int, float]] text_opacity: Opacity of text background (0.0-1.0). If None, defaults to 0.8. Default None.
|
|
201
|
+
:param Optional[Union[int, float]] circle_size: Radius of pose keypoint circles. If None, auto-computed based on video resolution. Default None.
|
|
202
|
+
:param Optional[str] pose_palette: Name of color palette for pose keypoints. Must be from :class:`simba.utils.enums.Options.PALETTE_OPTIONS_CATEGORICAL` or :class:`simba.utils.enums.Options.PALETTE_OPTIONS`. Default 'Set1'.
|
|
203
|
+
:param bool print_timers: If True, displays cumulative time for each classifier behavior on each frame. Default True.
|
|
204
|
+
:param bool show_bbox: If True, draws axis-aligned bounding boxes around detected animals. Default False.
|
|
205
|
+
:param Optional[int] show_gantt: If 1, appends static Gantt chart to video. If 2, appends dynamic Gantt chart that updates per frame. If None, no Gantt chart. Default None.
|
|
206
|
+
:param Tuple[int, int, int] text_clr: RGB color tuple for text foreground. Default (255, 255, 255) (white).
|
|
207
|
+
:param Tuple[int, int, int] text_bg_clr: RGB color tuple for text background. Default (0, 0, 0) (black).
|
|
208
|
+
:param bool gpu: If True, uses GPU acceleration for video concatenation (requires CUDA-capable GPU). Default False.
|
|
209
|
+
:param int core_cnt: Number of CPU cores to use for parallel processing. Pass -1 to use all available cores. Default -1.
|
|
210
|
+
|
|
211
|
+
:example:
|
|
212
|
+
>>> clf_plotter = PlotSklearnResultsMultiProcess(
|
|
213
|
+
... config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini',
|
|
214
|
+
... video_setting=True,
|
|
215
|
+
... frame_setting=False,
|
|
216
|
+
... video_paths='Trial_10.mp4',
|
|
217
|
+
... rotate=False,
|
|
218
|
+
... show_pose=True,
|
|
219
|
+
... show_bbox=True,
|
|
220
|
+
... print_timers=True,
|
|
221
|
+
... show_gantt=1,
|
|
222
|
+
... core_cnt=5
|
|
223
|
+
... )
|
|
224
|
+
>>> clf_plotter.run()
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
def __init__(self,
|
|
228
|
+
config_path: Union[str, os.PathLike],
|
|
229
|
+
video_setting: bool = True,
|
|
230
|
+
frame_setting: bool = False,
|
|
231
|
+
video_paths: Optional[Union[List[Union[str, os.PathLike]], Union[str, os.PathLike]]] = None,
|
|
232
|
+
rotate: bool = False,
|
|
233
|
+
animal_names: bool = False,
|
|
234
|
+
show_pose: bool = True,
|
|
235
|
+
show_confidence: bool = False,
|
|
236
|
+
font_size: Optional[Union[int, float]] = None,
|
|
237
|
+
space_size: Optional[Union[int, float]] = None,
|
|
238
|
+
text_thickness: Optional[Union[int, float]] = None,
|
|
239
|
+
text_opacity: Optional[Union[int, float]] = None,
|
|
240
|
+
circle_size: Optional[Union[int, float]] = None,
|
|
241
|
+
pose_palette: Optional[str] = 'Set1',
|
|
242
|
+
print_timers: bool = True,
|
|
243
|
+
show_bbox: bool = False,
|
|
244
|
+
show_gantt: Optional[int] = None,
|
|
245
|
+
text_clr: Tuple[int, int, int] = (255, 255, 255),
|
|
246
|
+
text_bg_clr: Tuple[int, int, int] = (0, 0, 0),
|
|
247
|
+
gpu: bool = False,
|
|
248
|
+
verbose: bool = True,
|
|
249
|
+
core_cnt: int = -1):
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
ConfigReader.__init__(self, config_path=config_path)
|
|
253
|
+
TrainModelMixin.__init__(self)
|
|
254
|
+
PlottingMixin.__init__(self)
|
|
255
|
+
log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals()))
|
|
256
|
+
for i in [video_setting, frame_setting, rotate, print_timers, animal_names, show_pose, gpu, show_bbox, show_confidence]:
|
|
257
|
+
check_valid_boolean(value=i, source=self.__class__.__name__, raise_error=True)
|
|
258
|
+
if (not video_setting) and (not frame_setting):
|
|
259
|
+
raise NoSpecifiedOutputError(msg="Please choose to create a video and/or frames. SimBA found that you ticked neither video and/or frames", source=self.__class__.__name__)
|
|
260
|
+
if font_size is not None: check_float(name=f'{self.__class__.__name__} font_size', value=font_size, min_value=0.1)
|
|
261
|
+
if space_size is not None: check_float(name=f'{self.__class__.__name__} space_size', value=space_size, min_value=0.1)
|
|
262
|
+
if text_thickness is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=text_thickness, min_value=0.1)
|
|
263
|
+
if circle_size is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=circle_size, min_value=0.1)
|
|
264
|
+
if circle_size is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=circle_size, min_value=0.1)
|
|
265
|
+
if text_opacity is not None: check_float(name=f'{self.__class__.__name__} text_opacity', value=text_opacity, min_value=0.1)
|
|
266
|
+
pose_palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
|
|
267
|
+
check_str(name=f'{self.__class__.__name__} pose_palette', value=pose_palette, options=pose_palettes)
|
|
268
|
+
self.clr_lst = create_color_palette(pallete_name=pose_palette, increments=len(self.body_parts_lst)+1)
|
|
269
|
+
check_if_valid_rgb_tuple(data=text_clr, source=f'{self.__class__.__name__} text_clr')
|
|
270
|
+
check_if_valid_rgb_tuple(data=text_bg_clr, source=f'{self.__class__.__name__} text_bg_clr')
|
|
271
|
+
check_valid_boolean(value=verbose, source=f'{self.__class__.__name__} verbose', raise_error=True)
|
|
272
|
+
if show_gantt is not None:
|
|
273
|
+
check_int(name=f"{self.__class__.__name__} show_gantt", value=show_gantt, max_value=2, min_value=1)
|
|
274
|
+
self.video_setting, self.frame_setting, self.rotate, self.print_timers = video_setting, frame_setting, rotate, print_timers
|
|
275
|
+
self.circle_size, self.font_size, self.animal_names, self.text_opacity = circle_size, font_size, animal_names, text_opacity
|
|
276
|
+
self.text_thickness, self.space_size, self.show_pose, self.pose_palette, self.verbose = text_thickness, space_size, show_pose, pose_palette, verbose
|
|
277
|
+
self.text_color, self.text_bg_color, self.show_bbox, self.show_gantt, self.show_confidence = text_clr, text_bg_clr, show_bbox, show_gantt, show_confidence
|
|
278
|
+
self.gpu = True if check_nvidea_gpu_available() and gpu else False
|
|
279
|
+
self.pose_threshold = read_config_entry(self.config, ConfigKey.THRESHOLD_SETTINGS.value, ConfigKey.SKLEARN_BP_PROB_THRESH.value, Dtypes.FLOAT.value, 0.00)
|
|
280
|
+
if not os.path.exists(self.sklearn_plot_dir):
|
|
281
|
+
os.makedirs(self.sklearn_plot_dir)
|
|
282
|
+
if isinstance(video_paths, str): self.video_paths = [video_paths]
|
|
283
|
+
elif isinstance(video_paths, list): self.video_paths = video_paths
|
|
284
|
+
elif video_paths is None:
|
|
285
|
+
self.video_paths = find_all_videos_in_project(videos_dir=self.video_dir)
|
|
286
|
+
if len(self.video_paths) == 0:
|
|
287
|
+
raise NoDataError(msg=f'Cannot create classification videos. No videos exist in {self.video_dir} directory', source=self.__class__.__name__)
|
|
288
|
+
else:
|
|
289
|
+
raise InvalidInputError(msg=f'video_paths has to be a path of a list of paths. Got {type(video_paths)}', source=self.__class__.__name__)
|
|
290
|
+
|
|
291
|
+
for video_path in self.video_paths:
|
|
292
|
+
video_name = get_fn_ext(filepath=video_path)[1]
|
|
293
|
+
data_path = os.path.join(self.machine_results_dir, f'{video_name}.{self.file_type}')
|
|
294
|
+
if not os.path.isfile(data_path): raise NoDataError(msg=f'Cannot create classification videos for {video_name}. Expected classification data at location {data_path} but file does not exist', source=self.__class__.__name__)
|
|
295
|
+
check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0])
|
|
296
|
+
self.core_cnt = find_core_cnt()[0] if int(core_cnt) == -1 or int(core_cnt) > find_core_cnt()[0] else int(core_cnt)
|
|
297
|
+
self.conf_cols = [f'Probability_{x}' for x in self.clf_names]
|
|
298
|
+
if platform.system() == "Darwin":
|
|
299
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
300
|
+
|
|
301
|
+
def __get_print_settings(self):
|
|
302
|
+
optimal_circle_size = self.get_optimal_circle_size(frame_size=(self.video_meta_data["width"], self.video_meta_data["height"]), circle_frame_ratio=100)
|
|
303
|
+
longest_str = str(max(['TIMERS:', 'ENSEMBLE PREDICTION:'] + self.clf_names, key=len))
|
|
304
|
+
self.video_text_thickness = TextOptions.TEXT_THICKNESS.value if self.text_thickness is None else int(max(self.text_thickness, 1))
|
|
305
|
+
optimal_font_size, _, optimal_spacing_scale = self.get_optimal_font_scales(text=longest_str, accepted_px_width=int(self.video_meta_data["width"] / 3), accepted_px_height=int(self.video_meta_data["height"] / 10), text_thickness=self.video_text_thickness)
|
|
306
|
+
self.video_circle_size = optimal_circle_size if self.circle_size is None else int(max(1, self.circle_size))
|
|
307
|
+
self.video_font_size = optimal_font_size if self.font_size is None else self.font_size
|
|
308
|
+
self.video_space_size = optimal_spacing_scale if self.space_size is None else int(max(self.space_size, 1))
|
|
309
|
+
self.video_text_opacity = 0.8 if self.text_opacity is None else float(self.text_opacity)
|
|
310
|
+
|
|
311
|
+
def run(self):
|
|
312
|
+
if self.verbose: print(f'Creating {len(self.video_paths)} classification visualization(s) using {self.core_cnt} cores... ({get_current_time()})')
|
|
313
|
+
self.pool = get_cpu_pool(core_cnt=self.core_cnt, source=self.__class__.__name__, )
|
|
314
|
+
for video_cnt, video_path in enumerate(self.video_paths):
|
|
315
|
+
video_timer = SimbaTimer(start=True)
|
|
316
|
+
_, self.video_name, _ = get_fn_ext(video_path)
|
|
317
|
+
if self.verbose: print(f"[{get_current_time()}] Creating classification visualization for video {self.video_name}...")
|
|
318
|
+
self.data_path = os.path.join(self.machine_results_dir, f'{self.video_name}.{self.file_type}')
|
|
319
|
+
self.data_df = read_df(self.data_path, self.file_type).reset_index(drop=True).fillna(0)
|
|
320
|
+
if self.show_pose: check_that_column_exist(df=self.data_df, column_name=self.bp_col_names, file_name=self.data_path)
|
|
321
|
+
if self.show_confidence: check_that_column_exist(df=self.data_df, column_name=self.conf_cols, file_name=self.data_path)
|
|
322
|
+
self.video_meta_data = get_video_meta_data(video_path=video_path)
|
|
323
|
+
height, width = deepcopy(self.video_meta_data["height"]), deepcopy(self.video_meta_data["width"])
|
|
324
|
+
self.save_path = os.path.join(self.sklearn_plot_dir, f"{self.video_name}.mp4")
|
|
325
|
+
self.video_frame_dir, self.video_temp_dir = None, None
|
|
326
|
+
if self.video_setting:
|
|
327
|
+
self.video_save_path = os.path.join(self.sklearn_plot_dir, f"{self.video_name}.mp4")
|
|
328
|
+
self.video_temp_dir = os.path.join(self.sklearn_plot_dir, self.video_name, "temp")
|
|
329
|
+
create_directory(paths=self.video_temp_dir, overwrite=True)
|
|
330
|
+
if self.frame_setting:
|
|
331
|
+
self.video_frame_dir = os.path.join(self.sklearn_plot_dir, self.video_name)
|
|
332
|
+
create_directory(paths=self.video_temp_dir, overwrite=True)
|
|
333
|
+
if self.rotate:
|
|
334
|
+
self.video_meta_data["height"], self.video_meta_data["width"] = (width, height)
|
|
335
|
+
check_video_and_data_frm_count_align(video=video_path, data=self.data_df, name=self.video_name, raise_error=False)
|
|
336
|
+
check_that_column_exist(df=self.data_df, column_name=self.clf_names, file_name=self.data_path)
|
|
337
|
+
self.__get_print_settings()
|
|
338
|
+
if self.show_gantt is not None:
|
|
339
|
+
self.gantt_clrs = create_color_palette(pallete_name=self.pose_palette, increments=len(self.clf_names) + 1, as_int=True, as_rgb_ratio=True)
|
|
340
|
+
self.bouts_df = detect_bouts(data_df=self.data_df, target_lst=list(self.clf_names), fps=int(self.video_meta_data["fps"]))
|
|
341
|
+
self.final_gantt_img = PlottingMixin().make_gantt_plot(x_length=len(self.data_df) + 1, bouts_df=self.bouts_df, clf_names=self.clf_names, fps=self.video_meta_data["fps"], width=self.video_meta_data["width"], height=self.video_meta_data["height"], font_size=12, font_rotation=90, video_name=self.video_meta_data["video_name"], save_path=None, palette=self.gantt_clrs)
|
|
342
|
+
self.final_gantt_img = self.resize_gantt(self.final_gantt_img, self.video_meta_data["height"])
|
|
343
|
+
else:
|
|
344
|
+
self.bouts_df, self.final_gantt_img, self.gantt_clrs = None, None, None
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
self.clf_cumsums, self.clf_p = {}, {} if self.show_confidence else None
|
|
348
|
+
for clf_name in self.clf_names:
|
|
349
|
+
self.clf_cumsums[clf_name] = self.data_df[clf_name].cumsum()
|
|
350
|
+
if self.show_confidence: self.clf_p[clf_name] = np.round(self.data_df[f'Probability_{clf_name}'].values.reshape(-1), 4)
|
|
351
|
+
|
|
352
|
+
self.data_df["index"] = self.data_df.index
|
|
353
|
+
data = np.array_split(self.data_df, self.core_cnt)
|
|
354
|
+
data = [(cnt, x) for (cnt, x) in enumerate(data)]
|
|
355
|
+
|
|
356
|
+
constants = functools.partial(_multiprocess_sklearn_video,
|
|
357
|
+
bp_dict=self.animal_bp_dict,
|
|
358
|
+
video_save_dir=self.video_temp_dir,
|
|
359
|
+
frame_save_dir=self.video_frame_dir,
|
|
360
|
+
clf_cumsum=self.clf_cumsums,
|
|
361
|
+
rotate=self.rotate,
|
|
362
|
+
video_path=video_path,
|
|
363
|
+
clf_confidence=self.clf_p,
|
|
364
|
+
print_timers=self.print_timers,
|
|
365
|
+
video_setting=self.video_setting,
|
|
366
|
+
frame_setting=self.frame_setting,
|
|
367
|
+
pose_threshold=self.pose_threshold,
|
|
368
|
+
show_pose=self.show_pose,
|
|
369
|
+
show_animal_names=self.animal_names,
|
|
370
|
+
circle_size=self.video_circle_size,
|
|
371
|
+
font_size=self.video_font_size,
|
|
372
|
+
space_size=self.video_space_size,
|
|
373
|
+
text_thickness=self.video_text_thickness,
|
|
374
|
+
text_opacity=self.video_text_opacity,
|
|
375
|
+
text_bg_clr=self.text_bg_color,
|
|
376
|
+
text_color=self.text_color,
|
|
377
|
+
pose_clr_lst=self.clr_lst,
|
|
378
|
+
show_bbox=self.show_bbox,
|
|
379
|
+
show_gantt=self.show_gantt,
|
|
380
|
+
bouts_df=self.bouts_df,
|
|
381
|
+
final_gantt=self.final_gantt_img,
|
|
382
|
+
gantt_clrs=self.gantt_clrs,
|
|
383
|
+
clf_names=self.clf_names,
|
|
384
|
+
verbose=self.verbose)
|
|
385
|
+
|
|
386
|
+
for cnt, result in enumerate(self.pool.imap(constants, data, chunksize=self.multiprocess_chunksize)):
|
|
387
|
+
if self.verbose: print(f"[{get_current_time()}] Image batch {result} complete, Video {(video_cnt + 1)}/{len(self.video_paths)}...")
|
|
388
|
+
|
|
389
|
+
if self.video_setting:
|
|
390
|
+
if self.verbose: print(f"Joining {self.video_name} multiprocessed video...")
|
|
391
|
+
concatenate_videos_in_folder(in_folder=self.video_temp_dir, save_path=self.video_save_path, gpu=self.gpu, verbose=self.verbose)
|
|
392
|
+
video_timer.stop_timer()
|
|
393
|
+
print(f"Video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s)...")
|
|
394
|
+
|
|
395
|
+
terminate_cpu_pool(pool=self.pool, force=False)
|
|
396
|
+
self.timer.stop_timer()
|
|
397
|
+
if self.video_setting:
|
|
398
|
+
stdout_success(msg=f"{len(self.video_paths)} video(s) saved in {self.sklearn_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
|
|
399
|
+
if self.frame_setting:
|
|
400
|
+
stdout_success(f"Frames for {len(self.video_paths)} videos saved in sub-folders within {self.sklearn_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
if __name__ == "__main__":
|
|
405
|
+
clf_plotter = PlotSklearnResultsMultiProcess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
|
|
406
|
+
video_setting=True,
|
|
407
|
+
frame_setting=False,
|
|
408
|
+
video_paths=None,
|
|
409
|
+
print_timers=True,
|
|
410
|
+
rotate=False,
|
|
411
|
+
core_cnt=21,
|
|
412
|
+
animal_names=False,
|
|
413
|
+
show_bbox=True,
|
|
414
|
+
show_gantt=None)
|
|
415
|
+
clf_plotter.run()
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
# if __name__ == "__main__":
|
|
421
|
+
# clf_plotter = PlotSklearnResultsMultiProcess(config_path=r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini",
|
|
422
|
+
# video_setting=True,
|
|
423
|
+
# frame_setting=False,
|
|
424
|
+
# video_paths=r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/videos/Together_1.mp4",
|
|
425
|
+
# print_timers=True,
|
|
426
|
+
# rotate=False,
|
|
427
|
+
# animal_names=False,
|
|
428
|
+
# show_bbox=True,
|
|
429
|
+
# show_gantt=None)
|
|
430
|
+
# clf_plotter.run()
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
#text_settings = {'circle_scale': 5, 'font_size': 0.528, 'spacing_scale': 28, 'text_thickness': 2}
|
|
436
|
+
# clf_plotter = PlotSklearnResultsMultiProcess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini',
|
|
437
|
+
# video_setting=True,
|
|
438
|
+
# frame_setting=False,
|
|
439
|
+
# rotate=False,
|
|
440
|
+
# video_file_path='592_MA147_Gq_CNO_0515.mp4',
|
|
441
|
+
# cores=-1,
|
|
442
|
+
# text_settings=False)
|
|
443
|
+
# clf_plotter.run()
|
|
444
|
+
#
|
|
445
|
+
|
|
446
|
+
# clf_plotter = PlotSklearnResultsMultiProcess(config_path='/Users/simon/Desktop/envs/troubleshooting/DLC_2_Black_animals/project_folder/project_config.ini', video_setting=True, frame_setting=False, rotate=False, video_file_path='Together_1.avi', cores=5)
|
|
447
|
+
# clf_plotter.run()
|
|
448
|
+
|
|
449
|
+
# if __name__ == "__main__":
|
|
450
|
+
# clf_plotter = PlotSklearnResultsMultiProcess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
|
|
451
|
+
# video_setting = True,
|
|
452
|
+
# frame_setting = False,
|
|
453
|
+
# rotate = False,
|
|
454
|
+
# core_cnt = 6,
|
|
455
|
+
# show_confidence=True,
|
|
456
|
+
# video_paths=r"C:\troubleshooting\mitra\project_folder\videos\501_MA142_Gi_CNO_0521.mp4")
|
|
456
457
|
# clf_plotter.run()
|