simba-uw-tf-dev 4.7.5__py3-none-any.whl → 4.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of simba-uw-tf-dev might be problematic. Click here for more details.

Files changed (29) hide show
  1. simba/assets/.recent_projects.txt +2 -0
  2. simba/assets/icons/folder_2.png +0 -0
  3. simba/assets/icons/folder_video.png +0 -0
  4. simba/assets/lookups/tooptips.json +24 -2
  5. simba/mixins/feature_extraction_mixin.py +0 -2
  6. simba/model/yolo_fit.py +42 -9
  7. simba/sandbox/av1.py +5 -0
  8. simba/sandbox/clean_sleap.py +4 -0
  9. simba/sandbox/denoise_hqdn3d.py +266 -0
  10. simba/sandbox/extract_random_frames.py +126 -0
  11. simba/third_party_label_appenders/transform/coco_keypoints_to_yolo.py +1 -2
  12. simba/third_party_label_appenders/transform/sleap_csv_to_yolo.py +18 -12
  13. simba/ui/create_project_ui.py +1 -1
  14. simba/ui/pop_ups/batch_preprocess_pop_up.py +1 -1
  15. simba/ui/pop_ups/simba_to_yolo_keypoints_popup.py +96 -96
  16. simba/ui/pop_ups/sleap_annotations_to_yolo_popup.py +32 -18
  17. simba/ui/pop_ups/sleap_csv_predictions_to_yolo_popup.py +15 -14
  18. simba/ui/pop_ups/video_processing_pop_up.py +1 -1
  19. simba/ui/pop_ups/yolo_plot_results.py +146 -153
  20. simba/ui/pop_ups/yolo_pose_train_popup.py +69 -23
  21. simba/utils/checks.py +2414 -2401
  22. simba/utils/read_write.py +22 -20
  23. simba/video_processors/video_processing.py +21 -13
  24. {simba_uw_tf_dev-4.7.5.dist-info → simba_uw_tf_dev-4.7.7.dist-info}/METADATA +1 -1
  25. {simba_uw_tf_dev-4.7.5.dist-info → simba_uw_tf_dev-4.7.7.dist-info}/RECORD +29 -23
  26. {simba_uw_tf_dev-4.7.5.dist-info → simba_uw_tf_dev-4.7.7.dist-info}/LICENSE +0 -0
  27. {simba_uw_tf_dev-4.7.5.dist-info → simba_uw_tf_dev-4.7.7.dist-info}/WHEEL +0 -0
  28. {simba_uw_tf_dev-4.7.5.dist-info → simba_uw_tf_dev-4.7.7.dist-info}/entry_points.txt +0 -0
  29. {simba_uw_tf_dev-4.7.5.dist-info → simba_uw_tf_dev-4.7.7.dist-info}/top_level.txt +0 -0
