simba-uw-tf-dev 4.6.6__py3-none-any.whl → 4.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. simba/assets/.recent_projects.txt +1 -0
  2. simba/data_processors/blob_location_computer.py +1 -1
  3. simba/data_processors/circling_detector.py +30 -13
  4. simba/data_processors/cuda/image.py +53 -25
  5. simba/data_processors/cuda/statistics.py +57 -19
  6. simba/data_processors/cuda/timeseries.py +1 -1
  7. simba/data_processors/egocentric_aligner.py +1 -1
  8. simba/data_processors/freezing_detector.py +54 -50
  9. simba/feature_extractors/feature_subsets.py +2 -2
  10. simba/feature_extractors/mitra_feature_extractor.py +2 -2
  11. simba/feature_extractors/straub_tail_analyzer.py +4 -4
  12. simba/labelling/standard_labeller.py +1 -1
  13. simba/mixins/config_reader.py +5 -2
  14. simba/mixins/geometry_mixin.py +8 -8
  15. simba/mixins/image_mixin.py +14 -14
  16. simba/mixins/plotting_mixin.py +28 -10
  17. simba/mixins/statistics_mixin.py +39 -9
  18. simba/mixins/timeseries_features_mixin.py +1 -1
  19. simba/mixins/train_model_mixin.py +65 -27
  20. simba/model/inference_batch.py +1 -1
  21. simba/model/yolo_seg_inference.py +3 -3
  22. simba/outlier_tools/skip_outlier_correction.py +1 -1
  23. simba/plotting/gantt_creator.py +29 -10
  24. simba/plotting/gantt_creator_mp.py +50 -17
  25. simba/plotting/heat_mapper_clf_mp.py +2 -2
  26. simba/pose_importers/simba_blob_importer.py +3 -3
  27. simba/roi_tools/roi_aggregate_stats_mp.py +1 -1
  28. simba/roi_tools/roi_clf_calculator_mp.py +1 -1
  29. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
  30. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
  31. simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
  32. simba/ui/pop_ups/gantt_pop_up.py +31 -6
  33. simba/ui/pop_ups/video_processing_pop_up.py +1 -1
  34. simba/utils/custom_feature_extractor.py +1 -1
  35. simba/utils/data.py +2 -2
  36. simba/utils/read_write.py +32 -18
  37. simba/utils/yolo.py +10 -1
  38. simba/video_processors/blob_tracking_executor.py +2 -2
  39. simba/video_processors/clahe_ui.py +1 -1
  40. simba/video_processors/egocentric_video_rotator.py +3 -3
  41. simba/video_processors/multi_cropper.py +1 -1
  42. simba/video_processors/video_processing.py +27 -10
  43. simba/video_processors/videos_to_frames.py +2 -2
  44. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/METADATA +3 -2
  45. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/RECORD +49 -49
  46. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/LICENSE +0 -0
  47. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/WHEEL +0 -0
  48. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/entry_points.txt +0 -0
  49. {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
2
 
3
3
  import os
4
4
  import shutil
5
+ from copy import deepcopy
5
6
  from typing import List, Optional, Union
6
7
 
7
8
  import cv2
@@ -16,7 +17,7 @@ from simba.utils.checks import (
16
17
  from simba.utils.data import create_color_palette, detect_bouts
17
18
  from simba.utils.enums import Formats, Options
18
19
  from simba.utils.errors import NoSpecifiedOutputError
19
- from simba.utils.lookups import get_named_colors
20
+ from simba.utils.lookups import get_fonts, get_named_colors
20
21
  from simba.utils.printing import stdout_success
21
22
  from simba.utils.read_write import get_fn_ext, read_df
22
23
 
@@ -60,16 +61,18 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
60
61
 
61
62
  def __init__(self,
62
63
  config_path: Union[str, os.PathLike],
63
- data_paths: List[Union[str, os.PathLike]],
64
+ data_paths: Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] = None,
64
65
  width: int = 640,
65
66
  height: int = 480,
66
67
  font_size: int = 8,
67
68
  font_rotation: int = 45,
69
+ font: Optional[str] = None,
68
70
  palette: str = 'Set1',
69
- frame_setting: Optional[bool] = False,
70
- video_setting: Optional[bool] = False,
71
- last_frm_setting: Optional[bool] = True,
72
- hhmmss: Optional[bool] = True):
71
+ frame_setting: bool = False,
72
+ video_setting: bool = False,
73
+ last_frm_setting: bool = True,
74
+ hhmmss: bool = True,
75
+ clf_names: Optional[List[str]] = None):
73
76
 
74
77
  if ((frame_setting != True) and (video_setting != True) and (last_frm_setting != True)):
75
78
  raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.")
@@ -78,7 +81,13 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
78
81
  check_int(value=height, min_value=1, name=f'{self.__class__.__name__} height')
79
82
  check_int(value=font_size, min_value=1, name=f'{self.__class__.__name__} font_size')
80
83
  check_int(value=font_rotation, min_value=0, max_value=180, name=f'{self.__class__.__name__} font_rotation')
81
- check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
84
+ if isinstance(data_paths, list):
85
+ check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
86
+ elif isinstance(data_paths, str):
87
+ check_file_exist_and_readable(file_path=data_paths)
88
+ data_paths = [data_paths]
89
+ else:
90
+ data_paths = deepcopy(self.machine_results_paths)
82
91
  check_valid_boolean(value=hhmmss, source=f'{self.__class__.__name__} hhmmss', raise_error=False)
83
92
  palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
84
93
  check_str(name=f'{self.__class__.__name__} palette', value=palette, options=palettes)
