simba-uw-tf-dev 4.6.6__py3-none-any.whl → 4.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of simba-uw-tf-dev might be problematic. Click here for more details.

Files changed (49) hide show
  1. simba/assets/.recent_projects.txt +1 -0
  2. simba/data_processors/blob_location_computer.py +1 -1
  3. simba/data_processors/circling_detector.py +30 -13
  4. simba/data_processors/cuda/image.py +53 -25
  5. simba/data_processors/cuda/statistics.py +57 -19
  6. simba/data_processors/cuda/timeseries.py +1 -1
  7. simba/data_processors/egocentric_aligner.py +1 -1
  8. simba/data_processors/freezing_detector.py +54 -50
  9. simba/feature_extractors/feature_subsets.py +2 -2
  10. simba/feature_extractors/mitra_feature_extractor.py +2 -2
  11. simba/feature_extractors/straub_tail_analyzer.py +4 -4
  12. simba/labelling/standard_labeller.py +1 -1
  13. simba/mixins/config_reader.py +5 -2
  14. simba/mixins/geometry_mixin.py +8 -8
  15. simba/mixins/image_mixin.py +14 -14
  16. simba/mixins/plotting_mixin.py +28 -10
  17. simba/mixins/statistics_mixin.py +39 -9
  18. simba/mixins/timeseries_features_mixin.py +1 -1
  19. simba/mixins/train_model_mixin.py +65 -27
  20. simba/model/inference_batch.py +1 -1
  21. simba/model/yolo_seg_inference.py +3 -3
  22. simba/outlier_tools/skip_outlier_correction.py +1 -1
  23. simba/plotting/gantt_creator.py +29 -10
  24. simba/plotting/gantt_creator_mp.py +50 -17
  25. simba/plotting/heat_mapper_clf_mp.py +2 -2
  26. simba/pose_importers/simba_blob_importer.py +3 -3
  27. simba/roi_tools/roi_aggregate_stats_mp.py +1 -1
  28. simba/roi_tools/roi_clf_calculator_mp.py +1 -1
  29. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
  30. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
  31. simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
  32. simba/ui/pop_ups/gantt_pop_up.py +31 -6
  33. simba/ui/pop_ups/video_processing_pop_up.py +1 -1
  34. simba/utils/custom_feature_extractor.py +1 -1
  35. simba/utils/data.py +2 -2
  36. simba/utils/read_write.py +32 -18
  37. simba/utils/yolo.py +10 -1
  38. simba/video_processors/blob_tracking_executor.py +2 -2
  39. simba/video_processors/clahe_ui.py +1 -1
  40. simba/video_processors/egocentric_video_rotator.py +3 -3
  41. simba/video_processors/multi_cropper.py +1 -1
  42. simba/video_processors/video_processing.py +27 -10
  43. simba/video_processors/videos_to_frames.py +2 -2
  44. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/METADATA +3 -2
  45. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/RECORD +49 -49
  46. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/LICENSE +0 -0
  47. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/WHEEL +0 -0
  48. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/entry_points.txt +0 -0
  49. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/top_level.txt +0 -0
@@ -44,10 +44,10 @@ class StraubTailAnalyzer(ConfigReader):
44
44
  .. [1] Lazaro et al., Brainwide Genetic Capture for Conscious State Transitions, `biorxiv`, doi: https://doi.org/10.1101/2025.03.28.646066
45
45
 
46
46
  :example:
