simba-uw-tf-dev 4.6.2__py3-none-any.whl → 4.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. simba/assets/.recent_projects.txt +1 -0
  2. simba/assets/lookups/tooptips.json +6 -1
  3. simba/assets/lookups/yolo_schematics/yolo_mitra.csv +9 -0
  4. simba/data_processors/agg_clf_counter_mp.py +52 -53
  5. simba/data_processors/blob_location_computer.py +1 -1
  6. simba/data_processors/circling_detector.py +30 -13
  7. simba/data_processors/cuda/geometry.py +45 -27
  8. simba/data_processors/cuda/image.py +1648 -1598
  9. simba/data_processors/cuda/statistics.py +72 -26
  10. simba/data_processors/cuda/timeseries.py +1 -1
  11. simba/data_processors/cue_light_analyzer.py +5 -9
  12. simba/data_processors/egocentric_aligner.py +25 -7
  13. simba/data_processors/freezing_detector.py +55 -47
  14. simba/data_processors/kleinberg_calculator.py +61 -29
  15. simba/feature_extractors/feature_subsets.py +14 -7
  16. simba/feature_extractors/mitra_feature_extractor.py +2 -2
  17. simba/feature_extractors/straub_tail_analyzer.py +4 -6
  18. simba/labelling/standard_labeller.py +1 -1
  19. simba/mixins/config_reader.py +5 -2
  20. simba/mixins/geometry_mixin.py +22 -36
  21. simba/mixins/image_mixin.py +24 -28
  22. simba/mixins/plotting_mixin.py +28 -10
  23. simba/mixins/statistics_mixin.py +48 -11
  24. simba/mixins/timeseries_features_mixin.py +1 -1
  25. simba/mixins/train_model_mixin.py +68 -33
  26. simba/model/inference_batch.py +2 -2
  27. simba/model/yolo_seg_inference.py +3 -3
  28. simba/outlier_tools/skip_outlier_correction.py +1 -1
  29. simba/plotting/ROI_feature_visualizer_mp.py +3 -5
  30. simba/plotting/clf_validator_mp.py +4 -5
  31. simba/plotting/cue_light_visualizer.py +6 -7
  32. simba/plotting/directing_animals_visualizer_mp.py +2 -3
  33. simba/plotting/distance_plotter_mp.py +378 -378
  34. simba/plotting/gantt_creator.py +29 -10
  35. simba/plotting/gantt_creator_mp.py +96 -33
  36. simba/plotting/geometry_plotter.py +270 -272
  37. simba/plotting/heat_mapper_clf_mp.py +4 -6
  38. simba/plotting/heat_mapper_location_mp.py +2 -2
  39. simba/plotting/light_dark_box_plotter.py +2 -2
  40. simba/plotting/path_plotter_mp.py +26 -29
  41. simba/plotting/plot_clf_results_mp.py +455 -454
  42. simba/plotting/pose_plotter_mp.py +28 -29
  43. simba/plotting/probability_plot_creator_mp.py +288 -288
  44. simba/plotting/roi_plotter_mp.py +31 -31
  45. simba/plotting/single_run_model_validation_video_mp.py +427 -427
  46. simba/plotting/spontaneous_alternation_plotter.py +2 -3
  47. simba/plotting/yolo_pose_track_visualizer.py +32 -27
  48. simba/plotting/yolo_pose_visualizer.py +35 -36
  49. simba/plotting/yolo_seg_visualizer.py +2 -3
  50. simba/pose_importers/simba_blob_importer.py +3 -3
  51. simba/roi_tools/roi_aggregate_stats_mp.py +5 -4
  52. simba/roi_tools/roi_clf_calculator_mp.py +4 -4
  53. simba/sandbox/analyze_runtimes.py +30 -0
  54. simba/sandbox/cuda/egocentric_rotator.py +374 -374
  55. simba/sandbox/get_cpu_pool.py +5 -0
  56. simba/sandbox/proboscis_to_tip.py +28 -0
  57. simba/sandbox/test_directionality.py +47 -0
  58. simba/sandbox/test_nonstatic_directionality.py +27 -0
  59. simba/sandbox/test_pycharm_cuda.py +51 -0
  60. simba/sandbox/test_simba_install.py +41 -0
  61. simba/sandbox/test_static_directionality.py +26 -0
  62. simba/sandbox/test_static_directionality_2d.py +26 -0
  63. simba/sandbox/verify_env.py +42 -0
  64. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
  65. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
  66. simba/third_party_label_appenders/transform/simba_to_yolo.py +8 -5
  67. simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
  68. simba/ui/pop_ups/fsttc_pop_up.py +27 -25
  69. simba/ui/pop_ups/gantt_pop_up.py +31 -6
  70. simba/ui/pop_ups/kleinberg_pop_up.py +39 -40
  71. simba/ui/pop_ups/run_machine_models_popup.py +21 -21
  72. simba/ui/pop_ups/simba_to_yolo_keypoints_popup.py +2 -2
  73. simba/ui/pop_ups/video_processing_pop_up.py +37 -29
  74. simba/ui/pop_ups/yolo_inference_popup.py +1 -1
  75. simba/ui/pop_ups/yolo_pose_train_popup.py +1 -1
  76. simba/ui/tkinter_functions.py +3 -0
  77. simba/utils/custom_feature_extractor.py +1 -1
  78. simba/utils/data.py +90 -14
  79. simba/utils/enums.py +1 -0
  80. simba/utils/errors.py +441 -440
  81. simba/utils/lookups.py +1203 -1203
  82. simba/utils/printing.py +124 -124
  83. simba/utils/read_write.py +3769 -3721
  84. simba/utils/yolo.py +10 -1
  85. simba/video_processors/blob_tracking_executor.py +2 -2
  86. simba/video_processors/clahe_ui.py +1 -1
  87. simba/video_processors/egocentric_video_rotator.py +44 -41
  88. simba/video_processors/multi_cropper.py +1 -1
  89. simba/video_processors/video_processing.py +75 -33
  90. simba/video_processors/videos_to_frames.py +43 -33
  91. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/METADATA +4 -3
  92. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/RECORD +96 -85
  93. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/LICENSE +0 -0
  94. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/WHEEL +0 -0
  95. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/entry_points.txt +0 -0
  96. {simba_uw_tf_dev-4.6.2.dist-info → simba_uw_tf_dev-4.7.2.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,7 @@ from simba.utils.data import create_color_palette, detect_bouts
39
39
  from simba.utils.enums import Formats, Keys, Options, Paths
40
40
  from simba.utils.errors import InvalidInputError
41
41
  from simba.utils.lookups import (get_categorical_palettes, get_color_dict,
42
- get_named_colors)
42
+ get_fonts, get_named_colors)
43
43
  from simba.utils.printing import SimbaTimer, stdout_success
