simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simba/assets/.recent_projects.txt +1 -0
- simba/assets/lookups/tooptips.json +6 -1
- simba/assets/lookups/yolo_schematics/yolo_mitra.csv +9 -0
- simba/data_processors/agg_clf_counter_mp.py +52 -53
- simba/data_processors/blob_location_computer.py +1 -1
- simba/data_processors/circling_detector.py +30 -13
- simba/data_processors/cuda/geometry.py +45 -27
- simba/data_processors/cuda/image.py +1648 -1598
- simba/data_processors/cuda/statistics.py +72 -26
- simba/data_processors/cuda/timeseries.py +1 -1
- simba/data_processors/cue_light_analyzer.py +5 -9
- simba/data_processors/egocentric_aligner.py +25 -7
- simba/data_processors/freezing_detector.py +55 -47
- simba/data_processors/kleinberg_calculator.py +61 -29
- simba/feature_extractors/feature_subsets.py +14 -7
- simba/feature_extractors/mitra_feature_extractor.py +2 -2
- simba/feature_extractors/straub_tail_analyzer.py +4 -6
- simba/labelling/standard_labeller.py +1 -1
- simba/mixins/config_reader.py +5 -2
- simba/mixins/geometry_mixin.py +22 -36
- simba/mixins/image_mixin.py +24 -28
- simba/mixins/plotting_mixin.py +28 -10
- simba/mixins/statistics_mixin.py +48 -11
- simba/mixins/timeseries_features_mixin.py +1 -1
- simba/mixins/train_model_mixin.py +68 -33
- simba/model/inference_batch.py +2 -2
- simba/model/yolo_seg_inference.py +3 -3
- simba/outlier_tools/skip_outlier_correction.py +1 -1
- simba/plotting/ROI_feature_visualizer_mp.py +3 -5
- simba/plotting/clf_validator_mp.py +4 -5
- simba/plotting/cue_light_visualizer.py +6 -7
- simba/plotting/directing_animals_visualizer_mp.py +2 -3
- simba/plotting/distance_plotter_mp.py +378 -378
- simba/plotting/gantt_creator.py +29 -10
- simba/plotting/gantt_creator_mp.py +96 -33
- simba/plotting/geometry_plotter.py +270 -272
- simba/plotting/heat_mapper_clf_mp.py +4 -6
- simba/plotting/heat_mapper_location_mp.py +2 -2
- simba/plotting/light_dark_box_plotter.py +2 -2
- simba/plotting/path_plotter_mp.py +26 -29
- simba/plotting/plot_clf_results_mp.py +455 -454
- simba/plotting/pose_plotter_mp.py +28 -29
- simba/plotting/probability_plot_creator_mp.py +288 -288
- simba/plotting/roi_plotter_mp.py +31 -31
- simba/plotting/single_run_model_validation_video_mp.py +427 -427
- simba/plotting/spontaneous_alternation_plotter.py +2 -3
- simba/plotting/yolo_pose_track_visualizer.py +32 -27
- simba/plotting/yolo_pose_visualizer.py +35 -36
- simba/plotting/yolo_seg_visualizer.py +2 -3
- simba/pose_importers/simba_blob_importer.py +3 -3
- simba/roi_tools/roi_aggregate_stats_mp.py +5 -4
- simba/roi_tools/roi_clf_calculator_mp.py +4 -4
- simba/sandbox/analyze_runtimes.py +30 -0
- simba/sandbox/cuda/egocentric_rotator.py +374 -374
- simba/sandbox/get_cpu_pool.py +5 -0
- simba/sandbox/proboscis_to_tip.py +28 -0
- simba/sandbox/test_directionality.py +47 -0
- simba/sandbox/test_nonstatic_directionality.py +27 -0
- simba/sandbox/test_pycharm_cuda.py +51 -0
- simba/sandbox/test_simba_install.py +41 -0
- simba/sandbox/test_static_directionality.py +26 -0
- simba/sandbox/test_static_directionality_2d.py +26 -0
- simba/sandbox/verify_env.py +42 -0
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
- simba/third_party_label_appenders/transform/simba_to_yolo.py +8 -5
- simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
- simba/ui/pop_ups/fsttc_pop_up.py +27 -25
- simba/ui/pop_ups/gantt_pop_up.py +31 -6
- simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
- simba/ui/pop_ups/run_machine_models_popup.py +21 -21
- simba/ui/pop_ups/simba_to_yolo_keypoints_popup.py +2 -2
- simba/ui/pop_ups/video_processing_pop_up.py +37 -29
- simba/ui/pop_ups/yolo_inference_popup.py +1 -1
- simba/ui/pop_ups/yolo_pose_train_popup.py +1 -1
- simba/ui/tkinter_functions.py +3 -0
- simba/utils/custom_feature_extractor.py +1 -1
- simba/utils/data.py +90 -14
- simba/utils/enums.py +1 -0
- simba/utils/errors.py +441 -440
- simba/utils/lookups.py +1203 -1203
- simba/utils/printing.py +124 -124
- simba/utils/read_write.py +3769 -3721
- simba/utils/yolo.py +10 -1
- simba/video_processors/blob_tracking_executor.py +2 -2
- simba/video_processors/clahe_ui.py +1 -1
- simba/video_processors/egocentric_video_rotator.py +44 -41
- simba/video_processors/multi_cropper.py +1 -1
- simba/video_processors/video_processing.py +75 -33
- simba/video_processors/videos_to_frames.py +43 -33
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/METADATA +4 -3
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/RECORD +96 -85
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/top_level.txt +0 -0
simba/mixins/plotting_mixin.py
CHANGED
|
@@ -39,7 +39,7 @@ from simba.utils.data import create_color_palette, detect_bouts
|
|
|
39
39
|
from simba.utils.enums import Formats, Keys, Options, Paths
|
|
40
40
|
from simba.utils.errors import InvalidInputError
|
|
41
41
|
from simba.utils.lookups import (get_categorical_palettes, get_color_dict,
|
|
42
|
-
get_named_colors)
|
|
42
|
+
get_fonts, get_named_colors)
|
|
43
43
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
44
44
|
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
|
|
45
45
|
get_fn_ext, get_video_meta_data, read_df,
|
|
@@ -342,16 +342,28 @@ class PlottingMixin(object):
|
|
|
342
342
|
height: int = 480,
|
|
343
343
|
font_size: int = 8,
|
|
344
344
|
font_rotation: int = 45,
|
|
345
|
+
font: Optional[str] = None,
|
|
345
346
|
save_path: Optional[str] = None,
|
|
347
|
+
edge_clr: Optional[str] = 'black',
|
|
346
348
|
hhmmss: bool = False) -> Union[None, np.ndarray]:
|
|
347
349
|
|
|
348
350
|
video_timer = SimbaTimer(start=True)
|
|
349
351
|
colour_tuple_x = list(np.arange(3.5, 203.5, 5))
|
|
352
|
+
original_font_family = copy(plt.rcParams['font.family']) if isinstance(plt.rcParams['font.family'], list) else plt.rcParams['font.family']
|
|
353
|
+
|
|
354
|
+
if font is not None:
|
|
355
|
+
available_fonts = get_fonts()
|
|
356
|
+
if font in available_fonts:
|
|
357
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
358
|
+
plt.rcParams['font.family'] = font
|
|
359
|
+
else:
|
|
360
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
361
|
+
plt.rcParams['font.family'] = [font, 'sans-serif']
|
|
362
|
+
|
|
350
363
|
fig, ax = plt.subplots()
|
|
351
|
-
fig.patch.set_facecolor('#fafafa')
|
|
352
|
-
ax.set_facecolor('#ffffff')
|
|
353
364
|
fig.patch.set_facecolor('white')
|
|
354
|
-
|
|
365
|
+
|
|
366
|
+
plt.title(video_name, fontsize=font_size + 6, pad=25, fontweight='bold')
|
|
355
367
|
ax.spines['top'].set_visible(False)
|
|
356
368
|
ax.spines['right'].set_visible(False)
|
|
357
369
|
ax.spines['left'].set_color('#666666')
|
|
@@ -367,7 +379,7 @@ class PlottingMixin(object):
|
|
|
367
379
|
if event[0] == x:
|
|
368
380
|
ix = clf_names.index(x)
|
|
369
381
|
data_event = event[1][["Start_time", "Bout_time"]]
|
|
370
|
-
ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix])
|
|
382
|
+
ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix], edgecolor=edge_clr, linewidth=0.5, alpha=0.9)
|
|
371
383
|
|
|
372
384
|
x_ticks_seconds = np.round(np.linspace(0, x_length / fps, 6))
|
|
373
385
|
x_ticks_locs = x_ticks_seconds
|
|
@@ -375,18 +387,19 @@ class PlottingMixin(object):
|
|
|
375
387
|
if hhmmss:
|
|
376
388
|
x_lbls = [seconds_to_timestamp(sec) for sec in x_ticks_seconds]
|
|
377
389
|
else:
|
|
378
|
-
x_lbls = x_ticks_seconds
|
|
390
|
+
x_lbls = [int(x) for x in x_ticks_seconds]
|
|
379
391
|
|
|
380
|
-
#x_ticks_locs = x_lbls = np.round(np.linspace(0, x_length / fps, 6))
|
|
381
392
|
ax.set_xticks(x_ticks_locs)
|
|
382
393
|
ax.set_xticklabels(x_lbls)
|
|
383
394
|
ax.set_ylim(0, colour_tuple_x[len(clf_names)])
|
|
384
395
|
ax.set_yticks(np.arange(5, 5 * len(clf_names) + 1, 5))
|
|
385
|
-
ax.set_yticklabels(clf_names, rotation=font_rotation)
|
|
396
|
+
ax.set_yticklabels(clf_names, rotation=font_rotation, ha='right', va='center')
|
|
386
397
|
ax.tick_params(axis="both", labelsize=font_size)
|
|
387
398
|
plt.xlabel(x_label, fontsize=font_size + 3)
|
|
388
|
-
ax.
|
|
389
|
-
|
|
399
|
+
ax.grid(True, axis='both', linewidth=1.0, color='gray', alpha=0.2, linestyle='--', which='major')
|
|
400
|
+
|
|
401
|
+
plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
|
|
402
|
+
plt.tight_layout()
|
|
390
403
|
buffer_ = io.BytesIO()
|
|
391
404
|
plt.savefig(buffer_, format="png")
|
|
392
405
|
buffer_.seek(0)
|
|
@@ -397,6 +410,11 @@ class PlottingMixin(object):
|
|
|
397
410
|
frame = np.uint8(open_cv_image)
|
|
398
411
|
buffer_.close()
|
|
399
412
|
plt.close('all')
|
|
413
|
+
|
|
414
|
+
if font is not None:
|
|
415
|
+
plt.rcParams['font.family'] = original_font_family
|
|
416
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
417
|
+
|
|
400
418
|
if save_path is not None:
|
|
401
419
|
cv2.imwrite(save_path, frame)
|
|
402
420
|
video_timer.stop_timer()
|
simba/mixins/statistics_mixin.py
CHANGED
|
@@ -8,6 +8,8 @@ from sklearn.metrics import (adjusted_mutual_info_score, adjusted_rand_score,
|
|
|
8
8
|
fowlkes_mallows_score)
|
|
9
9
|
from sklearn.neighbors import LocalOutlierFactor
|
|
10
10
|
|
|
11
|
+
from simba.utils.printing import SimbaTimer
|
|
12
|
+
|
|
11
13
|
try:
|
|
12
14
|
from typing import Literal
|
|
13
15
|
except:
|
|
@@ -538,7 +540,8 @@ class Statistics(FeatureExtractionMixin):
|
|
|
538
540
|
sample_1: np.ndarray,
|
|
539
541
|
sample_2: np.ndarray,
|
|
540
542
|
fill_value: Optional[int] = 1,
|
|
541
|
-
bucket_method: Literal["fd", "doane", "auto", "scott", "stone", "rice", "sturges", "sqrt"] = "auto"
|
|
543
|
+
bucket_method: Literal["fd", "doane", "auto", "scott", "stone", "rice", "sturges", "sqrt"] = "auto",
|
|
544
|
+
verbose: bool = False) -> float:
|
|
542
545
|
|
|
543
546
|
r"""
|
|
544
547
|
Compute Kullback-Leibler divergence between two distributions.
|
|
@@ -562,6 +565,7 @@ class Statistics(FeatureExtractionMixin):
|
|
|
562
565
|
:returns: Kullback-Leibler divergence between ``sample_1`` and ``sample_2``
|
|
563
566
|
:rtype: float
|
|
564
567
|
"""
|
|
568
|
+
timer = SimbaTimer(start=True)
|
|
565
569
|
check_valid_array(data=sample_1, source=Statistics.kullback_leibler_divergence.__name__, accepted_ndims=(1,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
|
|
566
570
|
check_valid_array(data=sample_2, source=Statistics.kullback_leibler_divergence.__name__, accepted_ndims=(1,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
|
|
567
571
|
check_str(name=f"{self.__class__.__name__} bucket_method", value=bucket_method, options=Options.BUCKET_METHODS.value)
|
|
@@ -573,7 +577,10 @@ class Statistics(FeatureExtractionMixin):
|
|
|
573
577
|
sample_1_hist[sample_1_hist == 0] = fill_value
|
|
574
578
|
sample_2_hist[sample_2_hist == 0] = fill_value
|
|
575
579
|
sample_1_hist, sample_2_hist = sample_1_hist / np.sum(sample_1_hist), sample_2_hist / np.sum(sample_2_hist)
|
|
576
|
-
|
|
580
|
+
kl = stats.entropy(pk=sample_1_hist, qk=sample_2_hist)
|
|
581
|
+
timer.stop_timer()
|
|
582
|
+
if verbose: print(f'KL divergence performed on {sample_1.shape[0]} observations (elapsed time: {timer.elapsed_time_str}s)')
|
|
583
|
+
return kl
|
|
577
584
|
|
|
578
585
|
def rolling_kullback_leibler_divergence(
|
|
579
586
|
self,
|
|
@@ -3271,10 +3278,34 @@ class Statistics(FeatureExtractionMixin):
|
|
|
3271
3278
|
|
|
3272
3279
|
Youden's J statistic is a measure of the overall performance of a binary classification test, taking into account both sensitivity (true positive rate) and specificity (true negative rate).
|
|
3273
3280
|
|
|
3274
|
-
|
|
3275
|
-
|
|
3276
|
-
|
|
3281
|
+
The Youden's J statistic is calculated as:
|
|
3282
|
+
|
|
3283
|
+
.. math::
|
|
3284
|
+
J = \text{sensitivity} + \text{specificity} - 1
|
|
3285
|
+
|
|
3286
|
+
where:
|
|
3287
|
+
|
|
3288
|
+
- :math:`\text{sensitivity} = \frac{TP}{TP + FN}` is the true positive rate
|
|
3289
|
+
- :math:`\text{specificity} = \frac{TN}{TN + FP}` is the true negative rate
|
|
3290
|
+
|
|
3291
|
+
The statistic ranges from -1 to 1, where:
|
|
3292
|
+
- :math:`J = 1` indicates perfect classification
|
|
3293
|
+
- :math:`J = 0` indicates the test performs no better than random
|
|
3294
|
+
- :math:`J < 0` indicates the test performs worse than random
|
|
3295
|
+
|
|
3296
|
+
:param sample_1: The first binary array (ground truth or reference).
|
|
3297
|
+
:param sample_2: The second binary array (predictions or test results).
|
|
3298
|
+
:return: Youden's J statistic. Returns NaN if either sensitivity or specificity cannot be calculated (division by zero).
|
|
3277
3299
|
:rtype: float
|
|
3300
|
+
|
|
3301
|
+
:references:
|
|
3302
|
+
.. [1] Youden, W. J. (1950). Index for rating diagnostic tests. Cancer, 3(1), 32-35.
|
|
3303
|
+
https://acsjournals.onlinelibrary.wiley.com/doi/abs/10.1002/1097-0142(1950)3:1%3C32::AID-CNCR2820030106%3E3.0.CO;2-3
|
|
3304
|
+
|
|
3305
|
+
:example:
|
|
3306
|
+
>>> y_true = np.array([1, 1, 0, 0, 1, 0, 1, 1, 0, 0])
|
|
3307
|
+
>>> y_pred = np.array([1, 1, 0, 1, 1, 0, 1, 0, 0, 0])
|
|
3308
|
+
>>> j = Statistics.youden_j(sample_1=y_true, sample_2=y_pred)
|
|
3278
3309
|
"""
|
|
3279
3310
|
|
|
3280
3311
|
check_valid_array(data=sample_1, source=f'{Statistics.youden_j.__name__} sample_1', accepted_ndims=(1,), accepted_values=[0, 1])
|
|
@@ -4250,7 +4281,7 @@ class Statistics(FeatureExtractionMixin):
|
|
|
4250
4281
|
return separation_trace / compactness
|
|
4251
4282
|
|
|
4252
4283
|
@staticmethod
|
|
4253
|
-
def i_index(x: np.ndarray, y: np.ndarray):
|
|
4284
|
+
def i_index(x: np.ndarray, y: np.ndarray, verbose: bool = False) -> float:
|
|
4254
4285
|
|
|
4255
4286
|
"""
|
|
4256
4287
|
Calculate the I-Index for evaluating clustering quality.
|
|
@@ -4275,9 +4306,10 @@ class Statistics(FeatureExtractionMixin):
|
|
|
4275
4306
|
>>> X, y = make_blobs(n_samples=5000, centers=20, n_features=3, random_state=0, cluster_std=0.1)
|
|
4276
4307
|
>>> Statistics.i_index(x=X, y=y)
|
|
4277
4308
|
"""
|
|
4309
|
+
timer = SimbaTimer(start=True)
|
|
4278
4310
|
check_valid_array(data=x, accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
|
|
4279
4311
|
check_valid_array(data=y, accepted_ndims=(1,), accepted_dtypes=Formats.NUMERIC_DTYPES.value, accepted_axis_0_shape=[x.shape[0], ])
|
|
4280
|
-
|
|
4312
|
+
cluster_cnt = get_unique_values_in_iterable(data=y, name=Statistics.i_index.__name__, min=2)
|
|
4281
4313
|
unique_y = np.unique(y)
|
|
4282
4314
|
n_y = unique_y.shape[0]
|
|
4283
4315
|
global_centroid = np.mean(x, axis=0)
|
|
@@ -4289,7 +4321,12 @@ class Statistics(FeatureExtractionMixin):
|
|
|
4289
4321
|
cluster_centroid = np.mean(cluster_obs, axis=0)
|
|
4290
4322
|
swc += np.sum(np.linalg.norm(cluster_obs - cluster_centroid, axis=1) ** 2)
|
|
4291
4323
|
|
|
4292
|
-
|
|
4324
|
+
|
|
4325
|
+
i_index = np.float32(sst / (n_y * swc))
|
|
4326
|
+
timer.stop_timer()
|
|
4327
|
+
if verbose: print(f'I-index for {x.shape[0]} observations in {cluster_cnt} clusters computed (elapsed time: {timer.elapsed_time_str}s)')
|
|
4328
|
+
return i_index
|
|
4329
|
+
|
|
4293
4330
|
|
|
4294
4331
|
@staticmethod
|
|
4295
4332
|
def sd_index(x: np.ndarray, y: np.ndarray) -> float:
|
|
@@ -5291,7 +5328,7 @@ class Statistics(FeatureExtractionMixin):
|
|
|
5291
5328
|
"""
|
|
5292
5329
|
Compute one-way ANOVAs comparing each column (axis 1) on two arrays.
|
|
5293
5330
|
|
|
5294
|
-
..
|
|
5331
|
+
.. note::
|
|
5295
5332
|
Use for computing and presenting aggregate statistics. Not suitable for featurization.
|
|
5296
5333
|
|
|
5297
5334
|
.. seealso::
|
|
@@ -5329,7 +5366,7 @@ class Statistics(FeatureExtractionMixin):
|
|
|
5329
5366
|
"""
|
|
5330
5367
|
Compute Kruskal-Wallis comparing each column (axis 1) on two arrays.
|
|
5331
5368
|
|
|
5332
|
-
..
|
|
5369
|
+
.. note::
|
|
5333
5370
|
Use for computing and presenting aggregate statistics. Not suitable for featurization.
|
|
5334
5371
|
|
|
5335
5372
|
.. seealso::
|
|
@@ -5366,7 +5403,7 @@ class Statistics(FeatureExtractionMixin):
|
|
|
5366
5403
|
"""
|
|
5367
5404
|
Compute pairwise grouped Tukey-HSD tests.
|
|
5368
5405
|
|
|
5369
|
-
..
|
|
5406
|
+
.. note::
|
|
5370
5407
|
Use for computing and presenting aggregate statistics. Not suitable for featurization.
|
|
5371
5408
|
|
|
5372
5409
|
:param np.ndarray data: 2D array with observations rowwise (axis 0) and features columnwise (axis 1)
|
|
@@ -2198,7 +2198,7 @@ class TimeseriesFeatureMixin(object):
|
|
|
2198
2198
|
:example:
|
|
2199
2199
|
>>> x = np.random.randint(0, 100, (400, 2))
|
|
2200
2200
|
>>> results_1 = TimeseriesFeatureMixin.sliding_entropy_of_directional_changes(x=x, bins=16, window_size=5.0, sample_rate=30)
|
|
2201
|
-
>>> x = pd.read_csv(r"C
|
|
2201
|
+
>>> x = pd.read_csv(r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/input_csv/Together_1.csv")[['Ear_left_1_x', 'Ear_left_1_y']].values
|
|
2202
2202
|
>>> results_2 = TimeseriesFeatureMixin.sliding_entropy_of_directional_changes(x=x, bins=16, window_size=5.0, sample_rate=30)
|
|
2203
2203
|
"""
|
|
2204
2204
|
|
|
@@ -67,7 +67,7 @@ from simba.utils.checks import (check_all_dfs_in_list_has_same_cols,
|
|
|
67
67
|
check_valid_boolean, check_valid_dataframe,
|
|
68
68
|
check_valid_lst, is_lxc_container)
|
|
69
69
|
from simba.utils.data import (detect_bouts, detect_bouts_multiclass,
|
|
70
|
-
get_library_version)
|
|
70
|
+
get_library_version, terminate_cpu_pool)
|
|
71
71
|
from simba.utils.enums import (OS, ConfigKey, Defaults, Dtypes, Formats, Links,
|
|
72
72
|
Methods, MLParamKeys, Options)
|
|
73
73
|
from simba.utils.errors import (ClassifierInferenceError, CorruptedFileError,
|
|
@@ -77,10 +77,10 @@ from simba.utils.errors import (ClassifierInferenceError, CorruptedFileError,
|
|
|
77
77
|
SamplingError, SimBAModuleNotFoundError)
|
|
78
78
|
from simba.utils.lookups import get_meta_data_file_headers, get_table
|
|
79
79
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
80
|
-
from simba.utils.read_write import (find_core_cnt,
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
str_2_bool)
|
|
80
|
+
from simba.utils.read_write import (find_core_cnt, get_current_time,
|
|
81
|
+
get_fn_ext, get_memory_usage_of_df,
|
|
82
|
+
get_pkg_version, read_config_entry,
|
|
83
|
+
read_df, read_meta_file, str_2_bool)
|
|
84
84
|
from simba.utils.warnings import (GPUToolsWarning, MissingUserInputWarning,
|
|
85
85
|
MultiProcessingFailedWarning,
|
|
86
86
|
NoModuleWarning, NotEnoughDataWarning,
|
|
@@ -1070,10 +1070,7 @@ class TrainModelMixin(object):
|
|
|
1070
1070
|
MissingUserInputWarning(msg=f'Skipping {str(config.get("SML settings", "target_name_" + str(n + 1)))} classifier analysis: missing information (e.g., no discrimination threshold and/or minimum bout set in the project_config.ini',source=self.__class__.__name__)
|
|
1071
1071
|
|
|
1072
1072
|
if len(model_dict.keys()) == 0:
|
|
1073
|
-
raise NoDataError(
|
|
1074
|
-
msg=f"There are no models with accurate data specified in the RUN MODELS menu. Specify the model information to SimBA RUN MODELS menu to use them to analyze videos",
|
|
1075
|
-
source=self.get_model_info.__name__,
|
|
1076
|
-
)
|
|
1073
|
+
raise NoDataError(msg=f"There are no models with accurate data specified in the RUN MODELS menu. Specify the model information to SimBA RUN MODELS menu to use them to analyze videos. PLease check the model paths, thresholds, and minimum bout lengths.", source=self.get_model_info.__name__)
|
|
1077
1074
|
else:
|
|
1078
1075
|
return model_dict
|
|
1079
1076
|
|
|
@@ -1383,18 +1380,39 @@ class TrainModelMixin(object):
|
|
|
1383
1380
|
x_df: Union[pd.DataFrame, np.ndarray],
|
|
1384
1381
|
multiclass: bool = False,
|
|
1385
1382
|
model_name: Optional[str] = None,
|
|
1386
|
-
data_path: Optional[Union[str, os.PathLike]] = None
|
|
1383
|
+
data_path: Optional[Union[str, os.PathLike]] = None,
|
|
1384
|
+
verbose: bool = False) -> np.ndarray:
|
|
1387
1385
|
|
|
1388
1386
|
"""
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1387
|
+
Helper to predict class probabilities using a fitted random forest classifier.
|
|
1388
|
+
|
|
1389
|
+
Computes prediction probabilities for binary or multiclass classification using either
|
|
1390
|
+
scikit-learn or cuML RandomForestClassifier. For binary classifiers, returns the
|
|
1391
|
+
probability of the positive class (class 1). For multiclass classifiers, returns
|
|
1392
|
+
probabilities for all classes.
|
|
1393
|
+
|
|
1394
|
+
.. csv-table::
|
|
1395
|
+
:header: EXPECTED RUNTIMES
|
|
1396
|
+
:file: ../../docs/tables/clf_predict_proba.csv
|
|
1397
|
+
:widths: 10, 45, 45
|
|
1398
|
+
:align: center
|
|
1399
|
+
:header-rows: 1
|
|
1400
|
+
|
|
1401
|
+
.. seealso::
|
|
1402
|
+
To fit a classifier, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_fit`
|
|
1403
|
+
To define a classifier, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_define`
|
|
1404
|
+
|
|
1405
|
+
:param Union[RandomForestClassifier, cuRF] clf: Fitted random forest classifier object from sklearn or cuml.
|
|
1406
|
+
:param Union[pd.DataFrame, np.ndarray] x_df: Features for data to predict. DataFrame or array of shape (n_samples, n_features).
|
|
1407
|
+
:param bool multiclass: If True, the classifier predicts more than 2 classes. If False, binary classifier (default: False).
|
|
1408
|
+
:param Optional[str] model_name: Name of the model for error messages and logging. Default: None.
|
|
1409
|
+
:param Optional[Union[str, os.PathLike]] data_path: Path to the data file being processed, used in error messages. Default: None.
|
|
1410
|
+
:param bool verbose: If True, print inference progress and timing information. Default: False.
|
|
1411
|
+
:return np.ndarray: Prediction probabilities. For binary classifiers: 1D array of shape (n_samples,) with probability of positive class. For multiclass: 2D array of shape (n_samples, n_classes) with probabilities for each class.
|
|
1412
|
+
|
|
1396
1413
|
"""
|
|
1397
1414
|
|
|
1415
|
+
timer = SimbaTimer(start=True)
|
|
1398
1416
|
if hasattr(clf, "n_features_"):
|
|
1399
1417
|
clf_n_features = clf.n_features_
|
|
1400
1418
|
elif hasattr(clf, "n_features_in_"):
|
|
@@ -1420,6 +1438,8 @@ class TrainModelMixin(object):
|
|
|
1420
1438
|
p_vals = clf.predict_proba(x_df)
|
|
1421
1439
|
if multiclass and (clf.n_classes_ != p_vals.shape[1]):
|
|
1422
1440
|
raise ClassifierInferenceError(msg=f"The classifier {model_name} (data path: {data_path}) is a multiclassifier expected to create {clf.n_classes_} behavior probabilities. However, it produced probabilities for {p_vals.shape[1]} behaviors. See The SimBA GitHub FAQ page or Gitter for more information and suggested fixes.", source=self.__class__.__name__)
|
|
1441
|
+
timer.stop_timer()
|
|
1442
|
+
if verbose: print(f'Inference for model {model_name} over {x_df.shape[0]} observations complete ({timer.elapsed_time_str}s).')
|
|
1423
1443
|
if not multiclass:
|
|
1424
1444
|
if isinstance(p_vals, pd.DataFrame):
|
|
1425
1445
|
return p_vals[1].values
|
|
@@ -1447,7 +1467,7 @@ class TrainModelMixin(object):
|
|
|
1447
1467
|
bootstrap: Optional[bool] = True,
|
|
1448
1468
|
verbose: Optional[int] = 1,
|
|
1449
1469
|
class_weight: Optional[dict] = None,
|
|
1450
|
-
cuda: Optional[bool] = False) -> RandomForestClassifier:
|
|
1470
|
+
cuda: Optional[bool] = False) -> Union[RandomForestClassifier, cuRF]:
|
|
1451
1471
|
|
|
1452
1472
|
if not cuda:
|
|
1453
1473
|
# NOTE: LOKY ISSUES ON WINDOWS WITH SCIKIT IF THE CORE COUNT EXCEEDS 61.
|
|
@@ -1482,20 +1502,32 @@ class TrainModelMixin(object):
|
|
|
1482
1502
|
clf: Union[RandomForestClassifier, cuRF],
|
|
1483
1503
|
x_df: pd.DataFrame,
|
|
1484
1504
|
y_df: pd.DataFrame,
|
|
1485
|
-
) -> RandomForestClassifier:
|
|
1505
|
+
verbose: bool = False) -> Union[RandomForestClassifier, cuRF]:
|
|
1486
1506
|
|
|
1487
1507
|
"""
|
|
1488
|
-
Helper to fit clf model
|
|
1508
|
+
Helper to fit clf model.
|
|
1489
1509
|
|
|
1490
|
-
|
|
1510
|
+
.. csv-table::
|
|
1511
|
+
:header: EXPECTED RUNTIMES
|
|
1512
|
+
:file: ../../docs/tables/clf_fit.csv
|
|
1513
|
+
:widths: 20, 20, 30, 30
|
|
1514
|
+
:align: center
|
|
1515
|
+
:header-rows: 1
|
|
1516
|
+
|
|
1517
|
+
.. seealso::
|
|
1518
|
+
To define a cuml/sklearn object, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_define`
|
|
1519
|
+
|
|
1520
|
+
:param clf: Un-fitted random forest classifier object, either from sklearn or cuml.
|
|
1491
1521
|
:param pd.DataFrame x_df: Pandas dataframe with features.
|
|
1492
1522
|
:param pd.DataFrame y_df: Pandas dataframe/Series with target
|
|
1493
1523
|
:return: Fitted random forest classifier object
|
|
1494
1524
|
:rtype: RandomForestClassifier
|
|
1495
1525
|
"""
|
|
1496
1526
|
|
|
1527
|
+
timer = SimbaTimer(start=True)
|
|
1497
1528
|
nan_features = x_df[~x_df.applymap(np.isreal).all(1)]
|
|
1498
1529
|
nan_target = y_df.loc[pd.to_numeric(y_df).isna()]
|
|
1530
|
+
using_cuda = True if CUML in str(clf.__class__.__module__).lower() else False
|
|
1499
1531
|
if len(nan_features) > 0:
|
|
1500
1532
|
raise FaultyTrainingSetError(
|
|
1501
1533
|
msg=f"{len(nan_features)} frame(s) in your project_folder/csv/targets_inserted directory contains FEATURES with non-numerical values",
|
|
@@ -1504,9 +1536,16 @@ class TrainModelMixin(object):
|
|
|
1504
1536
|
raise FaultyTrainingSetError(
|
|
1505
1537
|
msg=f"{len(nan_target)} frame(s) in your project_folder/csv/targets_inserted directory contains ANNOTATIONS with non-numerical values",
|
|
1506
1538
|
source=self.__class__.__name__)
|
|
1539
|
+
if verbose: print(f'[{get_current_time()}] Fitting classifier for {len(x_df)} observations (cuda: {"True" if using_cuda else "False"})...')
|
|
1540
|
+
if using_cuda:
|
|
1541
|
+
x_data = x_df.values if isinstance(x_df, pd.DataFrame) else x_df
|
|
1542
|
+
y_data = y_df.values if isinstance(y_df, (pd.DataFrame, pd.Series)) else y_df
|
|
1543
|
+
clf.fit(x_data, y_data)
|
|
1544
|
+
else:
|
|
1545
|
+
clf.fit(x_df, y_df)
|
|
1507
1546
|
|
|
1508
|
-
|
|
1509
|
-
|
|
1547
|
+
timer.stop_timer()
|
|
1548
|
+
if verbose: print(f'[{get_current_time()}] Classifier fitted in {timer.elapsed_time_str}s.')
|
|
1510
1549
|
return clf
|
|
1511
1550
|
|
|
1512
1551
|
@staticmethod
|
|
@@ -1563,9 +1602,7 @@ class TrainModelMixin(object):
|
|
|
1563
1602
|
:rtype: Tuple[pd.DataFrame, List[int]]
|
|
1564
1603
|
|
|
1565
1604
|
"""
|
|
1566
|
-
if (platform.system() == "Darwin") and (
|
|
1567
|
-
multiprocessing.get_start_method() != "spawn"
|
|
1568
|
-
):
|
|
1605
|
+
if (platform.system() == "Darwin") and (multiprocessing.get_start_method() != "spawn"):
|
|
1569
1606
|
multiprocessing.set_start_method("spawn", force=True)
|
|
1570
1607
|
cpu_cnt, _ = find_core_cnt()
|
|
1571
1608
|
df_lst, frame_numbers_lst = [], []
|
|
@@ -1592,9 +1629,7 @@ class TrainModelMixin(object):
|
|
|
1592
1629
|
:, ~df_concat.columns.str.contains("^Unnamed")
|
|
1593
1630
|
].astype(np.float32)
|
|
1594
1631
|
memory_size = get_memory_usage_of_df(df=df_concat)
|
|
1595
|
-
print(
|
|
1596
|
-
f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB'
|
|
1597
|
-
)
|
|
1632
|
+
print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB')
|
|
1598
1633
|
|
|
1599
1634
|
return df_concat, frame_numbers_lst
|
|
1600
1635
|
|
|
@@ -1859,7 +1894,7 @@ class TrainModelMixin(object):
|
|
|
1859
1894
|
shap_raw.append(shap_data[result[1]][1].drop(clf_name, axis=1))
|
|
1860
1895
|
if verbose: print(f"Completed SHAP care batch (Batch {result[1] + 1}/{len(shap_data)}).")
|
|
1861
1896
|
|
|
1862
|
-
pool
|
|
1897
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
1863
1898
|
shap_df = pd.DataFrame(data=np.row_stack(shap_results), columns=list(x_names) + ["Expected_value", "Sum", "Prediction_probability", clf_name])
|
|
1864
1899
|
raw_df = pd.DataFrame(data=np.row_stack(shap_raw), columns=list(x_names))
|
|
1865
1900
|
out_shap_path, out_raw_path, img_save_path, df_save_paths, summary_dfs, img = None, None, None, None, None, None
|
|
@@ -2607,9 +2642,9 @@ class TrainModelMixin(object):
|
|
|
2607
2642
|
:param bool plot: If True, create SHAP aggregation and plots.
|
|
2608
2643
|
|
|
2609
2644
|
:example:
|
|
2610
|
-
>>> CONFIG_PATH = r"C
|
|
2611
|
-
>>> RF_PATH = r"C
|
|
2612
|
-
>>> DATA_PATH = r"C
|
|
2645
|
+
>>> CONFIG_PATH = r"C:/troubleshooting/mitra/project_folder/project_config.ini"
|
|
2646
|
+
>>> RF_PATH = r"C:/troubleshooting/mitra/models/validations/straub_tail_5_new/straub_tail_5.sav"
|
|
2647
|
+
>>> DATA_PATH = r"C:/troubleshooting/mitra/project_folder/csv/targets_inserted/new_straub/appended/501_MA142_Gi_CNO_0514.csv"
|
|
2613
2648
|
>>> config = ConfigReader(config_path=CONFIG_PATH)
|
|
2614
2649
|
>>> df = read_df(file_path=DATA_PATH, file_type='csv')
|
|
2615
2650
|
>>> y = df['straub_tail']
|
simba/model/inference_batch.py
CHANGED
|
@@ -45,7 +45,7 @@ class InferenceBatch(TrainModelMixin, ConfigReader):
|
|
|
45
45
|
>>> inferencer.run()
|
|
46
46
|
|
|
47
47
|
:example II:
|
|
48
|
-
>>> inferencer = InferenceBatch(config_path=r"D
|
|
48
|
+
>>> inferencer = InferenceBatch(config_path=r"D:/troubleshooting/mitra/project_folder/project_config.ini", features_dir=r"D:/troubleshooting/mitra/project_folder/videos/bg_removed/rotated/tail_features/APPENDED")
|
|
49
49
|
>>> inferencer.run()
|
|
50
50
|
"""
|
|
51
51
|
|
|
@@ -101,7 +101,7 @@ class InferenceBatch(TrainModelMixin, ConfigReader):
|
|
|
101
101
|
video_timer.stop_timer()
|
|
102
102
|
print(f"Predictions created for {file_name} (frame count: {len(in_df)}, elapsed time: {video_timer.elapsed_time_str}) ...")
|
|
103
103
|
self.timer.stop_timer()
|
|
104
|
-
stdout_success(msg=f"Machine predictions complete. Files saved in {self.save_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
|
|
104
|
+
stdout_success(msg=f"Machine predictions complete for {len(self.feature_file_paths)} file(s). Files saved in {self.save_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
|
|
105
105
|
|
|
106
106
|
if __name__ == "__main__" and not hasattr(sys, 'ps1'):
|
|
107
107
|
parser = argparse.ArgumentParser(description="Perform classifications according to rules defined in SImAB project_config.ini.")
|
|
@@ -55,9 +55,9 @@ class YOLOSegmentationInference():
|
|
|
55
55
|
To visualize the segmentation results, see :func:`simba.plotting.yolo_seg_visualizer.YOLOSegmentationVisualizer`
|
|
56
56
|
|
|
57
57
|
:example:
|
|
58
|
-
>>> weights_path = r"D
|
|
59
|
-
>>> video_path = r"D
|
|
60
|
-
>>> save_dir=r"D
|
|
58
|
+
>>> weights_path = r"D:/platea/yolo_071525/mdl/train3/weights/best.pt"
|
|
59
|
+
>>> video_path = r"D:/platea/platea_videos/videos/clipped/10B_Mouse_5-choice_MustTouchTrainingNEWFINAL_a7.mp4"
|
|
60
|
+
>>> save_dir = r"D:/platea/platea_videos/videos/yolo_results"
|
|
61
61
|
>>> runner = YOLOSegmentationInference(weights_path=weights_path, video_path=video_path, save_dir=save_dir, verbose=True, device=0, format=None, stream=True, batch_size=10, imgsz=320, interpolate=True, threshold=0.8, retina_msk=True)
|
|
62
62
|
>>> runner.run()
|
|
63
63
|
|
|
@@ -47,5 +47,5 @@ class OutlierCorrectionSkipper(ConfigReader):
|
|
|
47
47
|
self.timer.stop_timer()
|
|
48
48
|
stdout_success(msg=f"Skipped outlier correction for {len(self.input_csv_paths)} files", elapsed_time=self.timer.elapsed_time_str)
|
|
49
49
|
|
|
50
|
-
# test = OutlierCorrectionSkipper(config_path=
|
|
50
|
+
# test = OutlierCorrectionSkipper(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
|
|
51
51
|
# test.run()
|
|
@@ -24,10 +24,9 @@ from simba.utils.checks import (check_file_exist_and_readable,
|
|
|
24
24
|
check_if_valid_rgb_tuple, check_int, check_str,
|
|
25
25
|
check_valid_boolean, check_valid_lst,
|
|
26
26
|
check_video_and_data_frm_count_align)
|
|
27
|
-
from simba.utils.data import slice_roi_dict_for_video
|
|
27
|
+
from simba.utils.data import slice_roi_dict_for_video, terminate_cpu_pool
|
|
28
28
|
from simba.utils.enums import Formats, TextOptions
|
|
29
|
-
from simba.utils.errors import
|
|
30
|
-
ROICoordinatesNotFoundError)
|
|
29
|
+
from simba.utils.errors import BodypartColumnNotFoundError, NoFilesFoundError
|
|
31
30
|
from simba.utils.printing import stdout_success
|
|
32
31
|
from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
33
32
|
find_core_cnt, get_fn_ext,
|
|
@@ -315,8 +314,7 @@ class ROIfeatureVisualizerMultiprocess(ConfigReader):
|
|
|
315
314
|
print(f"Joining {self.video_name} multi-processed video...")
|
|
316
315
|
concatenate_videos_in_folder(in_folder=self.save_temp_dir, save_path=self.save_path, video_format="mp4", remove_splits=True, gpu=self.gpu)
|
|
317
316
|
self.timer.stop_timer()
|
|
318
|
-
pool
|
|
319
|
-
pool.join()
|
|
317
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
320
318
|
stdout_success(msg=f"Video {self.video_name} complete. Video saved in directory {self.roi_features_save_dir}.", elapsed_time=self.timer.elapsed_time_str)
|
|
321
319
|
|
|
322
320
|
|
|
@@ -14,9 +14,9 @@ from simba.mixins.plotting_mixin import PlottingMixin
|
|
|
14
14
|
from simba.utils.checks import (check_float, check_if_valid_rgb_tuple,
|
|
15
15
|
check_int, check_str, check_that_column_exist,
|
|
16
16
|
check_valid_lst)
|
|
17
|
-
from simba.utils.data import detect_bouts
|
|
18
|
-
from simba.utils.enums import Formats,
|
|
19
|
-
from simba.utils.errors import
|
|
17
|
+
from simba.utils.data import detect_bouts, terminate_cpu_pool
|
|
18
|
+
from simba.utils.enums import Formats, TextOptions
|
|
19
|
+
from simba.utils.errors import NoSpecifiedOutputError
|
|
20
20
|
from simba.utils.printing import SimbaTimer, log_event, stdout_success
|
|
21
21
|
from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
22
22
|
find_core_cnt, get_fn_ext,
|
|
@@ -218,8 +218,7 @@ class ClassifierValidationClipsMultiprocess(ConfigReader):
|
|
|
218
218
|
for cnt, result in enumerate(
|
|
219
219
|
pool.imap(constants, clip_data, chunksize=self.multiprocess_chunksize)):
|
|
220
220
|
print(f"Bout {cnt+1} complete...")
|
|
221
|
-
pool
|
|
222
|
-
pool.join()
|
|
221
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
223
222
|
|
|
224
223
|
if self.concat_video:
|
|
225
224
|
print(f"Joining {file_name} multiprocessed video...")
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
__author__ = "Simon Nilsson; sronilsson@gmail.com"
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
|
-
import itertools
|
|
5
4
|
import multiprocessing
|
|
6
5
|
import os
|
|
7
|
-
from typing import List,
|
|
6
|
+
from typing import List, Union
|
|
8
7
|
|
|
9
8
|
import cv2
|
|
10
9
|
import numpy as np
|
|
@@ -16,14 +15,15 @@ from simba.utils.checks import (check_file_exist_and_readable, check_int,
|
|
|
16
15
|
check_valid_boolean, check_valid_dataframe,
|
|
17
16
|
check_valid_lst)
|
|
18
17
|
from simba.utils.data import (create_color_palettes, detect_bouts,
|
|
19
|
-
slice_roi_dict_from_attribute
|
|
20
|
-
|
|
18
|
+
slice_roi_dict_from_attribute,
|
|
19
|
+
terminate_cpu_pool)
|
|
20
|
+
from simba.utils.enums import Defaults, Formats, TextOptions
|
|
21
21
|
from simba.utils.errors import NoROIDataError, NoSpecifiedOutputError
|
|
22
22
|
from simba.utils.printing import stdout_success
|
|
23
23
|
from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
24
24
|
create_directory, find_core_cnt,
|
|
25
25
|
get_fn_ext, get_video_meta_data, read_df,
|
|
26
|
-
read_frm_of_video
|
|
26
|
+
read_frm_of_video)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def _plot_cue_light_data(frm_idxs: list,
|
|
@@ -197,8 +197,7 @@ class CueLightVisualizer(ConfigReader):
|
|
|
197
197
|
for cnt, result in enumerate(pool.imap(constants, self.frame_chunks, chunksize=self.multiprocess_chunksize)):
|
|
198
198
|
if self.verbose:
|
|
199
199
|
print(f'Batch {int(result+1/self.core_cnt)} complete...')
|
|
200
|
-
pool
|
|
201
|
-
pool.join()
|
|
200
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
202
201
|
self.timer.stop_timer()
|
|
203
202
|
if self.video_setting:
|
|
204
203
|
print(f"Joining {self.video_name} multiprocessed video...")
|
|
@@ -19,7 +19,7 @@ from simba.utils.checks import (check_file_exist_and_readable,
|
|
|
19
19
|
check_if_valid_rgb_tuple, check_int,
|
|
20
20
|
check_valid_lst,
|
|
21
21
|
check_video_and_data_frm_count_align)
|
|
22
|
-
from simba.utils.data import create_color_palettes
|
|
22
|
+
from simba.utils.data import create_color_palettes, terminate_cpu_pool
|
|
23
23
|
from simba.utils.enums import OS, Formats, Keys, TextOptions
|
|
24
24
|
from simba.utils.errors import (AnimalNumberError, InvalidInputError,
|
|
25
25
|
NoFilesFoundError)
|
|
@@ -226,8 +226,7 @@ class DirectingOtherAnimalsVisualizerMultiprocess(ConfigReader, PlottingMixin):
|
|
|
226
226
|
print(f"Joining {self.video_name} multi-processed video...")
|
|
227
227
|
concatenate_videos_in_folder(in_folder=self.save_temp_path, save_path=self.save_path, video_format="mp4", remove_splits=True)
|
|
228
228
|
self.timer.stop_timer()
|
|
229
|
-
pool
|
|
230
|
-
pool.join()
|
|
229
|
+
terminate_cpu_pool(pool=pool, force=False)
|
|
231
230
|
stdout_success(msg=f"Video {self.video_name} complete. Video saved in {self.directing_animals_video_output_path} directory", elapsed_time=self.timer.elapsed_time_str)
|
|
232
231
|
|
|
233
232
|
|