47
- >>> runner = StraubTailAnalyzer(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
48
- >>> data_dir=r'C:\troubleshooting\mitra\project_folder\videos\additional\bg_removed\rotated',
49
- >>> video_dir=r'C:\troubleshooting\mitra\project_folder\videos\additional\bg_removed\rotated',
50
- >>> save_dir=r'C:\troubleshooting\mitra\project_folder\videos\additional\bg_removed\rotated\tail_features_additional',
47
+ >>> runner = StraubTailAnalyzer(config_path=r"C:/troubleshooting/mitra/project_folder/project_config.ini",
48
+ >>> data_dir=r'C:/troubleshooting/mitra/project_folder/videos/additional/bg_removed/rotated',
49
+ >>> video_dir=r'C:/troubleshooting/mitra/project_folder/videos/additional/bg_removed/rotated',
50
+ >>> save_dir=r'C:/troubleshooting/mitra/project_folder/videos/additional/bg_removed/rotated/tail_features_additional',
51
51
  >>> anchor_points=('tail_base', 'tail_center', 'tail_tip'),
52
52
  >>> body_parts=('nose', 'left_ear', 'right_ear', 'right_side', 'left_side', 'tail_base'))
53
53
  >>> runner.run()
@@ -64,7 +64,7 @@ class LabellingInterface(ConfigReader):
64
64
  :param bool continuing: Set True to resume annotations from an existing targets file. Defaults to False.
65
65
 
66
66
  :example:
67
- >>> _ = LabellingInterface(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini", file_path=r"C:\troubleshooting\mitra\project_folder\videos\501_MA142_Gi_CNO_0521.mp4", thresholds=None, continuing=False)
67
+ >>> _ = LabellingInterface(config_path=r"C:/troubleshooting/mitra/project_folder/project_config.ini", file_path=r"C:/troubleshooting/mitra/project_folder/videos/501_MA142_Gi_CNO_0521.mp4", thresholds=None, continuing=False)
68
68
  """
69
69
 
70
70
  def __init__(self,
@@ -41,8 +41,8 @@ from simba.utils.read_write import (find_core_cnt, get_all_clf_names,
41
41
  get_fn_ext, read_config_file, read_df,
42
42
  read_project_path_and_file_type, write_df)
43
43
  from simba.utils.warnings import (BodypartColumnNotFoundWarning,
44
- InvalidValueWarning, NoDataFoundWarning,
45
- NoFileFoundWarning)
44
+ DuplicateNamesWarning, InvalidValueWarning,
45
+ NoDataFoundWarning, NoFileFoundWarning)
46
46
 
47
47
 
48
48
  class ConfigReader(object):
@@ -610,11 +610,14 @@ class ConfigReader(object):
610
610
  >>> config_reader.get_bp_headers()
611
611
  """
612
612
 
613
+ duplicates = list({x for x in self.body_parts_lst if self.body_parts_lst.count(x) > 1})
614
+ if len(duplicates) > 0: DuplicateNamesWarning(msg=f'The pose configuration file at {self.body_parts_path} contains duplicate entries: {duplicates}', source=self.__class__.__name__)
613
615
  self.bp_headers = []
614
616
  for bp in self.body_parts_lst:
615
617
  c1, c2, c3 = (f"{bp}_x", f"{bp}_y", f"{bp}_p")
616
618
  self.bp_headers.extend((c1, c2, c3))
617
619
 
620
+
618
621
  def read_config_entry(
619
622
  self,
620
623
  config: ConfigParser,
@@ -1556,7 +1556,7 @@ class GeometryMixin(object):
1556
1556
  :rtype: List[float]
1557
1557
 
1558
1558
  :example:
1559
- >>> df = read_df(file_path=r"C:\troubleshooting\two_black_animals_14bp\project_folder\csv\outlier_corrected_movement_location\Together_2.csv", file_type='csv').astype(int)
1559
+ >>> df = read_df(file_path=r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_2.csv", file_type='csv').astype(int)
1560
1560
  >>> animal_1_cols = [x for x in df.columns if '_1_' in x and not '_p' in x]
1561
1561
  >>> animal_2_cols = [x for x in df.columns if '_2_' in x and not '_p' in x]
1562
1562
  >>> animal_1_arr = df[animal_1_cols].values.reshape(len(df), int(len(animal_1_cols)/ 2), 2)
@@ -1622,7 +1622,7 @@ class GeometryMixin(object):
1622
1622
  :return List[float]: List of overlap between corresponding Polygons. If overlap 1, else 0.
1623
1623
 
1624
1624
  :example:
1625
- >>> df = read_df(file_path=r"C:\troubleshooting\two_black_animals_14bp\project_folder\csv\outlier_corrected_movement_location\Together_2.csv", file_type='csv').astype(int)
1625
+ >>> df = read_df(file_path=r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_2.csv", file_type='csv').astype(int)
1626
1626
  >>> animal_1_cols = [x for x in df.columns if '_1_' in x and not '_p' in x]
1627
1627
  >>> animal_2_cols = [x for x in df.columns if '_2_' in x and not '_p' in x]
1628
1628
  >>> animal_1_arr = df[animal_1_cols].values.reshape(len(df), int(len(animal_1_cols)/ 2), 2)
@@ -1693,7 +1693,7 @@ class GeometryMixin(object):
1693
1693
  :rtype: List[float]
1694
1694
 
1695
1695
  :example:
1696
- >>> df = read_df(file_path=r"C:\troubleshooting\two_black_animals_14bp\project_folder\csv\outlier_corrected_movement_location\Together_2.csv", file_type='csv').astype(int)
1696
+ >>> df = read_df(file_path=r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_2.csv", file_type='csv').astype(int)
1697
1697
  >>> animal_1_cols = [x for x in df.columns if '_1_' in x and not '_p' in x]
1698
1698
  >>> animal_2_cols = [x for x in df.columns if '_2_' in x and not '_p' in x]
1699
1699
  >>> animal_1_arr = df[animal_1_cols].values.reshape(len(df), int(len(animal_1_cols)/ 2), 2)
@@ -1763,7 +1763,7 @@ class GeometryMixin(object):
1763
1763
  :rtype: List[Polygon]
1764
1764
 
1765
1765
  :example:
1766
- >>> df = read_df(file_path=r"C:\troubleshooting\two_black_animals_14bp\project_folder\csv\outlier_corrected_movement_location\Together_2.csv", file_type='csv').astype(int)
1766
+ >>> df = read_df(file_path=r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_2.csv", file_type='csv').astype(int)
1767
1767
  >>> animal_1_cols = [x for x in df.columns if '_1_' in x and not '_p' in x]
1768
1768
  >>> animal_1_arr = df[animal_1_cols].values.reshape(len(df), int(len(animal_1_cols)/ 2), 2)
1769
1769
  >>> animal_1_geo = GeometryMixin.bodyparts_to_polygon(data=animal_1_arr)
@@ -3525,10 +3525,10 @@ class GeometryMixin(object):
3525
3525
  :rtype: Tuple[Dict[Tuple[int, int], Dict[Tuple[int, int], float]], Dict[Tuple[int, int], Dict[Tuple[int, int], int]]]
3526
3526
 
3527
3527
  :example:
3528
- >>> video_meta_data = get_video_meta_data(video_path=r"C:\troubleshooting\mitra\project_folder\videos\708_MA149_Gq_CNO_0515.mp4")
3528
+ >>> video_meta_data = get_video_meta_data(video_path=r"C:/troubleshooting/mitra/project_folder/videos/708_MA149_Gq_CNO_0515.mp4")
3529
3529
  >>> w, h = video_meta_data['width'], video_meta_data['height']
3530
3530
  >>> grid = GeometryMixin().bucket_img_into_grid_square(bucket_grid_size=(5, 5), bucket_grid_size_mm=None, img_size=(h, w), verbose=False)[0]
3531
- >>> data = read_df(file_path=r'C:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location\708_MA149_Gq_CNO_0515.csv', file_type='csv')[['Nose_x', 'Nose_y']].values
3531
+ >>> data = read_df(file_path=r'C:/troubleshooting/mitra/project_folder/csv/outlier_corrected_movement_location/708_MA149_Gq_CNO_0515.csv', file_type='csv')[['Nose_x', 'Nose_y']].values
3532
3532
  >>> transition_probabilities, _ = geometry_transition_probabilities(data=data, grid=grid)
3533
3533
  """
3534
3534
 
@@ -3990,7 +3990,7 @@ class GeometryMixin(object):
3990
3990
  :rtype: np.ndarray
3991
3991
 
3992
3992
  :example:
3993
- >>> data_path = r"C:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location\FRR_gq_Saline_0624.csv"
3993
+ >>> data_path = r"C:/troubleshooting/mitra/project_folder/csv/outlier_corrected_movement_location/FRR_gq_Saline_0624.csv"
3994
3994
  >>> animal_data = read_df(file_path=data_path, file_type='csv', usecols=['Nose_x', 'Nose_y', 'Tail_base_x', 'Tail_base_y', 'Left_side_x', 'Left_side_y', 'Right_side_x', 'Right_side_y']).values.reshape(-1, 4, 2)[0:20].astype(np.int32)
3995
3995
  >>> animal_polygons = GeometryMixin().bodyparts_to_polygon(data=animal_data)
3996
3996
  >>> GeometryMixin.geometries_to_exterior_keypoints(geometries=animal_polygons)
@@ -4160,7 +4160,7 @@ class GeometryMixin(object):
4160
4160
  :rtype: Union[None, Dict[Any, dict]]
4161
4161
 
4162
4162
  :example I:
4163
- >>> results = GeometryMixin.sleap_csv_to_geometries(data=r"C:\troubleshooting\ants\pose_data\ant.csv")
4163
+ >>> results = GeometryMixin.sleap_csv_to_geometries(data=r"C:/troubleshooting/ants/pose_data/ant.csv")
4164
4164
  >>> # Results structure: {track_id: {frame_idx: Polygon, ...}, ...}
4165
4165
 
4166
4166
  :example II
@@ -57,17 +57,16 @@ class ImageMixin(object):
57
57
  pass
58
58
 
59
59
  @staticmethod
60
- def brightness_intensity(imgs: List[np.ndarray], ignore_black: Optional[bool] = True) -> List[float]:
60
+ def brightness_intensity(imgs: Union[List[np.ndarray], np.ndarray], ignore_black: bool = True, verbose: bool = False) -> np.ndarray:
61
61
  """
62
62
  Compute the average brightness intensity within each image within a list.
63
63
 
64
64
  For example, (i) create a list of images containing a light cue ROI, (ii) compute brightness in each image, (iii) perform kmeans on brightness, and get the frames when the light cue is on vs off.
65
65
 
66
66
  .. seealso::
67
- For GPU acceleration, see :func:`simba.data_processors.cuda.image.img_stack_brightness`.
68
- For geometry based brightness, see :func:`simba.mixins.geometry_mixin.GeometryMixin.get_geometry_brightness_intensity`
67
+ For GPU acceleration, see :func:`simba.data_processors.cuda.image.img_stack_brightness`. For geometry based brightness, see :func:`simba.mixins.geometry_mixin.GeometryMixin.get_geometry_brightness_intensity`
69
68
 
70
- :param List[np.ndarray] imgs: List of images as arrays to calculate average brightness intensity within.
69
+ :param Union[List[np.ndarray], np.ndarray] imgs: List of images as arrays or 3/4d array of images to calculate average brightness intensity within.
71
70
  :param Optional[bool] ignore_black: If True, ignores black pixels. If the images are sliced non-rectangular geometric shapes created by ``slice_shapes_in_img``, then pixels that don't belong to the shape has been masked in black.
72
71
  :returns: List of floats of size len(imgs) with brightness intensities.
73
72
  :rtype: List[float]
@@ -77,14 +76,12 @@ class ImageMixin(object):
77
76
  >>> ImageMixin.brightness_intensity(imgs=[img], ignore_black=False)
78
77
  >>> [159.0]
79
78
  """
80
- results = []
81
- check_instance(source=f"{ImageMixin().brightness_intensity.__name__} imgs", instance=imgs, accepted_types=list)
82
- for cnt, img in enumerate(imgs):
83
- check_instance(
84
- source=f"{ImageMixin().brightness_intensity.__name__} img {cnt}",
85
- instance=img,
86
- accepted_types=np.ndarray,
87
- )
79
+ results, timer = [], SimbaTimer(start=True)
80
+ check_instance(source=f"{ImageMixin().brightness_intensity.__name__} imgs", instance=imgs, accepted_types=(list, np.ndarray,))
81
+ if isinstance(imgs, np.ndarray): imgs = np.array(imgs)
82
+ for img_cnt in range(imgs.shape[0]):
83
+ img = imgs[img_cnt]
84
+ check_instance(source=f"{ImageMixin().brightness_intensity.__name__} img {img_cnt}", instance=img, accepted_types=np.ndarray)
88
85
  if len(img) == 0:
89
86
  results.append(0)
90
87
  else:
@@ -92,7 +89,10 @@ class ImageMixin(object):
92
89
  results.append(np.ceil(np.average(img[img != 0])))
93
90
  else:
94
91
  results.append(np.ceil(np.average(img)))
95
- return results
92
+ b = np.array(results).astype(np.float32)
93
+ timer.stop_timer()
94
+ if verbose: print(f'Brightness computed in {b.shape[0]} images (elapsed time {timer.elapsed_time_str}s)')
95
+
96
96
 
97
97
  @staticmethod
98
98
  def gaussian_blur(img: np.ndarray, kernel_size: Optional[Tuple] = (9, 9)) -> np.ndarray:
@@ -1898,7 +1898,7 @@ class ImageMixin(object):
1898
1898
  :rtype: np.ndarray
1899
1899
 
1900
1900
  :example:
1901
- >>> VIDEO_PATH = r"D:\EPM_2\EPM_1.mp4"
1901
+ >>> VIDEO_PATH = r"D:/EPM_2/EPM_1.mp4"
1902
1902
  >>> img = read_img_batch_from_video(video_path=VIDEO_PATH, greyscale=True, start_frm=0, end_frm=15, core_cnt=1)
1903
1903
  >>> imgs = np.stack(list(img.values()))
1904
1904
  >>> resized_img = resize_img_stack(imgs=imgs)
@@ -39,7 +39,7 @@ from simba.utils.data import create_color_palette, detect_bouts
39
39
  from simba.utils.enums import Formats, Keys, Options, Paths
40
40
  from simba.utils.errors import InvalidInputError
41
41
  from simba.utils.lookups import (get_categorical_palettes, get_color_dict,
42
- get_named_colors)
42
+ get_fonts, get_named_colors)
43
43
  from simba.utils.printing import SimbaTimer, stdout_success
44
44
  from simba.utils.read_write import (find_files_of_filetypes_in_directory,
45
45
  get_fn_ext, get_video_meta_data, read_df,
@@ -342,16 +342,28 @@ class PlottingMixin(object):
342
342
  height: int = 480,
343
343
  font_size: int = 8,
344
344
  font_rotation: int = 45,
345
+ font: Optional[str] = None,
345
346
  save_path: Optional[str] = None,
347
+ edge_clr: Optional[str] = 'black',
346
348
  hhmmss: bool = False) -> Union[None, np.ndarray]:
347
349
 
348
350
  video_timer = SimbaTimer(start=True)
349
351
  colour_tuple_x = list(np.arange(3.5, 203.5, 5))
352
+ original_font_family = copy(plt.rcParams['font.family']) if isinstance(plt.rcParams['font.family'], list) else plt.rcParams['font.family']
353
+
354
+ if font is not None:
355
+ available_fonts = get_fonts()
356
+ if font in available_fonts:
357
+ matplotlib.font_manager._get_font.cache_clear()
358
+ plt.rcParams['font.family'] = font
359
+ else:
360
+ matplotlib.font_manager._get_font.cache_clear()
361
+ plt.rcParams['font.family'] = [font, 'sans-serif']
362
+
350
363
  fig, ax = plt.subplots()
351
- fig.patch.set_facecolor('#fafafa')
352
- ax.set_facecolor('#ffffff')
353
364
  fig.patch.set_facecolor('white')
354
- plt.title(video_name, fontsize=font_size + 6, pad=15, fontweight='bold')
365
+
366
+ plt.title(video_name, fontsize=font_size + 6, pad=25, fontweight='bold')
355
367
  ax.spines['top'].set_visible(False)
356
368
  ax.spines['right'].set_visible(False)
357
369
  ax.spines['left'].set_color('#666666')
@@ -367,7 +379,7 @@ class PlottingMixin(object):
367
379
  if event[0] == x:
368
380
  ix = clf_names.index(x)
369
381
  data_event = event[1][["Start_time", "Bout_time"]]
370
- ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix])
382
+ ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix], edgecolor=edge_clr, linewidth=0.5, alpha=0.9)
371
383
 
372
384
  x_ticks_seconds = np.round(np.linspace(0, x_length / fps, 6))
373
385
  x_ticks_locs = x_ticks_seconds
@@ -375,18 +387,19 @@ class PlottingMixin(object):
375
387
  if hhmmss:
376
388
  x_lbls = [seconds_to_timestamp(sec) for sec in x_ticks_seconds]
377
389
  else:
378
- x_lbls = x_ticks_seconds
390
+ x_lbls = [int(x) for x in x_ticks_seconds]
379
391
 
380
- #x_ticks_locs = x_lbls = np.round(np.linspace(0, x_length / fps, 6))
381
392
  ax.set_xticks(x_ticks_locs)
382
393
  ax.set_xticklabels(x_lbls)
383
394
  ax.set_ylim(0, colour_tuple_x[len(clf_names)])
384
395
  ax.set_yticks(np.arange(5, 5 * len(clf_names) + 1, 5))
385
- ax.set_yticklabels(clf_names, rotation=font_rotation)
396
+ ax.set_yticklabels(clf_names, rotation=font_rotation, ha='right', va='center')
386
397
  ax.tick_params(axis="both", labelsize=font_size)
387
398
  plt.xlabel(x_label, fontsize=font_size + 3)
388
- ax.yaxis.grid(True, linewidth=1.5, color='gray', alpha=0.4, linestyle='--')
389
- #ax.yaxis.grid(True)
399
+ ax.grid(True, axis='both', linewidth=1.0, color='gray', alpha=0.2, linestyle='--', which='major')
400
+
401
+ plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
402
+ plt.tight_layout()
390
403
  buffer_ = io.BytesIO()
391
404
  plt.savefig(buffer_, format="png")
392
405
  buffer_.seek(0)
@@ -397,6 +410,11 @@ class PlottingMixin(object):
397
410
  frame = np.uint8(open_cv_image)
398
411
  buffer_.close()
399
412
  plt.close('all')
413
+
414
+ if font is not None:
415
+ plt.rcParams['font.family'] = original_font_family
416
+ matplotlib.font_manager._get_font.cache_clear()
417
+
400
418
  if save_path is not None:
401
419
  cv2.imwrite(save_path, frame)
402
420
  video_timer.stop_timer()
@@ -3278,10 +3278,34 @@ class Statistics(FeatureExtractionMixin):
3278
3278
 
3279
3279
  Youden's J statistic is a measure of the overall performance of a binary classification test, taking into account both sensitivity (true positive rate) and specificity (true negative rate).
3280
3280
 
3281
- :param sample_1: The first binary array.
3282
- :param sample_2: The second binary array.
3283
- :return: Youden's J statistic.
3281
+ The Youden's J statistic is calculated as:
3282
+
3283
+ .. math::
3284
+ J = \text{sensitivity} + \text{specificity} - 1
3285
+
3286
+ where:
3287
+
3288
+ - :math:`\text{sensitivity} = \frac{TP}{TP + FN}` is the true positive rate
3289
+ - :math:`\text{specificity} = \frac{TN}{TN + FP}` is the true negative rate
3290
+
3291
+ The statistic ranges from -1 to 1, where:
3292
+ - :math:`J = 1` indicates perfect classification
3293
+ - :math:`J = 0` indicates the test performs no better than random
3294
+ - :math:`J < 0` indicates the test performs worse than random
3295
+
3296
+ :param sample_1: The first binary array (ground truth or reference).
3297
+ :param sample_2: The second binary array (predictions or test results).
3298
+ :return: Youden's J statistic. Returns NaN if either sensitivity or specificity cannot be calculated (division by zero).
3284
3299
  :rtype: float
3300
+
3301
+ :references:
3302
+ .. [1] Youden, W. J. (1950). Index for rating diagnostic tests. Cancer, 3(1), 32-35.
3303
+ https://acsjournals.onlinelibrary.wiley.com/doi/abs/10.1002/1097-0142(1950)3:1%3C32::AID-CNCR2820030106%3E3.0.CO;2-3
3304
+
3305
+ :example:
3306
+ >>> y_true = np.array([1, 1, 0, 0, 1, 0, 1, 1, 0, 0])
3307
+ >>> y_pred = np.array([1, 1, 0, 1, 1, 0, 1, 0, 0, 0])
3308
+ >>> j = Statistics.youden_j(sample_1=y_true, sample_2=y_pred)
3285
3309
  """
3286
3310
 
3287
3311
  check_valid_array(data=sample_1, source=f'{Statistics.youden_j.__name__} sample_1', accepted_ndims=(1,), accepted_values=[0, 1])
@@ -4257,7 +4281,7 @@ class Statistics(FeatureExtractionMixin):
4257
4281
  return separation_trace / compactness
4258
4282
 
4259
4283
  @staticmethod
4260
- def i_index(x: np.ndarray, y: np.ndarray):
4284
+ def i_index(x: np.ndarray, y: np.ndarray, verbose: bool = False) -> float:
4261
4285
 
4262
4286
  """
4263
4287
  Calculate the I-Index for evaluating clustering quality.
@@ -4282,9 +4306,10 @@ class Statistics(FeatureExtractionMixin):
4282
4306
  >>> X, y = make_blobs(n_samples=5000, centers=20, n_features=3, random_state=0, cluster_std=0.1)
4283
4307
  >>> Statistics.i_index(x=X, y=y)
4284
4308
  """
4309
+ timer = SimbaTimer(start=True)
4285
4310
  check_valid_array(data=x, accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
4286
4311
  check_valid_array(data=y, accepted_ndims=(1,), accepted_dtypes=Formats.NUMERIC_DTYPES.value, accepted_axis_0_shape=[x.shape[0], ])
4287
- _ = get_unique_values_in_iterable(data=y, name=Statistics.i_index.__name__, min=2)
4312
+ cluster_cnt = get_unique_values_in_iterable(data=y, name=Statistics.i_index.__name__, min=2)
4288
4313
  unique_y = np.unique(y)
4289
4314
  n_y = unique_y.shape[0]
4290
4315
  global_centroid = np.mean(x, axis=0)
@@ -4296,7 +4321,12 @@ class Statistics(FeatureExtractionMixin):
4296
4321
  cluster_centroid = np.mean(cluster_obs, axis=0)
4297
4322
  swc += np.sum(np.linalg.norm(cluster_obs - cluster_centroid, axis=1) ** 2)
4298
4323
 
4299
- return sst / (n_y * swc)
4324
+
4325
+ i_index = np.float32(sst / (n_y * swc))
4326
+ timer.stop_timer()
4327
+ if verbose: print(f'I-index for {x.shape[0]} observations in {cluster_cnt} clusters computed (elapsed time: {timer.elapsed_time_str}s)')
4328
+ return i_index
4329
+
4300
4330
 
4301
4331
  @staticmethod
4302
4332
  def sd_index(x: np.ndarray, y: np.ndarray) -> float:
@@ -5298,7 +5328,7 @@ class Statistics(FeatureExtractionMixin):
5298
5328
  """
5299
5329
  Compute one-way ANOVAs comparing each column (axis 1) on two arrays.
5300
5330
 
5301
- .. notes::
5331
+ .. note::
5302
5332
  Use for computing and presenting aggregate statistics. Not suitable for featurization.
5303
5333
 
5304
5334
  .. seealso::
@@ -5336,7 +5366,7 @@ class Statistics(FeatureExtractionMixin):
5336
5366
  """
5337
5367
  Compute Kruskal-Wallis comparing each column (axis 1) on two arrays.
5338
5368
 
5339
- .. notes::
5369
+ .. note::
5340
5370
  Use for computing and presenting aggregate statistics. Not suitable for featurization.
5341
5371
 
5342
5372
  .. seealso::
@@ -5373,7 +5403,7 @@ class Statistics(FeatureExtractionMixin):
5373
5403
  """
5374
5404
  Compute pairwise grouped Tukey-HSD tests.
5375
5405
 
5376
- .. notes::
5406
+ .. note::
5377
5407
  Use for computing and presenting aggregate statistics. Not suitable for featurization.
5378
5408
 
5379
5409
  :param np.ndarray data: 2D array with observations rowwise (axis 0) and features columnwise (axis 1)
@@ -2198,7 +2198,7 @@ class TimeseriesFeatureMixin(object):
2198
2198
  :example:
2199
2199
  >>> x = np.random.randint(0, 100, (400, 2))
2200
2200
  >>> results_1 = TimeseriesFeatureMixin.sliding_entropy_of_directional_changes(x=x, bins=16, window_size=5.0, sample_rate=30)
2201
- >>> x = pd.read_csv(r"C:\troubleshooting\two_black_animals_14bp\project_folder\csv\input_csv\Together_1.csv")[['Ear_left_1_x', 'Ear_left_1_y']].values
2201
+ >>> x = pd.read_csv(r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/input_csv/Together_1.csv")[['Ear_left_1_x', 'Ear_left_1_y']].values
2202
2202
  >>> results_2 = TimeseriesFeatureMixin.sliding_entropy_of_directional_changes(x=x, bins=16, window_size=5.0, sample_rate=30)
2203
2203
  """
2204
2204
 
@@ -77,10 +77,10 @@ from simba.utils.errors import (ClassifierInferenceError, CorruptedFileError,
77
77
  SamplingError, SimBAModuleNotFoundError)
78
78
  from simba.utils.lookups import get_meta_data_file_headers, get_table
79
79
  from simba.utils.printing import SimbaTimer, stdout_success
80
- from simba.utils.read_write import (find_core_cnt, get_fn_ext,
81
- get_memory_usage_of_df, get_pkg_version,
82
- read_config_entry, read_df, read_meta_file,
83
- str_2_bool)
80
+ from simba.utils.read_write import (find_core_cnt, get_current_time,
81
+ get_fn_ext, get_memory_usage_of_df,
82
+ get_pkg_version, read_config_entry,
83
+ read_df, read_meta_file, str_2_bool)
84
84
  from simba.utils.warnings import (GPUToolsWarning, MissingUserInputWarning,
85
85
  MultiProcessingFailedWarning,
86
86
  NoModuleWarning, NotEnoughDataWarning,
@@ -1383,18 +1383,39 @@ class TrainModelMixin(object):
1383
1383
  x_df: Union[pd.DataFrame, np.ndarray],
1384
1384
  multiclass: bool = False,
1385
1385
  model_name: Optional[str] = None,
1386
- data_path: Optional[Union[str, os.PathLike]] = None) -> np.ndarray:
1386
+ data_path: Optional[Union[str, os.PathLike]] = None,
1387
+ verbose: bool = False) -> np.ndarray:
1387
1388
 
1388
1389
  """
1389
- :param RandomForestClassifier clf: Random forest classifier object
1390
- :param Union[pd.DataFrame, np.ndarray] x_df: Features for data to predict as a dataframe or array of size (M,N).
1391
- :param bool multiclass: If True, the classifier predicts more than 2 targets. Else, boolean classifier.
1392
- :param Optional[str] model_name: Name of model
1393
- :param Optional[str] data_path: Path to model on disk
1394
- :return np.ndarray: 2D array with frame represented by rows and present/absent probabilities as columns
1395
- :raises FeatureNumberMismatchError: If shape of x_df and clf.n_features_ or n_features_in_ show mismatch
1390
+ Helper to predict class probabilities using a fitted random forest classifier.
1391
+
1392
+ Computes prediction probabilities for binary or multiclass classification using either
1393
+ scikit-learn or cuML RandomForestClassifier. For binary classifiers, returns the
1394
+ probability of the positive class (class 1). For multiclass classifiers, returns
1395
+ probabilities for all classes.
1396
+
1397
+ .. csv-table::
1398
+ :header: EXPECTED RUNTIMES
1399
+ :file: ../../docs/tables/clf_predict_proba.csv
1400
+ :widths: 10, 45, 45
1401
+ :align: center
1402
+ :header-rows: 1
1403
+
1404
+ .. seealso::
1405
+ To fit a classifier, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_fit`
1406
+ To define a classifier, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_define`
1407
+
1408
+ :param Union[RandomForestClassifier, cuRF] clf: Fitted random forest classifier object from sklearn or cuml.
1409
+ :param Union[pd.DataFrame, np.ndarray] x_df: Features for data to predict. DataFrame or array of shape (n_samples, n_features).
1410
+ :param bool multiclass: If True, the classifier predicts more than 2 classes. If False, binary classifier (default: False).
1411
+ :param Optional[str] model_name: Name of the model for error messages and logging. Default: None.
1412
+ :param Optional[Union[str, os.PathLike]] data_path: Path to the data file being processed, used in error messages. Default: None.
1413
+ :param bool verbose: If True, print inference progress and timing information. Default: False.
1414
+ :return np.ndarray: Prediction probabilities. For binary classifiers: 1D array of shape (n_samples,) with probability of positive class. For multiclass: 2D array of shape (n_samples, n_classes) with probabilities for each class.
1415
+
1396
1416
  """
1397
1417
 
1418
+ timer = SimbaTimer(start=True)
1398
1419
  if hasattr(clf, "n_features_"):
1399
1420
  clf_n_features = clf.n_features_
1400
1421
  elif hasattr(clf, "n_features_in_"):
@@ -1420,6 +1441,8 @@ class TrainModelMixin(object):
1420
1441
  p_vals = clf.predict_proba(x_df)
1421
1442
  if multiclass and (clf.n_classes_ != p_vals.shape[1]):
1422
1443
  raise ClassifierInferenceError(msg=f"The classifier {model_name} (data path: {data_path}) is a multiclassifier expected to create {clf.n_classes_} behavior probabilities. However, it produced probabilities for {p_vals.shape[1]} behaviors. See The SimBA GitHub FAQ page or Gitter for more information and suggested fixes.", source=self.__class__.__name__)
1444
+ timer.stop_timer()
1445
+ if verbose: print(f'Inference for model {model_name} over {x_df.shape[0]} observations complete ({timer.elapsed_time_str}s).')
1423
1446
  if not multiclass:
1424
1447
  if isinstance(p_vals, pd.DataFrame):
1425
1448
  return p_vals[1].values
@@ -1447,7 +1470,7 @@ class TrainModelMixin(object):
1447
1470
  bootstrap: Optional[bool] = True,
1448
1471
  verbose: Optional[int] = 1,
1449
1472
  class_weight: Optional[dict] = None,
1450
- cuda: Optional[bool] = False) -> RandomForestClassifier:
1473
+ cuda: Optional[bool] = False) -> Union[RandomForestClassifier, cuRF]:
1451
1474
 
1452
1475
  if not cuda:
1453
1476
  # NOTE: LOKY ISSUES ON WINDOWS WITH SCIKIT IF THE CORE COUNT EXCEEDS 61.
@@ -1482,20 +1505,32 @@ class TrainModelMixin(object):
1482
1505
  clf: Union[RandomForestClassifier, cuRF],
1483
1506
  x_df: pd.DataFrame,
1484
1507
  y_df: pd.DataFrame,
1485
- ) -> RandomForestClassifier:
1508
+ verbose: bool = False) -> Union[RandomForestClassifier, cuRF]:
1486
1509
 
1487
1510
  """
1488
- Helper to fit clf model
1511
+ Helper to fit clf model.
1489
1512
 
1490
- :param clf: Un-fitted random forest classifier object
1513
+ .. csv-table::
1514
+ :header: EXPECTED RUNTIMES
1515
+ :file: ../../docs/tables/clf_fit.csv
1516
+ :widths: 20, 20, 30, 30
1517
+ :align: center
1518
+ :header-rows: 1
1519
+
1520
+ .. seealso::
1521
+ To define a cuml/sklearn object, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_define`
1522
+
1523
+ :param clf: Un-fitted random forest classifier object, either from sklearn or cuml.
1491
1524
  :param pd.DataFrame x_df: Pandas dataframe with features.
1492
1525
  :param pd.DataFrame y_df: Pandas dataframe/Series with target
1493
1526
  :return: Fitted random forest classifier object
1494
1527
  :rtype: RandomForestClassifier
1495
1528
  """
1496
1529
 
1530
+ timer = SimbaTimer(start=True)
1497
1531
  nan_features = x_df[~x_df.applymap(np.isreal).all(1)]
1498
1532
  nan_target = y_df.loc[pd.to_numeric(y_df).isna()]
1533
+ using_cuda = True if CUML in str(clf.__class__.__module__).lower() else False
1499
1534
  if len(nan_features) > 0:
1500
1535
  raise FaultyTrainingSetError(
1501
1536
  msg=f"{len(nan_features)} frame(s) in your project_folder/csv/targets_inserted directory contains FEATURES with non-numerical values",
@@ -1504,9 +1539,16 @@ class TrainModelMixin(object):
1504
1539
  raise FaultyTrainingSetError(
1505
1540
  msg=f"{len(nan_target)} frame(s) in your project_folder/csv/targets_inserted directory contains ANNOTATIONS with non-numerical values",
1506
1541
  source=self.__class__.__name__)
1542
+ if verbose: print(f'[{get_current_time()}] Fitting classifier for {len(x_df)} observations (cuda: {"True" if using_cuda else "False"})...')
1543
+ if using_cuda:
1544
+ x_data = x_df.values if isinstance(x_df, pd.DataFrame) else x_df
1545
+ y_data = y_df.values if isinstance(y_df, (pd.DataFrame, pd.Series)) else y_df
1546
+ clf.fit(x_data, y_data)
1547
+ else:
1548
+ clf.fit(x_df, y_df)
1507
1549
 
1508
- clf.fit(x_df, y_df)
1509
-
1550
+ timer.stop_timer()
1551
+ if verbose: print(f'[{get_current_time()}] Classifier fitted in {timer.elapsed_time_str}s.')
1510
1552
  return clf
1511
1553
 
1512
1554
  @staticmethod
@@ -1563,9 +1605,7 @@ class TrainModelMixin(object):
1563
1605
  :rtype: Tuple[pd.DataFrame, List[int]]
1564
1606
 
1565
1607
  """
1566
- if (platform.system() == "Darwin") and (
1567
- multiprocessing.get_start_method() != "spawn"
1568
- ):
1608
+ if (platform.system() == "Darwin") and (multiprocessing.get_start_method() != "spawn"):
1569
1609
  multiprocessing.set_start_method("spawn", force=True)
1570
1610
  cpu_cnt, _ = find_core_cnt()
1571
1611
  df_lst, frame_numbers_lst = [], []
@@ -1592,9 +1632,7 @@ class TrainModelMixin(object):
1592
1632
  :, ~df_concat.columns.str.contains("^Unnamed")
1593
1633
  ].astype(np.float32)
1594
1634
  memory_size = get_memory_usage_of_df(df=df_concat)
1595
- print(
1596
- f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB'
1597
- )
1635
+ print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB')
1598
1636
 
1599
1637
  return df_concat, frame_numbers_lst
1600
1638
 
@@ -2607,9 +2645,9 @@ class TrainModelMixin(object):
2607
2645
  :param bool plot: If True, create SHAP aggregation and plots.
2608
2646
 
2609
2647
  :example:
2610
- >>> CONFIG_PATH = r"C:\troubleshooting\mitra\project_folder\project_config.ini"
2611
- >>> RF_PATH = r"C:\troubleshooting\mitra\models\validations\straub_tail_5_new\straub_tail_5.sav"
2612
- >>> DATA_PATH = r"C:\troubleshooting\mitra\project_folder\csv\targets_inserted\new_straub\appended\501_MA142_Gi_CNO_0514.csv"
2648
+ >>> CONFIG_PATH = r"C:/troubleshooting/mitra/project_folder/project_config.ini"
2649
+ >>> RF_PATH = r"C:/troubleshooting/mitra/models/validations/straub_tail_5_new/straub_tail_5.sav"
2650
+ >>> DATA_PATH = r"C:/troubleshooting/mitra/project_folder/csv/targets_inserted/new_straub/appended/501_MA142_Gi_CNO_0514.csv"
2613
2651
  >>> config = ConfigReader(config_path=CONFIG_PATH)
2614
2652
  >>> df = read_df(file_path=DATA_PATH, file_type='csv')
2615
2653
  >>> y = df['straub_tail']
@@ -45,7 +45,7 @@ class InferenceBatch(TrainModelMixin, ConfigReader):
45
45
  >>> inferencer.run()
46
46
 
47
47
  :example II:
48
- >>> inferencer = InferenceBatch(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini", features_dir=r"D:\troubleshooting\mitra\project_folder\videos\bg_removed\rotated\tail_features\APPENDED")
48
+ >>> inferencer = InferenceBatch(config_path=r"D:/troubleshooting/mitra/project_folder/project_config.ini", features_dir=r"D:/troubleshooting/mitra/project_folder/videos/bg_removed/rotated/tail_features/APPENDED")
49
49
  >>> inferencer.run()
50
50
  """
51
51
 
@@ -55,9 +55,9 @@ class YOLOSegmentationInference():
55
55
  To visualize the segmentation results, see :func:`simba.plotting.yolo_seg_visualizer.YOLOSegmentationVisualizer`
56
56
 
57
57
  :example:
58
- >>> weights_path = r"D:\platea\yolo_071525\mdl\train3\weights\best.pt"
59
- >>> video_path = r"D:\platea\platea_videos\videos\clipped\10B_Mouse_5-choice_MustTouchTrainingNEWFINAL_a7.mp4"
60
- >>> save_dir=r"D:\platea\platea_videos\videos\yolo_results"
58
+ >>> weights_path = r"D:/platea/yolo_071525/mdl/train3/weights/best.pt"
59
+ >>> video_path = r"D:/platea/platea_videos/videos/clipped/10B_Mouse_5-choice_MustTouchTrainingNEWFINAL_a7.mp4"
60
+ >>> save_dir = r"D:/platea/platea_videos/videos/yolo_results"
61
61
  >>> runner = YOLOSegmentationInference(weights_path=weights_path, video_path=video_path, save_dir=save_dir, verbose=True, device=0, format=None, stream=True, batch_size=10, imgsz=320, interpolate=True, threshold=0.8, retina_msk=True)
62
62
  >>> runner.run()
63
63
 
@@ -47,5 +47,5 @@ class OutlierCorrectionSkipper(ConfigReader):
47
47
  self.timer.stop_timer()
48
48
  stdout_success(msg=f"Skipped outlier correction for {len(self.input_csv_paths)} files", elapsed_time=self.timer.elapsed_time_str)
49
49
 
50
- # test = OutlierCorrectionSkipper(config_path='/Users/simon/Desktop/envs/troubleshooting/naresh/project_folder/project_config.ini')
50
+ # test = OutlierCorrectionSkipper(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
51
51
  # test.run()