@@ -90,7 +99,12 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
90
99
  if not os.path.exists(self.gantt_plot_dir): os.makedirs(self.gantt_plot_dir)
91
100
  self.frame_setting, self.video_setting, self.last_frm_setting = frame_setting, video_setting, last_frm_setting
92
101
  self.width, self.height, self.font_size, self.font_rotation = width, height, font_size, font_rotation
93
- self.data_paths, self.hhmmss = data_paths, hhmmss
102
+ if font is not None:
103
+ check_str(name=f'{self.__class__.__name__} font', value=font, options=list(get_fonts().keys()), raise_error=True)
104
+ if clf_names is not None:
105
+ check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), valid_values=self.clf_names, min_len=1, raise_error=True)
106
+ self.clf_names = clf_names
107
+ self.data_paths, self.hhmmss, self.font = data_paths, hhmmss, font
94
108
  self.colours = get_named_colors()
95
109
  self.colour_tuple_x = list(np.arange(3.5, 203.5, 5))
96
110
 
@@ -121,6 +135,7 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
121
135
  font_size=self.font_size,
122
136
  font_rotation=self.font_rotation,
123
137
  video_name=self.video_name,
138
+ font=self.font,
124
139
  save_path=os.path.join(self.gantt_plot_dir, f"{self.video_name }_final_image.png"),
125
140
  palette=self.clr_lst,
126
141
  hhmmss=self.hhmmss)
@@ -135,6 +150,7 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
135
150
  width=self.width,
136
151
  height=self.height,
137
152
  font_size=self.font_size,
153
+ font=self.font,
138
154
  font_rotation=self.font_rotation,
139
155
  video_name=self.video_name,
140
156
  palette=self.clr_lst,
@@ -156,13 +172,16 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
156
172
  # test = GanttCreatorSingleProcess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
157
173
  # frame_setting=False,
158
174
  # video_setting=False,
159
- # data_paths=[r"C:\troubleshooting\mitra\project_folder\csv\machine_results\592_MA147_Gq_CNO_0515.csv"],
175
+ # data_paths=[r"C:\troubleshooting\mitra\project_folder\csv\machine_results\501_MA142_Gi_CNO_0516.csv"],
160
176
  # last_frm_setting=True,
161
177
  # width=640,
162
178
  # height= 480,
163
179
  # font_size=10,
180
+ # font=None,
164
181
  # font_rotation=45,
165
- # palette='Set1')
182
+ # hhmmss=False,
183
+ # palette='Set1',
184
+ # clf_names=['straub_tail'])
166
185
  # test.run()
167
186
 
168
187
 
@@ -8,16 +8,37 @@ import gc
8
8
  import multiprocessing
9
9
  import os
10
10
  import platform
11
+ import sys
11
12
  from copy import deepcopy
12
13
  from typing import List, Optional, Union
13
14
 
14
15
  import cv2
15
- import matplotlib
16
+
17
+ is_pycharm_ipython = True
18
+ try:
19
+ module_names = list(sys.modules.keys())
20
+ if any('pydev' in str(mod).lower() for mod in module_names):
21
+ is_pycharm_ipython = True
22
+ elif 'IPython' in sys.modules or 'ipython' in sys.modules:
23
+ is_pycharm_ipython = True
24
+ else:
25
+ is_pycharm_ipython = False
26
+ except Exception:
27
+ is_pycharm_ipython = True
28
+
29
+ if not is_pycharm_ipython:
30
+ if 'MPLBACKEND' not in os.environ: os.environ['MPLBACKEND'] = 'Agg'
31
+ try:
32
+ import matplotlib
33
+ matplotlib.use('Agg', force=False)
34
+ except (RecursionError, RuntimeError, ValueError, SystemError):
35
+ import matplotlib
36
+ else:
37
+ import matplotlib
38
+
16
39
  import numpy as np
17
40
  import pandas as pd
18
41
 
19
- matplotlib.use('Agg')
20
-
21
42
  from simba.mixins.config_reader import ConfigReader
22
43
  from simba.mixins.plotting_mixin import PlottingMixin
23
44
  from simba.utils.checks import (
@@ -28,6 +49,7 @@ from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
28
49
  terminate_cpu_pool)
29
50
  from simba.utils.enums import Formats, Options
30
51
  from simba.utils.errors import NoSpecifiedOutputError
52
+ from simba.utils.lookups import get_fonts
31
53
  from simba.utils.printing import SimbaTimer, stdout_success
32
54
  from simba.utils.read_write import (concatenate_videos_in_folder,
33
55
  create_directory, find_core_cnt,
@@ -51,6 +73,7 @@ def gantt_creator_mp(data: np.array,
51
73
  width: int,
52
74
  height: int,
53
75
  font_size: int,
76
+ font: str,
54
77
  font_rotation: int,
55
78
  palette: np.ndarray,
56
79
  hhmmss: bool):
@@ -72,6 +95,7 @@ def gantt_creator_mp(data: np.array,
72
95
  width=width,
73
96
  height=height,
74
97
  font_size=font_size,
98
+ font=font,
75
99
  font_rotation=font_rotation,
76
100
  video_name=video_name,
77
101
  save_path=None,
@@ -139,9 +163,11 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
139
163
  height: int = 480,
140
164
  font_size: int = 8,
141
165
  font_rotation: int = 45,
166
+ font: Optional[str] = None,
142
167
  palette: str = 'Set1',
143
168
  core_cnt: int = -1,
144
- hhmmss: bool = False):
169
+ hhmmss: bool = False,
170
+ clf_names: Optional[List[str]] = None):
145
171
 
