simba-uw-tf-dev 4.6.7__py3-none-any.whl → 4.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simba/assets/.recent_projects.txt +1 -0
- simba/data_processors/circling_detector.py +30 -13
- simba/data_processors/cuda/image.py +42 -18
- simba/data_processors/cuda/statistics.py +2 -3
- simba/data_processors/freezing_detector.py +54 -50
- simba/feature_extractors/mitra_feature_extractor.py +2 -2
- simba/mixins/config_reader.py +5 -2
- simba/mixins/plotting_mixin.py +28 -10
- simba/outlier_tools/skip_outlier_correction.py +1 -1
- simba/plotting/gantt_creator.py +29 -10
- simba/plotting/gantt_creator_mp.py +50 -17
- simba/sandbox/analyze_runtimes.py +30 -30
- simba/sandbox/test_directionality.py +47 -47
- simba/sandbox/test_nonstatic_directionality.py +27 -27
- simba/sandbox/test_pycharm_cuda.py +51 -51
- simba/sandbox/test_simba_install.py +41 -41
- simba/sandbox/test_static_directionality.py +26 -26
- simba/sandbox/test_static_directionality_2d.py +26 -26
- simba/sandbox/verify_env.py +42 -42
- simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
- simba/ui/pop_ups/gantt_pop_up.py +31 -6
- simba/ui/pop_ups/video_processing_pop_up.py +1 -1
- simba/video_processors/video_processing.py +1 -1
- {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/METADATA +1 -1
- {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/RECORD +29 -29
- {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/top_level.txt +0 -0
|
@@ -11,12 +11,13 @@ from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
|
|
|
11
11
|
from simba.mixins.timeseries_features_mixin import TimeseriesFeatureMixin
|
|
12
12
|
from simba.utils.checks import (
|
|
13
13
|
check_all_file_names_are_represented_in_video_log, check_if_dir_exists,
|
|
14
|
-
|
|
14
|
+
check_str, check_valid_dataframe)
|
|
15
15
|
from simba.utils.data import detect_bouts, plug_holes_shortest_bout
|
|
16
16
|
from simba.utils.enums import Formats
|
|
17
17
|
from simba.utils.printing import stdout_success
|
|
18
18
|
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
|
|
19
|
-
get_fn_ext, read_df,
|
|
19
|
+
get_current_time, get_fn_ext, read_df,
|
|
20
|
+
read_video_info)
|
|
20
21
|
|
|
21
22
|
CIRCLING = 'CIRCLING'
|
|
22
23
|
|
|
@@ -58,30 +59,34 @@ class CirclingDetector(ConfigReader):
|
|
|
58
59
|
"""
|
|
59
60
|
|
|
60
61
|
def __init__(self,
|
|
61
|
-
data_dir: Union[str, os.PathLike],
|
|
62
62
|
config_path: Union[str, os.PathLike],
|
|
63
63
|
nose_name: Optional[str] = 'nose',
|
|
64
|
+
data_dir: Optional[Union[str, os.PathLike]] = None,
|
|
64
65
|
left_ear_name: Optional[str] = 'left_ear',
|
|
65
66
|
right_ear_name: Optional[str] = 'right_ear',
|
|
66
67
|
tail_base_name: Optional[str] = 'tail_base',
|
|
67
68
|
center_name: Optional[str] = 'center',
|
|
68
|
-
time_threshold: Optional[int] =
|
|
69
|
-
circular_range_threshold: Optional[int] =
|
|
69
|
+
time_threshold: Optional[int] = 7,
|
|
70
|
+
circular_range_threshold: Optional[int] = 350,
|
|
71
|
+
shortest_bout: int = 100,
|
|
70
72
|
movement_threshold: Optional[int] = 60,
|
|
71
73
|
save_dir: Optional[Union[str, os.PathLike]] = None):
|
|
72
74
|
|
|
73
|
-
check_if_dir_exists(in_dir=data_dir)
|
|
74
75
|
for bp_name in [nose_name, left_ear_name, right_ear_name, tail_base_name]: check_str(name='body part name', value=bp_name, allow_blank=False)
|
|
75
|
-
self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
|
|
76
76
|
ConfigReader.__init__(self, config_path=config_path, read_video_info=True, create_logger=False)
|
|
77
|
+
if data_dir is not None:
|
|
78
|
+
check_if_dir_exists(in_dir=data_dir)
|
|
79
|
+
else:
|
|
80
|
+
data_dir = self.outlier_corrected_dir
|
|
81
|
+
self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
|
|
77
82
|
self.nose_heads = [f'{nose_name}_x'.lower(), f'{nose_name}_y'.lower()]
|
|
78
83
|
self.left_ear_heads = [f'{left_ear_name}_x'.lower(), f'{left_ear_name}_y'.lower()]
|
|
79
84
|
self.right_ear_heads = [f'{right_ear_name}_x'.lower(), f'{right_ear_name}_y'.lower()]
|
|
80
85
|
self.center_heads = [f'{center_name}_x'.lower(), f'{center_name}_y'.lower()]
|
|
81
86
|
self.required_field = self.nose_heads + self.left_ear_heads + self.right_ear_heads
|
|
82
|
-
self.save_dir = save_dir
|
|
87
|
+
self.save_dir, self.shortest_bout = save_dir, shortest_bout
|
|
83
88
|
if self.save_dir is None:
|
|
84
|
-
self.save_dir = os.path.join(self.logs_path, f'circling_data_{self.datetime}')
|
|
89
|
+
self.save_dir = os.path.join(self.logs_path, f'circling_data_{time_threshold}s_{circular_range_threshold}d_{movement_threshold}mm_{self.datetime}')
|
|
85
90
|
os.makedirs(self.save_dir)
|
|
86
91
|
else:
|
|
87
92
|
check_if_dir_exists(in_dir=self.save_dir)
|
|
@@ -93,7 +98,7 @@ class CirclingDetector(ConfigReader):
|
|
|
93
98
|
check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
|
|
94
99
|
for file_cnt, file_path in enumerate(self.data_paths):
|
|
95
100
|
video_name = get_fn_ext(filepath=file_path)[1]
|
|
96
|
-
print(f'Analyzing {video_name} ({file_cnt+1}/{len(self.data_paths)})
|
|
101
|
+
print(f'[{get_current_time()}] Analyzing circling {video_name}... (video {file_cnt+1}/{len(self.data_paths)})')
|
|
97
102
|
save_file_path = os.path.join(self.save_dir, f'{video_name}.csv')
|
|
98
103
|
df = read_df(file_path=file_path, file_type='csv').reset_index(drop=True)
|
|
99
104
|
_, px_per_mm, fps = read_video_info(video_info_df=self.video_info_df, video_name=video_name)
|
|
@@ -115,11 +120,24 @@ class CirclingDetector(ConfigReader):
|
|
|
115
120
|
circling_idx = np.argwhere(sliding_circular_range >= self.circular_range_threshold).astype(np.int32).flatten()
|
|
116
121
|
movement_idx = np.argwhere(movement_sum >= self.movement_threshold).astype(np.int32).flatten()
|
|
117
122
|
circling_idx = [x for x in movement_idx if x in circling_idx]
|
|
123
|
+
df[f'Probability_{CIRCLING}'] = 0
|
|
118
124
|
df[CIRCLING] = 0
|
|
119
125
|
df.loc[circling_idx, CIRCLING] = 1
|
|
126
|
+
df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
|
|
127
|
+
df = plug_holes_shortest_bout(data_df=df, clf_name=CIRCLING, fps=fps, shortest_bout=self.shortest_bout)
|
|
120
128
|
bouts = detect_bouts(data_df=df, target_lst=[CIRCLING], fps=fps)
|
|
121
|
-
|
|
129
|
+
if len(bouts) > 0:
|
|
130
|
+
df[CIRCLING] = 0
|
|
131
|
+
circling_idx = list(bouts.apply(lambda x: list(range(int(x["Start_frame"]), int(x["End_frame"]) + 1)), 1))
|
|
132
|
+
circling_idx = [x for xs in circling_idx for x in xs]
|
|
133
|
+
df.loc[circling_idx, CIRCLING] = 1
|
|
134
|
+
df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
|
|
135
|
+
else:
|
|
136
|
+
df[CIRCLING] = 0
|
|
137
|
+
circling_idx = []
|
|
138
|
+
|
|
122
139
|
df.to_csv(save_file_path)
|
|
140
|
+
#print(video_name, len(circling_idx), round(len(circling_idx) / fps, 4), df[CIRCLING].sum())
|
|
123
141
|
agg_results.loc[len(agg_results)] = [video_name, len(circling_idx), round(len(circling_idx) / fps, 4), len(bouts), round((len(circling_idx) / len(df)) * 100, 4), len(df), round(len(df)/fps, 2) ]
|
|
124
142
|
|
|
125
143
|
agg_results.to_csv(agg_results_path)
|
|
@@ -127,7 +145,6 @@ class CirclingDetector(ConfigReader):
|
|
|
127
145
|
|
|
128
146
|
#
|
|
129
147
|
#
|
|
130
|
-
# detector = CirclingDetector(
|
|
131
|
-
# config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
|
|
148
|
+
# detector = CirclingDetector(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
|
|
132
149
|
# detector.run()
|
|
133
150
|
|
|
@@ -332,12 +332,22 @@ def _digital(data, results):
|
|
|
332
332
|
def img_stack_brightness(x: np.ndarray,
|
|
333
333
|
method: Optional[Literal['photometric', 'digital']] = 'digital',
|
|
334
334
|
ignore_black: bool = True,
|
|
335
|
-
verbose: bool = False
|
|
335
|
+
verbose: bool = False,
|
|
336
|
+
batch_size: int = 2500) -> np.ndarray:
|
|
336
337
|
"""
|
|
337
338
|
Calculate the average brightness of a stack of images using a specified method.
|
|
338
339
|
|
|
339
340
|
Useful for analyzing light cues or brightness changes over time. For example, compute brightness in images containing a light cue ROI, then perform clustering (e.g., k-means) on brightness values to identify frames when the light cue is on vs off.
|
|
340
341
|
|
|
342
|
+
.. csv-table::
|
|
343
|
+
:header: EXPECTED RUNTIMES
|
|
344
|
+
:file: ../../../docs/tables/img_stack_brightness_gpu.csv
|
|
345
|
+
:widths: 10, 45, 45
|
|
346
|
+
:align: center
|
|
347
|
+
:class: simba-table
|
|
348
|
+
:header-rows: 1
|
|
349
|
+
|
|
350
|
+
|
|
341
351
|
- **Photometric Method**: The brightness is calculated using the formula:
|
|
342
352
|
|
|
343
353
|
.. math::
|
|
@@ -365,27 +375,41 @@ def img_stack_brightness(x: np.ndarray,
|
|
|
365
375
|
|
|
366
376
|
check_instance(source=img_stack_brightness.__name__, instance=x, accepted_types=(np.ndarray,))
|
|
367
377
|
check_if_valid_img(data=x[0], source=img_stack_brightness.__name__)
|
|
378
|
+
check_int(name=f'{img_stack_brightness.__name__} batch_size', value=batch_size, allow_zero=False, allow_negative=False, raise_error=True)
|
|
368
379
|
x, timer = np.ascontiguousarray(x).astype(np.uint8), SimbaTimer(start=True)
|
|
380
|
+
results = []
|
|
369
381
|
if x.ndim == 4:
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
382
|
+
batch_results_dev = cuda.device_array((batch_size, x.shape[1], x.shape[2]), dtype=np.uint8)
|
|
383
|
+
for batch_cnt, l in enumerate(range(0, x.shape[0], batch_size)):
|
|
384
|
+
r = l + batch_size
|
|
385
|
+
batch_x = x[l:r]
|
|
386
|
+
if batch_x.ndim == 4:
|
|
387
|
+
grid_x = (batch_x.shape[1] + 16 - 1) // 16
|
|
388
|
+
grid_y = (batch_x.shape[2] + 16 - 1) // 16
|
|
389
|
+
grid_z = batch_x.shape[0]
|
|
390
|
+
threads_per_block = (16, 16, 1)
|
|
391
|
+
blocks_per_grid = (grid_y, grid_x, grid_z)
|
|
392
|
+
x_dev = cuda.to_device(batch_x)
|
|
393
|
+
if method == PHOTOMETRIC:
|
|
394
|
+
_photometric[blocks_per_grid, threads_per_block](x_dev, batch_results_dev)
|
|
395
|
+
else:
|
|
396
|
+
_digital[blocks_per_grid, threads_per_block](x_dev, batch_results_dev)
|
|
397
|
+
batch_results_host = batch_results_dev.copy_to_host()[:batch_x.shape[0]]
|
|
398
|
+
batch_results_cp = cp.asarray(batch_results_host)
|
|
399
|
+
if ignore_black:
|
|
400
|
+
mask = batch_results_cp != 0
|
|
401
|
+
batch_results_cp = cp.where(mask, batch_results_cp, cp.nan)
|
|
402
|
+
batch_results = cp.nanmean(batch_results_cp, axis=(1, 2))
|
|
403
|
+
batch_results = cp.where(cp.isnan(batch_results), 0, batch_results)
|
|
404
|
+
batch_results = batch_results.get()
|
|
405
|
+
else:
|
|
406
|
+
batch_results = cp.mean(batch_results_cp, axis=(1, 2)).get()
|
|
379
407
|
else:
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
masked_array = np.ma.masked_equal(results, 0)
|
|
384
|
-
results = np.mean(masked_array, axis=(1, 2)).filled(0)
|
|
385
|
-
else:
|
|
386
|
-
results = deepcopy(x)
|
|
387
|
-
results = np.mean(results, axis=(1, 2))
|
|
408
|
+
batch_results = deepcopy(x)
|
|
409
|
+
batch_results = np.mean(batch_results, axis=(1, 2))
|
|
410
|
+
results.append(batch_results)
|
|
388
411
|
timer.stop_timer()
|
|
412
|
+
results = np.concatenate(results) if len(results) > 0 else np.array([])
|
|
389
413
|
if verbose: print(f'Brightness computed in {results.shape[0]} images (elapsed time {timer.elapsed_time_str}s)')
|
|
390
414
|
return results
|
|
391
415
|
|
|
@@ -19,7 +19,6 @@ from scipy.spatial import ConvexHull
|
|
|
19
19
|
from simba.utils.read_write import get_unique_values_in_iterable, read_df
|
|
20
20
|
from simba.utils.warnings import GPUToolsWarning
|
|
21
21
|
|
|
22
|
-
|
|
23
22
|
try:
|
|
24
23
|
import cupy as cp
|
|
25
24
|
from cupyx.scipy.spatial.distance import cdist
|
|
@@ -44,8 +43,8 @@ except:
|
|
|
44
43
|
|
|
45
44
|
from simba.data_processors.cuda.utils import _cuda_are_rows_equal
|
|
46
45
|
from simba.mixins.statistics_mixin import Statistics
|
|
47
|
-
from simba.utils.checks import (check_int, check_str,
|
|
48
|
-
|
|
46
|
+
from simba.utils.checks import (check_float, check_int, check_str,
|
|
47
|
+
check_valid_array, check_valid_tuple)
|
|
49
48
|
from simba.utils.data import bucket_data
|
|
50
49
|
from simba.utils.enums import Formats
|
|
51
50
|
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Optional, Union
|
|
3
|
-
|
|
4
3
|
import numpy as np
|
|
5
4
|
import pandas as pd
|
|
6
5
|
from numba import typed
|
|
7
|
-
|
|
8
6
|
from simba.mixins.config_reader import ConfigReader
|
|
9
7
|
from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
|
|
10
8
|
from simba.mixins.timeseries_features_mixin import TimeseriesFeatureMixin
|
|
@@ -14,65 +12,73 @@ from simba.utils.checks import (
|
|
|
14
12
|
from simba.utils.data import detect_bouts, plug_holes_shortest_bout
|
|
15
13
|
from simba.utils.enums import Formats
|
|
16
14
|
from simba.utils.printing import stdout_success
|
|
17
|
-
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
|
|
18
|
-
get_fn_ext, read_df, read_video_info)
|
|
15
|
+
from simba.utils.read_write import (find_files_of_filetypes_in_directory, get_fn_ext, read_df, read_video_info, get_current_time)
|
|
19
16
|
|
|
20
17
|
NAPE_X, NAPE_Y = 'nape_x', 'nape_y'
|
|
21
18
|
FREEZING = 'FREEZING'
|
|
22
19
|
|
|
23
20
|
class FreezingDetector(ConfigReader):
|
|
24
|
-
|
|
25
21
|
"""
|
|
26
|
-
Detect freezing behavior using heuristic rules.
|
|
27
|
-
|
|
22
|
+
Detect freezing behavior using heuristic rules based on movement velocity thresholds.
|
|
23
|
+
Analyzes pose-estimation data to detect freezing episodes by computing the mean velocity
|
|
24
|
+
of key body parts (nape, nose, and tail-base) and identifying periods where movement falls below
|
|
25
|
+
a specified threshold for a minimum duration.
|
|
28
26
|
.. important::
|
|
29
|
-
|
|
30
27
|
Freezing is detected as `present` when **the velocity (computed from the mean movement of the nape, nose, and tail-base body-parts) falls below
|
|
31
|
-
the movement threshold for the duration (and longer) of the
|
|
32
|
-
|
|
28
|
+
the movement threshold for the duration (and longer) of the specified time-window**.
|
|
33
29
|
Freezing is detected as `absent` when not present.
|
|
34
|
-
|
|
35
30
|
.. note::
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
:param Union[str, os.PathLike] data_dir: Path to directory containing pose-estimated body-part data in CSV format.
|
|
40
|
-
:param Union[str, os.PathLike] config_path: Path to SimBA project config file.
|
|
41
|
-
:param Optional[str] nose_name: The name of the pose-estimated nose body-part. Defaults to 'nose'.
|
|
42
|
-
:param Optional[str] left_ear_name: The name of the pose-estimated left ear body-part. Defaults to '
|
|
43
|
-
:param Optional[str] right_ear_name: The name of the pose-estimated right ear body-part. Defaults to 'right_ear'.
|
|
44
|
-
:param Optional[str] tail_base_name: The name of the pose-estimated tail base body-part. Defaults to 'tail_base'.
|
|
45
|
-
:param Optional[int] time_window: The time window in
|
|
46
|
-
:param Optional[int] movement_threshold:
|
|
47
|
-
:param Optional[
|
|
48
|
-
|
|
31
|
+
The method uses the left and right ear body-parts to compute the `nape` location of the animal
|
|
32
|
+
as the midpoint between the ears. The nape, nose, and tail-base movements are averaged to compute
|
|
33
|
+
overall animal movement velocity.
|
|
34
|
+
:param Union[str, os.PathLike] data_dir: Path to directory containing pose-estimated body-part data in CSV format. Each CSV file should contain pose estimation data for one video.
|
|
35
|
+
:param Union[str, os.PathLike] config_path: Path to SimBA project config file (`.ini` format) containing project settings and video information.
|
|
36
|
+
:param Optional[str] nose_name: The name of the pose-estimated nose body-part column (without _x/_y suffix). Defaults to 'nose'.
|
|
37
|
+
:param Optional[str] left_ear_name: The name of the pose-estimated left ear body-part column (without _x/_y suffix). Defaults to 'Left_ear'.
|
|
38
|
+
:param Optional[str] right_ear_name: The name of the pose-estimated right ear body-part column (without _x/_y suffix). Defaults to 'right_ear'.
|
|
39
|
+
:param Optional[str] tail_base_name: The name of the pose-estimated tail base body-part column (without _x/_y suffix). Defaults to 'tail_base'.
|
|
40
|
+
:param Optional[int] time_window: The minimum time window in seconds that movement must be below the threshold to be considered freezing. Only freezing bouts lasting at least this duration are retained. Defaults to 3.
|
|
41
|
+
:param Optional[int] movement_threshold: Movement threshold in millimeters per second. Frames with mean velocity below this threshold are considered potential freezing. Defaults to 5.
|
|
42
|
+
:param Optional[int] shortest_bout: Minimum duration in milliseconds for a freezing bout to be considered valid. Shorter bouts are filtered out. Defaults to 100.
|
|
43
|
+
:param Optional[Union[str, os.PathLike]] save_dir: Directory where to store the results. If None, then results are stored in a timestamped subdirectory within the ``logs`` directory of the SimBA project.
|
|
44
|
+
:returns: None. Results are saved to CSV files in the specified save directory:
|
|
45
|
+
- Individual video results: One CSV file per video with freezing annotations added as a 'FREEZING' column (1 = freezing, 0 = not freezing)
|
|
46
|
+
- Aggregate results: `aggregate_freezing_results.csv` containing summary statistics for all videos
|
|
49
47
|
:example:
|
|
50
|
-
>>> FreezingDetector(
|
|
51
|
-
|
|
48
|
+
>>> FreezingDetector(
|
|
49
|
+
... data_dir=r'D:\\troubleshooting\\mitra\\project_folder\\csv\\outlier_corrected_movement_location',
|
|
50
|
+
... config_path=r"D:\\troubleshooting\\mitra\\project_folder\\project_config.ini",
|
|
51
|
+
... time_window=3,
|
|
52
|
+
... movement_threshold=5,
|
|
53
|
+
... shortest_bout=100
|
|
54
|
+
... ).run()
|
|
52
55
|
References
|
|
53
56
|
----------
|
|
57
|
+
..
|
|
54
58
|
.. [1] Sabnis et al., Visual detection of seizures in mice using supervised machine learning, `biorxiv`, doi: https://doi.org/10.1101/2024.05.29.596520.
|
|
55
59
|
.. [2] Lopez et al., Region-specific Nucleus Accumbens Dopamine Signals Encode Distinct Aspects of Avoidance Learning, `biorxiv`, doi: https://doi.org/10.1101/2024.08.28.610149
|
|
56
|
-
.. [3] Lopez, Gabriela C., Louis D. Van Camp, Ryan F. Kovaleski, et al.
|
|
60
|
+
.. [3] Lopez, Gabriela C., Louis D. Van Camp, Ryan F. Kovaleski, et al. "Region-Specific Nucleus Accumbens Dopamine Signals Encode Distinct Aspects of Avoidance Learning." `Cell Biology`, Volume 35, Issue 10p2433-2443.e5May 19, 2025. DOI: 10.1016/j.cub.2025.04.006
|
|
57
61
|
.. [4] Lazaro et al., Brainwide Genetic Capture for Conscious State Transitions, `biorxiv`, doi: https://doi.org/10.1101/2025.03.28.646066
|
|
62
|
+
.. [5] Sabnis et al., Visual detection of seizures in mice using supervised machine learning, 2025, Cell Reports Methods 5, 101242 December 15, 2025.
|
|
58
63
|
"""
|
|
59
|
-
|
|
60
64
|
def __init__(self,
|
|
61
|
-
data_dir: Union[str, os.PathLike],
|
|
62
65
|
config_path: Union[str, os.PathLike],
|
|
63
|
-
nose_name:
|
|
64
|
-
left_ear_name:
|
|
65
|
-
right_ear_name:
|
|
66
|
-
tail_base_name:
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
66
|
+
nose_name: str = 'nose',
|
|
67
|
+
left_ear_name: str = 'Left_ear',
|
|
68
|
+
right_ear_name: str = 'right_ear',
|
|
69
|
+
tail_base_name: str = 'tail_base',
|
|
70
|
+
data_dir: Optional[Union[str, os.PathLike]] = None,
|
|
71
|
+
time_window: int = 4,
|
|
72
|
+
movement_threshold: int = 5,
|
|
73
|
+
shortest_bout: int = 100,
|
|
70
74
|
save_dir: Optional[Union[str, os.PathLike]] = None):
|
|
71
|
-
|
|
72
|
-
check_if_dir_exists(in_dir=data_dir)
|
|
73
75
|
for bp_name in [nose_name, left_ear_name, right_ear_name, tail_base_name]: check_str(name='body part name', value=bp_name, allow_blank=False)
|
|
74
|
-
self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
|
|
75
76
|
ConfigReader.__init__(self, config_path=config_path, read_video_info=True, create_logger=False)
|
|
77
|
+
if data_dir is not None:
|
|
78
|
+
check_if_dir_exists(in_dir=data_dir)
|
|
79
|
+
else:
|
|
80
|
+
data_dir = self.outlier_corrected_dir
|
|
81
|
+
self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
|
|
76
82
|
self.nose_heads = [f'{nose_name}_x'.lower(), f'{nose_name}_y'.lower()]
|
|
77
83
|
self.left_ear_heads = [f'{left_ear_name}_x'.lower(), f'{left_ear_name}_y'.lower()]
|
|
78
84
|
self.right_ear_heads = [f'{right_ear_name}_x'.lower(), f'{right_ear_name}_y'.lower()]
|
|
@@ -82,21 +88,19 @@ class FreezingDetector(ConfigReader):
|
|
|
82
88
|
check_int(name='movement_threshold', value=movement_threshold, min_value=1)
|
|
83
89
|
self.save_dir = save_dir
|
|
84
90
|
if self.save_dir is None:
|
|
85
|
-
self.save_dir = os.path.join(self.logs_path, f'freezing_data_time_{time_window}s_{self.datetime}')
|
|
91
|
+
self.save_dir = os.path.join(self.logs_path, f'freezing_data_time_{time_window}s_{movement_threshold}mm_{self.datetime}')
|
|
86
92
|
os.makedirs(self.save_dir)
|
|
87
93
|
else:
|
|
88
94
|
check_if_dir_exists(in_dir=self.save_dir)
|
|
89
95
|
self.time_window, self.movement_threshold = time_window, movement_threshold
|
|
90
96
|
self.movement_threshold, self.shortest_bout = movement_threshold, shortest_bout
|
|
91
|
-
self.run()
|
|
92
|
-
|
|
93
97
|
def run(self):
|
|
94
98
|
agg_results = pd.DataFrame(columns=['VIDEO', 'FREEZING FRAMES', 'FREEZING TIME (S)', 'FREEZING BOUT COUNTS', 'FREEZING PCT OF SESSION', 'VIDEO TOTAL FRAMES', 'VIDEO TOTAL TIME (S)'])
|
|
95
|
-
agg_results_path = os.path.join(self.save_dir, '
|
|
99
|
+
agg_results_path = os.path.join(self.save_dir, f'aggregate_freezing_results_{self.datetime}.csv')
|
|
96
100
|
check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
|
|
97
101
|
for file_cnt, file_path in enumerate(self.data_paths):
|
|
98
102
|
video_name = get_fn_ext(filepath=file_path)[1]
|
|
99
|
-
print(f'Analyzing {video_name}...({file_cnt+1}/{len(self.data_paths)})')
|
|
103
|
+
print(f'[{get_current_time()}] Analyzing freezing {video_name}...(video {file_cnt+1}/{len(self.data_paths)})')
|
|
100
104
|
save_file_path = os.path.join(self.save_dir, f'{video_name}.csv')
|
|
101
105
|
df = read_df(file_path=file_path, file_type='csv').reset_index(drop=True)
|
|
102
106
|
_, px_per_mm, fps = read_video_info(vid_info_df=self.video_info_df, video_name=video_name)
|
|
@@ -118,23 +122,23 @@ class FreezingDetector(ConfigReader):
|
|
|
118
122
|
mean_movement = np.mean(movement, axis=1)
|
|
119
123
|
mm_s = TimeseriesFeatureMixin.sliding_descriptive_statistics(data=mean_movement.astype(np.float32), window_sizes=np.array([1], dtype=np.float64), sample_rate=int(fps), statistics=typed.List(["sum"]))[0].flatten()
|
|
120
124
|
freezing_idx = np.argwhere(mm_s <= self.movement_threshold).astype(np.int32).flatten()
|
|
125
|
+
df[f'Probability_{FREEZING}'] = 0
|
|
121
126
|
df[FREEZING] = 0
|
|
122
127
|
df.loc[freezing_idx, FREEZING] = 1
|
|
123
128
|
df = plug_holes_shortest_bout(data_df=df, clf_name=FREEZING, fps=fps, shortest_bout=self.shortest_bout)
|
|
124
129
|
bouts = detect_bouts(data_df=df, target_lst=[FREEZING], fps=fps)
|
|
125
130
|
bouts = bouts[bouts['Bout_time'] >= self.time_window]
|
|
126
131
|
if len(bouts) > 0:
|
|
132
|
+
df[FREEZING] = 0
|
|
127
133
|
freezing_idx = list(bouts.apply(lambda x: list(range(int(x["Start_frame"]), int(x["End_frame"]) + 1)), 1))
|
|
128
134
|
freezing_idx = [x for xs in freezing_idx for x in xs]
|
|
129
135
|
df.loc[freezing_idx, FREEZING] = 1
|
|
136
|
+
df.loc[freezing_idx, f'Probability_{FREEZING}'] = 1
|
|
130
137
|
else:
|
|
138
|
+
df[FREEZING] = 0
|
|
131
139
|
freezing_idx = []
|
|
132
140
|
df.to_csv(save_file_path)
|
|
141
|
+
print(video_name, len(freezing_idx), round(len(freezing_idx) / fps, 4), df[FREEZING].sum())
|
|
133
142
|
agg_results.loc[len(agg_results)] = [video_name, len(freezing_idx), round(len(freezing_idx) / fps, 4), len(bouts), round((len(freezing_idx) / len(df)) * 100, 4), len(df), round(len(df)/fps, 2) ]
|
|
134
|
-
|
|
135
143
|
agg_results.to_csv(agg_results_path)
|
|
136
|
-
stdout_success(msg=f'Results saved in {self.save_dir} directory.')
|
|
137
|
-
|
|
138
|
-
#
|
|
139
|
-
# FreezingDetector(data_dir=r'C:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location',
|
|
140
|
-
# config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
|
|
144
|
+
self.timer.stop_timer(); stdout_success(msg=f'Results saved in {self.save_dir} directory.', elapsed_time=self.timer.elapsed_time_str)
|
|
@@ -28,7 +28,7 @@ RIGHT_EAR = 'right_ear'
|
|
|
28
28
|
CENTER = 'center'
|
|
29
29
|
TAIL_BASE = 'tail_base'
|
|
30
30
|
TAIL_CENTER = 'tail_center'
|
|
31
|
-
TAIL_TIP = '
|
|
31
|
+
TAIL_TIP = 'tail_end'
|
|
32
32
|
|
|
33
33
|
TIME_WINDOWS = np.array([0.25, 0.5, 1.0, 2.0])
|
|
34
34
|
|
|
@@ -207,7 +207,7 @@ class MitraFeatureExtractor(ConfigReader,
|
|
|
207
207
|
|
|
208
208
|
|
|
209
209
|
|
|
210
|
-
# feature_extractor = MitraFeatureExtractor(config_path=r"
|
|
210
|
+
# feature_extractor = MitraFeatureExtractor(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
|
|
211
211
|
# feature_extractor.run()
|
|
212
212
|
|
|
213
213
|
|
simba/mixins/config_reader.py
CHANGED
|
@@ -41,8 +41,8 @@ from simba.utils.read_write import (find_core_cnt, get_all_clf_names,
|
|
|
41
41
|
get_fn_ext, read_config_file, read_df,
|
|
42
42
|
read_project_path_and_file_type, write_df)
|
|
43
43
|
from simba.utils.warnings import (BodypartColumnNotFoundWarning,
|
|
44
|
-
|
|
45
|
-
NoFileFoundWarning)
|
|
44
|
+
DuplicateNamesWarning, InvalidValueWarning,
|
|
45
|
+
NoDataFoundWarning, NoFileFoundWarning)
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
class ConfigReader(object):
|
|
@@ -610,11 +610,14 @@ class ConfigReader(object):
|
|
|
610
610
|
>>> config_reader.get_bp_headers()
|
|
611
611
|
"""
|
|
612
612
|
|
|
613
|
+
duplicates = list({x for x in self.body_parts_lst if self.body_parts_lst.count(x) > 1})
|
|
614
|
+
if len(duplicates) > 0: DuplicateNamesWarning(msg=f'The pose configuration file at {self.body_parts_path} contains duplicate entries: {duplicates}', source=self.__class__.__name__)
|
|
613
615
|
self.bp_headers = []
|
|
614
616
|
for bp in self.body_parts_lst:
|
|
615
617
|
c1, c2, c3 = (f"{bp}_x", f"{bp}_y", f"{bp}_p")
|
|
616
618
|
self.bp_headers.extend((c1, c2, c3))
|
|
617
619
|
|
|
620
|
+
|
|
618
621
|
def read_config_entry(
|
|
619
622
|
self,
|
|
620
623
|
config: ConfigParser,
|
simba/mixins/plotting_mixin.py
CHANGED
|
@@ -39,7 +39,7 @@ from simba.utils.data import create_color_palette, detect_bouts
|
|
|
39
39
|
from simba.utils.enums import Formats, Keys, Options, Paths
|
|
40
40
|
from simba.utils.errors import InvalidInputError
|
|
41
41
|
from simba.utils.lookups import (get_categorical_palettes, get_color_dict,
|
|
42
|
-
get_named_colors)
|
|
42
|
+
get_fonts, get_named_colors)
|
|
43
43
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
44
44
|
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
|
|
45
45
|
get_fn_ext, get_video_meta_data, read_df,
|
|
@@ -342,16 +342,28 @@ class PlottingMixin(object):
|
|
|
342
342
|
height: int = 480,
|
|
343
343
|
font_size: int = 8,
|
|
344
344
|
font_rotation: int = 45,
|
|
345
|
+
font: Optional[str] = None,
|
|
345
346
|
save_path: Optional[str] = None,
|
|
347
|
+
edge_clr: Optional[str] = 'black',
|
|
346
348
|
hhmmss: bool = False) -> Union[None, np.ndarray]:
|
|
347
349
|
|
|
348
350
|
video_timer = SimbaTimer(start=True)
|
|
349
351
|
colour_tuple_x = list(np.arange(3.5, 203.5, 5))
|
|
352
|
+
original_font_family = copy(plt.rcParams['font.family']) if isinstance(plt.rcParams['font.family'], list) else plt.rcParams['font.family']
|
|
353
|
+
|
|
354
|
+
if font is not None:
|
|
355
|
+
available_fonts = get_fonts()
|
|
356
|
+
if font in available_fonts:
|
|
357
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
358
|
+
plt.rcParams['font.family'] = font
|
|
359
|
+
else:
|
|
360
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
361
|
+
plt.rcParams['font.family'] = [font, 'sans-serif']
|
|
362
|
+
|
|
350
363
|
fig, ax = plt.subplots()
|
|
351
|
-
fig.patch.set_facecolor('#fafafa')
|
|
352
|
-
ax.set_facecolor('#ffffff')
|
|
353
364
|
fig.patch.set_facecolor('white')
|
|
354
|
-
|
|
365
|
+
|
|
366
|
+
plt.title(video_name, fontsize=font_size + 6, pad=25, fontweight='bold')
|
|
355
367
|
ax.spines['top'].set_visible(False)
|
|
356
368
|
ax.spines['right'].set_visible(False)
|
|
357
369
|
ax.spines['left'].set_color('#666666')
|
|
@@ -367,7 +379,7 @@ class PlottingMixin(object):
|
|
|
367
379
|
if event[0] == x:
|
|
368
380
|
ix = clf_names.index(x)
|
|
369
381
|
data_event = event[1][["Start_time", "Bout_time"]]
|
|
370
|
-
ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix])
|
|
382
|
+
ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix], edgecolor=edge_clr, linewidth=0.5, alpha=0.9)
|
|
371
383
|
|
|
372
384
|
x_ticks_seconds = np.round(np.linspace(0, x_length / fps, 6))
|
|
373
385
|
x_ticks_locs = x_ticks_seconds
|
|
@@ -375,18 +387,19 @@ class PlottingMixin(object):
|
|
|
375
387
|
if hhmmss:
|
|
376
388
|
x_lbls = [seconds_to_timestamp(sec) for sec in x_ticks_seconds]
|
|
377
389
|
else:
|
|
378
|
-
x_lbls = x_ticks_seconds
|
|
390
|
+
x_lbls = [int(x) for x in x_ticks_seconds]
|
|
379
391
|
|
|
380
|
-
#x_ticks_locs = x_lbls = np.round(np.linspace(0, x_length / fps, 6))
|
|
381
392
|
ax.set_xticks(x_ticks_locs)
|
|
382
393
|
ax.set_xticklabels(x_lbls)
|
|
383
394
|
ax.set_ylim(0, colour_tuple_x[len(clf_names)])
|
|
384
395
|
ax.set_yticks(np.arange(5, 5 * len(clf_names) + 1, 5))
|
|
385
|
-
ax.set_yticklabels(clf_names, rotation=font_rotation)
|
|
396
|
+
ax.set_yticklabels(clf_names, rotation=font_rotation, ha='right', va='center')
|
|
386
397
|
ax.tick_params(axis="both", labelsize=font_size)
|
|
387
398
|
plt.xlabel(x_label, fontsize=font_size + 3)
|
|
388
|
-
ax.
|
|
389
|
-
|
|
399
|
+
ax.grid(True, axis='both', linewidth=1.0, color='gray', alpha=0.2, linestyle='--', which='major')
|
|
400
|
+
|
|
401
|
+
plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
|
|
402
|
+
plt.tight_layout()
|
|
390
403
|
buffer_ = io.BytesIO()
|
|
391
404
|
plt.savefig(buffer_, format="png")
|
|
392
405
|
buffer_.seek(0)
|
|
@@ -397,6 +410,11 @@ class PlottingMixin(object):
|
|
|
397
410
|
frame = np.uint8(open_cv_image)
|
|
398
411
|
buffer_.close()
|
|
399
412
|
plt.close('all')
|
|
413
|
+
|
|
414
|
+
if font is not None:
|
|
415
|
+
plt.rcParams['font.family'] = original_font_family
|
|
416
|
+
matplotlib.font_manager._get_font.cache_clear()
|
|
417
|
+
|
|
400
418
|
if save_path is not None:
|
|
401
419
|
cv2.imwrite(save_path, frame)
|
|
402
420
|
video_timer.stop_timer()
|
|
@@ -47,5 +47,5 @@ class OutlierCorrectionSkipper(ConfigReader):
|
|
|
47
47
|
self.timer.stop_timer()
|
|
48
48
|
stdout_success(msg=f"Skipped outlier correction for {len(self.input_csv_paths)} files", elapsed_time=self.timer.elapsed_time_str)
|
|
49
49
|
|
|
50
|
-
# test = OutlierCorrectionSkipper(config_path=
|
|
50
|
+
# test = OutlierCorrectionSkipper(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
|
|
51
51
|
# test.run()
|
simba/plotting/gantt_creator.py
CHANGED
|
@@ -2,6 +2,7 @@ __author__ = "Simon Nilsson; sronilsson@gmail.com"
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import shutil
|
|
5
|
+
from copy import deepcopy
|
|
5
6
|
from typing import List, Optional, Union
|
|
6
7
|
|
|
7
8
|
import cv2
|
|
@@ -16,7 +17,7 @@ from simba.utils.checks import (
|
|
|
16
17
|
from simba.utils.data import create_color_palette, detect_bouts
|
|
17
18
|
from simba.utils.enums import Formats, Options
|
|
18
19
|
from simba.utils.errors import NoSpecifiedOutputError
|
|
19
|
-
from simba.utils.lookups import get_named_colors
|
|
20
|
+
from simba.utils.lookups import get_fonts, get_named_colors
|
|
20
21
|
from simba.utils.printing import stdout_success
|
|
21
22
|
from simba.utils.read_write import get_fn_ext, read_df
|
|
22
23
|
|
|
@@ -60,16 +61,18 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
60
61
|
|
|
61
62
|
def __init__(self,
|
|
62
63
|
config_path: Union[str, os.PathLike],
|
|
63
|
-
data_paths: List[Union[str, os.PathLike]],
|
|
64
|
+
data_paths: Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] = None,
|
|
64
65
|
width: int = 640,
|
|
65
66
|
height: int = 480,
|
|
66
67
|
font_size: int = 8,
|
|
67
68
|
font_rotation: int = 45,
|
|
69
|
+
font: Optional[str] = None,
|
|
68
70
|
palette: str = 'Set1',
|
|
69
|
-
frame_setting:
|
|
70
|
-
video_setting:
|
|
71
|
-
last_frm_setting:
|
|
72
|
-
hhmmss:
|
|
71
|
+
frame_setting: bool = False,
|
|
72
|
+
video_setting: bool = False,
|
|
73
|
+
last_frm_setting: bool = True,
|
|
74
|
+
hhmmss: bool = True,
|
|
75
|
+
clf_names: Optional[List[str]] = None):
|
|
73
76
|
|
|
74
77
|
if ((frame_setting != True) and (video_setting != True) and (last_frm_setting != True)):
|
|
75
78
|
raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.")
|
|
@@ -78,7 +81,13 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
78
81
|
check_int(value=height, min_value=1, name=f'{self.__class__.__name__} height')
|
|
79
82
|
check_int(value=font_size, min_value=1, name=f'{self.__class__.__name__} font_size')
|
|
80
83
|
check_int(value=font_rotation, min_value=0, max_value=180, name=f'{self.__class__.__name__} font_rotation')
|
|
81
|
-
|
|
84
|
+
if isinstance(data_paths, list):
|
|
85
|
+
check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
|
|
86
|
+
elif isinstance(data_paths, str):
|
|
87
|
+
check_file_exist_and_readable(file_path=data_paths)
|
|
88
|
+
data_paths = [data_paths]
|
|
89
|
+
else:
|
|
90
|
+
data_paths = deepcopy(self.machine_results_paths)
|
|
82
91
|
check_valid_boolean(value=hhmmss, source=f'{self.__class__.__name__} hhmmss', raise_error=False)
|
|
83
92
|
palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
|
|
84
93
|
check_str(name=f'{self.__class__.__name__} palette', value=palette, options=palettes)
|
|
@@ -90,7 +99,12 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
90
99
|
if not os.path.exists(self.gantt_plot_dir): os.makedirs(self.gantt_plot_dir)
|
|
91
100
|
self.frame_setting, self.video_setting, self.last_frm_setting = frame_setting, video_setting, last_frm_setting
|
|
92
101
|
self.width, self.height, self.font_size, self.font_rotation = width, height, font_size, font_rotation
|
|
93
|
-
|
|
102
|
+
if font is not None:
|
|
103
|
+
check_str(name=f'{self.__class__.__name__} font', value=font, options=list(get_fonts().keys()), raise_error=True)
|
|
104
|
+
if clf_names is not None:
|
|
105
|
+
check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), valid_values=self.clf_names, min_len=1, raise_error=True)
|
|
106
|
+
self.clf_names = clf_names
|
|
107
|
+
self.data_paths, self.hhmmss, self.font = data_paths, hhmmss, font
|
|
94
108
|
self.colours = get_named_colors()
|
|
95
109
|
self.colour_tuple_x = list(np.arange(3.5, 203.5, 5))
|
|
96
110
|
|
|
@@ -121,6 +135,7 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
121
135
|
font_size=self.font_size,
|
|
122
136
|
font_rotation=self.font_rotation,
|
|
123
137
|
video_name=self.video_name,
|
|
138
|
+
font=self.font,
|
|
124
139
|
save_path=os.path.join(self.gantt_plot_dir, f"{self.video_name }_final_image.png"),
|
|
125
140
|
palette=self.clr_lst,
|
|
126
141
|
hhmmss=self.hhmmss)
|
|
@@ -135,6 +150,7 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
135
150
|
width=self.width,
|
|
136
151
|
height=self.height,
|
|
137
152
|
font_size=self.font_size,
|
|
153
|
+
font=self.font,
|
|
138
154
|
font_rotation=self.font_rotation,
|
|
139
155
|
video_name=self.video_name,
|
|
140
156
|
palette=self.clr_lst,
|
|
@@ -156,13 +172,16 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
156
172
|
# test = GanttCreatorSingleProcess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
|
|
157
173
|
# frame_setting=False,
|
|
158
174
|
# video_setting=False,
|
|
159
|
-
# data_paths=[r"C:\troubleshooting\mitra\project_folder\csv\machine_results\
|
|
175
|
+
# data_paths=[r"C:\troubleshooting\mitra\project_folder\csv\machine_results\501_MA142_Gi_CNO_0516.csv"],
|
|
160
176
|
# last_frm_setting=True,
|
|
161
177
|
# width=640,
|
|
162
178
|
# height= 480,
|
|
163
179
|
# font_size=10,
|
|
180
|
+
# font=None,
|
|
164
181
|
# font_rotation=45,
|
|
165
|
-
#
|
|
182
|
+
# hhmmss=False,
|
|
183
|
+
# palette='Set1',
|
|
184
|
+
# clf_names=['straub_tail'])
|
|
166
185
|
# test.run()
|
|
167
186
|
|
|
168
187
|
|