44
44
  from simba.utils.read_write import (find_files_of_filetypes_in_directory,
45
45
  get_fn_ext, get_video_meta_data, read_df,
@@ -342,16 +342,28 @@ class PlottingMixin(object):
342
342
  height: int = 480,
343
343
  font_size: int = 8,
344
344
  font_rotation: int = 45,
345
+ font: Optional[str] = None,
345
346
  save_path: Optional[str] = None,
347
+ edge_clr: Optional[str] = 'black',
346
348
  hhmmss: bool = False) -> Union[None, np.ndarray]:
347
349
 
348
350
  video_timer = SimbaTimer(start=True)
349
351
  colour_tuple_x = list(np.arange(3.5, 203.5, 5))
352
+ original_font_family = copy(plt.rcParams['font.family']) if isinstance(plt.rcParams['font.family'], list) else plt.rcParams['font.family']
353
+
354
+ if font is not None:
355
+ available_fonts = get_fonts()
356
+ if font in available_fonts:
357
+ matplotlib.font_manager._get_font.cache_clear()
358
+ plt.rcParams['font.family'] = font
359
+ else:
360
+ matplotlib.font_manager._get_font.cache_clear()
361
+ plt.rcParams['font.family'] = [font, 'sans-serif']
362
+
350
363
  fig, ax = plt.subplots()
351
- fig.patch.set_facecolor('#fafafa')
352
- ax.set_facecolor('#ffffff')
353
364
  fig.patch.set_facecolor('white')
354
- plt.title(video_name, fontsize=font_size + 6, pad=15, fontweight='bold')
365
+
366
+ plt.title(video_name, fontsize=font_size + 6, pad=25, fontweight='bold')
355
367
  ax.spines['top'].set_visible(False)
356
368
  ax.spines['right'].set_visible(False)
357
369
  ax.spines['left'].set_color('#666666')
@@ -367,7 +379,7 @@ class PlottingMixin(object):
367
379
  if event[0] == x:
368
380
  ix = clf_names.index(x)
369
381
  data_event = event[1][["Start_time", "Bout_time"]]
370
- ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix])
382
+ ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix], edgecolor=edge_clr, linewidth=0.5, alpha=0.9)
371
383
 
372
384
  x_ticks_seconds = np.round(np.linspace(0, x_length / fps, 6))
373
385
  x_ticks_locs = x_ticks_seconds
@@ -375,18 +387,19 @@ class PlottingMixin(object):
375
387
  if hhmmss:
376
388
  x_lbls = [seconds_to_timestamp(sec) for sec in x_ticks_seconds]
377
389
  else:
378
- x_lbls = x_ticks_seconds
390
+ x_lbls = [int(x) for x in x_ticks_seconds]
379
391
 
380
- #x_ticks_locs = x_lbls = np.round(np.linspace(0, x_length / fps, 6))
381
392
  ax.set_xticks(x_ticks_locs)
382
393
  ax.set_xticklabels(x_lbls)
383
394
  ax.set_ylim(0, colour_tuple_x[len(clf_names)])
384
395
  ax.set_yticks(np.arange(5, 5 * len(clf_names) + 1, 5))
385
- ax.set_yticklabels(clf_names, rotation=font_rotation)
396
+ ax.set_yticklabels(clf_names, rotation=font_rotation, ha='right', va='center')
386
397
  ax.tick_params(axis="both", labelsize=font_size)
387
398
  plt.xlabel(x_label, fontsize=font_size + 3)
388
- ax.yaxis.grid(True, linewidth=1.5, color='gray', alpha=0.4, linestyle='--')
389
- #ax.yaxis.grid(True)
399
+ ax.grid(True, axis='both', linewidth=1.0, color='gray', alpha=0.2, linestyle='--', which='major')
400
+
401
+ plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
402
+ plt.tight_layout()
390
403
  buffer_ = io.BytesIO()
391
404
  plt.savefig(buffer_, format="png")
392
405
  buffer_.seek(0)
@@ -397,6 +410,11 @@ class PlottingMixin(object):
397
410
  frame = np.uint8(open_cv_image)
398
411
  buffer_.close()
399
412
  plt.close('all')
413
+
414
+ if font is not None:
415
+ plt.rcParams['font.family'] = original_font_family
416
+ matplotlib.font_manager._get_font.cache_clear()
417
+
400
418
  if save_path is not None:
401
419
  cv2.imwrite(save_path, frame)
402
420
  video_timer.stop_timer()
@@ -8,6 +8,8 @@ from sklearn.metrics import (adjusted_mutual_info_score, adjusted_rand_score,
8
8
  fowlkes_mallows_score)
9
9
  from sklearn.neighbors import LocalOutlierFactor
10
10
 
11
+ from simba.utils.printing import SimbaTimer
12
+
11
13
  try:
12
14
  from typing import Literal
13
15
  except:
@@ -538,7 +540,8 @@ class Statistics(FeatureExtractionMixin):
538
540
  sample_1: np.ndarray,
539
541
  sample_2: np.ndarray,
540
542
  fill_value: Optional[int] = 1,
541
- bucket_method: Literal["fd", "doane", "auto", "scott", "stone", "rice", "sturges", "sqrt"] = "auto") -> float:
543
+ bucket_method: Literal["fd", "doane", "auto", "scott", "stone", "rice", "sturges", "sqrt"] = "auto",
544
+ verbose: bool = False) -> float:
542
545
 
543
546
  r"""
544
547
  Compute Kullback-Leibler divergence between two distributions.
@@ -562,6 +565,7 @@ class Statistics(FeatureExtractionMixin):
562
565
  :returns: Kullback-Leibler divergence between ``sample_1`` and ``sample_2``
563
566
  :rtype: float
564
567
  """