146
172
  check_file_exist_and_readable(file_path=config_path)
147
173
  if (not frame_setting) and (not video_setting) and (not last_frm_setting):
@@ -157,6 +183,8 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
157
183
  check_int(name=f"{self.__class__.__name__} core_cnt",value=core_cnt, min_value=-1, unaccepted_vals=[0], max_value=find_core_cnt()[0])
158
184
  self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
159
185
  self.width, self.height, self.font_size, self.font_rotation, self.hhmmss = width, height, font_size, font_rotation, hhmmss
186
+ if font is not None:
187
+ check_str(name=f'{self.__class__.__name__} font', value=font, options=list(get_fonts().keys()), raise_error=True)
160
188
  ConfigReader.__init__(self, config_path=config_path, create_logger=False)
161
189
  if isinstance(data_paths, list):
162
190
  check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
@@ -166,9 +194,12 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
166
194
  else:
167
195
  data_paths = deepcopy(self.machine_results_paths)
168
196
  for file_path in data_paths: check_file_exist_and_readable(file_path=file_path)
197
+ if clf_names is not None:
198
+ check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), valid_values=self.clf_names, min_len=1, raise_error=True)
199
+ self.clf_names = clf_names
169
200
  PlottingMixin.__init__(self)
170
201
  self.clr_lst = create_color_palette(pallete_name=palette, increments=len(self.body_parts_lst) + 1, as_int=True, as_rgb_ratio=True)
171
- self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting = frame_setting, video_setting,data_paths, last_frm_setting
202
+ self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting, self.font = frame_setting, video_setting,data_paths, last_frm_setting, font
172
203
  if not os.path.exists(self.gantt_plot_dir): os.makedirs(self.gantt_plot_dir)
173
204
  if platform.system() == "Darwin":
174
205
  multiprocessing.set_start_method("spawn", force=True)
@@ -206,6 +237,7 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
206
237
  font_size=self.font_size,
207
238
  font_rotation=self.font_rotation,
208
239
  video_name=self.video_name,
240
+ font=self.font,
209
241
  save_path=os.path.join(self.gantt_plot_dir, f"{self.video_name}_final_image.png"),
210
242
  palette=self.clr_lst,
211
243
  hhmmss=self.hhmmss)
@@ -225,6 +257,7 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
225
257
  width=self.width,
226
258
  height=self.height,
227
259
  font_size=self.font_size,
260
+ font=self.font,
228
261
  font_rotation=self.font_rotation,
229
262
  video_name=self.video_name,
230
263
  palette=self.clr_lst,
@@ -254,18 +287,18 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
254
287
  # font_rotation= 45)
255
288
  # test.run()
256
289
 
257
- if __name__ == "__main__":
258
- test = GanttCreatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
259
- frame_setting=False,
260
- video_setting=True,
261
- data_paths=r"D:\troubleshooting\maplight_ri\project_folder\csv\machine_results\Trial_1_C24_D1_1.csv",
262
- last_frm_setting=False,
263
- width=640,
264
- height= 480,
265
- font_size=10,
266
- font_rotation= 45,
267
- core_cnt=16)
268
- test.run()
290
+ # if __name__ == "__main__":
291
+ # test = GanttCreatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
292
+ # frame_setting=False,
293
+ # video_setting=True,
294
+ # data_paths=r"D:\troubleshooting\maplight_ri\project_folder\csv\machine_results\Trial_1_C24_D1_1.csv",
295
+ # last_frm_setting=False,
296
+ # width=640,
297
+ # height= 480,
298
+ # font_size=10,
299
+ # font_rotation= 45,
300
+ # core_cnt=16)
301
+ # test.run()
269
302
 
270
303
 
271
304
  # test = GanttCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini',
@@ -98,14 +98,14 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
98
98
 
99
99
 
100
100
  :example II:
101
- >>> test = HeatMapperClfMultiprocess(config_path=r"C:\troubleshooting\RAT_NOR\project_folder\project_config.ini",
101
+ >>> test = HeatMapperClfMultiprocess(config_path=r"C:/troubleshooting/RAT_NOR/project_folder/project_config.ini",
102
102
  >>> style_attr = {'palette': 'jet', 'shading': 'gouraud', 'bin_size': 50, 'max_scale': 'auto'},
103
103
  >>> final_img_setting=True,
104
104
  >>> video_setting=True,
105
105
  >>> frame_setting=True,
106
106
  >>> bodypart='Ear_left',
107
107
  >>> clf_name='straub_tail',
108
- >>> data_paths=[r"C:\troubleshooting\RAT_NOR\project_folder\csv\test\2022-06-20_NOB_DOT_4.csv"])
108
+ >>> data_paths=[r"C:/troubleshooting/RAT_NOR/project_folder/csv/test/2022-06-20_NOB_DOT_4.csv"])
109
109
  >>> test.run()