simba/utils/checks.py CHANGED
@@ -1,2401 +1,2414 @@
1
- __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
-
3
- import ast
4
- import glob
5
- import os
6
- import re
7
- import subprocess
8
- from pathlib import Path
9
- from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
10
-
11
- try:
12
- from typing import Literal
13
- except:
14
- from typing_extensions import Literal
15
-
16
- try:
17
- import cupy as cp
18
- except ModuleNotFoundError:
19
- import numpy as cp
20
-
21
- import multiprocessing
22
-
23
- import cv2
24
- import numpy as np
25
- import pandas as pd
26
- import trafaret as t
27
- from shapely.geometry import Polygon
28
-
29
- from simba.data_processors.cuda.utils import _is_cuda_available
30
- from simba.utils.enums import Formats, Keys, Options, UMAPParam
31
- from simba.utils.errors import (ArrayError, ColumnNotFoundError,
32
- CorruptedFileError, CountError,
33
- DirectoryNotEmptyError, FFMPEGNotFoundError,
34
- FloatError, FrameRangeError, IntegerError,
35
- InvalidFilepathError, InvalidInputError,
36
- MissingColumnsError, NoDataError,
37
- NoFilesFoundError, NoROIDataError,
38
- NotDirectoryError, ParametersFileError,
39
- SimBAGPUError, StringError)
40
- from simba.utils.warnings import (CorruptedFileWarning, FrameRangeWarning,
41
- InvalidValueWarning, NoDataFoundWarning)
42
-
43
-
44
- def check_file_exist_and_readable(file_path: Union[str, os.PathLike], raise_error: bool = True) -> bool:
45
- """
46
- Checks if a path points to a readable file.
47
-
48
- :param str file_path: Path to file on disk.
49
- :raise NoFilesFoundError: The file does not exist.
50
- :raise CorruptedFileError: The file can not be read or is zero byte size.
51
- """
52
- check_instance(source="FILE PATH", instance=file_path, accepted_types=(str, os.PathLike))
53
- if not os.path.isfile(file_path):
54
- if raise_error:
55
- raise NoFilesFoundError(msg=f"{file_path} is not a valid file path", source=check_file_exist_and_readable.__name__)
56
- else:
57
- return False
58
- elif not os.access(file_path, os.R_OK):
59
- if raise_error:
60
- raise CorruptedFileError(msg=f"{file_path} is not readable", source=check_file_exist_and_readable.__name__)
61
- else:
62
- return False
63
- elif os.stat(file_path).st_size == 0:
64
- if raise_error:
65
- raise CorruptedFileError(msg=f"{file_path} is 0 bytes and contains no data.", source=check_file_exist_and_readable.__name__)
66
- else:
67
- return False
68
- else:
69
- return True
70
-
71
-
72
- def check_int(name: str,
73
- value: Any,
74
- max_value: Optional[int] = None,
75
- min_value: Optional[int] = None,
76
- unaccepted_vals: Optional[List[int]] = None,
77
- accepted_vals: Optional[List[int]] = None,
78
- allow_negative: bool = True,
79
- allow_zero: bool = True,
80
- raise_error: Optional[bool] = True) -> Tuple[bool, str]:
81
- """
82
- Check if variable is a valid integer.
83
-
84
- Validates that a value is an integer and optionally checks it against constraints such as
85
- minimum/maximum values, accepted/unaccepted value lists, and negative/zero number restrictions.
86
-
87
- :param str name: Name of the variable being checked (used in error messages).
88
- :param Any value: The value to validate as an integer.
89
- :param Optional[int] max_value: Maximum allowed value. If None, no maximum constraint. Default None.
90
- :param Optional[int] min_value: Minimum allowed value. If None, no minimum constraint. Default None.
91
- :param Optional[List[int]] unaccepted_vals: List of integer values that are not accepted. If value is in this list, validation fails. Default None.
92
- :param Optional[List[int]] accepted_vals: List of integer values that are accepted. If value is not in this list, validation fails. Default None.
93
- :param bool allow_negative: If False, negative values will cause validation to fail. Default True.
94
- :param bool allow_zero: If False, zero values will cause validation to fail. Default True.
95
- :param Optional[bool] raise_error: If True, raises IntegerError when validation fails. If False, returns (False, error_message) tuple. Default True.
96
- :return: If `raise_error` is False, returns a tuple (bool, str) where bool indicates if value is valid, and str contains error message (empty string if valid). If `raise_error` is True and validation passes, returns (True, ""). If `raise_error` is True and validation fails, raises IntegerError.
97
- :rtype: Tuple[bool, str]
98
- :raises IntegerError: If validation fails and `raise_error` is True.
99
-
100
- :example:
101
- >>> check_int(name='My_fps', value=25, min_value=1)
102
- >>> check_int(name='Quality', value=50, min_value=0, max_value=100, raise_error=False)
103
- >>> check_int(name='Mode', value=2, accepted_vals=[1, 2, 3])
104
- >>> check_int(name='Count', value=-5, allow_negative=False)
105
- >>> check_int(name='Divisor', value=0, allow_zero=False)
106
- """
107
- msg = ""
108
- try:
109
- t.Int().check(value)
110
- except t.DataError as e:
111
- msg = f"{name} should be an integer number in SimBA, but is set to {str(value)}"
112
- if raise_error:
113
- raise IntegerError(msg=msg, source=check_int.__name__)
114
- else:
115
- return False, msg
116
- if min_value != None:
117
- if int(value) < min_value:
118
- msg = f"{name} should be MORE THAN OR EQUAL to {str(min_value)}. It is set to {str(value)}"
119
- if raise_error:
120
- raise IntegerError(msg=msg, source=check_int.__name__)
121
- else:
122
- return False, msg
123
- if max_value != None:
124
- if int(value) > max_value:
125
- msg = f"{name} should be LESS THAN OR EQUAL to {str(max_value)}. It is set to {str(value)}"
126
- if raise_error:
127
- raise IntegerError(msg=msg, source=check_int.__name__)
128
- else:
129
- return False, msg
130
- if unaccepted_vals != None:
131
- check_valid_lst(data=unaccepted_vals, source=name, valid_dtypes=(int,), min_len=1)
132
- if int(value) in unaccepted_vals:
133
- msg = f"{name} is an not an accepted value. Unaccepted values {unaccepted_vals}."
134
- if raise_error:
135
- raise IntegerError(msg=msg, source=check_int.__name__)
136
- else:
137
- return False, msg
138
- if accepted_vals != None:
139
- check_valid_lst(data=accepted_vals, source=name, valid_dtypes=(int,), min_len=1)
140
- if int(value) not in accepted_vals:
141
- msg = f"{name} is an not an accepted value. Got: {value}. Accepted values {accepted_vals}."
142
- if raise_error:
143
- raise IntegerError(msg=msg, source=check_int.__name__)
144
- else:
145
- return False, msg
146
-
147
- if not allow_negative and int(value) < 0:
148
- msg = f"{name} is negative and negative is not accepted. Got: {value}."
149
- if raise_error:
150
- raise IntegerError(msg=msg, source=check_int.__name__)
151
- else:
152
- return False, msg
153
-
154
- if not allow_zero and int(value) == 0:
155
- msg = f"{name} is zero and zero is not accepted. Got: {value}."
156
- if raise_error:
157
- raise IntegerError(msg=msg, source=check_int.__name__)
158
- else:
159
- return False, msg
160
-
161
- return True, msg
162
-
163
-
164
- def check_str(name: str,
165
- value: Any,
166
- options: Optional[Union[Tuple[Any], List[Any], Iterable[Any]]] = (),
167
- allow_blank: bool = False,
168
- invalid_options: Optional[Union[List[str], Tuple[str]]] = None,
169
- raise_error: bool = True,
170
- invalid_substrs: Optional[Union[List[str], Tuple[str]]] = None) -> Tuple[bool, str]:
171
-
172
- """
173
- Check if variable is a valid string.
174
-
175
- :param str name: Name of variable
176
- :param Any value: Value of variable
177
- :param Optional[Tuple[Any]] options: Tuple of allowed strings. If empty tuple, then any string allowed. Default: ().
178
- :param Optional[bool] allow_blank: If True, allow empty string. Default: False.
179
- :param Optional[bool] raise_error: If True, then raise error if invalid string. Default: True.
180
- :param Optional[List[str]] invalid_options: If not None, then a list of strings that are invalid.
181
- :param Optional[List[str]] invalid_substrs: If not None, then a list of characters or substrings that are not allowed in the string.
182
- :return: If `raise_error` is False, then returns size-2 Tuple, with first value being a bool representing if valid string, and second value a string representing error reason (if valid is False, else empty string).
183
- :rtype: Tuple[bool, str]
184
-
185
- :examples:
186
- >>> check_str(name='split_eval', input='gini', options=['entropy', 'gini'])
187
- """
188
-
189
- msg = ""
190
- try:
191
- t.String(allow_blank=allow_blank).check(value)
192
- except t.DataError as e:
193
- msg = f"{name} should be an string in SimBA, but is set to {str(value)}"
194
- if raise_error:
195
- raise StringError(msg=msg, source=check_str.__name__)
196
- else:
197
- return False, msg
198
- if len(options) > 0:
199
- if value not in options:
200
- msg = f"{name} is set to {value} in SimBA, but this is not a valid option: {options}"
201
- if raise_error:
202
- raise StringError(msg=msg, source=check_str.__name__)
203
- else:
204
- return False, msg
205
- else:
206
- return True, msg
207
-
208
- if invalid_options is not None:
209
- check_instance(source=f'{name} invalid_options', accepted_types=(tuple, list,), instance=invalid_options)
210
- if isinstance(invalid_options, tuple):
211
- invalid_options = list(invalid_options)
212
- check_valid_lst(data=invalid_options, valid_dtypes=(str,), min_len=1)
213
- if value in invalid_options:
214
- msg = f"{name} is set to {value} in SimBA, but this is among invalid options: {invalid_options}"
215
- if raise_error:
216
- raise StringError(msg=msg, source=check_str.__name__)
217
- else:
218
- return False, msg
219
- else:
220
- return True, msg
221
- if invalid_substrs is not None:
222
- if not isinstance(invalid_substrs, (tuple, list)):
223
- check_instance(source=f'{name} invalid_characters', accepted_types=(tuple, list,), instance=invalid_options)
224
- if isinstance(invalid_substrs, tuple):
225
- invalid_substrs = list(invalid_substrs)
226
- check_valid_lst(data=invalid_substrs, valid_dtypes=(str,), min_len=1)
227
- for substr in invalid_substrs:
228
- if substr in value:
229
- msg = f'{name} contains the characters "{substr}" . This character/substring is NOT accepted.'
230
- if raise_error:
231
- raise StringError(msg=msg, source=check_str.__name__)
232
- else:
233
- return False, msg
234
- else:
235
- return True, msg
236
-
237
- def check_float(name: str,
238
- value: Any,
239
- max_value: Optional[float] = None,
240
- min_value: Optional[float] = None,
241
- raise_error: bool = True,
242
- allow_zero: bool = True,
243
- allow_negative: bool = True) -> Tuple[bool, str]:
244
- """
245
- Check if variable is a valid float.
246
-
247
- :param str name: Name of variable
248
- :param Any value: Value of variable
249
- :param Optional[int] max_value: Maximum allowed value of the float. If None, then no maximum. Default: None.
250
- :param Optional[int]: Minimum allowed value of the float. If None, then no minimum. Default: Non
251
- :param Optional[bool] allow_zero: If True, do not allow float to be zero. Default: True and allow zero.
252
- :param Optional[bool] allow_negative: If True, do not allow float to be below zero Default: True and allow negative.
253
- :param Optional[bool] raise_error: If True, then raise error if invalid float. Default: True.
254
- :return: If `raise_error` is False, then returns size-2 tuple, with first value being a bool representing if valid float, and second value a string representing error (if valid is False, else empty string)
255
- :rtype: Tuple[bool, str]
256
-
257
-
258
- :examples:
259
- >>> check_float(name='My_float', value=0.5, max_value=1.0, min_value=0.0)
260
- """
261
-
262
- msg = ""
263
- try:
264
- t.Float().check(value)
265
- except t.DataError as e:
266
- msg = f"{name} should be a float number in SimBA, but is set to {str(value)}"
267
- if raise_error:
268
- raise FloatError(msg=msg, source=check_float.__name__)
269
- else:
270
- return False, msg
271
- if min_value != None:
272
- if float(value) < min_value:
273
- msg = f"{name} should be MORE THAN OR EQUAL to {str(min_value)}. It is set to {str(value)}"
274
- if raise_error:
275
- raise FloatError(msg=msg, source=check_float.__name__)
276
- else:
277
- return False, msg
278
- if max_value != None:
279
- if float(value) > max_value:
280
- msg = f"{name} should be LESS THAN OR EQUAL to {str(max_value)}. It is set to {str(value)}"
281
- if raise_error:
282
- raise FloatError(msg=msg, source=check_float.__name__)
283
- else:
284
- return False, msg
285
- if not allow_zero:
286
- if float(value) == 0:
287
- msg = f"{name} cannot be ZERO. It is set to {str(value)}"
288
- if raise_error:
289
- raise FloatError(msg=msg, source=check_float.__name__)
290
- else:
291
- return False, msg
292
-
293
- if not allow_negative:
294
- if float(value) < 0:
295
- msg = f"{name} cannot be BELOW zero. It is set to {str(value)}"
296
- if raise_error:
297
- raise FloatError(msg=msg, source=check_float.__name__)
298
- else:
299
- return False, msg
300
-
301
- return True, msg
302
-
303
-
304
- def check_iterable_length(source: str, val: int, exact_accepted_length: Optional[int] = None, max: Optional[int] = np.inf, min: int = 1, raise_error: bool = True) -> bool:
305
- if (not exact_accepted_length) and (not max) and (not min):
306
- if raise_error:
307
- raise InvalidInputError(msg=f"Provide exact_accepted_length or max and min values for {source}", source=check_iterable_length.__name__)
308
- else:
309
- return False
310
- if exact_accepted_length:
311
- if val != exact_accepted_length:
312
- if raise_error:
313
- raise InvalidInputError(msg=f"{source} length is {val}, expected {exact_accepted_length}", source=check_iterable_length.__name__)
314
- else:
315
- return False
316
- elif (val > max) or (val < min):
317
- if raise_error:
318
- raise InvalidInputError(msg=f"{source} value {val} does not full-fill criterion: min {min}, max{max} ", source=check_iterable_length.__name__)
319
- else:
320
- return False
321
- return True
322
-
323
-
324
- def check_instance(source: str, instance: object, accepted_types: Union[Tuple[Any], Any], raise_error: bool = True, warning: bool = True) -> bool:
325
- """
326
- Check if an instance is an acceptable type.
327
-
328
- :param str name: Arbitrary name of instance used for interpretable error msg. Can also be the name of the method.
329
- :param object instance: A data object.
330
- :param Union[Tuple[object], object] accepted_types: Accepted instance types. E.g., (Polygon, pd.DataFrame) or Polygon.
331
- :param Optional[bool] raise_error: If True, raises error of instance is not of valid type, else returns bool.
332
- :param Optional[bool] warning: If True, prints warning of instance is not of valid type, else returns bool.
333
- """
334
-
335
- if not isinstance(instance, accepted_types):
336
- msg = f"{source} requires {accepted_types}, got {type(instance)}"
337
- if raise_error:
338
- raise InvalidInputError(msg=msg, source=source)
339
- else:
340
- if warning:
341
- InvalidValueWarning(msg=msg, source=source)
342
- return False
343
- return True
344
-
345
-
346
- def get_fn_ext(filepath: Union[os.PathLike, str]) -> (str, str, str):
347
- """
348
- Split file path into three components: (i) directory, (ii) file name, and (iii) file extension.
349
-
350
- :parameter str filepath: Path to file.
351
- :return str: File directory name
352
- :return str: File name
353
- :return str: File extension
354
-
355
- :example:
356
- >>> get_fn_ext(filepath='C:/My_videos/MyVideo.mp4')
357
- >>> ('My_videos', 'MyVideo', '.mp4')
358
- """
359
- file_extension = Path(filepath).suffix
360
- try:
361
- file_name = os.path.basename(filepath.rsplit(file_extension, 1)[0])
362
- except ValueError:
363
- raise InvalidFilepathError(
364
- msg=f"{filepath} is not a valid filepath", source=get_fn_ext.__name__
365
- )
366
- dir_name = os.path.dirname(filepath)
367
- return dir_name, file_name, file_extension
368
-
369
-
370
- def check_if_filepath_list_is_empty(filepaths: List[str], error_msg: str) -> None:
371
- """
372
- Check if a list is empty
373
-
374
- :param List[str]: List of file-paths.
375
- :raise NoFilesFoundError: The list is empty.
376
- """
377
-
378
- if len(filepaths) == 0:
379
- raise NoFilesFoundError(
380
- msg=error_msg, source=check_if_filepath_list_is_empty.__name__
381
- )
382
- else:
383
- pass
384
-
385
-
386
- def check_all_file_names_are_represented_in_video_log(
387
- video_info_df: pd.DataFrame, data_paths: List[Union[str, os.PathLike]]
388
- ) -> None:
389
- """
390
- Helper to check that all files are represented in a dataframe of the SimBA `project_folder/logs/video_info.csv`
391
- file.
392
-
393
- :param pd.DataFrame video_info_df: List of file-paths.
394
- :param List[Union[str, os.PathLike]] data_paths: List of file-paths.
395
- :raise ParametersFileError: The list is empty.
396
- """
397
-
398
- missing_videos = []
399
- for file_path in data_paths:
400
- video_name = get_fn_ext(file_path)[1]
401
- if video_name not in list(video_info_df["Video"]):
402
- missing_videos.append(video_name)
403
- if len(missing_videos) > 0:
404
- raise ParametersFileError(
405
- msg=f"SimBA could not find {len(missing_videos)} video(s) in the video_info.csv file. Make sure all videos analyzed are represented in the project_folder/logs/video_info.csv file. MISSING VIDEOS: {missing_videos}"
406
- )
407
-
408
-
409
- def check_if_dir_exists(in_dir: Union[str, os.PathLike],
410
- source: Optional[str] = None,
411
- create_if_not_exist: Optional[bool] = False,
412
- raise_error: bool = True) -> Union[None, bool]:
413
- """
414
- Check if a directory path exists.
415
-
416
- :param Union[str, os.PathLike] in_dir: Putative directory path.
417
- :param Optional[str] source: String source for interpretable error messaging.
418
- :param Optional[bool] create_if_not_exist: If directory does not exist, then create it. Default False.
419
- :param Optional[bool] raise_error: If True, raise error if dir does not exist. If False return None. Default True.
420
- :raise NotDirectoryError: The directory does not exist.
421
- """
422
-
423
- if not isinstance(in_dir, (str, Path, os.PathLike)):
424
- if raise_error:
425
- raise NotDirectoryError(msg=f"{in_dir} is not a valid directory", source=check_if_dir_exists.__name__)
426
- else:
427
- return False
428
-
429
- elif not os.path.isdir(in_dir):
430
- if create_if_not_exist:
431
- try:
432
- os.makedirs(in_dir)
433
- except:
434
- pass
435
- else:
436
- if source is None:
437
- if raise_error:
438
- raise NotDirectoryError(msg=f"{in_dir} is not a valid directory", source=check_if_dir_exists.__name__)
439
- else:
440
- return False
441
- else:
442
- if raise_error:
443
- raise NotDirectoryError(msg=f"{in_dir} is not a valid directory", source=source)
444
- else:
445
- return False
446
- else:
447
- return True
448
-
449
-
450
- def check_that_column_exist(df: pd.DataFrame,
451
- column_name: Union[str, os.PathLike, List[str]],
452
- file_name: str,
453
- raise_error: bool = True) -> Union[None, bool]:
454
- """
455
- Check if single named field or a list of fields exist within a dataframe.
456
-
457
- .. seealso::
458
- Consider :func:`simba.utils.checks.check_valid_dataframe` instead.
459
-
460
- :param pd.DataFrame df: The DataFrame to check for column existence.
461
- :param Union[str, os.PathLike, List[str]] column_name: Name or names of field(s) to check for existence.
462
- :param str file_name: Path of ``df`` on disk (used for error messages).
463
- :param bool raise_error: If True, raises ColumnNotFoundError if column doesn't exist. If False, returns bool. Default: True.
464
- :return: True if all columns exist, False if any column is missing (when raise_error=False), None if raise_error=True and all columns exist.
465
- :rtype: Union[None, bool]
466
- :raises ColumnNotFoundError: The ``column_name`` does not exist within ``df``.
467
-
468
- :example:
469
- >>> df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
470
- >>> check_that_column_exist(df=df, column_name='A', file_name='test.csv')
471
- True
472
- >>> check_that_column_exist(df=df, column_name=['A', 'B'], file_name='test.csv')
473
- True
474
- >>> check_that_column_exist(df=df, column_name='C', file_name='test.csv', raise_error=False)
475
- False
476
- """
477
-
478
- if type(column_name) == str:
479
- column_name = [column_name]
480
- for column in column_name:
481
- if column not in df.columns:
482
- if raise_error:
483
- raise ColumnNotFoundError(column_name=column, file_name=file_name, source=check_that_column_exist.__name__)
484
- else:
485
- return False
486
- return True
487
-
488
-
489
- def check_if_valid_input(
490
- name: str, input: str, options: List[str], raise_error: bool = True
491
- ) -> (bool, str):
492
- """
493
- Check if string variable is valid option.
494
-
495
- .. seealso::
496
- Consider :func:`simba.utils.checks.check_str`.
497
-
498
- :param str name: Atrbitrary name of variable.
499
- :param Any input: Value of variable.
500
- :param List[str] options: Allowed options of ``input``
501
- :param Optional[bool] raise_error: If True, then raise error if invalid value. Default: True.
502
-
503
- :return bool: False if invalid. True if valid.
504
- :return str: If invalid, then error msg. Else, empty str.
505
-
506
- :example:
507
- >>> check_if_valid_input(name='split_eval', input='gini', options=['entropy', 'gini'])
508
- >>> (True, '')
509
- """
510
-
511
- msg = ""
512
- if input not in options:
513
- msg = f"{name} is set to {str(input)}, which is an invalid setting. OPTIONS {options}"
514
- if raise_error:
515
- raise InvalidInputError(msg=msg, source=check_if_valid_input.__name__)
516
- else:
517
- return False, msg
518
- else:
519
- return True, msg
520
-
521
-
522
- def check_minimum_roll_windows(
523
- roll_windows_values: List[int], minimum_fps: float
524
- ) -> List[int]:
525
- """
526
- Remove any rolling temporal window that are shorter than a single frame in
527
- any of the videos within the project.
528
-
529
- :param List[int] roll_windows_values: Rolling temporal windows represented as frame counts. E.g., [10, 15, 30, 60]
530
- :param float minimum_fps: The lowest fps of the videos that are to be analyzed. E.g., 10.
531
-
532
- :return List[int]: roll_windows_values without impassable windows.
533
- """
534
-
535
- for win in range(len(roll_windows_values)):
536
- if minimum_fps < roll_windows_values[win]:
537
- roll_windows_values[win] = minimum_fps
538
- else:
539
- pass
540
- roll_windows_values = list(set(roll_windows_values))
541
- return roll_windows_values
542
-
543
-
544
- def check_same_number_of_rows_in_dfs(dfs: List[pd.DataFrame]) -> bool:
545
- """
546
- Helper to check that each dataframe in list contains an equal number of rows
547
-
548
- :param List[pd.DataFrame] dfs: List of dataframes.
549
- :return bool: True if dataframes has an equal number of rows. Else False.
550
-
551
- >>> df_1, df_2 = pd.DataFrame([[1, 2], [1, 2]]), pd.DataFrame([[4, 2], [9, 3], [1, 5]])
552
- >>> check_same_number_of_rows_in_dfs(dfs=[df_1, df_2])
553
- >>> False
554
- >>> df_1, df_2 = pd.DataFrame([[1, 2], [1, 2]]), pd.DataFrame([[4, 2], [9, 3]])
555
- >>> True
556
- """
557
-
558
- row_cnt = None
559
- for df_cnt, df in enumerate(dfs):
560
- if df_cnt == 0:
561
- row_cnt = len(df)
562
- else:
563
- if len(df) != row_cnt:
564
- return False
565
- return True
566
-
567
-
568
- def check_if_headers_in_dfs_are_unique(dfs: List[pd.DataFrame]) -> List[str]:
569
- """
570
- Helper to check heaaders in multiple dataframes are unique.
571
-
572
- :param List[pd.DataFrame] dfs: List of dataframes.
573
- :return List[str]: List of columns headers seen in multiple dataframes. Empty if None.
574
-
575
- :examples:
576
- >>> df_1, df_2 = pd.DataFrame([[1, 2]], columns=['My_column_1', 'My_column_2']), pd.DataFrame([[4, 2]], columns=['My_column_3', 'My_column_1'])
577
- >>> check_if_headers_in_dfs_are_unique(dfs=[df_1, df_2])
578
- >>> ['My_column_1']
579
- """
580
- seen_headers = []
581
- for df_cnt, df in enumerate(dfs):
582
- seen_headers.extend(list(df.columns))
583
- duplicates = list(set([x for x in seen_headers if seen_headers.count(x) > 1]))
584
- return duplicates
585
-
586
-
587
- def check_if_string_value_is_valid_video_timestamp(value: str, name: str) -> None:
588
- """
589
- Helper to check if a string is in a valid HH:MM:SS format
590
-
591
- :param str value: Timestamp in HH:MM:SS format.
592
- :param str name: An arbitrary string name of the timestamp.
593
- :raises InvalidInputError: If the timestamp is in invalid format
594
-
595
- :example:
596
- >>> check_if_string_value_is_valid_video_timestamp(value='00:0b:10', name='My time stamp')
597
- >>> "InvalidInputError: My time stamp is should be in the format XX:XX:XX where X is an integer between 0-9"
598
- >>> check_if_string_value_is_valid_video_timestamp(value='00:00:10', name='My time stamp'
599
- """
600
- r = re.compile(r"^\d{2}:\d{2}:\d{2}(\.\d+)?$")
601
- if not r.match(value):
602
- raise InvalidInputError(
603
- msg=f"{name} should be in the format XX:XX:XX:XXXX or XX:XX:XX where X is an integer between 0-9. Got: {value}",
604
- source=check_if_string_value_is_valid_video_timestamp.__name__,
605
- )
606
- else:
607
- pass
608
-
609
-
610
- def check_that_hhmmss_start_is_before_end(
611
- start_time: str, end_time: str, name: str
612
- ) -> None:
613
- """
614
- Helper to check that a start time in HH:MM:SS or HH:MM:SS:MS format is before an end time in HH:MM:SS or HH:MM:SS:MS format
615
-
616
- :param str start_time: Period start time in HH:MM:SS format.
617
- :param str end_time: Period end time in HH:MM:SS format.
618
- :param int name: Name of the variable
619
- :raises InvalidInputError: If end time is before the start time.
620
-
621
- :example:
622
- >>> check_that_hhmmss_start_is_before_end(start_time='00:00:05', end_time='00:00:01', name='My time period')
623
- >>> "InvalidInputError: My time period has an end-time which is before the start-time"
624
- >>> check_that_hhmmss_start_is_before_end(start_time='00:00:01', end_time='00:00:05')
625
- """
626
-
627
- if len(start_time.split(":")) != 3:
628
- raise InvalidInputError(
629
- f"Invalid time-stamp: ({start_time}). HH:MM:SS or HH:MM:SS.MS format required"
630
- )
631
- elif len(end_time.split(":")) != 3:
632
- raise InvalidInputError(
633
- f"Invalid time-stamp: ({end_time}). HH:MM:SS or HH:MM:SS.MS format required"
634
- )
635
- start_h, start_m, start_s = start_time.split(":")
636
- end_h, end_m, end_s = end_time.split(":")
637
- start_val = int(start_h) * 3600 + int(start_m) * 60 + float(start_s)
638
- end_val = int(end_h) * 3600 + int(end_m) * 60 + float(end_s)
639
- if end_val < start_val:
640
- raise InvalidInputError(
641
- f"{name} has an end-time which is before the start-time.",
642
- source=check_that_hhmmss_start_is_before_end.__name__,
643
- )
644
-
645
-
646
- def check_nvidea_gpu_available(raise_error: bool = False) -> bool:
647
- """
648
- Helper to check of NVIDEA GPU is available via ``nvidia-smi``.
649
- returns bool: True if nvidia-smi returns not None. Else False.
650
- """
651
- try:
652
- subprocess.check_output("nvidia-smi")
653
- return True
654
- except Exception:
655
- if raise_error:
656
- raise SimBAGPUError(msg='No NVIDIA GPU detected on machine (checked by calling "nvidia-smi")', source=check_nvidea_gpu_available.__name__)
657
- return False
658
-
659
-
660
- def check_ffmpeg_available(raise_error: Optional[bool] = False) -> Union[bool, None]:
661
- """
662
- Helper to check of FFMpeg is available via subprocess ``ffmpeg``.
663
-
664
- .. seealso::
665
- To check which encoders are available in FFMpeg installation, see :func:`simba.utils.lookups.get_ffmpeg_encoders`
666
-
667
- :param Optional[bool] raise_error: If True, raises ``FFMPEGNotFoundError`` if FFmpeg can't be found. Else return False. Default False.
668
- :returns bool: True if ``ffmpeg`` returns not None and raise_error is False. Else False.
669
- """
670
-
671
- try:
672
- subprocess.call("ffmpeg", stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
673
- return True
674
- except Exception:
675
- if raise_error:
676
- raise FFMPEGNotFoundError(
677
- msg="FFMpeg could not be found on the instance (as evaluated via subprocess ffmpeg). Please make sure FFMpeg is installed."
678
- )
679
- else:
680
- return False
681
-
682
- def check_if_valid_rgb_str(
683
- input: str,
684
- delimiter: str = ",",
685
- return_cleaned_rgb_tuple: bool = True,
686
- reverse_returned: bool = True,
687
- ):
688
- """
689
- Helper to check if a string is a valid representation of an RGB color.
690
-
691
- :param str input: Value to check as string. E.g., '(166, 29, 12)' or '22,32,999'
692
- :param str delimiter: The delimiter between subsequent values in the rgb input string.
693
- :param bool return_cleaned_rgb_tuple: If True, and input is a valid rgb, then returns a "clean" rgb tuple: Eg. '166, 29, 12' -> (166, 29, 12). Else, returns None.
694
- :param bool reverse_returned: If True and return_cleaned_rgb_tuple is True, reverses to returned cleaned rgb tuple (e.g., RGB becomes BGR) before returning it.
695
-
696
- :example:
697
- >>> check_if_valid_rgb_str(input='(50, 25, 100)', return_cleaned_rgb_tuple=True, reverse_returned=True)
698
- >>> (100, 25, 50)
699
- """
700
-
701
- input = input.replace(" ", "")
702
- if input.count(delimiter) != 2:
703
- raise InvalidInputError(msg=f"{input} in not a valid RGB color")
704
- values = input.split(",")
705
- rgb = []
706
- for value in values:
707
- val = "".join(c for c in value if c.isdigit())
708
- check_int(
709
- name="RGB value", value=val, max_value=255, min_value=0, raise_error=True
710
- )
711
- rgb.append(val)
712
- rgb = tuple([int(x) for x in rgb])
713
-
714
- if return_cleaned_rgb_tuple:
715
- if reverse_returned:
716
- rgb = rgb[::-1]
717
- return rgb
718
-
719
-
720
- def check_if_valid_rgb_tuple(data: Tuple[int, int, int],
721
- raise_error: bool = True,
722
- source: Optional[str] = None) -> bool:
723
- check_instance(source=check_if_valid_rgb_tuple.__name__, instance=data, accepted_types=tuple, raise_error=raise_error)
724
- check_iterable_length(source=check_if_valid_rgb_tuple.__name__, val=len(data), exact_accepted_length=3, raise_error=raise_error)
725
- for i in range(len(data)):
726
- if source is None:
727
- check_int(name="RGB value", value=data[i], max_value=255, min_value=0, raise_error=raise_error)
728
- else:
729
- check_int(name=f"RGB value {source}", value=data[i], max_value=255, min_value=0, raise_error=raise_error)
730
- return True
731
-
732
-
733
- def check_if_list_contains_values(
734
- data: List[Union[float, int, str]],
735
- values: List[Union[float, int, str]],
736
- name: str,
737
- raise_error: bool = True,
738
- ) -> None:
739
- """
740
- Helper to check if values are represeted in a list. E.g., make sure annotatations of behvaior absent and present are represented in annitation column
741
-
742
- :param List[Union[float, int, str]] data: List of values. E.g., annotation column represented as list.
743
- :param List[Union[float, int, str]] values: Values to conform present. E.g., [0, 1].
744
- :param str name: Arbitrary name of the data for more useful error msg.
745
- :param bool raise_error: If True, raise error of not all values can be found in data. Else, print warning.
746
-
747
- :example:
748
- >>> check_if_list_contains_values(data=[1,2, 3, 4, 0], values=[0, 1, 6], name='My_data')
749
- """
750
-
751
- data, missing_values = list(set(data)), []
752
- for value in values:
753
- if value not in data:
754
- missing_values.append(value)
755
-
756
- if len(missing_values) > 0 and raise_error:
757
- raise NoDataError(
758
- msg=f"{name} does not contain the following expected values: {missing_values}",
759
- source=check_if_list_contains_values.__name__,
760
- )
761
-
762
- elif len(missing_values) > 0 and not raise_error:
763
- NoDataFoundWarning(
764
- msg=f"{name} does not contain the following expected values: {missing_values}",
765
- source=check_if_list_contains_values.__name__,
766
- )
767
-
768
-
769
- def check_valid_hex_color(color_hex: str, raise_error: Optional[bool] = True) -> bool:
770
- """
771
- Check if given string represents a valid hexadecimal color code.
772
-
773
- :param str color_hex: A string representing a hexadecimal color code, either in the format '#RRGGBB' or '#RGB'.
774
- :param bool raise_error: If True, raise an exception when the color_hex is invalid; if False, return False instead. Default is True.
775
- :return bool: True if the color_hex is a valid hexadecimal color code; False otherwise (if raise_error is False).
776
- :raises IntegerError: If the color_hex is an invalid hexadecimal color code and raise_error is True.
777
- """
778
-
779
- hex_regex = re.compile(r"^#([0-9a-fA-F]{6}|[0-9a-fA-F]{3})$")
780
- match = hex_regex.match(color_hex)
781
- if match is None and raise_error:
782
- raise IntegerError(
783
- msg=f"{color_hex} is an invalid hex color",
784
- source=check_valid_hex_color.__name__,
785
- )
786
- elif match is None and not raise_error:
787
- return False
788
- else:
789
- return True
790
-
791
- def check_valid_url(url: str) -> bool:
792
- """ Helper to check if a string is a valid url"""
793
- regex = re.compile(
794
- r'^(https?|ftp)://' # protocol
795
- r'(\S+(:\S*)?@)?' # user:password (optional)
796
- r'((\d{1,3}\.){3}\d{1,3}|' # IP address
797
- r'([a-zA-Z0-9.-]+\.[a-zA-Z]{2,}))' # domain name
798
- r'(:\d+)?' # port (optional)
799
- r'(/[\S]*)?$', # path (optional)
800
- re.IGNORECASE)
801
- return re.match(regex, url) is not None
802
-
803
-
804
- def check_if_2d_array_has_min_unique_values(data: np.ndarray, min: int) -> bool:
805
- """
806
- Check if a 2D NumPy array has at least a minimum number of unique rows.
807
-
808
- For example, use when creating shapely Polygons or Linestrings, which typically requires at least 2 or three unique
809
- body-part coordinates.
810
-
811
- :param np.ndarray data: Input 2D array to be checked.
812
- :param np.ndarray min: Minimum number of unique rows required.
813
- :return bool: True if the input array has at least the specified minimum number of unique rows, False otherwise.
814
-
815
- :example:
816
- >>> data = np.array([[0, 0], [0, 0], [0, 0], [0, 1]])
817
- >>> check_if_2d_array_has_min_unique_values(data=data, min=2)
818
- >>> True
819
- """
820
-
821
- if len(data.shape) != 2:
822
- raise CountError(
823
- msg=f"Requires input array of two dimensions, found {data.size}",
824
- source=check_if_2d_array_has_min_unique_values.__name__,
825
- )
826
- sliced_data = np.unique(data, axis=0)
827
- if sliced_data.shape[0] < min:
828
- return False
829
- else:
830
- return True
831
-
832
-
833
- def check_if_module_has_import(parsed_file: ast.Module, import_name: str) -> bool:
834
- """
835
- Check if a Python module has a specific import statement. For example, check if module imports `argparse` or circular statistics mixin.
836
-
837
- Used for e.g., user custom feature extraction classes in ``simba.utils.custom_feature_extractor.CustomFeatureExtractor``.
838
-
839
- :parameter ast.Module file_path: The abstract syntax tree (AST) of the Python module.
840
- :parameter str import_name: The name of the module or package to check for in the import statements.
841
- :parameter bool: True if the specified import is found in the module, False otherwise.
842
-
843
- :example:
844
- >>> parsed_file = ast.parse(Path('/simba/misc/piotr.py').read_text())
845
- >>> check_if_module_has_import(parsed_file=parsed_file, import_name='argparse')
846
- >>> True
847
- """
848
- imports = [
849
- n for n in parsed_file.body if isinstance(n, (ast.Import, ast.ImportFrom))
850
- ]
851
- for i in imports:
852
- for name in i.names:
853
- if name.name == import_name:
854
- return True
855
- return False
856
-
857
-
858
- def check_valid_extension(
859
- path: Union[str, os.PathLike], accepted_extensions: Union[List[str], str]
860
- ):
861
- """
862
- Checks if the file extension of the provided path is in the list of accepted extensions.
863
-
864
- :param Union[str, os.PathLike] file_path: The path to the file whose extension needs to be checked.
865
- :param List[str] accepted_extensions: A list of accepted file extensions. E.g., ['pickle', 'csv'].
866
- """
867
- if isinstance(accepted_extensions, (list, tuple)):
868
- check_valid_lst(data=accepted_extensions, source=f"{check_valid_extension.__name__} accepted_extensions", valid_dtypes=(str,), min_len=1)
869
- elif isinstance(accepted_extensions, str):
870
- check_str(name=f"{check_valid_extension.__name__} accepted_extensions", value=accepted_extensions)
871
- accepted_extensions = [accepted_extensions]
872
- accepted_extensions = [x.lower() for x in accepted_extensions]
873
- check_file_exist_and_readable(file_path=path)
874
- extension = get_fn_ext(filepath=path)[2][1:]
875
- if extension.lower() not in accepted_extensions:
876
- raise InvalidFilepathError(msg=f"File extension for file {path} has an invalid extension. Found {extension}, accepted: {accepted_extensions}", source=check_valid_extension.__name__)
877
-
878
-
879
- def check_if_valid_img(data: np.ndarray,
880
- source: str = "",
881
- raise_error: bool = True,
882
- greyscale: bool = False,
883
- color: bool = False) -> Union[bool, None]:
884
- """
885
- Check if a variable is a valid image.
886
-
887
- :param str source: Name of the variable and/or class origin for informative error messaging and logging.
888
- :param np.ndarray data: Data variable to check if a valid image representation.
889
- :param bool greyscale: Checks that the image is greyscale. Default False.
890
- :param bool color: Checks that the image is color. Default False.
891
- :parameter bool raise_error: If True, raise InvalidInputError if invalid image representation. Else, return bool.
892
- """
893
-
894
- check_instance(source=check_if_valid_img.__name__, instance=data, accepted_types=(np.ndarray, cp.ndarray))
895
- if (data.ndim != 2) and (data.ndim != 3):
896
- if raise_error:
897
- raise InvalidInputError(msg=f"The {source} data is not a valid image. It has {data.ndim} dimensions", source=check_if_valid_img.__name__)
898
- else:
899
- return False
900
- if data.dtype not in [np.uint8, np.uint16, np.float32, np.float64]:
901
- if raise_error:
902
- raise InvalidInputError(msg=f"The {source} data is not a valid image. It is dtype {data.dtype}", source=check_if_valid_img.__name__)
903
- else:
904
- return False
905
- if np.max(data) > 255:
906
- if raise_error:
907
- raise InvalidInputError(msg=f"The {source} data is not a valid image. Values found that are above 255: {np.max(data)}", source=check_if_valid_img.__name__)
908
- if greyscale:
909
- if (data.ndim != 2):
910
- if raise_error:
911
- raise InvalidInputError(msg=f"The {source} image is not a greyscale image. Got {data.ndim} dimensions", source=check_if_valid_img.__name__)
912
- else:
913
- return False
914
- if color:
915
- if (data.ndim != 3):
916
- if raise_error:
917
- raise InvalidInputError(msg=f"The {source} image is not a color image. Got {data.ndim} dimensions", source=check_if_valid_img.__name__)
918
- else:
919
- return False
920
-
921
-
922
-
923
- return True
924
-
925
-
926
- def check_that_dir_has_list_of_filenames(
927
- dir: Union[str, os.PathLike],
928
- file_name_lst: List[str],
929
- file_type: Optional[str] = "csv",
930
- ):
931
- """
932
- Check that all file names in a list has an equivalent file in a specified directory. E.g., check if all files in the outlier corrected folder has an equivalent file in the featurues_extracted directory.
933
-
934
- :example:
935
- >>> file_name_lst = glob.glob('/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement' + '/*.csv')
936
- >>> check_that_dir_has_list_of_filenames(dir = '/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/features_extracted', file_name_lst=file_name_lst)
937
- """
938
-
939
- files_in_dir = glob.glob(dir + f"/*.{file_type}")
940
- files_in_dir = [os.path.basename(x) for x in files_in_dir]
941
- for file_name in file_name_lst:
942
- if os.path.basename(file_name) not in files_in_dir:
943
- raise NoFilesFoundError(msg=f"File name {os.path.basename(file_name)} could not be found in the directory {dir}", source=check_that_dir_has_list_of_filenames.__name__)
944
-
945
-
946
- def check_valid_array(data: np.ndarray,
947
- source: Optional[str] = "",
948
- accepted_ndims: Optional[Union[Tuple[int], Any]] = None,
949
- accepted_sizes: Optional[List[int]] = None,
950
- accepted_axis_0_shape: Optional[Union[List[int], Tuple[int]]] = None,
951
- accepted_axis_1_shape: Optional[Union[List[int], Tuple[int]]] = None,
952
- accepted_dtypes: Optional[Union[List[Union[str, Type]], Tuple[Union[str, Type]], Iterable[Any]]] = None,
953
- accepted_values: Optional[List[Any]] = None,
954
- accepted_shapes: Optional[List[Tuple[int]]] = None,
955
- min_axis_0: Optional[int] = None,
956
- max_axis_1: Optional[int] = None,
957
- min_axis_1: Optional[int] = None,
958
- min_value: Optional[Union[float, int]] = None,
959
- max_value: Optional[Union[float, int]] = None,
960
- raise_error: bool = True) -> Union[None, bool]:
961
- """
962
- Check if the given array satisfies specified criteria regarding its dimensions, shape, and data type.
963
-
964
- :param np.ndarray data: The numpy array to be checked.
965
- :param Optional[str] source: A string identifying the source, name, or purpose of the array for interpretable error messaging.
966
- :param Optional[Union[Tuple[int], Any]] accepted_ndims: List of tuples representing acceptable dimensions. If provided, checks whether the array's number of dimensions matches any tuple in the list.
967
- :param Optional[List[int]] accepted_sizes: List of acceptable sizes for the array's shape. If provided, checks whether the length of the array's shape matches any value in the list.
968
- :param Optional[Union[List[int], Tuple[int]]] accepted_axis_0_shape: List of accepted number of rows of 2-dimensional array. Will also raise error if value passed and input is not a 2-dimensional array.
969
- :param Optional[Union[List[int], Tuple[int]]] accepted_axis_1_shape: List of accepted number of columns or fields of 2-dimensional array. Will also raise error if value passed and input is not a 2-dimensional array.
970
- :param Optional[Union[List[Union[str, Type]], Tuple[Union[str, Type]], Iterable[Any]]] accepted_dtypes: List of acceptable data types for the array. If provided, checks whether the array's data type matches any string in the list.
971
- :param Optional[List[Any]] accepted_values: List of acceptable values that can be present in the array.
972
- :param Optional[List[Tuple[int]]] accepted_shapes: List of acceptable shapes for the array. If provided, checks whether the array's shape matches any tuple in the list.
973
- :param Optional[int] min_axis_0: Minimum number of rows required for the array.
974
- :param Optional[int] max_axis_1: Maximum number of columns allowed for the array.
975
- :param Optional[int] min_axis_1: Minimum number of columns required for the array.
976
- :param Optional[Union[float, int]] min_value: Minimum value allowed in the array.
977
- :param Optional[Union[float, int]] max_value: Maximum value allowed in the array.
978
- :param bool raise_error: If True, raises ArrayError if validation fails. If False, returns bool. Default: True.
979
- :return: True if array passes all validation checks, False if validation fails (when raise_error=False), None if raise_error=True and validation passes.
980
- :rtype: Union[None, bool]
981
-
982
- :example:
983
- >>> data = np.array([[1, 2], [3, 4]])
984
- >>> check_valid_array(data, source="Example", accepted_ndims=(2,), accepted_sizes=[2], accepted_dtypes=[np.int64])
985
- True
986
- >>> check_valid_array(data, source="Example", min_axis_0=3, raise_error=False)
987
- False
988
- """
989
-
990
- check_instance(source=source, instance=data, accepted_types=np.ndarray)
991
- if accepted_ndims is not None:
992
- if data.ndim not in accepted_ndims:
993
- if raise_error:
994
- raise ArrayError(msg=f"Array not of acceptable dimensions. Found {data.ndim}, accepted: {accepted_ndims}: {source}", source=check_valid_array.__name__)
995
- else:
996
- return False
997
- if accepted_sizes is not None:
998
- if len(data.shape) not in accepted_sizes:
999
- if raise_error:
1000
- raise ArrayError(msg=f"Array not of acceptable size. Found {len(data.shape)}, accepted: {accepted_sizes}: {source}", source=check_valid_array.__name__)
1001
- else:
1002
- return False
1003
- if accepted_dtypes is not None:
1004
- if data.dtype not in accepted_dtypes:
1005
- if raise_error:
1006
- raise ArrayError(msg=f"Array not of acceptable type. Found {data.dtype}, accepted: {accepted_dtypes}: {source}", source=check_valid_array.__name__)
1007
- else:
1008
- return False
1009
- if accepted_shapes is not None:
1010
- if data.shape not in accepted_shapes:
1011
- if raise_error:
1012
- raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape}, accepted: {accepted_shapes}: {source}", source=check_valid_array.__name__)
1013
- else:
1014
- return False
1015
- if accepted_axis_0_shape is not None:
1016
- if not isinstance(accepted_axis_0_shape, (tuple, list)):
1017
- raise InvalidInputError(msg=f"accepted_axis_0_shape is invalid format. Accepted: {'list, tuple'}. Got: {type(accepted_axis_0_shape)}, {source}", source=check_valid_array.__name__)
1018
- for cnt, i in enumerate(accepted_axis_0_shape):
1019
- check_int(name=f"{source} {cnt} accepted_axis_0_shape", value=i, min_value=1)
1020
- # if data.ndim != 2:
1021
- # raise ArrayError(
1022
- # msg=f"Array not of acceptable dimension. Found {data.ndim}, accepted: 2, {source}",
1023
- # source=check_valid_array.__name__,
1024
- # )
1025
- if data.shape[0] not in accepted_axis_0_shape:
1026
- if raise_error:
1027
- raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape[0]} rows, accepted: {accepted_axis_0_shape}, {source}", source=check_valid_array.__name__)
1028
- else:
1029
- return False
1030
- if accepted_axis_1_shape is not None:
1031
- if not isinstance(accepted_axis_1_shape, (tuple, list)):
1032
- raise InvalidInputError(msg=f"accepted_axis_1_shape is invalid format. Accepted: {'list, tuple'}. Got: {type(accepted_axis_1_shape)}, {source}", source=check_valid_array.__name__)
1033
- for cnt, i in enumerate(accepted_axis_1_shape):
1034
- check_int(name=f"{source} {cnt} accepted_axis_1_shape", value=i, min_value=1)
1035
- if data.ndim != 2:
1036
- raise ArrayError(msg=f"Array not of acceptable dimension. Found {data.ndim}, accepted: 2, {source}", source=check_valid_array.__name__,)
1037
- elif data.shape[1] not in accepted_axis_1_shape:
1038
- if raise_error:
1039
- raise ArrayError( msg=f"Array not of acceptable shape. Found {data.shape[0]} columns (axis=1), accepted: {accepted_axis_1_shape}, {source}", source=check_valid_array.__name__)
1040
- else:
1041
- return False
1042
- if min_axis_0 is not None:
1043
- check_int(name=f"{source} min_axis_0", value=min_axis_0)
1044
- if data.shape[0] < min_axis_0:
1045
- if raise_error:
1046
- raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape[0]} rows, minimum accepted: {min_axis_0}, {source}", source=check_valid_array.__name__)
1047
- else:
1048
- return False
1049
-
1050
- if max_axis_1 is not None and data.ndim > 1:
1051
- check_int(name=f"{source} max_axis_1", value=max_axis_1)
1052
- if data.shape[1] > max_axis_1:
1053
- if raise_error:
1054
- raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape[1]} columns, maximum columns accepted: {max_axis_1}, {source}", source=check_valid_array.__name__)
1055
- else:
1056
- return False
1057
- if min_axis_1 is not None and data.ndim > 1:
1058
- check_int(name=f"{source} min_axis_1", value=min_axis_1)
1059
- if data.shape[1] < min_axis_1:
1060
- if raise_error:
1061
- raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape[1]} columns, minimum columns accepted: {min_axis_1}, {source}", source=check_valid_array.__name__)
1062
- else:
1063
- return False
1064
- if accepted_values is not None:
1065
- check_valid_lst(data=accepted_values, source=f"{source} accepted_values")
1066
- additional_vals = list(set(np.unique(data)) - set(accepted_values))
1067
- if len(additional_vals) > 0:
1068
- if raise_error:
1069
- raise ArrayError(msg=f"Array contains unacceptable values. Found {additional_vals}, accepted: {accepted_values}, {source}", source=check_valid_array.__name__,)
1070
- return False
1071
-
1072
- if min_value is not None:
1073
- check_float(name=f'{source} min_value', value=min_value)
1074
- if np.min(data) < min_value:
1075
- if raise_error:
1076
- raise ArrayError(msg=f"Array contains value below accepted value. Found {np.min(data)}, accepted minimum: {min_value}, {source}", source=check_valid_array.__name__, )
1077
- else:
1078
- return False
1079
-
1080
- if max_value is not None:
1081
- check_float(name=f'{source} max_value', value=max_value)
1082
- if np.max(data) > max_value:
1083
- if raise_error:
1084
- raise ArrayError(msg=f"Array contains value above accepted maximum value. Found {np.max(data)}, accepted minimum: {max_value}, {source}", source=check_valid_array.__name__, )
1085
- else:
1086
- return False
1087
- else:
1088
- return True
1089
-
1090
- def check_valid_lst(data: list,
1091
- source: Optional[str] = "",
1092
- valid_dtypes: Optional[Union[Tuple[Any], List[Any], Any]] = None,
1093
- valid_values: Optional[List[Any]] = None,
1094
- min_len: Optional[int] = 1,
1095
- max_len: Optional[int] = None,
1096
- min_value: Optional[float] = None,
1097
- exact_len: Optional[int] = None,
1098
- raise_error: Optional[bool] = True) -> bool:
1099
- """
1100
- Check the validity of a list based on passed criteria.
1101
-
1102
- :param list data: The input list to be validated.
1103
- :param Optional[str] source: A string indicating the source or context of the data for informative error messaging.
1104
- :param Optional[Union[Tuple[Any], List[Any], Any]] valid_dtypes: A tuple, list, or single type of accepted data types. If provided, check if all elements in the list have data types in this collection.
1105
- :param Optional[List[Any]] valid_values: A list of accepted list values. If provided, check if all elements in the list have matching values in this list.
1106
- :param Optional[int] min_len: The minimum allowed length of the list. Default: 1.
1107
- :param Optional[int] max_len: The maximum allowed length of the list.
1108
- :param Optional[float] min_value: The minimum value allowed for numeric elements in the list.
1109
- :param Optional[int] exact_len: The exact length required for the list. If provided, overrides min_len and max_len.
1110
- :param Optional[bool] raise_error: If True, raise an InvalidInputError if any validation fails. If False, return False instead of raising an error. Default: True.
1111
- :return bool: True if all validation criteria are met, False otherwise.
1112
-
1113
- :example:
1114
- >>> check_valid_lst(data=[1, 2, 'three'], valid_dtypes=(int, str), min_len=2, max_len=5)
1115
- True
1116
- >>> check_valid_lst(data=[1, 2, 3], valid_dtypes=(int,), exact_len=3)
1117
- True
1118
- >>> check_valid_lst(data=[1, 2, 3], min_value=0, raise_error=False)
1119
- True
1120
- """
1121
- check_instance(source=source, instance=data, accepted_types=list)
1122
- if min_len is not None:
1123
- check_int(
1124
- name=f"{source} {min_len}",
1125
- value=min_len,
1126
- min_value=0,
1127
- raise_error=raise_error,
1128
- )
1129
- if len(data) < min_len:
1130
- if raise_error:
1131
- raise InvalidInputError(
1132
- msg=f"Invalid length of list. Found {len(data)}, minimum accepted: {min_len}",
1133
- source=source,
1134
- )
1135
- else:
1136
- return False
1137
-
1138
- check_instance(source=source, instance=data, accepted_types=list)
1139
- if valid_dtypes is not None:
1140
- for dtype in set([type(x) for x in data]):
1141
- if dtype not in valid_dtypes:
1142
- if raise_error:
1143
- raise InvalidInputError(msg=f"Invalid data type found in list. Found {dtype}, accepted: {valid_dtypes}", source=source)
1144
- else:
1145
- return False
1146
-
1147
- if max_len is not None:
1148
- check_int(
1149
- name=f"{source} {max_len}",
1150
- value=max_len,
1151
- min_value=0,
1152
- raise_error=raise_error,
1153
- )
1154
- if len(data) > max_len:
1155
- if raise_error:
1156
- raise InvalidInputError(
1157
- msg=f"Invalid length of list. Found {len(data)}, maximum accepted: {min_len}",
1158
- source=source,
1159
- )
1160
- else:
1161
- return False
1162
- if exact_len is not None:
1163
- check_int(
1164
- name=f"{source} {exact_len}",
1165
- value=exact_len,
1166
- min_value=0,
1167
- raise_error=raise_error,
1168
- )
1169
- if len(data) != exact_len:
1170
- if raise_error:
1171
- raise InvalidInputError(
1172
- msg=f"Invalid length of list. Found {len(data)}, accepted: {exact_len}",
1173
- source=source,
1174
- )
1175
- else:
1176
- return False
1177
-
1178
- if valid_values != None:
1179
- check_valid_lst(
1180
- data=valid_values, source=check_valid_lst.__name__, min_len=1
1181
- )
1182
- invalids = list(set(data) - set(valid_values))
1183
- if len(invalids):
1184
- if raise_error:
1185
- raise InvalidInputError(
1186
- msg=f"Invalid list entries. Found {invalids}, accepted: {valid_values}",
1187
- source=source,
1188
- )
1189
- else:
1190
- return False
1191
-
1192
- if min_value != None:
1193
- check_float(name=check_valid_lst.__name__, value=min_value)
1194
- invalids = [x for x in data if x < min_value]
1195
- if len(invalids) > 0:
1196
- if raise_error:
1197
- raise InvalidInputError(msg=f"Invalid list entries. Found {invalids}, minimum accepted value: {min_value}", source=source)
1198
- else:
1199
- return False
1200
-
1201
- return True
1202
-
1203
-
1204
- def check_if_keys_exist_in_dict(
1205
- data: dict,
1206
- key: Union[str, int, tuple, List],
1207
- name: Optional[str] = "",
1208
- raise_error: Optional[bool] = True,
1209
- ) -> bool:
1210
- """
1211
- Check if one or more keys exist in a dictionary.
1212
-
1213
- This function validates that all specified keys are present in the given dictionary.
1214
- It can check for a single key or multiple keys at once.
1215
-
1216
- .. seealso::
1217
- Consider :func:`simba.utils.checks.check_valid_dict`
1218
-
1219
- :param dict data: The dictionary to check for key existence.
1220
- :param Union[str, int, tuple, List] key: The key(s) to check for in the dictionary. Can be a single key or a list/tuple of keys.
1221
- :param Optional[str] name: A string identifying the source or context of the data for informative error messaging. Default: "".
1222
- :param Optional[bool] raise_error: If True, raises InvalidInputError if any key is missing. If False, returns False instead of raising an error. Default: True.
1223
- :return bool: True if all keys exist in the dictionary, False if any key is missing (when raise_error=False).
1224
- :raises InvalidInputError: If any of the specified keys do not exist in the dictionary and raise_error=True.
1225
-
1226
- :example:
1227
- >>> data = {'a': 1, 'b': 2, 'c': 3}
1228
- >>> check_if_keys_exist_in_dict(data=data, key='a')
1229
- True
1230
- >>> check_if_keys_exist_in_dict(data=data, key=['a', 'b'])
1231
- True
1232
- >>> check_if_keys_exist_in_dict(data=data, key='d', raise_error=False)
1233
- False
1234
- """
1235
-
1236
- check_instance(source=name, instance=data, accepted_types=(dict,))
1237
- check_instance(
1238
- source=name,
1239
- instance=key,
1240
- accepted_types=(
1241
- str,
1242
- int,
1243
- tuple,
1244
- List,
1245
- ),
1246
- )
1247
- if not isinstance(key, (list, tuple)):
1248
- key = [key]
1249
-
1250
- for k in key:
1251
- if k not in list(data.keys()):
1252
- if raise_error:
1253
- raise InvalidInputError(
1254
- msg=f"{k} does not exist in object {name}",
1255
- source=check_if_keys_exist_in_dict.__class__.__name__,
1256
- )
1257
- else:
1258
- pass
1259
- return True
1260
-
1261
-
1262
- def check_that_directory_is_empty(directory: Union[str, os.PathLike], raise_error: Optional[bool] = True) -> None:
1263
- """
1264
- Checks if a directory is empty. If the directory has content, then returns False or raises ``DirectoryNotEmptyError``.
1265
-
1266
- :param str directory: Directory to check.
1267
- :raises DirectoryNotEmptyError: If ``directory`` contains files.
1268
- """
1269
-
1270
- check_if_dir_exists(in_dir=directory)
1271
- try:
1272
- all_files_in_folder = [
1273
- f for f in next(os.walk(directory))[2] if not f[0] == "."
1274
- ]
1275
- except StopIteration:
1276
- return 0
1277
- else:
1278
- if len(all_files_in_folder) > 0:
1279
- if raise_error:
1280
- raise DirectoryNotEmptyError(
1281
- msg=f"The {directory} is not empty and contains {str(len(all_files_in_folder))} files. Use a directory that is empty.",
1282
- source=check_that_directory_is_empty.__name__,
1283
- )
1284
- else:
1285
- return False
1286
- else:
1287
- return True
1288
-
1289
-
1290
- def check_umap_hyperparameters(hyper_parameters: Dict[str, Any]) -> None:
1291
- """
1292
- Checks if dictionary of paramameters (umap, scaling, etc) are valid for grid-search umap dimensionality reduction .
1293
-
1294
- :param dict hyper_parameters: Dictionary holding umap hyerparameters.
1295
- :raises InvalidInputError: If any input is invalid
1296
-
1297
- :example:
1298
- >>> check_umap_hyperparameters(hyper_parameters={'n_neighbors': [2], 'min_distance': [0.1], 'spread': [1], 'scaler': 'MIN-MAX', 'variance': 0.2})
1299
- """
1300
- for key in UMAPParam.HYPERPARAMETERS.value:
1301
- if key not in hyper_parameters.keys():
1302
- raise InvalidInputError(
1303
- msg=f"Hyperparameter dictionary is missing {key} entry.",
1304
- source=check_umap_hyperparameters.__name__,
1305
- )
1306
- for key in [
1307
- UMAPParam.N_NEIGHBORS.value,
1308
- UMAPParam.MIN_DISTANCE.value,
1309
- UMAPParam.SPREAD.value,
1310
- ]:
1311
- if not isinstance(hyper_parameters[key], list):
1312
- raise InvalidInputError(
1313
- msg=f"Hyperparameter dictionary key {key} has to be a list but got {type(hyper_parameters[key])}.",
1314
- source=check_umap_hyperparameters.__name__,
1315
- )
1316
- if len(hyper_parameters[key]) == 0:
1317
- raise InvalidInputError(
1318
- msg=f"Hyperparameter dictionary key {key} has 0 entries.",
1319
- source=check_umap_hyperparameters.__name__,
1320
- )
1321
- for value in hyper_parameters[key]:
1322
- if not isinstance(value, (int, float)):
1323
- raise InvalidInputError(
1324
- msg=f"Hyperparameter dictionary key {key} have to have numeric entries but got {type(value)}.",
1325
- source=check_umap_hyperparameters.__name__,
1326
- )
1327
- if hyper_parameters[UMAPParam.SCALER.value] not in Options.SCALER_OPTIONS.value:
1328
- raise InvalidInputError(
1329
- msg=f"Scaler {hyper_parameters[UMAPParam.SCALER.value]} not supported. Opitions: {Options.SCALER_OPTIONS.value}",
1330
- source=check_umap_hyperparameters.__name__,
1331
- )
1332
- check_float(
1333
- "VARIANCE THRESHOLD",
1334
- value=hyper_parameters[UMAPParam.VARIANCE.value],
1335
- min_value=0.0,
1336
- max_value=100.0,
1337
- )
1338
- def check_video_has_rois(roi_dict: Dict[str, pd.DataFrame],
1339
- roi_names: List[str] = None,
1340
- video_names: List[str] = None,
1341
- source: str = 'roi dict',
1342
- raise_error: bool = True):
1343
- """
1344
- Check that specified videos all have user-defined ROIs with specified names.
1345
-
1346
- This function validates that all specified videos contain the required ROIs (Regions of Interest)
1347
- with the specified names. It checks across all ROI types: rectangles, circles, and polygons.
1348
-
1349
- .. note::
1350
- To get roi dictionary, see :func:`simba.mixins.config_reader.ConfigReader.read_roi_data`.
1351
-
1352
- :param Dict[str, pd.DataFrame] roi_dict: Dictionary containing ROI dataframes with keys for rectangles, circles, and polygons.
1353
- :param Optional[List[str]] roi_names: List of ROI names to check for. If None, uses all unique ROI names from the data. Default: None.
1354
- :param Optional[List[str]] video_names: List of video names to check. If None, uses all unique video names from the data. Default: None.
1355
- :param str source: A string identifying the source or context for informative error messaging. Default: 'roi dict'.
1356
- :param bool raise_error: If True, raises NoROIDataError if any videos are missing required ROIs. If False, returns tuple with validation result and missing ROIs. Default: True.
1357
- :return: If raise_error=True: None if all validations pass, raises exception if validation fails. If raise_error=False: Tuple of (bool, dict) where bool indicates success and dict contains missing ROIs by video.
1358
- :rtype: Union[None, Tuple[bool, Dict[str, List[str]]]]
1359
- :raises NoROIDataError: If any videos are missing required ROIs and raise_error=True.
1360
-
1361
- :example:
1362
- >>> roi_dict = {
1363
- ... 'rectangles': pd.DataFrame({'Video': ['video1'], 'Name': ['ROI1']}),
1364
- ... 'circles': pd.DataFrame({'Video': ['video1'], 'Name': ['ROI2']}),
1365
- ... 'polygons': pd.DataFrame({'Video': ['video1'], 'Name': ['ROI3']})
1366
- ... }
1367
- >>> check_video_has_rois(roi_dict=roi_dict, roi_names=['ROI1', 'ROI2'], video_names=['video1'])
1368
- True
1369
- >>> check_video_has_rois(roi_dict=roi_dict, roi_names=['ROI1', 'ROI4'], video_names=['video1'], raise_error=False)
1370
- (False, {'video1': ['ROI4']})
1371
- """
1372
-
1373
- check_valid_dict(x=roi_dict, valid_key_dtypes=(str,), valid_values_dtypes=(pd.DataFrame,), required_keys=(Keys.ROI_RECTANGLES.value, Keys.ROI_CIRCLES.value, Keys.ROI_POLYGONS.value,),)
1374
- check_valid_dataframe(df=roi_dict[Keys.ROI_RECTANGLES.value], source=f'{check_video_has_rois.__name__} {source} roi_dict {Keys.ROI_RECTANGLES.value}', required_fields=['Video', 'Name'])
1375
- check_valid_dataframe(df=roi_dict[Keys.ROI_CIRCLES.value], source=f'{check_video_has_rois.__name__} {source} roi_dict {Keys.ROI_CIRCLES.value}', required_fields=['Video', 'Name'])
1376
- check_valid_dataframe(df=roi_dict[Keys.ROI_POLYGONS.value], source=f'{check_video_has_rois.__name__} {source} roi_dict {Keys.ROI_POLYGONS.value}', required_fields=['Video', 'Name'])
1377
- if roi_names is not None:
1378
- check_valid_lst(data=roi_names, source=f'{check_video_has_rois.__name__} {source} roi_names', valid_dtypes=(str,), min_len=1)
1379
- else:
1380
- roi_names = list(set(list(roi_dict[Keys.ROI_RECTANGLES.value]['Name'].unique()) + list(roi_dict[Keys.ROI_CIRCLES.value]['Name'].unique()) + list(roi_dict[Keys.ROI_POLYGONS.value]['Name'].unique())))
1381
- if video_names is not None:
1382
- check_valid_lst(data=video_names, source=f'{check_video_has_rois.__name__} {source} video_names', min_len=1,)
1383
- else:
1384
- video_names = list(set(list(roi_dict[Keys.ROI_RECTANGLES.value]['Video'].unique()) + list(roi_dict[Keys.ROI_CIRCLES.value]['Video'].unique()) + list(roi_dict[Keys.ROI_POLYGONS.value]['Video'].unique())))
1385
- missing_rois = {}
1386
- rois_missing = False
1387
- for video_name in video_names:
1388
- missing_rois[video_name] = []
1389
- for roi_name in roi_names:
1390
- rect_filt = roi_dict[Keys.ROI_RECTANGLES.value][(roi_dict[Keys.ROI_RECTANGLES.value]['Video'] == video_name) & (roi_dict[Keys.ROI_RECTANGLES.value]['Name'] == roi_name)]
1391
- circ_filt = roi_dict[Keys.ROI_CIRCLES.value][(roi_dict[Keys.ROI_CIRCLES.value]['Video'] == video_name) & (roi_dict[Keys.ROI_CIRCLES.value]['Name'] == roi_name)]
1392
- poly_filt = roi_dict[Keys.ROI_POLYGONS.value][(roi_dict[Keys.ROI_POLYGONS.value]['Video'] == video_name) & (roi_dict[Keys.ROI_POLYGONS.value]['Name'] == roi_name)]
1393
- if (len(rect_filt) + len(circ_filt) + len(poly_filt)) == 0:
1394
- missing_rois[video_name].append(roi_name); rois_missing = True
1395
- if rois_missing and raise_error:
1396
- raise NoROIDataError(msg=f'Some videos are missing some ROIs: {missing_rois}', source=f'{check_video_has_rois.__name__} {source}')
1397
- elif rois_missing:
1398
- return False, missing_rois
1399
- else:
1400
- return True
1401
-
1402
-
1403
- def check_if_df_field_is_boolean(df: pd.DataFrame,
1404
- field: str,
1405
- raise_error: bool = True,
1406
- bool_values: Optional[Tuple[Any]] = (0, 1),
1407
- df_name: Optional[str] = ''):
1408
- """
1409
- Check if a DataFrame field contains only boolean values.
1410
-
1411
- This function validates that a specified column in a DataFrame contains only
1412
- the expected boolean values (e.g., 0/1, True/False). It checks for any
1413
- unexpected values that are not in the allowed boolean values set.
1414
-
1415
- :param pd.DataFrame df: The DataFrame to check.
1416
- :param str field: Name of the column to validate for boolean values.
1417
- :param bool raise_error: If True, raises CountError when non-boolean values are found. If False, returns False. Default: True.
1418
- :param Optional[Tuple[Any]] bool_values: Tuple of accepted boolean values. Default: (0, 1).
1419
- :param Optional[str] df_name: Name of the DataFrame for error messaging. Default: ''.
1420
- :return: True if field contains only boolean values, False if non-boolean values found and raise_error=False.
1421
- :rtype: bool
1422
- :raises CountError: If non-boolean values are found in the field and raise_error=True.
1423
-
1424
- :example:
1425
- >>> df = pd.DataFrame({'binary_col': [0, 1, 0, 1], 'mixed_col': [0, 1, 2, 0]})
1426
- >>> check_if_df_field_is_boolean(df=df, field='binary_col')
1427
- True
1428
- >>> check_if_df_field_is_boolean(df=df, field='mixed_col', raise_error=False)
1429
- False
1430
- >>> check_if_df_field_is_boolean(df=df, field='mixed_col', bool_values=(0, 1, 2))
1431
- True
1432
- """
1433
- check_instance(source=f'{check_if_df_field_is_boolean.__name__} df', instance=df, accepted_types=(pd.DataFrame,))
1434
- check_str(name=f"{check_if_df_field_is_boolean.__name__} field", value=field)
1435
- check_that_column_exist(df=df, column_name=field, file_name=check_if_df_field_is_boolean.__name__)
1436
- additional = list((set(list(df[field])) - set(bool_values)))
1437
- if len(additional) > 0:
1438
- if raise_error:
1439
- raise CountError(msg=f"Field {field} not a boolean in {df_name}. Found values {additional}. Accepted: {bool_values}", source=check_if_df_field_is_boolean.__name__)
1440
- else:
1441
- return False
1442
- return True
1443
-
1444
-
1445
- def check_valid_dataframe(
1446
- df: pd.DataFrame,
1447
- source: Optional[str] = "",
1448
- valid_dtypes: Optional[Tuple[Any]] = None,
1449
- required_fields: Optional[List[str]] = None,
1450
- min_axis_0: Optional[int] = None,
1451
- min_axis_1: Optional[int] = None,
1452
- max_axis_0: Optional[int] = None,
1453
- max_axis_1: Optional[int] = None,
1454
- allow_duplicate_col_names = True,
1455
- ):
1456
- """
1457
- Validate a DataFrame against various criteria.
1458
-
1459
- This function performs comprehensive validation of a pandas DataFrame including
1460
- data types, dimensions, required columns, and duplicate column names. It raises
1461
- exceptions for any validation failures.
1462
-
1463
- :param pd.DataFrame df: The DataFrame to validate.
1464
- :param Optional[str] source: Source identifier for error messages. Default: "".
1465
- :param Optional[Tuple[Any]] valid_dtypes: Tuple of allowed data types. If None, no dtype validation. Default: None.
1466
- :param Optional[List[str]] required_fields: List of required column names. If None, no field validation. Default: None.
1467
- :param Optional[int] min_axis_0: Minimum number of rows required. If None, no minimum row validation. Default: None.
1468
- :param Optional[int] min_axis_1: Minimum number of columns required. If None, no minimum column validation. Default: None.
1469
- :param Optional[int] max_axis_0: Maximum number of rows allowed. If None, no maximum row validation. Default: None.
1470
- :param Optional[int] max_axis_1: Maximum number of columns allowed. If None, no maximum column validation. Default: None.
1471
- :param bool allow_duplicate_col_names: If False, raises error for duplicate column names. Default: True.
1472
- :return: None if validation passes.
1473
- :rtype: None
1474
- :raises InvalidInputError: If any validation criteria are not met.
1475
-
1476
- :example:
1477
- >>> df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
1478
- >>> check_valid_dataframe(df=df, required_fields=['A', 'B'], min_axis_0=1)
1479
- >>> check_valid_dataframe(df=df, valid_dtypes=(int,), max_axis_1=2)
1480
- >>> check_valid_dataframe(df=df, allow_duplicate_col_names=False)
1481
- """
1482
- check_instance(source=source, instance=df, accepted_types=(pd.DataFrame,))
1483
- if valid_dtypes is not None:
1484
- dtypes = list(set(df.dtypes))
1485
- additional = [x for x in dtypes if x not in valid_dtypes]
1486
- if len(additional) > 0:
1487
- raise InvalidInputError(
1488
- msg=f"The dataframe {source} has invalid data format(s) {additional}. Valid: {valid_dtypes}",
1489
- source=source,
1490
- )
1491
- if min_axis_1 is not None:
1492
- check_int(name=f"{source} min_axis_1", value=min_axis_1, min_value=1)
1493
- if len(df.columns) < min_axis_1:
1494
- raise InvalidInputError(
1495
- msg=f"The dataframe {source} has less than ({df.columns}) the required minimum number of columns ({min_axis_1}).",
1496
- source=source,
1497
- )
1498
- if min_axis_0 is not None:
1499
- check_int(name=f"{source} min_axis_0", value=min_axis_0, min_value=1)
1500
- if len(df) < min_axis_0:
1501
- raise InvalidInputError(
1502
- msg=f"The dataframe {source} has less than ({len(df)}) the required minimum number of rows ({min_axis_0}).",
1503
- source=source,
1504
- )
1505
- if max_axis_0 is not None:
1506
- check_int(name=f"{source} max_axis_0", value=min_axis_0, min_value=1)
1507
- if len(df) > max_axis_0:
1508
- raise InvalidInputError(
1509
- msg=f"The dataframe {source} has more than ({len(df)}) the required maximum number of rows ({max_axis_0}).",
1510
- source=source,
1511
- )
1512
- if max_axis_1 is not None:
1513
- check_int(name=f"{source} max_axis_1", value=min_axis_1, min_value=1)
1514
- if len(df.columns) > max_axis_1:
1515
- raise InvalidInputError(
1516
- msg=f"The dataframe {source} has more than ({df.columns}) the required maximum number of columns ({max_axis_1}).",
1517
- source=source,
1518
- )
1519
- if required_fields is not None:
1520
- check_valid_lst(
1521
- data=required_fields,
1522
- source=check_valid_dataframe.__name__,
1523
- valid_dtypes=(str,),
1524
- )
1525
- missing = list(set(required_fields) - set(df.columns))
1526
- if len(missing) > 0:
1527
- raise InvalidInputError(
1528
- msg=f"The dataframe {source} are missing required columns {missing}.",
1529
- source=source,
1530
- )
1531
-
1532
- if not allow_duplicate_col_names:
1533
- col_names = list(df.columns)
1534
- seen = set()
1535
- duplicate_col_names = list(set(x for x in col_names if x in seen or seen.add(x)))
1536
- if len(duplicate_col_names) > 0:
1537
- raise InvalidInputError(msg=f"The dataframe {source} has duplicate column names {duplicate_col_names}.", source=source)
1538
-
1539
-
1540
-
1541
-
1542
- def check_valid_boolean(value: Union[Any, List[Any]], source: Optional[str] = '', raise_error: Optional[bool] = True):
1543
- """
1544
- Check if a value or list of values contains only valid boolean values.
1545
-
1546
- This function validates that the input value(s) are valid Python boolean values
1547
- (True or False). It can handle single values or lists of values, and provides
1548
- flexible error handling options.
1549
-
1550
- :param Union[Any, List[Any]] value: Single value or list of values to validate for boolean type.
1551
- :param Optional[str] source: Source identifier for error messages. Default: ''.
1552
- :param Optional[bool] raise_error: If True, raises InvalidInputError when non-boolean values are found. If False, returns False. Default: True.
1553
- :return: True if all values are valid booleans, False if any non-boolean values found and raise_error=False.
1554
- :rtype: bool
1555
- :raises InvalidInputError: If non-boolean values are found and raise_error=True.
1556
-
1557
- :example:
1558
- >>> check_valid_boolean(True)
1559
- True
1560
- >>> check_valid_boolean([True, False, True])
1561
- True
1562
- >>> check_valid_boolean([True, 1, False], raise_error=False)
1563
- False
1564
- >>> check_valid_boolean('not_bool', raise_error=False)
1565
- False
1566
- """
1567
- if not isinstance(value, list):
1568
- value = [value]
1569
- for val in value:
1570
- if val in (True, False):
1571
- return True
1572
- else:
1573
- if raise_error:
1574
- raise InvalidInputError(msg=f'{val} is not a valid boolean', source=source)
1575
- else:
1576
- return False
1577
-
1578
- def check_valid_tuple(x: tuple,
1579
- source: Optional[str] = "",
1580
- accepted_lengths: Optional[Tuple[int]] = None,
1581
- valid_dtypes: Optional[Tuple[Any]] = None,
1582
- minimum_length: Optional[int] = None,
1583
- accepted_values: Optional[Iterable[Any]] = None,
1584
- min_integer: Optional[int] = None):
1585
- """
1586
- Validate a tuple against various criteria.
1587
-
1588
- This function performs comprehensive validation of a tuple including
1589
- length constraints, data types, minimum values, and accepted values.
1590
- It raises exceptions for any validation failures.
1591
-
1592
- :param tuple x: The tuple to validate.
1593
- :param Optional[str] source: Source identifier for error messages. Default: "".
1594
- :param Optional[Tuple[int]] accepted_lengths: Tuple of accepted lengths. If None, no length validation. Default: None.
1595
- :param Optional[Tuple[Any]] valid_dtypes: Tuple of allowed data types for tuple elements. If None, no dtype validation. Default: None.
1596
- :param Optional[int] minimum_length: Minimum length required. If None, no minimum length validation. Default: None.
1597
- :param Optional[Iterable[Any]] accepted_values: Iterable of accepted values for tuple elements. If None, no value validation. Default: None.
1598
- :param Optional[int] min_integer: Minimum value for integer elements. If None, no integer validation. Default: None.
1599
- :return: None if validation passes.
1600
- :rtype: None
1601
- :raises InvalidInputError: If any validation criteria are not met.
1602
-
1603
- :example:
1604
- >>> check_valid_tuple(x=(1, 2, 3), accepted_lengths=(2, 3), valid_dtypes=(int,))
1605
- >>> check_valid_tuple(x=('a', 'b'), minimum_length=2, accepted_values=['a', 'b', 'c'])
1606
- >>> check_valid_tuple(x=(5, 10, 15), min_integer=5)
1607
- """
1608
-
1609
- if not isinstance(x, (tuple)):
1610
- raise InvalidInputError(msg=f"{check_valid_tuple.__name__} {source} is not a valid tuple, got: {type(x)}", source=source,)
1611
- if accepted_lengths is not None:
1612
- if len(x) not in accepted_lengths:
1613
- raise InvalidInputError(
1614
- msg=f"Tuple is not of valid lengths. Found {len(x)}. Accepted: {accepted_lengths}",
1615
- source=source,
1616
- )
1617
- if valid_dtypes is not None:
1618
- dtypes = list(set([type(v) for v in x]))
1619
- additional = [x for x in dtypes if x not in valid_dtypes]
1620
- if len(additional) > 0:
1621
- raise InvalidInputError(msg=f"The tuple {source} has invalid data format(s) {additional}. Valid: {valid_dtypes}", source=source)
1622
-
1623
- if minimum_length is not None:
1624
- check_int(name=f'{check_valid_tuple.__name__} minimum_length', value=minimum_length, min_value=1)
1625
- tuple_len = len(x)
1626
- if tuple_len < minimum_length:
1627
- raise InvalidInputError(msg=f"The tuple {source} is shorter ({tuple_len}) than the minimum required length ({minimum_length}).", source=source)
1628
-
1629
- if accepted_values is not None:
1630
- check_instance(source=f'{check_valid_tuple.__name__} accepted_values', accepted_types=(list, tuple,), instance=accepted_values)
1631
- for i in x:
1632
- if i not in accepted_values:
1633
- raise InvalidInputError(msg=f"The tuple {source} has a value that is NOT accepted: {i}, (accepted: {accepted_values}).", source=source)
1634
-
1635
- if min_integer is not None:
1636
- check_int(name=f'{check_valid_tuple.__name__} min_integer', value=min_integer)
1637
- for i in x:
1638
- if isinstance(i, int):
1639
- if i < min_integer:
1640
- raise InvalidInputError(msg=f"The tuple {source} has an integer value below the minimum allowed integer value: {i}, (minimum: {min_integer}).", source=source)
1641
-
1642
-
1643
- def check_video_and_data_frm_count_align(video: Union[str, os.PathLike, cv2.VideoCapture],
1644
- data: Union[str, os.PathLike, pd.DataFrame],
1645
- name: Optional[str] = "",
1646
- raise_error: Optional[bool] = True) -> Union[None, bool]:
1647
- """
1648
- Check if the frame count of a video matches the row count of a data file.
1649
-
1650
- :param Union[str, os.PathLike, cv2.VideoCapture] video: Path to the video file or cv2.VideoCapture object.
1651
- :param Union[str, os.PathLike, pd.DataFrame] data: Path to the data file or DataFrame containing the data.
1652
- :param Optional[str] name: Name of the video (optional for interpretable error msgs).
1653
- :param Optional[bool] raise_error: Whether to raise an error if the counts don't align (default is True). If False, prints warning.
1654
- :return None:
1655
-
1656
- :example:
1657
- >>> data_1 = '/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/csv/outlier_corrected_movement_location/SI_DAY3_308_CD1_PRESENT.csv'
1658
- >>> video_1 = '/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/frames/output/ROI_analysis/SI_DAY3_308_CD1_PRESENT.mp4'
1659
- >>> check_video_and_data_frm_count_align(video=video_1, data=data_1, raise_error=True)
1660
- """
1661
-
1662
- def _count_generator(reader):
1663
- b = reader(1024 * 1024)
1664
- while b:
1665
- yield b
1666
- b = reader(1024 * 1024)
1667
-
1668
- check_instance(
1669
- source=f"{check_video_and_data_frm_count_align.__name__} video",
1670
- instance=video,
1671
- accepted_types=(str, cv2.VideoCapture),
1672
- )
1673
- check_instance(
1674
- source=f"{check_video_and_data_frm_count_align.__name__} data",
1675
- instance=data,
1676
- accepted_types=(str, pd.DataFrame),
1677
- )
1678
- check_str(
1679
- name=f"{check_video_and_data_frm_count_align.__name__} name",
1680
- value=name,
1681
- allow_blank=True,
1682
- )
1683
- if isinstance(video, str):
1684
- check_file_exist_and_readable(file_path=video)
1685
- video = cv2.VideoCapture(video)
1686
- video_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
1687
- if isinstance(data, str):
1688
- check_file_exist_and_readable(file_path=data)
1689
- with open(data, "rb") as fp:
1690
- c_generator = _count_generator(fp.raw.read)
1691
- data_count = (sum(buffer.count(b"\n") for buffer in c_generator)) - 1
1692
- else:
1693
- data_count = len(data)
1694
- if data_count != video_count:
1695
- if not raise_error:
1696
- FrameRangeWarning(msg=f"The video {name} has {video_count} frames, but the associated data file for this video has {data_count} rows", source=check_video_and_data_frm_count_align.__name__)
1697
- return False
1698
- else:
1699
- raise FrameRangeError(msg=f"The video {name} has {video_count} frames, but the associated data file for this video has {data_count} rows", source=check_video_and_data_frm_count_align.__name__,)
1700
- return True
1701
-
1702
- def check_if_video_corrupted(video: Union[str, os.PathLike, cv2.VideoCapture],
1703
- frame_interval: Optional[int] = None,
1704
- frame_n: Optional[int] = 20,
1705
- raise_error: Optional[bool] = True) -> None:
1706
-
1707
- """
1708
- Check if a video file is corrupted by inspecting a set of its frames.
1709
-
1710
- .. note::
1711
- For decent run-time regardless of video length, pass a smaller ``frame_n`` (<100).
1712
-
1713
- :param Union[str, os.PathLike] video_path: Path to the video file or cv2.VideoCapture OpenCV object.
1714
- :param Optional[int] frame_interval: Interval between frames to be checked. If None, ``frame_n`` will be used.
1715
- :param Optional[int] frame_n: Number of frames to be checked, will be sampled at large allowed interval. If None, ``frame_interval`` will be used.
1716
- :param Optional[bool] raise_error: Whether to raise an error if corruption is found. If False, prints warning.
1717
- :return None:
1718
-
1719
- :example:
1720
- >>> check_if_video_corrupted(video_path='/Users/simon/Downloads/NOR ENCODING FExMP8.mp4')
1721
- """
1722
- check_instance(source=f'{check_if_video_corrupted.__name__} video', instance=video, accepted_types=(str, cv2.VideoCapture))
1723
- if isinstance(video, str):
1724
- check_file_exist_and_readable(file_path=video)
1725
- cap = cv2.VideoCapture(video)
1726
- else:
1727
- cap = video
1728
- frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
1729
- if (frame_interval is not None and frame_n is not None) or (frame_interval is None and frame_n is None):
1730
- raise InvalidInputError(msg='Pass frame_interval OR frame_n', source=check_if_video_corrupted.__name__)
1731
- if frame_interval is not None:
1732
- frms_to_check = list(range(0, frame_count, frame_interval))
1733
- else:
1734
- frms_to_check = np.array_split(np.arange(0, frame_count), frame_n)
1735
- frms_to_check = [x[-1] for x in frms_to_check]
1736
- errors = []
1737
- for frm_id in frms_to_check:
1738
- cap.set(1, frm_id)
1739
- ret, _ = cap.read()
1740
- if not ret: errors.append(frm_id)
1741
- if len(errors) > 0:
1742
- if raise_error:
1743
- raise CorruptedFileError(msg=f'Found {len(errors)} corrupted frame(s) at indexes {errors} in video {video}', source=check_if_video_corrupted.__name__)
1744
- else:
1745
- CorruptedFileWarning(msg=f'Found {len(errors)} corrupted frame(s) at indexes {errors} in video {video}', source=check_if_video_corrupted.__name__)
1746
- else:
1747
- pass
1748
-
1749
-
1750
- def check_valid_dict(x: dict,
1751
- valid_key_dtypes: Optional[Tuple[Any]] = None,
1752
- valid_values_dtypes: Optional[Tuple[Any, ...]] = None,
1753
- valid_keys: Optional[Union[Tuple[Any], List[Any]]] = None,
1754
- max_len_keys: Optional[int] = None,
1755
- min_len_keys: Optional[int] = None,
1756
- required_keys: Optional[Tuple[Any, ...]] = None,
1757
- max_value: Optional[Union[float, int]] = None,
1758
- min_value: Optional[Union[float, int]] = None,
1759
- source: Optional[str] = None):
1760
- """
1761
- Validate a dictionary against various criteria.
1762
-
1763
- This function performs comprehensive validation of a dictionary including
1764
- key/value data types, key constraints, required keys, and numeric value ranges.
1765
- It raises exceptions for any validation failures.
1766
-
1767
- :param dict x: The dictionary to validate.
1768
- :param Optional[Tuple[Any]] valid_key_dtypes: Tuple of allowed data types for dictionary keys. If None, no key type validation. Default: None.
1769
- :param Optional[Tuple[Any, ...]] valid_values_dtypes: Tuple of allowed data types for dictionary values. If None, no value type validation. Default: None.
1770
- :param Optional[Union[Tuple[Any], List[Any]]] valid_keys: Tuple or list of valid key names. If None, no key name validation. Default: None.
1771
- :param Optional[int] max_len_keys: Maximum number of keys allowed. If None, no maximum key count validation. Default: None.
1772
- :param Optional[int] min_len_keys: Minimum number of keys required. If None, no minimum key count validation. Default: None.
1773
- :param Optional[Tuple[Any, ...]] required_keys: Tuple of required key names. If None, no required key validation. Default: None.
1774
- :param Optional[Union[float, int]] max_value: Maximum numeric value allowed for numeric values. If None, no maximum value validation. Default: None.
1775
- :param Optional[Union[float, int]] min_value: Minimum numeric value allowed for numeric values. If None, no minimum value validation. Default: None.
1776
- :param Optional[str] source: Source identifier for error messages. If None, uses function name. Default: None.
1777
- :return: None if validation passes.
1778
- :rtype: None
1779
- :raises InvalidInputError: If any validation criteria are not met.
1780
-
1781
- :example:
1782
- >>> check_valid_dict(x={'a': 1, 'b': 2}, valid_key_dtypes=(str,), valid_values_dtypes=(int,))
1783
- >>> check_valid_dict(x={'key1': 10, 'key2': 20}, required_keys=('key1',), min_value=5, max_value=25)
1784
- >>> check_valid_dict(x={'x': 1, 'y': 2}, valid_keys=('x', 'y', 'z'), min_len_keys=2)
1785
- """
1786
-
1787
-
1788
- source = check_valid_dict.__name__ if source is None else source
1789
- check_instance(source=check_valid_dict.__name__, instance=x, accepted_types=(dict,))
1790
- if valid_key_dtypes is not None:
1791
- for i in list(x.keys()):
1792
- if not isinstance(i, valid_key_dtypes):
1793
- raise InvalidInputError(msg=f'{type(i)} is not a valid key DTYPE. Valid: {valid_key_dtypes}', source=source)
1794
- if valid_values_dtypes is not None:
1795
- for i in list(x.values()):
1796
- if not isinstance(i, valid_values_dtypes):
1797
- raise InvalidInputError(msg=f'{type(i)} is not a valid value DTYPE. Valid: {valid_values_dtypes}', source=source)
1798
- if max_len_keys is not None:
1799
- check_int(name=f'{check_valid_dict.__name__} max_len_keys', min_value=1, value=max_len_keys)
1800
- key_cnt = len(list(x.keys()))
1801
- if key_cnt > max_len_keys:
1802
- raise InvalidInputError(msg=f'Dictionary have {key_cnt} keys. Maximum allowed: {max_len_keys}', source=source)
1803
- if min_len_keys is not None:
1804
- check_int(name=f'{check_valid_dict.__name__} min_len_keys', min_value=1, value=min_len_keys)
1805
- key_cnt = len(list(x.keys()))
1806
- if key_cnt < min_len_keys:
1807
- raise InvalidInputError(msg=f'Dictionary have {key_cnt} keys. Minimum allowed: {min_len_keys}', source=source)
1808
- if required_keys is not None:
1809
- for i in list(required_keys):
1810
- if i not in list(x.keys()):
1811
- raise InvalidInputError(msg=f'The required key {i} does not exist in the dictionary. Existing keys: {list(x.keys())}', source=source)
1812
- if max_value is not None:
1813
- if not isinstance(max_value, (float, int)):
1814
- raise InvalidInputError(msg=f'{check_valid_dict.__name__} max_value has to be a float or integer, got {type(max_value)}.')
1815
- for k, v in x.items():
1816
- if isinstance(v, (float, int)):
1817
- if v > max_value:
1818
- raise InvalidInputError(msg=f'The required key {k} has value {v} which is above the max allowed: {max_value}.', source=source)
1819
- if min_value is not None:
1820
- if not isinstance(min_value, (float, int)):
1821
- raise InvalidInputError(msg=f'{check_valid_dict.__name__} max_value has to be a float or integer, got {type(min_value)}.')
1822
- for k, v in x.items():
1823
- if isinstance(v, (float, int)):
1824
- if v < min_value:
1825
- raise InvalidInputError(msg=f'The required key {k} has value {v} which is less than the minimum allowed: {min_value}.', source=source)
1826
- if valid_keys is not None:
1827
- if not isinstance(valid_keys, (tuple, list)):
1828
- raise InvalidInputError(msg=f'{check_valid_dict.__name__} valid_keys has to tuple, got {type(valid_keys)}.')
1829
- invalid_keys = [i for i in x.keys() if i not in valid_keys]
1830
- if len(invalid_keys) > 0:
1831
- raise InvalidInputError(msg=f'The dictionary has keys that are invalid ({invalid_keys}). Accepted, valid keys are: {valid_keys}.', source=source)
1832
-
1833
-
1834
-
1835
-
1836
- def is_video_color(video: Union[str, os.PathLike, cv2.VideoCapture]) -> bool:
1837
- """
1838
- Determines whether a video is in color or greyscale.
1839
-
1840
- .. seealso::
1841
- :func:`simba.mixins.image_mixin.ImageMixin.is_video_color`
1842
-
1843
- :param Union[str, os.PathLike, cv2.VideoCapture] video: The video source, either a cv2.VideoCapture object or a path to a file on disk.
1844
- :return: Returns `True` if the video is in color (has more than one channel), and `False` if the video is greyscale (single channel).
1845
- :rtype: bool
1846
- """
1847
-
1848
- check_instance(source=is_video_color.__name__, instance=video, accepted_types=(str, cv2.VideoCapture))
1849
-
1850
- # Handle string path vs VideoCapture object
1851
- should_release = False
1852
- if isinstance(video, str):
1853
- check_file_exist_and_readable(file_path=video)
1854
- video = cv2.VideoCapture(video)
1855
- should_release = True
1856
-
1857
- try:
1858
- video.set(cv2.CAP_PROP_POS_FRAMES, 0)
1859
- _, frm = video.read()
1860
-
1861
- # If frame has only 2 dimensions, it's definitely greyscale
1862
- if frm.ndim == 2:
1863
- return False
1864
-
1865
- # If frame has 3 dimensions, check if it's actually greyscale
1866
- # (some greyscale videos are stored as 3-channel with identical values)
1867
- if frm.ndim == 3:
1868
- # Check if all channels are identical (indicating greyscale)
1869
- if frm.shape[2] == 3: # BGR format
1870
- # Compare B, G, and R channels
1871
- if np.array_equal(frm[:, :, 0], frm[:, :, 1]) and np.array_equal(frm[:, :, 1], frm[:, :, 2]):
1872
- return False # All channels identical = greyscale
1873
- else:
1874
- return True # Channels different = color
1875
- elif frm.shape[2] == 1:
1876
- return False # Single channel = greyscale
1877
- else:
1878
- return True # Other multi-channel formats = color
1879
-
1880
- # Default case: assume greyscale
1881
- return False
1882
-
1883
- finally:
1884
- # Clean up VideoCapture if we created it
1885
- if should_release and video.isOpened():
1886
- video.release()
1887
-
1888
-
1889
- def check_filepaths_in_iterable_exist(file_paths: Iterable[str],
1890
- name: Optional[str] = None):
1891
-
1892
- check_instance(source=f'{check_filepaths_in_iterable_exist.__name__} file_paths {name}', instance=file_paths, accepted_types=(list, tuple,))
1893
- if len(file_paths) == 0:
1894
- raise NoFilesFoundError(msg=f'{name} {file_paths} is empty')
1895
- for file_path in file_paths:
1896
- check_str(name=f'{check_filepaths_in_iterable_exist.__name__} {file_path} {name}', value=file_path)
1897
- if not os.path.isfile(file_path):
1898
- raise NoFilesFoundError(msg=f'{name} {file_path} is not a valid file path')
1899
-
1900
- def check_all_dfs_in_list_has_same_cols(dfs: List[pd.DataFrame], raise_error: bool = True, source: str = '') -> bool:
1901
- """
1902
- Check that all DataFrames in a list have the same column names.
1903
-
1904
- This function validates that all DataFrames in the provided list contain
1905
- identical column headers. It finds the intersection of all column names
1906
- and identifies any missing headers that are not present in all DataFrames.
1907
-
1908
- :param List[pd.DataFrame] dfs: List of DataFrames to validate for consistent column names.
1909
- :param bool raise_error: If True, raises MissingColumnsError when column names don't match. If False, returns False. Default: True.
1910
- :param str source: Source identifier for error messages. Default: ''.
1911
- :return: True if all DataFrames have the same column names, False if they don't match and raise_error=False.
1912
- :rtype: bool
1913
- :raises MissingColumnsError: If DataFrames have different column names and raise_error=True.
1914
-
1915
- :example:
1916
- >>> df1 = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
1917
- >>> df2 = pd.DataFrame({'A': [5, 6], 'B': [7, 8]})
1918
- >>> check_all_dfs_in_list_has_same_cols(dfs=[df1, df2])
1919
- True
1920
- >>> df3 = pd.DataFrame({'A': [1, 2], 'C': [3, 4]})
1921
- >>> check_all_dfs_in_list_has_same_cols(dfs=[df1, df3], raise_error=False)
1922
- False
1923
- """
1924
- check_valid_lst(data=dfs, source=check_all_dfs_in_list_has_same_cols.__name__, valid_dtypes=(pd.DataFrame,), min_len=1)
1925
- col_headers = [list(x.columns) for x in dfs]
1926
- common_headers = set(col_headers[0]).intersection(*col_headers[1:])
1927
- all_headers = set(item for sublist in col_headers for item in sublist)
1928
- missing_headers = list(all_headers - common_headers)
1929
- if len(missing_headers) > 0:
1930
- if raise_error:
1931
- raise MissingColumnsError(msg=f"The data in {source} directory do not contain the same headers. Some files are missing the headers: {missing_headers}", source=check_all_dfs_in_list_has_same_cols.__name__)
1932
- else:
1933
- return False
1934
- return True
1935
-
1936
-
1937
- def is_valid_video_file(file_path: Union[str, os.PathLike], raise_error: bool = True):
1938
- """
1939
- Check if a file path is a valid video file.
1940
-
1941
- This function validates that a file path exists, is readable, and can be
1942
- opened as a video file using OpenCV. It performs basic video file validation
1943
- by attempting to open the file with cv2.VideoCapture.
1944
-
1945
- :param Union[str, os.PathLike] file_path: Path to the video file to validate.
1946
- :param bool raise_error: If True, raises InvalidFilepathError when file is not a valid video. If False, returns False. Default: True.
1947
- :return: True if the file is a valid video file, False if it's not valid and raise_error=False.
1948
- :rtype: bool
1949
- :raises InvalidFilepathError: If the file is not a valid video file and raise_error=True.
1950
-
1951
- :example:
1952
- >>> is_valid_video_file('/path/to/video.mp4')
1953
- True
1954
- >>> is_valid_video_file('/path/to/invalid.txt', raise_error=False)
1955
- False
1956
- >>> is_valid_video_file('/path/to/corrupted.mp4', raise_error=False)
1957
- False
1958
- """
1959
- check_file_exist_and_readable(file_path=file_path)
1960
- try:
1961
- cap = cv2.VideoCapture(file_path)
1962
- if not cap.isOpened():
1963
- if not raise_error:
1964
- return False
1965
- else:
1966
- raise InvalidFilepathError(msg=f'The path {file_path} is not a valid video file', source=is_valid_video_file.__name__)
1967
- return True
1968
- except Exception:
1969
- if not raise_error:
1970
- return False
1971
- else:
1972
- raise InvalidFilepathError(msg=f'The path {file_path} is not a valid video file', source=is_valid_video_file.__name__)
1973
- finally:
1974
- if 'cap' in locals():
1975
- if cap.isOpened():
1976
- cap.release()
1977
-
1978
-
1979
- def check_valid_polygon(polygon: Union[np.ndarray, Polygon], raise_error: bool = True, name: Optional[str] = None) -> Union[bool, None]:
1980
- """
1981
- Validates whether the given polygon is a valid geometric shape.
1982
-
1983
- :param Union[np.ndarray, Polygon] polygon: The polygon to validate, either as a NumPy array of shape (N, 2) or a shapely Polygon object.
1984
- :param bool raise_error: If True, raises an InvalidInputError if the polygon is invalid; otherwise, returns False.
1985
- :param Optional[str] name: An optional name for the polygon to include in error messages.
1986
- :return: True if the polygon is valid, False if invalid (and raise_error is False), or None if an error is raised.
1987
- """
1988
-
1989
-
1990
- name = '' if name is None else name
1991
- check_instance(source=f'{check_valid_polygon.__name__} polygon', accepted_types=(np.ndarray, Polygon,), instance=polygon)
1992
- if isinstance(polygon, np.ndarray):
1993
- check_valid_array(data=polygon, source=f'{check_valid_polygon.__name__} polygon', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value, min_axis_0=3, accepted_axis_1_shape=[2,])
1994
- polygon = Polygon(polygon.astype(np.int32))
1995
- if not polygon.is_valid:
1996
- if raise_error:
1997
- raise InvalidInputError(msg=f'The polygon {name} is invalid', source=check_valid_polygon.__name__)
1998
- else:
1999
- return False
2000
- else:
2001
- return True
2002
-
2003
-
2004
- def is_img_bw(img: np.ndarray,
2005
- raise_error: bool = True,
2006
- source: Optional[str] = '') -> bool:
2007
- """
2008
- Check if an image is binary black and white.
2009
-
2010
- This function validates that an image contains only two pixel values:
2011
- 0 (black) and 255 (white). It checks all unique pixel values in the image
2012
- and ensures they are exactly these two values.
2013
-
2014
- :param np.ndarray img: The image array to validate for binary black and white format.
2015
- :param bool raise_error: If True, raises InvalidInputError when image is not binary black and white. If False, returns False. Default: True.
2016
- :param Optional[str] source: Source identifier for error messages. Default: ''.
2017
- :return: True if the image is binary black and white, False if it's not and raise_error=False.
2018
- :rtype: bool
2019
- :raises InvalidInputError: If the image is not binary black and white and raise_error=True.
2020
-
2021
- :example:
2022
- >>> bw_img = np.array([[0, 255], [255, 0]], dtype=np.uint8)
2023
- >>> is_img_bw(bw_img)
2024
- True
2025
- >>> gray_img = np.array([[128, 200], [50, 100]], dtype=np.uint8)
2026
- >>> is_img_bw(gray_img, raise_error=False)
2027
- False
2028
- """
2029
- check_if_valid_img(data=img, source=is_img_bw.__name__, raise_error=True)
2030
- px_vals = set(list(np.sort(np.unique(img)).astype(np.int32)))
2031
- additional = list(px_vals - {0, 255})
2032
- if len(additional) > 0:
2033
- if raise_error:
2034
- raise InvalidInputError(msg=f'The image {source} is not a black-and-white image. Expected: [0, 255], got {additional}', source=is_img_bw.__name__)
2035
- else:
2036
- return False
2037
- else:
2038
- return True
2039
-
2040
- def is_img_greyscale(img: np.ndarray,
2041
- raise_error: bool = True,
2042
- source: Optional[str] = '') -> bool:
2043
- """
2044
- Check if an image is greyscale.
2045
-
2046
- This function validates that an image is in greyscale format by checking
2047
- that it has exactly 2 dimensions (height and width). Greyscale images
2048
- have a single channel and are represented as 2D arrays.
2049
-
2050
- :param np.ndarray img: The image array to validate for greyscale format.
2051
- :param bool raise_error: If True, raises InvalidInputError when image is not greyscale. If False, returns False. Default: True.
2052
- :param Optional[str] source: Source identifier for error messages. Default: ''.
2053
- :return: True if the image is greyscale, False if it's not and raise_error=False.
2054
- :rtype: bool
2055
- :raises InvalidInputError: If the image is not greyscale and raise_error=True.
2056
-
2057
- :example:
2058
- >>> gray_img = np.array([[128, 200], [50, 100]], dtype=np.uint8)
2059
- >>> is_img_greyscale(gray_img)
2060
- True
2061
- >>> color_img = np.array([[[128, 200, 50], [100, 150, 75]]], dtype=np.uint8)
2062
- >>> is_img_greyscale(color_img, raise_error=False)
2063
- False
2064
- """
2065
- check_if_valid_img(data=img, source=is_img_greyscale.__name__, raise_error=False, greyscale=True)
2066
- if not img.ndim == 2:
2067
- if raise_error:
2068
- raise InvalidInputError(msg=f'The image {source} is not a greyscale image. Expected 2 dimensions got: {img.ndim}', source=is_img_greyscale.__name__)
2069
- else:
2070
- return False
2071
- else:
2072
- return True
2073
-
2074
-
2075
- def is_wsl() -> bool:
2076
- """
2077
- Check if SimBA is running in Microsoft WSL (Windows Subsystem for Linux).
2078
-
2079
- This function detects whether the current environment is running inside
2080
- Microsoft WSL by checking the contents of /proc/version for the presence
2081
- of "microsoft" string, which indicates WSL environment.
2082
-
2083
- :return: True if running in WSL, False otherwise.
2084
- :rtype: bool
2085
-
2086
- :example:
2087
- >>> is_wsl()
2088
- False # When running on native Linux
2089
- >>> is_wsl()
2090
- True # When running in WSL
2091
- """
2092
- try:
2093
- with open("/proc/version", "r") as f:
2094
- return "microsoft" in f.read().lower()
2095
- except FileNotFoundError:
2096
- return False
2097
-
2098
- def is_windows_path(value):
2099
- """
2100
- Check if the value is a valid Windows path format.
2101
-
2102
- This function validates that a string follows the Windows path format
2103
- by checking that it starts with a drive letter followed by a colon
2104
- (e.g., "C:", "D:", etc.). It performs basic format validation without
2105
- checking if the path actually exists on the filesystem.
2106
-
2107
- :param value: The value to check for Windows path format.
2108
- :return: True if the value is a valid Windows path format, False otherwise.
2109
- :rtype: bool
2110
-
2111
- :example:
2112
- >>> is_windows_path("C:\\Users\\username\\file.txt")
2113
- True
2114
- >>> is_windows_path("D:\\data\\folder")
2115
- True
2116
- >>> is_windows_path("/home/user/file.txt")
2117
- False
2118
- >>> is_windows_path("relative/path")
2119
- False
2120
- >>> is_windows_path("")
2121
- False
2122
- """
2123
- return isinstance(value, str) and (len(value) > 1 and value[1] == ':' and value[0].isalpha())
2124
-
2125
-
2126
- def check_same_files_exist_in_all_directories(dirs: List[Union[str, os.PathLike]], raise_error: bool = False, file_type: str = "csv") -> bool:
2127
- """
2128
- Check if the same files of a given type exist in all specified directories.
2129
-
2130
- :param List[Union[str, os.PathLike]] dirs: List of directory paths to check.
2131
- :param bool raise_error: If True, raises an error when file names do not match across directories. Defaults to False.
2132
- :param bool raise_error: File extension (without the dot) to check for (e.g., 'csv', 'txt'). Defaults to 'csv'.
2133
- """
2134
-
2135
- check_valid_lst( data=dirs, source=f"{check_same_files_exist_in_all_directories.__name__} dirs", valid_dtypes=(str, os.PathLike), min_len=2)
2136
- file_sets = [{os.path.basename(f) for f in glob.glob(os.path.join(dir, f"*.{file_type}"))} for dir in dirs]
2137
- common_files = set.intersection(*file_sets) if file_sets else set()
2138
- if not all(files == common_files for files in file_sets):
2139
- if raise_error:
2140
- raise NoFilesFoundError( msg=f"Files of type '{file_type}' do not match across directories: {dirs}.", source=check_same_files_exist_in_all_directories.__name__)
2141
- return False
2142
- return True
2143
-
2144
-
2145
-
2146
- def check_valid_img_path(path: Union[str, os.PathLike], raise_error: bool = True):
2147
- """
2148
- Check if a file path is a valid image file.
2149
-
2150
- This function validates that a file path exists, is readable, and can be
2151
- opened as an image file using OpenCV. It performs basic image file validation
2152
- by attempting to read the file with cv2.imread.
2153
-
2154
- :param Union[str, os.PathLike] path: Path to the image file to validate.
2155
- :param bool raise_error: If True, raises InvalidInputError when file is not a valid image. If False, returns False. Default: True.
2156
- :return: True if the file is a valid image file, False if it's not valid and raise_error=False.
2157
- :rtype: bool
2158
- :raises InvalidInputError: If the file is not a valid image file and raise_error=True.
2159
-
2160
- :example:
2161
- >>> check_valid_img_path('/path/to/image.jpg')
2162
- True
2163
- >>> check_valid_img_path('/path/to/invalid.txt', raise_error=False)
2164
- False
2165
- >>> check_valid_img_path('/path/to/corrupted.png', raise_error=False)
2166
- False
2167
- """
2168
- check_file_exist_and_readable(path)
2169
- try:
2170
- _ = cv2.imread(path)
2171
- except Exception as e:
2172
- if raise_error:
2173
- print(e.args)
2174
- raise InvalidInputError(msg=f'{path} could not be read as a valid image file', source=check_valid_img_path.__name__)
2175
- else:
2176
- return False
2177
- return True
2178
-
2179
-
2180
-
2181
-
2182
- def check_valid_device(device: Union[Literal['cpu'], int], raise_error: bool = True) -> bool:
2183
- """
2184
- Validate a compute device specification, ensuring it is either 'cpu' or a valid GPU index.
2185
-
2186
- This function validates that a device specification is valid for use with
2187
- PyTorch/CUDA operations. It checks if the device is either 'cpu' for CPU
2188
- usage or a valid integer representing a CUDA device index.
2189
-
2190
- :param Union[Literal['cpu'], int] device: The device to validate. Should be the string 'cpu' for CPU usage, or an integer representing a CUDA device index (e.g., 0 for 'cuda:0').
2191
- :param bool raise_error: If True, raises InvalidInputError or SimBAGPUError when the device is invalid. If False, returns False instead of raising errors. Default: True.
2192
- :return: True if the device is valid, False if it's invalid and raise_error=False.
2193
- :rtype: bool
2194
- :raises InvalidInputError: If the device format is invalid and raise_error=True.
2195
- :raises SimBAGPUError: If the GPU device is not available or not valid and raise_error=True.
2196
-
2197
- :example:
2198
- >>> check_valid_device('cpu')
2199
- True
2200
- >>> check_valid_device(0) # GPU 0
2201
- True
2202
- >>> check_valid_device(5, raise_error=False) # Non-existent GPU
2203
- False
2204
- >>> check_valid_device('gpu', raise_error=False) # Invalid format
2205
- False
2206
- """
2207
- source = check_valid_device.__name__
2208
- if isinstance(device, str):
2209
- valid, msg = check_str(name=f'{source} format', value=device.lower(), options=['cpu'], raise_error=False)
2210
- if not valid:
2211
- if raise_error:
2212
- raise InvalidInputError(msg=msg, source=source)
2213
- return False
2214
- return True
2215
-
2216
- valid, msg = check_int(name=f'{source} device', value=device, min_value=0, raise_error=False)
2217
- if not valid:
2218
- if raise_error:
2219
- raise InvalidInputError(msg=msg, source=source)
2220
- return False
2221
-
2222
- gpu_available, gpus = _is_cuda_available()
2223
- if not gpu_available:
2224
- if raise_error:
2225
- raise SimBAGPUError(msg=f'No GPU detected but device {device} passed', source=source)
2226
- return False
2227
-
2228
- if device not in gpus:
2229
- if raise_error:
2230
- raise SimBAGPUError(msg=f'Unaccepted GPU device {device} passed. Accepted: {list(gpus.keys())}', source=source)
2231
- return False
2232
-
2233
- def is_lxc_container() -> bool:
2234
- """
2235
- Helper to check if the current environment is inside a LXC Linux container.
2236
-
2237
- .. note::
2238
- See GitHub issue 457 for origin - https://github.com/sgoldenlab/simba/issues/457#issuecomment-3052631284
2239
- Thanks Heinrich2818 - https://github.com/Heinrich2818
2240
-
2241
- :return: True if current environment is a LXC linux container, False if not.
2242
- :rtype: bool
2243
- """
2244
-
2245
- try:
2246
- with open('/proc/1/cgroup') as f:
2247
- for line in f:
2248
- if 'lxc' in line:
2249
- return True
2250
- except IOError:
2251
- pass
2252
- try:
2253
- with open('/proc/self/mountinfo') as f:
2254
- for line in f:
2255
- if ' - cgroup2 ' in line:
2256
- mount_point = line.strip().split()[-1]
2257
- if 'lxc' in mount_point:
2258
- return True
2259
- except IOError:
2260
- pass
2261
- try:
2262
- with open('/proc/1/environ', 'rb') as f:
2263
- env = f.read().split(b'\0')
2264
- for e in env:
2265
- if e == b'container=lxc':
2266
- return True
2267
- except IOError:
2268
- pass
2269
- try:
2270
- import subprocess
2271
- r = subprocess.run(['systemd-detect-virt', '--container'], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=False, text=True)
2272
- if r.stdout.strip() == 'lxc':
2273
- return True
2274
- except Exception:
2275
- pass
2276
- return False
2277
-
2278
-
2279
-
2280
- def check_valid_cpu_pool(value: Any,
2281
- source: str = '',
2282
- max_cores: Optional[int] = None,
2283
- min_cores: Optional[int] = None,
2284
- accepted_cores: Optional[Union[List[int], Tuple[int, ...], int]] = None,
2285
- raise_error: bool = True) -> bool:
2286
-
2287
- """
2288
- Validates that a value is a valid multiprocessing.Pool instance and optionally checks core count constraints.
2289
-
2290
- :param Any value: The value to validate. Must be an instance of multiprocessing.pool.Pool.
2291
- :param str source: Optional source identifier for error messages. Default is empty string.
2292
- :param Optional[int] max_cores: Optional maximum number of processes allowed in the pool. If provided, validates that pool._processes <= max_cores.
2293
- :param Optional[int] min_cores: Optional minimum number of processes required in the pool. If provided, validates that pool._processes >= min_cores.
2294
- :param Optional[Union[List[int], Tuple[int, ...], int]] accepted_cores: Optional exact or list of acceptable process counts. If an int, validates that pool._processes == accepted_cores. If a list/tuple of ints, validates that pool._processes is in accepted_cores. All values must be positive integers.
2295
- :param bool raise_error: If True, raises InvalidInputError on validation failure. If False, returns False on failure. Default is True.
2296
- :return bool: True if validation passes, False if validation fails and raise_error is False.
2297
- :raises InvalidInputError: If value is not a valid Pool instance, if core count constraints are violated, if accepted_cores contains invalid types, or if raise_error is True.
2298
-
2299
- :example:
2300
- >>> import multiprocessing
2301
- >>> pool = multiprocessing.Pool(processes=4)
2302
- >>> check_valid_cpu_pool(value=pool, source='test', max_cores=8, min_cores=2)
2303
- >>> True
2304
- >>> check_valid_cpu_pool(value=pool, source='test', accepted_cores=[4, 8, 16])
2305
- >>> True
2306
- >>> check_valid_cpu_pool(value=pool, source='test', accepted_cores=4)
2307
- >>> True
2308
- """
2309
-
2310
- if not isinstance(value, (multiprocessing.pool.Pool,)):
2311
- if raise_error:
2312
- raise InvalidInputError(msg=f'Not a valid CPU pool. Expected {multiprocessing.pool.Pool}, got {type(value)}.', source=source)
2313
- else:
2314
- return False
2315
- if max_cores is not None:
2316
- check_int(name=f'{source} max_cores', value=max_cores, min_value=1)
2317
- if value._processes > max_cores:
2318
- if raise_error: raise InvalidInputError(msg=f'CPU pool has too many processes. Got {value._processes}, max {max_cores}', source=source)
2319
- else: return False
2320
- if min_cores is not None:
2321
- check_int(name=f'{source} min_cores', value=min_cores, min_value=1)
2322
- if value._processes < min_cores:
2323
- if raise_error:
2324
- raise InvalidInputError(msg=f'CPU pool has too few processes. Got {value._processes}, min {min_cores}',source=source)
2325
- else:
2326
- return False
2327
- if accepted_cores is not None:
2328
- if isinstance(accepted_cores, int):
2329
- is_valid, _ = check_int(name=f'{source} accepted_cores', value=accepted_cores, min_value=1, raise_error=raise_error)
2330
- if not is_valid:
2331
- return False
2332
- if value._processes != accepted_cores:
2333
- if raise_error:
2334
- raise InvalidInputError(msg=f'CPU pool has an unacceptable number of cores. Got {value._processes}, accepted {accepted_cores}', source=source)
2335
- else:
2336
- return False
2337
- elif isinstance(accepted_cores, (tuple, list)):
2338
- is_valid = check_valid_lst(data=list(accepted_cores), source=f'{source} accepted_cores', valid_dtypes=(int,), min_len=1, min_value=1, raise_error=raise_error)
2339
- if not is_valid:
2340
- return False
2341
- if value._processes not in accepted_cores:
2342
- if raise_error:
2343
- raise InvalidInputError(msg=f'CPU pool has an unacceptable number of cores. Got {value._processes}, accepted {accepted_cores}', source=source)
2344
- else:
2345
- return False
2346
- if min_cores is not None:
2347
- if min(accepted_cores) < min_cores:
2348
- if raise_error:
2349
- raise InvalidInputError(msg=f'accepted_cores contains values below min_cores. min_cores={min_cores}, accepted_cores={accepted_cores}', source=source)
2350
- else:
2351
- return False
2352
- if max_cores is not None:
2353
- if max(accepted_cores) > max_cores:
2354
- if raise_error:
2355
- raise InvalidInputError(msg=f'accepted_cores contains values above max_cores. max_cores={max_cores}, accepted_cores={accepted_cores}', source=source)
2356
- else:
2357
- return False
2358
- else:
2359
- raise InvalidInputError(msg=f'accepted_cores has to be an int, list of ints, or tuple of ints. Got {type(accepted_cores)}', source=source)
2360
-
2361
- return True
2362
-
2363
-
2364
- def check_valid_codec(codec: str, raise_error: bool = True, source: str = ''):
2365
- """
2366
- Validate that a codec string is available in the current FFmpeg installation.
2367
-
2368
- Checks if the provided codec name exists in the list of available FFmpeg encoders
2369
- by querying FFmpeg directly. This ensures the codec can be used for video encoding/decoding.
2370
-
2371
- .. note::
2372
- This function requires FFmpeg to be installed and available in the system PATH.
2373
- The function queries FFmpeg for available encoders at runtime, so it will reflect
2374
- the actual encoders available in your FFmpeg installation.
2375
-
2376
- .. seealso::
2377
- To get a list of all available encoders, see :func:`~simba.utils.lookups.get_ffmpeg_encoders`.
2378
- To check if FFmpeg is available, see :func:`~simba.utils.checks.check_ffmpeg_available`.
2379
-
2380
- :param str codec: The codec name to validate (e.g., 'libx264', 'h264_nvenc', 'libvpx-vp9').
2381
- :param bool raise_error: If True, raises ``InvalidInputError`` when codec is invalid. If False, returns False. Default: True.
2382
- :param str source: Source identifier for error messages. Used when raising exceptions. Default: ''.
2383
- :return: True if codec is valid, False if invalid and ``raise_error=False``.
2384
- :rtype: bool
2385
- :raises InvalidInputError: If codec is not valid and ``raise_error=True``.
2386
-
2387
- :example:
2388
- >>> check_valid_codec(codec='libx264')
2389
- >>> check_valid_codec(codec='h264_nvenc', source='my_function')
2390
- >>> is_valid = check_valid_codec(codec='invalid_codec', raise_error=False)
2391
- """
2392
- from simba.utils.lookups import get_ffmpeg_encoders; encoders = get_ffmpeg_encoders()
2393
- valid_codec = check_str(name=f'{check_valid_codec.__name__} codec', value=codec, options=encoders, allow_blank=False, raise_error=False)[0]
2394
- if not valid_codec:
2395
- if raise_error:
2396
- raise InvalidInputError(msg=f'The codec {codec} is not a valid codec in the current FFMPEG installation', source=source)
2397
- else:
2398
- return False
2399
- return True
2400
-
2401
-
1
+ __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
+
3
+ import ast
4
+ import glob
5
+ import os
6
+ import re
7
+ import subprocess
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
10
+
11
+ try:
12
+ from typing import Literal
13
+ except:
14
+ from typing_extensions import Literal
15
+
16
+ try:
17
+ import cupy as cp
18
+ except ModuleNotFoundError:
19
+ import numpy as cp
20
+
21
+ import multiprocessing
22
+
23
+ import cv2
24
+ import numpy as np
25
+ import pandas as pd
26
+ import trafaret as t
27
+ from shapely.geometry import Polygon
28
+
29
+ from simba.data_processors.cuda.utils import _is_cuda_available
30
+ from simba.utils.enums import Formats, Keys, Options, UMAPParam
31
+ from simba.utils.errors import (ArrayError, ColumnNotFoundError,
32
+ CorruptedFileError, CountError,
33
+ DirectoryNotEmptyError, FFMPEGNotFoundError,
34
+ FloatError, FrameRangeError, IntegerError,
35
+ InvalidFilepathError, InvalidInputError,
36
+ MissingColumnsError, NoDataError,
37
+ NoFilesFoundError, NoROIDataError,
38
+ NotDirectoryError, ParametersFileError,
39
+ SimBAGPUError, StringError)
40
+ from simba.utils.warnings import (CorruptedFileWarning, FrameRangeWarning,
41
+ InvalidValueWarning, NoDataFoundWarning)
42
+
43
+
44
+ def check_file_exist_and_readable(file_path: Union[str, os.PathLike], raise_error: bool = True) -> bool:
45
+ """
46
+ Checks if a path points to a readable file.
47
+
48
+ :param str file_path: Path to file on disk.
49
+ :raise NoFilesFoundError: The file does not exist.
50
+ :raise CorruptedFileError: The file can not be read or is zero byte size.
51
+ """
52
+ check_instance(source="FILE PATH", instance=file_path, accepted_types=(str, os.PathLike))
53
+ if not os.path.isfile(file_path):
54
+ if raise_error:
55
+ raise NoFilesFoundError(msg=f"{file_path} is not a valid file path", source=check_file_exist_and_readable.__name__)
56
+ else:
57
+ return False
58
+ elif not os.access(file_path, os.R_OK):
59
+ if raise_error:
60
+ raise CorruptedFileError(msg=f"{file_path} is not readable", source=check_file_exist_and_readable.__name__)
61
+ else:
62
+ return False
63
+ elif os.stat(file_path).st_size == 0:
64
+ if raise_error:
65
+ raise CorruptedFileError(msg=f"{file_path} is 0 bytes and contains no data.", source=check_file_exist_and_readable.__name__)
66
+ else:
67
+ return False
68
+ else:
69
+ return True
70
+
71
+
72
+ def check_int(name: str,
73
+ value: Any,
74
+ max_value: Optional[int] = None,
75
+ min_value: Optional[int] = None,
76
+ unaccepted_vals: Optional[List[int]] = None,
77
+ accepted_vals: Optional[List[int]] = None,
78
+ allow_negative: bool = True,
79
+ allow_zero: bool = True,
80
+ raise_error: Optional[bool] = True) -> Tuple[bool, str]:
81
+ """
82
+ Check if variable is a valid integer.
83
+
84
+ Validates that a value is an integer and optionally checks it against constraints such as
85
+ minimum/maximum values, accepted/unaccepted value lists, and negative/zero number restrictions.
86
+
87
+ :param str name: Name of the variable being checked (used in error messages).
88
+ :param Any value: The value to validate as an integer.
89
+ :param Optional[int] max_value: Maximum allowed value. If None, no maximum constraint. Default None.
90
+ :param Optional[int] min_value: Minimum allowed value. If None, no minimum constraint. Default None.
91
+ :param Optional[List[int]] unaccepted_vals: List of integer values that are not accepted. If value is in this list, validation fails. Default None.
92
+ :param Optional[List[int]] accepted_vals: List of integer values that are accepted. If value is not in this list, validation fails. Default None.
93
+ :param bool allow_negative: If False, negative values will cause validation to fail. Default True.
94
+ :param bool allow_zero: If False, zero values will cause validation to fail. Default True.
95
+ :param Optional[bool] raise_error: If True, raises IntegerError when validation fails. If False, returns (False, error_message) tuple. Default True.
96
+ :return: If `raise_error` is False, returns a tuple (bool, str) where bool indicates if value is valid, and str contains error message (empty string if valid). If `raise_error` is True and validation passes, returns (True, ""). If `raise_error` is True and validation fails, raises IntegerError.
97
+ :rtype: Tuple[bool, str]
98
+ :raises IntegerError: If validation fails and `raise_error` is True.
99
+
100
+ :example:
101
+ >>> check_int(name='My_fps', value=25, min_value=1)
102
+ >>> check_int(name='Quality', value=50, min_value=0, max_value=100, raise_error=False)
103
+ >>> check_int(name='Mode', value=2, accepted_vals=[1, 2, 3])
104
+ >>> check_int(name='Count', value=-5, allow_negative=False)
105
+ >>> check_int(name='Divisor', value=0, allow_zero=False)
106
+ """
107
+ msg = ""
108
+ try:
109
+ t.Int().check(value)
110
+ except t.DataError as e:
111
+ msg = f"{name} should be an integer number in SimBA, but is set to {str(value)}"
112
+ if raise_error:
113
+ raise IntegerError(msg=msg, source=check_int.__name__)
114
+ else:
115
+ return False, msg
116
+ if min_value != None:
117
+ if int(value) < min_value:
118
+ msg = f"{name} should be MORE THAN OR EQUAL to {str(min_value)}. It is set to {str(value)}"
119
+ if raise_error:
120
+ raise IntegerError(msg=msg, source=check_int.__name__)
121
+ else:
122
+ return False, msg
123
+ if max_value != None:
124
+ if int(value) > max_value:
125
+ msg = f"{name} should be LESS THAN OR EQUAL to {str(max_value)}. It is set to {str(value)}"
126
+ if raise_error:
127
+ raise IntegerError(msg=msg, source=check_int.__name__)
128
+ else:
129
+ return False, msg
130
+ if unaccepted_vals != None:
131
+ check_valid_lst(data=unaccepted_vals, source=name, valid_dtypes=(int,), min_len=1)
132
+ if int(value) in unaccepted_vals:
133
+ msg = f"{name} is an not an accepted value. Unaccepted values {unaccepted_vals}."
134
+ if raise_error:
135
+ raise IntegerError(msg=msg, source=check_int.__name__)
136
+ else:
137
+ return False, msg
138
+ if accepted_vals != None:
139
+ check_valid_lst(data=accepted_vals, source=name, valid_dtypes=(int,), min_len=1)
140
+ if int(value) not in accepted_vals:
141
+ msg = f"{name} is an not an accepted value. Got: {value}. Accepted values {accepted_vals}."
142
+ if raise_error:
143
+ raise IntegerError(msg=msg, source=check_int.__name__)
144
+ else:
145
+ return False, msg
146
+
147
+ if not allow_negative and int(value) < 0:
148
+ msg = f"{name} is negative and negative is not accepted. Got: {value}."
149
+ if raise_error:
150
+ raise IntegerError(msg=msg, source=check_int.__name__)
151
+ else:
152
+ return False, msg
153
+
154
+ if not allow_zero and int(value) == 0:
155
+ msg = f"{name} is zero and zero is not accepted. Got: {value}."
156
+ if raise_error:
157
+ raise IntegerError(msg=msg, source=check_int.__name__)
158
+ else:
159
+ return False, msg
160
+
161
+ return True, msg
162
+
163
+
164
+ def check_str(name: str,
165
+ value: Any,
166
+ options: Optional[Union[Tuple[Any], List[Any], Iterable[Any]]] = (),
167
+ allow_blank: bool = False,
168
+ invalid_options: Optional[Union[List[str], Tuple[str]]] = None,
169
+ raise_error: bool = True,
170
+ invalid_substrs: Optional[Union[List[str], Tuple[str]]] = None) -> Tuple[bool, str]:
171
+
172
+ """
173
+ Check if variable is a valid string.
174
+
175
+ :param str name: Name of variable
176
+ :param Any value: Value of variable
177
+ :param Optional[Tuple[Any]] options: Tuple of allowed strings. If empty tuple, then any string allowed. Default: ().
178
+ :param Optional[bool] allow_blank: If True, allow empty string. Default: False.
179
+ :param Optional[bool] raise_error: If True, then raise error if invalid string. Default: True.
180
+ :param Optional[List[str]] invalid_options: If not None, then a list of strings that are invalid.
181
+ :param Optional[List[str]] invalid_substrs: If not None, then a list of characters or substrings that are not allowed in the string.
182
+ :return: If `raise_error` is False, then returns size-2 Tuple, with first value being a bool representing if valid string, and second value a string representing error reason (if valid is False, else empty string).
183
+ :rtype: Tuple[bool, str]
184
+
185
+ :examples:
186
+ >>> check_str(name='split_eval', input='gini', options=['entropy', 'gini'])
187
+ """
188
+
189
+ msg = ""
190
+ try:
191
+ t.String(allow_blank=allow_blank).check(value)
192
+ except t.DataError as e:
193
+ msg = f"{name} should be an string in SimBA, but is set to {str(value)}"
194
+ if raise_error:
195
+ raise StringError(msg=msg, source=check_str.__name__)
196
+ else:
197
+ return False, msg
198
+ if len(options) > 0:
199
+ if value not in options:
200
+ msg = f"{name} is set to {value} in SimBA, but this is not a valid option: {options}"
201
+ if raise_error:
202
+ raise StringError(msg=msg, source=check_str.__name__)
203
+ else:
204
+ return False, msg
205
+ else:
206
+ return True, msg
207
+
208
+ if invalid_options is not None:
209
+ check_instance(source=f'{name} invalid_options', accepted_types=(tuple, list,), instance=invalid_options)
210
+ if isinstance(invalid_options, tuple):
211
+ invalid_options = list(invalid_options)
212
+ check_valid_lst(data=invalid_options, valid_dtypes=(str,), min_len=1)
213
+ if value in invalid_options:
214
+ msg = f"{name} is set to {value} in SimBA, but this is among invalid options: {invalid_options}"
215
+ if raise_error:
216
+ raise StringError(msg=msg, source=check_str.__name__)
217
+ else:
218
+ return False, msg
219
+ else:
220
+ return True, msg
221
+ if invalid_substrs is not None:
222
+ if not isinstance(invalid_substrs, (tuple, list)):
223
+ check_instance(source=f'{name} invalid_characters', accepted_types=(tuple, list,), instance=invalid_options)
224
+ if isinstance(invalid_substrs, tuple):
225
+ invalid_substrs = list(invalid_substrs)
226
+ check_valid_lst(data=invalid_substrs, valid_dtypes=(str,), min_len=1)
227
+ for substr in invalid_substrs:
228
+ if substr in value:
229
+ msg = f'{name} contains the characters "{substr}" . This character/substring is NOT accepted.'
230
+ if raise_error:
231
+ raise StringError(msg=msg, source=check_str.__name__)
232
+ else:
233
+ return False, msg
234
+ else:
235
+ return True, msg
236
+
237
+ def check_float(name: str,
238
+ value: Any,
239
+ max_value: Optional[float] = None,
240
+ min_value: Optional[float] = None,
241
+ raise_error: bool = True,
242
+ allow_zero: bool = True,
243
+ allow_negative: bool = True) -> Tuple[bool, str]:
244
+ """
245
+ Check if variable is a valid float.
246
+
247
+ :param str name: Name of variable
248
+ :param Any value: Value of variable
249
+ :param Optional[int] max_value: Maximum allowed value of the float. If None, then no maximum. Default: None.
250
+ :param Optional[int]: Minimum allowed value of the float. If None, then no minimum. Default: Non
251
+ :param Optional[bool] allow_zero: If True, do not allow float to be zero. Default: True and allow zero.
252
+ :param Optional[bool] allow_negative: If True, do not allow float to be below zero Default: True and allow negative.
253
+ :param Optional[bool] raise_error: If True, then raise error if invalid float. Default: True.
254
+ :return: If `raise_error` is False, then returns size-2 tuple, with first value being a bool representing if valid float, and second value a string representing error (if valid is False, else empty string)
255
+ :rtype: Tuple[bool, str]
256
+
257
+
258
+ :examples:
259
+ >>> check_float(name='My_float', value=0.5, max_value=1.0, min_value=0.0)
260
+ """
261
+
262
+ msg = ""
263
+ try:
264
+ t.Float().check(value)
265
+ except t.DataError as e:
266
+ msg = f"{name} should be a float number in SimBA, but is set to {str(value)}"
267
+ if raise_error:
268
+ raise FloatError(msg=msg, source=check_float.__name__)
269
+ else:
270
+ return False, msg
271
+ if min_value != None:
272
+ if float(value) < min_value:
273
+ msg = f"{name} should be MORE THAN OR EQUAL to {str(min_value)}. It is set to {str(value)}"
274
+ if raise_error:
275
+ raise FloatError(msg=msg, source=check_float.__name__)
276
+ else:
277
+ return False, msg
278
+ if max_value != None:
279
+ if float(value) > max_value:
280
+ msg = f"{name} should be LESS THAN OR EQUAL to {str(max_value)}. It is set to {str(value)}"
281
+ if raise_error:
282
+ raise FloatError(msg=msg, source=check_float.__name__)
283
+ else:
284
+ return False, msg
285
+ if not allow_zero:
286
+ if float(value) == 0:
287
+ msg = f"{name} cannot be ZERO. It is set to {str(value)}"
288
+ if raise_error:
289
+ raise FloatError(msg=msg, source=check_float.__name__)
290
+ else:
291
+ return False, msg
292
+
293
+ if not allow_negative:
294
+ if float(value) < 0:
295
+ msg = f"{name} cannot be BELOW zero. It is set to {str(value)}"
296
+ if raise_error:
297
+ raise FloatError(msg=msg, source=check_float.__name__)
298
+ else:
299
+ return False, msg
300
+
301
+ return True, msg
302
+
303
+
304
+ def check_iterable_length(source: str, val: int, exact_accepted_length: Optional[int] = None, max: Optional[int] = np.inf, min: int = 1, raise_error: bool = True) -> bool:
305
+ if (not exact_accepted_length) and (not max) and (not min):
306
+ if raise_error:
307
+ raise InvalidInputError(msg=f"Provide exact_accepted_length or max and min values for {source}", source=check_iterable_length.__name__)
308
+ else:
309
+ return False
310
+ if exact_accepted_length:
311
+ if val != exact_accepted_length:
312
+ if raise_error:
313
+ raise InvalidInputError(msg=f"{source} length is {val}, expected {exact_accepted_length}", source=check_iterable_length.__name__)
314
+ else:
315
+ return False
316
+ elif (val > max) or (val < min):
317
+ if raise_error:
318
+ raise InvalidInputError(msg=f"{source} value {val} does not full-fill criterion: min {min}, max{max} ", source=check_iterable_length.__name__)
319
+ else:
320
+ return False
321
+ return True
322
+
323
+
324
+ def check_instance(source: str, instance: object, accepted_types: Union[Tuple[Any], Any], raise_error: bool = True, warning: bool = True) -> bool:
325
+ """
326
+ Check if an instance is an acceptable type.
327
+
328
+ :param str name: Arbitrary name of instance used for interpretable error msg. Can also be the name of the method.
329
+ :param object instance: A data object.
330
+ :param Union[Tuple[object], object] accepted_types: Accepted instance types. E.g., (Polygon, pd.DataFrame) or Polygon.
331
+ :param Optional[bool] raise_error: If True, raises error of instance is not of valid type, else returns bool.
332
+ :param Optional[bool] warning: If True, prints warning of instance is not of valid type, else returns bool.
333
+ """
334
+
335
+ if not isinstance(instance, accepted_types):
336
+ msg = f"{source} requires {accepted_types}, got {type(instance)}"
337
+ if raise_error:
338
+ raise InvalidInputError(msg=msg, source=source)
339
+ else:
340
+ if warning:
341
+ InvalidValueWarning(msg=msg, source=source)
342
+ return False
343
+ return True
344
+
345
+
346
+ def get_fn_ext(filepath: Union[os.PathLike, str]) -> (str, str, str):
347
+ """
348
+ Split file path into three components: (i) directory, (ii) file name, and (iii) file extension.
349
+
350
+ :parameter str filepath: Path to file.
351
+ :return str: File directory name
352
+ :return str: File name
353
+ :return str: File extension
354
+
355
+ :example:
356
+ >>> get_fn_ext(filepath='C:/My_videos/MyVideo.mp4')
357
+ >>> ('My_videos', 'MyVideo', '.mp4')
358
+ """
359
+ file_extension = Path(filepath).suffix
360
+ try:
361
+ file_name = os.path.basename(filepath.rsplit(file_extension, 1)[0])
362
+ except ValueError:
363
+ raise InvalidFilepathError(
364
+ msg=f"{filepath} is not a valid filepath", source=get_fn_ext.__name__
365
+ )
366
+ dir_name = os.path.dirname(filepath)
367
+ return dir_name, file_name, file_extension
368
+
369
+
370
+ def check_if_filepath_list_is_empty(filepaths: List[str], error_msg: str) -> None:
371
+ """
372
+ Check if a list is empty
373
+
374
+ :param List[str]: List of file-paths.
375
+ :raise NoFilesFoundError: The list is empty.
376
+ """
377
+
378
+ if len(filepaths) == 0:
379
+ raise NoFilesFoundError(
380
+ msg=error_msg, source=check_if_filepath_list_is_empty.__name__
381
+ )
382
+ else:
383
+ pass
384
+
385
+
386
+ def check_all_file_names_are_represented_in_video_log(
387
+ video_info_df: pd.DataFrame, data_paths: List[Union[str, os.PathLike]]
388
+ ) -> None:
389
+ """
390
+ Helper to check that all files are represented in a dataframe of the SimBA `project_folder/logs/video_info.csv`
391
+ file.
392
+
393
+ :param pd.DataFrame video_info_df: List of file-paths.
394
+ :param List[Union[str, os.PathLike]] data_paths: List of file-paths.
395
+ :raise ParametersFileError: The list is empty.
396
+ """
397
+
398
+ missing_videos = []
399
+ for file_path in data_paths:
400
+ video_name = get_fn_ext(file_path)[1]
401
+ if video_name not in list(video_info_df["Video"]):
402
+ missing_videos.append(video_name)
403
+ if len(missing_videos) > 0:
404
+ raise ParametersFileError(
405
+ msg=f"SimBA could not find {len(missing_videos)} video(s) in the video_info.csv file. Make sure all videos analyzed are represented in the project_folder/logs/video_info.csv file. MISSING VIDEOS: {missing_videos}"
406
+ )
407
+
408
+
409
+ def check_if_dir_exists(in_dir: Union[str, os.PathLike],
410
+ source: Optional[str] = None,
411
+ create_if_not_exist: Optional[bool] = False,
412
+ raise_error: bool = True) -> Union[None, bool]:
413
+ """
414
+ Check if a directory path exists.
415
+
416
+ :param Union[str, os.PathLike] in_dir: Putative directory path.
417
+ :param Optional[str] source: String source for interpretable error messaging.
418
+ :param Optional[bool] create_if_not_exist: If directory does not exist, then create it. Default False.
419
+ :param Optional[bool] raise_error: If True, raise error if dir does not exist. If False return None. Default True.
420
+ :raise NotDirectoryError: The directory does not exist.
421
+ """
422
+
423
+ if not isinstance(in_dir, (str, Path, os.PathLike)):
424
+ if raise_error:
425
+ raise NotDirectoryError(msg=f"{in_dir} is not a valid directory", source=check_if_dir_exists.__name__)
426
+ else:
427
+ return False
428
+
429
+ elif not os.path.isdir(in_dir):
430
+ if create_if_not_exist:
431
+ try:
432
+ os.makedirs(in_dir)
433
+ except:
434
+ pass
435
+ else:
436
+ if source is None:
437
+ if raise_error:
438
+ raise NotDirectoryError(msg=f"{in_dir} is not a valid directory", source=check_if_dir_exists.__name__)
439
+ else:
440
+ return False
441
+ else:
442
+ if raise_error:
443
+ raise NotDirectoryError(msg=f"{in_dir} is not a valid directory", source=source)
444
+ else:
445
+ return False
446
+ else:
447
+ return True
448
+
449
+
450
+ def check_that_column_exist(df: pd.DataFrame,
451
+ column_name: Union[str, os.PathLike, List[str]],
452
+ file_name: str,
453
+ raise_error: bool = True) -> Union[None, bool]:
454
+ """
455
+ Check if single named field or a list of fields exist within a dataframe.
456
+
457
+ .. seealso::
458
+ Consider :func:`simba.utils.checks.check_valid_dataframe` instead.
459
+
460
+ :param pd.DataFrame df: The DataFrame to check for column existence.
461
+ :param Union[str, os.PathLike, List[str]] column_name: Name or names of field(s) to check for existence.
462
+ :param str file_name: Path of ``df`` on disk (used for error messages).
463
+ :param bool raise_error: If True, raises ColumnNotFoundError if column doesn't exist. If False, returns bool. Default: True.
464
+ :return: True if all columns exist, False if any column is missing (when raise_error=False), None if raise_error=True and all columns exist.
465
+ :rtype: Union[None, bool]
466
+ :raises ColumnNotFoundError: The ``column_name`` does not exist within ``df``.
467
+
468
+ :example:
469
+ >>> df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
470
+ >>> check_that_column_exist(df=df, column_name='A', file_name='test.csv')
471
+ True
472
+ >>> check_that_column_exist(df=df, column_name=['A', 'B'], file_name='test.csv')
473
+ True
474
+ >>> check_that_column_exist(df=df, column_name='C', file_name='test.csv', raise_error=False)
475
+ False
476
+ """
477
+
478
+ if type(column_name) == str:
479
+ column_name = [column_name]
480
+ for column in column_name:
481
+ if column not in df.columns:
482
+ if raise_error:
483
+ raise ColumnNotFoundError(column_name=column, file_name=file_name, source=check_that_column_exist.__name__)
484
+ else:
485
+ return False
486
+ return True
487
+
488
+
489
+ def check_if_valid_input(
490
+ name: str, input: str, options: List[str], raise_error: bool = True
491
+ ) -> (bool, str):
492
+ """
493
+ Check if string variable is valid option.
494
+
495
+ .. seealso::
496
+ Consider :func:`simba.utils.checks.check_str`.
497
+
498
+ :param str name: Atrbitrary name of variable.
499
+ :param Any input: Value of variable.
500
+ :param List[str] options: Allowed options of ``input``
501
+ :param Optional[bool] raise_error: If True, then raise error if invalid value. Default: True.
502
+
503
+ :return bool: False if invalid. True if valid.
504
+ :return str: If invalid, then error msg. Else, empty str.
505
+
506
+ :example:
507
+ >>> check_if_valid_input(name='split_eval', input='gini', options=['entropy', 'gini'])
508
+ >>> (True, '')
509
+ """
510
+
511
+ msg = ""
512
+ if input not in options:
513
+ msg = f"{name} is set to {str(input)}, which is an invalid setting. OPTIONS {options}"
514
+ if raise_error:
515
+ raise InvalidInputError(msg=msg, source=check_if_valid_input.__name__)
516
+ else:
517
+ return False, msg
518
+ else:
519
+ return True, msg
520
+
521
+
522
+ def check_minimum_roll_windows(
523
+ roll_windows_values: List[int], minimum_fps: float
524
+ ) -> List[int]:
525
+ """
526
+ Remove any rolling temporal window that are shorter than a single frame in
527
+ any of the videos within the project.
528
+
529
+ :param List[int] roll_windows_values: Rolling temporal windows represented as frame counts. E.g., [10, 15, 30, 60]
530
+ :param float minimum_fps: The lowest fps of the videos that are to be analyzed. E.g., 10.
531
+
532
+ :return List[int]: roll_windows_values without impassable windows.
533
+ """
534
+
535
+ for win in range(len(roll_windows_values)):
536
+ if minimum_fps < roll_windows_values[win]:
537
+ roll_windows_values[win] = minimum_fps
538
+ else:
539
+ pass
540
+ roll_windows_values = list(set(roll_windows_values))
541
+ return roll_windows_values
542
+
543
+
544
+ def check_same_number_of_rows_in_dfs(dfs: List[pd.DataFrame]) -> bool:
545
+ """
546
+ Helper to check that each dataframe in list contains an equal number of rows
547
+
548
+ :param List[pd.DataFrame] dfs: List of dataframes.
549
+ :return bool: True if dataframes has an equal number of rows. Else False.
550
+
551
+ >>> df_1, df_2 = pd.DataFrame([[1, 2], [1, 2]]), pd.DataFrame([[4, 2], [9, 3], [1, 5]])
552
+ >>> check_same_number_of_rows_in_dfs(dfs=[df_1, df_2])
553
+ >>> False
554
+ >>> df_1, df_2 = pd.DataFrame([[1, 2], [1, 2]]), pd.DataFrame([[4, 2], [9, 3]])
555
+ >>> True
556
+ """
557
+
558
+ row_cnt = None
559
+ for df_cnt, df in enumerate(dfs):
560
+ if df_cnt == 0:
561
+ row_cnt = len(df)
562
+ else:
563
+ if len(df) != row_cnt:
564
+ return False
565
+ return True
566
+
567
+
568
+ def check_if_headers_in_dfs_are_unique(dfs: List[pd.DataFrame]) -> List[str]:
569
+ """
570
+ Helper to check heaaders in multiple dataframes are unique.
571
+
572
+ :param List[pd.DataFrame] dfs: List of dataframes.
573
+ :return List[str]: List of columns headers seen in multiple dataframes. Empty if None.
574
+
575
+ :examples:
576
+ >>> df_1, df_2 = pd.DataFrame([[1, 2]], columns=['My_column_1', 'My_column_2']), pd.DataFrame([[4, 2]], columns=['My_column_3', 'My_column_1'])
577
+ >>> check_if_headers_in_dfs_are_unique(dfs=[df_1, df_2])
578
+ >>> ['My_column_1']
579
+ """
580
+ seen_headers = []
581
+ for df_cnt, df in enumerate(dfs):
582
+ seen_headers.extend(list(df.columns))
583
+ duplicates = list(set([x for x in seen_headers if seen_headers.count(x) > 1]))
584
+ return duplicates
585
+
586
+
587
+ def check_if_string_value_is_valid_video_timestamp(value: str, name: str) -> None:
588
+ """
589
+ Helper to check if a string is in a valid HH:MM:SS format
590
+
591
+ :param str value: Timestamp in HH:MM:SS format.
592
+ :param str name: An arbitrary string name of the timestamp.
593
+ :raises InvalidInputError: If the timestamp is in invalid format
594
+
595
+ :example:
596
+ >>> check_if_string_value_is_valid_video_timestamp(value='00:0b:10', name='My time stamp')
597
+ >>> "InvalidInputError: My time stamp is should be in the format XX:XX:XX where X is an integer between 0-9"
598
+ >>> check_if_string_value_is_valid_video_timestamp(value='00:00:10', name='My time stamp'
599
+ """
600
+ r = re.compile(r"^\d{2}:\d{2}:\d{2}(\.\d+)?$")
601
+ if not r.match(value):
602
+ raise InvalidInputError(
603
+ msg=f"{name} should be in the format XX:XX:XX:XXXX or XX:XX:XX where X is an integer between 0-9. Got: {value}",
604
+ source=check_if_string_value_is_valid_video_timestamp.__name__,
605
+ )
606
+ else:
607
+ pass
608
+
609
+
610
+ def check_that_hhmmss_start_is_before_end(
611
+ start_time: str, end_time: str, name: str
612
+ ) -> None:
613
+ """
614
+ Helper to check that a start time in HH:MM:SS or HH:MM:SS:MS format is before an end time in HH:MM:SS or HH:MM:SS:MS format
615
+
616
+ :param str start_time: Period start time in HH:MM:SS format.
617
+ :param str end_time: Period end time in HH:MM:SS format.
618
+ :param int name: Name of the variable
619
+ :raises InvalidInputError: If end time is before the start time.
620
+
621
+ :example:
622
+ >>> check_that_hhmmss_start_is_before_end(start_time='00:00:05', end_time='00:00:01', name='My time period')
623
+ >>> "InvalidInputError: My time period has an end-time which is before the start-time"
624
+ >>> check_that_hhmmss_start_is_before_end(start_time='00:00:01', end_time='00:00:05')
625
+ """
626
+
627
+ if len(start_time.split(":")) != 3:
628
+ raise InvalidInputError(
629
+ f"Invalid time-stamp: ({start_time}). HH:MM:SS or HH:MM:SS.MS format required"
630
+ )
631
+ elif len(end_time.split(":")) != 3:
632
+ raise InvalidInputError(
633
+ f"Invalid time-stamp: ({end_time}). HH:MM:SS or HH:MM:SS.MS format required"
634
+ )
635
+ start_h, start_m, start_s = start_time.split(":")
636
+ end_h, end_m, end_s = end_time.split(":")
637
+ start_val = int(start_h) * 3600 + int(start_m) * 60 + float(start_s)
638
+ end_val = int(end_h) * 3600 + int(end_m) * 60 + float(end_s)
639
+ if end_val < start_val:
640
+ raise InvalidInputError(
641
+ f"{name} has an end-time which is before the start-time.",
642
+ source=check_that_hhmmss_start_is_before_end.__name__,
643
+ )
644
+
645
+
646
+ def check_nvidea_gpu_available(raise_error: bool = False) -> bool:
647
+ """
648
+ Helper to check of NVIDEA GPU is available via ``nvidia-smi``.
649
+ returns bool: True if nvidia-smi returns not None. Else False.
650
+ """
651
+ try:
652
+ subprocess.check_output("nvidia-smi")
653
+ return True
654
+ except Exception:
655
+ if raise_error:
656
+ raise SimBAGPUError(msg='No NVIDIA GPU detected on machine (checked by calling "nvidia-smi")', source=check_nvidea_gpu_available.__name__)
657
+ return False
658
+
659
+
660
+ def check_ffmpeg_available(raise_error: Optional[bool] = False) -> Union[bool, None]:
661
+ """
662
+ Helper to check of FFMpeg is available via subprocess ``ffmpeg``.
663
+
664
+ .. seealso::
665
+ To check which encoders are available in FFMpeg installation, see :func:`simba.utils.lookups.get_ffmpeg_encoders`
666
+
667
+ :param Optional[bool] raise_error: If True, raises ``FFMPEGNotFoundError`` if FFmpeg can't be found. Else return False. Default False.
668
+ :returns bool: True if ``ffmpeg`` returns not None and raise_error is False. Else False.
669
+ """
670
+
671
+ try:
672
+ subprocess.call("ffmpeg", stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
673
+ return True
674
+ except Exception:
675
+ if raise_error:
676
+ raise FFMPEGNotFoundError(
677
+ msg="FFMpeg could not be found on the instance (as evaluated via subprocess ffmpeg). Please make sure FFMpeg is installed."
678
+ )
679
+ else:
680
+ return False
681
+
682
+ def check_if_valid_rgb_str(
683
+ input: str,
684
+ delimiter: str = ",",
685
+ return_cleaned_rgb_tuple: bool = True,
686
+ reverse_returned: bool = True,
687
+ ):
688
+ """
689
+ Helper to check if a string is a valid representation of an RGB color.
690
+
691
+ :param str input: Value to check as string. E.g., '(166, 29, 12)' or '22,32,999'
692
+ :param str delimiter: The delimiter between subsequent values in the rgb input string.
693
+ :param bool return_cleaned_rgb_tuple: If True, and input is a valid rgb, then returns a "clean" rgb tuple: Eg. '166, 29, 12' -> (166, 29, 12). Else, returns None.
694
+ :param bool reverse_returned: If True and return_cleaned_rgb_tuple is True, reverses to returned cleaned rgb tuple (e.g., RGB becomes BGR) before returning it.
695
+
696
+ :example:
697
+ >>> check_if_valid_rgb_str(input='(50, 25, 100)', return_cleaned_rgb_tuple=True, reverse_returned=True)
698
+ >>> (100, 25, 50)
699
+ """
700
+
701
+ input = input.replace(" ", "")
702
+ if input.count(delimiter) != 2:
703
+ raise InvalidInputError(msg=f"{input} in not a valid RGB color")
704
+ values = input.split(",")
705
+ rgb = []
706
+ for value in values:
707
+ val = "".join(c for c in value if c.isdigit())
708
+ check_int(
709
+ name="RGB value", value=val, max_value=255, min_value=0, raise_error=True
710
+ )
711
+ rgb.append(val)
712
+ rgb = tuple([int(x) for x in rgb])
713
+
714
+ if return_cleaned_rgb_tuple:
715
+ if reverse_returned:
716
+ rgb = rgb[::-1]
717
+ return rgb
718
+
719
+
720
+ def check_if_valid_rgb_tuple(data: Tuple[int, int, int],
721
+ raise_error: bool = True,
722
+ source: Optional[str] = None) -> bool:
723
+ check_instance(source=check_if_valid_rgb_tuple.__name__, instance=data, accepted_types=tuple, raise_error=raise_error)
724
+ check_iterable_length(source=check_if_valid_rgb_tuple.__name__, val=len(data), exact_accepted_length=3, raise_error=raise_error)
725
+ for i in range(len(data)):
726
+ if source is None:
727
+ check_int(name="RGB value", value=data[i], max_value=255, min_value=0, raise_error=raise_error)
728
+ else:
729
+ check_int(name=f"RGB value {source}", value=data[i], max_value=255, min_value=0, raise_error=raise_error)
730
+ return True
731
+
732
+
733
+ def check_if_list_contains_values(
734
+ data: List[Union[float, int, str]],
735
+ values: List[Union[float, int, str]],
736
+ name: str,
737
+ raise_error: bool = True,
738
+ ) -> None:
739
+ """
740
+ Helper to check if values are represeted in a list. E.g., make sure annotatations of behvaior absent and present are represented in annitation column
741
+
742
+ :param List[Union[float, int, str]] data: List of values. E.g., annotation column represented as list.
743
+ :param List[Union[float, int, str]] values: Values to conform present. E.g., [0, 1].
744
+ :param str name: Arbitrary name of the data for more useful error msg.
745
+ :param bool raise_error: If True, raise error of not all values can be found in data. Else, print warning.
746
+
747
+ :example:
748
+ >>> check_if_list_contains_values(data=[1,2, 3, 4, 0], values=[0, 1, 6], name='My_data')
749
+ """
750
+
751
+ data, missing_values = list(set(data)), []
752
+ for value in values:
753
+ if value not in data:
754
+ missing_values.append(value)
755
+
756
+ if len(missing_values) > 0 and raise_error:
757
+ raise NoDataError(
758
+ msg=f"{name} does not contain the following expected values: {missing_values}",
759
+ source=check_if_list_contains_values.__name__,
760
+ )
761
+
762
+ elif len(missing_values) > 0 and not raise_error:
763
+ NoDataFoundWarning(
764
+ msg=f"{name} does not contain the following expected values: {missing_values}",
765
+ source=check_if_list_contains_values.__name__,
766
+ )
767
+
768
+
769
+ def check_valid_hex_color(color_hex: str, raise_error: Optional[bool] = True) -> bool:
770
+ """
771
+ Check if given string represents a valid hexadecimal color code.
772
+
773
+ :param str color_hex: A string representing a hexadecimal color code, either in the format '#RRGGBB' or '#RGB'.
774
+ :param bool raise_error: If True, raise an exception when the color_hex is invalid; if False, return False instead. Default is True.
775
+ :return bool: True if the color_hex is a valid hexadecimal color code; False otherwise (if raise_error is False).
776
+ :raises IntegerError: If the color_hex is an invalid hexadecimal color code and raise_error is True.
777
+ """
778
+
779
+ hex_regex = re.compile(r"^#([0-9a-fA-F]{6}|[0-9a-fA-F]{3})$")
780
+ match = hex_regex.match(color_hex)
781
+ if match is None and raise_error:
782
+ raise IntegerError(
783
+ msg=f"{color_hex} is an invalid hex color",
784
+ source=check_valid_hex_color.__name__,
785
+ )
786
+ elif match is None and not raise_error:
787
+ return False
788
+ else:
789
+ return True
790
+
791
+ def check_valid_url(url: str, raise_error: bool = False, source: str = '') -> bool:
792
+ """
793
+ Check if a string is a valid URL (http, https, or ftp).
794
+
795
+ :param str url: The string to validate as a URL.
796
+ :param bool raise_error: If True, raises InvalidInputError when the URL is invalid. Default: False.
797
+ :param str source: Source identifier for error messages when raise_error=True. Default: ''.
798
+ :return: True if the string is a valid URL, False otherwise.
799
+ """
800
+ regex = re.compile(
801
+ r'^(https?|ftp)://' # protocol
802
+ r'(\S+(:\S*)?@)?' # user:password (optional)
803
+ r'((\d{1,3}\.){3}\d{1,3}|' # IP address
804
+ r'([a-zA-Z0-9.-]+\.[a-zA-Z]{2,}))' # domain name
805
+ r'(:\d+)?' # port (optional)
806
+ r'(/[\S]*)?$', # path (optional)
807
+ re.IGNORECASE)
808
+ is_valid = re.match(regex, url) is not None
809
+ if not is_valid and raise_error:
810
+ raise InvalidInputError(
811
+ msg=f"Invalid URL: {url}",
812
+ source=source or check_valid_url.__name__
813
+ )
814
+ return is_valid
815
+
816
+
817
+ def check_if_2d_array_has_min_unique_values(data: np.ndarray, min: int) -> bool:
818
+ """
819
+ Check if a 2D NumPy array has at least a minimum number of unique rows.
820
+
821
+ For example, use when creating shapely Polygons or Linestrings, which typically requires at least 2 or three unique
822
+ body-part coordinates.
823
+
824
+ :param np.ndarray data: Input 2D array to be checked.
825
+ :param np.ndarray min: Minimum number of unique rows required.
826
+ :return bool: True if the input array has at least the specified minimum number of unique rows, False otherwise.
827
+
828
+ :example:
829
+ >>> data = np.array([[0, 0], [0, 0], [0, 0], [0, 1]])
830
+ >>> check_if_2d_array_has_min_unique_values(data=data, min=2)
831
+ >>> True
832
+ """
833
+
834
+ if len(data.shape) != 2:
835
+ raise CountError(
836
+ msg=f"Requires input array of two dimensions, found {data.size}",
837
+ source=check_if_2d_array_has_min_unique_values.__name__,
838
+ )
839
+ sliced_data = np.unique(data, axis=0)
840
+ if sliced_data.shape[0] < min:
841
+ return False
842
+ else:
843
+ return True
844
+
845
+
846
+ def check_if_module_has_import(parsed_file: ast.Module, import_name: str) -> bool:
847
+ """
848
+ Check if a Python module has a specific import statement. For example, check if module imports `argparse` or circular statistics mixin.
849
+
850
+ Used for e.g., user custom feature extraction classes in ``simba.utils.custom_feature_extractor.CustomFeatureExtractor``.
851
+
852
+ :parameter ast.Module file_path: The abstract syntax tree (AST) of the Python module.
853
+ :parameter str import_name: The name of the module or package to check for in the import statements.
854
+ :parameter bool: True if the specified import is found in the module, False otherwise.
855
+
856
+ :example:
857
+ >>> parsed_file = ast.parse(Path('/simba/misc/piotr.py').read_text())
858
+ >>> check_if_module_has_import(parsed_file=parsed_file, import_name='argparse')
859
+ >>> True
860
+ """
861
+ imports = [
862
+ n for n in parsed_file.body if isinstance(n, (ast.Import, ast.ImportFrom))
863
+ ]
864
+ for i in imports:
865
+ for name in i.names:
866
+ if name.name == import_name:
867
+ return True
868
+ return False
869
+
870
+
871
+ def check_valid_extension(
872
+ path: Union[str, os.PathLike], accepted_extensions: Union[List[str], str]
873
+ ):
874
+ """
875
+ Checks if the file extension of the provided path is in the list of accepted extensions.
876
+
877
+ :param Union[str, os.PathLike] file_path: The path to the file whose extension needs to be checked.
878
+ :param List[str] accepted_extensions: A list of accepted file extensions. E.g., ['pickle', 'csv'].
879
+ """
880
+ if isinstance(accepted_extensions, (list, tuple)):
881
+ check_valid_lst(data=accepted_extensions, source=f"{check_valid_extension.__name__} accepted_extensions", valid_dtypes=(str,), min_len=1)
882
+ elif isinstance(accepted_extensions, str):
883
+ check_str(name=f"{check_valid_extension.__name__} accepted_extensions", value=accepted_extensions)
884
+ accepted_extensions = [accepted_extensions]
885
+ accepted_extensions = [x.lower() for x in accepted_extensions]
886
+ check_file_exist_and_readable(file_path=path)
887
+ extension = get_fn_ext(filepath=path)[2][1:]
888
+ if extension.lower() not in accepted_extensions:
889
+ raise InvalidFilepathError(msg=f"File extension for file {path} has an invalid extension. Found {extension}, accepted: {accepted_extensions}", source=check_valid_extension.__name__)
890
+
891
+
892
+ def check_if_valid_img(data: np.ndarray,
893
+ source: str = "",
894
+ raise_error: bool = True,
895
+ greyscale: bool = False,
896
+ color: bool = False) -> Union[bool, None]:
897
+ """
898
+ Check if a variable is a valid image.
899
+
900
+ :param str source: Name of the variable and/or class origin for informative error messaging and logging.
901
+ :param np.ndarray data: Data variable to check if a valid image representation.
902
+ :param bool greyscale: Checks that the image is greyscale. Default False.
903
+ :param bool color: Checks that the image is color. Default False.
904
+ :parameter bool raise_error: If True, raise InvalidInputError if invalid image representation. Else, return bool.
905
+ """
906
+
907
+ check_instance(source=check_if_valid_img.__name__, instance=data, accepted_types=(np.ndarray, cp.ndarray))
908
+ if (data.ndim != 2) and (data.ndim != 3):
909
+ if raise_error:
910
+ raise InvalidInputError(msg=f"The {source} data is not a valid image. It has {data.ndim} dimensions", source=check_if_valid_img.__name__)
911
+ else:
912
+ return False
913
+ if data.dtype not in [np.uint8, np.uint16, np.float32, np.float64]:
914
+ if raise_error:
915
+ raise InvalidInputError(msg=f"The {source} data is not a valid image. It is dtype {data.dtype}", source=check_if_valid_img.__name__)
916
+ else:
917
+ return False
918
+ if np.max(data) > 255:
919
+ if raise_error:
920
+ raise InvalidInputError(msg=f"The {source} data is not a valid image. Values found that are above 255: {np.max(data)}", source=check_if_valid_img.__name__)
921
+ if greyscale:
922
+ if (data.ndim != 2):
923
+ if raise_error:
924
+ raise InvalidInputError(msg=f"The {source} image is not a greyscale image. Got {data.ndim} dimensions", source=check_if_valid_img.__name__)
925
+ else:
926
+ return False
927
+ if color:
928
+ if (data.ndim != 3):
929
+ if raise_error:
930
+ raise InvalidInputError(msg=f"The {source} image is not a color image. Got {data.ndim} dimensions", source=check_if_valid_img.__name__)
931
+ else:
932
+ return False
933
+
934
+
935
+
936
+ return True
937
+
938
+
939
+ def check_that_dir_has_list_of_filenames(
940
+ dir: Union[str, os.PathLike],
941
+ file_name_lst: List[str],
942
+ file_type: Optional[str] = "csv",
943
+ ):
944
+ """
945
+ Check that all file names in a list has an equivalent file in a specified directory. E.g., check if all files in the outlier corrected folder has an equivalent file in the featurues_extracted directory.
946
+
947
+ :example:
948
+ >>> file_name_lst = glob.glob('/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement' + '/*.csv')
949
+ >>> check_that_dir_has_list_of_filenames(dir = '/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/features_extracted', file_name_lst=file_name_lst)
950
+ """
951
+
952
+ files_in_dir = glob.glob(dir + f"/*.{file_type}")
953
+ files_in_dir = [os.path.basename(x) for x in files_in_dir]
954
+ for file_name in file_name_lst:
955
+ if os.path.basename(file_name) not in files_in_dir:
956
+ raise NoFilesFoundError(msg=f"File name {os.path.basename(file_name)} could not be found in the directory {dir}", source=check_that_dir_has_list_of_filenames.__name__)
957
+
958
+
959
+ def check_valid_array(data: np.ndarray,
960
+ source: Optional[str] = "",
961
+ accepted_ndims: Optional[Union[Tuple[int], Any]] = None,
962
+ accepted_sizes: Optional[List[int]] = None,
963
+ accepted_axis_0_shape: Optional[Union[List[int], Tuple[int]]] = None,
964
+ accepted_axis_1_shape: Optional[Union[List[int], Tuple[int]]] = None,
965
+ accepted_dtypes: Optional[Union[List[Union[str, Type]], Tuple[Union[str, Type]], Iterable[Any]]] = None,
966
+ accepted_values: Optional[List[Any]] = None,
967
+ accepted_shapes: Optional[List[Tuple[int]]] = None,
968
+ min_axis_0: Optional[int] = None,
969
+ max_axis_1: Optional[int] = None,
970
+ min_axis_1: Optional[int] = None,
971
+ min_value: Optional[Union[float, int]] = None,
972
+ max_value: Optional[Union[float, int]] = None,
973
+ raise_error: bool = True) -> Union[None, bool]:
974
+ """
975
+ Check if the given array satisfies specified criteria regarding its dimensions, shape, and data type.
976
+
977
+ :param np.ndarray data: The numpy array to be checked.
978
+ :param Optional[str] source: A string identifying the source, name, or purpose of the array for interpretable error messaging.
979
+ :param Optional[Union[Tuple[int], Any]] accepted_ndims: List of tuples representing acceptable dimensions. If provided, checks whether the array's number of dimensions matches any tuple in the list.
980
+ :param Optional[List[int]] accepted_sizes: List of acceptable sizes for the array's shape. If provided, checks whether the length of the array's shape matches any value in the list.
981
+ :param Optional[Union[List[int], Tuple[int]]] accepted_axis_0_shape: List of accepted number of rows of 2-dimensional array. Will also raise error if value passed and input is not a 2-dimensional array.
982
+ :param Optional[Union[List[int], Tuple[int]]] accepted_axis_1_shape: List of accepted number of columns or fields of 2-dimensional array. Will also raise error if value passed and input is not a 2-dimensional array.
983
+ :param Optional[Union[List[Union[str, Type]], Tuple[Union[str, Type]], Iterable[Any]]] accepted_dtypes: List of acceptable data types for the array. If provided, checks whether the array's data type matches any string in the list.
984
+ :param Optional[List[Any]] accepted_values: List of acceptable values that can be present in the array.
985
+ :param Optional[List[Tuple[int]]] accepted_shapes: List of acceptable shapes for the array. If provided, checks whether the array's shape matches any tuple in the list.
986
+ :param Optional[int] min_axis_0: Minimum number of rows required for the array.
987
+ :param Optional[int] max_axis_1: Maximum number of columns allowed for the array.
988
+ :param Optional[int] min_axis_1: Minimum number of columns required for the array.
989
+ :param Optional[Union[float, int]] min_value: Minimum value allowed in the array.
990
+ :param Optional[Union[float, int]] max_value: Maximum value allowed in the array.
991
+ :param bool raise_error: If True, raises ArrayError if validation fails. If False, returns bool. Default: True.
992
+ :return: True if array passes all validation checks, False if validation fails (when raise_error=False), None if raise_error=True and validation passes.
993
+ :rtype: Union[None, bool]
994
+
995
+ :example:
996
+ >>> data = np.array([[1, 2], [3, 4]])
997
+ >>> check_valid_array(data, source="Example", accepted_ndims=(2,), accepted_sizes=[2], accepted_dtypes=[np.int64])
998
+ True
999
+ >>> check_valid_array(data, source="Example", min_axis_0=3, raise_error=False)
1000
+ False
1001
+ """
1002
+
1003
+ check_instance(source=source, instance=data, accepted_types=np.ndarray)
1004
+ if accepted_ndims is not None:
1005
+ if data.ndim not in accepted_ndims:
1006
+ if raise_error:
1007
+ raise ArrayError(msg=f"Array not of acceptable dimensions. Found {data.ndim}, accepted: {accepted_ndims}: {source}", source=check_valid_array.__name__)
1008
+ else:
1009
+ return False
1010
+ if accepted_sizes is not None:
1011
+ if len(data.shape) not in accepted_sizes:
1012
+ if raise_error:
1013
+ raise ArrayError(msg=f"Array not of acceptable size. Found {len(data.shape)}, accepted: {accepted_sizes}: {source}", source=check_valid_array.__name__)
1014
+ else:
1015
+ return False
1016
+ if accepted_dtypes is not None:
1017
+ if data.dtype not in accepted_dtypes:
1018
+ if raise_error:
1019
+ raise ArrayError(msg=f"Array not of acceptable type. Found {data.dtype}, accepted: {accepted_dtypes}: {source}", source=check_valid_array.__name__)
1020
+ else:
1021
+ return False
1022
+ if accepted_shapes is not None:
1023
+ if data.shape not in accepted_shapes:
1024
+ if raise_error:
1025
+ raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape}, accepted: {accepted_shapes}: {source}", source=check_valid_array.__name__)
1026
+ else:
1027
+ return False
1028
+ if accepted_axis_0_shape is not None:
1029
+ if not isinstance(accepted_axis_0_shape, (tuple, list)):
1030
+ raise InvalidInputError(msg=f"accepted_axis_0_shape is invalid format. Accepted: {'list, tuple'}. Got: {type(accepted_axis_0_shape)}, {source}", source=check_valid_array.__name__)
1031
+ for cnt, i in enumerate(accepted_axis_0_shape):
1032
+ check_int(name=f"{source} {cnt} accepted_axis_0_shape", value=i, min_value=1)
1033
+ # if data.ndim != 2:
1034
+ # raise ArrayError(
1035
+ # msg=f"Array not of acceptable dimension. Found {data.ndim}, accepted: 2, {source}",
1036
+ # source=check_valid_array.__name__,
1037
+ # )
1038
+ if data.shape[0] not in accepted_axis_0_shape:
1039
+ if raise_error:
1040
+ raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape[0]} rows, accepted: {accepted_axis_0_shape}, {source}", source=check_valid_array.__name__)
1041
+ else:
1042
+ return False
1043
+ if accepted_axis_1_shape is not None:
1044
+ if not isinstance(accepted_axis_1_shape, (tuple, list)):
1045
+ raise InvalidInputError(msg=f"accepted_axis_1_shape is invalid format. Accepted: {'list, tuple'}. Got: {type(accepted_axis_1_shape)}, {source}", source=check_valid_array.__name__)
1046
+ for cnt, i in enumerate(accepted_axis_1_shape):
1047
+ check_int(name=f"{source} {cnt} accepted_axis_1_shape", value=i, min_value=1)
1048
+ if data.ndim != 2:
1049
+ raise ArrayError(msg=f"Array not of acceptable dimension. Found {data.ndim}, accepted: 2, {source}", source=check_valid_array.__name__,)
1050
+ elif data.shape[1] not in accepted_axis_1_shape:
1051
+ if raise_error:
1052
+ raise ArrayError( msg=f"Array not of acceptable shape. Found {data.shape[0]} columns (axis=1), accepted: {accepted_axis_1_shape}, {source}", source=check_valid_array.__name__)
1053
+ else:
1054
+ return False
1055
+ if min_axis_0 is not None:
1056
+ check_int(name=f"{source} min_axis_0", value=min_axis_0)
1057
+ if data.shape[0] < min_axis_0:
1058
+ if raise_error:
1059
+ raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape[0]} rows, minimum accepted: {min_axis_0}, {source}", source=check_valid_array.__name__)
1060
+ else:
1061
+ return False
1062
+
1063
+ if max_axis_1 is not None and data.ndim > 1:
1064
+ check_int(name=f"{source} max_axis_1", value=max_axis_1)
1065
+ if data.shape[1] > max_axis_1:
1066
+ if raise_error:
1067
+ raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape[1]} columns, maximum columns accepted: {max_axis_1}, {source}", source=check_valid_array.__name__)
1068
+ else:
1069
+ return False
1070
+ if min_axis_1 is not None and data.ndim > 1:
1071
+ check_int(name=f"{source} min_axis_1", value=min_axis_1)
1072
+ if data.shape[1] < min_axis_1:
1073
+ if raise_error:
1074
+ raise ArrayError(msg=f"Array not of acceptable shape. Found {data.shape[1]} columns, minimum columns accepted: {min_axis_1}, {source}", source=check_valid_array.__name__)
1075
+ else:
1076
+ return False
1077
+ if accepted_values is not None:
1078
+ check_valid_lst(data=accepted_values, source=f"{source} accepted_values")
1079
+ additional_vals = list(set(np.unique(data)) - set(accepted_values))
1080
+ if len(additional_vals) > 0:
1081
+ if raise_error:
1082
+ raise ArrayError(msg=f"Array contains unacceptable values. Found {additional_vals}, accepted: {accepted_values}, {source}", source=check_valid_array.__name__,)
1083
+ return False
1084
+
1085
+ if min_value is not None:
1086
+ check_float(name=f'{source} min_value', value=min_value)
1087
+ if np.min(data) < min_value:
1088
+ if raise_error:
1089
+ raise ArrayError(msg=f"Array contains value below accepted value. Found {np.min(data)}, accepted minimum: {min_value}, {source}", source=check_valid_array.__name__, )
1090
+ else:
1091
+ return False
1092
+
1093
+ if max_value is not None:
1094
+ check_float(name=f'{source} max_value', value=max_value)
1095
+ if np.max(data) > max_value:
1096
+ if raise_error:
1097
+ raise ArrayError(msg=f"Array contains value above accepted maximum value. Found {np.max(data)}, accepted minimum: {max_value}, {source}", source=check_valid_array.__name__, )
1098
+ else:
1099
+ return False
1100
+ else:
1101
+ return True
1102
+
1103
+ def check_valid_lst(data: list,
1104
+ source: Optional[str] = "",
1105
+ valid_dtypes: Optional[Union[Tuple[Any], List[Any], Any]] = None,
1106
+ valid_values: Optional[List[Any]] = None,
1107
+ min_len: Optional[int] = 1,
1108
+ max_len: Optional[int] = None,
1109
+ min_value: Optional[float] = None,
1110
+ exact_len: Optional[int] = None,
1111
+ raise_error: Optional[bool] = True) -> bool:
1112
+ """
1113
+ Check the validity of a list based on passed criteria.
1114
+
1115
+ :param list data: The input list to be validated.
1116
+ :param Optional[str] source: A string indicating the source or context of the data for informative error messaging.
1117
+ :param Optional[Union[Tuple[Any], List[Any], Any]] valid_dtypes: A tuple, list, or single type of accepted data types. If provided, check if all elements in the list have data types in this collection.
1118
+ :param Optional[List[Any]] valid_values: A list of accepted list values. If provided, check if all elements in the list have matching values in this list.
1119
+ :param Optional[int] min_len: The minimum allowed length of the list. Default: 1.
1120
+ :param Optional[int] max_len: The maximum allowed length of the list.
1121
+ :param Optional[float] min_value: The minimum value allowed for numeric elements in the list.
1122
+ :param Optional[int] exact_len: The exact length required for the list. If provided, overrides min_len and max_len.
1123
+ :param Optional[bool] raise_error: If True, raise an InvalidInputError if any validation fails. If False, return False instead of raising an error. Default: True.
1124
+ :return bool: True if all validation criteria are met, False otherwise.
1125
+
1126
+ :example:
1127
+ >>> check_valid_lst(data=[1, 2, 'three'], valid_dtypes=(int, str), min_len=2, max_len=5)
1128
+ True
1129
+ >>> check_valid_lst(data=[1, 2, 3], valid_dtypes=(int,), exact_len=3)
1130
+ True
1131
+ >>> check_valid_lst(data=[1, 2, 3], min_value=0, raise_error=False)
1132
+ True
1133
+ """
1134
+ check_instance(source=source, instance=data, accepted_types=list)
1135
+ if min_len is not None:
1136
+ check_int(
1137
+ name=f"{source} {min_len}",
1138
+ value=min_len,
1139
+ min_value=0,
1140
+ raise_error=raise_error,
1141
+ )
1142
+ if len(data) < min_len:
1143
+ if raise_error:
1144
+ raise InvalidInputError(
1145
+ msg=f"Invalid length of list. Found {len(data)}, minimum accepted: {min_len}",
1146
+ source=source,
1147
+ )
1148
+ else:
1149
+ return False
1150
+
1151
+ check_instance(source=source, instance=data, accepted_types=list)
1152
+ if valid_dtypes is not None:
1153
+ for dtype in set([type(x) for x in data]):
1154
+ if dtype not in valid_dtypes:
1155
+ if raise_error:
1156
+ raise InvalidInputError(msg=f"Invalid data type found in list. Found {dtype}, accepted: {valid_dtypes}", source=source)
1157
+ else:
1158
+ return False
1159
+
1160
+ if max_len is not None:
1161
+ check_int(
1162
+ name=f"{source} {max_len}",
1163
+ value=max_len,
1164
+ min_value=0,
1165
+ raise_error=raise_error,
1166
+ )
1167
+ if len(data) > max_len:
1168
+ if raise_error:
1169
+ raise InvalidInputError(
1170
+ msg=f"Invalid length of list. Found {len(data)}, maximum accepted: {min_len}",
1171
+ source=source,
1172
+ )
1173
+ else:
1174
+ return False
1175
+ if exact_len is not None:
1176
+ check_int(
1177
+ name=f"{source} {exact_len}",
1178
+ value=exact_len,
1179
+ min_value=0,
1180
+ raise_error=raise_error,
1181
+ )
1182
+ if len(data) != exact_len:
1183
+ if raise_error:
1184
+ raise InvalidInputError(
1185
+ msg=f"Invalid length of list. Found {len(data)}, accepted: {exact_len}",
1186
+ source=source,
1187
+ )
1188
+ else:
1189
+ return False
1190
+
1191
+ if valid_values != None:
1192
+ check_valid_lst(
1193
+ data=valid_values, source=check_valid_lst.__name__, min_len=1
1194
+ )
1195
+ invalids = list(set(data) - set(valid_values))
1196
+ if len(invalids):
1197
+ if raise_error:
1198
+ raise InvalidInputError(
1199
+ msg=f"Invalid list entries. Found {invalids}, accepted: {valid_values}",
1200
+ source=source,
1201
+ )
1202
+ else:
1203
+ return False
1204
+
1205
+ if min_value != None:
1206
+ check_float(name=check_valid_lst.__name__, value=min_value)
1207
+ invalids = [x for x in data if x < min_value]
1208
+ if len(invalids) > 0:
1209
+ if raise_error:
1210
+ raise InvalidInputError(msg=f"Invalid list entries. Found {invalids}, minimum accepted value: {min_value}", source=source)
1211
+ else:
1212
+ return False
1213
+
1214
+ return True
1215
+
1216
+
1217
+ def check_if_keys_exist_in_dict(
1218
+ data: dict,
1219
+ key: Union[str, int, tuple, List],
1220
+ name: Optional[str] = "",
1221
+ raise_error: Optional[bool] = True,
1222
+ ) -> bool:
1223
+ """
1224
+ Check if one or more keys exist in a dictionary.
1225
+
1226
+ This function validates that all specified keys are present in the given dictionary.
1227
+ It can check for a single key or multiple keys at once.
1228
+
1229
+ .. seealso::
1230
+ Consider :func:`simba.utils.checks.check_valid_dict`
1231
+
1232
+ :param dict data: The dictionary to check for key existence.
1233
+ :param Union[str, int, tuple, List] key: The key(s) to check for in the dictionary. Can be a single key or a list/tuple of keys.
1234
+ :param Optional[str] name: A string identifying the source or context of the data for informative error messaging. Default: "".
1235
+ :param Optional[bool] raise_error: If True, raises InvalidInputError if any key is missing. If False, returns False instead of raising an error. Default: True.
1236
+ :return bool: True if all keys exist in the dictionary, False if any key is missing (when raise_error=False).
1237
+ :raises InvalidInputError: If any of the specified keys do not exist in the dictionary and raise_error=True.
1238
+
1239
+ :example:
1240
+ >>> data = {'a': 1, 'b': 2, 'c': 3}
1241
+ >>> check_if_keys_exist_in_dict(data=data, key='a')
1242
+ True
1243
+ >>> check_if_keys_exist_in_dict(data=data, key=['a', 'b'])
1244
+ True
1245
+ >>> check_if_keys_exist_in_dict(data=data, key='d', raise_error=False)
1246
+ False
1247
+ """
1248
+
1249
+ check_instance(source=name, instance=data, accepted_types=(dict,))
1250
+ check_instance(
1251
+ source=name,
1252
+ instance=key,
1253
+ accepted_types=(
1254
+ str,
1255
+ int,
1256
+ tuple,
1257
+ List,
1258
+ ),
1259
+ )
1260
+ if not isinstance(key, (list, tuple)):
1261
+ key = [key]
1262
+
1263
+ for k in key:
1264
+ if k not in list(data.keys()):
1265
+ if raise_error:
1266
+ raise InvalidInputError(
1267
+ msg=f"{k} does not exist in object {name}",
1268
+ source=check_if_keys_exist_in_dict.__class__.__name__,
1269
+ )
1270
+ else:
1271
+ pass
1272
+ return True
1273
+
1274
+
1275
+ def check_that_directory_is_empty(directory: Union[str, os.PathLike], raise_error: Optional[bool] = True) -> None:
1276
+ """
1277
+ Checks if a directory is empty. If the directory has content, then returns False or raises ``DirectoryNotEmptyError``.
1278
+
1279
+ :param str directory: Directory to check.
1280
+ :raises DirectoryNotEmptyError: If ``directory`` contains files.
1281
+ """
1282
+
1283
+ check_if_dir_exists(in_dir=directory)
1284
+ try:
1285
+ all_files_in_folder = [
1286
+ f for f in next(os.walk(directory))[2] if not f[0] == "."
1287
+ ]
1288
+ except StopIteration:
1289
+ return 0
1290
+ else:
1291
+ if len(all_files_in_folder) > 0:
1292
+ if raise_error:
1293
+ raise DirectoryNotEmptyError(
1294
+ msg=f"The {directory} is not empty and contains {str(len(all_files_in_folder))} files. Use a directory that is empty.",
1295
+ source=check_that_directory_is_empty.__name__,
1296
+ )
1297
+ else:
1298
+ return False
1299
+ else:
1300
+ return True
1301
+
1302
+
1303
+ def check_umap_hyperparameters(hyper_parameters: Dict[str, Any]) -> None:
1304
+ """
1305
+ Checks if dictionary of paramameters (umap, scaling, etc) are valid for grid-search umap dimensionality reduction .
1306
+
1307
+ :param dict hyper_parameters: Dictionary holding umap hyerparameters.
1308
+ :raises InvalidInputError: If any input is invalid
1309
+
1310
+ :example:
1311
+ >>> check_umap_hyperparameters(hyper_parameters={'n_neighbors': [2], 'min_distance': [0.1], 'spread': [1], 'scaler': 'MIN-MAX', 'variance': 0.2})
1312
+ """
1313
+ for key in UMAPParam.HYPERPARAMETERS.value:
1314
+ if key not in hyper_parameters.keys():
1315
+ raise InvalidInputError(
1316
+ msg=f"Hyperparameter dictionary is missing {key} entry.",
1317
+ source=check_umap_hyperparameters.__name__,
1318
+ )
1319
+ for key in [
1320
+ UMAPParam.N_NEIGHBORS.value,
1321
+ UMAPParam.MIN_DISTANCE.value,
1322
+ UMAPParam.SPREAD.value,
1323
+ ]:
1324
+ if not isinstance(hyper_parameters[key], list):
1325
+ raise InvalidInputError(
1326
+ msg=f"Hyperparameter dictionary key {key} has to be a list but got {type(hyper_parameters[key])}.",
1327
+ source=check_umap_hyperparameters.__name__,
1328
+ )
1329
+ if len(hyper_parameters[key]) == 0:
1330
+ raise InvalidInputError(
1331
+ msg=f"Hyperparameter dictionary key {key} has 0 entries.",
1332
+ source=check_umap_hyperparameters.__name__,
1333
+ )
1334
+ for value in hyper_parameters[key]:
1335
+ if not isinstance(value, (int, float)):
1336
+ raise InvalidInputError(
1337
+ msg=f"Hyperparameter dictionary key {key} have to have numeric entries but got {type(value)}.",
1338
+ source=check_umap_hyperparameters.__name__,
1339
+ )
1340
+ if hyper_parameters[UMAPParam.SCALER.value] not in Options.SCALER_OPTIONS.value:
1341
+ raise InvalidInputError(
1342
+ msg=f"Scaler {hyper_parameters[UMAPParam.SCALER.value]} not supported. Opitions: {Options.SCALER_OPTIONS.value}",
1343
+ source=check_umap_hyperparameters.__name__,
1344
+ )
1345
+ check_float(
1346
+ "VARIANCE THRESHOLD",
1347
+ value=hyper_parameters[UMAPParam.VARIANCE.value],
1348
+ min_value=0.0,
1349
+ max_value=100.0,
1350
+ )
1351
+ def check_video_has_rois(roi_dict: Dict[str, pd.DataFrame],
1352
+ roi_names: List[str] = None,
1353
+ video_names: List[str] = None,
1354
+ source: str = 'roi dict',
1355
+ raise_error: bool = True):
1356
+ """
1357
+ Check that specified videos all have user-defined ROIs with specified names.
1358
+
1359
+ This function validates that all specified videos contain the required ROIs (Regions of Interest)
1360
+ with the specified names. It checks across all ROI types: rectangles, circles, and polygons.
1361
+
1362
+ .. note::
1363
+ To get roi dictionary, see :func:`simba.mixins.config_reader.ConfigReader.read_roi_data`.
1364
+
1365
+ :param Dict[str, pd.DataFrame] roi_dict: Dictionary containing ROI dataframes with keys for rectangles, circles, and polygons.
1366
+ :param Optional[List[str]] roi_names: List of ROI names to check for. If None, uses all unique ROI names from the data. Default: None.
1367
+ :param Optional[List[str]] video_names: List of video names to check. If None, uses all unique video names from the data. Default: None.
1368
+ :param str source: A string identifying the source or context for informative error messaging. Default: 'roi dict'.
1369
+ :param bool raise_error: If True, raises NoROIDataError if any videos are missing required ROIs. If False, returns tuple with validation result and missing ROIs. Default: True.
1370
+ :return: If raise_error=True: None if all validations pass, raises exception if validation fails. If raise_error=False: Tuple of (bool, dict) where bool indicates success and dict contains missing ROIs by video.
1371
+ :rtype: Union[None, Tuple[bool, Dict[str, List[str]]]]
1372
+ :raises NoROIDataError: If any videos are missing required ROIs and raise_error=True.
1373
+
1374
+ :example:
1375
+ >>> roi_dict = {
1376
+ ... 'rectangles': pd.DataFrame({'Video': ['video1'], 'Name': ['ROI1']}),
1377
+ ... 'circles': pd.DataFrame({'Video': ['video1'], 'Name': ['ROI2']}),
1378
+ ... 'polygons': pd.DataFrame({'Video': ['video1'], 'Name': ['ROI3']})
1379
+ ... }
1380
+ >>> check_video_has_rois(roi_dict=roi_dict, roi_names=['ROI1', 'ROI2'], video_names=['video1'])
1381
+ True
1382
+ >>> check_video_has_rois(roi_dict=roi_dict, roi_names=['ROI1', 'ROI4'], video_names=['video1'], raise_error=False)
1383
+ (False, {'video1': ['ROI4']})
1384
+ """
1385
+
1386
+ check_valid_dict(x=roi_dict, valid_key_dtypes=(str,), valid_values_dtypes=(pd.DataFrame,), required_keys=(Keys.ROI_RECTANGLES.value, Keys.ROI_CIRCLES.value, Keys.ROI_POLYGONS.value,),)
1387
+ check_valid_dataframe(df=roi_dict[Keys.ROI_RECTANGLES.value], source=f'{check_video_has_rois.__name__} {source} roi_dict {Keys.ROI_RECTANGLES.value}', required_fields=['Video', 'Name'])
1388
+ check_valid_dataframe(df=roi_dict[Keys.ROI_CIRCLES.value], source=f'{check_video_has_rois.__name__} {source} roi_dict {Keys.ROI_CIRCLES.value}', required_fields=['Video', 'Name'])
1389
+ check_valid_dataframe(df=roi_dict[Keys.ROI_POLYGONS.value], source=f'{check_video_has_rois.__name__} {source} roi_dict {Keys.ROI_POLYGONS.value}', required_fields=['Video', 'Name'])
1390
+ if roi_names is not None:
1391
+ check_valid_lst(data=roi_names, source=f'{check_video_has_rois.__name__} {source} roi_names', valid_dtypes=(str,), min_len=1)
1392
+ else:
1393
+ roi_names = list(set(list(roi_dict[Keys.ROI_RECTANGLES.value]['Name'].unique()) + list(roi_dict[Keys.ROI_CIRCLES.value]['Name'].unique()) + list(roi_dict[Keys.ROI_POLYGONS.value]['Name'].unique())))
1394
+ if video_names is not None:
1395
+ check_valid_lst(data=video_names, source=f'{check_video_has_rois.__name__} {source} video_names', min_len=1,)
1396
+ else:
1397
+ video_names = list(set(list(roi_dict[Keys.ROI_RECTANGLES.value]['Video'].unique()) + list(roi_dict[Keys.ROI_CIRCLES.value]['Video'].unique()) + list(roi_dict[Keys.ROI_POLYGONS.value]['Video'].unique())))
1398
+ missing_rois = {}
1399
+ rois_missing = False
1400
+ for video_name in video_names:
1401
+ missing_rois[video_name] = []
1402
+ for roi_name in roi_names:
1403
+ rect_filt = roi_dict[Keys.ROI_RECTANGLES.value][(roi_dict[Keys.ROI_RECTANGLES.value]['Video'] == video_name) & (roi_dict[Keys.ROI_RECTANGLES.value]['Name'] == roi_name)]
1404
+ circ_filt = roi_dict[Keys.ROI_CIRCLES.value][(roi_dict[Keys.ROI_CIRCLES.value]['Video'] == video_name) & (roi_dict[Keys.ROI_CIRCLES.value]['Name'] == roi_name)]
1405
+ poly_filt = roi_dict[Keys.ROI_POLYGONS.value][(roi_dict[Keys.ROI_POLYGONS.value]['Video'] == video_name) & (roi_dict[Keys.ROI_POLYGONS.value]['Name'] == roi_name)]
1406
+ if (len(rect_filt) + len(circ_filt) + len(poly_filt)) == 0:
1407
+ missing_rois[video_name].append(roi_name); rois_missing = True
1408
+ if rois_missing and raise_error:
1409
+ raise NoROIDataError(msg=f'Some videos are missing some ROIs: {missing_rois}', source=f'{check_video_has_rois.__name__} {source}')
1410
+ elif rois_missing:
1411
+ return False, missing_rois
1412
+ else:
1413
+ return True
1414
+
1415
+
1416
+ def check_if_df_field_is_boolean(df: pd.DataFrame,
1417
+ field: str,
1418
+ raise_error: bool = True,
1419
+ bool_values: Optional[Tuple[Any]] = (0, 1),
1420
+ df_name: Optional[str] = ''):
1421
+ """
1422
+ Check if a DataFrame field contains only boolean values.
1423
+
1424
+ This function validates that a specified column in a DataFrame contains only
1425
+ the expected boolean values (e.g., 0/1, True/False). It checks for any
1426
+ unexpected values that are not in the allowed boolean values set.
1427
+
1428
+ :param pd.DataFrame df: The DataFrame to check.
1429
+ :param str field: Name of the column to validate for boolean values.
1430
+ :param bool raise_error: If True, raises CountError when non-boolean values are found. If False, returns False. Default: True.
1431
+ :param Optional[Tuple[Any]] bool_values: Tuple of accepted boolean values. Default: (0, 1).
1432
+ :param Optional[str] df_name: Name of the DataFrame for error messaging. Default: ''.
1433
+ :return: True if field contains only boolean values, False if non-boolean values found and raise_error=False.
1434
+ :rtype: bool
1435
+ :raises CountError: If non-boolean values are found in the field and raise_error=True.
1436
+
1437
+ :example:
1438
+ >>> df = pd.DataFrame({'binary_col': [0, 1, 0, 1], 'mixed_col': [0, 1, 2, 0]})
1439
+ >>> check_if_df_field_is_boolean(df=df, field='binary_col')
1440
+ True
1441
+ >>> check_if_df_field_is_boolean(df=df, field='mixed_col', raise_error=False)
1442
+ False
1443
+ >>> check_if_df_field_is_boolean(df=df, field='mixed_col', bool_values=(0, 1, 2))
1444
+ True
1445
+ """
1446
+ check_instance(source=f'{check_if_df_field_is_boolean.__name__} df', instance=df, accepted_types=(pd.DataFrame,))
1447
+ check_str(name=f"{check_if_df_field_is_boolean.__name__} field", value=field)
1448
+ check_that_column_exist(df=df, column_name=field, file_name=check_if_df_field_is_boolean.__name__)
1449
+ additional = list((set(list(df[field])) - set(bool_values)))
1450
+ if len(additional) > 0:
1451
+ if raise_error:
1452
+ raise CountError(msg=f"Field {field} not a boolean in {df_name}. Found values {additional}. Accepted: {bool_values}", source=check_if_df_field_is_boolean.__name__)
1453
+ else:
1454
+ return False
1455
+ return True
1456
+
1457
+
1458
+ def check_valid_dataframe(
1459
+ df: pd.DataFrame,
1460
+ source: Optional[str] = "",
1461
+ valid_dtypes: Optional[Tuple[Any]] = None,
1462
+ required_fields: Optional[List[str]] = None,
1463
+ min_axis_0: Optional[int] = None,
1464
+ min_axis_1: Optional[int] = None,
1465
+ max_axis_0: Optional[int] = None,
1466
+ max_axis_1: Optional[int] = None,
1467
+ allow_duplicate_col_names = True,
1468
+ ):
1469
+ """
1470
+ Validate a DataFrame against various criteria.
1471
+
1472
+ This function performs comprehensive validation of a pandas DataFrame including
1473
+ data types, dimensions, required columns, and duplicate column names. It raises
1474
+ exceptions for any validation failures.
1475
+
1476
+ :param pd.DataFrame df: The DataFrame to validate.
1477
+ :param Optional[str] source: Source identifier for error messages. Default: "".
1478
+ :param Optional[Tuple[Any]] valid_dtypes: Tuple of allowed data types. If None, no dtype validation. Default: None.
1479
+ :param Optional[List[str]] required_fields: List of required column names. If None, no field validation. Default: None.
1480
+ :param Optional[int] min_axis_0: Minimum number of rows required. If None, no minimum row validation. Default: None.
1481
+ :param Optional[int] min_axis_1: Minimum number of columns required. If None, no minimum column validation. Default: None.
1482
+ :param Optional[int] max_axis_0: Maximum number of rows allowed. If None, no maximum row validation. Default: None.
1483
+ :param Optional[int] max_axis_1: Maximum number of columns allowed. If None, no maximum column validation. Default: None.
1484
+ :param bool allow_duplicate_col_names: If False, raises error for duplicate column names. Default: True.
1485
+ :return: None if validation passes.
1486
+ :rtype: None
1487
+ :raises InvalidInputError: If any validation criteria are not met.
1488
+
1489
+ :example:
1490
+ >>> df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
1491
+ >>> check_valid_dataframe(df=df, required_fields=['A', 'B'], min_axis_0=1)
1492
+ >>> check_valid_dataframe(df=df, valid_dtypes=(int,), max_axis_1=2)
1493
+ >>> check_valid_dataframe(df=df, allow_duplicate_col_names=False)
1494
+ """
1495
+ check_instance(source=source, instance=df, accepted_types=(pd.DataFrame,))
1496
+ if valid_dtypes is not None:
1497
+ dtypes = list(set(df.dtypes))
1498
+ additional = [x for x in dtypes if x not in valid_dtypes]
1499
+ if len(additional) > 0:
1500
+ raise InvalidInputError(
1501
+ msg=f"The dataframe {source} has invalid data format(s) {additional}. Valid: {valid_dtypes}",
1502
+ source=source,
1503
+ )
1504
+ if min_axis_1 is not None:
1505
+ check_int(name=f"{source} min_axis_1", value=min_axis_1, min_value=1)
1506
+ if len(df.columns) < min_axis_1:
1507
+ raise InvalidInputError(
1508
+ msg=f"The dataframe {source} has less than ({df.columns}) the required minimum number of columns ({min_axis_1}).",
1509
+ source=source,
1510
+ )
1511
+ if min_axis_0 is not None:
1512
+ check_int(name=f"{source} min_axis_0", value=min_axis_0, min_value=1)
1513
+ if len(df) < min_axis_0:
1514
+ raise InvalidInputError(
1515
+ msg=f"The dataframe {source} has less than ({len(df)}) the required minimum number of rows ({min_axis_0}).",
1516
+ source=source,
1517
+ )
1518
+ if max_axis_0 is not None:
1519
+ check_int(name=f"{source} max_axis_0", value=min_axis_0, min_value=1)
1520
+ if len(df) > max_axis_0:
1521
+ raise InvalidInputError(
1522
+ msg=f"The dataframe {source} has more than ({len(df)}) the required maximum number of rows ({max_axis_0}).",
1523
+ source=source,
1524
+ )
1525
+ if max_axis_1 is not None:
1526
+ check_int(name=f"{source} max_axis_1", value=min_axis_1, min_value=1)
1527
+ if len(df.columns) > max_axis_1:
1528
+ raise InvalidInputError(
1529
+ msg=f"The dataframe {source} has more than ({df.columns}) the required maximum number of columns ({max_axis_1}).",
1530
+ source=source,
1531
+ )
1532
+ if required_fields is not None:
1533
+ check_valid_lst(
1534
+ data=required_fields,
1535
+ source=check_valid_dataframe.__name__,
1536
+ valid_dtypes=(str,),
1537
+ )
1538
+ missing = list(set(required_fields) - set(df.columns))
1539
+ if len(missing) > 0:
1540
+ raise InvalidInputError(
1541
+ msg=f"The dataframe {source} are missing required columns {missing}.",
1542
+ source=source,
1543
+ )
1544
+
1545
+ if not allow_duplicate_col_names:
1546
+ col_names = list(df.columns)
1547
+ seen = set()
1548
+ duplicate_col_names = list(set(x for x in col_names if x in seen or seen.add(x)))
1549
+ if len(duplicate_col_names) > 0:
1550
+ raise InvalidInputError(msg=f"The dataframe {source} has duplicate column names {duplicate_col_names}.", source=source)
1551
+
1552
+
1553
+
1554
+
1555
+ def check_valid_boolean(value: Union[Any, List[Any]], source: Optional[str] = '', raise_error: Optional[bool] = True):
1556
+ """
1557
+ Check if a value or list of values contains only valid boolean values.
1558
+
1559
+ This function validates that the input value(s) are valid Python boolean values
1560
+ (True or False). It can handle single values or lists of values, and provides
1561
+ flexible error handling options.
1562
+
1563
+ :param Union[Any, List[Any]] value: Single value or list of values to validate for boolean type.
1564
+ :param Optional[str] source: Source identifier for error messages. Default: ''.
1565
+ :param Optional[bool] raise_error: If True, raises InvalidInputError when non-boolean values are found. If False, returns False. Default: True.
1566
+ :return: True if all values are valid booleans, False if any non-boolean values found and raise_error=False.
1567
+ :rtype: bool
1568
+ :raises InvalidInputError: If non-boolean values are found and raise_error=True.
1569
+
1570
+ :example:
1571
+ >>> check_valid_boolean(True)
1572
+ True
1573
+ >>> check_valid_boolean([True, False, True])
1574
+ True
1575
+ >>> check_valid_boolean([True, 1, False], raise_error=False)
1576
+ False
1577
+ >>> check_valid_boolean('not_bool', raise_error=False)
1578
+ False
1579
+ """
1580
+ if not isinstance(value, list):
1581
+ value = [value]
1582
+ for val in value:
1583
+ if val in (True, False):
1584
+ return True
1585
+ else:
1586
+ if raise_error:
1587
+ raise InvalidInputError(msg=f'{val} is not a valid boolean', source=source)
1588
+ else:
1589
+ return False
1590
+
1591
+ def check_valid_tuple(x: tuple,
1592
+ source: Optional[str] = "",
1593
+ accepted_lengths: Optional[Tuple[int]] = None,
1594
+ valid_dtypes: Optional[Tuple[Any]] = None,
1595
+ minimum_length: Optional[int] = None,
1596
+ accepted_values: Optional[Iterable[Any]] = None,
1597
+ min_integer: Optional[int] = None):
1598
+ """
1599
+ Validate a tuple against various criteria.
1600
+
1601
+ This function performs comprehensive validation of a tuple including
1602
+ length constraints, data types, minimum values, and accepted values.
1603
+ It raises exceptions for any validation failures.
1604
+
1605
+ :param tuple x: The tuple to validate.
1606
+ :param Optional[str] source: Source identifier for error messages. Default: "".
1607
+ :param Optional[Tuple[int]] accepted_lengths: Tuple of accepted lengths. If None, no length validation. Default: None.
1608
+ :param Optional[Tuple[Any]] valid_dtypes: Tuple of allowed data types for tuple elements. If None, no dtype validation. Default: None.
1609
+ :param Optional[int] minimum_length: Minimum length required. If None, no minimum length validation. Default: None.
1610
+ :param Optional[Iterable[Any]] accepted_values: Iterable of accepted values for tuple elements. If None, no value validation. Default: None.
1611
+ :param Optional[int] min_integer: Minimum value for integer elements. If None, no integer validation. Default: None.
1612
+ :return: None if validation passes.
1613
+ :rtype: None
1614
+ :raises InvalidInputError: If any validation criteria are not met.
1615
+
1616
+ :example:
1617
+ >>> check_valid_tuple(x=(1, 2, 3), accepted_lengths=(2, 3), valid_dtypes=(int,))
1618
+ >>> check_valid_tuple(x=('a', 'b'), minimum_length=2, accepted_values=['a', 'b', 'c'])
1619
+ >>> check_valid_tuple(x=(5, 10, 15), min_integer=5)
1620
+ """
1621
+
1622
+ if not isinstance(x, (tuple)):
1623
+ raise InvalidInputError(msg=f"{check_valid_tuple.__name__} {source} is not a valid tuple, got: {type(x)}", source=source,)
1624
+ if accepted_lengths is not None:
1625
+ if len(x) not in accepted_lengths:
1626
+ raise InvalidInputError(
1627
+ msg=f"Tuple is not of valid lengths. Found {len(x)}. Accepted: {accepted_lengths}",
1628
+ source=source,
1629
+ )
1630
+ if valid_dtypes is not None:
1631
+ dtypes = list(set([type(v) for v in x]))
1632
+ additional = [x for x in dtypes if x not in valid_dtypes]
1633
+ if len(additional) > 0:
1634
+ raise InvalidInputError(msg=f"The tuple {source} has invalid data format(s) {additional}. Valid: {valid_dtypes}", source=source)
1635
+
1636
+ if minimum_length is not None:
1637
+ check_int(name=f'{check_valid_tuple.__name__} minimum_length', value=minimum_length, min_value=1)
1638
+ tuple_len = len(x)
1639
+ if tuple_len < minimum_length:
1640
+ raise InvalidInputError(msg=f"The tuple {source} is shorter ({tuple_len}) than the minimum required length ({minimum_length}).", source=source)
1641
+
1642
+ if accepted_values is not None:
1643
+ check_instance(source=f'{check_valid_tuple.__name__} accepted_values', accepted_types=(list, tuple,), instance=accepted_values)
1644
+ for i in x:
1645
+ if i not in accepted_values:
1646
+ raise InvalidInputError(msg=f"The tuple {source} has a value that is NOT accepted: {i}, (accepted: {accepted_values}).", source=source)
1647
+
1648
+ if min_integer is not None:
1649
+ check_int(name=f'{check_valid_tuple.__name__} min_integer', value=min_integer)
1650
+ for i in x:
1651
+ if isinstance(i, int):
1652
+ if i < min_integer:
1653
+ raise InvalidInputError(msg=f"The tuple {source} has an integer value below the minimum allowed integer value: {i}, (minimum: {min_integer}).", source=source)
1654
+
1655
+
1656
+ def check_video_and_data_frm_count_align(video: Union[str, os.PathLike, cv2.VideoCapture],
1657
+ data: Union[str, os.PathLike, pd.DataFrame],
1658
+ name: Optional[str] = "",
1659
+ raise_error: Optional[bool] = True) -> Union[None, bool]:
1660
+ """
1661
+ Check if the frame count of a video matches the row count of a data file.
1662
+
1663
+ :param Union[str, os.PathLike, cv2.VideoCapture] video: Path to the video file or cv2.VideoCapture object.
1664
+ :param Union[str, os.PathLike, pd.DataFrame] data: Path to the data file or DataFrame containing the data.
1665
+ :param Optional[str] name: Name of the video (optional for interpretable error msgs).
1666
+ :param Optional[bool] raise_error: Whether to raise an error if the counts don't align (default is True). If False, prints warning.
1667
+ :return None:
1668
+
1669
+ :example:
1670
+ >>> data_1 = '/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/csv/outlier_corrected_movement_location/SI_DAY3_308_CD1_PRESENT.csv'
1671
+ >>> video_1 = '/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/frames/output/ROI_analysis/SI_DAY3_308_CD1_PRESENT.mp4'
1672
+ >>> check_video_and_data_frm_count_align(video=video_1, data=data_1, raise_error=True)
1673
+ """
1674
+
1675
+ def _count_generator(reader):
1676
+ b = reader(1024 * 1024)
1677
+ while b:
1678
+ yield b
1679
+ b = reader(1024 * 1024)
1680
+
1681
+ check_instance(
1682
+ source=f"{check_video_and_data_frm_count_align.__name__} video",
1683
+ instance=video,
1684
+ accepted_types=(str, cv2.VideoCapture),
1685
+ )
1686
+ check_instance(
1687
+ source=f"{check_video_and_data_frm_count_align.__name__} data",
1688
+ instance=data,
1689
+ accepted_types=(str, pd.DataFrame),
1690
+ )
1691
+ check_str(
1692
+ name=f"{check_video_and_data_frm_count_align.__name__} name",
1693
+ value=name,
1694
+ allow_blank=True,
1695
+ )
1696
+ if isinstance(video, str):
1697
+ check_file_exist_and_readable(file_path=video)
1698
+ video = cv2.VideoCapture(video)
1699
+ video_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
1700
+ if isinstance(data, str):
1701
+ check_file_exist_and_readable(file_path=data)
1702
+ with open(data, "rb") as fp:
1703
+ c_generator = _count_generator(fp.raw.read)
1704
+ data_count = (sum(buffer.count(b"\n") for buffer in c_generator)) - 1
1705
+ else:
1706
+ data_count = len(data)
1707
+ if data_count != video_count:
1708
+ if not raise_error:
1709
+ FrameRangeWarning(msg=f"The video {name} has {video_count} frames, but the associated data file for this video has {data_count} rows", source=check_video_and_data_frm_count_align.__name__)
1710
+ return False
1711
+ else:
1712
+ raise FrameRangeError(msg=f"The video {name} has {video_count} frames, but the associated data file for this video has {data_count} rows", source=check_video_and_data_frm_count_align.__name__,)
1713
+ return True
1714
+
1715
+ def check_if_video_corrupted(video: Union[str, os.PathLike, cv2.VideoCapture],
1716
+ frame_interval: Optional[int] = None,
1717
+ frame_n: Optional[int] = 20,
1718
+ raise_error: Optional[bool] = True) -> None:
1719
+
1720
+ """
1721
+ Check if a video file is corrupted by inspecting a set of its frames.
1722
+
1723
+ .. note::
1724
+ For decent run-time regardless of video length, pass a smaller ``frame_n`` (<100).
1725
+
1726
+ :param Union[str, os.PathLike] video_path: Path to the video file or cv2.VideoCapture OpenCV object.
1727
+ :param Optional[int] frame_interval: Interval between frames to be checked. If None, ``frame_n`` will be used.
1728
+ :param Optional[int] frame_n: Number of frames to be checked, will be sampled at large allowed interval. If None, ``frame_interval`` will be used.
1729
+ :param Optional[bool] raise_error: Whether to raise an error if corruption is found. If False, prints warning.
1730
+ :return None:
1731
+
1732
+ :example:
1733
+ >>> check_if_video_corrupted(video_path='/Users/simon/Downloads/NOR ENCODING FExMP8.mp4')
1734
+ """
1735
+ check_instance(source=f'{check_if_video_corrupted.__name__} video', instance=video, accepted_types=(str, cv2.VideoCapture))
1736
+ if isinstance(video, str):
1737
+ check_file_exist_and_readable(file_path=video)
1738
+ cap = cv2.VideoCapture(video)
1739
+ else:
1740
+ cap = video
1741
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
1742
+ if (frame_interval is not None and frame_n is not None) or (frame_interval is None and frame_n is None):
1743
+ raise InvalidInputError(msg='Pass frame_interval OR frame_n', source=check_if_video_corrupted.__name__)
1744
+ if frame_interval is not None:
1745
+ frms_to_check = list(range(0, frame_count, frame_interval))
1746
+ else:
1747
+ frms_to_check = np.array_split(np.arange(0, frame_count), frame_n)
1748
+ frms_to_check = [x[-1] for x in frms_to_check]
1749
+ errors = []
1750
+ for frm_id in frms_to_check:
1751
+ cap.set(1, frm_id)
1752
+ ret, _ = cap.read()
1753
+ if not ret: errors.append(frm_id)
1754
+ if len(errors) > 0:
1755
+ if raise_error:
1756
+ raise CorruptedFileError(msg=f'Found {len(errors)} corrupted frame(s) at indexes {errors} in video {video}', source=check_if_video_corrupted.__name__)
1757
+ else:
1758
+ CorruptedFileWarning(msg=f'Found {len(errors)} corrupted frame(s) at indexes {errors} in video {video}', source=check_if_video_corrupted.__name__)
1759
+ else:
1760
+ pass
1761
+
1762
+
1763
+ def check_valid_dict(x: dict,
1764
+ valid_key_dtypes: Optional[Tuple[Any]] = None,
1765
+ valid_values_dtypes: Optional[Tuple[Any, ...]] = None,
1766
+ valid_keys: Optional[Union[Tuple[Any], List[Any]]] = None,
1767
+ max_len_keys: Optional[int] = None,
1768
+ min_len_keys: Optional[int] = None,
1769
+ required_keys: Optional[Tuple[Any, ...]] = None,
1770
+ max_value: Optional[Union[float, int]] = None,
1771
+ min_value: Optional[Union[float, int]] = None,
1772
+ source: Optional[str] = None):
1773
+ """
1774
+ Validate a dictionary against various criteria.
1775
+
1776
+ This function performs comprehensive validation of a dictionary including
1777
+ key/value data types, key constraints, required keys, and numeric value ranges.
1778
+ It raises exceptions for any validation failures.
1779
+
1780
+ :param dict x: The dictionary to validate.
1781
+ :param Optional[Tuple[Any]] valid_key_dtypes: Tuple of allowed data types for dictionary keys. If None, no key type validation. Default: None.
1782
+ :param Optional[Tuple[Any, ...]] valid_values_dtypes: Tuple of allowed data types for dictionary values. If None, no value type validation. Default: None.
1783
+ :param Optional[Union[Tuple[Any], List[Any]]] valid_keys: Tuple or list of valid key names. If None, no key name validation. Default: None.
1784
+ :param Optional[int] max_len_keys: Maximum number of keys allowed. If None, no maximum key count validation. Default: None.
1785
+ :param Optional[int] min_len_keys: Minimum number of keys required. If None, no minimum key count validation. Default: None.
1786
+ :param Optional[Tuple[Any, ...]] required_keys: Tuple of required key names. If None, no required key validation. Default: None.
1787
+ :param Optional[Union[float, int]] max_value: Maximum numeric value allowed for numeric values. If None, no maximum value validation. Default: None.
1788
+ :param Optional[Union[float, int]] min_value: Minimum numeric value allowed for numeric values. If None, no minimum value validation. Default: None.
1789
+ :param Optional[str] source: Source identifier for error messages. If None, uses function name. Default: None.
1790
+ :return: None if validation passes.
1791
+ :rtype: None
1792
+ :raises InvalidInputError: If any validation criteria are not met.
1793
+
1794
+ :example:
1795
+ >>> check_valid_dict(x={'a': 1, 'b': 2}, valid_key_dtypes=(str,), valid_values_dtypes=(int,))
1796
+ >>> check_valid_dict(x={'key1': 10, 'key2': 20}, required_keys=('key1',), min_value=5, max_value=25)
1797
+ >>> check_valid_dict(x={'x': 1, 'y': 2}, valid_keys=('x', 'y', 'z'), min_len_keys=2)
1798
+ """
1799
+
1800
+
1801
+ source = check_valid_dict.__name__ if source is None else source
1802
+ check_instance(source=check_valid_dict.__name__, instance=x, accepted_types=(dict,))
1803
+ if valid_key_dtypes is not None:
1804
+ for i in list(x.keys()):
1805
+ if not isinstance(i, valid_key_dtypes):
1806
+ raise InvalidInputError(msg=f'{type(i)} is not a valid key DTYPE. Valid: {valid_key_dtypes}', source=source)
1807
+ if valid_values_dtypes is not None:
1808
+ for i in list(x.values()):
1809
+ if not isinstance(i, valid_values_dtypes):
1810
+ raise InvalidInputError(msg=f'{type(i)} is not a valid value DTYPE. Valid: {valid_values_dtypes}', source=source)
1811
+ if max_len_keys is not None:
1812
+ check_int(name=f'{check_valid_dict.__name__} max_len_keys', min_value=1, value=max_len_keys)
1813
+ key_cnt = len(list(x.keys()))
1814
+ if key_cnt > max_len_keys:
1815
+ raise InvalidInputError(msg=f'Dictionary have {key_cnt} keys. Maximum allowed: {max_len_keys}', source=source)
1816
+ if min_len_keys is not None:
1817
+ check_int(name=f'{check_valid_dict.__name__} min_len_keys', min_value=1, value=min_len_keys)
1818
+ key_cnt = len(list(x.keys()))
1819
+ if key_cnt < min_len_keys:
1820
+ raise InvalidInputError(msg=f'Dictionary have {key_cnt} keys. Minimum allowed: {min_len_keys}', source=source)
1821
+ if required_keys is not None:
1822
+ for i in list(required_keys):
1823
+ if i not in list(x.keys()):
1824
+ raise InvalidInputError(msg=f'The required key {i} does not exist in the dictionary. Existing keys: {list(x.keys())}', source=source)
1825
+ if max_value is not None:
1826
+ if not isinstance(max_value, (float, int)):
1827
+ raise InvalidInputError(msg=f'{check_valid_dict.__name__} max_value has to be a float or integer, got {type(max_value)}.')
1828
+ for k, v in x.items():
1829
+ if isinstance(v, (float, int)):
1830
+ if v > max_value:
1831
+ raise InvalidInputError(msg=f'The required key {k} has value {v} which is above the max allowed: {max_value}.', source=source)
1832
+ if min_value is not None:
1833
+ if not isinstance(min_value, (float, int)):
1834
+ raise InvalidInputError(msg=f'{check_valid_dict.__name__} max_value has to be a float or integer, got {type(min_value)}.')
1835
+ for k, v in x.items():
1836
+ if isinstance(v, (float, int)):
1837
+ if v < min_value:
1838
+ raise InvalidInputError(msg=f'The required key {k} has value {v} which is less than the minimum allowed: {min_value}.', source=source)
1839
+ if valid_keys is not None:
1840
+ if not isinstance(valid_keys, (tuple, list)):
1841
+ raise InvalidInputError(msg=f'{check_valid_dict.__name__} valid_keys has to tuple, got {type(valid_keys)}.')
1842
+ invalid_keys = [i for i in x.keys() if i not in valid_keys]
1843
+ if len(invalid_keys) > 0:
1844
+ raise InvalidInputError(msg=f'The dictionary has keys that are invalid ({invalid_keys}). Accepted, valid keys are: {valid_keys}.', source=source)
1845
+
1846
+
1847
+
1848
+
1849
+ def is_video_color(video: Union[str, os.PathLike, cv2.VideoCapture]) -> bool:
1850
+ """
1851
+ Determines whether a video is in color or greyscale.
1852
+
1853
+ .. seealso::
1854
+ :func:`simba.mixins.image_mixin.ImageMixin.is_video_color`
1855
+
1856
+ :param Union[str, os.PathLike, cv2.VideoCapture] video: The video source, either a cv2.VideoCapture object or a path to a file on disk.
1857
+ :return: Returns `True` if the video is in color (has more than one channel), and `False` if the video is greyscale (single channel).
1858
+ :rtype: bool
1859
+ """
1860
+
1861
+ check_instance(source=is_video_color.__name__, instance=video, accepted_types=(str, cv2.VideoCapture))
1862
+
1863
+ # Handle string path vs VideoCapture object
1864
+ should_release = False
1865
+ if isinstance(video, str):
1866
+ check_file_exist_and_readable(file_path=video)
1867
+ video = cv2.VideoCapture(video)
1868
+ should_release = True
1869
+
1870
+ try:
1871
+ video.set(cv2.CAP_PROP_POS_FRAMES, 0)
1872
+ _, frm = video.read()
1873
+
1874
+ # If frame has only 2 dimensions, it's definitely greyscale
1875
+ if frm.ndim == 2:
1876
+ return False
1877
+
1878
+ # If frame has 3 dimensions, check if it's actually greyscale
1879
+ # (some greyscale videos are stored as 3-channel with identical values)
1880
+ if frm.ndim == 3:
1881
+ # Check if all channels are identical (indicating greyscale)
1882
+ if frm.shape[2] == 3: # BGR format
1883
+ # Compare B, G, and R channels
1884
+ if np.array_equal(frm[:, :, 0], frm[:, :, 1]) and np.array_equal(frm[:, :, 1], frm[:, :, 2]):
1885
+ return False # All channels identical = greyscale
1886
+ else:
1887
+ return True # Channels different = color
1888
+ elif frm.shape[2] == 1:
1889
+ return False # Single channel = greyscale
1890
+ else:
1891
+ return True # Other multi-channel formats = color
1892
+
1893
+ # Default case: assume greyscale
1894
+ return False
1895
+
1896
+ finally:
1897
+ # Clean up VideoCapture if we created it
1898
+ if should_release and video.isOpened():
1899
+ video.release()
1900
+
1901
+
1902
+ def check_filepaths_in_iterable_exist(file_paths: Iterable[str],
1903
+ name: Optional[str] = None):
1904
+
1905
+ check_instance(source=f'{check_filepaths_in_iterable_exist.__name__} file_paths {name}', instance=file_paths, accepted_types=(list, tuple,))
1906
+ if len(file_paths) == 0:
1907
+ raise NoFilesFoundError(msg=f'{name} {file_paths} is empty')
1908
+ for file_path in file_paths:
1909
+ check_str(name=f'{check_filepaths_in_iterable_exist.__name__} {file_path} {name}', value=file_path)
1910
+ if not os.path.isfile(file_path):
1911
+ raise NoFilesFoundError(msg=f'{name} {file_path} is not a valid file path')
1912
+
1913
+ def check_all_dfs_in_list_has_same_cols(dfs: List[pd.DataFrame], raise_error: bool = True, source: str = '') -> bool:
1914
+ """
1915
+ Check that all DataFrames in a list have the same column names.
1916
+
1917
+ This function validates that all DataFrames in the provided list contain
1918
+ identical column headers. It finds the intersection of all column names
1919
+ and identifies any missing headers that are not present in all DataFrames.
1920
+
1921
+ :param List[pd.DataFrame] dfs: List of DataFrames to validate for consistent column names.
1922
+ :param bool raise_error: If True, raises MissingColumnsError when column names don't match. If False, returns False. Default: True.
1923
+ :param str source: Source identifier for error messages. Default: ''.
1924
+ :return: True if all DataFrames have the same column names, False if they don't match and raise_error=False.
1925
+ :rtype: bool
1926
+ :raises MissingColumnsError: If DataFrames have different column names and raise_error=True.
1927
+
1928
+ :example:
1929
+ >>> df1 = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
1930
+ >>> df2 = pd.DataFrame({'A': [5, 6], 'B': [7, 8]})
1931
+ >>> check_all_dfs_in_list_has_same_cols(dfs=[df1, df2])
1932
+ True
1933
+ >>> df3 = pd.DataFrame({'A': [1, 2], 'C': [3, 4]})
1934
+ >>> check_all_dfs_in_list_has_same_cols(dfs=[df1, df3], raise_error=False)
1935
+ False
1936
+ """
1937
+ check_valid_lst(data=dfs, source=check_all_dfs_in_list_has_same_cols.__name__, valid_dtypes=(pd.DataFrame,), min_len=1)
1938
+ col_headers = [list(x.columns) for x in dfs]
1939
+ common_headers = set(col_headers[0]).intersection(*col_headers[1:])
1940
+ all_headers = set(item for sublist in col_headers for item in sublist)
1941
+ missing_headers = list(all_headers - common_headers)
1942
+ if len(missing_headers) > 0:
1943
+ if raise_error:
1944
+ raise MissingColumnsError(msg=f"The data in {source} directory do not contain the same headers. Some files are missing the headers: {missing_headers}", source=check_all_dfs_in_list_has_same_cols.__name__)
1945
+ else:
1946
+ return False
1947
+ return True
1948
+
1949
+
1950
+ def is_valid_video_file(file_path: Union[str, os.PathLike], raise_error: bool = True):
1951
+ """
1952
+ Check if a file path is a valid video file.
1953
+
1954
+ This function validates that a file path exists, is readable, and can be
1955
+ opened as a video file using OpenCV. It performs basic video file validation
1956
+ by attempting to open the file with cv2.VideoCapture.
1957
+
1958
+ :param Union[str, os.PathLike] file_path: Path to the video file to validate.
1959
+ :param bool raise_error: If True, raises InvalidFilepathError when file is not a valid video. If False, returns False. Default: True.
1960
+ :return: True if the file is a valid video file, False if it's not valid and raise_error=False.
1961
+ :rtype: bool
1962
+ :raises InvalidFilepathError: If the file is not a valid video file and raise_error=True.
1963
+
1964
+ :example:
1965
+ >>> is_valid_video_file('/path/to/video.mp4')
1966
+ True
1967
+ >>> is_valid_video_file('/path/to/invalid.txt', raise_error=False)
1968
+ False
1969
+ >>> is_valid_video_file('/path/to/corrupted.mp4', raise_error=False)
1970
+ False
1971
+ """
1972
+ check_file_exist_and_readable(file_path=file_path)
1973
+ try:
1974
+ cap = cv2.VideoCapture(file_path)
1975
+ if not cap.isOpened():
1976
+ if not raise_error:
1977
+ return False
1978
+ else:
1979
+ raise InvalidFilepathError(msg=f'The path {file_path} is not a valid video file', source=is_valid_video_file.__name__)
1980
+ return True
1981
+ except Exception:
1982
+ if not raise_error:
1983
+ return False
1984
+ else:
1985
+ raise InvalidFilepathError(msg=f'The path {file_path} is not a valid video file', source=is_valid_video_file.__name__)
1986
+ finally:
1987
+ if 'cap' in locals():
1988
+ if cap.isOpened():
1989
+ cap.release()
1990
+
1991
+
1992
+ def check_valid_polygon(polygon: Union[np.ndarray, Polygon], raise_error: bool = True, name: Optional[str] = None) -> Union[bool, None]:
1993
+ """
1994
+ Validates whether the given polygon is a valid geometric shape.
1995
+
1996
+ :param Union[np.ndarray, Polygon] polygon: The polygon to validate, either as a NumPy array of shape (N, 2) or a shapely Polygon object.
1997
+ :param bool raise_error: If True, raises an InvalidInputError if the polygon is invalid; otherwise, returns False.
1998
+ :param Optional[str] name: An optional name for the polygon to include in error messages.
1999
+ :return: True if the polygon is valid, False if invalid (and raise_error is False), or None if an error is raised.
2000
+ """
2001
+
2002
+
2003
+ name = '' if name is None else name
2004
+ check_instance(source=f'{check_valid_polygon.__name__} polygon', accepted_types=(np.ndarray, Polygon,), instance=polygon)
2005
+ if isinstance(polygon, np.ndarray):
2006
+ check_valid_array(data=polygon, source=f'{check_valid_polygon.__name__} polygon', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value, min_axis_0=3, accepted_axis_1_shape=[2,])
2007
+ polygon = Polygon(polygon.astype(np.int32))
2008
+ if not polygon.is_valid:
2009
+ if raise_error:
2010
+ raise InvalidInputError(msg=f'The polygon {name} is invalid', source=check_valid_polygon.__name__)
2011
+ else:
2012
+ return False
2013
+ else:
2014
+ return True
2015
+
2016
+
2017
+ def is_img_bw(img: np.ndarray,
2018
+ raise_error: bool = True,
2019
+ source: Optional[str] = '') -> bool:
2020
+ """
2021
+ Check if an image is binary black and white.
2022
+
2023
+ This function validates that an image contains only two pixel values:
2024
+ 0 (black) and 255 (white). It checks all unique pixel values in the image
2025
+ and ensures they are exactly these two values.
2026
+
2027
+ :param np.ndarray img: The image array to validate for binary black and white format.
2028
+ :param bool raise_error: If True, raises InvalidInputError when image is not binary black and white. If False, returns False. Default: True.
2029
+ :param Optional[str] source: Source identifier for error messages. Default: ''.
2030
+ :return: True if the image is binary black and white, False if it's not and raise_error=False.
2031
+ :rtype: bool
2032
+ :raises InvalidInputError: If the image is not binary black and white and raise_error=True.
2033
+
2034
+ :example:
2035
+ >>> bw_img = np.array([[0, 255], [255, 0]], dtype=np.uint8)
2036
+ >>> is_img_bw(bw_img)
2037
+ True
2038
+ >>> gray_img = np.array([[128, 200], [50, 100]], dtype=np.uint8)
2039
+ >>> is_img_bw(gray_img, raise_error=False)
2040
+ False
2041
+ """
2042
+ check_if_valid_img(data=img, source=is_img_bw.__name__, raise_error=True)
2043
+ px_vals = set(list(np.sort(np.unique(img)).astype(np.int32)))
2044
+ additional = list(px_vals - {0, 255})
2045
+ if len(additional) > 0:
2046
+ if raise_error:
2047
+ raise InvalidInputError(msg=f'The image {source} is not a black-and-white image. Expected: [0, 255], got {additional}', source=is_img_bw.__name__)
2048
+ else:
2049
+ return False
2050
+ else:
2051
+ return True
2052
+
2053
+ def is_img_greyscale(img: np.ndarray,
2054
+ raise_error: bool = True,
2055
+ source: Optional[str] = '') -> bool:
2056
+ """
2057
+ Check if an image is greyscale.
2058
+
2059
+ This function validates that an image is in greyscale format by checking
2060
+ that it has exactly 2 dimensions (height and width). Greyscale images
2061
+ have a single channel and are represented as 2D arrays.
2062
+
2063
+ :param np.ndarray img: The image array to validate for greyscale format.
2064
+ :param bool raise_error: If True, raises InvalidInputError when image is not greyscale. If False, returns False. Default: True.
2065
+ :param Optional[str] source: Source identifier for error messages. Default: ''.
2066
+ :return: True if the image is greyscale, False if it's not and raise_error=False.
2067
+ :rtype: bool
2068
+ :raises InvalidInputError: If the image is not greyscale and raise_error=True.
2069
+
2070
+ :example:
2071
+ >>> gray_img = np.array([[128, 200], [50, 100]], dtype=np.uint8)
2072
+ >>> is_img_greyscale(gray_img)
2073
+ True
2074
+ >>> color_img = np.array([[[128, 200, 50], [100, 150, 75]]], dtype=np.uint8)
2075
+ >>> is_img_greyscale(color_img, raise_error=False)
2076
+ False
2077
+ """
2078
+ check_if_valid_img(data=img, source=is_img_greyscale.__name__, raise_error=False, greyscale=True)
2079
+ if not img.ndim == 2:
2080
+ if raise_error:
2081
+ raise InvalidInputError(msg=f'The image {source} is not a greyscale image. Expected 2 dimensions got: {img.ndim}', source=is_img_greyscale.__name__)
2082
+ else:
2083
+ return False
2084
+ else:
2085
+ return True
2086
+
2087
+
2088
+ def is_wsl() -> bool:
2089
+ """
2090
+ Check if SimBA is running in Microsoft WSL (Windows Subsystem for Linux).
2091
+
2092
+ This function detects whether the current environment is running inside
2093
+ Microsoft WSL by checking the contents of /proc/version for the presence
2094
+ of "microsoft" string, which indicates WSL environment.
2095
+
2096
+ :return: True if running in WSL, False otherwise.
2097
+ :rtype: bool
2098
+
2099
+ :example:
2100
+ >>> is_wsl()
2101
+ False # When running on native Linux
2102
+ >>> is_wsl()
2103
+ True # When running in WSL
2104
+ """
2105
+ try:
2106
+ with open("/proc/version", "r") as f:
2107
+ return "microsoft" in f.read().lower()
2108
+ except FileNotFoundError:
2109
+ return False
2110
+
2111
+ def is_windows_path(value):
2112
+ """
2113
+ Check if the value is a valid Windows path format.
2114
+
2115
+ This function validates that a string follows the Windows path format
2116
+ by checking that it starts with a drive letter followed by a colon
2117
+ (e.g., "C:", "D:", etc.). It performs basic format validation without
2118
+ checking if the path actually exists on the filesystem.
2119
+
2120
+ :param value: The value to check for Windows path format.
2121
+ :return: True if the value is a valid Windows path format, False otherwise.
2122
+ :rtype: bool
2123
+
2124
+ :example:
2125
+ >>> is_windows_path("C:\\Users\\username\\file.txt")
2126
+ True
2127
+ >>> is_windows_path("D:\\data\\folder")
2128
+ True
2129
+ >>> is_windows_path("/home/user/file.txt")
2130
+ False
2131
+ >>> is_windows_path("relative/path")
2132
+ False
2133
+ >>> is_windows_path("")
2134
+ False
2135
+ """
2136
+ return isinstance(value, str) and (len(value) > 1 and value[1] == ':' and value[0].isalpha())
2137
+
2138
+
2139
+ def check_same_files_exist_in_all_directories(dirs: List[Union[str, os.PathLike]], raise_error: bool = False, file_type: str = "csv") -> bool:
2140
+ """
2141
+ Check if the same files of a given type exist in all specified directories.
2142
+
2143
+ :param List[Union[str, os.PathLike]] dirs: List of directory paths to check.
2144
+ :param bool raise_error: If True, raises an error when file names do not match across directories. Defaults to False.
2145
+ :param bool raise_error: File extension (without the dot) to check for (e.g., 'csv', 'txt'). Defaults to 'csv'.
2146
+ """
2147
+
2148
+ check_valid_lst( data=dirs, source=f"{check_same_files_exist_in_all_directories.__name__} dirs", valid_dtypes=(str, os.PathLike), min_len=2)
2149
+ file_sets = [{os.path.basename(f) for f in glob.glob(os.path.join(dir, f"*.{file_type}"))} for dir in dirs]
2150
+ common_files = set.intersection(*file_sets) if file_sets else set()
2151
+ if not all(files == common_files for files in file_sets):
2152
+ if raise_error:
2153
+ raise NoFilesFoundError( msg=f"Files of type '{file_type}' do not match across directories: {dirs}.", source=check_same_files_exist_in_all_directories.__name__)
2154
+ return False
2155
+ return True
2156
+
2157
+
2158
+
2159
+ def check_valid_img_path(path: Union[str, os.PathLike], raise_error: bool = True):
2160
+ """
2161
+ Check if a file path is a valid image file.
2162
+
2163
+ This function validates that a file path exists, is readable, and can be
2164
+ opened as an image file using OpenCV. It performs basic image file validation
2165
+ by attempting to read the file with cv2.imread.
2166
+
2167
+ :param Union[str, os.PathLike] path: Path to the image file to validate.
2168
+ :param bool raise_error: If True, raises InvalidInputError when file is not a valid image. If False, returns False. Default: True.
2169
+ :return: True if the file is a valid image file, False if it's not valid and raise_error=False.
2170
+ :rtype: bool
2171
+ :raises InvalidInputError: If the file is not a valid image file and raise_error=True.
2172
+
2173
+ :example:
2174
+ >>> check_valid_img_path('/path/to/image.jpg')
2175
+ True
2176
+ >>> check_valid_img_path('/path/to/invalid.txt', raise_error=False)
2177
+ False
2178
+ >>> check_valid_img_path('/path/to/corrupted.png', raise_error=False)
2179
+ False
2180
+ """
2181
+ check_file_exist_and_readable(path)
2182
+ try:
2183
+ _ = cv2.imread(path)
2184
+ except Exception as e:
2185
+ if raise_error:
2186
+ print(e.args)
2187
+ raise InvalidInputError(msg=f'{path} could not be read as a valid image file', source=check_valid_img_path.__name__)
2188
+ else:
2189
+ return False
2190
+ return True
2191
+
2192
+
2193
+
2194
+
2195
+ def check_valid_device(device: Union[Literal['cpu'], int], raise_error: bool = True) -> bool:
2196
+ """
2197
+ Validate a compute device specification, ensuring it is either 'cpu' or a valid GPU index.
2198
+
2199
+ This function validates that a device specification is valid for use with
2200
+ PyTorch/CUDA operations. It checks if the device is either 'cpu' for CPU
2201
+ usage or a valid integer representing a CUDA device index.
2202
+
2203
+ :param Union[Literal['cpu'], int] device: The device to validate. Should be the string 'cpu' for CPU usage, or an integer representing a CUDA device index (e.g., 0 for 'cuda:0').
2204
+ :param bool raise_error: If True, raises InvalidInputError or SimBAGPUError when the device is invalid. If False, returns False instead of raising errors. Default: True.
2205
+ :return: True if the device is valid, False if it's invalid and raise_error=False.
2206
+ :rtype: bool
2207
+ :raises InvalidInputError: If the device format is invalid and raise_error=True.
2208
+ :raises SimBAGPUError: If the GPU device is not available or not valid and raise_error=True.
2209
+
2210
+ :example:
2211
+ >>> check_valid_device('cpu')
2212
+ True
2213
+ >>> check_valid_device(0) # GPU 0
2214
+ True
2215
+ >>> check_valid_device(5, raise_error=False) # Non-existent GPU
2216
+ False
2217
+ >>> check_valid_device('gpu', raise_error=False) # Invalid format
2218
+ False
2219
+ """
2220
+ source = check_valid_device.__name__
2221
+ if isinstance(device, str):
2222
+ valid, msg = check_str(name=f'{source} format', value=device.lower(), options=['cpu'], raise_error=False)
2223
+ if not valid:
2224
+ if raise_error:
2225
+ raise InvalidInputError(msg=msg, source=source)
2226
+ return False
2227
+ return True
2228
+
2229
+ valid, msg = check_int(name=f'{source} device', value=device, min_value=0, raise_error=False)
2230
+ if not valid:
2231
+ if raise_error:
2232
+ raise InvalidInputError(msg=msg, source=source)
2233
+ return False
2234
+
2235
+ gpu_available, gpus = _is_cuda_available()
2236
+ if not gpu_available:
2237
+ if raise_error:
2238
+ raise SimBAGPUError(msg=f'No GPU detected but device {device} passed', source=source)
2239
+ return False
2240
+
2241
+ if device not in gpus:
2242
+ if raise_error:
2243
+ raise SimBAGPUError(msg=f'Unaccepted GPU device {device} passed. Accepted: {list(gpus.keys())}', source=source)
2244
+ return False
2245
+
2246
+ def is_lxc_container() -> bool:
2247
+ """
2248
+ Helper to check if the current environment is inside a LXC Linux container.
2249
+
2250
+ .. note::
2251
+ See GitHub issue 457 for origin - https://github.com/sgoldenlab/simba/issues/457#issuecomment-3052631284
2252
+ Thanks Heinrich2818 - https://github.com/Heinrich2818
2253
+
2254
+ :return: True if current environment is a LXC linux container, False if not.
2255
+ :rtype: bool
2256
+ """
2257
+
2258
+ try:
2259
+ with open('/proc/1/cgroup') as f:
2260
+ for line in f:
2261
+ if 'lxc' in line:
2262
+ return True
2263
+ except IOError:
2264
+ pass
2265
+ try:
2266
+ with open('/proc/self/mountinfo') as f:
2267
+ for line in f:
2268
+ if ' - cgroup2 ' in line:
2269
+ mount_point = line.strip().split()[-1]
2270
+ if 'lxc' in mount_point:
2271
+ return True
2272
+ except IOError:
2273
+ pass
2274
+ try:
2275
+ with open('/proc/1/environ', 'rb') as f:
2276
+ env = f.read().split(b'\0')
2277
+ for e in env:
2278
+ if e == b'container=lxc':
2279
+ return True
2280
+ except IOError:
2281
+ pass
2282
+ try:
2283
+ import subprocess
2284
+ r = subprocess.run(['systemd-detect-virt', '--container'], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=False, text=True)
2285
+ if r.stdout.strip() == 'lxc':
2286
+ return True
2287
+ except Exception:
2288
+ pass
2289
+ return False
2290
+
2291
+
2292
+
2293
+ def check_valid_cpu_pool(value: Any,
2294
+ source: str = '',
2295
+ max_cores: Optional[int] = None,
2296
+ min_cores: Optional[int] = None,
2297
+ accepted_cores: Optional[Union[List[int], Tuple[int, ...], int]] = None,
2298
+ raise_error: bool = True) -> bool:
2299
+
2300
+ """
2301
+ Validates that a value is a valid multiprocessing.Pool instance and optionally checks core count constraints.
2302
+
2303
+ :param Any value: The value to validate. Must be an instance of multiprocessing.pool.Pool.
2304
+ :param str source: Optional source identifier for error messages. Default is empty string.
2305
+ :param Optional[int] max_cores: Optional maximum number of processes allowed in the pool. If provided, validates that pool._processes <= max_cores.
2306
+ :param Optional[int] min_cores: Optional minimum number of processes required in the pool. If provided, validates that pool._processes >= min_cores.
2307
+ :param Optional[Union[List[int], Tuple[int, ...], int]] accepted_cores: Optional exact or list of acceptable process counts. If an int, validates that pool._processes == accepted_cores. If a list/tuple of ints, validates that pool._processes is in accepted_cores. All values must be positive integers.
2308
+ :param bool raise_error: If True, raises InvalidInputError on validation failure. If False, returns False on failure. Default is True.
2309
+ :return bool: True if validation passes, False if validation fails and raise_error is False.
2310
+ :raises InvalidInputError: If value is not a valid Pool instance, if core count constraints are violated, if accepted_cores contains invalid types, or if raise_error is True.
2311
+
2312
+ :example:
2313
+ >>> import multiprocessing
2314
+ >>> pool = multiprocessing.Pool(processes=4)
2315
+ >>> check_valid_cpu_pool(value=pool, source='test', max_cores=8, min_cores=2)
2316
+ >>> True
2317
+ >>> check_valid_cpu_pool(value=pool, source='test', accepted_cores=[4, 8, 16])
2318
+ >>> True
2319
+ >>> check_valid_cpu_pool(value=pool, source='test', accepted_cores=4)
2320
+ >>> True
2321
+ """
2322
+
2323
+ if not isinstance(value, (multiprocessing.pool.Pool,)):
2324
+ if raise_error:
2325
+ raise InvalidInputError(msg=f'Not a valid CPU pool. Expected {multiprocessing.pool.Pool}, got {type(value)}.', source=source)
2326
+ else:
2327
+ return False
2328
+ if max_cores is not None:
2329
+ check_int(name=f'{source} max_cores', value=max_cores, min_value=1)
2330
+ if value._processes > max_cores:
2331
+ if raise_error: raise InvalidInputError(msg=f'CPU pool has too many processes. Got {value._processes}, max {max_cores}', source=source)
2332
+ else: return False
2333
+ if min_cores is not None:
2334
+ check_int(name=f'{source} min_cores', value=min_cores, min_value=1)
2335
+ if value._processes < min_cores:
2336
+ if raise_error:
2337
+ raise InvalidInputError(msg=f'CPU pool has too few processes. Got {value._processes}, min {min_cores}',source=source)
2338
+ else:
2339
+ return False
2340
+ if accepted_cores is not None:
2341
+ if isinstance(accepted_cores, int):
2342
+ is_valid, _ = check_int(name=f'{source} accepted_cores', value=accepted_cores, min_value=1, raise_error=raise_error)
2343
+ if not is_valid:
2344
+ return False
2345
+ if value._processes != accepted_cores:
2346
+ if raise_error:
2347
+ raise InvalidInputError(msg=f'CPU pool has an unacceptable number of cores. Got {value._processes}, accepted {accepted_cores}', source=source)
2348
+ else:
2349
+ return False
2350
+ elif isinstance(accepted_cores, (tuple, list)):
2351
+ is_valid = check_valid_lst(data=list(accepted_cores), source=f'{source} accepted_cores', valid_dtypes=(int,), min_len=1, min_value=1, raise_error=raise_error)
2352
+ if not is_valid:
2353
+ return False
2354
+ if value._processes not in accepted_cores:
2355
+ if raise_error:
2356
+ raise InvalidInputError(msg=f'CPU pool has an unacceptable number of cores. Got {value._processes}, accepted {accepted_cores}', source=source)
2357
+ else:
2358
+ return False
2359
+ if min_cores is not None:
2360
+ if min(accepted_cores) < min_cores:
2361
+ if raise_error:
2362
+ raise InvalidInputError(msg=f'accepted_cores contains values below min_cores. min_cores={min_cores}, accepted_cores={accepted_cores}', source=source)
2363
+ else:
2364
+ return False
2365
+ if max_cores is not None:
2366
+ if max(accepted_cores) > max_cores:
2367
+ if raise_error:
2368
+ raise InvalidInputError(msg=f'accepted_cores contains values above max_cores. max_cores={max_cores}, accepted_cores={accepted_cores}', source=source)
2369
+ else:
2370
+ return False
2371
+ else:
2372
+ raise InvalidInputError(msg=f'accepted_cores has to be an int, list of ints, or tuple of ints. Got {type(accepted_cores)}', source=source)
2373
+
2374
+ return True
2375
+
2376
+
2377
+ def check_valid_codec(codec: str, raise_error: bool = True, source: str = ''):
2378
+ """
2379
+ Validate that a codec string is available in the current FFmpeg installation.
2380
+
2381
+ Checks if the provided codec name exists in the list of available FFmpeg encoders
2382
+ by querying FFmpeg directly. This ensures the codec can be used for video encoding/decoding.
2383
+
2384
+ .. note::
2385
+ This function requires FFmpeg to be installed and available in the system PATH.
2386
+ The function queries FFmpeg for available encoders at runtime, so it will reflect
2387
+ the actual encoders available in your FFmpeg installation.
2388
+
2389
+ .. seealso::
2390
+ To get a list of all available encoders, see :func:`~simba.utils.lookups.get_ffmpeg_encoders`.
2391
+ To check if FFmpeg is available, see :func:`~simba.utils.checks.check_ffmpeg_available`.
2392
+
2393
+ :param str codec: The codec name to validate (e.g., 'libx264', 'h264_nvenc', 'libvpx-vp9').
2394
+ :param bool raise_error: If True, raises ``InvalidInputError`` when codec is invalid. If False, returns False. Default: True.
2395
+ :param str source: Source identifier for error messages. Used when raising exceptions. Default: ''.
2396
+ :return: True if codec is valid, False if invalid and ``raise_error=False``.
2397
+ :rtype: bool
2398
+ :raises InvalidInputError: If codec is not valid and ``raise_error=True``.
2399
+
2400
+ :example:
2401
+ >>> check_valid_codec(codec='libx264')
2402
+ >>> check_valid_codec(codec='h264_nvenc', source='my_function')
2403
+ >>> is_valid = check_valid_codec(codec='invalid_codec', raise_error=False)
2404
+ """
2405
+ from simba.utils.lookups import get_ffmpeg_encoders; encoders = get_ffmpeg_encoders()
2406
+ valid_codec = check_str(name=f'{check_valid_codec.__name__} codec', value=codec, options=encoders, allow_blank=False, raise_error=False)[0]
2407
+ if not valid_codec:
2408
+ if raise_error:
2409
+ raise InvalidInputError(msg=f'The codec {codec} is not a valid codec in the current FFMPEG installation', source=source)
2410
+ else:
2411
+ return False
2412
+ return True
2413
+
2414
+