simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simba/assets/.recent_projects.txt +1 -0
- simba/assets/lookups/tooptips.json +6 -1
- simba/data_processors/agg_clf_counter_mp.py +52 -53
- simba/data_processors/blob_location_computer.py +1 -1
- simba/data_processors/circling_detector.py +30 -13
- simba/data_processors/cuda/geometry.py +45 -27
- simba/data_processors/cuda/image.py +1648 -1598
- simba/data_processors/cuda/statistics.py +72 -26
- simba/data_processors/cuda/timeseries.py +1 -1
- simba/data_processors/cue_light_analyzer.py +5 -9
- simba/data_processors/egocentric_aligner.py +25 -7
- simba/data_processors/freezing_detector.py +55 -47
- simba/data_processors/kleinberg_calculator.py +61 -29
- simba/feature_extractors/feature_subsets.py +14 -7
- simba/feature_extractors/mitra_feature_extractor.py +2 -2
- simba/feature_extractors/straub_tail_analyzer.py +4 -6
- simba/labelling/standard_labeller.py +1 -1
- simba/mixins/config_reader.py +5 -2
- simba/mixins/geometry_mixin.py +22 -36
- simba/mixins/image_mixin.py +24 -28
- simba/mixins/plotting_mixin.py +28 -10
- simba/mixins/statistics_mixin.py +48 -11
- simba/mixins/timeseries_features_mixin.py +1 -1
- simba/mixins/train_model_mixin.py +67 -29
- simba/model/inference_batch.py +1 -1
- simba/model/yolo_seg_inference.py +3 -3
- simba/outlier_tools/skip_outlier_correction.py +1 -1
- simba/plotting/ROI_feature_visualizer_mp.py +3 -5
- simba/plotting/clf_validator_mp.py +4 -5
- simba/plotting/cue_light_visualizer.py +6 -7
- simba/plotting/directing_animals_visualizer_mp.py +2 -3
- simba/plotting/distance_plotter_mp.py +378 -378
- simba/plotting/gantt_creator.py +29 -10
- simba/plotting/gantt_creator_mp.py +96 -33
- simba/plotting/geometry_plotter.py +270 -272
- simba/plotting/heat_mapper_clf_mp.py +4 -6
- simba/plotting/heat_mapper_location_mp.py +2 -2
- simba/plotting/light_dark_box_plotter.py +2 -2
- simba/plotting/path_plotter_mp.py +26 -29
- simba/plotting/plot_clf_results_mp.py +455 -454
- simba/plotting/pose_plotter_mp.py +28 -29
- simba/plotting/probability_plot_creator_mp.py +288 -288
- simba/plotting/roi_plotter_mp.py +31 -31
- simba/plotting/single_run_model_validation_video_mp.py +427 -427
- simba/plotting/spontaneous_alternation_plotter.py +2 -3
- simba/plotting/yolo_pose_track_visualizer.py +32 -27
- simba/plotting/yolo_pose_visualizer.py +35 -36
- simba/plotting/yolo_seg_visualizer.py +2 -3
- simba/pose_importers/simba_blob_importer.py +3 -3
- simba/roi_tools/roi_aggregate_stats_mp.py +5 -4
- simba/roi_tools/roi_clf_calculator_mp.py +4 -4
- simba/sandbox/analyze_runtimes.py +30 -0
- simba/sandbox/cuda/egocentric_rotator.py +374 -374
- simba/sandbox/get_cpu_pool.py +5 -0
- simba/sandbox/proboscis_to_tip.py +28 -0
- simba/sandbox/test_directionality.py +47 -0
- simba/sandbox/test_nonstatic_directionality.py +27 -0
- simba/sandbox/test_pycharm_cuda.py +51 -0
- simba/sandbox/test_simba_install.py +41 -0
- simba/sandbox/test_static_directionality.py +26 -0
- simba/sandbox/test_static_directionality_2d.py +26 -0
- simba/sandbox/verify_env.py +42 -0
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
- simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
- simba/ui/pop_ups/fsttc_pop_up.py +27 -25
- simba/ui/pop_ups/gantt_pop_up.py +31 -6
- simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
- simba/ui/pop_ups/video_processing_pop_up.py +37 -29
- simba/ui/tkinter_functions.py +3 -0
- simba/utils/custom_feature_extractor.py +1 -1
- simba/utils/data.py +90 -14
- simba/utils/enums.py +1 -0
- simba/utils/errors.py +441 -440
- simba/utils/lookups.py +1203 -1203
- simba/utils/printing.py +124 -124
- simba/utils/read_write.py +3769 -3721
- simba/utils/yolo.py +10 -1
- simba/video_processors/blob_tracking_executor.py +2 -2
- simba/video_processors/clahe_ui.py +1 -1
- simba/video_processors/egocentric_video_rotator.py +44 -41
- simba/video_processors/multi_cropper.py +1 -1
- simba/video_processors/video_processing.py +5264 -5222
- simba/video_processors/videos_to_frames.py +43 -33
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/METADATA +4 -3
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/RECORD +90 -80
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.1.dist-info}/top_level.txt +0 -0
|
@@ -1,427 +1,427 @@
|
|
|
1
|
-
__author__ = "Simon Nilsson; sronilsson@gmail.com"
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
|
|
5
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
6
|
-
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
7
|
-
import functools
|
|
8
|
-
import multiprocessing
|
|
9
|
-
import os
|
|
10
|
-
import platform
|
|
11
|
-
from copy import deepcopy
|
|
12
|
-
from typing import List, Optional, Tuple, Union
|
|
13
|
-
|
|
14
|
-
import cv2
|
|
15
|
-
import imutils
|
|
16
|
-
import pandas as pd
|
|
17
|
-
|
|
18
|
-
try:
|
|
19
|
-
from typing import Literal
|
|
20
|
-
except:
|
|
21
|
-
from typing_extensions import Literal
|
|
22
|
-
|
|
23
|
-
import matplotlib
|
|
24
|
-
import matplotlib.pyplot as plt
|
|
25
|
-
import numpy as np
|
|
26
|
-
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
|
27
|
-
|
|
28
|
-
from simba.mixins.config_reader import ConfigReader
|
|
29
|
-
from simba.mixins.geometry_mixin import GeometryMixin
|
|
30
|
-
from simba.mixins.plotting_mixin import PlottingMixin
|
|
31
|
-
from simba.mixins.train_model_mixin import TrainModelMixin
|
|
32
|
-
from simba.utils.checks import (check_file_exist_and_readable, check_float,
|
|
33
|
-
check_int, check_str, check_valid_boolean,
|
|
34
|
-
check_video_and_data_frm_count_align)
|
|
35
|
-
from simba.utils.data import create_color_palette, plug_holes_shortest_bout
|
|
36
|
-
|
|
37
|
-
from simba.utils.
|
|
38
|
-
from simba.utils.
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
cv2.
|
|
81
|
-
cv2.
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
ax.
|
|
104
|
-
ax.
|
|
105
|
-
ax.
|
|
106
|
-
ax.
|
|
107
|
-
ax.
|
|
108
|
-
ax.
|
|
109
|
-
ax.yaxis.
|
|
110
|
-
|
|
111
|
-
canvas
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
addSpacer
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
addSpacer
|
|
163
|
-
|
|
164
|
-
addSpacer
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
gantt_img =
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
:param Union[str, os.PathLike]
|
|
202
|
-
:param Union[str, os.PathLike]
|
|
203
|
-
:param
|
|
204
|
-
:param bool
|
|
205
|
-
:param
|
|
206
|
-
:param Optional[
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
:param Optional[int]
|
|
211
|
-
:param Optional[int]
|
|
212
|
-
:param Optional[
|
|
213
|
-
:param float
|
|
214
|
-
:param
|
|
215
|
-
:param int
|
|
216
|
-
:param
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
-
|
|
220
|
-
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
:
|
|
226
|
-
:
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
:
|
|
231
|
-
:
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
>>>
|
|
236
|
-
|
|
237
|
-
...
|
|
238
|
-
...
|
|
239
|
-
...
|
|
240
|
-
...
|
|
241
|
-
...
|
|
242
|
-
...
|
|
243
|
-
...
|
|
244
|
-
...
|
|
245
|
-
...
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
check_file_exist_and_readable(file_path=
|
|
274
|
-
check_file_exist_and_readable(file_path=
|
|
275
|
-
|
|
276
|
-
check_valid_boolean(value=[
|
|
277
|
-
check_valid_boolean(value=[
|
|
278
|
-
check_valid_boolean(value=[
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
if
|
|
282
|
-
if
|
|
283
|
-
if
|
|
284
|
-
if
|
|
285
|
-
check_float(name=f
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
self.
|
|
300
|
-
self.
|
|
301
|
-
self.
|
|
302
|
-
self.
|
|
303
|
-
self.
|
|
304
|
-
self.
|
|
305
|
-
self.
|
|
306
|
-
self.
|
|
307
|
-
self.
|
|
308
|
-
self.
|
|
309
|
-
self.
|
|
310
|
-
self.
|
|
311
|
-
|
|
312
|
-
self.
|
|
313
|
-
self.
|
|
314
|
-
self.
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
self.
|
|
326
|
-
self.
|
|
327
|
-
self.
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
self.
|
|
332
|
-
self.data_df[self.
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
self.
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
self.
|
|
341
|
-
self.final_gantt_img = self.
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
data =
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
pool
|
|
375
|
-
concatenate_videos_in_folder(in_folder=self.temp_dir, save_path=self.video_save_path)
|
|
376
|
-
self.timer.stop_timer()
|
|
377
|
-
stdout_success(msg=f"Video complete, saved at {self.video_save_path}", elapsed_time=self.timer.elapsed_time_str)
|
|
378
|
-
|
|
379
|
-
#
|
|
380
|
-
# if __name__ == "__main__":
|
|
381
|
-
# test = ValidateModelOneVideoMultiprocess(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini",
|
|
382
|
-
# feature_path=r"D:\troubleshooting\mitra\project_folder\csv\features_extracted\592_MA147_CNO1_0515.csv",
|
|
383
|
-
# model_path=r"C:\troubleshooting\mitra\models\validations\rearing_5\rearing.sav",
|
|
384
|
-
# create_gantt=2,
|
|
385
|
-
# show_pose=True,
|
|
386
|
-
# show_animal_names=True,
|
|
387
|
-
# core_cnt=13,
|
|
388
|
-
# show_clf_confidence=True,
|
|
389
|
-
# discrimination_threshold=0.20)
|
|
390
|
-
# test.run()
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
#
|
|
394
|
-
# if __name__ == "__main__":
|
|
395
|
-
# test = ValidateModelOneVideoMultiprocess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
|
|
396
|
-
# feature_file_path=r"C:\troubleshooting\mitra\project_folder\csv\features_extracted\844_MA131_gq_CNO_0624.csv",
|
|
397
|
-
# model_path=r"C:\troubleshooting\mitra\models\validations\lay-on-belly_1\lay-on-belly.sav",
|
|
398
|
-
# discrimination_threshold=0.35,
|
|
399
|
-
# shortest_bout=200,
|
|
400
|
-
# cores=-1,
|
|
401
|
-
# settings={'pose': True, 'animal_names': False, 'styles': None},
|
|
402
|
-
# create_gantt=2)
|
|
403
|
-
# test.run()
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
# test = ValidateModelOneVideoMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini',
|
|
410
|
-
# feature_file_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/csv/features_extracted/SI_DAY3_308_CD1_PRESENT.csv',
|
|
411
|
-
# model_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/models/generated_models/Running.sav',
|
|
412
|
-
# discrimination_threshold=0.6,
|
|
413
|
-
# shortest_bout=50,
|
|
414
|
-
# cores=6,
|
|
415
|
-
# settings={'pose': True, 'animal_names': True, 'styles': None},
|
|
416
|
-
# create_gantt=None)
|
|
417
|
-
# test.run()
|
|
418
|
-
|
|
419
|
-
# test = ValidateModelOneVideoMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
|
|
420
|
-
# feature_file_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/features_extracted/Together_1.csv',
|
|
421
|
-
# model_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/models/generated_models/Attack.sav',
|
|
422
|
-
# discrimination_threshold=0.6,
|
|
423
|
-
# shortest_bout=50,
|
|
424
|
-
# cores=6,
|
|
425
|
-
# settings={'pose': True, 'animal_names': True, 'styles': None},
|
|
426
|
-
# create_gantt=None)
|
|
427
|
-
# test.run()
|
|
1
|
+
__author__ = "Simon Nilsson; sronilsson@gmail.com"
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
6
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
7
|
+
import functools
|
|
8
|
+
import multiprocessing
|
|
9
|
+
import os
|
|
10
|
+
import platform
|
|
11
|
+
from copy import deepcopy
|
|
12
|
+
from typing import List, Optional, Tuple, Union
|
|
13
|
+
|
|
14
|
+
import cv2
|
|
15
|
+
import imutils
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from typing import Literal
|
|
20
|
+
except:
|
|
21
|
+
from typing_extensions import Literal
|
|
22
|
+
|
|
23
|
+
import matplotlib
|
|
24
|
+
import matplotlib.pyplot as plt
|
|
25
|
+
import numpy as np
|
|
26
|
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
|
27
|
+
|
|
28
|
+
from simba.mixins.config_reader import ConfigReader
|
|
29
|
+
from simba.mixins.geometry_mixin import GeometryMixin
|
|
30
|
+
from simba.mixins.plotting_mixin import PlottingMixin
|
|
31
|
+
from simba.mixins.train_model_mixin import TrainModelMixin
|
|
32
|
+
from simba.utils.checks import (check_file_exist_and_readable, check_float,
|
|
33
|
+
check_int, check_str, check_valid_boolean,
|
|
34
|
+
check_video_and_data_frm_count_align)
|
|
35
|
+
from simba.utils.data import (create_color_palette, plug_holes_shortest_bout,
|
|
36
|
+
terminate_cpu_pool)
|
|
37
|
+
from simba.utils.enums import Options, TextOptions
|
|
38
|
+
from simba.utils.printing import SimbaTimer, stdout_success
|
|
39
|
+
from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
40
|
+
create_directory, find_core_cnt,
|
|
41
|
+
get_fn_ext, get_video_meta_data, read_df,
|
|
42
|
+
read_pickle, write_df)
|
|
43
|
+
from simba.utils.warnings import FrameRangeWarning
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _validation_video_mp(data: pd.DataFrame,
|
|
47
|
+
bp_dict: dict,
|
|
48
|
+
video_save_dir: str,
|
|
49
|
+
video_path: str,
|
|
50
|
+
text_thickness: int,
|
|
51
|
+
text_opacity: float,
|
|
52
|
+
font_size: int,
|
|
53
|
+
text_spacing: int,
|
|
54
|
+
circle_size: int,
|
|
55
|
+
show_pose: bool,
|
|
56
|
+
show_animal_bounding_boxes: bool,
|
|
57
|
+
show_animal_names: bool,
|
|
58
|
+
gantt_setting: Union[int, None],
|
|
59
|
+
final_gantt: Optional[np.ndarray],
|
|
60
|
+
clf_data: np.ndarray,
|
|
61
|
+
clrs: List[List],
|
|
62
|
+
clf_name: str,
|
|
63
|
+
bouts_df: pd.DataFrame,
|
|
64
|
+
conf_data: np.ndarray):
|
|
65
|
+
|
|
66
|
+
def _put_text(img: np.ndarray,
|
|
67
|
+
text: str,
|
|
68
|
+
pos: Tuple[int, int],
|
|
69
|
+
font_size: int,
|
|
70
|
+
font_thickness: Optional[int] = 2,
|
|
71
|
+
font: Optional[int] = cv2.FONT_HERSHEY_DUPLEX,
|
|
72
|
+
text_color: Optional[Tuple[int, int, int]] = (255, 255, 255),
|
|
73
|
+
text_color_bg: Optional[Tuple[int, int, int]] = (0, 0, 0),
|
|
74
|
+
text_bg_alpha: float = 0.8):
|
|
75
|
+
|
|
76
|
+
x, y = pos
|
|
77
|
+
text_size, px_buffer = cv2.getTextSize(text, font, font_size, font_thickness)
|
|
78
|
+
w, h = text_size
|
|
79
|
+
overlay, output = img.copy(), img.copy()
|
|
80
|
+
cv2.rectangle(overlay, (x, y-h), (x + w, y + px_buffer), text_color_bg, -1)
|
|
81
|
+
cv2.addWeighted(overlay, text_bg_alpha, output, 1 - text_bg_alpha, 0, output)
|
|
82
|
+
cv2.putText(output, text, (x, y), font, font_size, text_color, font_thickness)
|
|
83
|
+
return output
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _create_gantt(bouts_df: pd.DataFrame,
|
|
87
|
+
clf_name: str,
|
|
88
|
+
image_index: int,
|
|
89
|
+
fps: int,
|
|
90
|
+
header_font_size: int = 24,
|
|
91
|
+
label_font_size: int = 12):
|
|
92
|
+
|
|
93
|
+
fig, ax = plt.subplots(figsize=(final_gantt.shape[1] / dpi, final_gantt.shape[0] / dpi))
|
|
94
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
95
|
+
relRows = bouts_df.loc[bouts_df["End_frame"] <= image_index]
|
|
96
|
+
for i, event in enumerate(relRows.groupby("Event")):
|
|
97
|
+
data_event = event[1][["Start_time", "Bout_time"]]
|
|
98
|
+
ax.broken_barh(data_event.values, (4, 4), facecolors="red")
|
|
99
|
+
xLength = (round(image_index / fps)) + 1
|
|
100
|
+
if xLength < 10:
|
|
101
|
+
xLength = 10
|
|
102
|
+
|
|
103
|
+
ax.set_xlim(0, xLength)
|
|
104
|
+
ax.set_ylim([0, 12])
|
|
105
|
+
ax.set_xlabel("Session (s)", fontsize=label_font_size)
|
|
106
|
+
ax.set_ylabel(clf_name, fontsize=label_font_size)
|
|
107
|
+
ax.set_title(f"{clf_name} GANTT CHART", fontsize=header_font_size)
|
|
108
|
+
ax.set_yticks([])
|
|
109
|
+
ax.yaxis.set_ticklabels([])
|
|
110
|
+
ax.yaxis.grid(True)
|
|
111
|
+
canvas = FigureCanvas(fig)
|
|
112
|
+
canvas.draw()
|
|
113
|
+
img = np.array(np.uint8(np.array(canvas.renderer._renderer)))[:, :, :3]
|
|
114
|
+
plt.close(fig)
|
|
115
|
+
return img
|
|
116
|
+
|
|
117
|
+
dpi = plt.rcParams["figure.dpi"]
|
|
118
|
+
fourcc, font = cv2.VideoWriter_fourcc(*"mp4v"), cv2.FONT_HERSHEY_DUPLEX
|
|
119
|
+
cap = cv2.VideoCapture(video_path)
|
|
120
|
+
video_meta_data = get_video_meta_data(video_path=video_path, fps_as_int=False)
|
|
121
|
+
batch_id, batch_data = data[0], data[1]
|
|
122
|
+
start_frm, current_frm, end_frm = batch_data.index[0], batch_data.index[0], batch_data.index[-1]
|
|
123
|
+
video_save_path = os.path.join(video_save_dir, f"{batch_id}.mp4")
|
|
124
|
+
if gantt_setting is not None:
|
|
125
|
+
video_size = (int(video_meta_data["width"] + final_gantt.shape[1]), int(video_meta_data["height"]))
|
|
126
|
+
writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], video_size)
|
|
127
|
+
else:
|
|
128
|
+
video_size = (int(video_meta_data["width"]), int(video_meta_data["height"]))
|
|
129
|
+
writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], video_size)
|
|
130
|
+
cap.set(1, start_frm)
|
|
131
|
+
while (current_frm <= end_frm) & (current_frm <= video_meta_data["frame_count"]):
|
|
132
|
+
clf_frm_cnt = np.sum(clf_data[0:current_frm])
|
|
133
|
+
ret, img = cap.read()
|
|
134
|
+
if ret:
|
|
135
|
+
frm_timer = SimbaTimer(start=True)
|
|
136
|
+
if show_pose:
|
|
137
|
+
for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
|
|
138
|
+
for bp_cnt, bp in enumerate(range(len(animal_data["X_bps"]))):
|
|
139
|
+
x_header, y_header = (animal_data["X_bps"][bp], animal_data["Y_bps"][bp])
|
|
140
|
+
animal_cords = tuple(batch_data.loc[current_frm, [x_header, y_header]])
|
|
141
|
+
cv2.circle(img, (int(animal_cords[0]), int(animal_cords[1])), circle_size, clrs[animal_cnt][bp_cnt], -1)
|
|
142
|
+
if show_animal_names:
|
|
143
|
+
for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
|
|
144
|
+
x_header, y_header = (animal_data["X_bps"][0], animal_data["Y_bps"][0],)
|
|
145
|
+
animal_cords = tuple(batch_data.loc[current_frm, [x_header, y_header]])
|
|
146
|
+
cv2.putText(img, animal_name, (int(animal_cords[0]), int(animal_cords[1])), font, font_size, clrs[animal_cnt][0], text_thickness)
|
|
147
|
+
if show_animal_bounding_boxes:
|
|
148
|
+
for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()):
|
|
149
|
+
animal_headers = [val for pair in zip(animal_data["X_bps"], animal_data["Y_bps"]) for val in pair]
|
|
150
|
+
animal_cords = batch_data.loc[current_frm, animal_headers].values.reshape(-1, 2).astype(np.int32)
|
|
151
|
+
try:
|
|
152
|
+
bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
|
|
153
|
+
cv2.polylines(img, [bbox], True, clrs[animal_cnt][0], thickness=text_thickness, lineType=-1)
|
|
154
|
+
except:
|
|
155
|
+
pass
|
|
156
|
+
target_timer = round((1 / video_meta_data["fps"]) * clf_frm_cnt, 2)
|
|
157
|
+
img = _put_text(img=img, text="BEHAVIOR TIMER:", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value)
|
|
158
|
+
addSpacer = 2
|
|
159
|
+
img = _put_text(img=img, text=f"{clf_name} {target_timer}s", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_bg_alpha=text_opacity)
|
|
160
|
+
addSpacer += 1
|
|
161
|
+
if conf_data is not None:
|
|
162
|
+
img = _put_text(img=img, text=f"{clf_name} PROBABILITY: {round(conf_data[current_frm], 4)}", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_bg_alpha=text_opacity)
|
|
163
|
+
addSpacer += 1
|
|
164
|
+
img = _put_text(img=img, text="ENSEMBLE PREDICTION:", pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_bg_alpha=text_opacity)
|
|
165
|
+
addSpacer += 1
|
|
166
|
+
if clf_data[current_frm] == 1:
|
|
167
|
+
img = _put_text(img=img, text=clf_name, pos=(TextOptions.BORDER_BUFFER_Y.value, text_spacing * addSpacer), font_size=font_size, font_thickness=TextOptions.TEXT_THICKNESS.value, text_color=TextOptions.COLOR.value, text_bg_alpha=text_opacity)
|
|
168
|
+
addSpacer += 1
|
|
169
|
+
if gantt_setting == 1:
|
|
170
|
+
img = np.concatenate((img, final_gantt), axis=1)
|
|
171
|
+
elif gantt_setting == 2:
|
|
172
|
+
gantt_img = _create_gantt(bouts_df, clf_name, current_frm, video_meta_data["fps"], header_font_size=9, label_font_size=12)
|
|
173
|
+
gantt_img = imutils.resize(gantt_img, height=video_meta_data["height"])
|
|
174
|
+
img = np.concatenate((img, gantt_img), axis=1)
|
|
175
|
+
img = cv2.resize(img, video_size, interpolation=cv2.INTER_LINEAR)
|
|
176
|
+
writer.write(np.uint8(img))
|
|
177
|
+
current_frm += 1
|
|
178
|
+
frm_timer.stop_timer()
|
|
179
|
+
print(f"Multi-processing video frame {current_frm} on core {batch_id}...(elapsed time: {frm_timer.elapsed_time_str}s)")
|
|
180
|
+
else:
|
|
181
|
+
FrameRangeWarning(msg=f'Frame {current_frm} could not be read in video {video_path}. The video contains {video_meta_data["frame_count"]} frames while the data file contains data for {len(batch_data)} frames. Consider re-encoding the video, or make sure the pose-estimation data and associated video contains the same number of frames. ', source=_validation_video_mp.__name__)
|
|
182
|
+
break
|
|
183
|
+
|
|
184
|
+
cap.release()
|
|
185
|
+
writer.release()
|
|
186
|
+
return batch_id
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class ValidateModelOneVideoMultiprocess(ConfigReader, PlottingMixin, TrainModelMixin):
|
|
190
|
+
"""
|
|
191
|
+
Create classifier validation video for a single input video using multiprocessing for improved performance.
|
|
192
|
+
|
|
193
|
+
This class generates validation videos that overlay classifier predictions, pose estimations, and
|
|
194
|
+
optional Gantt charts onto the original video using multiple CPU cores for faster processing.
|
|
195
|
+
Results are stored in the `project_folder/frames/output/validation` directory.
|
|
196
|
+
|
|
197
|
+
.. note::
|
|
198
|
+
This multiprocess version provides significant speed improvements over the single-core
|
|
199
|
+
:class:`simba.plotting.single_run_model_validation_video.ValidateModelOneVideo` class.
|
|
200
|
+
|
|
201
|
+
:param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format.
|
|
202
|
+
:param Union[str, os.PathLike] feature_path: Path to SimBA file (parquet or CSV) containing pose-estimation and feature data.
|
|
203
|
+
:param Union[str, os.PathLike] model_path: Path to pickled classifier object (.sav file).
|
|
204
|
+
:param bool show_pose: If True, overlay pose estimation keypoints on the video. Default: True.
|
|
205
|
+
:param bool show_animal_names: If True, display animal names near the first body part. Default: False.
|
|
206
|
+
:param Optional[int] font_size: Font size for text overlays. If None, automatically calculated based on video dimensions.
|
|
207
|
+
:param Optional[str] bp_palette: Optional name of the palette to use to color the animal body-parts (e.g., Pastel1). If None, ``spring`` is used.
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
:param Optional[int] circle_size: Size of pose estimation circles. If None, automatically calculated based on video dimensions.
|
|
211
|
+
:param Optional[int] text_spacing: Spacing between text lines. If None, automatically calculated.
|
|
212
|
+
:param Optional[int] text_thickness: Thickness of text overlay. If None, uses default value.
|
|
213
|
+
:param Optional[float] text_opacity: Opacity of text overlays (0.1-1.0). If None, defaults to 0.8.
|
|
214
|
+
:param float discrimination_threshold: Classification probability threshold (0.0-1.0). Default: 0.0.
|
|
215
|
+
:param int shortest_bout: Minimum classified bout length in milliseconds. Bouts shorter than this will be reclassified as absent. Default: 0.
|
|
216
|
+
:param int core_cnt: Number of CPU cores to use for processing. If -1, uses all available cores. Default: -1.
|
|
217
|
+
:param Optional[Union[None, int]] create_gantt: Gantt chart creation option:
|
|
218
|
+
|
|
219
|
+
- None: No Gantt chart
|
|
220
|
+
- 1: Static Gantt chart (final frame only, faster)
|
|
221
|
+
- 2: Dynamic Gantt chart (updated per frame)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
.. youtube:: UOLSj7DGKRo
|
|
225
|
+
:width: 640
|
|
226
|
+
:height: 480
|
|
227
|
+
:align: center
|
|
228
|
+
|
|
229
|
+
.. video:: _static/img/T1.webm
|
|
230
|
+
:width: 1000
|
|
231
|
+
:autoplay:
|
|
232
|
+
:loop:
|
|
233
|
+
|
|
234
|
+
:example:
|
|
235
|
+
>>> # Create multiprocess validation video with dynamic Gantt chart
|
|
236
|
+
>>> validator = ValidateModelOneVideoMultiprocess(
|
|
237
|
+
... config_path=r'/path/to/project_config.ini',
|
|
238
|
+
... feature_path=r'/path/to/features.csv',
|
|
239
|
+
... model_path=r'/path/to/classifier.sav',
|
|
240
|
+
... show_pose=True,
|
|
241
|
+
... show_animal_names=True,
|
|
242
|
+
... discrimination_threshold=0.6,
|
|
243
|
+
... shortest_bout=500,
|
|
244
|
+
... core_cnt=4,
|
|
245
|
+
... create_gantt=2
|
|
246
|
+
... )
|
|
247
|
+
>>> validator.run()
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(self,
|
|
251
|
+
config_path: Union[str, os.PathLike],
|
|
252
|
+
feature_path: Union[str, os.PathLike],
|
|
253
|
+
model_path: Union[str, os.PathLike],
|
|
254
|
+
show_pose: bool = True,
|
|
255
|
+
show_animal_names: bool = False,
|
|
256
|
+
show_animal_bounding_boxes: bool = False,
|
|
257
|
+
show_clf_confidence: bool = False,
|
|
258
|
+
font_size: Optional[bool] = None,
|
|
259
|
+
circle_size: Optional[int] = None,
|
|
260
|
+
text_spacing: Optional[int] = None,
|
|
261
|
+
text_thickness: Optional[int] = None,
|
|
262
|
+
text_opacity: Optional[float] = None,
|
|
263
|
+
bp_palette: Optional[str] = None,
|
|
264
|
+
discrimination_threshold: float = 0.0,
|
|
265
|
+
shortest_bout: int = 0.0,
|
|
266
|
+
core_cnt: int = -1,
|
|
267
|
+
create_gantt: Optional[Union[None, int]] = None):
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
ConfigReader.__init__(self, config_path=config_path)
|
|
271
|
+
PlottingMixin.__init__(self)
|
|
272
|
+
TrainModelMixin.__init__(self)
|
|
273
|
+
check_file_exist_and_readable(file_path=config_path)
|
|
274
|
+
check_file_exist_and_readable(file_path=feature_path)
|
|
275
|
+
check_file_exist_and_readable(file_path=model_path)
|
|
276
|
+
check_valid_boolean(value=[show_pose], source=f'{self.__class__.__name__} show_pose', raise_error=True)
|
|
277
|
+
check_valid_boolean(value=[show_animal_names], source=f'{self.__class__.__name__} show_animal_names', raise_error=True)
|
|
278
|
+
check_valid_boolean(value=[show_animal_bounding_boxes], source=f'{self.__class__.__name__} show_animal_bounding_boxes', raise_error=True)
|
|
279
|
+
check_valid_boolean(value=[show_clf_confidence], source=f'{self.__class__.__name__} show_clf_confidence', raise_error=True)
|
|
280
|
+
check_int(name=f"{self.__class__.__name__} core_cnt", value=core_cnt, min_value=-1, unaccepted_vals=[0])
|
|
281
|
+
if font_size is not None: check_int(name=f'{self.__class__.__name__} font_size', value=font_size)
|
|
282
|
+
if circle_size is not None: check_int(name=f'{self.__class__.__name__} circle_size', value=circle_size)
|
|
283
|
+
if text_spacing is not None: check_int(name=f'{self.__class__.__name__} text_spacing', value=text_spacing)
|
|
284
|
+
if text_opacity is not None: check_float(name=f'{self.__class__.__name__} text_opacity', value=text_opacity, min_value=0.1)
|
|
285
|
+
if text_thickness is not None: check_float(name=f'{self.__class__.__name__} text_thickness', value=text_thickness, min_value=0.1)
|
|
286
|
+
check_float(name=f"{self.__class__.__name__} discrimination_threshold", value=discrimination_threshold, min_value=0, max_value=1.0)
|
|
287
|
+
check_int(name=f"{self.__class__.__name__} shortest_bout", value=shortest_bout, min_value=0)
|
|
288
|
+
if create_gantt is not None:
|
|
289
|
+
check_int(name=f"{self.__class__.__name__} create gantt", value=create_gantt, max_value=2, min_value=1)
|
|
290
|
+
if not os.path.exists(self.single_validation_video_save_dir):
|
|
291
|
+
os.makedirs(self.single_validation_video_save_dir)
|
|
292
|
+
if bp_palette is not None:
|
|
293
|
+
self.bp_palette = []
|
|
294
|
+
check_str(name=f'{self.__class__.__name__} bp_palette', value=bp_palette, options=(Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value))
|
|
295
|
+
for animal in range(self.animal_cnt):
|
|
296
|
+
self.bp_palette.append(create_color_palette(pallete_name=bp_palette, increments=(int(len(self.body_parts_lst)/self.animal_cnt) +1), as_int=True))
|
|
297
|
+
else:
|
|
298
|
+
self.bp_palette = deepcopy(self.clr_lst)
|
|
299
|
+
_, self.feature_filename, ext = get_fn_ext(feature_path)
|
|
300
|
+
self.video_path = self.find_video_of_file(self.video_dir, self.feature_filename)
|
|
301
|
+
self.video_meta_data = get_video_meta_data(video_path=self.video_path, fps_as_int=False)
|
|
302
|
+
self.clf_name, self.feature_file_path = (os.path.basename(model_path).replace(".sav", ""), feature_path)
|
|
303
|
+
self.vid_output_path = os.path.join(self.single_validation_video_save_dir, f"{self.feature_filename} {self.clf_name}.mp4")
|
|
304
|
+
self.clf_data_save_path = os.path.join(self.clf_data_validation_dir, f"{self.feature_filename }.csv")
|
|
305
|
+
self.show_pose, self.show_animal_names = show_pose, show_animal_names
|
|
306
|
+
self.font_size, self.circle_size, self.text_spacing, self.show_clf_confidence = font_size, circle_size, text_spacing, show_clf_confidence
|
|
307
|
+
self.text_opacity, self.text_thickness, self.show_animal_bounding_boxes = text_opacity, text_thickness, show_animal_bounding_boxes
|
|
308
|
+
self.clf = read_pickle(data_path=model_path, verbose=True)
|
|
309
|
+
self.data_df = read_df(feature_path, self.file_type)
|
|
310
|
+
self.x_df = self.drop_bp_cords(df=self.data_df)
|
|
311
|
+
self.discrimination_threshold, self.shortest_bout, self.create_gantt = float(discrimination_threshold), shortest_bout, create_gantt
|
|
312
|
+
check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.feature_filename, raise_error=False)
|
|
313
|
+
self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
|
|
314
|
+
self.temp_dir = os.path.join(self.single_validation_video_save_dir, "temp")
|
|
315
|
+
self.video_save_path = os.path.join(self.single_validation_video_save_dir, f"{self.feature_filename}.mp4")
|
|
316
|
+
create_directory(paths=self.temp_dir, overwrite=True)
|
|
317
|
+
if platform.system() == "Darwin":
|
|
318
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
319
|
+
|
|
320
|
+
def _get_styles(self):
|
|
321
|
+
self.video_text_thickness = TextOptions.TEXT_THICKNESS.value if self.text_thickness is None else int(max(self.text_thickness, 1))
|
|
322
|
+
longest_str = str(max(['TIMERS:', 'ENSEMBLE PREDICTION:'] + self.clf_names, key=len))
|
|
323
|
+
optimal_font_size, _, optimal_spacing_scale = self.get_optimal_font_scales(text=longest_str, accepted_px_width=int(self.video_meta_data["width"] / 3), accepted_px_height=int(self.video_meta_data["height"] / 10), text_thickness=self.video_text_thickness)
|
|
324
|
+
optimal_circle_size = self.get_optimal_circle_size(frame_size=(self.video_meta_data["width"], self.video_meta_data["height"]), circle_frame_ratio=100)
|
|
325
|
+
self.video_circle_size = optimal_circle_size if self.circle_size is None else int(self.circle_size)
|
|
326
|
+
self.video_font_size = optimal_font_size if self.font_size is None else self.font_size
|
|
327
|
+
self.video_space_size = optimal_spacing_scale if self.text_spacing is None else int(max(self.text_spacing, 1))
|
|
328
|
+
self.video_text_opacity = 0.8 if self.text_opacity is None else float(self.text_opacity)
|
|
329
|
+
|
|
330
|
+
def run(self):
|
|
331
|
+
self.prob_col_name = f"Probability_{self.clf_name}"
|
|
332
|
+
self.data_df[self.prob_col_name] = self.clf_predict_proba(clf=self.clf, x_df=self.x_df, model_name=self.clf_name, data_path=self.feature_file_path)
|
|
333
|
+
self.data_df[self.clf_name] = np.where(self.data_df[self.prob_col_name] > self.discrimination_threshold, 1, 0)
|
|
334
|
+
if self.shortest_bout > 1:
|
|
335
|
+
self.data_df = plug_holes_shortest_bout(data_df=self.data_df, clf_name=self.clf_name, fps=self.video_meta_data['fps'], shortest_bout=self.shortest_bout)
|
|
336
|
+
_ = write_df(df=self.data_df, file_type=self.file_type, save_path=self.clf_data_save_path)
|
|
337
|
+
print(f"Predictions created for video {self.feature_filename} (creating video, follow progressin OS terminal)...")
|
|
338
|
+
self._get_styles()
|
|
339
|
+
if self.create_gantt is not None:
|
|
340
|
+
self.bouts_df = self.get_bouts_for_gantt(data_df=self.data_df, clf_name=self.clf_name, fps=self.video_meta_data['fps'])
|
|
341
|
+
self.final_gantt_img = self.create_gantt_img(self.bouts_df ,self.clf_name,len(self.data_df), self.video_meta_data['fps'],f"Behavior gantt chart (entire session, length (s): {self.video_meta_data['video_length_s']}, frames: {self.video_meta_data['frame_count']})", header_font_size=9, label_font_size=12)
|
|
342
|
+
self.final_gantt_img = self.resize_gantt(self.final_gantt_img, self.video_meta_data["height"])
|
|
343
|
+
else:
|
|
344
|
+
self.bouts_df, self.final_gantt_img = None, None
|
|
345
|
+
conf_data = self.data_df[self.prob_col_name].values if self.show_clf_confidence else None
|
|
346
|
+
|
|
347
|
+
self.data_df = self.data_df.head(min(len(self.data_df), self.video_meta_data["frame_count"]))
|
|
348
|
+
data = np.array_split(self.data_df, self.core_cnt)
|
|
349
|
+
data = [(i, j) for i, j in enumerate(data)]
|
|
350
|
+
|
|
351
|
+
with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool:
|
|
352
|
+
constants = functools.partial(_validation_video_mp,
|
|
353
|
+
bp_dict=self.animal_bp_dict,
|
|
354
|
+
video_save_dir=self.temp_dir,
|
|
355
|
+
text_thickness=self.video_text_thickness,
|
|
356
|
+
text_opacity=self.video_text_opacity,
|
|
357
|
+
font_size=self.video_font_size,
|
|
358
|
+
text_spacing=self.video_space_size,
|
|
359
|
+
circle_size=self.video_circle_size,
|
|
360
|
+
video_path=self.video_path,
|
|
361
|
+
show_pose=self.show_pose,
|
|
362
|
+
show_animal_names=self.show_animal_names,
|
|
363
|
+
show_animal_bounding_boxes=self.show_animal_bounding_boxes,
|
|
364
|
+
gantt_setting=self.create_gantt,
|
|
365
|
+
final_gantt=self.final_gantt_img,
|
|
366
|
+
clf_data=self.data_df[self.clf_name].values,
|
|
367
|
+
clrs=self.bp_palette,
|
|
368
|
+
clf_name=self.clf_name,
|
|
369
|
+
bouts_df=self.bouts_df,
|
|
370
|
+
conf_data=conf_data)
|
|
371
|
+
|
|
372
|
+
for cnt, result in enumerate(pool.imap(constants, data, chunksize=self.multiprocess_chunksize)):
|
|
373
|
+
print(f"Image batch {result} complete, Video {self.feature_filename}...")
|
|
374
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
375
|
+
concatenate_videos_in_folder(in_folder=self.temp_dir, save_path=self.video_save_path)
|
|
376
|
+
self.timer.stop_timer()
|
|
377
|
+
stdout_success(msg=f"Video complete, saved at {self.video_save_path}", elapsed_time=self.timer.elapsed_time_str)
|
|
378
|
+
|
|
379
|
+
#
|
|
380
|
+
# if __name__ == "__main__":
|
|
381
|
+
# test = ValidateModelOneVideoMultiprocess(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini",
|
|
382
|
+
# feature_path=r"D:\troubleshooting\mitra\project_folder\csv\features_extracted\592_MA147_CNO1_0515.csv",
|
|
383
|
+
# model_path=r"C:\troubleshooting\mitra\models\validations\rearing_5\rearing.sav",
|
|
384
|
+
# create_gantt=2,
|
|
385
|
+
# show_pose=True,
|
|
386
|
+
# show_animal_names=True,
|
|
387
|
+
# core_cnt=13,
|
|
388
|
+
# show_clf_confidence=True,
|
|
389
|
+
# discrimination_threshold=0.20)
|
|
390
|
+
# test.run()
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
#
|
|
394
|
+
# if __name__ == "__main__":
|
|
395
|
+
# test = ValidateModelOneVideoMultiprocess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
|
|
396
|
+
# feature_file_path=r"C:\troubleshooting\mitra\project_folder\csv\features_extracted\844_MA131_gq_CNO_0624.csv",
|
|
397
|
+
# model_path=r"C:\troubleshooting\mitra\models\validations\lay-on-belly_1\lay-on-belly.sav",
|
|
398
|
+
# discrimination_threshold=0.35,
|
|
399
|
+
# shortest_bout=200,
|
|
400
|
+
# cores=-1,
|
|
401
|
+
# settings={'pose': True, 'animal_names': False, 'styles': None},
|
|
402
|
+
# create_gantt=2)
|
|
403
|
+
# test.run()
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
# test = ValidateModelOneVideoMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini',
|
|
410
|
+
# feature_file_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/csv/features_extracted/SI_DAY3_308_CD1_PRESENT.csv',
|
|
411
|
+
# model_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/models/generated_models/Running.sav',
|
|
412
|
+
# discrimination_threshold=0.6,
|
|
413
|
+
# shortest_bout=50,
|
|
414
|
+
# cores=6,
|
|
415
|
+
# settings={'pose': True, 'animal_names': True, 'styles': None},
|
|
416
|
+
# create_gantt=None)
|
|
417
|
+
# test.run()
|
|
418
|
+
|
|
419
|
+
# test = ValidateModelOneVideoMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
|
|
420
|
+
# feature_file_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/features_extracted/Together_1.csv',
|
|
421
|
+
# model_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/models/generated_models/Attack.sav',
|
|
422
|
+
# discrimination_threshold=0.6,
|
|
423
|
+
# shortest_bout=50,
|
|
424
|
+
# cores=6,
|
|
425
|
+
# settings={'pose': True, 'animal_names': True, 'styles': None},
|
|
426
|
+
# create_gantt=None)
|
|
427
|
+
# test.run()
|