simba-uw-tf-dev 4.6.6__py3-none-any.whl → 4.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simba/assets/.recent_projects.txt +1 -0
- simba/data_processors/blob_location_computer.py +1 -1
- simba/data_processors/circling_detector.py +30 -13
- simba/data_processors/cuda/image.py +53 -25
- simba/data_processors/cuda/statistics.py +57 -19
- simba/data_processors/cuda/timeseries.py +1 -1
- simba/data_processors/egocentric_aligner.py +1 -1
- simba/data_processors/freezing_detector.py +54 -50
- simba/feature_extractors/feature_subsets.py +2 -2
- simba/feature_extractors/mitra_feature_extractor.py +2 -2
- simba/feature_extractors/straub_tail_analyzer.py +4 -4
- simba/labelling/standard_labeller.py +1 -1
- simba/mixins/config_reader.py +5 -2
- simba/mixins/geometry_mixin.py +8 -8
- simba/mixins/image_mixin.py +14 -14
- simba/mixins/plotting_mixin.py +28 -10
- simba/mixins/statistics_mixin.py +39 -9
- simba/mixins/timeseries_features_mixin.py +1 -1
- simba/mixins/train_model_mixin.py +65 -27
- simba/model/inference_batch.py +1 -1
- simba/model/yolo_seg_inference.py +3 -3
- simba/outlier_tools/skip_outlier_correction.py +1 -1
- simba/plotting/gantt_creator.py +29 -10
- simba/plotting/gantt_creator_mp.py +50 -17
- simba/plotting/heat_mapper_clf_mp.py +2 -2
- simba/pose_importers/simba_blob_importer.py +3 -3
- simba/roi_tools/roi_aggregate_stats_mp.py +1 -1
- simba/roi_tools/roi_clf_calculator_mp.py +1 -1
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +3 -3
- simba/third_party_label_appenders/transform/coco_keypoints_to_yolo_bbox.py +2 -2
- simba/ui/pop_ups/clf_plot_pop_up.py +2 -2
- simba/ui/pop_ups/gantt_pop_up.py +31 -6
- simba/ui/pop_ups/video_processing_pop_up.py +1 -1
- simba/utils/custom_feature_extractor.py +1 -1
- simba/utils/data.py +2 -2
- simba/utils/read_write.py +32 -18
- simba/utils/yolo.py +10 -1
- simba/video_processors/blob_tracking_executor.py +2 -2
- simba/video_processors/clahe_ui.py +1 -1
- simba/video_processors/egocentric_video_rotator.py +3 -3
- simba/video_processors/multi_cropper.py +1 -1
- simba/video_processors/video_processing.py +27 -10
- simba/video_processors/videos_to_frames.py +2 -2
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/METADATA +3 -2
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/RECORD +49 -49
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/LICENSE +0 -0
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/WHEEL +0 -0
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/entry_points.txt +0 -0
- {simba_uw_tf_dev-4.6.6.dist-info → simba_uw_tf_dev-4.6.8.dist-info}/top_level.txt +0 -0
simba/plotting/gantt_creator.py
CHANGED
|
@@ -2,6 +2,7 @@ __author__ = "Simon Nilsson; sronilsson@gmail.com"
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import shutil
|
|
5
|
+
from copy import deepcopy
|
|
5
6
|
from typing import List, Optional, Union
|
|
6
7
|
|
|
7
8
|
import cv2
|
|
@@ -16,7 +17,7 @@ from simba.utils.checks import (
|
|
|
16
17
|
from simba.utils.data import create_color_palette, detect_bouts
|
|
17
18
|
from simba.utils.enums import Formats, Options
|
|
18
19
|
from simba.utils.errors import NoSpecifiedOutputError
|
|
19
|
-
from simba.utils.lookups import get_named_colors
|
|
20
|
+
from simba.utils.lookups import get_fonts, get_named_colors
|
|
20
21
|
from simba.utils.printing import stdout_success
|
|
21
22
|
from simba.utils.read_write import get_fn_ext, read_df
|
|
22
23
|
|
|
@@ -60,16 +61,18 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
60
61
|
|
|
61
62
|
def __init__(self,
|
|
62
63
|
config_path: Union[str, os.PathLike],
|
|
63
|
-
data_paths: List[Union[str, os.PathLike]],
|
|
64
|
+
data_paths: Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] = None,
|
|
64
65
|
width: int = 640,
|
|
65
66
|
height: int = 480,
|
|
66
67
|
font_size: int = 8,
|
|
67
68
|
font_rotation: int = 45,
|
|
69
|
+
font: Optional[str] = None,
|
|
68
70
|
palette: str = 'Set1',
|
|
69
|
-
frame_setting:
|
|
70
|
-
video_setting:
|
|
71
|
-
last_frm_setting:
|
|
72
|
-
hhmmss:
|
|
71
|
+
frame_setting: bool = False,
|
|
72
|
+
video_setting: bool = False,
|
|
73
|
+
last_frm_setting: bool = True,
|
|
74
|
+
hhmmss: bool = True,
|
|
75
|
+
clf_names: Optional[List[str]] = None):
|
|
73
76
|
|
|
74
77
|
if ((frame_setting != True) and (video_setting != True) and (last_frm_setting != True)):
|
|
75
78
|
raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.")
|
|
@@ -78,7 +81,13 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
78
81
|
check_int(value=height, min_value=1, name=f'{self.__class__.__name__} height')
|
|
79
82
|
check_int(value=font_size, min_value=1, name=f'{self.__class__.__name__} font_size')
|
|
80
83
|
check_int(value=font_rotation, min_value=0, max_value=180, name=f'{self.__class__.__name__} font_rotation')
|
|
81
|
-
|
|
84
|
+
if isinstance(data_paths, list):
|
|
85
|
+
check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
|
|
86
|
+
elif isinstance(data_paths, str):
|
|
87
|
+
check_file_exist_and_readable(file_path=data_paths)
|
|
88
|
+
data_paths = [data_paths]
|
|
89
|
+
else:
|
|
90
|
+
data_paths = deepcopy(self.machine_results_paths)
|
|
82
91
|
check_valid_boolean(value=hhmmss, source=f'{self.__class__.__name__} hhmmss', raise_error=False)
|
|
83
92
|
palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
|
|
84
93
|
check_str(name=f'{self.__class__.__name__} palette', value=palette, options=palettes)
|
|
@@ -90,7 +99,12 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
90
99
|
if not os.path.exists(self.gantt_plot_dir): os.makedirs(self.gantt_plot_dir)
|
|
91
100
|
self.frame_setting, self.video_setting, self.last_frm_setting = frame_setting, video_setting, last_frm_setting
|
|
92
101
|
self.width, self.height, self.font_size, self.font_rotation = width, height, font_size, font_rotation
|
|
93
|
-
|
|
102
|
+
if font is not None:
|
|
103
|
+
check_str(name=f'{self.__class__.__name__} font', value=font, options=list(get_fonts().keys()), raise_error=True)
|
|
104
|
+
if clf_names is not None:
|
|
105
|
+
check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), valid_values=self.clf_names, min_len=1, raise_error=True)
|
|
106
|
+
self.clf_names = clf_names
|
|
107
|
+
self.data_paths, self.hhmmss, self.font = data_paths, hhmmss, font
|
|
94
108
|
self.colours = get_named_colors()
|
|
95
109
|
self.colour_tuple_x = list(np.arange(3.5, 203.5, 5))
|
|
96
110
|
|
|
@@ -121,6 +135,7 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
121
135
|
font_size=self.font_size,
|
|
122
136
|
font_rotation=self.font_rotation,
|
|
123
137
|
video_name=self.video_name,
|
|
138
|
+
font=self.font,
|
|
124
139
|
save_path=os.path.join(self.gantt_plot_dir, f"{self.video_name }_final_image.png"),
|
|
125
140
|
palette=self.clr_lst,
|
|
126
141
|
hhmmss=self.hhmmss)
|
|
@@ -135,6 +150,7 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
135
150
|
width=self.width,
|
|
136
151
|
height=self.height,
|
|
137
152
|
font_size=self.font_size,
|
|
153
|
+
font=self.font,
|
|
138
154
|
font_rotation=self.font_rotation,
|
|
139
155
|
video_name=self.video_name,
|
|
140
156
|
palette=self.clr_lst,
|
|
@@ -156,13 +172,16 @@ class GanttCreatorSingleProcess(ConfigReader, PlottingMixin):
|
|
|
156
172
|
# test = GanttCreatorSingleProcess(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
|
|
157
173
|
# frame_setting=False,
|
|
158
174
|
# video_setting=False,
|
|
159
|
-
# data_paths=[r"C:\troubleshooting\mitra\project_folder\csv\machine_results\
|
|
175
|
+
# data_paths=[r"C:\troubleshooting\mitra\project_folder\csv\machine_results\501_MA142_Gi_CNO_0516.csv"],
|
|
160
176
|
# last_frm_setting=True,
|
|
161
177
|
# width=640,
|
|
162
178
|
# height= 480,
|
|
163
179
|
# font_size=10,
|
|
180
|
+
# font=None,
|
|
164
181
|
# font_rotation=45,
|
|
165
|
-
#
|
|
182
|
+
# hhmmss=False,
|
|
183
|
+
# palette='Set1',
|
|
184
|
+
# clf_names=['straub_tail'])
|
|
166
185
|
# test.run()
|
|
167
186
|
|
|
168
187
|
|
|
@@ -8,16 +8,37 @@ import gc
|
|
|
8
8
|
import multiprocessing
|
|
9
9
|
import os
|
|
10
10
|
import platform
|
|
11
|
+
import sys
|
|
11
12
|
from copy import deepcopy
|
|
12
13
|
from typing import List, Optional, Union
|
|
13
14
|
|
|
14
15
|
import cv2
|
|
15
|
-
|
|
16
|
+
|
|
17
|
+
is_pycharm_ipython = True
|
|
18
|
+
try:
|
|
19
|
+
module_names = list(sys.modules.keys())
|
|
20
|
+
if any('pydev' in str(mod).lower() for mod in module_names):
|
|
21
|
+
is_pycharm_ipython = True
|
|
22
|
+
elif 'IPython' in sys.modules or 'ipython' in sys.modules:
|
|
23
|
+
is_pycharm_ipython = True
|
|
24
|
+
else:
|
|
25
|
+
is_pycharm_ipython = False
|
|
26
|
+
except Exception:
|
|
27
|
+
is_pycharm_ipython = True
|
|
28
|
+
|
|
29
|
+
if not is_pycharm_ipython:
|
|
30
|
+
if 'MPLBACKEND' not in os.environ: os.environ['MPLBACKEND'] = 'Agg'
|
|
31
|
+
try:
|
|
32
|
+
import matplotlib
|
|
33
|
+
matplotlib.use('Agg', force=False)
|
|
34
|
+
except (RecursionError, RuntimeError, ValueError, SystemError):
|
|
35
|
+
import matplotlib
|
|
36
|
+
else:
|
|
37
|
+
import matplotlib
|
|
38
|
+
|
|
16
39
|
import numpy as np
|
|
17
40
|
import pandas as pd
|
|
18
41
|
|
|
19
|
-
matplotlib.use('Agg')
|
|
20
|
-
|
|
21
42
|
from simba.mixins.config_reader import ConfigReader
|
|
22
43
|
from simba.mixins.plotting_mixin import PlottingMixin
|
|
23
44
|
from simba.utils.checks import (
|
|
@@ -28,6 +49,7 @@ from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
|
|
|
28
49
|
terminate_cpu_pool)
|
|
29
50
|
from simba.utils.enums import Formats, Options
|
|
30
51
|
from simba.utils.errors import NoSpecifiedOutputError
|
|
52
|
+
from simba.utils.lookups import get_fonts
|
|
31
53
|
from simba.utils.printing import SimbaTimer, stdout_success
|
|
32
54
|
from simba.utils.read_write import (concatenate_videos_in_folder,
|
|
33
55
|
create_directory, find_core_cnt,
|
|
@@ -51,6 +73,7 @@ def gantt_creator_mp(data: np.array,
|
|
|
51
73
|
width: int,
|
|
52
74
|
height: int,
|
|
53
75
|
font_size: int,
|
|
76
|
+
font: str,
|
|
54
77
|
font_rotation: int,
|
|
55
78
|
palette: np.ndarray,
|
|
56
79
|
hhmmss: bool):
|
|
@@ -72,6 +95,7 @@ def gantt_creator_mp(data: np.array,
|
|
|
72
95
|
width=width,
|
|
73
96
|
height=height,
|
|
74
97
|
font_size=font_size,
|
|
98
|
+
font=font,
|
|
75
99
|
font_rotation=font_rotation,
|
|
76
100
|
video_name=video_name,
|
|
77
101
|
save_path=None,
|
|
@@ -139,9 +163,11 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
139
163
|
height: int = 480,
|
|
140
164
|
font_size: int = 8,
|
|
141
165
|
font_rotation: int = 45,
|
|
166
|
+
font: Optional[str] = None,
|
|
142
167
|
palette: str = 'Set1',
|
|
143
168
|
core_cnt: int = -1,
|
|
144
|
-
hhmmss: bool = False
|
|
169
|
+
hhmmss: bool = False,
|
|
170
|
+
clf_names: Optional[List[str]] = None):
|
|
145
171
|
|
|
146
172
|
check_file_exist_and_readable(file_path=config_path)
|
|
147
173
|
if (not frame_setting) and (not video_setting) and (not last_frm_setting):
|
|
@@ -157,6 +183,8 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
157
183
|
check_int(name=f"{self.__class__.__name__} core_cnt",value=core_cnt, min_value=-1, unaccepted_vals=[0], max_value=find_core_cnt()[0])
|
|
158
184
|
self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
|
|
159
185
|
self.width, self.height, self.font_size, self.font_rotation, self.hhmmss = width, height, font_size, font_rotation, hhmmss
|
|
186
|
+
if font is not None:
|
|
187
|
+
check_str(name=f'{self.__class__.__name__} font', value=font, options=list(get_fonts().keys()), raise_error=True)
|
|
160
188
|
ConfigReader.__init__(self, config_path=config_path, create_logger=False)
|
|
161
189
|
if isinstance(data_paths, list):
|
|
162
190
|
check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
|
|
@@ -166,9 +194,12 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
166
194
|
else:
|
|
167
195
|
data_paths = deepcopy(self.machine_results_paths)
|
|
168
196
|
for file_path in data_paths: check_file_exist_and_readable(file_path=file_path)
|
|
197
|
+
if clf_names is not None:
|
|
198
|
+
check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), valid_values=self.clf_names, min_len=1, raise_error=True)
|
|
199
|
+
self.clf_names = clf_names
|
|
169
200
|
PlottingMixin.__init__(self)
|
|
170
201
|
self.clr_lst = create_color_palette(pallete_name=palette, increments=len(self.body_parts_lst) + 1, as_int=True, as_rgb_ratio=True)
|
|
171
|
-
self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting = frame_setting, video_setting,data_paths, last_frm_setting
|
|
202
|
+
self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting, self.font = frame_setting, video_setting,data_paths, last_frm_setting, font
|
|
172
203
|
if not os.path.exists(self.gantt_plot_dir): os.makedirs(self.gantt_plot_dir)
|
|
173
204
|
if platform.system() == "Darwin":
|
|
174
205
|
multiprocessing.set_start_method("spawn", force=True)
|
|
@@ -206,6 +237,7 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
206
237
|
font_size=self.font_size,
|
|
207
238
|
font_rotation=self.font_rotation,
|
|
208
239
|
video_name=self.video_name,
|
|
240
|
+
font=self.font,
|
|
209
241
|
save_path=os.path.join(self.gantt_plot_dir, f"{self.video_name}_final_image.png"),
|
|
210
242
|
palette=self.clr_lst,
|
|
211
243
|
hhmmss=self.hhmmss)
|
|
@@ -225,6 +257,7 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
225
257
|
width=self.width,
|
|
226
258
|
height=self.height,
|
|
227
259
|
font_size=self.font_size,
|
|
260
|
+
font=self.font,
|
|
228
261
|
font_rotation=self.font_rotation,
|
|
229
262
|
video_name=self.video_name,
|
|
230
263
|
palette=self.clr_lst,
|
|
@@ -254,18 +287,18 @@ class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
|
|
|
254
287
|
# font_rotation= 45)
|
|
255
288
|
# test.run()
|
|
256
289
|
|
|
257
|
-
if __name__ == "__main__":
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
290
|
+
# if __name__ == "__main__":
|
|
291
|
+
# test = GanttCreatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
|
|
292
|
+
# frame_setting=False,
|
|
293
|
+
# video_setting=True,
|
|
294
|
+
# data_paths=r"D:\troubleshooting\maplight_ri\project_folder\csv\machine_results\Trial_1_C24_D1_1.csv",
|
|
295
|
+
# last_frm_setting=False,
|
|
296
|
+
# width=640,
|
|
297
|
+
# height= 480,
|
|
298
|
+
# font_size=10,
|
|
299
|
+
# font_rotation= 45,
|
|
300
|
+
# core_cnt=16)
|
|
301
|
+
# test.run()
|
|
269
302
|
|
|
270
303
|
|
|
271
304
|
# test = GanttCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini',
|
|
@@ -98,14 +98,14 @@ class HeatMapperClfMultiprocess(ConfigReader, PlottingMixin):
|
|
|
98
98
|
|
|
99
99
|
|
|
100
100
|
:example II:
|
|
101
|
-
>>> test = HeatMapperClfMultiprocess(config_path=r"C
|
|
101
|
+
>>> test = HeatMapperClfMultiprocess(config_path=r"C:/troubleshooting/RAT_NOR/project_folder/project_config.ini",
|
|
102
102
|
>>> style_attr = {'palette': 'jet', 'shading': 'gouraud', 'bin_size': 50, 'max_scale': 'auto'},
|
|
103
103
|
>>> final_img_setting=True,
|
|
104
104
|
>>> video_setting=True,
|
|
105
105
|
>>> frame_setting=True,
|
|
106
106
|
>>> bodypart='Ear_left',
|
|
107
107
|
>>> clf_name='straub_tail',
|
|
108
|
-
>>> data_paths=[r"C
|
|
108
|
+
>>> data_paths=[r"C:/troubleshooting/RAT_NOR/project_folder/csv/test/2022-06-20_NOB_DOT_4.csv"])
|
|
109
109
|
>>> test.run()
|
|
110
110
|
"""
|
|
111
111
|
|
|
@@ -44,10 +44,10 @@ class SimBABlobImporter(ConfigReader):
|
|
|
44
44
|
:param Optional[bool] verbose: If True, prints progress messages. Default: True.
|
|
45
45
|
|
|
46
46
|
:example:
|
|
47
|
-
>>> r = SimBABlobImporter(config_path=r"C
|
|
47
|
+
>>> r = SimBABlobImporter(config_path=r"C:/troubleshooting/simba_blob_project/project_folder/project_config.ini", data_path=r'C:/troubleshooting/simba_blob_project/data')
|
|
48
48
|
>>> r.run()
|
|
49
|
-
>>> r = SimBABlobImporter(config_path=r"C
|
|
50
|
-
... data_path=r'C
|
|
49
|
+
>>> r = SimBABlobImporter(config_path=r"C:/troubleshooting/simba_blob_project/project_folder/project_config.ini",
|
|
50
|
+
... data_path=r'C:/troubleshooting/simba_blob_project/data',
|
|
51
51
|
... smoothing_settings={'method': 'savitzky-golay', 'time_window': 100},
|
|
52
52
|
... interpolation_settings={'method': 'nearest', 'type': 'body-parts'})
|
|
53
53
|
>>> r.run()
|
|
@@ -168,7 +168,7 @@ class ROIAggregateStatisticsAnalyzerMultiprocess(ConfigReader, FeatureExtraction
|
|
|
168
168
|
:param save_path (str | os.PathLike, optional): Path to save summary statistics.
|
|
169
169
|
|
|
170
170
|
:example:
|
|
171
|
-
>>> analyzer = ROIAggregateStatisticsAnalyzerMultiprocess(config_path=r"C
|
|
171
|
+
>>> analyzer = ROIAggregateStatisticsAnalyzerMultiprocess(config_path=r"C:/troubleshooting/mitra/project_folder/project_config.ini", body_parts=['Center'], first_entry_time=True, threshold=0.0, calculate_distances=True, transpose=False, detailed_bout_data=True)
|
|
172
172
|
>>> analyzer.run()
|
|
173
173
|
>>> analyzer.save()
|
|
174
174
|
"""
|
|
@@ -150,7 +150,7 @@ class ROIClfCalculatorMultiprocess(ConfigReader):
|
|
|
150
150
|
'GitHub tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/Scenario2.md#part-4--analyze-machine-results`__.
|
|
151
151
|
|
|
152
152
|
:example:
|
|
153
|
-
>>> analyzer = ROIClfCalculatorMultiprocess(config_path=r"D
|
|
153
|
+
>>> analyzer = ROIClfCalculatorMultiprocess(config_path=r"D:/troubleshooting/maplight_ri/project_folder/project_config.ini", bp_names=['resident_NOSE'], clf_names=['attack'], clf_time=True, started_bout_cnt=True, ended_bout_cnt=False, bout_table=True, transpose=True, core_cnt=20)
|
|
154
154
|
>>> analyzer.run()
|
|
155
155
|
>>> analyzer.save()
|
|
156
156
|
"""
|
|
@@ -54,15 +54,15 @@ class COCOKeypoints2Yolo:
|
|
|
54
54
|
:return: None
|
|
55
55
|
|
|
56
56
|
:example:
|
|
57
|
-
>>> runner = COCOKeypoints2Yolo(coco_path=r"D
|
|
57
|
+
>>> runner = COCOKeypoints2Yolo(coco_path=r"D:/cvat_annotations/frames/coco_keypoints_1/s1/annotations/s1.json", img_dir=r"D:/cvat_annotations/frames/simon", save_dir=r"D:/cvat_annotations/frames/yolo_keypoints", clahe=True)
|
|
58
58
|
>>> runner.run()
|
|
59
59
|
|
|
60
60
|
:example II:
|
|
61
|
-
>>> runner = COCOKeypoints2Yolo(coco_path=r"D
|
|
61
|
+
>>> runner = COCOKeypoints2Yolo(coco_path=r"D:/cvat_annotations/frames/coco_keypoints_1/merged.json", img_dir=r"D:/cvat_annotations/frames", save_dir=r"D:/cvat_annotations/frames/yolo", clahe=False)
|
|
62
62
|
>>> runner.run()
|
|
63
63
|
|
|
64
64
|
:example III:
|
|
65
|
-
>>> runner = COCOKeypoints2Yolo(coco_path=r"E
|
|
65
|
+
>>> runner = COCOKeypoints2Yolo(coco_path=r"E:/netholabs_videos/mosaics/subset/to_annotate/2d_mosaic_batch_1.json", img_dir=r"E:/netholabs_videos/mosaics/subset/to_annotate", save_dir=r"E:/netholabs_videos/mosaics/yolo_mdl", clahe=False)
|
|
66
66
|
>>> runner.run()
|
|
67
67
|
|
|
68
68
|
:references:
|
|
@@ -58,11 +58,11 @@ class COCOKeypoints2YoloBbox:
|
|
|
58
58
|
:return: None
|
|
59
59
|
|
|
60
60
|
:example:
|
|
61
|
-
>>> runner =
|
|
61
|
+
>>> runner = COCOKeypoints2YoloBbox(coco_path=r"D:/cvat_annotations/frames/coco_keypoints_1/s1/annotations/s1.json", img_dir=r"D:/cvat_annotations/frames/simon", save_dir=r"D:/cvat_annotations/frames/yolo_keypoints", clahe=True)
|
|
62
62
|
>>> runner.run()
|
|
63
63
|
|
|
64
64
|
:example II:
|
|
65
|
-
>>> runner =
|
|
65
|
+
>>> runner = COCOKeypoints2YoloBbox(coco_path=r"D:/cvat_annotations/frames/coco_keypoints_1/merged.json", img_dir=r"D:/cvat_annotations/frames", save_dir=r"D:/cvat_annotations/frames/yolo", clahe=False)
|
|
66
66
|
>>> runner.run()
|
|
67
67
|
|
|
68
68
|
:references:
|
|
@@ -49,8 +49,8 @@ class SklearnVisualizationPopUp(PopUpMixin, ConfigReader):
|
|
|
49
49
|
pose_palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
|
|
50
50
|
PopUpMixin.__init__(self, title="VISUALIZE CLASSIFICATION (SKLEARN) RESULTS", icon='photos')
|
|
51
51
|
bp_threshold_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="BODY-PART VISUALIZATION THRESHOLD", icon_name='threshold', icon_link=Links.SKLEARN_PLOTS.value, padx=5, pady=5, relief='solid')
|
|
52
|
-
self.bp_threshold_lbl = SimBALabel(parent=bp_threshold_frm, txt="Body-parts detected below the set threshold won't be shown in the output videos.", font=Formats.FONT_REGULAR_ITALICS.value)
|
|
53
|
-
self.bp_threshold_entry = Entry_Box(parent=bp_threshold_frm, fileDescription='BODY-PART PROBABILITY THRESHOLD: ', labelwidth=40, entry_box_width=15, value=0.00, img='green_dice')
|
|
52
|
+
self.bp_threshold_lbl = SimBALabel(parent=bp_threshold_frm, txt="Body-parts detected below the set threshold won't be shown in the output videos (use 0.0 to see all body-part predictions)", font=Formats.FONT_REGULAR_ITALICS.value)
|
|
53
|
+
self.bp_threshold_entry = Entry_Box(parent=bp_threshold_frm, fileDescription='BODY-PART PROBABILITY THRESHOLD: ', labelwidth=40, entry_box_width=15, value=0.00, img='green_dice', justify='center')
|
|
54
54
|
self.get_bp_probability_threshold()
|
|
55
55
|
|
|
56
56
|
bp_threshold_frm.grid(row=0, column=0, sticky=NW)
|
simba/ui/pop_ups/gantt_pop_up.py
CHANGED
|
@@ -13,6 +13,7 @@ from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, SimbaButton,
|
|
|
13
13
|
from simba.utils.checks import check_if_filepath_list_is_empty
|
|
14
14
|
from simba.utils.enums import Formats, Links, Options
|
|
15
15
|
from simba.utils.errors import NoSpecifiedOutputError
|
|
16
|
+
from simba.utils.lookups import get_fonts
|
|
16
17
|
from simba.utils.read_write import find_files_of_filetypes_in_directory
|
|
17
18
|
|
|
18
19
|
|
|
@@ -29,7 +30,9 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
|
|
|
29
30
|
check_if_filepath_list_is_empty(filepaths=self.machine_results_paths,error_msg=f"SIMBA ERROR: Zero files found in the {self.machine_results_dir} directory. Create classification results before visualizing gantt charts",)
|
|
30
31
|
palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
|
|
31
32
|
self.data_paths = find_files_of_filetypes_in_directory(directory=self.machine_results_dir, extensions=[f'.{self.file_type}'], as_dict=True)
|
|
32
|
-
max_file_name_len = max(len(k) for k in self.data_paths) + 5
|
|
33
|
+
max_file_name_len, fonts = max(len(k) for k in self.data_paths) + 5, list(get_fonts(sort_alphabetically=True).keys())
|
|
34
|
+
fonts.insert(0, 'AUTO')
|
|
35
|
+
default_font = 'Arial' if 'Arial' in fonts else 'AUTO'
|
|
33
36
|
PopUpMixin.__init__(self, config_path=config_path, title="VISUALIZE GANTT PLOTS", icon='gantt_small')
|
|
34
37
|
|
|
35
38
|
self.style_settings_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="STYLE SETTINGS", icon_name='settings', icon_link=Links.GANTT_PLOTS.value, relief='solid', padx=5, pady=5)
|
|
@@ -39,6 +42,15 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
|
|
|
39
42
|
self.palette_dropdown = SimBADropDown(parent=self.style_settings_frm, dropdown_options=palettes, label='COLOR PALETTE: ', label_width=30, dropdown_width=30, value='Set1', img='palette_small')
|
|
40
43
|
self.time_format_dropdown = SimBADropDown(parent=self.style_settings_frm, dropdown_options=['SECONDS', 'HH:MM:SS'], label='X-AXIS TIME FORMAT: ', label_width=30, dropdown_width=30, value='SECONDS', img='timer_2')
|
|
41
44
|
self.core_dropdown = SimBADropDown(parent=self.style_settings_frm, dropdown_options=list(range(1, self.cpu_cnt+1)), label='CPU CORES: ', label_width=30, dropdown_width=30, value=int(self.cpu_cnt/2), img='cpu_small')
|
|
45
|
+
self.font_dropdown = SimBADropDown(parent=self.style_settings_frm, dropdown_options=fonts, label='FONT: ', label_width=30, dropdown_width=30, value=default_font, img='font')
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
self.clf_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="BEHAVIORS", icon_name='forest', icon_link=Links.GANTT_PLOTS.value, relief='solid', padx=5, pady=5)
|
|
49
|
+
self.clf_choices = {}
|
|
50
|
+
for cnt, clf_name in enumerate(self.clf_names):
|
|
51
|
+
gantt_frames_cb, self.gantt_frames_var = SimbaCheckbox(parent=self.clf_frm, txt=clf_name, val=True)
|
|
52
|
+
self.clf_choices[clf_name] = self.gantt_frames_var
|
|
53
|
+
gantt_frames_cb.grid(row=cnt, column=0, sticky=NW)
|
|
42
54
|
|
|
43
55
|
self.settings_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="VISUALIZATION SETTINGS", icon_name='eye', icon_link=Links.GANTT_PLOTS.value, relief='solid', padx=5, pady=5)
|
|
44
56
|
gantt_frames_cb, self.gantt_frames_var = SimbaCheckbox(parent=self.settings_frm, txt='CREATE FRAMES', txt_img='frames', val=False)
|
|
@@ -57,18 +69,20 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
|
|
|
57
69
|
self.font_rotation_dropdown.grid(row=2, sticky=NW)
|
|
58
70
|
self.palette_dropdown.grid(row=3, sticky=NW)
|
|
59
71
|
self.time_format_dropdown.grid(row=4, sticky=NW)
|
|
60
|
-
self.
|
|
72
|
+
self.font_dropdown.grid(row=5, sticky=NW)
|
|
73
|
+
self.core_dropdown.grid(row=6, sticky=NW)
|
|
61
74
|
|
|
62
|
-
self.
|
|
75
|
+
self.clf_frm.grid(row=1, sticky=NW, padx=10, pady=10)
|
|
76
|
+
self.settings_frm.grid(row=2, sticky=NW, padx=10, pady=10)
|
|
63
77
|
gantt_videos_cb.grid(row=0, sticky=NW)
|
|
64
78
|
gantt_frames_cb.grid(row=1, sticky=W)
|
|
65
79
|
gantt_last_frame_cb.grid(row=2, sticky=NW)
|
|
66
80
|
|
|
67
|
-
self.run_single_video_frm.grid(row=
|
|
81
|
+
self.run_single_video_frm.grid(row=3, sticky=NW)
|
|
68
82
|
self.run_single_video_btn.grid(row=0, sticky=NW)
|
|
69
83
|
self.single_video_dropdown.grid(row=0, column=1, sticky=NW)
|
|
70
84
|
|
|
71
|
-
self.run_multiple_videos.grid(row=
|
|
85
|
+
self.run_multiple_videos.grid(row=4, sticky=NW)
|
|
72
86
|
self.run_multiple_video_btn.grid(row=0, sticky=NW)
|
|
73
87
|
|
|
74
88
|
self.main_frm.mainloop()
|
|
@@ -85,8 +99,14 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
|
|
|
85
99
|
video_setting = self.gantt_videos_var.get()
|
|
86
100
|
last_frm_setting = self.gantt_last_frame_var.get()
|
|
87
101
|
palette = self.palette_dropdown.get_value()
|
|
102
|
+
font = self.font_dropdown.get_value()
|
|
103
|
+
font = None if font == 'AUTO' else font
|
|
88
104
|
hhmmss = True if self.time_format_dropdown.get_value() == 'HH:MM:SS' else False
|
|
89
|
-
|
|
105
|
+
clf_names = []
|
|
106
|
+
for clf_name, clf_val in self.clf_choices.items():
|
|
107
|
+
if clf_val.get(): clf_names.append(clf_name)
|
|
108
|
+
if len(clf_names) < 1:
|
|
109
|
+
raise NoSpecifiedOutputError(msg="Select AT LEAST one behavior name.")
|
|
90
110
|
if not frame_setting and not video_setting and not last_frm_setting:
|
|
91
111
|
raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.")
|
|
92
112
|
|
|
@@ -103,6 +123,8 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
|
|
|
103
123
|
data_paths=data_paths,
|
|
104
124
|
width=width,
|
|
105
125
|
height=height,
|
|
126
|
+
font=font,
|
|
127
|
+
clf_names=clf_names,
|
|
106
128
|
font_size=font_size,
|
|
107
129
|
font_rotation=font_rotation,
|
|
108
130
|
core_cnt=core_cnt,
|
|
@@ -116,7 +138,9 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
|
|
|
116
138
|
last_frm_setting=last_frm_setting,
|
|
117
139
|
data_paths=data_paths,
|
|
118
140
|
width=width,
|
|
141
|
+
font=font,
|
|
119
142
|
height=height,
|
|
143
|
+
clf_names=clf_names,
|
|
120
144
|
font_size=font_size,
|
|
121
145
|
font_rotation=font_rotation,
|
|
122
146
|
palette=palette)
|
|
@@ -124,6 +148,7 @@ class GanttPlotPopUp(PopUpMixin, ConfigReader):
|
|
|
124
148
|
|
|
125
149
|
|
|
126
150
|
#_ = GanttPlotPopUp(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini")
|
|
151
|
+
#_ = GanttPlotPopUp(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
|
|
127
152
|
# _ = GanttPlotPopUp(config_path=r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini")
|
|
128
153
|
|
|
129
154
|
# _ = GanttPlotPopUp(config_path=r"C:\troubleshooting\RAT_NOR\project_folder\project_config.ini")
|
|
@@ -1708,7 +1708,7 @@ class ClipMultipleVideosByTimestamps(PopUpMixin):
|
|
|
1708
1708
|
check_that_hhmmss_start_is_before_end(start_time=start, end_time=end, name=video_name)
|
|
1709
1709
|
check_if_hhmmss_timestamp_is_valid_part_of_video(timestamp=start, video_path=self.video_paths[video_name])
|
|
1710
1710
|
check_if_hhmmss_timestamp_is_valid_part_of_video(timestamp=end, video_path=self.video_paths[video_name])
|
|
1711
|
-
clip_video_in_range(file_path=self.video_paths[video_name], start_time=start, end_time=end, out_dir=self.save_dir, overwrite=True, include_clip_time_in_filename=False, gpu=gpu, quality=quality_pct)
|
|
1711
|
+
clip_video_in_range(file_path=self.video_paths[video_name], start_time=start, end_time=end, out_dir=self.save_dir, overwrite=True, include_clip_time_in_filename=False, gpu=gpu, quality=quality_pct, codec='libx264', verbose=True)
|
|
1712
1712
|
timer.stop_timer()
|
|
1713
1713
|
stdout_success(msg=f"{len(self.entry_boxes)} videos clipped by time-stamps and saved in {self.save_dir}", elapsed_time=timer.elapsed_time_str,)
|
|
1714
1714
|
|
|
@@ -30,7 +30,7 @@ class CustomFeatureExtractor(ConfigReader):
|
|
|
30
30
|
4. Handle cases of multiple classes and missing configuration arguments.
|
|
31
31
|
5. Invokes the feature extraction process if conditions are met.
|
|
32
32
|
|
|
33
|
-
..
|
|
33
|
+
.. note::
|
|
34
34
|
|
|
35
35
|
`Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/extractFeatures.md>`_.
|
|
36
36
|
|
simba/utils/data.py
CHANGED
|
@@ -1785,8 +1785,8 @@ def fft_lowpass_filter(data: np.ndarray, cut_off: float = 0.1) -> np.ndarray:
|
|
|
1785
1785
|
|
|
1786
1786
|
:example:
|
|
1787
1787
|
>>> from simba.utils.read_write import read_df
|
|
1788
|
-
>>> IN_PATH = r"C
|
|
1789
|
-
>>> OUT_PATH = r"C
|
|
1788
|
+
>>> IN_PATH = r"C:/troubleshooting/RAT_NOR/project_folder/csv/outlier_corrected_movement_location/2022-06-20_NOB_DOT_4.csv"
|
|
1789
|
+
>>> OUT_PATH = r"C:/troubleshooting/RAT_NOR/project_folder/csv/outlier_corrected_movement_location/2022-06-20_NOB_DOT_4_filtered.csv"
|
|
1790
1790
|
>>> df = read_df(file_path=IN_PATH)
|
|
1791
1791
|
>>> data = df.values
|
|
1792
1792
|
>>> x = fft_lowpass_filter(data=data, cut_off=0.1)
|
simba/utils/read_write.py
CHANGED
|
@@ -97,13 +97,20 @@ def read_df(file_path: Union[str, os.PathLike],
|
|
|
97
97
|
.. note::
|
|
98
98
|
For improved runtime, defaults to :external:py:meth:`pyarrow.csv.write_cs` if file type is ``csv``.
|
|
99
99
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
100
|
+
.. csv-table::
|
|
101
|
+
:header: EXPECTED RUNTIMES
|
|
102
|
+
:file: ../../docs/tables/read_df.csv
|
|
103
|
+
:widths: 10, 45, 45
|
|
104
|
+
:align: center
|
|
105
|
+
:header-rows: 1
|
|
106
|
+
|
|
107
|
+
:param str file_path: Path to data file
|
|
108
|
+
:param str file_type: Type of data. OPTIONS: 'parquet', 'csv', 'pickle'.
|
|
109
|
+
:param Optional[bool]: If the input file has an initial index column. Default: True.
|
|
110
|
+
:param Optional[List[str]] remove_columns: If not None, then remove columns in lits.
|
|
111
|
+
:param Optional[List[str]] usecols: If not None, then keep columns in list.
|
|
112
|
+
:param bool check_multiindex: check file is multi-index headers. Default: False.
|
|
113
|
+
:param int multi_index_headers_to_keep: If reading multi-index file, and we want to keep one of the dropped multi-index levels as the header in the output file, specify the index of the multiindex hader as int.
|
|
107
114
|
:return: Table data in pd.DataFrame format.
|
|
108
115
|
:rtype: pd.DataFrame
|
|
109
116
|
|
|
@@ -207,11 +214,18 @@ def write_df(df: pd.DataFrame,
|
|
|
207
214
|
.. note::
|
|
208
215
|
For improved runtime, defaults to ``pyarrow.csv`` if file_type == ``csv``.
|
|
209
216
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
217
|
+
.. csv-table::
|
|
218
|
+
:header: EXPECTED RUNTIMES
|
|
219
|
+
:file: ../../docs/tables/write_df.csv
|
|
220
|
+
:widths: 10, 45, 45
|
|
221
|
+
:align: center
|
|
222
|
+
:header-rows: 1
|
|
223
|
+
|
|
224
|
+
:param pd.DataFrame df: Pandas dataframe to save to disk.
|
|
225
|
+
:param str file_type: Type of data. OPTIONS: ``parquet``, ``csv``, ``pickle``.
|
|
226
|
+
:param str save_path: Location where to store the data.
|
|
227
|
+
:param bool check_multiindex: check if input file is multi-index headers. Default: False.
|
|
228
|
+
:param bool verbose: Prints message on completion. Default: False.
|
|
215
229
|
|
|
216
230
|
:example:
|
|
217
231
|
>>> write_df(df=df, file_type='csv', save_path='project_folder/csv/input_csv/Video_1.csv')
|
|
@@ -1130,8 +1144,8 @@ def get_file_name_info_in_directory(directory: Union[str, os.PathLike], file_typ
|
|
|
1130
1144
|
:return dict: All found files as values and file base names as keys.
|
|
1131
1145
|
|
|
1132
1146
|
:example:
|
|
1133
|
-
>>> get_file_name_info_in_directory(directory='C
|
|
1134
|
-
>>> {'Video_1': 'C
|
|
1147
|
+
>>> get_file_name_info_in_directory(directory='C:/project_folder/csv/machine_results', file_type='csv')
|
|
1148
|
+
>>> {'Video_1': 'C:/project_folder/csv/machine_results/Video_1'}
|
|
1135
1149
|
"""
|
|
1136
1150
|
|
|
1137
1151
|
results = {}
|
|
@@ -2630,7 +2644,7 @@ def bento_file_reader(file_path: Union[str, os.PathLike],
|
|
|
2630
2644
|
:rtype: Dict[str, pd.DataFrame]
|
|
2631
2645
|
|
|
2632
2646
|
:example:
|
|
2633
|
-
>>> bento_file_reader(file_path=r"C
|
|
2647
|
+
>>> bento_file_reader(file_path=r"C:/troubleshooting/bento_test/bento_files/20240812_crumpling3.annot")
|
|
2634
2648
|
"""
|
|
2635
2649
|
|
|
2636
2650
|
def _orient_columns_melt(df: pd.DataFrame) -> pd.DataFrame:
|
|
@@ -2953,7 +2967,7 @@ def labelme_to_dlc(labelme_dir: Union[str, os.PathLike],
|
|
|
2953
2967
|
:return: None
|
|
2954
2968
|
|
|
2955
2969
|
:example:
|
|
2956
|
-
>>> labelme_dir = r'D
|
|
2970
|
+
>>> labelme_dir = r'D:/ts_annotations'
|
|
2957
2971
|
>>> labelme_to_dlc(labelme_dir=labelme_dir)
|
|
2958
2972
|
"""
|
|
2959
2973
|
|
|
@@ -3597,8 +3611,8 @@ def osf_download(project_id: str, save_dir: Union[str, os.PathLike], storage: st
|
|
|
3597
3611
|
:param bool overwrite: If True, overwrite existing files. If False, skip existing files (default: False).
|
|
3598
3612
|
|
|
3599
3613
|
:example:
|
|
3600
|
-
>>> osf_download(project_id="7fgwn", save_dir=r'E
|
|
3601
|
-
>>> osf_download(project_id="kym42", save_dir=r'E
|
|
3614
|
+
>>> osf_download(project_id="7fgwn", save_dir=r'E:/rgb_white_vs_black_imgs')
|
|
3615
|
+
>>> osf_download(project_id="kym42", save_dir=r'E:/crim13_imgs', overwrite=True)
|
|
3602
3616
|
"""
|
|
3603
3617
|
|
|
3604
3618
|
_ = get_pkg_version(pkg='osfclient', raise_error=True)
|
simba/utils/yolo.py
CHANGED
|
@@ -47,6 +47,9 @@ def fit_yolo(weights_path: Union[str, os.PathLike],
|
|
|
47
47
|
`Download initial weights <https://huggingface.co/Ultralytics>`__.
|
|
48
48
|
`Example model_yaml <https://github.com/sgoldenlab/simba/blob/master/misc/ex_yolo_model.yaml>`__.
|
|
49
49
|
|
|
50
|
+
.. seealso::
|
|
51
|
+
For the recommended wrapper class with parameter validation, see :class:`simba.model.yolo_fit.FitYolo`.
|
|
52
|
+
|
|
50
53
|
:param initial_weights: Path to the pre-trained YOLO model weights (usually a `.pt` file). Example weights can be found [here](https://huggingface.co/Ultralytics).
|
|
51
54
|
:param model_yaml: YAML file containing paths to the training, validation, and testing datasets and the object class mappings. Example YAML file can be found [here](https://github.com/sgoldenlab/simba/blob/master/misc/ex_yolo_model.yaml).
|
|
52
55
|
:param save_path: Directory path where the trained model, logs, and results will be saved.
|
|
@@ -55,7 +58,7 @@ def fit_yolo(weights_path: Union[str, os.PathLike],
|
|
|
55
58
|
:return: None. The trained model and associated training logs are saved in the specified `project_path`.
|
|
56
59
|
|
|
57
60
|
:example:
|
|
58
|
-
>>> fit_yolo(initial_weights=r"C
|
|
61
|
+
>>> fit_yolo(initial_weights=r"C:/troubleshooting/coco_data/weights/yolov8n-obb.pt", data=r"C:/troubleshooting/coco_data/model.yaml", save_path=r"C:/troubleshooting/coco_data/mdl", batch=16)
|
|
59
62
|
"""
|
|
60
63
|
|
|
61
64
|
if not _is_cuda_available()[0]:
|
|
@@ -83,6 +86,9 @@ def load_yolo_model(weights_path: Union[str, os.PathLike],
|
|
|
83
86
|
"""
|
|
84
87
|
Load a YOLO model.
|
|
85
88
|
|
|
89
|
+
.. seealso::
|
|
90
|
+
For recommended wrapper classes that use this function, see :class:`simba.model.yolo_fit.FitYolo`, :class:`simba.model.yolo_inference.YoloInference`, :class:`simba.model.yolo_pose_inference.YOLOPoseInference`, :class:`simba.model.yolo_seg_inference.YOLOSegmentationInference`, and :class:`simba.model.yolo_pose_track_inference.YOLOPoseTrackInference`.
|
|
91
|
+
|
|
86
92
|
:param Union[str, os.PathLike] weights_path: Path to model weights (.pt, .engine, etc).
|
|
87
93
|
:param bool verbose: Whether to print loading info.
|
|
88
94
|
:param Optional[str] format: Export format, one of VALID_FORMATS or None to skip export.
|
|
@@ -169,6 +175,9 @@ def yolo_predict(model: YOLO,
|
|
|
169
175
|
"""
|
|
170
176
|
Produce YOLO predictions.
|
|
171
177
|
|
|
178
|
+
.. seealso::
|
|
179
|
+
For recommended wrapper classes that use this function, see :class:`simba.model.yolo_inference.YoloInference`, :class:`simba.model.yolo_pose_inference.YOLOPoseInference`, and :class:`simba.model.yolo_seg_inference.YOLOSegmentationInference`.
|
|
180
|
+
|
|
172
181
|
:param Union[str, os.PathLike] model: Loaded ultralytics.YOLO model. Returned by :func:`~simba.bounding_box_tools.yolo.model.load_yolo_model`.
|
|
173
182
|
:param Union[str, os.PathLike, np.ndarray] source: Path to video, video stream, directory, image, or image as loaded array.
|
|
174
183
|
:param bool half: Whether to use half precision (FP16) for inference to speed up processing.
|