110
110
  """
111
111
 
@@ -44,10 +44,10 @@ class SimBABlobImporter(ConfigReader):
44
44
  :param Optional[bool] verbose: If True, prints progress messages. Default: True.
45
45
 
46
46
  :example:
47
- >>> r = SimBABlobImporter(config_path=r"C:\troubleshooting\simba_blob_project\project_folder\project_config.ini", data_path=r'C:\troubleshooting\simba_blob_project\data')
47
+ >>> r = SimBABlobImporter(config_path=r"C:/troubleshooting/simba_blob_project/project_folder/project_config.ini", data_path=r'C:/troubleshooting/simba_blob_project/data')
48
48
  >>> r.run()
49
- >>> r = SimBABlobImporter(config_path=r"C:\troubleshooting\simba_blob_project\project_folder\project_config.ini",
50
- ... data_path=r'C:\troubleshooting\simba_blob_project\data',
49
+ >>> r = SimBABlobImporter(config_path=r"C:/troubleshooting/simba_blob_project/project_folder/project_config.ini",
50
+ ... data_path=r'C:/troubleshooting/simba_blob_project/data',
51
51
  ... smoothing_settings={'method': 'savitzky-golay', 'time_window': 100},
52
52
  ... interpolation_settings={'method': 'nearest', 'type': 'body-parts'})
53
53
  >>> r.run()
@@ -168,7 +168,7 @@ class ROIAggregateStatisticsAnalyzerMultiprocess(ConfigReader, FeatureExtraction
168
168
  :param save_path (str | os.PathLike, optional): Path to save summary statistics.
169
169
 
170
170
  :example:
171
- >>> analyzer = ROIAggregateStatisticsAnalyzerMultiprocess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini", body_parts=['Center'], first_entry_time=True, threshold=0.0, calculate_distances=True, transpose=False, detailed_bout_data=True)
171
+ >>> analyzer = ROIAggregateStatisticsAnalyzerMultiprocess(config_path=r"C:/troubleshooting/mitra/project_folder/project_config.ini", body_parts=['Center'], first_entry_time=True, threshold=0.0, calculate_distances=True, transpose=False, detailed_bout_data=True)
172
172
  >>> analyzer.run()
173
173
  >>> analyzer.save()
174
174
  """
@@ -150,7 +150,7 @@ class ROIClfCalculatorMultiprocess(ConfigReader):
150
150
  'GitHub tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/Scenario2.md#part-4--analyze-machine-results`__.
151
151
 
152
152
  :example:
153
- >>> analyzer = ROIClfCalculatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini", bp_names=['resident_NOSE'], clf_names=['attack'], clf_time=True, started_bout_cnt=True, ended_bout_cnt=False, bout_table=True, transpose=True, core_cnt=20)
153
+ >>> analyzer = ROIClfCalculatorMultiprocess(config_path=r"D:/troubleshooting/maplight_ri/project_folder/project_config.ini", bp_names=['resident_NOSE'], clf_names=['attack'], clf_time=True, started_bout_cnt=True, ended_bout_cnt=False, bout_table=True, transpose=True, core_cnt=20)
154
154
  >>> analyzer.run()
155
155
  >>> analyzer.save()
