simba-uw-tf-dev 4.6.6__py3-none-any.whl → 4.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of simba-uw-tf-dev might be problematic. Click here for more details.
- simba/assets/.recent_projects.txt +1 -0
- simba/data_processors/blob_location_computer.py +1 -1
- simba/data_processors/circling_detector.py +30 -13
- simba/data_processors/cuda/image.py +53 -25
- simba/data_processors/cuda/statistics.py +57 -19
- simba/data_processors/cuda/timeseries.py +1 -1
- simba/data_processors/egocentric_aligner.py +1 -1
- simba/data_processors/freezing_detector.py +54 -50
- simba/feature_extractors/feature_subsets.py +2 -2
- simba/feature_extractors/mitra_feature_extractor.py +2 -2
- simba/feature_extractors/straub_tail_analyzer.py +4 -4
- simba/labelling/standard_labeller.py +1 -1
- simba/mixins/config_reader.py +5 -2
- simba/mixins/geometry_mixin.py +8 -8
- simba/mixins/image_mixin.py +14 -14
- simba/mixins/plotting_mixin.py +28 -10
- simba/mixins/statistics_mixin.py +39 -9
- simba/mixins/timeseries_features_mixin.py +1 -1
- simba/mixins/train_model_mixin.py +65 -27
- simba/model/inference_batch.py +1 -1
- simba/model/yolo_seg_inference.py +3 -3
- simba/outlier_tools/skip_outlier_correction.py +1 -1
- simba/plotting/gantt_creator.py +29 -10
- simba/plotting/gantt_creator_mp.py +50 -17
- simba/plotting/heat_mapper_clf_mp.py +2 -2
- simba/pose_importers/simba_blob_importer.py +3 -3
- simba/roi_tools/roi_aggregate_stats_mp.py +1 -1
- simba/roi_tools/roi_clf_calculator_mp.py +1 -1
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
- simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
- simba/ui/pop_ups/gantt_pop_up.py +31 -6
- simba/ui/pop_ups/video_processing_pop_up.py +1 -1
- simba/utils/custom_feature_extractor.py +1 -1
- simba/utils/data.py +2 -2
- simba/utils/read_write.py +32 -18
- simba/utils/yolo.py +10 -1
- simba/video_processors/blob_tracking_executor.py +2 -2
- simba/video_processors/clahe_ui.py +1 -1
- simba/video_processors/egocentric_video_rotator.py +3 -3
- simba/video_processors/multi_cropper.py +1 -1
- simba/video_processors/video_processing.py +27 -10
- simba/video_processors/videos_to_frames.py +2 -2
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/METADATA +3 -2
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/RECORD +49 -49
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/top_level.txt +0 -0
|
@@ -44,10 +44,10 @@ class StraubTailAnalyzer(ConfigReader):
|
|
|
44
44
|
.. [1] Lazaro et al., Brainwide Genetic Capture for Conscious State Transitions, `biorxiv`, doi: https://doi.org/10.1101/2025.03.28.646066
|
|
45
45
|
|
|
46
46
|
:example:
|
|
47
|
-
>>> runner = StraubTailAnalyzer(config_path=r"C
|
|
48
|
-
>>> data_dir=r'C
|
|
49
|
-
>>> video_dir=r'C
|
|
50
|
-
>>> save_dir=r'C
|
|
47
|
+
>>> runner = StraubTailAnalyzer(config_path=r"C:/troubleshooting/mitra/project_folder/project_config.ini",
|
|
48
|
+
>>> data_dir=r'C:/troubleshooting/mitra/project_folder/videos/additional/bg_removed/rotated',
|
|
49
|
+
>>> video_dir=r'C:/troubleshooting/mitra/project_folder/videos/additional/bg_removed/rotated',
|
|
50
|
+
>>> save_dir=r'C:/troubleshooting/mitra/project_folder/videos/additional/bg_removed/rotated/tail_features_additional',
|
|
51
51
|
>>> anchor_points=('tail_base', 'tail_center', 'tail_tip'),
|
|
52
52
|
>>> body_parts=('nose', 'left_ear', 'right_ear', 'right_side', 'left_side', 'tail_base'))
|
|
53
53
|
>>> runner.run()
|
|
@@ -64,7 +64,7 @@ class LabellingInterface(ConfigReader):
|
|
|
64
64
|
:param bool continuing: Set True to resume annotations from an existing targets file. Defaults to False.
|
|
65
65
|
|
|
66
66
|
:example:
|
|
67
|
-
>>> _ = LabellingInterface(config_path=r"C
|
|
67
|
+
>>> _ = LabellingInterface(config_path=r"C:/troubleshooting/mitra/project_folder/project_config.ini", file_path=r"C:/troubleshooting/mitra/project_folder/videos/501_MA142_Gi_CNO_0521.mp4", thresholds=None, continuing=False)
|
|
68
68
|
"""
|
|
69
69
|
|
|
70
70
|
def __init__(self,
|
simba/mixins/config_reader.py
CHANGED
|
@@ -41,8 +41,8 @@ from simba.utils.read_write import (find_core_cnt, get_all_clf_names,
|
|
|
41
41
|
get_fn_ext, read_config_file, read_df,
|
|
42
42
|
read_project_path_and_file_type, write_df)
|
|
43
43
|
from simba.utils.warnings import (BodypartColumnNotFoundWarning,
|
|
44
|
-
|
|
45
|
-
NoFileFoundWarning)
|
|
44
|
+
DuplicateNamesWarning, InvalidValueWarning,
|
|
45
|
+
NoDataFoundWarning, NoFileFoundWarning)
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
class ConfigReader(object):
|
|
@@ -610,11 +610,14 @@ class ConfigReader(object):
|
|
|
610
610
|
>>> config_reader.get_bp_headers()
|
|
611
611
|
"""
|
|
612
612
|
|
|
613
|
+
duplicates = list({x for x in self.body_parts_lst if self.body_parts_lst.count(x) > 1})
|
|
614
|
+
if len(duplicates) > 0: DuplicateNamesWarning(msg=f'The pose configuration file at {self.body_parts_path} contains duplicate entries: {duplicates}', source=self.__class__.__name__)
|
|
613
615
|
self.bp_headers = []
|
|
614
616
|
for bp in self.body_parts_lst:
|
|
615
617
|
c1, c2, c3 = (f"{bp}_x", f"{bp}_y", f"{bp}_p")
|
|
616
618
|
self.bp_headers.extend((c1, c2, c3))
|
|
617
619
|
|
|
620
|
+
|
|
618
621
|
def read_config_entry(
|
|
619
622
|
self,
|
|
620
623
|
config: ConfigParser,
|
simba/mixins/geometry_mixin.py
CHANGED
|
@@ -1556,7 +1556,7 @@ class GeometryMixin(object):
|
|
|
1556
1556
|
:rtype: List[float]
|
|
1557
1557
|
|
|
1558
1558
|
:example:
|
|
1559
|
-
>>> df = read_df(file_path=r"C
|
|
1559
|
+
>>> df = read_df(file_path=r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_2.csv", file_type='csv').astype(int)
|
|
1560
1560
|
>>> animal_1_cols = [x for x in df.columns if '_1_' in x and not '_p' in x]
|
|
1561
1561
|
>>> animal_2_cols = [x for x in df.columns if '_2_' in x and not '_p' in x]
|
|
1562
1562
|
>>> animal_1_arr = df[animal_1_cols].values.reshape(len(df), int(len(animal_1_cols)/ 2), 2)
|
|
@@ -1622,7 +1622,7 @@ class GeometryMixin(object):
|
|
|
1622
1622
|
:return List[float]: List of overlap between corresponding Polygons. If overlap 1, else 0.
|
|
1623
1623
|
|
|
1624
1624
|
:example:
|
|
1625
|
-
>>> df = read_df(file_path=r"C
|
|
1625
|
+
>>> df = read_df(file_path=r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_2.csv", file_type='csv').astype(int)
|
|
1626
1626
|
>>> animal_1_cols = [x for x in df.columns if '_1_' in x and not '_p' in x]
|
|
1627
1627
|
>>> animal_2_cols = [x for x in df.columns if '_2_' in x and not '_p' in x]
|
|
1628
1628
|
>>> animal_1_arr = df[animal_1_cols].values.reshape(len(df), int(len(animal_1_cols)/ 2), 2)
|
|
@@ -1693,7 +1693,7 @@ class GeometryMixin(object):
|
|
|
1693
1693
|
:rtype: List[float]
|
|
1694
1694
|
|
|
1695
1695
|
:example:
|
|
1696
|
-
>>> df = read_df(file_path=r"C
|
|
1696
|
+
>>> df = read_df(file_path=r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_2.csv", file_type='csv').astype(int)
|
|
1697
1697
|
>>> animal_1_cols = [x for x in df.columns if '_1_' in x and not '_p' in x]
|
|
1698
1698
|
>>> animal_2_cols = [x for x in df.columns if '_2_' in x and not '_p' in x]
|
|
1699
1699
|
>>> animal_1_arr = df[animal_1_cols].values.reshape(len(df), int(len(animal_1_cols)/ 2), 2)
|
|
@@ -1763,7 +1763,7 @@ class GeometryMixin(object):
|
|
|
1763
1763
|
:rtype: List[Polygon]
|
|
1764
1764
|
|
|
1765
1765
|
:example:
|
|
1766
|
-
>>> df = read_df(file_path=r"C
|
|
1766
|
+
>>> df = read_df(file_path=r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_2.csv", file_type='csv').astype(int)
|
|
1767
1767
|
>>> animal_1_cols = [x for x in df.columns if '_1_' in x and not '_p' in x]
|
|
1768
1768
|
>>> animal_1_arr = df[animal_1_cols].values.reshape(len(df), int(len(animal_1_cols)/ 2), 2)
|
|
1769
1769
|
>>> animal_1_geo = GeometryMixin.bodyparts_to_polygon(data=animal_1_arr)
|
|
@@ -3525,10 +3525,10 @@ class GeometryMixin(object):
|
|
|
3525
3525
|
:rtype: Tuple[Dict[Tuple[int, int], Dict[Tuple[int, int], float]], Dict[Tuple[int, int], Dict[Tuple[int, int], int]]]
|
|
3526
3526
|
|
|
3527
3527
|
:example:
|
|
3528
|
-
>>> video_meta_data = get_video_meta_data(video_path=r"C
|
|
3528
|
+
>>> video_meta_data = get_video_meta_data(video_path=r"C:/troubleshooting/mitra/project_folder/videos/708_MA149_Gq_CNO_0515.mp4")
|
|
3529
3529
|
>>> w, h = video_meta_data['width'], video_meta_data['height']
|
|
3530
3530
|
>>> grid = GeometryMixin().bucket_img_into_grid_square(bucket_grid_size=(5, 5), bucket_grid_size_mm=None, img_size=(h, w), verbose=False)[0]
|
|
3531
|
-
>>> data = read_df(file_path=r'C
|
|
3531
|
+
>>> data = read_df(file_path=r'C:/troubleshooting/mitra/project_folder/csv/outlier_corrected_movement_location/708_MA149_Gq_CNO_0515.csv', file_type='csv')[['Nose_x', 'Nose_y']].values
|
|
3532
3532
|
>>> transition_probabilities, _ = geometry_transition_probabilities(data=data, grid=grid)
|
|
3533
3533
|
"""
|
|
3534
3534
|
|
|
@@ -3990,7 +3990,7 @@ class GeometryMixin(object):
|
|
|
3990
3990
|
:rtype: np.ndarray
|
|
3991
3991
|
|
|
3992
3992
|
:example:
|
|
3993
|
-
>>> data_path = r"C
|
|
3993
|
+
>>> data_path = r"C:/troubleshooting/mitra/project_folder/csv/outlier_corrected_movement_location/FRR_gq_Saline_0624.csv"
|
|
3994
3994
|
>>> animal_data = read_df(file_path=data_path, file_type='csv', usecols=['Nose_x', 'Nose_y', 'Tail_base_x', 'Tail_base_y', 'Left_side_x', 'Left_side_y', 'Right_side_x', 'Right_side_y']).values.reshape(-1, 4, 2)[0:20].astype(np.int32)
|
|
3995
3995
|
>>> animal_polygons = GeometryMixin().bodyparts_to_polygon(data=animal_data)
|
|
3996
3996
|
>>> GeometryMixin.geometries_to_exterior_keypoints(geometries=animal_polygons)
|
|
@@ -4160,7 +4160,7 @@ class GeometryMixin(object):
|
|
|
4160
4160
|
:rtype: Union[None, Dict[Any, dict]]
|
|
4161
4161
|
|
|
4162
4162
|
:example I:
|
|
4163
|
-
>>> results = GeometryMixin.sleap_csv_to_geometries(data=r"C
|
|
4163
|
+
>>> results = GeometryMixin.sleap_csv_to_geometries(data=r"C:/troubleshooting/ants/pose_data/ant.csv")
|
|
4164
4164
|
>>> # Results structure: {track_id: {frame_idx: Polygon, ...}, ...}
|
|
4165
4165
|
|
|
4166
4166
|
:example II
|
simba/mixins/image_mixin.py
CHANGED
|
@@ -57,17 +57,16 @@ class ImageMixin(object):
|
|
|
57
57
|
pass
|
|
58
58
|
|
|
59
59
|
@staticmethod
|
|
60
|
-
def brightness_intensity(imgs: List[np.ndarray], ignore_black:
|
|
60
|
+
def brightness_intensity(imgs: Union[List[np.ndarray], np.ndarray], ignore_black: bool = True, verbose: bool = False) -> np.ndarray:
|
|
61
61
|
"""
|
|
62
62
|
Compute the average brightness intensity within each image within a list.
|
|
63
63
|
|
|
64
64
|
For example, (i) create a list of images containing a light cue ROI, (ii) compute brightness in each image, (iii) perform kmeans on brightness, and get the frames when the light cue is on vs off.
|
|
65
65
|
|
|
66
66
|
.. seealso::
|
|
67
|
-
For GPU acceleration, see :func:`simba.data_processors.cuda.image.img_stack_brightness`.
|
|
68
|
-
For geometry based brightness, see :func:`simba.mixins.geometry_mixin.GeometryMixin.get_geometry_brightness_intensity`
|
|
67
|
+
For GPU acceleration, see :func:`simba.data_processors.cuda.image.img_stack_brightness`. For geometry based brightness, see :func:`simba.mixins.geometry_mixin.GeometryMixin.get_geometry_brightness_intensity`
|
|
69
68
|
|
|
70
|
-
:param List[np.ndarray] imgs: List of images as arrays to calculate average brightness intensity within.
|
|
69
|
+
:param Union[List[np.ndarray], np.ndarray] imgs: List of images as arrays or 3/4d array of images to calculate average brightness intensity within.
|
|
71
70
|
:param Optional[bool] ignore_black: If True, ignores black pixels. If the images are sliced non-rectangular geometric shapes created by ``slice_shapes_in_img``, then pixels that don't belong to the shape has been masked in black.
|
|
72
71
|
:returns: List of floats of size len(imgs) with brightness intensities.
|
|
73
72
|
:rtype: List[float]
|
|
@@ -77,14 +76,12 @@ class ImageMixin(object):
|
|
|
77
76
|
>>> ImageMixin.brightness_intensity(imgs=[img], ignore_black=False)
|
|
78
77
|
>>> [159.0]
|
|
79
78
|
"""
|
|
80
|
-
results = []
|
|
81
|
-
check_instance(source=f"{ImageMixin().brightness_intensity.__name__} imgs", instance=imgs, accepted_types=list)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
accepted_types=np.ndarray,
|
|
87
|
-
)
|
|
79
|
+
results, timer = [], SimbaTimer(start=True)
|
|
80
|
+
check_instance(source=f"{ImageMixin().brightness_intensity.__name__} imgs", instance=imgs, accepted_types=(list, np.ndarray,))
|
|
81
|
+
if isinstance(imgs, np.ndarray): imgs = np.array(imgs)
|
|
82
|
+
for img_cnt in range(imgs.shape[0]):
|
|
83
|
+
img = imgs[img_cnt]
|
|
84
|
+
check_instance(source=f"{ImageMixin().brightness_intensity.__name__} img {img_cnt}", instance=img, accepted_types=np.ndarray)
|
|
88
85
|
if len(img) == 0:
|
|
89
86
|
results.append(0)
|
|
90
87
|
else:
|
|
@@ -92,7 +89,10 @@ class ImageMixin(object):
|
|
|
92
89
|
results.append(np.ceil(np.average(img[img != 0])))
|
|
93
90
|
else:
|
|
94
91
|
results.append(np.ceil(np.average(img)))
|
|
95
|
-
|
|
92
|
+
b = np.array(results).astype(np.float32)
|
|
93
|
+
timer.stop_timer()
|
|
94
|
+
if verbose: print(f'Brightness computed in {b.shape[0]} images (elapsed time {timer.elapsed_time_str}s)')
|
|
95
|
+
|
|
96
96
|
|
|
97
97
|
@staticmethod
|
|
98
98
|
def gaussian_blur(img: np.ndarray, kernel_size: Optional[Tuple] = (9, 9)) -> np.ndarray:
|
|
@@ -1898,7 +1898,7 @@ class ImageMixin(object):
|
|
|
1898
1898
|
:rtype: np.ndarray
|
|
1899
1899
|
|
|
1900
1900
|
:example:
|
|
1901
|
-
>>> VIDEO_PATH = r"D
|
|
1901
|
+
>>> VIDEO_PATH = r"D:/EPM_2/EPM_1.mp4"
|
|
1902
1902
|
>>> img = read_img_batch_from_video(video_path=VIDEO_PATH, greyscale=True, start_frm=0, end_frm=15, core_cnt=1)
|
|
1903
1903
|
>>> imgs = np.stack(list(img.values()))
|
|
1904
1904
|
>>> resized_img = resize_img_stack(imgs=imgs)
|
simba/mixins/plotting_mixin.py
CHANGED
|
@@ -39,7 +39,7 @@ from simba.utils.data import create_color_palette, detect_bouts
|
|
|
39
39
|
from simba.utils.enums import Formats, Keys, Options, Paths
|
|
40
40
|
from simba.utils.errors import InvalidInputError
|
|
41
41
|
from simba.utils.lookups import (get_categorical_palettes, get_color_dict,
|
|
42
|
-
get_named_colors)
|
|
42
|
+
get_fonts, get_named_colors)
|
|
43
43
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
44
44
|
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
|
|
45
45
|
get_fn_ext, get_video_meta_data, read_df,
|
|
@@ -342,16 +342,28 @@ class PlottingMixin(object):
|
|
|
342
342
|
height: int = 480,
|
|
343
343
|
font_size: int = 8,
|
|
344
344
|
font_rotation: int = 45,
|
|
345
|
+
font: Optional[str] = None,
|
|
345
346
|
save_path: Optional[str] = None,
|
|
347
|
+
edge_clr: Optional[str] = 'black',
|
|
346
348
|
hhmmss: bool = False) -> Union[None, np.ndarray]:
|
|
347
349
|
|
|
348
350
|
video_timer = SimbaTimer(start=True)
|
|
349
351
|
colour_tuple_x = list(np.arange(3.5, 203.5, 5))
|
|
352
|
+
original_font_family = copy(plt.rcParams['font.family']) if isinstance(plt.rcParams['font.family'], list) else plt.rcParams['font.family']
|
|
353
|
+
|
|
354
|
+
if font is not None:
|
|
355
|
+
available_fonts = get_fonts()
|
|
356
|
+
if font in available_fonts:
|
|
357
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
358
|
+
plt.rcParams['font.family'] = font
|
|
359
|
+
else:
|
|
360
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
361
|
+
plt.rcParams['font.family'] = [font, 'sans-serif']
|
|
362
|
+
|
|
350
363
|
fig, ax = plt.subplots()
|
|
351
|
-
fig.patch.set_facecolor('#fafafa')
|
|
352
|
-
ax.set_facecolor('#ffffff')
|
|
353
364
|
fig.patch.set_facecolor('white')
|
|
354
|
-
|
|
365
|
+
|
|
366
|
+
plt.title(video_name, fontsize=font_size + 6, pad=25, fontweight='bold')
|
|
355
367
|
ax.spines['top'].set_visible(False)
|
|
356
368
|
ax.spines['right'].set_visible(False)
|
|
357
369
|
ax.spines['left'].set_color('#666666')
|
|
@@ -367,7 +379,7 @@ class PlottingMixin(object):
|
|
|
367
379
|
if event[0] == x:
|
|
368
380
|
ix = clf_names.index(x)
|
|
369
381
|
data_event = event[1][["Start_time", "Bout_time"]]
|
|
370
|
-
ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix])
|
|
382
|
+
ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix], edgecolor=edge_clr, linewidth=0.5, alpha=0.9)
|
|
371
383
|
|
|
372
384
|
x_ticks_seconds = np.round(np.linspace(0, x_length / fps, 6))
|
|
373
385
|
x_ticks_locs = x_ticks_seconds
|
|
@@ -375,18 +387,19 @@ class PlottingMixin(object):
|
|
|
375
387
|
if hhmmss:
|
|
376
388
|
x_lbls = [seconds_to_timestamp(sec) for sec in x_ticks_seconds]
|
|
377
389
|
else:
|
|
378
|
-
x_lbls = x_ticks_seconds
|
|
390
|
+
x_lbls = [int(x) for x in x_ticks_seconds]
|
|
379
391
|
|
|
380
|
-
#x_ticks_locs = x_lbls = np.round(np.linspace(0, x_length / fps, 6))
|
|
381
392
|
ax.set_xticks(x_ticks_locs)
|
|
382
393
|
ax.set_xticklabels(x_lbls)
|
|
383
394
|
ax.set_ylim(0, colour_tuple_x[len(clf_names)])
|
|
384
395
|
ax.set_yticks(np.arange(5, 5 * len(clf_names) + 1, 5))
|
|
385
|
-
ax.set_yticklabels(clf_names, rotation=font_rotation)
|
|
396
|
+
ax.set_yticklabels(clf_names, rotation=font_rotation, ha='right', va='center')
|
|
386
397
|
ax.tick_params(axis="both", labelsize=font_size)
|
|
387
398
|
plt.xlabel(x_label, fontsize=font_size + 3)
|
|
388
|
-
ax.
|
|
389
|
-
|
|
399
|
+
ax.grid(True, axis='both', linewidth=1.0, color='gray', alpha=0.2, linestyle='--', which='major')
|
|
400
|
+
|
|
401
|
+
plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
|
|
402
|
+
plt.tight_layout()
|
|
390
403
|
buffer_ = io.BytesIO()
|
|
391
404
|
plt.savefig(buffer_, format="png")
|
|
392
405
|
buffer_.seek(0)
|
|
@@ -397,6 +410,11 @@ class PlottingMixin(object):
|
|
|
397
410
|
frame = np.uint8(open_cv_image)
|
|
398
411
|
buffer_.close()
|
|
399
412
|
plt.close('all')
|
|
413
|
+
|
|
414
|
+
if font is not None:
|
|
415
|
+
plt.rcParams['font.family'] = original_font_family
|
|
416
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
417
|
+
|
|
400
418
|
if save_path is not None:
|
|
401
419
|
cv2.imwrite(save_path, frame)
|
|
402
420
|
video_timer.stop_timer()
|
simba/mixins/statistics_mixin.py
CHANGED
|
@@ -3278,10 +3278,34 @@ class Statistics(FeatureExtractionMixin):
|
|
|
3278
3278
|
|
|
3279
3279
|
Youden's J statistic is a measure of the overall performance of a binary classification test, taking into account both sensitivity (true positive rate) and specificity (true negative rate).
|
|
3280
3280
|
|
|
3281
|
-
|
|
3282
|
-
|
|
3283
|
-
|
|
3281
|
+
The Youden's J statistic is calculated as:
|
|
3282
|
+
|
|
3283
|
+
.. math::
|
|
3284
|
+
J = \text{sensitivity} + \text{specificity} - 1
|
|
3285
|
+
|
|
3286
|
+
where:
|
|
3287
|
+
|
|
3288
|
+
- :math:`\text{sensitivity} = \frac{TP}{TP + FN}` is the true positive rate
|
|
3289
|
+
- :math:`\text{specificity} = \frac{TN}{TN + FP}` is the true negative rate
|
|
3290
|
+
|
|
3291
|
+
The statistic ranges from -1 to 1, where:
|
|
3292
|
+
- :math:`J = 1` indicates perfect classification
|
|
3293
|
+
- :math:`J = 0` indicates the test performs no better than random
|
|
3294
|
+
- :math:`J < 0` indicates the test performs worse than random
|
|
3295
|
+
|
|
3296
|
+
:param sample_1: The first binary array (ground truth or reference).
|
|
3297
|
+
:param sample_2: The second binary array (predictions or test results).
|
|
3298
|
+
:return: Youden's J statistic. Returns NaN if either sensitivity or specificity cannot be calculated (division by zero).
|
|
3284
3299
|
:rtype: float
|
|
3300
|
+
|
|
3301
|
+
:references:
|
|
3302
|
+
.. [1] Youden, W. J. (1950). Index for rating diagnostic tests. Cancer, 3(1), 32-35.
|
|
3303
|
+
https://acsjournals.onlinelibrary.wiley.com/doi/abs/10.1002/1097-0142(1950)3:1%3C32::AID-CNCR2820030106%3E3.0.CO;2-3
|
|
3304
|
+
|
|
3305
|
+
:example:
|
|
3306
|
+
>>> y_true = np.array([1, 1, 0, 0, 1, 0, 1, 1, 0, 0])
|
|
3307
|
+
>>> y_pred = np.array([1, 1, 0, 1, 1, 0, 1, 0, 0, 0])
|
|
3308
|
+
>>> j = Statistics.youden_j(sample_1=y_true, sample_2=y_pred)
|
|
3285
3309
|
"""
|
|
3286
3310
|
|
|
3287
3311
|
check_valid_array(data=sample_1, source=f'{Statistics.youden_j.__name__} sample_1', accepted_ndims=(1,), accepted_values=[0, 1])
|
|
@@ -4257,7 +4281,7 @@ class Statistics(FeatureExtractionMixin):
|
|
|
4257
4281
|
return separation_trace / compactness
|
|
4258
4282
|
|
|
4259
4283
|
@staticmethod
|
|
4260
|
-
def i_index(x: np.ndarray, y: np.ndarray):
|
|
4284
|
+
def i_index(x: np.ndarray, y: np.ndarray, verbose: bool = False) -> float:
|
|
4261
4285
|
|
|
4262
4286
|
"""
|
|
4263
4287
|
Calculate the I-Index for evaluating clustering quality.
|
|
@@ -4282,9 +4306,10 @@ class Statistics(FeatureExtractionMixin):
|
|
|
4282
4306
|
>>> X, y = make_blobs(n_samples=5000, centers=20, n_features=3, random_state=0, cluster_std=0.1)
|
|
4283
4307
|
>>> Statistics.i_index(x=X, y=y)
|
|
4284
4308
|
"""
|
|
4309
|
+
timer = SimbaTimer(start=True)
|
|
4285
4310
|
check_valid_array(data=x, accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
|
|
4286
4311
|
check_valid_array(data=y, accepted_ndims=(1,), accepted_dtypes=Formats.NUMERIC_DTYPES.value, accepted_axis_0_shape=[x.shape[0], ])
|
|
4287
|
-
|
|
4312
|
+
cluster_cnt = get_unique_values_in_iterable(data=y, name=Statistics.i_index.__name__, min=2)
|
|
4288
4313
|
unique_y = np.unique(y)
|
|
4289
4314
|
n_y = unique_y.shape[0]
|
|
4290
4315
|
global_centroid = np.mean(x, axis=0)
|
|
@@ -4296,7 +4321,12 @@ class Statistics(FeatureExtractionMixin):
|
|
|
4296
4321
|
cluster_centroid = np.mean(cluster_obs, axis=0)
|
|
4297
4322
|
swc += np.sum(np.linalg.norm(cluster_obs - cluster_centroid, axis=1) ** 2)
|
|
4298
4323
|
|
|
4299
|
-
|
|
4324
|
+
|
|
4325
|
+
i_index = np.float32(sst / (n_y * swc))
|
|
4326
|
+
timer.stop_timer()
|
|
4327
|
+
if verbose: print(f'I-index for {x.shape[0]} observations in {cluster_cnt} clusters computed (elapsed time: {timer.elapsed_time_str}s)')
|
|
4328
|
+
return i_index
|
|
4329
|
+
|
|
4300
4330
|
|
|
4301
4331
|
@staticmethod
|
|
4302
4332
|
def sd_index(x: np.ndarray, y: np.ndarray) -> float:
|
|
@@ -5298,7 +5328,7 @@ class Statistics(FeatureExtractionMixin):
|
|
|
5298
5328
|
"""
|
|
5299
5329
|
Compute one-way ANOVAs comparing each column (axis 1) on two arrays.
|
|
5300
5330
|
|
|
5301
|
-
..
|
|
5331
|
+
.. note::
|
|
5302
5332
|
Use for computing and presenting aggregate statistics. Not suitable for featurization.
|
|
5303
5333
|
|
|
5304
5334
|
.. seealso::
|
|
@@ -5336,7 +5366,7 @@ class Statistics(FeatureExtractionMixin):
|
|
|
5336
5366
|
"""
|
|
5337
5367
|
Compute Kruskal-Wallis comparing each column (axis 1) on two arrays.
|
|
5338
5368
|
|
|
5339
|
-
..
|
|
5369
|
+
.. note::
|
|
5340
5370
|
Use for computing and presenting aggregate statistics. Not suitable for featurization.
|
|
5341
5371
|
|
|
5342
5372
|
.. seealso::
|
|
@@ -5373,7 +5403,7 @@ class Statistics(FeatureExtractionMixin):
|
|
|
5373
5403
|
"""
|
|
5374
5404
|
Compute pairwise grouped Tukey-HSD tests.
|
|
5375
5405
|
|
|
5376
|
-
..
|
|
5406
|
+
.. note::
|
|
5377
5407
|
Use for computing and presenting aggregate statistics. Not suitable for featurization.
|
|
5378
5408
|
|
|
5379
5409
|
:param np.ndarray data: 2D array with observations rowwise (axis 0) and features columnwise (axis 1)
|
|
@@ -2198,7 +2198,7 @@ class TimeseriesFeatureMixin(object):
|
|
|
2198
2198
|
:example:
|
|
2199
2199
|
>>> x = np.random.randint(0, 100, (400, 2))
|
|
2200
2200
|
>>> results_1 = TimeseriesFeatureMixin.sliding_entropy_of_directional_changes(x=x, bins=16, window_size=5.0, sample_rate=30)
|
|
2201
|
-
>>> x = pd.read_csv(r"C
|
|
2201
|
+
>>> x = pd.read_csv(r"C:/troubleshooting/two_black_animals_14bp/project_folder/csv/input_csv/Together_1.csv")[['Ear_left_1_x', 'Ear_left_1_y']].values
|
|
2202
2202
|
>>> results_2 = TimeseriesFeatureMixin.sliding_entropy_of_directional_changes(x=x, bins=16, window_size=5.0, sample_rate=30)
|
|
2203
2203
|
"""
|
|
2204
2204
|
|
|
@@ -77,10 +77,10 @@ from simba.utils.errors import (ClassifierInferenceError, CorruptedFileError,
|
|
|
77
77
|
SamplingError, SimBAModuleNotFoundError)
|
|
78
78
|
from simba.utils.lookups import get_meta_data_file_headers, get_table
|
|
79
79
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
80
|
-
from simba.utils.read_write import (find_core_cnt,
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
str_2_bool)
|
|
80
|
+
from simba.utils.read_write import (find_core_cnt, get_current_time,
|
|
81
|
+
get_fn_ext, get_memory_usage_of_df,
|
|
82
|
+
get_pkg_version, read_config_entry,
|
|
83
|
+
read_df, read_meta_file, str_2_bool)
|
|
84
84
|
from simba.utils.warnings import (GPUToolsWarning, MissingUserInputWarning,
|
|
85
85
|
MultiProcessingFailedWarning,
|
|
86
86
|
NoModuleWarning, NotEnoughDataWarning,
|
|
@@ -1383,18 +1383,39 @@ class TrainModelMixin(object):
|
|
|
1383
1383
|
x_df: Union[pd.DataFrame, np.ndarray],
|
|
1384
1384
|
multiclass: bool = False,
|
|
1385
1385
|
model_name: Optional[str] = None,
|
|
1386
|
-
data_path: Optional[Union[str, os.PathLike]] = None
|
|
1386
|
+
data_path: Optional[Union[str, os.PathLike]] = None,
|
|
1387
|
+
verbose: bool = False) -> np.ndarray:
|
|
1387
1388
|
|
|
1388
1389
|
"""
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1390
|
+
Helper to predict class probabilities using a fitted random forest classifier.
|
|
1391
|
+
|
|
1392
|
+
Computes prediction probabilities for binary or multiclass classification using either
|
|
1393
|
+
scikit-learn or cuML RandomForestClassifier. For binary classifiers, returns the
|
|
1394
|
+
probability of the positive class (class 1). For multiclass classifiers, returns
|
|
1395
|
+
probabilities for all classes.
|
|
1396
|
+
|
|
1397
|
+
.. csv-table::
|
|
1398
|
+
:header: EXPECTED RUNTIMES
|
|
1399
|
+
:file: ../../docs/tables/clf_predict_proba.csv
|
|
1400
|
+
:widths: 10, 45, 45
|
|
1401
|
+
:align: center
|
|
1402
|
+
:header-rows: 1
|
|
1403
|
+
|
|
1404
|
+
.. seealso::
|
|
1405
|
+
To fit a classifier, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_fit`
|
|
1406
|
+
To define a classifier, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_define`
|
|
1407
|
+
|
|
1408
|
+
:param Union[RandomForestClassifier, cuRF] clf: Fitted random forest classifier object from sklearn or cuml.
|
|
1409
|
+
:param Union[pd.DataFrame, np.ndarray] x_df: Features for data to predict. DataFrame or array of shape (n_samples, n_features).
|
|
1410
|
+
:param bool multiclass: If True, the classifier predicts more than 2 classes. If False, binary classifier (default: False).
|
|
1411
|
+
:param Optional[str] model_name: Name of the model for error messages and logging. Default: None.
|
|
1412
|
+
:param Optional[Union[str, os.PathLike]] data_path: Path to the data file being processed, used in error messages. Default: None.
|
|
1413
|
+
:param bool verbose: If True, print inference progress and timing information. Default: False.
|
|
1414
|
+
:return np.ndarray: Prediction probabilities. For binary classifiers: 1D array of shape (n_samples,) with probability of positive class. For multiclass: 2D array of shape (n_samples, n_classes) with probabilities for each class.
|
|
1415
|
+
|
|
1396
1416
|
"""
|
|
1397
1417
|
|
|
1418
|
+
timer = SimbaTimer(start=True)
|
|
1398
1419
|
if hasattr(clf, "n_features_"):
|
|
1399
1420
|
clf_n_features = clf.n_features_
|
|
1400
1421
|
elif hasattr(clf, "n_features_in_"):
|
|
@@ -1420,6 +1441,8 @@ class TrainModelMixin(object):
|
|
|
1420
1441
|
p_vals = clf.predict_proba(x_df)
|
|
1421
1442
|
if multiclass and (clf.n_classes_ != p_vals.shape[1]):
|
|
1422
1443
|
raise ClassifierInferenceError(msg=f"The classifier {model_name} (data path: {data_path}) is a multiclassifier expected to create {clf.n_classes_} behavior probabilities. However, it produced probabilities for {p_vals.shape[1]} behaviors. See The SimBA GitHub FAQ page or Gitter for more information and suggested fixes.", source=self.__class__.__name__)
|
|
1444
|
+
timer.stop_timer()
|
|
1445
|
+
if verbose: print(f'Inference for model {model_name} over {x_df.shape[0]} observations complete ({timer.elapsed_time_str}s).')
|
|
1423
1446
|
if not multiclass:
|
|
1424
1447
|
if isinstance(p_vals, pd.DataFrame):
|
|
1425
1448
|
return p_vals[1].values
|
|
@@ -1447,7 +1470,7 @@ class TrainModelMixin(object):
|
|
|
1447
1470
|
bootstrap: Optional[bool] = True,
|
|
1448
1471
|
verbose: Optional[int] = 1,
|
|
1449
1472
|
class_weight: Optional[dict] = None,
|
|
1450
|
-
cuda: Optional[bool] = False) -> RandomForestClassifier:
|
|
1473
|
+
cuda: Optional[bool] = False) -> Union[RandomForestClassifier, cuRF]:
|
|
1451
1474
|
|
|
1452
1475
|
if not cuda:
|
|
1453
1476
|
# NOTE: LOKY ISSUES ON WINDOWS WITH SCIKIT IF THE CORE COUNT EXCEEDS 61.
|
|
@@ -1482,20 +1505,32 @@ class TrainModelMixin(object):
|
|
|
1482
1505
|
clf: Union[RandomForestClassifier, cuRF],
|
|
1483
1506
|
x_df: pd.DataFrame,
|
|
1484
1507
|
y_df: pd.DataFrame,
|
|
1485
|
-
) -> RandomForestClassifier:
|
|
1508
|
+
verbose: bool = False) -> Union[RandomForestClassifier, cuRF]:
|
|
1486
1509
|
|
|
1487
1510
|
"""
|
|
1488
|
-
Helper to fit clf model
|
|
1511
|
+
Helper to fit clf model.
|
|
1489
1512
|
|
|
1490
|
-
|
|
1513
|
+
.. csv-table::
|
|
1514
|
+
:header: EXPECTED RUNTIMES
|
|
1515
|
+
:file: ../../docs/tables/clf_fit.csv
|
|
1516
|
+
:widths: 20, 20, 30, 30
|
|
1517
|
+
:align: center
|
|
1518
|
+
:header-rows: 1
|
|
1519
|
+
|
|
1520
|
+
.. seealso::
|
|
1521
|
+
To define a cuml/sklearn object, see :func:`simba.mixins.train_model_mixin.TrainModelMixin.clf_define`
|
|
1522
|
+
|
|
1523
|
+
:param clf: Un-fitted random forest classifier object, either from sklearn or cuml.
|
|
1491
1524
|
:param pd.DataFrame x_df: Pandas dataframe with features.
|
|
1492
1525
|
:param pd.DataFrame y_df: Pandas dataframe/Series with target
|
|
1493
1526
|
:return: Fitted random forest classifier object
|
|
1494
1527
|
:rtype: RandomForestClassifier
|
|
1495
1528
|
"""
|
|
1496
1529
|
|
|
1530
|
+
timer = SimbaTimer(start=True)
|
|
1497
1531
|
nan_features = x_df[~x_df.applymap(np.isreal).all(1)]
|
|
1498
1532
|
nan_target = y_df.loc[pd.to_numeric(y_df).isna()]
|
|
1533
|
+
using_cuda = True if CUML in str(clf.__class__.__module__).lower() else False
|
|
1499
1534
|
if len(nan_features) > 0:
|
|
1500
1535
|
raise FaultyTrainingSetError(
|
|
1501
1536
|
msg=f"{len(nan_features)} frame(s) in your project_folder/csv/targets_inserted directory contains FEATURES with non-numerical values",
|
|
@@ -1504,9 +1539,16 @@ class TrainModelMixin(object):
|
|
|
1504
1539
|
raise FaultyTrainingSetError(
|
|
1505
1540
|
msg=f"{len(nan_target)} frame(s) in your project_folder/csv/targets_inserted directory contains ANNOTATIONS with non-numerical values",
|
|
1506
1541
|
source=self.__class__.__name__)
|
|
1542
|
+
if verbose: print(f'[{get_current_time()}] Fitting classifier for {len(x_df)} observations (cuda: {"True" if using_cuda else "False"})...')
|
|
1543
|
+
if using_cuda:
|
|
1544
|
+
x_data = x_df.values if isinstance(x_df, pd.DataFrame) else x_df
|
|
1545
|
+
y_data = y_df.values if isinstance(y_df, (pd.DataFrame, pd.Series)) else y_df
|
|
1546
|
+
clf.fit(x_data, y_data)
|
|
1547
|
+
else:
|
|
1548
|
+
clf.fit(x_df, y_df)
|
|
1507
1549
|
|
|
1508
|
-
|
|
1509
|
-
|
|
1550
|
+
timer.stop_timer()
|
|
1551
|
+
if verbose: print(f'[{get_current_time()}] Classifier fitted in {timer.elapsed_time_str}s.')
|
|
1510
1552
|
return clf
|
|
1511
1553
|
|
|
1512
1554
|
@staticmethod
|
|
@@ -1563,9 +1605,7 @@ class TrainModelMixin(object):
|
|
|
1563
1605
|
:rtype: Tuple[pd.DataFrame, List[int]]
|
|
1564
1606
|
|
|
1565
1607
|
"""
|
|
1566
|
-
if (platform.system() == "Darwin") and (
|
|
1567
|
-
multiprocessing.get_start_method() != "spawn"
|
|
1568
|
-
):
|
|
1608
|
+
if (platform.system() == "Darwin") and (multiprocessing.get_start_method() != "spawn"):
|
|
1569
1609
|
multiprocessing.set_start_method("spawn", force=True)
|
|
1570
1610
|
cpu_cnt, _ = find_core_cnt()
|
|
1571
1611
|
df_lst, frame_numbers_lst = [], []
|
|
@@ -1592,9 +1632,7 @@ class TrainModelMixin(object):
|
|
|
1592
1632
|
:, ~df_concat.columns.str.contains("^Unnamed")
|
|
1593
1633
|
].astype(np.float32)
|
|
1594
1634
|
memory_size = get_memory_usage_of_df(df=df_concat)
|
|
1595
|
-
print(
|
|
1596
|
-
f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB'
|
|
1597
|
-
)
|
|
1635
|
+
print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB')
|
|
1598
1636
|
|
|
1599
1637
|
return df_concat, frame_numbers_lst
|
|
1600
1638
|
|
|
@@ -2607,9 +2645,9 @@ class TrainModelMixin(object):
|
|
|
2607
2645
|
:param bool plot: If True, create SHAP aggregation and plots.
|
|
2608
2646
|
|
|
2609
2647
|
:example:
|
|
2610
|
-
>>> CONFIG_PATH = r"C
|
|
2611
|
-
>>> RF_PATH = r"C
|
|
2612
|
-
>>> DATA_PATH = r"C
|
|
2648
|
+
>>> CONFIG_PATH = r"C:/troubleshooting/mitra/project_folder/project_config.ini"
|
|
2649
|
+
>>> RF_PATH = r"C:/troubleshooting/mitra/models/validations/straub_tail_5_new/straub_tail_5.sav"
|
|
2650
|
+
>>> DATA_PATH = r"C:/troubleshooting/mitra/project_folder/csv/targets_inserted/new_straub/appended/501_MA142_Gi_CNO_0514.csv"
|
|
2613
2651
|
>>> config = ConfigReader(config_path=CONFIG_PATH)
|
|
2614
2652
|
>>> df = read_df(file_path=DATA_PATH, file_type='csv')
|
|
2615
2653
|
>>> y = df['straub_tail']
|
simba/model/inference_batch.py
CHANGED
|
@@ -45,7 +45,7 @@ class InferenceBatch(TrainModelMixin, ConfigReader):
|
|
|
45
45
|
>>> inferencer.run()
|
|
46
46
|
|
|
47
47
|
:example II:
|
|
48
|
-
>>> inferencer = InferenceBatch(config_path=r"D
|
|
48
|
+
>>> inferencer = InferenceBatch(config_path=r"D:/troubleshooting/mitra/project_folder/project_config.ini", features_dir=r"D:/troubleshooting/mitra/project_folder/videos/bg_removed/rotated/tail_features/APPENDED")
|
|
49
49
|
>>> inferencer.run()
|
|
50
50
|
"""
|
|
51
51
|
|
|
@@ -55,9 +55,9 @@ class YOLOSegmentationInference():
|
|
|
55
55
|
To visualize the segmentation results, see :func:`simba.plotting.yolo_seg_visualizer.YOLOSegmentationVisualizer`
|
|
56
56
|
|
|
57
57
|
:example:
|
|
58
|
-
>>> weights_path = r"D
|
|
59
|
-
>>> video_path = r"D
|
|
60
|
-
>>> save_dir=r"D
|
|
58
|
+
>>> weights_path = r"D:/platea/yolo_071525/mdl/train3/weights/best.pt"
|
|
59
|
+
>>> video_path = r"D:/platea/platea_videos/videos/clipped/10B_Mouse_5-choice_MustTouchTrainingNEWFINAL_a7.mp4"
|
|
60
|
+
>>> save_dir = r"D:/platea/platea_videos/videos/yolo_results"
|
|
61
61
|
>>> runner = YOLOSegmentationInference(weights_path=weights_path, video_path=video_path, save_dir=save_dir, verbose=True, device=0, format=None, stream=True, batch_size=10, imgsz=320, interpolate=True, threshold=0.8, retina_msk=True)
|
|
62
62
|
>>> runner.run()
|
|
63
63
|
|
|
@@ -47,5 +47,5 @@ class OutlierCorrectionSkipper(ConfigReader):
|
|
|
47
47
|
self.timer.stop_timer()
|
|
48
48
|
stdout_success(msg=f"Skipped outlier correction for {len(self.input_csv_paths)} files", elapsed_time=self.timer.elapsed_time_str)
|
|
49
49
|
|
|
50
|
-
# test = OutlierCorrectionSkipper(config_path=
|
|
50
|
+
# test = OutlierCorrectionSkipper(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
|
|
51
51
|
# test.run()
|