simba-uw-tf-dev 4.6.7__py3-none-any.whl → 4.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. simba/assets/.recent_projects.txt +1 -0
  2. simba/data_processors/circling_detector.py +30 -13
  3. simba/data_processors/cuda/image.py +42 -18
  4. simba/data_processors/cuda/statistics.py +2 -3
  5. simba/data_processors/freezing_detector.py +54 -50
  6. simba/feature_extractors/mitra_feature_extractor.py +2 -2
  7. simba/mixins/config_reader.py +5 -2
  8. simba/mixins/plotting_mixin.py +28 -10
  9. simba/outlier_tools/skip_outlier_correction.py +1 -1
  10. simba/plotting/gantt_creator.py +29 -10
  11. simba/plotting/gantt_creator_mp.py +50 -17
  12. simba/sandbox/analyze_runtimes.py +30 -30
  13. simba/sandbox/test_directionality.py +47 -47
  14. simba/sandbox/test_nonstatic_directionality.py +27 -27
  15. simba/sandbox/test_pycharm_cuda.py +51 -51
  16. simba/sandbox/test_simba_install.py +41 -41
  17. simba/sandbox/test_static_directionality.py +26 -26
  18. simba/sandbox/test_static_directionality_2d.py +26 -26
  19. simba/sandbox/verify_env.py +42 -42
  20. simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
  21. simba/ui/pop_ups/gantt_pop_up.py +31 -6
  22. simba/ui/pop_ups/video_processing_pop_up.py +1 -1
  23. simba/video_processors/video_processing.py +1 -1
  24. {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/METADATA +1 -1
  25. {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/RECORD +29 -29
  26. {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/LICENSE +0 -0
  27. {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/WHEEL +0 -0
  28. {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/entry_points.txt +0 -0
  29. {simba_uw_tf_dev-4.6.7.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/top_level.txt +0 -0
@@ -1,2 +1,3 @@
1
+ E:/troubleshooting/mitra_emergence/project_folder/project_config.ini
1
2
  C:/troubleshooting/meberled/project_folder/project_config.ini
2
3
  C:/troubleshooting/mitra/project_folder/project_config.ini
@@ -11,12 +11,13 @@ from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
11
11
  from simba.mixins.timeseries_features_mixin import TimeseriesFeatureMixin
12
12
  from simba.utils.checks import (
13
13
  check_all_file_names_are_represented_in_video_log, check_if_dir_exists,
14
- check_int, check_str, check_valid_dataframe)
14
+ check_str, check_valid_dataframe)
15
15
  from simba.utils.data import detect_bouts, plug_holes_shortest_bout
16
16
  from simba.utils.enums import Formats
17
17
  from simba.utils.printing import stdout_success
18
18
  from simba.utils.read_write import (find_files_of_filetypes_in_directory,
19
- get_fn_ext, read_df, read_video_info)
19
+ get_current_time, get_fn_ext, read_df,
20
+ read_video_info)
20
21
 
21
22
  CIRCLING = 'CIRCLING'
22
23
 
@@ -58,30 +59,34 @@ class CirclingDetector(ConfigReader):
58
59
  """
59
60
 
60
61
  def __init__(self,
61
- data_dir: Union[str, os.PathLike],
62
62
  config_path: Union[str, os.PathLike],
63
63
  nose_name: Optional[str] = 'nose',
64
+ data_dir: Optional[Union[str, os.PathLike]] = None,
64
65
  left_ear_name: Optional[str] = 'left_ear',
65
66
  right_ear_name: Optional[str] = 'right_ear',
66
67
  tail_base_name: Optional[str] = 'tail_base',
67
68
  center_name: Optional[str] = 'center',
68
- time_threshold: Optional[int] = 10,
69
- circular_range_threshold: Optional[int] = 320,
69
+ time_threshold: Optional[int] = 7,
70
+ circular_range_threshold: Optional[int] = 350,
71
+ shortest_bout: int = 100,
70
72
  movement_threshold: Optional[int] = 60,
71
73
  save_dir: Optional[Union[str, os.PathLike]] = None):
72
74
 
73
- check_if_dir_exists(in_dir=data_dir)
74
75
  for bp_name in [nose_name, left_ear_name, right_ear_name, tail_base_name]: check_str(name='body part name', value=bp_name, allow_blank=False)
75
- self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
76
76
  ConfigReader.__init__(self, config_path=config_path, read_video_info=True, create_logger=False)
77
+ if data_dir is not None:
78
+ check_if_dir_exists(in_dir=data_dir)
79
+ else:
80
+ data_dir = self.outlier_corrected_dir
81
+ self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
77
82
  self.nose_heads = [f'{nose_name}_x'.lower(), f'{nose_name}_y'.lower()]
78
83
  self.left_ear_heads = [f'{left_ear_name}_x'.lower(), f'{left_ear_name}_y'.lower()]
79
84
  self.right_ear_heads = [f'{right_ear_name}_x'.lower(), f'{right_ear_name}_y'.lower()]
80
85
  self.center_heads = [f'{center_name}_x'.lower(), f'{center_name}_y'.lower()]
81
86
  self.required_field = self.nose_heads + self.left_ear_heads + self.right_ear_heads
82
- self.save_dir = save_dir
87
+ self.save_dir, self.shortest_bout = save_dir, shortest_bout
83
88
  if self.save_dir is None:
84
- self.save_dir = os.path.join(self.logs_path, f'circling_data_{self.datetime}')
89
+ self.save_dir = os.path.join(self.logs_path, f'circling_data_{time_threshold}s_{circular_range_threshold}d_{movement_threshold}mm_{self.datetime}')
85
90
  os.makedirs(self.save_dir)
86
91
  else:
87
92
  check_if_dir_exists(in_dir=self.save_dir)
@@ -93,7 +98,7 @@ class CirclingDetector(ConfigReader):
93
98
  check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
94
99
  for file_cnt, file_path in enumerate(self.data_paths):
95
100
  video_name = get_fn_ext(filepath=file_path)[1]
96
- print(f'Analyzing {video_name} ({file_cnt+1}/{len(self.data_paths)})...')
101
+ print(f'[{get_current_time()}] Analyzing circling {video_name}... (video {file_cnt+1}/{len(self.data_paths)})')
97
102
  save_file_path = os.path.join(self.save_dir, f'{video_name}.csv')
98
103
  df = read_df(file_path=file_path, file_type='csv').reset_index(drop=True)
99
104
  _, px_per_mm, fps = read_video_info(video_info_df=self.video_info_df, video_name=video_name)
@@ -115,11 +120,24 @@ class CirclingDetector(ConfigReader):
115
120
  circling_idx = np.argwhere(sliding_circular_range >= self.circular_range_threshold).astype(np.int32).flatten()
116
121
  movement_idx = np.argwhere(movement_sum >= self.movement_threshold).astype(np.int32).flatten()
117
122
  circling_idx = [x for x in movement_idx if x in circling_idx]
123
+ df[f'Probability_{CIRCLING}'] = 0
118
124
  df[CIRCLING] = 0
119
125
  df.loc[circling_idx, CIRCLING] = 1
126
+ df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
127
+ df = plug_holes_shortest_bout(data_df=df, clf_name=CIRCLING, fps=fps, shortest_bout=self.shortest_bout)
120
128
  bouts = detect_bouts(data_df=df, target_lst=[CIRCLING], fps=fps)
121
- df = plug_holes_shortest_bout(data_df=df, clf_name=CIRCLING, fps=fps, shortest_bout=100)
129
+ if len(bouts) > 0:
130
+ df[CIRCLING] = 0
131
+ circling_idx = list(bouts.apply(lambda x: list(range(int(x["Start_frame"]), int(x["End_frame"]) + 1)), 1))
132
+ circling_idx = [x for xs in circling_idx for x in xs]
133
+ df.loc[circling_idx, CIRCLING] = 1
134
+ df.loc[circling_idx, f'Probability_{CIRCLING}'] = 1
135
+ else:
136
+ df[CIRCLING] = 0
137
+ circling_idx = []
138
+
122
139
  df.to_csv(save_file_path)
140
+ #print(video_name, len(circling_idx), round(len(circling_idx) / fps, 4), df[CIRCLING].sum())
123
141
  agg_results.loc[len(agg_results)] = [video_name, len(circling_idx), round(len(circling_idx) / fps, 4), len(bouts), round((len(circling_idx) / len(df)) * 100, 4), len(df), round(len(df)/fps, 2) ]
124
142
 
125
143
  agg_results.to_csv(agg_results_path)
@@ -127,7 +145,6 @@ class CirclingDetector(ConfigReader):
127
145
 
128
146
  #
129
147
  #
130
- # detector = CirclingDetector(data_dir=r'C:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location',
131
- # config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
148
+ # detector = CirclingDetector(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
132
149
  # detector.run()
133
150
 
@@ -332,12 +332,22 @@ def _digital(data, results):
332
332
  def img_stack_brightness(x: np.ndarray,
333
333
  method: Optional[Literal['photometric', 'digital']] = 'digital',
334
334
  ignore_black: bool = True,
335
- verbose: bool = False) -> np.ndarray:
335
+ verbose: bool = False,
336
+ batch_size: int = 2500) -> np.ndarray:
336
337
  """
337
338
  Calculate the average brightness of a stack of images using a specified method.
338
339
 
339
340
  Useful for analyzing light cues or brightness changes over time. For example, compute brightness in images containing a light cue ROI, then perform clustering (e.g., k-means) on brightness values to identify frames when the light cue is on vs off.
340
341
 
342
+ .. csv-table::
343
+ :header: EXPECTED RUNTIMES
344
+ :file: ../../../docs/tables/img_stack_brightness_gpu.csv
345
+ :widths: 10, 45, 45
346
+ :align: center
347
+ :class: simba-table
348
+ :header-rows: 1
349
+
350
+
341
351
  - **Photometric Method**: The brightness is calculated using the formula:
342
352
 
343
353
  .. math::
@@ -365,27 +375,41 @@ def img_stack_brightness(x: np.ndarray,
365
375
 
366
376
  check_instance(source=img_stack_brightness.__name__, instance=x, accepted_types=(np.ndarray,))
367
377
  check_if_valid_img(data=x[0], source=img_stack_brightness.__name__)
378
+ check_int(name=f'{img_stack_brightness.__name__} batch_size', value=batch_size, allow_zero=False, allow_negative=False, raise_error=True)
368
379
  x, timer = np.ascontiguousarray(x).astype(np.uint8), SimbaTimer(start=True)
380
+ results = []
369
381
  if x.ndim == 4:
370
- grid_x = (x.shape[1] + 16 - 1) // 16
371
- grid_y = (x.shape[2] + 16 - 1) // 16
372
- grid_z = x.shape[0]
373
- threads_per_block = (16, 16, 1)
374
- blocks_per_grid = (grid_y, grid_x, grid_z)
375
- x_dev = cuda.to_device(x)
376
- results = cuda.device_array((x.shape[0], x.shape[1], x.shape[2]), dtype=np.uint8)
377
- if method == PHOTOMETRIC:
378
- _photometric[blocks_per_grid, threads_per_block](x_dev, results)
382
+ batch_results_dev = cuda.device_array((batch_size, x.shape[1], x.shape[2]), dtype=np.uint8)
383
+ for batch_cnt, l in enumerate(range(0, x.shape[0], batch_size)):
384
+ r = l + batch_size
385
+ batch_x = x[l:r]
386
+ if batch_x.ndim == 4:
387
+ grid_x = (batch_x.shape[1] + 16 - 1) // 16
388
+ grid_y = (batch_x.shape[2] + 16 - 1) // 16
389
+ grid_z = batch_x.shape[0]
390
+ threads_per_block = (16, 16, 1)
391
+ blocks_per_grid = (grid_y, grid_x, grid_z)
392
+ x_dev = cuda.to_device(batch_x)
393
+ if method == PHOTOMETRIC:
394
+ _photometric[blocks_per_grid, threads_per_block](x_dev, batch_results_dev)
395
+ else:
396
+ _digital[blocks_per_grid, threads_per_block](x_dev, batch_results_dev)
397
+ batch_results_host = batch_results_dev.copy_to_host()[:batch_x.shape[0]]
398
+ batch_results_cp = cp.asarray(batch_results_host)
399
+ if ignore_black:
400
+ mask = batch_results_cp != 0
401
+ batch_results_cp = cp.where(mask, batch_results_cp, cp.nan)
402
+ batch_results = cp.nanmean(batch_results_cp, axis=(1, 2))
403
+ batch_results = cp.where(cp.isnan(batch_results), 0, batch_results)
404
+ batch_results = batch_results.get()
405
+ else:
406
+ batch_results = cp.mean(batch_results_cp, axis=(1, 2)).get()
379
407
  else:
380
- _digital[blocks_per_grid, threads_per_block](x_dev, results)
381
- results = results.copy_to_host()
382
- if ignore_black:
383
- masked_array = np.ma.masked_equal(results, 0)
384
- results = np.mean(masked_array, axis=(1, 2)).filled(0)
385
- else:
386
- results = deepcopy(x)
387
- results = np.mean(results, axis=(1, 2))
408
+ batch_results = deepcopy(x)
409
+ batch_results = np.mean(batch_results, axis=(1, 2))
410
+ results.append(batch_results)
388
411
  timer.stop_timer()
412
+ results = np.concatenate(results) if len(results) > 0 else np.array([])
389
413
  if verbose: print(f'Brightness computed in {results.shape[0]} images (elapsed time {timer.elapsed_time_str}s)')
390
414
  return results
391
415
 
@@ -19,7 +19,6 @@ from scipy.spatial import ConvexHull
19
19
  from simba.utils.read_write import get_unique_values_in_iterable, read_df
20
20
  from simba.utils.warnings import GPUToolsWarning
21
21
 
22
-
23
22
  try:
24
23
  import cupy as cp
25
24
  from cupyx.scipy.spatial.distance import cdist
@@ -44,8 +43,8 @@ except:
44
43
 
45
44
  from simba.data_processors.cuda.utils import _cuda_are_rows_equal
46
45
  from simba.mixins.statistics_mixin import Statistics
47
- from simba.utils.checks import (check_int, check_str, check_valid_array,
48
- check_valid_tuple, check_float)
46
+ from simba.utils.checks import (check_float, check_int, check_str,
47
+ check_valid_array, check_valid_tuple)
49
48
  from simba.utils.data import bucket_data
50
49
  from simba.utils.enums import Formats
51
50
 
@@ -1,10 +1,8 @@
1
1
  import os
2
2
  from typing import Optional, Union
3
-
4
3
  import numpy as np
5
4
  import pandas as pd
6
5
  from numba import typed
7
-
8
6
  from simba.mixins.config_reader import ConfigReader
9
7
  from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
10
8
  from simba.mixins.timeseries_features_mixin import TimeseriesFeatureMixin
@@ -14,65 +12,73 @@ from simba.utils.checks import (
14
12
  from simba.utils.data import detect_bouts, plug_holes_shortest_bout
15
13
  from simba.utils.enums import Formats
16
14
  from simba.utils.printing import stdout_success
17
- from simba.utils.read_write import (find_files_of_filetypes_in_directory,
18
- get_fn_ext, read_df, read_video_info)
15
+ from simba.utils.read_write import (find_files_of_filetypes_in_directory, get_fn_ext, read_df, read_video_info, get_current_time)
19
16
 
20
17
  NAPE_X, NAPE_Y = 'nape_x', 'nape_y'
21
18
  FREEZING = 'FREEZING'
22
19
 
23
20
  class FreezingDetector(ConfigReader):
24
-
25
21
  """
26
- Detect freezing behavior using heuristic rules.
27
-
22
+ Detect freezing behavior using heuristic rules based on movement velocity thresholds.
23
+ Analyzes pose-estimation data to detect freezing episodes by computing the mean velocity
24
+ of key body parts (nape, nose, and tail-base) and identifying periods where movement falls below
25
+ a specified threshold for a minimum duration.
28
26
  .. important::
29
-
30
27
  Freezing is detected as `present` when **the velocity (computed from the mean movement of the nape, nose, and tail-base body-parts) falls below
31
- the movement threshold for the duration (and longer) of the specied time-window**.
32
-
28
+ the movement threshold for the duration (and longer) of the specified time-window**.
33
29
  Freezing is detected as `absent` when not present.
34
-
35
30
  .. note::
36
-
37
- We pass the names of the left and right ears, as the method will use body-parts to compute the `nape` location of the animal.
38
-
39
- :param Union[str, os.PathLike] data_dir: Path to directory containing pose-estimated body-part data in CSV format.
40
- :param Union[str, os.PathLike] config_path: Path to SimBA project config file.
41
- :param Optional[str] nose_name: The name of the pose-estimated nose body-part. Defaults to 'nose'.
42
- :param Optional[str] left_ear_name: The name of the pose-estimated left ear body-part. Defaults to 'left_ear'.
43
- :param Optional[str] right_ear_name: The name of the pose-estimated right ear body-part. Defaults to 'right_ear'.
44
- :param Optional[str] tail_base_name: The name of the pose-estimated tail base body-part. Defaults to 'tail_base'.
45
- :param Optional[int] time_window: The time window in preceding seconds in which to evaluate freezing. Default: 3.
46
- :param Optional[int] movement_threshold: A movement threshold in millimeters per second. Defaults to 5.
47
- :param Optional[Union[str, os.PathLike]] save_dir: Directory where to store the results. If None, then results are stored in the ``logs`` directory of the SimBA project.
48
-
31
+ The method uses the left and right ear body-parts to compute the `nape` location of the animal
32
+ as the midpoint between the ears. The nape, nose, and tail-base movements are averaged to compute
33
+ overall animal movement velocity.
34
+ :param Union[str, os.PathLike] data_dir: Path to directory containing pose-estimated body-part data in CSV format. Each CSV file should contain pose estimation data for one video.
35
+ :param Union[str, os.PathLike] config_path: Path to SimBA project config file (`.ini` format) containing project settings and video information.
36
+ :param Optional[str] nose_name: The name of the pose-estimated nose body-part column (without _x/_y suffix). Defaults to 'nose'.
37
+ :param Optional[str] left_ear_name: The name of the pose-estimated left ear body-part column (without _x/_y suffix). Defaults to 'Left_ear'.
38
+ :param Optional[str] right_ear_name: The name of the pose-estimated right ear body-part column (without _x/_y suffix). Defaults to 'right_ear'.
39
+ :param Optional[str] tail_base_name: The name of the pose-estimated tail base body-part column (without _x/_y suffix). Defaults to 'tail_base'.
40
+ :param Optional[int] time_window: The minimum time window in seconds that movement must be below the threshold to be considered freezing. Only freezing bouts lasting at least this duration are retained. Defaults to 3.
41
+ :param Optional[int] movement_threshold: Movement threshold in millimeters per second. Frames with mean velocity below this threshold are considered potential freezing. Defaults to 5.
42
+ :param Optional[int] shortest_bout: Minimum duration in milliseconds for a freezing bout to be considered valid. Shorter bouts are filtered out. Defaults to 100.
43
+ :param Optional[Union[str, os.PathLike]] save_dir: Directory where to store the results. If None, then results are stored in a timestamped subdirectory within the ``logs`` directory of the SimBA project.
44
+ :returns: None. Results are saved to CSV files in the specified save directory:
45
+ - Individual video results: One CSV file per video with freezing annotations added as a 'FREEZING' column (1 = freezing, 0 = not freezing)
46
+ - Aggregate results: `aggregate_freezing_results.csv` containing summary statistics for all videos
49
47
  :example:
50
- >>> FreezingDetector(data_dir=r'D:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location', config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini")
51
-
48
+ >>> FreezingDetector(
49
+ ... data_dir=r'D:\\troubleshooting\\mitra\\project_folder\\csv\\outlier_corrected_movement_location',
50
+ ... config_path=r"D:\\troubleshooting\\mitra\\project_folder\\project_config.ini",
51
+ ... time_window=3,
52
+ ... movement_threshold=5,
53
+ ... shortest_bout=100
54
+ ... ).run()
52
55
  References
53
56
  ----------
57
+ ..
54
58
  .. [1] Sabnis et al., Visual detection of seizures in mice using supervised machine learning, `biorxiv`, doi: https://doi.org/10.1101/2024.05.29.596520.
55
59
  .. [2] Lopez et al., Region-specific Nucleus Accumbens Dopamine Signals Encode Distinct Aspects of Avoidance Learning, `biorxiv`, doi: https://doi.org/10.1101/2024.08.28.610149
56
- .. [3] Lopez, Gabriela C., Louis D. Van Camp, Ryan F. Kovaleski, et al. Region-Specific Nucleus Accumbens Dopamine Signals Encode Distinct Aspects of Avoidance Learning.” `Cell Biology`, Volume 35, Issue 10p2433-2443.e5May 19, 2025. DOI: 10.1016/j.cub.2025.04.006
60
+ .. [3] Lopez, Gabriela C., Louis D. Van Camp, Ryan F. Kovaleski, et al. "Region-Specific Nucleus Accumbens Dopamine Signals Encode Distinct Aspects of Avoidance Learning." `Cell Biology`, Volume 35, Issue 10p2433-2443.e5May 19, 2025. DOI: 10.1016/j.cub.2025.04.006
57
61
  .. [4] Lazaro et al., Brainwide Genetic Capture for Conscious State Transitions, `biorxiv`, doi: https://doi.org/10.1101/2025.03.28.646066
62
+ .. [5] Sabnis et al., Visual detection of seizures in mice using supervised machine learning, 2025, Cell Reports Methods 5, 101242 December 15, 2025.
58
63
  """
59
-
60
64
  def __init__(self,
61
- data_dir: Union[str, os.PathLike],
62
65
  config_path: Union[str, os.PathLike],
63
- nose_name: Optional[str] = 'nose',
64
- left_ear_name: Optional[str] = 'Left_ear',
65
- right_ear_name: Optional[str] = 'right_ear',
66
- tail_base_name: Optional[str] = 'tail_base',
67
- time_window: Optional[int] = 3,
68
- movement_threshold: Optional[int] = 5,
69
- shortest_bout: Optional[int] = 100,
66
+ nose_name: str = 'nose',
67
+ left_ear_name: str = 'Left_ear',
68
+ right_ear_name: str = 'right_ear',
69
+ tail_base_name: str = 'tail_base',
70
+ data_dir: Optional[Union[str, os.PathLike]] = None,
71
+ time_window: int = 4,
72
+ movement_threshold: int = 5,
73
+ shortest_bout: int = 100,
70
74
  save_dir: Optional[Union[str, os.PathLike]] = None):
71
-
72
- check_if_dir_exists(in_dir=data_dir)
73
75
  for bp_name in [nose_name, left_ear_name, right_ear_name, tail_base_name]: check_str(name='body part name', value=bp_name, allow_blank=False)
74
- self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
75
76
  ConfigReader.__init__(self, config_path=config_path, read_video_info=True, create_logger=False)
77
+ if data_dir is not None:
78
+ check_if_dir_exists(in_dir=data_dir)
79
+ else:
80
+ data_dir = self.outlier_corrected_dir
81
+ self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
76
82
  self.nose_heads = [f'{nose_name}_x'.lower(), f'{nose_name}_y'.lower()]
77
83
  self.left_ear_heads = [f'{left_ear_name}_x'.lower(), f'{left_ear_name}_y'.lower()]
78
84
  self.right_ear_heads = [f'{right_ear_name}_x'.lower(), f'{right_ear_name}_y'.lower()]
@@ -82,21 +88,19 @@ class FreezingDetector(ConfigReader):
82
88
  check_int(name='movement_threshold', value=movement_threshold, min_value=1)
83
89
  self.save_dir = save_dir
84
90
  if self.save_dir is None:
85
- self.save_dir = os.path.join(self.logs_path, f'freezing_data_time_{time_window}s_{self.datetime}')
91
+ self.save_dir = os.path.join(self.logs_path, f'freezing_data_time_{time_window}s_{movement_threshold}mm_{self.datetime}')
86
92
  os.makedirs(self.save_dir)
87
93
  else:
88
94
  check_if_dir_exists(in_dir=self.save_dir)
89
95
  self.time_window, self.movement_threshold = time_window, movement_threshold
90
96
  self.movement_threshold, self.shortest_bout = movement_threshold, shortest_bout
91
- self.run()
92
-
93
97
  def run(self):
94
98
  agg_results = pd.DataFrame(columns=['VIDEO', 'FREEZING FRAMES', 'FREEZING TIME (S)', 'FREEZING BOUT COUNTS', 'FREEZING PCT OF SESSION', 'VIDEO TOTAL FRAMES', 'VIDEO TOTAL TIME (S)'])
95
- agg_results_path = os.path.join(self.save_dir, 'aggregate_freezing_results.csv')
99
+ agg_results_path = os.path.join(self.save_dir, f'aggregate_freezing_results_{self.datetime}.csv')
96
100
  check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
97
101
  for file_cnt, file_path in enumerate(self.data_paths):
98
102
  video_name = get_fn_ext(filepath=file_path)[1]
99
- print(f'Analyzing {video_name}...({file_cnt+1}/{len(self.data_paths)})')
103
+ print(f'[{get_current_time()}] Analyzing freezing {video_name}...(video {file_cnt+1}/{len(self.data_paths)})')
100
104
  save_file_path = os.path.join(self.save_dir, f'{video_name}.csv')
101
105
  df = read_df(file_path=file_path, file_type='csv').reset_index(drop=True)
102
106
  _, px_per_mm, fps = read_video_info(vid_info_df=self.video_info_df, video_name=video_name)
@@ -118,23 +122,23 @@ class FreezingDetector(ConfigReader):
118
122
  mean_movement = np.mean(movement, axis=1)
119
123
  mm_s = TimeseriesFeatureMixin.sliding_descriptive_statistics(data=mean_movement.astype(np.float32), window_sizes=np.array([1], dtype=np.float64), sample_rate=int(fps), statistics=typed.List(["sum"]))[0].flatten()
120
124
  freezing_idx = np.argwhere(mm_s <= self.movement_threshold).astype(np.int32).flatten()
125
+ df[f'Probability_{FREEZING}'] = 0
121
126
  df[FREEZING] = 0
122
127
  df.loc[freezing_idx, FREEZING] = 1
123
128
  df = plug_holes_shortest_bout(data_df=df, clf_name=FREEZING, fps=fps, shortest_bout=self.shortest_bout)
124
129
  bouts = detect_bouts(data_df=df, target_lst=[FREEZING], fps=fps)
125
130
  bouts = bouts[bouts['Bout_time'] >= self.time_window]
126
131
  if len(bouts) > 0:
132
+ df[FREEZING] = 0
127
133
  freezing_idx = list(bouts.apply(lambda x: list(range(int(x["Start_frame"]), int(x["End_frame"]) + 1)), 1))
128
134
  freezing_idx = [x for xs in freezing_idx for x in xs]
129
135
  df.loc[freezing_idx, FREEZING] = 1
136
+ df.loc[freezing_idx, f'Probability_{FREEZING}'] = 1
130
137
  else:
138
+ df[FREEZING] = 0
131
139
  freezing_idx = []
132
140
  df.to_csv(save_file_path)
141
+ print(video_name, len(freezing_idx), round(len(freezing_idx) / fps, 4), df[FREEZING].sum())
133
142
  agg_results.loc[len(agg_results)] = [video_name, len(freezing_idx), round(len(freezing_idx) / fps, 4), len(bouts), round((len(freezing_idx) / len(df)) * 100, 4), len(df), round(len(df)/fps, 2) ]
134
-
135
143
  agg_results.to_csv(agg_results_path)
136
- stdout_success(msg=f'Results saved in {self.save_dir} directory.')
137
-
138
- #
139
- # FreezingDetector(data_dir=r'C:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location',
140
- # config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
144
+ self.timer.stop_timer(); stdout_success(msg=f'Results saved in {self.save_dir} directory.', elapsed_time=self.timer.elapsed_time_str)
@@ -28,7 +28,7 @@ RIGHT_EAR = 'right_ear'
28
28
  CENTER = 'center'
29
29
  TAIL_BASE = 'tail_base'
30
30
  TAIL_CENTER = 'tail_center'
31
- TAIL_TIP = 'tail_tip'
31
+ TAIL_TIP = 'tail_end'
32
32
 
33
33
  TIME_WINDOWS = np.array([0.25, 0.5, 1.0, 2.0])
34
34
 
@@ -207,7 +207,7 @@ class MitraFeatureExtractor(ConfigReader,
207
207
 
208
208
 
209
209
 
210
- # feature_extractor = MitraFeatureExtractor(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini")
210
+ # feature_extractor = MitraFeatureExtractor(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
211
211
  # feature_extractor.run()
212
212
 
213
213
 
@@ -41,8 +41,8 @@ from simba.utils.read_write import (find_core_cnt, get_all_clf_names,
41
41
  get_fn_ext, read_config_file, read_df,
42
42
  read_project_path_and_file_type, write_df)
43
43
  from simba.utils.warnings import (BodypartColumnNotFoundWarning,
44
- InvalidValueWarning, NoDataFoundWarning,
45
- NoFileFoundWarning)
44
+ DuplicateNamesWarning, InvalidValueWarning,
45
+ NoDataFoundWarning, NoFileFoundWarning)
46
46
 
47
47
 
48
48
  class ConfigReader(object):
@@ -610,11 +610,14 @@ class ConfigReader(object):
610
610
  >>> config_reader.get_bp_headers()
611
611
  """
612
612
 
613
+ duplicates = list({x for x in self.body_parts_lst if self.body_parts_lst.count(x) > 1})
614
+ if len(duplicates) > 0: DuplicateNamesWarning(msg=f'The pose configuration file at {self.body_parts_path} contains duplicate entries: {duplicates}', source=self.__class__.__name__)
613
615
  self.bp_headers = []
614
616
  for bp in self.body_parts_lst:
615
617
  c1, c2, c3 = (f"{bp}_x", f"{bp}_y", f"{bp}_p")
616
618
  self.bp_headers.extend((c1, c2, c3))
617
619
 
620
+
618
621
  def read_config_entry(
619
622
  self,
620
623
  config: ConfigParser,
@@ -39,7 +39,7 @@ from simba.utils.data import create_color_palette, detect_bouts
39
39
  from simba.utils.enums import Formats, Keys, Options, Paths
40
40
  from simba.utils.errors import InvalidInputError
41
41
  from simba.utils.lookups import (get_categorical_palettes, get_color_dict,
42
- get_named_colors)
42
+ get_fonts, get_named_colors)
43
43
  from simba.utils.printing import SimbaTimer, stdout_success
44
44
  from simba.utils.read_write import (find_files_of_filetypes_in_directory,
45
45
  get_fn_ext, get_video_meta_data, read_df,
@@ -342,16 +342,28 @@ class PlottingMixin(object):
342
342
  height: int = 480,
343
343
  font_size: int = 8,
344
344
  font_rotation: int = 45,
345
+ font: Optional[str] = None,
345
346
  save_path: Optional[str] = None,
347
+ edge_clr: Optional[str] = 'black',
346
348
  hhmmss: bool = False) -> Union[None, np.ndarray]:
347
349
 
348
350
  video_timer = SimbaTimer(start=True)
349
351
  colour_tuple_x = list(np.arange(3.5, 203.5, 5))
352
+ original_font_family = copy(plt.rcParams['font.family']) if isinstance(plt.rcParams['font.family'], list) else plt.rcParams['font.family']
353
+
354
+ if font is not None:
355
+ available_fonts = get_fonts()
356
+ if font in available_fonts:
357
+ matplotlib.font_manager._get_font.cache_clear()
358
+ plt.rcParams['font.family'] = font
359
+ else:
360
+ matplotlib.font_manager._get_font.cache_clear()
361
+ plt.rcParams['font.family'] = [font, 'sans-serif']
362
+
350
363
  fig, ax = plt.subplots()
351
- fig.patch.set_facecolor('#fafafa')
352
- ax.set_facecolor('#ffffff')
353
364
  fig.patch.set_facecolor('white')
354
- plt.title(video_name, fontsize=font_size + 6, pad=15, fontweight='bold')
365
+
366
+ plt.title(video_name, fontsize=font_size + 6, pad=25, fontweight='bold')
355
367
  ax.spines['top'].set_visible(False)
356
368
  ax.spines['right'].set_visible(False)
357
369
  ax.spines['left'].set_color('#666666')
@@ -367,7 +379,7 @@ class PlottingMixin(object):
367
379
  if event[0] == x:
368
380
  ix = clf_names.index(x)
369
381
  data_event = event[1][["Start_time", "Bout_time"]]
370
- ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix])
382
+ ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix], edgecolor=edge_clr, linewidth=0.5, alpha=0.9)
371
383
 
372
384
  x_ticks_seconds = np.round(np.linspace(0, x_length / fps, 6))
373
385
  x_ticks_locs = x_ticks_seconds
@@ -375,18 +387,19 @@ class PlottingMixin(object):
375
387
  if hhmmss:
376
388
  x_lbls = [seconds_to_timestamp(sec) for sec in x_ticks_seconds]
377
389
  else:
378
- x_lbls = x_ticks_seconds
390
+ x_lbls = [int(x) for x in x_ticks_seconds]
379
391
 
380
- #x_ticks_locs = x_lbls = np.round(np.linspace(0, x_length / fps, 6))
381
392
  ax.set_xticks(x_ticks_locs)
382
393
  ax.set_xticklabels(x_lbls)
383
394
  ax.set_ylim(0, colour_tuple_x[len(clf_names)])
384
395
  ax.set_yticks(np.arange(5, 5 * len(clf_names) + 1, 5))
385
- ax.set_yticklabels(clf_names, rotation=font_rotation)
396
+ ax.set_yticklabels(clf_names, rotation=font_rotation, ha='right', va='center')
386
397
  ax.tick_params(axis="both", labelsize=font_size)
387
398
  plt.xlabel(x_label, fontsize=font_size + 3)
388
- ax.yaxis.grid(True, linewidth=1.5, color='gray', alpha=0.4, linestyle='--')
389
- #ax.yaxis.grid(True)
399
+ ax.grid(True, axis='both', linewidth=1.0, color='gray', alpha=0.2, linestyle='--', which='major')
400
+
401
+ plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15)
402
+ plt.tight_layout()
390
403
  buffer_ = io.BytesIO()
391
404
  plt.savefig(buffer_, format="png")
392
405
  buffer_.seek(0)
@@ -397,6 +410,11 @@ class PlottingMixin(object):
397
410
  frame = np.uint8(open_cv_image)
398
411
  buffer_.close()
399
412
  plt.close('all')
413
+
414
+ if font is not None:
415
+ plt.rcParams['font.family'] = original_font_family
416
+ matplotlib.font_manager._get_font.cache_clear()
417
+
400
418
  if save_path is not None:
401
419
  cv2.imwrite(save_path, frame)
402
420
  video_timer.stop_timer()
@@ -47,5 +47,5 @@ class OutlierCorrectionSkipper(ConfigReader):
47
47
  self.timer.stop_timer()
48
48
  stdout_success(msg=f"Skipped outlier correction for {len(self.input_csv_paths)} files", elapsed_time=self.timer.elapsed_time_str)
49
49
 
50
- # test = OutlierCorrectionSkipper(config_path='/Users/simon/Desktop/envs/troubleshooting/naresh/project_folder/project_config.ini')
50
+ # test = OutlierCorrectionSkipper(config_path=r"E:\troubleshooting\mitra_emergence\project_folder\project_config.ini")
51
51
  # test.run()
@@ -2,6 +2,7 @@ __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
2
 
3
3
  import os
4
4
  import shutil
5
+ from copy import deepcopy
5
6
  from typing import List, Optional, Union
6
7
 
7
8
  import cv2
@@ -16,7 +17,7 @@ from simba.utils.checks import (
16
17
  from simba.utils.data import create_color_palette, detect_bouts
17
18
  from simba.utils.enums import Formats, Options
18
19
  from simba.utils.errors import NoSpecifiedOutputError
19
- from simba.utils.lookups import get_named_colors
20
+ from simba.utils.lookups import get_fonts, get_named_colors
20
21
  from simba.utils.printing import stdout_success
21
22
  from simba.utils.read_write import get_fn_ext, read_df
22
23
 
@@ -60,16 +61,18 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
60
61
 
61
62
  def __init__(self,
62
63
  config_path: Union[str, os.PathLike],
63
- data_paths: List[Union[str, os.PathLike]],
64
+ data_paths: Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] = None,
64
65
  width: int = 640,
65
66
  height: int = 480,
66
67
  font_size: int = 8,
67
68
  font_rotation: int = 45,
69
+ font: Optional[str] = None,
68
70
  palette: str = 'Set1',
69
- frame_setting: Optional[bool] = False,
70
- video_setting: Optional[bool] = False,
71
- last_frm_setting: Optional[bool] = True,
72
- hhmmss: Optional[bool] = True):
71
+ frame_setting: bool = False,
72
+ video_setting: bool = False,
73
+ last_frm_setting: bool = True,
74
+ hhmmss: bool = True,
75
+ clf_names: Optional[List[str]] = None):
73
76
 
74
77
  if ((frame_setting != True) and (video_setting != True) and (last_frm_setting != True)):
75
78
  raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.")
@@ -78,7 +81,13 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
78
81
  check_int(value=height, min_value=1, name=f'{self.__class__.__name__} height')
79
82
  check_int(value=font_size, min_value=1, name=f'{self.__class__.__name__} font_size')
80
83
  check_int(value=font_rotation, min_value=0, max_value=180, name=f'{self.__class__.__name__} font_rotation')
81
- check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
84
+ if isinstance(data_paths, list):
85
+ check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
86
+ elif isinstance(data_paths, str):
87
+ check_file_exist_and_readable(file_path=data_paths)
88
+ data_paths = [data_paths]
89
+ else:
90
+ data_paths = deepcopy(self.machine_results_paths)
82
91
  check_valid_boolean(value=hhmmss, source=f'{self.__class__.__name__} hhmmss', raise_error=False)
83
92
  palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
84
93
  check_str(name=f'{self.__class__.__name__} palette', value=palette, options=palettes)
@@ -90,7 +99,12 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
90
99
  if not os.path.exists(self.gantt_plot_dir): os.makedirs(self.gantt_plot_dir)
91
100
  self.frame_setting, self.video_setting, self.last_frm_setting = frame_setting, video_setting, last_frm_setting
92
101
  self.width, self.height, self.font_size, self.font_rotation = width, height, font_size, font_rotation
93
- self.data_paths, self.hhmmss = data_paths, hhmmss
102
+ if font is not None:
103
+ check_str(name=f'{self.__class__.__name__} font', value=font, options=list(get_fonts().keys()), raise_error=True)
104
+ if clf_names is not None:
105
+ check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), valid_values=self.clf_names, min_len=1, raise_error=True)
106
+ self.clf_names = clf_names
107
+ self.data_paths, self.hhmmss, self.font = data_paths, hhmmss, font
94
108
  self.colours = get_named_colors()
95
109
  self.colour_tuple_x = list(np.arange(3.5, 203.5, 5))
96
110
 
@@ -121,6 +135,7 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
121
135
  font_size=self.font_size,
122
136
  font_rotation=self.font_rotation,
123
137
  video_name=self.video_name,
138
+ font=self.font,
124
139
  save_path=os.path.join(self.gantt_plot_dir, f"{self.video_name }_final_image.png"),
125
140
  palette=self.clr_lst,
126
141
  hhmmss=self.hhmmss)
@@ -135,6 +150,7 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
135
150
  width=self.width,
136
151
  height=self.height,
137
152
  font_size=self.font_size,
153
+ font=self.font,
138
154
  font_rotation=self.font_rotation,
139
155
  video_name=self.video_name,
140
156
  palette=self.clr_lst,
@@ -156,13 +172,16 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
156
172
  # test = GanttCreatorSingleProcess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
157
173
  # frame_setting=False,
158
174
  # video_setting=False,
159
- # data_paths=[r"C:\troubleshooting\mitra\project_folder\csv\machine_results\592_MA147_Gq_CNO_0515.csv"],
175
+ # data_paths=[r"C:\troubleshooting\mitra\project_folder\csv\machine_results\501_MA142_Gi_CNO_0516.csv"],
160
176
  # last_frm_setting=True,
161
177
  # width=640,
162
178
  # height= 480,
163
179
  # font_size=10,
180
+ # font=None,
164
181
  # font_rotation=45,
165
- # palette='Set1')
182
+ # hhmmss=False,
183
+ # palette='Set1',
184
+ # clf_names=['straub_tail'])
166
185
  # test.run()
167
186
 
168
187