156
156
  """
@@ -54,15 +54,15 @@ class COCOKeypoints2Yolo:
54
54
  :return: None
55
55
 
56
56
  :example:
57
- >>> runner = COCOKeypoints2Yolo(coco_path=r"D:\cvat_annotations\frames\coco_keypoints_1\s1\annotations\s1.json", img_dir=r"D:\cvat_annotations\frames\simon", save_dir=r"D:\cvat_annotations\frames\yolo_keypoints", clahe=True)
57
+ >>> runner = COCOKeypoints2Yolo(coco_path=r"D:/cvat_annotations/frames/coco_keypoints_1/s1/annotations/s1.json", img_dir=r"D:/cvat_annotations/frames/simon", save_dir=r"D:/cvat_annotations/frames/yolo_keypoints", clahe=True)
58
58
  >>> runner.run()
59
59
 
60
60
  :example II:
61
- >>> runner = COCOKeypoints2Yolo(coco_path=r"D:\cvat_annotations\frames\coco_keypoints_1\merged.json", img_dir=r"D:\cvat_annotations\frames", save_dir=r"D:\cvat_annotations\frames\yolo", clahe=False)
61
+ >>> runner = COCOKeypoints2Yolo(coco_path=r"D:/cvat_annotations/frames/coco_keypoints_1/merged.json", img_dir=r"D:/cvat_annotations/frames", save_dir=r"D:/cvat_annotations/frames/yolo", clahe=False)
62
62
  >>> runner.run()
63
63
 
64
64
  :example III:
65
- >>> runner = COCOKeypoints2Yolo(coco_path=r"E:\netholabs_videos\mosaics\subset\to_annotate\2d_mosaic_batch_1.json", img_dir=r"E:\netholabs_videos\mosaics\subset\to_annotate", save_dir=r"E:\netholabs_videos\mosaics\yolo_mdl", clahe=False)
65
+ >>> runner = COCOKeypoints2Yolo(coco_path=r"E:/netholabs_videos/mosaics/subset/to_annotate/2d_mosaic_batch_1.json", img_dir=r"E:/netholabs_videos/mosaics/subset/to_annotate", save_dir=r"E:/netholabs_videos/mosaics/yolo_mdl", clahe=False)
66
66
  >>> runner.run()
67
67
 
68
68
  :references:
@@ -58,11 +58,11 @@ class COCOKeypoints2YoloBbox:
58
58
  :return: None
59
59
 
60
60
  :example:
61
- >>> runner = COCOKeypoints2Yolo(coco_path=r"D:\cvat_annotations\frames\coco_keypoints_1\s1\annotations\s1.json", img_dir=r"D:\cvat_annotations\frames\simon", save_dir=r"D:\cvat_annotations\frames\yolo_keypoints", clahe=True)
61
+ >>> runner = COCOKeypoints2YoloBbox(coco_path=r"D:/cvat_annotations/frames/coco_keypoints_1/s1/annotations/s1.json", img_dir=r"D:/cvat_annotations/frames/simon", save_dir=r"D:/cvat_annotations/frames/yolo_keypoints", clahe=True)
62
62
  >>> runner.run()
63
63
 
64
64
  :example II:
65
- >>> runner = COCOKeypoints2Yolo(coco_path=r"D:\cvat_annotations\frames\coco_keypoints_1\merged.json", img_dir=r"D:\cvat_annotations\frames", save_dir=r"D:\cvat_annotations\frames\yolo", clahe=False)
65
+ >>> runner = COCOKeypoints2YoloBbox(coco_path=r"D:/cvat_annotations/frames/coco_keypoints_1/merged.json", img_dir=r"D:/cvat_annotations/frames", save_dir=r"D:/cvat_annotations/frames/yolo", clahe=False)
66
66
  >>> runner.run()
67
67
 
68
68
  :references:
@@ -49,8 +49,8 @@ class SklearnVisualizationPopUp(PopUpMixin, ConfigReader):
49
49
  pose_palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
50
50
  PopUpMixin.__init__(self, title="VISUALIZE CLASSIFICATION (SKLEARN) RESULTS", icon='photos')
51
51
  bp_threshold_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="BODY-PART VISUALIZATION THRESHOLD", icon_name='threshold', icon_link=Links.SKLEARN_PLOTS.value, padx=5, pady=5, relief='solid')
52
- self.bp_threshold_lbl = SimBALabel(parent=bp_threshold_frm, txt="Body-parts detected below the set threshold won't be shown in the output videos.", font=Formats.FONT_REGULAR_ITALICS.value)
53
- self.bp_threshold_entry = Entry_Box(parent=bp_threshold_frm, fileDescription='BODY-PART PROBABILITY THRESHOLD: ', labelwidth=40, entry_box_width=15, value=0.00, img='green_dice')
52
+ self.bp_threshold_lbl = SimBALabel(parent=bp_threshold_frm, txt="Body-parts detected below the set threshold won't be shown in the output videos (use 0.0 to see all body-part predictions)", font=Formats.FONT_REGULAR_ITALICS.value)
53
+ self.bp_threshold_entry = Entry_Box(parent=bp_threshold_frm, fileDescription='BODY-PART PROBABILITY THRESHOLD: ', labelwidth=40, entry_box_width=15, value=0.00, img='green_dice', justify='center')
54
54
  self.get_bp_probability_threshold()
55
55
 
56
56
  bp_threshold_frm.grid(row=0, column=0, sticky=NW)
@@ -13,6 +13,7 @@ from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, SimbaButton,
13
13
  from simba.utils.checks import check_if_filepath_list_is_empty
14
14
  from simba.utils.enums import Formats, Links, Options
15
15
  from simba.utils.errors import NoSpecifiedOutputError
16
+ from simba.utils.lookups import get_fonts
16
17
  from simba.utils.read_write import find_files_of_filetypes_in_directory
17
18
 
18
19
 
@@ -29,7 +30,9 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
29
30
  check_if_filepath_list_is_empty(filepaths=self.machine_results_paths,error_msg=f"SIMBA ERROR: Zero files found in the {self.machine_results_dir} directory. Create classification results before visualizing gantt charts",)
30
31
  palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
31
32
  self.data_paths = find_files_of_filetypes_in_directory(directory=self.machine_results_dir, extensions=[f'.{self.file_type}'], as_dict=True)
32
- max_file_name_len = max(len(k) for k in self.data_paths) + 5
33
+ max_file_name_len, fonts = max(len(k) for k in self.data_paths) + 5, list(get_fonts(sort_alphabetically=True).keys())
34
+ fonts.insert(0, 'AUTO')
35
+ default_font = 'Arial' if 'Arial' in fonts else 'AUTO'
33
36
  PopUpMixin.__init__(self, config_path=config_path, title="VISUALIZE GANTT PLOTS", icon='gantt_small')
34
37
 
35
38
  self.style_settings_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="STYLE SETTINGS", icon_name='settings', icon_link=Links.GANTT_PLOTS.value, relief='solid', padx=5, pady=5)
@@ -39,6 +42,15 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
39
42
  self.palette_dropdown = SimBADropDown(parent=self.style_settings_frm, dropdown_options=palettes, label='COLOR PALETTE: ', label_width=30, dropdown_width=30, value='Set1', img='palette_small')
40
43
  self.time_format_dropdown = SimBADropDown(parent=self.style_settings_frm, dropdown_options=['SECONDS', 'HH:MM:SS'], label='X-AXIS TIME FORMAT: ', label_width=30, dropdown_width=30, value='SECONDS', img='timer_2')
41
44
  self.core_dropdown = SimBADropDown(parent=self.style_settings_frm, dropdown_options=list(range(1, self.cpu_cnt+1)), label='CPU CORES: ', label_width=30, dropdown_width=30, value=int(self.cpu_cnt/2), img='cpu_small')
45
+ self.font_dropdown = SimBADropDown(parent=self.style_settings_frm, dropdown_options=fonts, label='FONT: ', label_width=30, dropdown_width=30, value=default_font, img='font')
46
+
47
+
48
+ self.clf_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="BEHAVIORS", icon_name='forest', icon_link=Links.GANTT_PLOTS.value, relief='solid', padx=5, pady=5)
49
+ self.clf_choices = {}
50
+ for cnt, clf_name in enumerate(self.clf_names):
51
+ gantt_frames_cb, self.gantt_frames_var = SimbaCheckbox(parent=self.clf_frm, txt=clf_name, val=True)
52
+ self.clf_choices[clf_name] = self.gantt_frames_var
53
+ gantt_frames_cb.grid(row=cnt, column=0, sticky=NW)
42
54
 
43
55
  self.settings_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="VISUALIZATION SETTINGS", icon_name='eye', icon_link=Links.GANTT_PLOTS.value, relief='solid', padx=5, pady=5)
44
56
  gantt_frames_cb, self.gantt_frames_var = SimbaCheckbox(parent=self.settings_frm, txt='CREATE FRAMES', txt_img='frames', val=False)
@@ -57,18 +69,20 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
57
69
  self.font_rotation_dropdown.grid(row=2, sticky=NW)
58
70
  self.palette_dropdown.grid(row=3, sticky=NW)
59
71
  self.time_format_dropdown.grid(row=4, sticky=NW)
60
- self.core_dropdown.grid(row=5, sticky=NW)
72
+ self.font_dropdown.grid(row=5, sticky=NW)
73
+ self.core_dropdown.grid(row=6, sticky=NW)
61
74
 
62
- self.settings_frm.grid(row=1, sticky=NW, padx=10, pady=10)
75
+ self.clf_frm.grid(row=1, sticky=NW, padx=10, pady=10)
76
+ self.settings_frm.grid(row=2, sticky=NW, padx=10, pady=10)
63
77
  gantt_videos_cb.grid(row=0, sticky=NW)
64
78
  gantt_frames_cb.grid(row=1, sticky=W)
65
79
  gantt_last_frame_cb.grid(row=2, sticky=NW)
66
80
 
67
- self.run_single_video_frm.grid(row=2, sticky=NW)
81
+ self.run_single_video_frm.grid(row=3, sticky=NW)
68
82
  self.run_single_video_btn.grid(row=0, sticky=NW)
69
83
  self.single_video_dropdown.grid(row=0, column=1, sticky=NW)
70
84
 
71
- self.run_multiple_videos.grid(row=3, sticky=NW)
85
+ self.run_multiple_videos.grid(row=4, sticky=NW)
72
86
  self.run_multiple_video_btn.grid(row=0, sticky=NW)
73
87
 
74
88
  self.main_frm.mainloop()
@@ -85,8 +99,14 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
85
99
  video_setting = self.gantt_videos_var.get()
86
100
  last_frm_setting = self.gantt_last_frame_var.get()
87
101
  palette = self.palette_dropdown.get_value()
102
+ font = self.font_dropdown.get_value()
103
+ font = None if font == 'AUTO' else font
88
104
  hhmmss = True if self.time_format_dropdown.get_value() == 'HH:MM:SS' else False
89
-
105
+ clf_names = []
106
+ for clf_name, clf_val in self.clf_choices.items():
107
+ if clf_val.get(): clf_names.append(clf_name)
108
+ if len(clf_names) < 1:
109
+ raise NoSpecifiedOutputError(msg="Select AT LEAST one behavior name.")
90
110
  if not frame_setting and not video_setting and not last_frm_setting:
91
111
  raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.")
92
112
 
@@ -103,6 +123,8 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
103
123
  data_paths=data_paths,
104
124
  width=width,
105
125
  height=height,
126
+ font=font,
127
+ clf_names=clf_names,
106
128
  font_size=font_size,
107
129
  font_rotation=font_rotation,
108
130
  core_cnt=core_cnt,
@@ -116,7 +138,9 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
116
138
  last_frm_setting=last_frm_setting,
117
139
  data_paths=data_paths,
118
140
  width=width,
141
+ font=font,
119
142
  height=height,
143
+ clf_names=clf_names,
120
144
  font_size=font_size,
121
145
  font_rotation=font_rotation,
122
146
  palette=palette)
@@ -124,6 +148,7 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
124
148
 
125
149
 
126
150
  #_ = GanttPlotPopUp(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini")
151
+ #_ = GanttPlotPopUp(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
127
152
  # _ = GanttPlotPopUp(config_path=r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini")
128
153
 
129
154
  # _ = GanttPlotPopUp(config_path=r"C:\troubleshooting\RAT_NOR\project_folder\project_config.ini")
@@ -1708,7 +1708,7 @@ class ClipMultipleVideosByTimestamps(PopUpMixin):
1708
1708
  check_that_hhmmss_start_is_before_end(start_time=start, end_time=end, name=video_name)
1709
1709
  check_if_hhmmss_timestamp_is_valid_part_of_video(timestamp=start, video_path=self.video_paths[video_name])
1710
1710
  check_if_hhmmss_timestamp_is_valid_part_of_video(timestamp=end, video_path=self.video_paths[video_name])
1711
- clip_video_in_range(file_path=self.video_paths[video_name], start_time=start, end_time=end, out_dir=self.save_dir, overwrite=True, include_clip_time_in_filename=False, gpu=gpu, quality=quality_pct)
1711
+ clip_video_in_range(file_path=self.video_paths[video_name], start_time=start, end_time=end, out_dir=self.save_dir, overwrite=True, include_clip_time_in_filename=False, gpu=gpu, quality=quality_pct, codec='libx264', verbose=True)
1712
1712
  timer.stop_timer()
1713
1713
  stdout_success(msg=f"{len(self.entry_boxes)} videos clipped by time-stamps and saved in {self.save_dir}", elapsed_time=timer.elapsed_time_str,)
1714
1714
 
@@ -30,7 +30,7 @@ class CustomFeatureExtractor(ConfigReader):
30
30
  4. Handle cases of multiple classes and missing configuration arguments.
31
31
  5. Invokes the feature extraction process if conditions are met.
32
32
 
33
- .. notes::
33
+ .. note::
34
34
 
35
35
  `Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/extractFeatures.md>`_.
36
36
 
simba/utils/data.py CHANGED
@@ -1785,8 +1785,8 @@ def fft_lowpass_filter(data: np.ndarray, cut_off: float = 0.1) -> np.ndarray:
1785
1785
 
1786
1786
  :example:
1787
1787
  >>> from simba.utils.read_write import read_df
1788
- >>> IN_PATH = r"C:\troubleshooting\RAT_NOR\project_folder\csv\outlier_corrected_movement_location\2022-06-20_NOB_DOT_4.csv"
1789
- >>> OUT_PATH = r"C:\troubleshooting\RAT_NOR\project_folder\csv\outlier_corrected_movement_location\2022-06-20_NOB_DOT_4_filtered.csv"
1788
+ >>> IN_PATH = r"C:/troubleshooting/RAT_NOR/project_folder/csv/outlier_corrected_movement_location/2022-06-20_NOB_DOT_4.csv"
1789
+ >>> OUT_PATH = r"C:/troubleshooting/RAT_NOR/project_folder/csv/outlier_corrected_movement_location/2022-06-20_NOB_DOT_4_filtered.csv"
1790
1790
  >>> df = read_df(file_path=IN_PATH)
1791
1791
  >>> data = df.values
1792
1792
  >>> x = fft_lowpass_filter(data=data, cut_off=0.1)
simba/utils/read_write.py CHANGED
@@ -97,13 +97,20 @@ def read_df(file_path: Union[str, os.PathLike],
97
97
  .. note::
98
98
  For improved runtime, defaults to :external:py:meth:`pyarrow.csv.write_cs` if file type is ``csv``.
99
99
 
100
- :parameter str file_path: Path to data file
101
- :parameter str file_type: Type of data. OPTIONS: 'parquet', 'csv', 'pickle'.
102
- :parameter Optional[bool]: If the input file has an initial index column. Default: True.
103
- :parameter Optional[List[str]] remove_columns: If not None, then remove columns in lits.
104
- :parameter Optional[List[str]] usecols: If not None, then keep columns in list.
105
- :parameter bool check_multiindex: check file is multi-index headers. Default: False.
106
- :parameter int multi_index_headers_to_keep: If reading multi-index file, and we want to keep one of the dropped multi-index levels as the header in the output file, specify the index of the multiindex hader as int.
100
+ .. csv-table::
101
+ :header: EXPECTED RUNTIMES
102
+ :file: ../../docs/tables/read_df.csv
103
+ :widths: 10, 45, 45
104
+ :align: center
105
+ :header-rows: 1
106
+
107
+ :param str file_path: Path to data file
108
+ :param str file_type: Type of data. OPTIONS: 'parquet', 'csv', 'pickle'.
109
+ :param Optional[bool]: If the input file has an initial index column. Default: True.
110
+ :param Optional[List[str]] remove_columns: If not None, then remove columns in lits.
111
+ :param Optional[List[str]] usecols: If not None, then keep columns in list.
112
+ :param bool check_multiindex: check file is multi-index headers. Default: False.
113
+ :param int multi_index_headers_to_keep: If reading multi-index file, and we want to keep one of the dropped multi-index levels as the header in the output file, specify the index of the multiindex hader as int.
107
114
  :return: Table data in pd.DataFrame format.
108
115
  :rtype: pd.DataFrame
109
116
 
@@ -207,11 +214,18 @@ def write_df(df: pd.DataFrame,
207
214
  .. note::
208
215
  For improved runtime, defaults to ``pyarrow.csv`` if file_type == ``csv``.
209
216
 
210
- :parameter pd.DataFrame df: Pandas dataframe to save to disk.
211
- :parameter str file_type: Type of data. OPTIONS: ``parquet``, ``csv``, ``pickle``.
212
- :parameter str save_path: Location where to store the data.
213
- :parameter bool check_multiindex: check if input file is multi-index headers. Default: False.
214
- :parameter bool verbose: Prints message on completion. Default: False.
217
+ .. csv-table::
218
+ :header: EXPECTED RUNTIMES
219
+ :file: ../../docs/tables/write_df.csv
220
+ :widths: 10, 45, 45
221
+ :align: center
222
+ :header-rows: 1
223
+
224
+ :param pd.DataFrame df: Pandas dataframe to save to disk.
225
+ :param str file_type: Type of data. OPTIONS: ``parquet``, ``csv``, ``pickle``.
226
+ :param str save_path: Location where to store the data.
227
+ :param bool check_multiindex: check if input file is multi-index headers. Default: False.
228
+ :param bool verbose: Prints message on completion. Default: False.
215
229
 
216
230
  :example:
217
231
  >>> write_df(df=df, file_type='csv', save_path='project_folder/csv/input_csv/Video_1.csv')
@@ -1130,8 +1144,8 @@ def get_file_name_info_in_directory(directory: Union[str, os.PathLike], file_typ
1130
1144
  :return dict: All found files as values and file base names as keys.
1131
1145
 
1132
1146
  :example:
1133
- >>> get_file_name_info_in_directory(directory='C:\project_folder\csv\machine_results', file_type='csv')
1134
- >>> {'Video_1': 'C:\project_folder\csv\machine_results\Video_1'}
1147
+ >>> get_file_name_info_in_directory(directory='C:/project_folder/csv/machine_results', file_type='csv')
1148
+ >>> {'Video_1': 'C:/project_folder/csv/machine_results/Video_1'}
1135
1149
  """