568
+ timer = SimbaTimer(start=True)
565
569
  check_valid_array(data=sample_1, source=Statistics.kullback_leibler_divergence.__name__, accepted_ndims=(1,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
566
570
  check_valid_array(data=sample_2, source=Statistics.kullback_leibler_divergence.__name__, accepted_ndims=(1,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
567
571
  check_str(name=f"{self.__class__.__name__} bucket_method", value=bucket_method, options=Options.BUCKET_METHODS.value)
@@ -573,7 +577,10 @@ class Statistics(FeatureExtractionMixin):
573
577
  sample_1_hist[sample_1_hist == 0] = fill_value
574
578
  sample_2_hist[sample_2_hist == 0] = fill_value
575
579
  sample_1_hist, sample_2_hist = sample_1_hist / np.sum(sample_1_hist), sample_2_hist / np.sum(sample_2_hist)
576
- return stats.entropy(pk=sample_1_hist, qk=sample_2_hist)
580
+ kl = stats.entropy(pk=sample_1_hist, qk=sample_2_hist)
581
+ timer.stop_timer()
582
+ if verbose: print(f'KL divergence performed on {sample_1.shape[0]} observations (elapsed time: {timer.elapsed_time_str}s)')
583
+ return kl
577
584
 
578
585
  def rolling_kullback_leibler_divergence(
579
586
  self,
@@ -3271,10 +3278,34 @@ class Statistics(FeatureExtractionMixin):
3271
3278
 
3272
3279
  Youden's J statistic is a measure of the overall performance of a binary classification test, taking into account both sensitivity (true positive rate) and specificity (true negative rate).
3273
3280
 
3274
- :param sample_1: The first binary array.
3275
- :param sample_2: The second binary array.
3276
- :return: Youden's J statistic.
3281
+ The Youden's J statistic is calculated as:
3282
+
3283
+ .. math::
3284
+ J = \text{sensitivity} + \text{specificity} - 1
3285
+
3286
+ where:
3287
+
3288
+ - :math:`\text{sensitivity} = \frac{TP}{TP + FN}` is the true positive rate
3289
+ - :math:`\text{specificity} = \frac{TN}{TN + FP}` is the true negative rate
3290
+
3291
+ The statistic ranges from -1 to 1, where:
3292
+ - :math:`J = 1` indicates perfect classification
3293
+ - :math:`J = 0` indicates the test performs no better than random
3294
+ - :math:`J < 0` indicates the test performs worse than random
3295
+
3296
+ :param sample_1: The first binary array (ground truth or reference).
3297
+ :param sample_2: The second binary array (predictions or test results).
3298
+ :return: Youden's J statistic. Returns NaN if either sensitivity or specificity cannot be calculated (division by zero).
3277
3299
  :rtype: float
3300
+
3301
+ :references:
3302
+ .. [1] Youden, W. J. (1950). Index for rating diagnostic tests. Cancer, 3(1), 32-35.
3303
+ https://acsjournals.onlinelibrary.wiley.com/doi/abs/10.1002/1097-0142(1950)3:1%3C32::AID-CNCR2820030106%3E3.0.CO;2-3
3304
+
3305
+ :example:
3306
+ >>> y_true = np.array([1, 1, 0, 0, 1, 0, 1, 1, 0, 0])
3307
+ >>> y_pred = np.array([1, 1, 0, 1, 1, 0, 1, 0, 0, 0])
3308
+ >>> j = Statistics.youden_j(sample_1=y_true, sample_2=y_pred)
3278
3309
  """
3279
3310
 
3280
3311
  check_valid_array(data=sample_1, source=f'{Statistics.youden_j.__name__} sample_1', accepted_ndims=(1,), accepted_values=[0, 1])
@@ -4250,7 +4281,7 @@ class Statistics(FeatureExtractionMixin):
4250
4281
  return separation_trace / compactness
4251
4282
 
4252
4283
  @staticmethod
4253
- def i_index(x: np.ndarray, y: np.ndarray):
4284
+ def i_index(x: np.ndarray, y: np.ndarray, verbose: bool = False) -> float:
4254
4285
 
4255
4286
  """
4256
4287
  Calculate the I-Index for evaluating clustering quality.
@@ -4275,9 +4306,10 @@ class Statistics(FeatureExtractionMixin):
4275
4306
  >>> X, y = make_blobs(n_samples=5000, centers=20, n_features=3, random_state=0, cluster_std=0.1)
4276
4307
  >>> Statistics.i_index(x=X, y=y)
4277
4308
  """
4309
+ timer = SimbaTimer(start=True)
4278
4310
  check_valid_array(data=x, accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
4279
4311
  check_valid_array(data=y, accepted_ndims=(1,), accepted_dtypes=Formats.NUMERIC_DTYPES.value, accepted_axis_0_shape=[x.shape[0], ])
4280
- _ = get_unique_values_in_iterable(data=y, name=Statistics.i_index.__name__, min=2)
4312
+ cluster_cnt = get_unique_values_in_iterable(data=y, name=Statistics.i_index.__name__, min=2)
4281
4313
  unique_y = np.unique(y)
4282
4314
  n_y = unique_y.shape[0]
4283
4315
  global_centroid = np.mean(x, axis=0)
@@ -4289,7 +4321,12 @@ class Statistics(FeatureExtractionMixin):
4289
4321
  cluster_centroid = np.mean(cluster_obs, axis=0)
4290
4322
  swc += np.sum(np.linalg.norm(cluster_obs - cluster_centroid, axis=1) ** 2)
4291
4323
 
4292
- return sst / (n_y * swc)
4324
+
4325
+ i_index = np.float32(sst / (n_y * swc))
4326
+ timer.stop_timer()
4327
+ if verbose: print(f'I-index for {x.shape[0]} observations in {cluster_cnt} clusters computed (elapsed time: {timer.elapsed_time_str}s)')
4328
+ return i_index
4329
+
4293
4330
 
4294
4331
  @staticmethod
4295
4332
  def sd_index(x: np.ndarray, y: np.ndarray) -> float:
@@ -5291,7 +5328,7 @@ class Statistics(FeatureExtractionMixin):
5291
5328
  """
5292
5329
  Compute one-way ANOVAs comparing each column (axis 1) on two arrays.
5293
5330
 
5294
- .. notes::
5331
+ .. note::
5295
5332
  Use for computing and presenting aggregate statistics. Not suitable for featurization.
5296
5333
 
5297
5334
  .. seealso::
@@ -5329,7 +5366,7 @@ class Statistics(FeatureExtractionMixin):
5329
5366
  """
5330
5367
  Compute Kruskal-Wallis comparing each column (axis 1) on two arrays.
5331
5368
 
5332
- .. notes::
5369
+ .. note::
5333
5370
  Use for computing and presenting aggregate statistics. Not suitable for featurization.
5334
5371
 
5335
5372
  .. seealso::
@@ -5366,7 +5403,7 @@ class Statistics(FeatureExtractionMixin):
5366
5403
  """
5367
5404
  Compute pairwise grouped Tukey-HSD tests.
5368
5405
 
5369
- .. notes::
5406
+ .. note::
5370
5407
  Use for computing and presenting aggregate statistics. Not suitable for featurization.
5371
5408
 
5372
5409
  :param np.ndarray data: 2D array with observations rowwise (axis 0) and features columnwise (axis 1)
@@ -2198,7 +2198,7 @@ class TimeseriesFeatureMixin(object):
2198
2198
  :example:
2199
2199
  >>> x = np.random.randint(0, 100, (400, 2))
2200
2200
  >>> results_1 = TimeseriesFeatureMixin.sliding_entropy_of_directional_changes(x=x, bins=16, window_size=5.0, sample_rate=30)
2201
- >>> x = pd.read_csv(r"C:\troubleshooting\two_black_animals_14bp\project_folder\csv\input_csv\Together_1.csv")[['Ear_left_1_x', 'Ear_left_1_y']].values
2201
+ >>> x = pd.read_csv(r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/input_csv/Together_1.csv")[['Ear_left_1_x', 'Ear_left_1_y']].values
2202
2202
  >>> results_2 = TimeseriesFeatureMixin.sliding_entropy_of_directional_changes(x=x, bins=16, window_size=5.0, sample_rate=30)
2203
2203
  """
2204
2204
 
@@ -67,7 +67,7 @@ from simba.utils.checks import (check_all_dfs_in_list_has_same_cols,
67
67
  check_valid_boolean, check_valid_dataframe,
68
68
  check_valid_lst, is_lxc_container)
69
69
  from simba.utils.data import (detect_bouts, detect_bouts_multiclass,
70
- get_library_version)
70
+ get_library_version, terminate_cpu_pool)
71
71
  from simba.utils.enums import (OS, ConfigKey, Defaults, Dtypes, Formats, Links,
72
72
  Methods, MLParamKeys, Options)
73
73
  from simba.utils.errors import (ClassifierInferenceError, CorruptedFileError,
@@ -77,10 +77,10 @@ from simba.utils.errors import (ClassifierInferenceError, CorruptedFileError,
77
77
  SamplingError, SimBAModuleNotFoundError)
78
78
  from simba.utils.lookups import get_meta_data_file_headers, get_table
79
79
  from simba.utils.printing import SimbaTimer, stdout_success
80
- from simba.utils.read_write import (find_core_cnt, get_fn_ext,
81
- get_memory_usage_of_df, get_pkg_version,
82
- read_config_entry, read_df, read_meta_file,
83
- str_2_bool)
80
+ from simba.utils.read_write import (find_core_cnt, get_current_time,
81
+ get_fn_ext, get_memory_usage_of_df,
82
+ get_pkg_version, read_config_entry,
83
+ read_df, read_meta_file, str_2_bool)
84
84
  from simba.utils.warnings import (GPUToolsWarning, MissingUserInputWarning,
85
85
  MultiProcessingFailedWarning,
86
86
  NoModuleWarning, NotEnoughDataWarning,
@@ -1070,10 +1070,7 @@ class TrainModelMixin(object):
1070
1070
  MissingUserInputWarning(msg=f'Skipping {str(config.get("SML settings", "target_name_" + str(n + 1)))} classifier analysis: missing information (e.g., no discrimination threshold and/or minimum bout set in the project_config.ini',source=self.__class__.__name__)
1071
1071
 
1072
1072
  if len(model_dict.keys()) == 0:
1073
- raise NoDataError(
1074
- msg=f"There are no models with accurate data specified in the RUN MODELS menu. Specify the model information to SimBA RUN MODELS menu to use them to analyze videos",
1075
- source=self.get_model_info.__name__,
1076
- )
1073
+ raise NoDataError(msg=f"There are no models with accurate data specified in the RUN MODELS menu. Specify the model information to SimBA RUN MODELS menu to use them to analyze videos. PLease check the model paths, thresholds, and minimum bout lengths.", source=self.get_model_info.__name__)
1077
1074
  else:
1078
1075
  return model_dict
1079
1076
 
@@ -1383,18 +1380,39 @@ class TrainModelMixin(object):
1383
1380
  x_df: Union[pd.DataFrame, np.ndarray],
1384
1381
  multiclass: bool = False,
1385
1382
  model_name: Optional[str] = None,
1386
- data_path: Optional[Union[str, os.PathLike]] = None) -> np.ndarray:
1383
+ data_path: Optional[Union[str, os.PathLike]] = None,
1384
+ verbose: bool = False) -> np.ndarray:
1387
1385
 
1388
1386
  """
1389
- :param RandomForestClassifier clf: Random forest classifier object
1390
- :param Union[pd.DataFrame, np.ndarray] x_df: Features for data to predict as a dataframe or array of size (M,N).
1391
- :param bool multiclass: If True, the classifier predicts more than 2 targets. Else, boolean classifier.
1392
- :param Optional[str] model_name: Name of model
1393
- :param Optional[str] data_path: Path to model on disk
1394
- :return np.ndarray: 2D array with frame represented by rows and present/absent probabilities as columns
1395
- :raises FeatureNumberMismatchError: If shape of x_df and clf.n_features_ or n_features_in_ show mismatch
1387
+ Helper to predict class probabilities using a fitted random forest classifier.
1388
+
1389
+ Computes prediction probabilities for binary or multiclass classification using either
1390
+ scikit-learn or cuML RandomForestClassifier. For binary classifiers, returns the
1391
+ probability of the positive class (class 1). For multiclass classifiers, returns
1392
+ probabilities for all classes.
1393
+
1394
+ .. csv-table::
1395
+ :header: EXPECTED RUNTIMES
1396
+ :file: ../../docs/tables/clf_predict_proba.csv
1397
+ :widths: 10, 45, 45
1398
+ :align: center
1399
+ :header-rows: 1
1400
+
1401
+ .. seealso::
1402
+ To fit a classifier, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_fit`
1403
+ To define a classifier, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_define`
1404
+
1405
+ :param Union[RandomForestClassifier, cuRF] clf: Fitted random forest classifier object from sklearn or cuml.
1406
+ :param Union[pd.DataFrame, np.ndarray] x_df: Features for data to predict. DataFrame or array of shape (n_samples, n_features).
1407
+ :param bool multiclass: If True, the classifier predicts more than 2 classes. If False, binary classifier (default: False).
1408
+ :param Optional[str] model_name: Name of the model for error messages and logging. Default: None.
1409
+ :param Optional[Union[str, os.PathLike]] data_path: Path to the data file being processed, used in error messages. Default: None.
1410
+ :param bool verbose: If True, print inference progress and timing information. Default: False.
1411
+ :return np.ndarray: Prediction probabilities. For binary classifiers: 1D array of shape (n_samples,) with probability of positive class. For multiclass: 2D array of shape (n_samples, n_classes) with probabilities for each class.
1412
+
1396
1413
  """
1397
1414
 
1415
+ timer = SimbaTimer(start=True)
1398
1416
  if hasattr(clf, "n_features_"):
1399
1417
  clf_n_features = clf.n_features_
1400
1418
  elif hasattr(clf, "n_features_in_"):
@@ -1420,6 +1438,8 @@ class TrainModelMixin(object):
1420
1438
  p_vals = clf.predict_proba(x_df)
1421
1439
  if multiclass and (clf.n_classes_ != p_vals.shape[1]):
1422
1440
  raise ClassifierInferenceError(msg=f"The classifier {model_name} (data path: {data_path}) is a multiclassifier expected to create {clf.n_classes_} behavior probabilities. However, it produced probabilities for {p_vals.shape[1]} behaviors. See The SimBA GitHub FAQ page or Gitter for more information and suggested fixes.", source=self.__class__.__name__)
1441
+ timer.stop_timer()
1442
+ if verbose: print(f'Inference for model {model_name} over {x_df.shape[0]} observations complete ({timer.elapsed_time_str}s).')
1423
1443
  if not multiclass:
1424
1444
  if isinstance(p_vals, pd.DataFrame):
1425
1445
  return p_vals[1].values
@@ -1447,7 +1467,7 @@ class TrainModelMixin(object):
1447
1467
  bootstrap: Optional[bool] = True,
1448
1468
  verbose: Optional[int] = 1,
1449
1469
  class_weight: Optional[dict] = None,
1450
- cuda: Optional[bool] = False) -> RandomForestClassifier:
1470
+ cuda: Optional[bool] = False) -> Union[RandomForestClassifier, cuRF]:
1451
1471
 
1452
1472
  if not cuda:
1453
1473
  # NOTE: LOKY ISSUES ON WINDOWS WITH SCIKIT IF THE CORE COUNT EXCEEDS 61.
@@ -1482,20 +1502,32 @@ class TrainModelMixin(object):
1482
1502
  clf: Union[RandomForestClassifier, cuRF],
1483
1503
  x_df: pd.DataFrame,
1484
1504
  y_df: pd.DataFrame,
1485
- ) -> RandomForestClassifier:
1505
+ verbose: bool = False) -> Union[RandomForestClassifier, cuRF]:
1486
1506
 
1487
1507
  """
1488
- Helper to fit clf model
1508
+ Helper to fit clf model.
1489
1509
 
1490
- :param clf: Un-fitted random forest classifier object
1510
+ .. csv-table::
1511
+ :header: EXPECTED RUNTIMES
1512
+ :file: ../../docs/tables/clf_fit.csv
1513
+ :widths: 20, 20, 30, 30
1514
+ :align: center
1515
+ :header-rows: 1
1516
+
1517
+ .. seealso::
1518
+ To define a cuml/sklearn object, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_define`
1519
+
1520
+ :param clf: Un-fitted random forest classifier object, either from sklearn or cuml.
1491
1521
  :param pd.DataFrame x_df: Pandas dataframe with features.
1492
1522
  :param pd.DataFrame y_df: Pandas dataframe/Series with target
1493
1523
  :return: Fitted random forest classifier object
1494
1524
  :rtype: RandomForestClassifier
1495
1525
  """
1496
1526
 
1527
+ timer = SimbaTimer(start=True)
1497
1528
  nan_features = x_df[~x_df.applymap(np.isreal).all(1)]
1498
1529
  nan_target = y_df.loc[pd.to_numeric(y_df).isna()]
1530
+ using_cuda = True if CUML in str(clf.__class__.__module__).lower() else False
1499
1531
  if len(nan_features) > 0:
1500
1532
  raise FaultyTrainingSetError(
1501
1533
  msg=f"{len(nan_features)} frame(s) in your project_folder/csv/targets_inserted directory contains FEATURES with non-numerical values",
@@ -1504,9 +1536,16 @@ class TrainModelMixin(object):
1504
1536
  raise FaultyTrainingSetError(
1505
1537
  msg=f"{len(nan_target)} frame(s) in your project_folder/csv/targets_inserted directory contains ANNOTATIONS with non-numerical values",
1506
1538
  source=self.__class__.__name__)
1539
+ if verbose: print(f'[{get_current_time()}] Fitting classifier for {len(x_df)} observations (cuda: {"True" if using_cuda else "False"})...')
1540
+ if using_cuda:
1541
+ x_data = x_df.values if isinstance(x_df, pd.DataFrame) else x_df
1542
+ y_data = y_df.values if isinstance(y_df, (pd.DataFrame, pd.Series)) else y_df
1543
+ clf.fit(x_data, y_data)
1544
+ else:
1545
+ clf.fit(x_df, y_df)
1507
1546
 
1508
- clf.fit(x_df, y_df)
1509
-
1547
+ timer.stop_timer()
1548
+ if verbose: print(f'[{get_current_time()}] Classifier fitted in {timer.elapsed_time_str}s.')
1510
1549
  return clf
1511
1550
 
1512
1551
  @staticmethod
@@ -1563,9 +1602,7 @@ class TrainModelMixin(object):
1563
1602
  :rtype: Tuple[pd.DataFrame, List[int]]
1564
1603
 
1565
1604
  """
1566
- if (platform.system() == "Darwin") and (
1567
- multiprocessing.get_start_method() != "spawn"
1568
- ):
1605
+ if (platform.system() == "Darwin") and (multiprocessing.get_start_method() != "spawn"):
1569
1606
  multiprocessing.set_start_method("spawn", force=True)
1570
1607
  cpu_cnt, _ = find_core_cnt()
1571
1608
  df_lst, frame_numbers_lst = [], []
@@ -1592,9 +1629,7 @@ class TrainModelMixin(object):
1592
1629
  :, ~df_concat.columns.str.contains("^Unnamed")
1593
1630
  ].astype(np.float32)
1594
1631
  memory_size = get_memory_usage_of_df(df=df_concat)
1595
- print(
1596
- f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB'
1597
- )
1632
+ print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB')
1598
1633
 
1599
1634
  return df_concat, frame_numbers_lst
1600
1635
 
@@ -1859,7 +1894,7 @@ class TrainModelMixin(object):
1859
1894
  shap_raw.append(shap_data[result[1]][1].drop(clf_name, axis=1))
1860
1895
  if verbose: print(f"Completed SHAP care batch (Batch {result[1] + 1}/{len(shap_data)}).")
1861
1896
 
1862
- pool.terminate(); pool.join()
1897
+ terminate_cpu_pool(pool=pool, force=False)
1863
1898
  shap_df = pd.DataFrame(data=np.row_stack(shap_results), columns=list(x_names) + ["Expected_value", "Sum", "Prediction_probability", clf_name])
1864
1899
  raw_df = pd.DataFrame(data=np.row_stack(shap_raw), columns=list(x_names))
1865
1900
  out_shap_path, out_raw_path, img_save_path, df_save_paths, summary_dfs, img = None, None, None, None, None, None
@@ -2607,9 +2642,9 @@ class TrainModelMixin(object):
2607
2642
  :param bool plot: If True, create SHAP aggregation and plots.
2608
2643
 
2609
2644
  :example:
2610
- >>> CONFIG_PATH = r"C:\troubleshooting\mitra\project_folder\project_config.ini"
2611
- >>> RF_PATH = r"C:\troubleshooting\mitra\models\validations\straub_tail_5_new\straub_tail_5.sav"
2612
- >>> DATA_PATH = r"C:\troubleshooting\mitra\project_folder\csv\targets_inserted\new_straub\appended\501_MA142_Gi_CNO_0514.csv"
2645
+ >>> CONFIG_PATH = r"C:/troubleshooting/mitra/project_folder/project_config.ini"
2646
+ >>> RF_PATH = r"C:/troubleshooting/mitra/models/validations/straub_tail_5_new/straub_tail_5.sav"
2647
+ >>> DATA_PATH = r"C:/troubleshooting/mitra/project_folder/csv/targets_inserted/new_straub/appended/501_MA142_Gi_CNO_0514.csv"
2613
2648
  >>> config = ConfigReader(config_path=CONFIG_PATH)
2614
2649
  >>> df = read_df(file_path=DATA_PATH, file_type='csv')
2615
2650
  >>> y = df['straub_tail']
@@ -45,7 +45,7 @@ class InferenceBatch(TrainModelMixin, ConfigReader):
45
45
  >>> inferencer.run()
46
46
 
47
47
  :example II:
48
- >>> inferencer = InferenceBatch(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini", features_dir=r"D:\troubleshooting\mitra\project_folder\videos\bg_removed\rotated\tail_features\APPENDED")
48
+ >>> inferencer = InferenceBatch(config_path=r"D:/troubleshooting/mitra/project_folder/project_config.ini", features_dir=r"D:/troubleshooting/mitra/project_folder/videos/bg_removed/rotated/tail_features/APPENDED")
49
49
  >>> inferencer.run()
50
50
  """
51
51
 
@@ -101,7 +101,7 @@ class InferenceBatch(TrainModelMixin, ConfigReader):
101
101
  video_timer.stop_timer()
102
102
  print(f"Predictions created for {file_name} (frame count: {len(in_df)}, elapsed time: {video_timer.elapsed_time_str}) ...")
103
103
  self.timer.stop_timer()
104
- stdout_success(msg=f"Machine predictions complete. Files saved in {self.save_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
104
+ stdout_success(msg=f"Machine predictions complete for {len(self.feature_file_paths)} file(s). Files saved in {self.save_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
105
105
 
106
106
  if __name__ == "__main__" and not hasattr(sys, 'ps1'):
107
107
  parser = argparse.ArgumentParser(description="Perform classifications according to rules defined in SImAB project_config.ini.")
@@ -55,9 +55,9 @@ class YOLOSegmentationInference():
55
55
  To visualize the segmentation results, see :func:`simba.plotting.yolo_seg_visualizer.YOLOSegmentationVisualizer`
56
56
 
57
57
  :example:
58
- >>> weights_path = r"D:\platea\yolo_071525\mdl\train3\weights\best.pt"
59
- >>> video_path = r"D:\platea\platea_videos\videos\clipped\10B_Mouse_5-choice_MustTouchTrainingNEWFINAL_a7.mp4"
60
- >>> save_dir=r"D:\platea\platea_videos\videos\yolo_results"
58
+ >>> weights_path = r"D:/platea/yolo_071525/mdl/train3/weights/best.pt"
59
+ >>> video_path = r"D:/platea/platea_videos/videos/clipped/10B_Mouse_5-choice_MustTouchTrainingNEWFINAL_a7.mp4"
60
+ >>> save_dir = r"D:/platea/platea_videos/videos/yolo_results"
61
61
  >>> runner = YOLOSegmentationInference(weights_path=weights_path, video_path=video_path, save_dir=save_dir, verbose=True, device=0, format=None, stream=True, batch_size=10, imgsz=320, interpolate=True, threshold=0.8, retina_msk=True)
62
62
  >>> runner.run()
63
63
 
@@ -47,5 +47,5 @@ class OutlierCorrectionSkipper(ConfigReader):
47
47
  self.timer.stop_timer()
48
48
  stdout_success(msg=f"Skipped outlier correction for {len(self.input_csv_paths)} files", elapsed_time=self.timer.elapsed_time_str)
49
49
 
50
- # test = OutlierCorrectionSkipper(config_path='/Users/simon/Desktop/envs/troubleshooting/naresh/project_folder/project_config.ini')
50
+ # test = OutlierCorrectionSkipper(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
51
51
  # test.run()
@@ -24,10 +24,9 @@ from simba.utils.checks import (check_file_exist_and_readable,
24
24
  check_if_valid_rgb_tuple, check_int, check_str,
25
25
  check_valid_boolean, check_valid_lst,
26
26
  check_video_and_data_frm_count_align)
27
- from simba.utils.data import slice_roi_dict_for_video
27
+ from simba.utils.data import slice_roi_dict_for_video, terminate_cpu_pool
28
28
  from simba.utils.enums import Formats, TextOptions
29
- from simba.utils.errors import (BodypartColumnNotFoundError, NoFilesFoundError,
30
- ROICoordinatesNotFoundError)
29
+ from simba.utils.errors import BodypartColumnNotFoundError, NoFilesFoundError
31
30
  from simba.utils.printing import stdout_success
32
31
  from simba.utils.read_write import (concatenate_videos_in_folder,
33
32
  find_core_cnt, get_fn_ext,
@@ -315,8 +314,7 @@ class ROIfeatureVisualizerMultiprocess(ConfigReader):
315
314
  print(f"Joining {self.video_name} multi-processed video...")
316
315
  concatenate_videos_in_folder(in_folder=self.save_temp_dir, save_path=self.save_path, video_format="mp4", remove_splits=True, gpu=self.gpu)
317
316
  self.timer.stop_timer()
318
- pool.terminate()
319
- pool.join()
317
+ terminate_cpu_pool(pool=pool, force=False)
320
318
  stdout_success(msg=f"Video {self.video_name} complete. Video saved in directory {self.roi_features_save_dir}.", elapsed_time=self.timer.elapsed_time_str)
321
319
 
322
320
 
@@ -14,9 +14,9 @@ from simba.mixins.plotting_mixin import PlottingMixin
14
14
  from simba.utils.checks import (check_float, check_if_valid_rgb_tuple,
15
15
  check_int, check_str, check_that_column_exist,
16
16
  check_valid_lst)
17
- from simba.utils.data import detect_bouts
18
- from simba.utils.enums import Formats, TagNames, TextOptions
19
- from simba.utils.errors import NoFilesFoundError, NoSpecifiedOutputError
17
+ from simba.utils.data import detect_bouts, terminate_cpu_pool
18
+ from simba.utils.enums import Formats, TextOptions
19
+ from simba.utils.errors import NoSpecifiedOutputError
20
20
  from simba.utils.printing import SimbaTimer, log_event, stdout_success
21
21
  from simba.utils.read_write import (concatenate_videos_in_folder,
22
22
  find_core_cnt, get_fn_ext,
@@ -218,8 +218,7 @@ class ClassifierValidationClipsMultiprocess(ConfigReader):
218
218
  for cnt, result in enumerate(
219
219
  pool.imap(constants, clip_data, chunksize=self.multiprocess_chunksize)):
220
220
  print(f"Bout {cnt+1} complete...")
221
- pool.terminate()
222
- pool.join()
221
+ terminate_cpu_pool(pool=pool, force=False)
223
222
 
224
223
  if self.concat_video:
225
224
  print(f"Joining {file_name} multiprocessed video...")
@@ -1,10 +1,9 @@
1
1
  __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
2
 
3
3
  import functools
4
- import itertools
5
4
  import multiprocessing
6
5
  import os
7
- from typing import List, Optional, Union
6
+ from typing import List, Union
8
7
 
9
8
  import cv2
10
9
  import numpy as np
@@ -16,14 +15,15 @@ from simba.utils.checks import (check_file_exist_and_readable, check_int,
16
15
  check_valid_boolean, check_valid_dataframe,
17
16
  check_valid_lst)
18
17
  from simba.utils.data import (create_color_palettes, detect_bouts,
19
- slice_roi_dict_from_attribute)
20
- from simba.utils.enums import Defaults, Formats, Keys, TextOptions
18
+ slice_roi_dict_from_attribute,
19
+ terminate_cpu_pool)
20
+ from simba.utils.enums import Defaults, Formats, TextOptions
21
21
  from simba.utils.errors import NoROIDataError, NoSpecifiedOutputError
22
22
  from simba.utils.printing import stdout_success
23
23
  from simba.utils.read_write import (concatenate_videos_in_folder,
24
24
  create_directory, find_core_cnt,
25
25
  get_fn_ext, get_video_meta_data, read_df,
26
- read_frm_of_video, remove_a_folder)
26
+ read_frm_of_video)
27
27
 
28
28
 
29
29
  def _plot_cue_light_data(frm_idxs: list,
@@ -197,8 +197,7 @@ class CueLightVisualizer(ConfigReader):
197
197
  for cnt, result in enumerate(pool.imap(constants, self.frame_chunks, chunksize=self.multiprocess_chunksize)):
198
198
  if self.verbose:
199
199
  print(f'Batch {int(result+1/self.core_cnt)} complete...')
200
- pool.terminate()
201
- pool.join()
200
+ terminate_cpu_pool(pool=pool, force=False)
202
201
  self.timer.stop_timer()
203
202
  if self.video_setting:
204
203
  print(f"Joining {self.video_name} multiprocessed video...")
@@ -19,7 +19,7 @@ from simba.utils.checks import (check_file_exist_and_readable,
19
19
  check_if_valid_rgb_tuple, check_int,
20
20
  check_valid_lst,
21
21
  check_video_and_data_frm_count_align)
22
- from simba.utils.data import create_color_palettes
22
+ from simba.utils.data import create_color_palettes, terminate_cpu_pool
23
23
  from simba.utils.enums import OS, Formats, Keys, TextOptions
24
24
  from simba.utils.errors import (AnimalNumberError, InvalidInputError,
25
25
  NoFilesFoundError)
@@ -226,8 +226,7 @@ class DirectingOtherAnimalsVisualizerMultiprocess(ConfigReader, PlottingMixin):
226
226
  print(f"Joining {self.video_name} multi-processed video...")
227
227
  concatenate_videos_in_folder(in_folder=self.save_temp_path, save_path=self.save_path, video_format="mp4", remove_splits=True)
228
228
  self.timer.stop_timer()
229
- pool.terminate()
230
- pool.join()
229
+ terminate_cpu_pool(pool=pool, force=False)
231
230
  stdout_success(msg=f"Video {self.video_name} complete. Video saved in {self.directing_animals_video_output_path} directory", elapsed_time=self.timer.elapsed_time_str)
232
231
 
233
232