1136
1150
 
1137
1151
  results = {}
@@ -2630,7 +2644,7 @@ def bento_file_reader(file_path: Union[str, os.PathLike],
2630
2644
  :rtype: Dict[str, pd.DataFrame]
2631
2645
 
2632
2646
  :example:
2633
- >>> bento_file_reader(file_path=r"C:\troubleshooting\bento_test\bento_files\20240812_crumpling3.annot")
2647
+ >>> bento_file_reader(file_path=r"C:/troubleshooting/bento_test/bento_files/20240812_crumpling3.annot")
2634
2648
  """
2635
2649
 
2636
2650
  def _orient_columns_melt(df: pd.DataFrame) -> pd.DataFrame:
@@ -2953,7 +2967,7 @@ def labelme_to_dlc(labelme_dir: Union[str, os.PathLike],
2953
2967
  :return: None
2954
2968
 
2955
2969
  :example:
2956
- >>> labelme_dir = r'D:\ts_annotations'
2970
+ >>> labelme_dir = r'D:/ts_annotations'
2957
2971
  >>> labelme_to_dlc(labelme_dir=labelme_dir)
2958
2972
  """
2959
2973
 
@@ -3597,8 +3611,8 @@ def osf_download(project_id: str, save_dir: Union[str, os.PathLike], storage: st
3597
3611
  :param bool overwrite: If True, overwrite existing files. If False, skip existing files (default: False).
3598
3612
 
3599
3613
  :example:
3600
- >>> osf_download(project_id="7fgwn", save_dir=r'E:\rgb_white_vs_black_imgs')
3601
- >>> osf_download(project_id="kym42", save_dir=r'E:\crim13_imgs', overwrite=True)
3614
+ >>> osf_download(project_id="7fgwn", save_dir=r'E:/rgb_white_vs_black_imgs')
3615
+ >>> osf_download(project_id="kym42", save_dir=r'E:/crim13_imgs', overwrite=True)
3602
3616
  """
3603
3617
 
3604
3618
  _ = get_pkg_version(pkg='osfclient', raise_error=True)
simba/utils/yolo.py CHANGED
@@ -47,6 +47,9 @@ def fit_yolo(weights_path: Union[str, os.PathLike],
47
47
  `Download initial weights <https://huggingface.co/Ultralytics>`__.
48
48
  `Example model_yaml <https://github.com/sgoldenlab/simba/blob/master/misc/ex_yolo_model.yaml>`__.
49
49
 
50
+ .. seealso::
51
+ For the recommended wrapper class with parameter validation, see :class:`simba.model.yolo_fit.FitYolo`.
52
+
50
53
  :param initial_weights: Path to the pre-trained YOLO model weights (usually a `.pt` file). Example weights can be found [here](https://huggingface.co/Ultralytics).
51
54
  :param model_yaml: YAML file containing paths to the training, validation, and testing datasets and the object class mappings. Example YAML file can be found [here](https://github.com/sgoldenlab/simba/blob/master/misc/ex_yolo_model.yaml).
52
55
  :param save_path: Directory path where the trained model, logs, and results will be saved.
@@ -55,7 +58,7 @@ def fit_yolo(weights_path: Union[str, os.PathLike],
55
58
  :return: None. The trained model and associated training logs are saved in the specified `project_path`.
56
59
 
57
60
  :example:
58
- >>> fit_yolo(initial_weights=r"C:\troubleshooting\coco_data\weights\yolov8n-obb.pt", data=r"C:\troubleshooting\coco_data\model.yaml", save_path=r"C:\troubleshooting\coco_data\mdl", batch=16)
61
+ >>> fit_yolo(initial_weights=r"C:/troubleshooting/coco_data/weights/yolov8n-obb.pt", data=r"C:/troubleshooting/coco_data/model.yaml", save_path=r"C:/troubleshooting/coco_data/mdl", batch=16)
59
62
  """
60
63
 
61
64
  if not _is_cuda_available()[0]:
@@ -83,6 +86,9 @@ def load_yolo_model(weights_path: Union[str, os.PathLike],
83
86
  """
84
87
  Load a YOLO model.
85
88
 
89
+ .. seealso::
90
+ For recommended wrapper classes that use this function, see :class:`simba.model.yolo_fit.FitYolo`, :class:`simba.model.yolo_inference.YoloInference`, :class:`simba.model.yolo_pose_inference.YOLOPoseInference`, :class:`simba.model.yolo_seg_inference.YOLOSegmentationInference`, and :class:`simba.model.yolo_pose_track_inference.YOLOPoseTrackInference`.
91
+
86
92
  :param Union[str, os.PathLike] weights_path: Path to model weights (.pt, .engine, etc).
87
93
  :param bool verbose: Whether to print loading info.
88
94
  :param Optional[str] format: Export format, one of VALID_FORMATS or None to skip export.
@@ -169,6 +175,9 @@ def yolo_predict(model: YOLO,
169
175
  """
170
176
  Produce YOLO predictions.
171
177
 
178
+ .. seealso::
179
+ For recommended wrapper classes that use this function, see :class:`simba.model.yolo_inference.YoloInference`, :class:`simba.model.yolo_pose_inference.YOLOPoseInference`, and :class:`simba.model.yolo_seg_inference.YOLOSegmentationInference`.
180
+
172
181
  :param Union[str, os.PathLike] model: Loaded ultralytics.YOLO model. Returned by :func:`~simba.bounding_box_tools.yolo.model.load_yolo_model`.
173
182
  :param Union[str, os.PathLike, np.ndarray] source: Path to video, video stream, directory, image, or image as loaded array.
174
183
  :param bool half: Whether to use half precision (FP16) for inference to speed up processing.