simba-uw-tf-dev 4.6.4__py3-none-any.whl → 4.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. simba/data_processors/cuda/geometry.py +45 -27
  2. simba/data_processors/cuda/image.py +1620 -1600
  3. simba/data_processors/cuda/statistics.py +17 -9
  4. simba/data_processors/egocentric_aligner.py +24 -6
  5. simba/data_processors/kleinberg_calculator.py +6 -2
  6. simba/feature_extractors/feature_subsets.py +12 -5
  7. simba/feature_extractors/straub_tail_analyzer.py +0 -2
  8. simba/mixins/statistics_mixin.py +9 -2
  9. simba/sandbox/analyze_runtimes.py +30 -0
  10. simba/sandbox/cuda/egocentric_rotator.py +374 -374
  11. simba/sandbox/proboscis_to_tip.py +28 -0
  12. simba/sandbox/test_directionality.py +47 -0
  13. simba/sandbox/test_nonstatic_directionality.py +27 -0
  14. simba/sandbox/test_pycharm_cuda.py +51 -0
  15. simba/sandbox/test_simba_install.py +41 -0
  16. simba/sandbox/test_static_directionality.py +26 -0
  17. simba/sandbox/test_static_directionality_2d.py +26 -0
  18. simba/sandbox/verify_env.py +42 -0
  19. simba/ui/pop_ups/fsttc_pop_up.py +27 -25
  20. simba/ui/pop_ups/kleinberg_pop_up.py +3 -2
  21. simba/utils/data.py +0 -1
  22. simba/utils/errors.py +441 -440
  23. simba/utils/lookups.py +1203 -1203
  24. simba/utils/read_write.py +38 -13
  25. simba/video_processors/egocentric_video_rotator.py +41 -36
  26. simba/video_processors/video_processing.py +5247 -5233
  27. simba/video_processors/videos_to_frames.py +41 -31
  28. {simba_uw_tf_dev-4.6.4.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/METADATA +2 -2
  29. {simba_uw_tf_dev-4.6.4.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/RECORD +33 -24
  30. {simba_uw_tf_dev-4.6.4.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/LICENSE +0 -0
  31. {simba_uw_tf_dev-4.6.4.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/WHEEL +0 -0
  32. {simba_uw_tf_dev-4.6.4.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/entry_points.txt +0 -0
  33. {simba_uw_tf_dev-4.6.4.dist-info → simba_uw_tf_dev-4.6.6.dist-info}/top_level.txt +0 -0
simba/utils/lookups.py CHANGED
@@ -1,1204 +1,1204 @@
1
- __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
-
3
- import glob
4
- import os
5
- import platform
6
- import random
7
- import re
8
- import struct
9
- import subprocess
10
- import sys
11
- import tkinter as tk
12
- from copy import copy
13
- from datetime import datetime
14
- from multiprocessing import Lock, Value
15
- from pathlib import Path
16
- from typing import Any, Dict, List, Optional, Tuple, Union
17
-
18
- try:
19
- from typing import Literal
20
- except:
21
- from typing_extensions import Literal
22
-
23
- import matplotlib.font_manager
24
- import pandas as pd
25
- import psutil
26
- import pyglet
27
- from matplotlib import cm
28
- from matplotlib.colors import hsv_to_rgb, rgb2hex
29
- from tabulate import tabulate
30
-
31
- import simba
32
- from simba.utils.checks import (check_ffmpeg_available,
33
- check_file_exist_and_readable, check_float,
34
- check_if_dir_exists, check_if_valid_rgb_tuple,
35
- check_instance, check_int, check_str,
36
- check_valid_dict, check_valid_tuple)
37
- from simba.utils.enums import (OS, UML, Defaults, FontPaths, Formats, Keys,
38
- Methods, Options, Paths)
39
- from simba.utils.errors import (FFMPEGNotFoundError, InvalidInputError,
40
- NoFilesFoundError, SimBAPAckageVersionError)
41
- from simba.utils.printing import stdout_information
42
- from simba.utils.read_write import (fetch_pip_data,
43
- find_files_of_filetypes_in_directory,
44
- get_fn_ext, get_video_meta_data, read_json)
45
- from simba.utils.warnings import NoDataFoundWarning
46
-
47
- if platform.system() == OS.WINDOWS.value:
48
- from pyglet.libs.win32 import constants
49
- constants.COINIT_MULTITHREADED = 0x2 # 0x2 = COINIT_APARTMENTTHREADED
50
-
51
-
52
- RGBFloat = Tuple[float, float, float]
53
-
54
- class SharedCounter(object):
55
- """Counter that can be shared across processes on different cores"""
56
-
57
- def __init__(self, initval=0):
58
- self.val = Value("i", initval)
59
- self.lock = Lock()
60
-
61
- def increment(self):
62
- with self.lock:
63
- self.val.value += 1
64
-
65
- def value(self):
66
- with self.lock:
67
- return self.val.value
68
-
69
-
70
- def get_body_part_configurations() -> Dict[str, Union[str, os.PathLike]]:
71
- """
72
- Return dict with named body-part schematics of pose-estimation schemas in SimBA installation as keys,
73
- and paths to the images representing those body-part schematics as values.
74
- """
75
-
76
- lookup = {}
77
- simba_dir = os.path.dirname(simba.__file__)
78
- img_dir = os.path.join(simba_dir, Paths.SCHEMATICS.value)
79
- names_path = os.path.join(simba_dir, Paths.PROJECT_POSE_CONFIG_NAMES.value)
80
- check_file_exist_and_readable(file_path=names_path)
81
- check_if_dir_exists(in_dir=img_dir)
82
- names_lst = list(pd.read_csv(names_path, header=None)[0])
83
- img_paths = glob.glob(img_dir + "/*.png")
84
- img_paths.sort(
85
- key=lambda v: [
86
- int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", v)
87
- ]
88
- )
89
- for name, img_path in zip(names_lst, img_paths):
90
- lookup[name] = {}
91
- lookup[name]["img_path"] = img_path
92
- return lookup
93
-
94
-
95
- def get_bp_config_codes() -> Dict[str, str]:
96
- """
97
- Helper to match SimBA project_config.ini [create ensemble settings][pose_estimation_body_parts] to string names.
98
-
99
- """
100
-
101
- return {
102
- "1 animal; 4 body-parts": "4",
103
- "1 animal; 7 body-parts": "7",
104
- "1 animal; 8 body-parts": "8",
105
- "1 animal; 9 body-parts": "9",
106
- "2 animals; 8 body-parts": "8",
107
- "2 animals; 14 body-parts": "14",
108
- "2 animals; 16 body-parts": "16",
109
- "MARS": Methods.USER_DEFINED.value,
110
- "Multi-animals; 4 body-parts": "8",
111
- "Multi-animals; 7 body-parts": "14",
112
- "Multi-animals; 8 body-parts": "16",
113
- "3D tracking": "3D_user_defined",
114
- "AMBER": "AMBER",
115
- "SimBA BLOB Tracking": Methods.SIMBA_BLOB.value,
116
- "FaceMap": Methods.FACEMAP.value,
117
- "SuperAnimal-TopView": Methods.SUPER_ANIMAL_TOPVIEW.value
118
- }
119
-
120
-
121
- def get_bp_config_code_class_pairs() -> Dict[str, object]:
122
- """
123
- Helper to match SimBA project_config.ini [create ensemble settings][pose_estimation_body_parts] setting to feature extraction module class.
124
- """
125
-
126
- from simba.feature_extractors.amber_feature_extractor import \
127
- AmberFeatureExtractor
128
- from simba.feature_extractors.feature_extractor_4bp import \
129
- ExtractFeaturesFrom4bps
130
- from simba.feature_extractors.feature_extractor_7bp import \
131
- ExtractFeaturesFrom7bps
132
- from simba.feature_extractors.feature_extractor_8bp import \
133
- ExtractFeaturesFrom8bps
134
- from simba.feature_extractors.feature_extractor_8bps_2_animals import \
135
- ExtractFeaturesFrom8bps2Animals
136
- from simba.feature_extractors.feature_extractor_9bp import \
137
- ExtractFeaturesFrom9bps
138
- from simba.feature_extractors.feature_extractor_14bp import \
139
- ExtractFeaturesFrom14bps
140
- from simba.feature_extractors.feature_extractor_16bp import \
141
- ExtractFeaturesFrom16bps
142
- from simba.feature_extractors.feature_extractor_user_defined import \
143
- UserDefinedFeatureExtractor
144
-
145
- return {
146
- "16": ExtractFeaturesFrom16bps,
147
- "14": ExtractFeaturesFrom14bps,
148
- "9": ExtractFeaturesFrom9bps,
149
- "8": {1: ExtractFeaturesFrom8bps, 2: ExtractFeaturesFrom8bps2Animals},
150
- "7": ExtractFeaturesFrom7bps,
151
- "4": ExtractFeaturesFrom4bps,
152
- "user_defined": UserDefinedFeatureExtractor,
153
- "AMBER": AmberFeatureExtractor,
154
- }
155
-
156
-
157
-
158
- def rgb_to_hex(color: Tuple[int, int, int]) -> str:
159
- check_if_valid_rgb_tuple(data=color, raise_error=True, source=rgb_to_hex.__name__)
160
- r, g, b = color
161
- return rgb2hex((r/255, g/255, b/255), keep_alpha=False)
162
-
163
- def get_icons_paths() -> Dict[str, Union[str, os.PathLike]]:
164
- """
165
- Helper to get dictionary with icons with the icon names as keys (grabbed from file-name) and their
166
- file paths as values.
167
- """
168
-
169
- simba_dir = os.path.dirname(simba.__file__)
170
- icons_dir = os.path.join(simba_dir, Paths.ICON_ASSETS.value)
171
- icon_paths = glob.glob(icons_dir + "/*.png")
172
- icons = {}
173
- for icon_path in icon_paths:
174
- _, icon_name, _ = get_fn_ext(icon_path)
175
- icons[icon_name] = {}
176
- icons[icon_name]["icon_path"] = icon_path
177
- return icons
178
-
179
- def load_simba_fonts():
180
- """ Load fonts defined in simba.utils.enums.FontPaths into memory"""
181
- simba_dir = os.path.dirname(simba.__file__)
182
- font_enum = {i.name: i.value for i in FontPaths}
183
- for k, v in font_enum.items():
184
- pyglet.font.add_file(os.path.join(simba_dir, v))
185
-
186
- def get_third_party_appender_file_formats() -> Dict[str, str]:
187
- """
188
- Helper to get dictionary that maps different third-party annotation tools with different file formats.
189
- """
190
-
191
- return {
192
- "BORIS": "csv",
193
- "ETHOVISION": "xlsx",
194
- "OBSERVER": "xlsx",
195
- "SOLOMON": "csv",
196
- "DEEPETHOGRAM": "csv",
197
- "BENTO": "annot",
198
- }
199
-
200
-
201
- def get_emojis() -> Dict[str, str]:
202
- """
203
- Helper to get dictionary of emojis with names as keys and emojis as values. Note, the same emojis are
204
- represented differently in different python versions.
205
- """
206
- python_version = str(f"{sys.version_info.major}.{sys.version_info.minor}")
207
- if python_version == "3.6":
208
- return {"thank_you": "".join(chr(x) for x in struct.unpack(">2H", "\U0001f64f".encode("utf-16be"))),
209
- "relaxed": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F600".encode("utf-16be"))),
210
- "error": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F6A8".encode("utf-16be"))),
211
- "complete": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F680".encode("utf-16be"))),
212
- "warning": "".join(chr(x) for x in struct.unpack(">2H", "\u2757\uFE0F".encode("utf-16be"))),
213
- "trash": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F5D1".encode("utf-16be"))),
214
- "information": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F4DD".encode("utf-16be")))} # 📝 memo
215
-
216
- elif python_version == "3.10" or python_version == "3.9":
217
- return {
218
- "thank_you": "\U0001f64f".encode("utf-8", "replace").decode(),
219
- "relaxed": "\U0001F600".encode("utf-8", "replace").decode(),
220
- "warning": "\u2757\uFE0F".encode("utf-8", "replace").decode(),
221
- "error": "\U0001F6A8".encode("utf-8", "replace").decode(),
222
- "complete": "\U0001F680".encode("utf-8", "replace").decode(),
223
- "trash": "\U0001F5D1".encode("utf-8", "replace").decode(),
224
- "information": "\U0001F4DD".encode("utf-8", "replace").decode(), # 📝 memo
225
- }
226
-
227
- elif python_version == "3.7":
228
- return {
229
- "thank_you": "\U0001f64f".encode("utf16", errors="surrogatepass").decode(
230
- "utf16"
231
- ),
232
- "relaxed": "\U0001F600".encode("utf16", errors="surrogatepass").decode(
233
- "utf16"
234
- ),
235
- "error": "\U0001F6A8".encode("utf16", errors="surrogatepass").decode(
236
- "utf16"
237
- ),
238
- "complete": "\U0001F680".encode("utf16", errors="surrogatepass").decode(
239
- "utf16"
240
- ),
241
- "warning": "\u2757\uFE0F".encode("utf16", errors="surrogatepass").decode(
242
- "utf16"
243
- ),
244
- "trash": "\U0001F5D1".encode("utf16", errors="surrogatepass").decode(
245
- "utf16"
246
- ),
247
- "information": "\U0001F4DD".encode("utf16", errors="surrogatepass").decode(
248
- "utf16"
249
- ), # 📝 memo
250
- }
251
-
252
- else:
253
- return {
254
- "thank_you": "\U0001f64f",
255
- "relaxed": "\U0001F600",
256
- "error": "\U0001F6A8",
257
- "complete": "\U0001F680",
258
- "warning": "\u2757\uFE0F",
259
- "trash": "\U0001F5D1",
260
- "information": "\U0001F4DD",
261
- # 📝 memo
262
- }
263
-
264
- def get_meta_data_file_headers() -> List[str]:
265
- """
266
- Get List of headers for SimBA classifier metadata output.
267
-
268
- :return List[str]:
269
- """
270
-
271
- return [
272
- "Classifier_name",
273
- "RF_criterion",
274
- "RF_max_features",
275
- "RF_min_sample_leaf",
276
- "RF_n_estimators",
277
- "compute_feature_permutation_importance",
278
- "generate_classification_report",
279
- "generate_example_decision_tree",
280
- "generate_features_importance_bar_graph",
281
- "generate_features_importance_log",
282
- "generate_precision_recall_curves",
283
- "generate_rf_model_meta_data_file",
284
- "generate_sklearn_learning_curves",
285
- "learning_curve_data_splits",
286
- "learning_curve_k_splits",
287
- "n_feature_importance_bars",
288
- "over_sample_ratio",
289
- "over_sample_setting",
290
- "train_test_size",
291
- "train_test_split_type",
292
- "under_sample_ratio",
293
- "under_sample_setting",
294
- "class_weight",
295
- 'rf_max_depth',
296
- ]
297
-
298
- def get_cmaps() -> List[str]:
299
- """
300
- Get list of named matplotlib color palettes.
301
- """
302
- return [
303
- "spring",
304
- "summer",
305
- "autumn",
306
- "cool",
307
- "Wistia",
308
- "Pastel1",
309
- "Set1",
310
- "winter",
311
- "afmhot",
312
- "gist_heat",
313
- "copper",
314
- ]
315
-
316
-
317
- def get_categorical_palettes():
318
- return [
319
- "Pastel1",
320
- "Pastel2",
321
- "Paired",
322
- "Accent",
323
- "Dark2",
324
- "Set1",
325
- "Set2",
326
- "Set3",
327
- "tab10",
328
- "tab20",
329
- "tab20b",
330
- "tab20c",
331
- ]
332
-
333
-
334
- def get_color_dict() -> Dict[str, Tuple[int, int, int]]:
335
- """
336
- Get dict of color names as keys and RGB tuples as values
337
- """
338
-
339
- return {
340
- "Grey": (220, 200, 200),
341
- "Red": (0, 0, 255),
342
- "Dark-red": (0, 0, 139),
343
- "Maroon": (0, 0, 128),
344
- "Orange": (0, 165, 255),
345
- "Dark-orange": (0, 140, 255),
346
- "Coral": (80, 127, 255),
347
- "Chocolate": (30, 105, 210),
348
- "Yellow": (0, 255, 255),
349
- "Green": (0, 128, 0),
350
- "Dark-grey": (105, 105, 105),
351
- "Light-grey": (192, 192, 192),
352
- "Pink": (178, 102, 255),
353
- "Lime": (204, 255, 229),
354
- "Purple": (255, 51, 153),
355
- "Cyan": (255, 255, 102),
356
- "White": (255, 255, 255),
357
- "Black": (0, 0, 0),
358
- "Darkgoldenrod": (184, 134, 11),
359
- "Olive": (109, 113, 46),
360
- "Seagreen": (46, 139, 87),
361
- "Dodgerblue": (30, 144, 255),
362
- "Springgreen": (0, 255, 127),
363
- "Firebrick": (178, 34, 34),
364
- "Indigo": (63, 15, 183),
365
- }
366
-
367
-
368
- def get_named_colors() -> List[str]:
369
- """
370
- Get list of named matplotlib colors.
371
- """
372
- return [
373
- "red",
374
- "pink",
375
- "lime",
376
- "gold",
377
- "coral",
378
- "lavender",
379
- "sienna",
380
- "tomato",
381
- "grey",
382
- "azure",
383
- "crimson",
384
- "lightgrey",
385
- "aqua",
386
- "plum",
387
- "blue",
388
- "teal",
389
- "maroon",
390
- "green",
391
- "black",
392
- "deeppink",
393
- "darkgoldenrod",
394
- "purple",
395
- "olive",
396
- "seagreen",
397
- "dodgerblue",
398
- "springgreen",
399
- "firebrick",
400
- "indigo",
401
- "white",
402
- ]
403
-
404
-
405
- def create_color_palettes(no_animals: int, map_size: int) -> List[List[int]]:
406
- """
407
- Create list of lists of bgr colors, one for each animal. Each list is pulled from a different palette
408
- matplotlib color map.
409
-
410
- :param int no_animals: Number of different palette lists
411
- :param int map_size: Number of colors in each created palette.
412
- :return List[List[int]]: BGR colors
413
-
414
- :example:
415
- >>> create_color_palettes(no_animals=2, map_size=2)
416
- >>> [[[255.0, 0.0, 255.0], [0.0, 255.0, 255.0]], [[102.0, 127.5, 0.0], [102.0, 255.0, 255.0]]]
417
- """
418
- colorListofList = []
419
- cmaps = [
420
- "spring",
421
- "summer",
422
- "autumn",
423
- "cool",
424
- "Wistia",
425
- "Pastel1",
426
- "Set1",
427
- "winter",
428
- "afmhot",
429
- "gist_heat",
430
- "copper",
431
- "viridis",
432
- "Set3",
433
- "Set2",
434
- "Paired",
435
- "seismic",
436
- "prism",
437
- "ocean",
438
- ]
439
-
440
- for colormap in range(no_animals):
441
- if hasattr(cm, "cmap_d") and colormap in cm.cmap_d:
442
- currColorMap = cm.get_cmap(cmaps[colormap], map_size)
443
- else:
444
- currColorMap = cm.get_cmap("spring", map_size)
445
- currColorList = []
446
- for i in range(currColorMap.N):
447
- rgb = list((currColorMap(i)[:3]))
448
- rgb = [i * 255 for i in rgb]
449
- rgb.reverse()
450
- currColorList.append(rgb)
451
- colorListofList.append(currColorList)
452
- return colorListofList
453
-
454
-
455
-
456
- def get_random_color_palette(n_colors: int):
457
- """ Get a random color palette with N random colors."""
458
- check_int(name=f'{get_random_color_palette.__name__} n_colors', value=n_colors, min_value=1, raise_error=True)
459
- return [tuple(random.randint(0, 255) for _ in range(3)) for _ in range(n_colors)]
460
-
461
- def cardinality_to_integer_lookup() -> Dict[str, int]:
462
- """
463
- Create dictionary that maps cardinal compass directions to integers.
464
-
465
- :example:
466
- >>> data = ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]
467
- >>> [cardinality_to_integer_lookup()[d] for d in data]
468
- >>> [0, 1, 2, 3, 4, 5, 6, 7]
469
- """
470
-
471
- return {"N": 0, "NE": 1, "E": 2, "SE": 3, "S": 4, "SW": 5, "W": 6, "NW": 7}
472
-
473
-
474
- def integer_to_cardinality_lookup():
475
- """
476
- Create dictionary that maps integers to cardinal compass directions.
477
- """
478
- return {0: "N", 1: "NE", 2: "E", 3: "SE", 4: "S", 5: "SW", 6: "W", 7: "NW"}
479
-
480
-
481
- def percent_to_crf_lookup() -> Dict[str, int]:
482
- """
483
- Create dictionary that matches human-readable percent values to FFmpeg Constant Rate Factor (CRF)
484
- values that regulates video quality in CPU codecs. Higher CRF values translates to lower video quality and reduced
485
- file sizes.
486
- """
487
- return {
488
- "10": 37,
489
- "20": 34,
490
- "30": 31,
491
- "40": 28,
492
- "50": 25,
493
- "60": 22,
494
- "70": 19,
495
- "80": 16,
496
- "90": 13,
497
- "100": 10,
498
- }
499
-
500
-
501
- def gpu_quality_to_cpu_quality_lk():
502
- return {'fast': 34,
503
- 'medium': 23,
504
- 'slow': 13}
505
-
506
- def percent_to_qv_lk():
507
- """
508
- Create dictionary that matches human-readable percent values to FFmpeg regulates video quality in CPU codecs.
509
- Higher FFmpeg quality scores maps to smaller, lower quality videos. Used in some AVI codecs such as 'divx' and 'mjpeg'.
510
- """
511
- return {100: 3,
512
- 90: 5,
513
- 80: 7,
514
- 70: 9,
515
- 60: 11,
516
- 50: 13,
517
- 40: 15,
518
- 30: 17,
519
- 20: 19,
520
- 10: 21}
521
-
522
- def get_ffmpeg_crossfade_methods():
523
- return ['fade',
524
- 'fadeblack',
525
- 'fadewhite',
526
- 'distance',
527
- 'wipeleft',
528
- 'wiperight',
529
- 'wipeup',
530
- 'wipedown',
531
- 'sideleft',
532
- 'sideright',
533
- 'sideup',
534
- 'sidedown',
535
- 'smoothleft',
536
- 'smoothright',
537
- 'smoothup',
538
- 'smoothdown',
539
- 'circlecrop',
540
- 'rectcrop',
541
- 'circleclose',
542
- 'circleopen',
543
- 'horzclose',
544
- 'horzopen',
545
- 'vertclose',
546
- 'vertopen',
547
- 'diagbl',
548
- 'diagbr',
549
- 'diagtl',
550
- 'diagtr',
551
- 'hlslice',
552
- 'hrslice',
553
- 'vuslice',
554
- 'vdslice',
555
- 'dissolve',
556
- 'pixelize',
557
- 'radial',
558
- 'hblur',
559
- 'wipetl',
560
- 'wipetr',
561
- 'wipebl',
562
- 'wipebr',
563
- 'fadegrays',
564
- 'squeezev',
565
- 'squeezeh',
566
- 'zoomin',
567
- 'hlwind',
568
- 'hrwind',
569
- 'vuwind',
570
- 'vdwind',
571
- 'coverleft',
572
- 'coverright',
573
- 'cobverup',
574
- 'coverdown',
575
- 'revealleft',
576
- 'revealright',
577
- 'revealup',
578
- 'revealdown']
579
-
580
-
581
- def video_quality_to_preset_lookup() -> Dict[str, str]:
582
- """
583
- Create dictionary that matches human-readable video quality settings to FFmpeg presets for GPU codecs.
584
- """
585
- return {"Low": "fast", "Medium": "medium", "High": "slow"}
586
-
587
-
588
- def get_labelling_img_kbd_bindings() -> dict:
589
- """
590
- Returns dictionary of tkinter keyboard bindings.
591
-
592
- .. note::
593
- Change ``kbd`` values to change keyboard shortcuts. For example:
594
-
595
- Some possible examples:
596
- <Key>, <KeyPress>, <KeyRelease>: Binds to any key press or release.
597
- <KeyPress-A>, <Key-a>: Binds to the 'a' key press (case sensitive).
598
- <Up>, <Down>, <Left>, <Right>: Binds to the arrow keys.
599
- <Control-KeyPress-A>, <Control-a>: Binds to Ctrl + A or Ctrl + a
600
- """
601
- return \
602
- {'frame+1': # MOVE FORWARD 1 FRAME
603
- {'label': 'Right Arrow = +1 frame',
604
- 'kbd': "<Right>"},
605
- 'frame-1': # MOVE BACK 1 FRAME
606
- {'label': 'Left Arrow = -1 frame',
607
- 'kbd': "<Left>"},
608
- 'save': # SAVE CURRENT ANNOTATIONS STATS TO DISK
609
- {'label': 'Ctrl + s = Save annotations file',
610
- 'kbd': "<Control-s>"},
611
- 'frame+1_keep_choices': # MOVE FORWARD 1 FRAME AND KEEP ANNOTATION SELECTIONS OF THE CURRENT FRAME
612
- {'label': 'Ctrl + a = +1 frame and keep choices',
613
- 'kbd': "<Control-a>"},
614
- 'frame-1_keep_choices': # MOVE BACKWARDS 1 FRAME AND KEEP ANNOTATION SELECTIONS OF THE CURRENT FRAME
615
- {'label': 'Ctrl + q = -1 frame and keep choices',
616
- 'kbd': "<Control-q>"},
617
- 'print_annotation_statistic': # PRINT ANNOTATION STATISTICS
618
- {'label': 'Ctrl + p = Show annotation statistics',
619
- 'kbd': "<Control-p>"},
620
- 'last_frame': # SHOW LAST FRAME
621
- {'label': 'Ctrl + l = Last frame',
622
- 'kbd': "<Control-l>"},
623
- 'first_frame': # SHOW FIRT FRAME
624
- {'label': 'Ctrl + o = First frame',
625
- 'kbd': "<Control-o>"}
626
- }
627
-
628
- def get_labelling_video_kbd_bindings() -> dict:
629
- """
630
- Returns a dictionary of OpenCV-compatible keyboard bindings for video labeling.
631
-
632
- Notes:
633
- - Change the `kbd` values to customize keyboard shortcuts.
634
- - OpenCV key codes differ from Tkinter bindings (see `get_labelling_img_kbd_bindings`).
635
- - Use either single-character strings (e.g. 'p') or integer ASCII codes (e.g. 32 for space bar).
636
-
637
- Examples:
638
- Remap space bar to Pause/Play:
639
- {'Pause/Play': {'label': 'Space = Pause/Play', 'kbd': 32}}
640
- """
641
-
642
- bindings = {
643
- 'Pause/Play': {
644
- 'label': 'p = Pause/Play',
645
- 'kbd': 'p'
646
- },
647
- 'forward_two_frames': {
648
- 'label': 'o = +2 frames',
649
- 'kbd': 'o'
650
- },
651
- 'forward_ten_frames': {
652
- 'label': 'e = +10 frames',
653
- 'kbd': 'e'
654
- },
655
- 'forward_one_second': {
656
- 'label': 'w = +1 second',
657
- 'kbd': 'w'
658
- },
659
- 'backwards_two_frames': {
660
- 'label': 't = -2 frames',
661
- 'kbd': 't'
662
- },
663
- 'backwards_ten_frames': {
664
- 'label': 's = -10 frames',
665
- 'kbd': 's'
666
- },
667
- 'backwards_one_second': {
668
- 'label': 'x = -1 second',
669
- 'kbd': 'x'
670
- },
671
- 'close_window': {
672
- 'label': 'q = Close video window',
673
- 'kbd': 'q'
674
- },
675
- }
676
-
677
-
678
- #PERFORM CHECKS THAT BINDINGS ARE DEFINED CORRECTLY.
679
- check_valid_dict( x=bindings, valid_key_dtypes=(str,), valid_values_dtypes=(dict,), source=f'{get_labelling_video_kbd_bindings.__name__} bindings')
680
- cleaned_bindings = {}
681
- for action, config in bindings.items():
682
- check_valid_dict(x=config, valid_key_dtypes=(str,), valid_values_dtypes=(str, int), required_keys=('label', 'kbd'))
683
- kbd_val = config['kbd']
684
- check_str(value=config['label'], allow_blank=False, raise_error=True, name=f'{get_labelling_video_kbd_bindings.__name__} action')
685
- if check_int(name=f'{action} kbd', value=kbd_val, raise_error=False)[0]:
686
- new_config = copy(config)
687
- new_config['kbd'] = int(kbd_val)
688
- cleaned_bindings[action] = new_config
689
- else:
690
- cleaned_bindings[action] = config
691
-
692
- return cleaned_bindings
693
-
694
-
695
- def get_fonts(sort_alphabetically: bool = False):
696
- """ Returns a dictionary with all fonts available in OS, with the font name as key and font path as value"""
697
- font_dict = {f.name: f.fname for f in matplotlib.font_manager.fontManager.ttflist if not f.name.startswith('.')}
698
- if len(font_dict) == 0:
699
- NoDataFoundWarning(msg='No fonts found on disk using matplotlib.font_manager', source=get_fonts.__name__)
700
- if platform.system() == OS.WINDOWS.value:
701
- font_dict = {key: str(Path(value.replace('C:', '')).as_posix()) for key, value in font_dict.items()}
702
- if sort_alphabetically:
703
- font_dict = dict(sorted(font_dict.items(), key=lambda x: [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', x[0])]))
704
- return font_dict
705
-
706
- def get_log_config():
707
- return {
708
- "version": 1,
709
- "disable_existing_loggers": False,
710
- "formatters": {
711
- "default": {
712
- "format": "%(asctime)s|%(name)s||%(message)s",
713
- "datefmt": "%Y-%m-%dT%H:%M:%SZ",
714
- # "class": "pythonjsonlogger.jsonlogger.JsonFormatter",
715
- }
716
- },
717
- "handlers": {
718
- "file_handler": {
719
- "class": "logging.handlers.RotatingFileHandler",
720
- "formatter": "default",
721
- "mode": "a",
722
- "backupCount": 5,
723
- "maxBytes": 5000000,
724
- }
725
- },
726
- "loggers": {"": {"level": "INFO", "handlers": ["file_handler"]}},
727
- }
728
-
729
- def get_model_names():
730
- model_names_dir = os.path.join(os.path.dirname(simba.__file__), Paths.UNSUPERVISED_MODEL_NAMES.value)
731
- return list(pd.read_parquet(model_names_dir)[UML.NAMES.value])
732
-
733
- def win_to_wsl_path(win_path: Union[str, os.PathLike]) -> str:
734
- """Helper to convert a windows path name, to a WSL path name"""
735
- result = subprocess.run(["wsl.exe", "wslpath", win_path], capture_output=True, text=True)
736
- if result.returncode != 0:
737
- raise RuntimeError(f"WSL path conversion failed: {result.stderr}")
738
- return result.stdout.strip()
739
-
740
-
741
- def get_available_ram():
742
- total = psutil.virtual_memory().total
743
- available = psutil.virtual_memory().available
744
- total_mb = total / (1024 ** 2)
745
- available_mb = available / (1024 ** 2)
746
-
747
- results = {}
748
- results["bytes"] = total
749
- results["available_bytes"] = available
750
- results["megabytes"] = total_mb
751
- results["available_mb"] = available_mb
752
- results["gigabytes"] = total_mb / 1024
753
- results["available_gb"] = available_mb / 1024
754
-
755
- return results
756
-
757
-
758
- def get_current_time():
759
- return datetime.now().strftime("%H:%M:%S")
760
-
761
-
762
- def get_display_resolution() -> Tuple[int, int]:
763
- """
764
- Helper to get main monitor / display resolution.
765
-
766
- .. note::
767
- May return the virtual geometry in multi-display setups. To return the resolution of each available monitor in mosaic, see :func:`simba.utils.lookups.get_monitor_info`.
768
-
769
- """
770
- root = tk.Tk()
771
- root.withdraw()
772
- width = root.winfo_screenwidth()
773
- height = root.winfo_screenheight()
774
- root.destroy()
775
- return (width, height)
776
-
777
-
778
- def get_img_resize_info(img_size: Tuple[int ,int],
779
- display_resolution: Optional[Tuple[int, int]] = None,
780
- max_height_ratio: float = 0.5,
781
- max_width_ratio: float = 0.5,
782
- min_height_ratio: float = 0.0,
783
- min_width_ratio: float = 0.0) -> Tuple[int, int, float, float]:
784
- """
785
- Calculates the new dimensions and scaling factors needed to resize an image while preserving its
786
- aspect ratio so that it fits within a given portion of the display resolution.
787
-
788
- :param Tuple[int, int] img_size : The original size of the image as (width, height).
789
- :param Optional[Tuple[int, int]] display_resolution: Optional resolution of the display as (width, height). If none, then grabs the resolution of the main monitor.
790
- :param float max_height_ratio: The maximum allowed height of the image as a fraction of the display height (default is 0.5).
791
- :param float max_width_ratio: The maximum allowed width of the image as a fraction of the display width (default is 0.5).
792
- :returns: Length 4 tuple with resized width, resized height, downscale factor, and upscale factor
793
- :rtype: Tuple[int, int, float, float]
794
- """
795
-
796
- if display_resolution is None:
797
- _, display_resolution = get_monitor_info()
798
- max_width = round(display_resolution[0] * max_width_ratio)
799
- max_height = round(display_resolution[1] * max_height_ratio)
800
- min_width = round(display_resolution[0] * min_width_ratio)
801
- min_height = round(display_resolution[1] * min_height_ratio)
802
-
803
- if img_size[1] > max_width or img_size[0] > max_height:
804
- width_ratio = max_width / img_size[0]
805
- height_ratio = max_height / img_size[1]
806
- downscale_factor = min(width_ratio, height_ratio)
807
- upscale_factor = 1 / downscale_factor
808
- new_width = round(img_size[0] * downscale_factor)
809
- new_height = round(img_size[1] * downscale_factor)
810
- return new_width, new_height, downscale_factor, upscale_factor
811
-
812
-
813
- elif img_size[1] < min_width or img_size[0] < min_height:
814
- width_ratio = min_width / img_size[0]
815
- height_ratio = min_height / img_size[1]
816
- scale = max(width_ratio, height_ratio) # ensures both dimensions meet or exceed min
817
- new_width = round(round(img_size[0] * scale))
818
- new_height = round(round(img_size[1] * scale))
819
- return new_width, new_height, scale, 1 / scale
820
-
821
- else:
822
- return img_size[0], img_size[1], 1, 1
823
-
824
- def is_running_in_ide():
825
- return hasattr(sys, 'ps1') or sys.flags.interactive
826
-
827
-
828
- def get_monitor_info() -> Tuple[Dict[int, Dict[str, Union[int, bool]]], Tuple[int, int]]:
829
- """
830
- Helper to get main monitor / display resolution.
831
-
832
- .. note::
833
- Returns dict containing the resolution of each available monitor. To get the virtual geometry, see :func:`simba.utils.lookups.get_display_resolution`, and tuple of main monitor width and height.
834
- """
835
- monitors = pyglet.canvas.get_display().get_screens()
836
- results = {}
837
- for monitor_cnt, monitor_info in enumerate(monitors):
838
- primary = True if monitor_info.x == 0 and monitor_info.y == 0 else False
839
- results[monitor_cnt] = {'width': monitor_info.width,
840
- 'height': monitor_info.height,
841
- 'primary': primary}
842
-
843
- main_monitor = next(({'width': v['width'], 'height': v['height']} for v in results.values() if v.get('primary')), {'width': next(iter(results.values()))['width'], 'height': next(iter(results.values()))['height']})
844
-
845
- return results, (int(main_monitor['width']), int(main_monitor['height']))
846
-
847
-
848
-
849
- def get_table(data: Dict[str, Any],
850
- headers: Optional[Tuple[str, str]] = ("SETTING", "VALUE"),
851
- tablefmt: str = "grid") -> str:
852
- """
853
- Create a formatted table string from dictionary data using the tabulate library.
854
-
855
- Converts a dictionary into a formatted table string suitable for display
856
- or printing. Each key-value pair in the dictionary becomes a row in the table.
857
-
858
- :param Dict[str, Any] data: Dictionary containing the data to be formatted as a table. Keys become the first column, values become the second column.
859
- :param Optional[Tuple[str, str]] headers: Tuple of two strings representing the column headers. Default is ("SETTING", "VALUE").
860
- :param Literal["grid"] tablefmt: Table format style. For options, see simba.utils.enums.Formats.VALID_TABLEFMT
861
- :return str: Formatted table string ready for display or printing.
862
-
863
- :example:
864
- >>> data = {"fps": 30, "width": 1920, "height": 1080, "frame_count": 3000}
865
- >>> table = get_table(data=data, headers=("PARAMETER", "VALUE"))
866
- """
867
-
868
- check_valid_dict(x=data, valid_key_dtypes=(str,), min_len_keys=1, source=f'{get_table.__name__} data')
869
- check_valid_tuple(x=headers, source=f'{get_table.__name__} data', accepted_lengths=(2,), valid_dtypes=(str,))
870
- check_str(name=f'{get_table.__name__} tablefmt', value=tablefmt, options=Formats.VALID_TABLEFMT.value, raise_error=True)
871
- table_view = [[key, data[key]] for key in data]
872
- return tabulate(table_view, headers=headers, tablefmt=tablefmt)
873
-
874
-
875
-
876
-
877
- def print_video_meta_data(data_path: Union[str, os.PathLike]) -> None:
878
- """
879
- Print video metadata as formatted tables to the console.
880
-
881
- This function reads video metadata from either a single video file or all video files
882
- in a directory, then prints the metadata as formatted tables.
883
-
884
- .. seealso::
885
- To get video metadata as a dictionary without printing, use :func:`simba.utils.read_write.get_video_meta_data`.
886
- To get video metadata as a table without printing, use :func:`simba.utils.lookups.get_table`.
887
-
888
- :param Union[str, os.PathLike] data_path: Path to video file or directory containing videos.
889
- :return: None. Video metadata is printed as formatted tables in the main console.
890
- """
891
-
892
- if os.path.isfile(data_path):
893
- video_meta_data = [get_video_meta_data(video_path=data_path)]
894
- elif os.path.isdir(data_path):
895
- video_paths = find_files_of_filetypes_in_directory(directory=data_path, extensions=Options.ALL_VIDEO_FORMAT_OPTIONS.value, raise_error=False)
896
- video_meta_data = [get_video_meta_data(video_path=x) for x in video_paths]
897
- else:
898
- raise NoFilesFoundError(msg=f'{data_path} is not a valid file or directory path', source=print_video_meta_data.__name__)
899
- for video_meta in video_meta_data:
900
- table = get_table(data=video_meta, headers=('VIDEO PARAMETER', 'VALUE'), tablefmt='grid')
901
- print(f"{table} {Defaults.STR_SPLIT_DELIMITER.value}TABLE")
902
-
903
-
904
- def get_ffmpeg_encoders(raise_error: bool = True, alphabetically_sorted: bool = False) -> List[str]:
905
- """
906
- Get a list of all available FFmpeg encoders.
907
-
908
- :param bool raise_error: If True, raises an exception when FFmpeg is not available or the command fails. If False, returns an empty list on error. Default: True.
909
- :return: List of encoder names (e.g., ['libx264', 'aac', 'libvpx', ...]). Returns empty list if FFmpeg is unavailable and raise_error=False.
910
- :rtype: List[str]
911
-
912
- :example:
913
- >>> codecs = get_ffmpeg_encoders()
914
- >>> print(Formats.BATCH_CODEC.value in codecs)
915
- """
916
-
917
- check_ffmpeg_available(raise_error=True)
918
- try:
919
- proc = subprocess.Popen(['ffmpeg', '-encoders'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
920
- stdout, stderr = proc.communicate()
921
- if isinstance(stdout, bytes):
922
- stdout = stdout.decode('utf-8')
923
- except Exception as e:
924
- if raise_error:
925
- raise FFMPEGNotFoundError(msg=str(e.args))
926
- else:
927
- return []
928
- encoders = []
929
- lines = stdout.split('\n')
930
-
931
- for line in lines:
932
- if re.match(r'^\s*[VAS]', line):
933
- parts = line.split()
934
- if len(parts) >= 2:
935
- encoder_name = parts[1]
936
- encoders.append(encoder_name)
937
-
938
- return sorted(encoders) if alphabetically_sorted else encoders
939
-
940
-
941
- def find_closest_string(target: str,
942
- string_list: List[str],
943
- case_sensitive: bool = False,
944
- token_based: bool = True) -> Optional[Tuple[str, Union[int, float]]]:
945
- """
946
- Find the closest string in a list to a target string using hybrid similarity matching.
947
-
948
- This function uses a combination of token-based matching and Levenshtein distance to find
949
- the best match. Token-based matching is particularly useful for strings like body part names
950
- where word order may vary (e.g., "Left_ear" vs "Ear_left").
951
-
952
- :param str target: The target string to match against.
953
- :param List[str] string_list: List of strings to search through.
954
- :param bool case_sensitive: If True, comparison is case-sensitive. If False (default), comparison is case-insensitive.
955
- :param bool token_based: If True (default), uses hybrid token-based and Levenshtein matching which handles word reordering better. If False, uses pure Levenshtein distance only.
956
- :return: Tuple of (closest_string, distance) or None if string_list is empty. When token_based=True, distance is a float score (lower is better). When token_based=False, distance is integer edit distance.
957
- :rtype: Optional[Tuple[str, Union[int, float]]]
958
-
959
- :example:
960
- >>> find_closest_string("cat", ["dog", "car", "bat"])
961
- >>> ('car', 0.33)
962
- >>> find_closest_string("Left_ear", ["Ear_left", "Right_ear", "Nose"])
963
- >>> ('Ear_left', 0.0)
964
- >>> find_closest_string("CAT", ["dog", "car", "bat"], case_sensitive=False)
965
- >>> ('car', 0.33)
966
- >>> find_closest_string("CAT", ["dog", "car", "bat"], case_sensitive=True, token_based=False)
967
- >>> ('car', 3)
968
- """
969
-
970
- check_str(name=f'{find_closest_string.__name__} target', value=target, allow_blank=False, raise_error=True)
971
- check_instance(source=f'{find_closest_string.__name__} string_list', instance=string_list, accepted_types=(list,), raise_error=True)
972
- if len(string_list) == 0:
973
- return None
974
- for i in string_list:
975
- check_str(name=f'{find_closest_string.__name__} string_list entry', value=i, allow_blank=False, raise_error=True)
976
-
977
- def levenshtein(s1: str, s2: str) -> int:
978
- if s1 == s2: return 0
979
- if not s1: return len(s2)
980
- if not s2: return len(s1)
981
- if len(s1) > len(s2): s1, s2 = s2, s1
982
- prev_row = list(range(len(s1) + 1))
983
- for i, c2 in enumerate(s2):
984
- curr_row = [i + 1]
985
- for j, c1 in enumerate(s1):
986
- cost = 0 if c1 == c2 else 1
987
- curr_row.append(min(prev_row[j + 1] + 1, curr_row[j] + 1, prev_row[j] + cost))
988
- prev_row = curr_row
989
- return prev_row[-1]
990
-
991
- def tokenize(s: str) -> List[str]:
992
- """Split string by common delimiters and return sorted tokens"""
993
- tokens = re.split(r'[_\-\s]+', s)
994
- return sorted([t for t in tokens if t])
995
-
996
- def token_sort_similarity(s1: str, s2: str) -> float:
997
- """
998
- Hybrid similarity score combining token matching with character-level Levenshtein.
999
- Returns a score where 0.0 = perfect match, higher = worse match.
1000
- """
1001
- tokens1 = tokenize(s1)
1002
- tokens2 = tokenize(s2)
1003
-
1004
- # Token set matching
1005
- set1, set2 = set(tokens1), set(tokens2)
1006
- intersection = len(set1 & set2)
1007
- union = len(set1 | set2)
1008
-
1009
- if union == 0:
1010
- token_score = 1.0
1011
- else:
1012
- token_score = 1.0 - (intersection / union) # Jaccard distance
1013
-
1014
- sorted_s1 = '_'.join(tokens1)
1015
- sorted_s2 = '_'.join(tokens2)
1016
- max_len = max(len(sorted_s1), len(sorted_s2))
1017
- if max_len == 0:
1018
- lev_score = 0.0
1019
- else:
1020
- lev_score = levenshtein(sorted_s1, sorted_s2) / max_len
1021
-
1022
- # Weighted combination: token matching (70%) + order similarity (30%)
1023
- return token_score * 0.7 + lev_score * 0.3
1024
-
1025
- # Prepare strings for comparison
1026
- if not case_sensitive:
1027
- target_cmp = target.lower()
1028
- string_list_cmp = [s.lower() for s in string_list]
1029
- else:
1030
- target_cmp = target
1031
- string_list_cmp = string_list
1032
-
1033
- # Find closest match
1034
- if token_based:
1035
- scores = [token_sort_similarity(target_cmp, s) for s in string_list_cmp]
1036
- closest_idx = min(range(len(scores)), key=lambda i: scores[i])
1037
- closest = string_list[closest_idx]
1038
- distance = scores[closest_idx]
1039
- else:
1040
- distances = [levenshtein(target_cmp, s) for s in string_list_cmp]
1041
- closest_idx = min(range(len(distances)), key=lambda i: distances[i])
1042
- closest = string_list[closest_idx]
1043
- distance = distances[closest_idx]
1044
-
1045
- return closest, distance
1046
-
1047
-
1048
-
1049
- def create_directionality_cords(bp_dict: dict,
1050
- left_ear_name: str,
1051
- nose_name: str,
1052
- right_ear_name: str) -> dict:
1053
- """
1054
- Helper to create a dictionary mapping animal body-parts (nose, left ear, right ear) to their X and Y coordinate
1055
- column names for directionality analysis.
1056
-
1057
- :param dict bp_dict: Dictionary with animal names as keys and body-part coordinate information as values. Expected to contain 'X_bps' and 'Y_bps' keys with lists of column names.
1058
- :param str left_ear_name: Name of the left ear body-part to search for in coordinate column names.
1059
- :param str nose_name: Name of the nose body-part to search for in coordinate column names.
1060
- :param str right_ear_name: Name of the right ear body-part to search for in coordinate column names.
1061
- :return: Nested dictionary with animal names as keys, body-part types (nose, ear_left, ear_right) as second-level keys, and coordinate types (X_bps, Y_bps) as third-level keys with corresponding column names as values.
1062
- :rtype: dict
1063
- :raises InvalidInputError: If any required body-part or coordinate cannot be found in the input dictionary.
1064
-
1065
- :example:
1066
- >>> bp_dict = {'Animal_1': {'X_bps': ['Animal_1_Nose_x', 'Animal_1_Ear_left_x', 'Animal_1_Ear_right_x'], 'Y_bps': ['Animal_1_Nose_y', 'Animal_1_Ear_left_y', 'Animal_1_Ear_right_y']}}
1067
- >>> create_directionality_cords(bp_dict=bp_dict, left_ear_name='Ear_left', nose_name='Nose', right_ear_name='Ear_right')
1068
- >>> {'Animal_1': {'nose': {'X_bps': 'Animal_1_Nose_x', 'Y_bps': 'Animal_1_Nose_y'}, 'ear_left': {'X_bps': 'Animal_1_Ear_left_x', 'Y_bps': 'Animal_1_Ear_left_y'}, 'ear_right': {'X_bps': 'Animal_1_Ear_right_x', 'Y_bps': 'Animal_1_Ear_right_y'}}}
1069
- """
1070
-
1071
- NOSE, EAR_LEFT, EAR_RIGHT = Keys.NOSE.value, Keys.EAR_LEFT.value, Keys.EAR_RIGHT.value
1072
-
1073
- results = {}
1074
- for animal in bp_dict.keys():
1075
- results[animal] = {NOSE: {}, EAR_LEFT: {}, EAR_RIGHT: {}}
1076
- for dimension in ["X_bps", "Y_bps"]:
1077
- for cord in bp_dict[animal][dimension]:
1078
- if (nose_name.lower() in cord.lower()) and ("x" in cord.lower()):
1079
- results[animal][NOSE]["X_bps"] = cord
1080
- elif (nose_name.lower() in cord.lower()) and ("y" in cord.lower()):
1081
- results[animal][NOSE]["Y_bps"] = cord
1082
- elif (left_ear_name.lower() in cord.lower()) and ("x" in cord.lower()):
1083
- results[animal][EAR_LEFT]["X_bps"] = cord
1084
- elif (left_ear_name.lower() in cord.lower()) and ("y" in cord.lower()):
1085
- results[animal][EAR_LEFT]["Y_bps"] = cord
1086
- elif (right_ear_name.lower() in cord.lower()) and ("x" in cord.lower()):
1087
- results[animal][EAR_RIGHT]["X_bps"] = cord
1088
- elif (right_ear_name.lower() in cord.lower()) and ("y" in cord.lower()):
1089
- results[animal][EAR_RIGHT]["Y_bps"] = cord
1090
-
1091
- for animal_name, animal_bps in results.items():
1092
- for bp_name, bp_values in animal_bps.items():
1093
- if len(bp_values.keys()) == 0:
1094
- raise InvalidInputError(msg=f'Could not detect a body-part for animal {animal_name}, body-part {bp_name} in SimBA project. Make sure the body-part configuration file at {Paths.BP_NAMES.value} lists the appropriate body-parts', source=create_directionality_cords.__name__)
1095
- for cord_key, cord_value in bp_values.items():
1096
- if cord_value == '':
1097
- raise InvalidInputError(msg=f'Could not detect a body-part for animal {animal_name}, body-part {bp_name} and coordinate {cord_key} in SimBA project. MAke sure the body-part configuration file at {Paths.BP_NAMES.value} lists the appropriate body-parts. Passed values: {left_ear_name, nose_name, right_ear_name}', source=create_directionality_cords.__name__)
1098
- return results
1099
-
1100
- def get_tooltips() -> Dict[str, str]:
1101
- simba_dir = os.path.dirname(simba.__file__)
1102
- tool_tips_path = os.path.join(simba_dir, Paths.TOOLTIPS.value)
1103
- if not os.path.isfile(tool_tips_path):
1104
- return {}
1105
- else:
1106
- return read_json(x=tool_tips_path, raise_error=False)
1107
-
1108
- def intermittent_palette(n: int = 10,
1109
- base_light: float = 0.55,
1110
- contrast_delta: float = 0.18,
1111
- seed_hue: Optional[float] = None,
1112
- output: Literal["rgb", "rgb255", "hex"] = "rgb",
1113
- rng: Optional[random.Random] = None) -> Union[List[RGBFloat], List[Tuple[int, int, int]], List[str]]:
1114
- """
1115
- Generate a categorical colour palette with evenly spaced hues and alternating lightness.
1116
-
1117
- .. note::
1118
- Use to get color palette where immediate colors are distinct.
1119
-
1120
- :param int n: Number of colours to generate. Must be greater than or equal to 1.
1121
- :param float base_light: Midpoint HSV value (0-1) used as the baseline lightness. Default ``0.55``.
1122
- :param float contrast_delta: Lightness offset added/subtracted per colour to improve visual separation. Default ``0.18``.
1123
- :param Optional[float] seed_hue: Initial hue (0-1). If ``None``, a random hue is sampled. Default ``None``.
1124
- :param str output: Output colour format. One of ``{"rgb", "rgb255", "hex"}``. Default ``"rgb"``.
1125
- :param Optional[random.Random] rng: Optional pre-seeded RNG for reproducible random starts.
1126
- :return: Colour palette in the requested format (RGB floats, RGB 0-255 integers, or hexadecimal strings).
1127
- :rtype: Union[List[Tuple[float, float, float]], List[Tuple[int, int, int]], List[str]]
1128
-
1129
- :example:
1130
- >>> palette = intermittent_palette(n=6, output="hex")
1131
- >>> palette
1132
- >>> ['#a33f46', '#51a5df', '#b36824', '#4dbd9f', '#c749b4', '#7a9a3e']
1133
- """
1134
-
1135
- fn_name = intermittent_palette.__name__
1136
- check_int(name=f"{fn_name} n", value=n, min_value=1)
1137
- check_float(name=f"{fn_name} base_light", value=base_light, min_value=0.0, max_value=1.0)
1138
- check_float(name=f"{fn_name} contrast_delta", value=contrast_delta, min_value=0.0, max_value=1.0)
1139
- if seed_hue is not None:
1140
- check_float(name=f"{fn_name} seed_hue", value=seed_hue, min_value=0.0, max_value=1.0)
1141
- check_str(name=f"{fn_name} output", value=output, options={"rgb", "rgb255", "hex"}, raise_error=True)
1142
-
1143
- if rng is not None and not isinstance(rng, random.Random):
1144
- raise InvalidInputError(msg="rng must be an instance of random.Random.", source=fn_name)
1145
-
1146
- golden_ratio = 0.618033988749895
1147
- rnd = rng or random.Random()
1148
- hue = seed_hue % 1.0 if seed_hue is not None else rnd.random()
1149
- colours: List[RGBFloat] = []
1150
-
1151
- for idx in range(n):
1152
- hue = (hue + golden_ratio) % 1.0
1153
- sat = 0.72 if idx % 3 else 0.85
1154
- light = base_light + (contrast_delta if idx % 2 else -contrast_delta)
1155
- light = min(max(light, 0.25), 0.85)
1156
- colours.append(tuple(hsv_to_rgb((hue, sat, light))))
1157
-
1158
- fmt = output.lower()
1159
- if fmt == "rgb":
1160
- return colours
1161
- elif fmt == "rgb255":
1162
- return [tuple(int(round(c * 255)) for c in colour) for colour in colours]
1163
- else:
1164
- return [rgb2hex(colour) for colour in colours]
1165
-
1166
- def quality_pct_to_crf(pct: int) -> int:
1167
- check_int(name=f'{quality_pct_to_crf.__name__} pct', min_value=1, max_value=100, raise_error=True, value=pct)
1168
- quality_lk = {int(k):v for k, v in percent_to_crf_lookup().items()}
1169
- closest_key = min(quality_lk, key=lambda k: abs(k - pct))
1170
- return quality_lk[closest_key]
1171
-
1172
-
1173
- def check_for_updates(time_out: int = 2):
1174
- """
1175
- Check for SimBA package updates by querying PyPI and comparing with the installed version.
1176
-
1177
- Fetches the latest SimBA version from PyPI and compares it with the currently installed
1178
- version. Prints an informational message indicating whether an update is available or if
1179
- the installation is up-to-date. Requires an active internet connection to query PyPI.
1180
-
1181
- :parameter int time_out: Timeout in seconds for the PyPI API request. Default is 2 seconds.
1182
- Must be at least 1 second.
1183
- :returns: None. Prints update information to stdout via stdout_information.
1184
- :raises SimBAPAckageVersionError: If the latest version cannot be fetched from PyPI, or if
1185
- the local SimBA version cannot be determined.
1186
-
1187
- :example:
1188
- >>> check_for_updates()
1189
- >>> # Prints: "UP-TO-DATE. You have the latest SimBA version (1.0.0)."
1190
- >>> # or: "NEW SimBA VERSION AVAILABLE. You have SimBA version 1.0.0. The latest version is 1.1.0..."
1191
- """
1192
- check_int(name=f'{fetch_pip_data.__name__} time_out', value=time_out, min_value=1)
1193
- _, latest_simba_version = fetch_pip_data(pip_url=r'https://pypi.org/pypi/simba-uw-tf-dev/json', time_out=time_out)
1194
- env_simba_version = OS.SIMBA_VERSION.value
1195
- if latest_simba_version is None:
1196
- raise SimBAPAckageVersionError(msg='Could not fetch latest SimBA version.', source=check_for_updates.__name__)
1197
- elif env_simba_version is None:
1198
- raise SimBAPAckageVersionError(msg='Could not get local SimBA version.', source=check_for_updates.__name__)
1199
- if latest_simba_version == env_simba_version:
1200
- msg = f'UP-TO-DATE. \nYou have the latest SimBA version ({env_simba_version}).'
1201
- else:
1202
- msg = (f'NEW SimBA VERSION AVAILABLE. \nYou have SimBA version {env_simba_version}. \nThe latest version is {latest_simba_version}. '
1203
- f'\nYou can update using "pip install simba-uw-tf-dev --upgrade"')
1
+ __author__ = "Simon Nilsson; sronilsson@gmail.com"
2
+
3
+ import glob
4
+ import os
5
+ import platform
6
+ import random
7
+ import re
8
+ import struct
9
+ import subprocess
10
+ import sys
11
+ import tkinter as tk
12
+ from copy import copy
13
+ from datetime import datetime
14
+ from multiprocessing import Lock, Value
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ try:
19
+ from typing import Literal
20
+ except:
21
+ from typing_extensions import Literal
22
+
23
+ import matplotlib.font_manager
24
+ import pandas as pd
25
+ import psutil
26
+ import pyglet
27
+ from matplotlib import cm
28
+ from matplotlib.colors import hsv_to_rgb, rgb2hex
29
+ from tabulate import tabulate
30
+
31
+ import simba
32
+ from simba.utils.checks import (check_ffmpeg_available,
33
+ check_file_exist_and_readable, check_float,
34
+ check_if_dir_exists, check_if_valid_rgb_tuple,
35
+ check_instance, check_int, check_str,
36
+ check_valid_dict, check_valid_tuple)
37
+ from simba.utils.enums import (OS, UML, Defaults, FontPaths, Formats, Keys,
38
+ Methods, Options, Paths)
39
+ from simba.utils.errors import (FFMPEGNotFoundError, InvalidInputError,
40
+ NoFilesFoundError, SimBAPAckageVersionError)
41
+ from simba.utils.printing import stdout_information
42
+ from simba.utils.read_write import (fetch_pip_data,
43
+ find_files_of_filetypes_in_directory,
44
+ get_fn_ext, get_video_meta_data, read_json)
45
+ from simba.utils.warnings import NoDataFoundWarning
46
+
47
+ if platform.system() == OS.WINDOWS.value:
48
+ from pyglet.libs.win32 import constants
49
+ constants.COINIT_MULTITHREADED = 0x2 # 0x2 = COINIT_APARTMENTTHREADED
50
+
51
+
52
+ RGBFloat = Tuple[float, float, float]
53
+
54
+ class SharedCounter(object):
55
+ """Counter that can be shared across processes on different cores"""
56
+
57
+ def __init__(self, initval=0):
58
+ self.val = Value("i", initval)
59
+ self.lock = Lock()
60
+
61
+ def increment(self):
62
+ with self.lock:
63
+ self.val.value += 1
64
+
65
+ def value(self):
66
+ with self.lock:
67
+ return self.val.value
68
+
69
+
70
+ def get_body_part_configurations() -> Dict[str, Union[str, os.PathLike]]:
71
+ """
72
+ Return dict with named body-part schematics of pose-estimation schemas in SimBA installation as keys,
73
+ and paths to the images representing those body-part schematics as values.
74
+ """
75
+
76
+ lookup = {}
77
+ simba_dir = os.path.dirname(simba.__file__)
78
+ img_dir = os.path.join(simba_dir, Paths.SCHEMATICS.value)
79
+ names_path = os.path.join(simba_dir, Paths.PROJECT_POSE_CONFIG_NAMES.value)
80
+ check_file_exist_and_readable(file_path=names_path)
81
+ check_if_dir_exists(in_dir=img_dir)
82
+ names_lst = list(pd.read_csv(names_path, header=None)[0])
83
+ img_paths = glob.glob(img_dir + "/*.png")
84
+ img_paths.sort(
85
+ key=lambda v: [
86
+ int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", v)
87
+ ]
88
+ )
89
+ for name, img_path in zip(names_lst, img_paths):
90
+ lookup[name] = {}
91
+ lookup[name]["img_path"] = img_path
92
+ return lookup
93
+
94
+
95
+ def get_bp_config_codes() -> Dict[str, str]:
96
+ """
97
+ Helper to match SimBA project_config.ini [create ensemble settings][pose_estimation_body_parts] to string names.
98
+
99
+ """
100
+
101
+ return {
102
+ "1 animal; 4 body-parts": "4",
103
+ "1 animal; 7 body-parts": "7",
104
+ "1 animal; 8 body-parts": "8",
105
+ "1 animal; 9 body-parts": "9",
106
+ "2 animals; 8 body-parts": "8",
107
+ "2 animals; 14 body-parts": "14",
108
+ "2 animals; 16 body-parts": "16",
109
+ "MARS": Methods.USER_DEFINED.value,
110
+ "Multi-animals; 4 body-parts": "8",
111
+ "Multi-animals; 7 body-parts": "14",
112
+ "Multi-animals; 8 body-parts": "16",
113
+ "3D tracking": "3D_user_defined",
114
+ "AMBER": "AMBER",
115
+ "SimBA BLOB Tracking": Methods.SIMBA_BLOB.value,
116
+ "FaceMap": Methods.FACEMAP.value,
117
+ "SuperAnimal-TopView": Methods.SUPER_ANIMAL_TOPVIEW.value
118
+ }
119
+
120
+
121
+ def get_bp_config_code_class_pairs() -> Dict[str, object]:
122
+ """
123
+ Helper to match SimBA project_config.ini [create ensemble settings][pose_estimation_body_parts] setting to feature extraction module class.
124
+ """
125
+
126
+ from simba.feature_extractors.amber_feature_extractor import \
127
+ AmberFeatureExtractor
128
+ from simba.feature_extractors.feature_extractor_4bp import \
129
+ ExtractFeaturesFrom4bps
130
+ from simba.feature_extractors.feature_extractor_7bp import \
131
+ ExtractFeaturesFrom7bps
132
+ from simba.feature_extractors.feature_extractor_8bp import \
133
+ ExtractFeaturesFrom8bps
134
+ from simba.feature_extractors.feature_extractor_8bps_2_animals import \
135
+ ExtractFeaturesFrom8bps2Animals
136
+ from simba.feature_extractors.feature_extractor_9bp import \
137
+ ExtractFeaturesFrom9bps
138
+ from simba.feature_extractors.feature_extractor_14bp import \
139
+ ExtractFeaturesFrom14bps
140
+ from simba.feature_extractors.feature_extractor_16bp import \
141
+ ExtractFeaturesFrom16bps
142
+ from simba.feature_extractors.feature_extractor_user_defined import \
143
+ UserDefinedFeatureExtractor
144
+
145
+ return {
146
+ "16": ExtractFeaturesFrom16bps,
147
+ "14": ExtractFeaturesFrom14bps,
148
+ "9": ExtractFeaturesFrom9bps,
149
+ "8": {1: ExtractFeaturesFrom8bps, 2: ExtractFeaturesFrom8bps2Animals},
150
+ "7": ExtractFeaturesFrom7bps,
151
+ "4": ExtractFeaturesFrom4bps,
152
+ "user_defined": UserDefinedFeatureExtractor,
153
+ "AMBER": AmberFeatureExtractor,
154
+ }
155
+
156
+
157
+
158
+ def rgb_to_hex(color: Tuple[int, int, int]) -> str:
159
+ check_if_valid_rgb_tuple(data=color, raise_error=True, source=rgb_to_hex.__name__)
160
+ r, g, b = color
161
+ return rgb2hex((r/255, g/255, b/255), keep_alpha=False)
162
+
163
+ def get_icons_paths() -> Dict[str, Union[str, os.PathLike]]:
164
+ """
165
+ Helper to get dictionary with icons with the icon names as keys (grabbed from file-name) and their
166
+ file paths as values.
167
+ """
168
+
169
+ simba_dir = os.path.dirname(simba.__file__)
170
+ icons_dir = os.path.join(simba_dir, Paths.ICON_ASSETS.value)
171
+ icon_paths = glob.glob(icons_dir + "/*.png")
172
+ icons = {}
173
+ for icon_path in icon_paths:
174
+ _, icon_name, _ = get_fn_ext(icon_path)
175
+ icons[icon_name] = {}
176
+ icons[icon_name]["icon_path"] = icon_path
177
+ return icons
178
+
179
+ def load_simba_fonts():
180
+ """ Load fonts defined in simba.utils.enums.FontPaths into memory"""
181
+ simba_dir = os.path.dirname(simba.__file__)
182
+ font_enum = {i.name: i.value for i in FontPaths}
183
+ for k, v in font_enum.items():
184
+ pyglet.font.add_file(os.path.join(simba_dir, v))
185
+
186
+ def get_third_party_appender_file_formats() -> Dict[str, str]:
187
+ """
188
+ Helper to get dictionary that maps different third-party annotation tools with different file formats.
189
+ """
190
+
191
+ return {
192
+ "BORIS": "csv",
193
+ "ETHOVISION": "xlsx",
194
+ "OBSERVER": "xlsx",
195
+ "SOLOMON": "csv",
196
+ "DEEPETHOGRAM": "csv",
197
+ "BENTO": "annot",
198
+ }
199
+
200
+
201
+ def get_emojis() -> Dict[str, str]:
202
+ """
203
+ Helper to get dictionary of emojis with names as keys and emojis as values. Note, the same emojis are
204
+ represented differently in different python versions.
205
+ """
206
+ python_version = str(f"{sys.version_info.major}.{sys.version_info.minor}")
207
+ if python_version == "3.6":
208
+ return {"thank_you": "".join(chr(x) for x in struct.unpack(">2H", "\U0001f64f".encode("utf-16be"))),
209
+ "relaxed": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F600".encode("utf-16be"))),
210
+ "error": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F6A8".encode("utf-16be"))),
211
+ "complete": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F680".encode("utf-16be"))),
212
+ "warning": "".join(chr(x) for x in struct.unpack(">2H", "\u2757\uFE0F".encode("utf-16be"))),
213
+ "trash": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F5D1".encode("utf-16be"))),
214
+ "information": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F4DD".encode("utf-16be")))} # 📝 memo
215
+
216
+ elif python_version == "3.10" or python_version == "3.9":
217
+ return {
218
+ "thank_you": "\U0001f64f".encode("utf-8", "replace").decode(),
219
+ "relaxed": "\U0001F600".encode("utf-8", "replace").decode(),
220
+ "warning": "\u2757\uFE0F".encode("utf-8", "replace").decode(),
221
+ "error": "\U0001F6A8".encode("utf-8", "replace").decode(),
222
+ "complete": "\U0001F680".encode("utf-8", "replace").decode(),
223
+ "trash": "\U0001F5D1".encode("utf-8", "replace").decode(),
224
+ "information": "\U0001F4DD".encode("utf-8", "replace").decode(), # 📝 memo
225
+ }
226
+
227
+ elif python_version == "3.7":
228
+ return {
229
+ "thank_you": "\U0001f64f".encode("utf16", errors="surrogatepass").decode(
230
+ "utf16"
231
+ ),
232
+ "relaxed": "\U0001F600".encode("utf16", errors="surrogatepass").decode(
233
+ "utf16"
234
+ ),
235
+ "error": "\U0001F6A8".encode("utf16", errors="surrogatepass").decode(
236
+ "utf16"
237
+ ),
238
+ "complete": "\U0001F680".encode("utf16", errors="surrogatepass").decode(
239
+ "utf16"
240
+ ),
241
+ "warning": "\u2757\uFE0F".encode("utf16", errors="surrogatepass").decode(
242
+ "utf16"
243
+ ),
244
+ "trash": "\U0001F5D1".encode("utf16", errors="surrogatepass").decode(
245
+ "utf16"
246
+ ),
247
+ "information": "\U0001F4DD".encode("utf16", errors="surrogatepass").decode(
248
+ "utf16"
249
+ ), # 📝 memo
250
+ }
251
+
252
+ else:
253
+ return {
254
+ "thank_you": "\U0001f64f",
255
+ "relaxed": "\U0001F600",
256
+ "error": "\U0001F6A8",
257
+ "complete": "\U0001F680",
258
+ "warning": "\u2757\uFE0F",
259
+ "trash": "\U0001F5D1",
260
+ "information": "\U0001F4DD",
261
+ # 📝 memo
262
+ }
263
+
264
+ def get_meta_data_file_headers() -> List[str]:
265
+ """
266
+ Get List of headers for SimBA classifier metadata output.
267
+
268
+ :return List[str]:
269
+ """
270
+
271
+ return [
272
+ "Classifier_name",
273
+ "RF_criterion",
274
+ "RF_max_features",
275
+ "RF_min_sample_leaf",
276
+ "RF_n_estimators",
277
+ "compute_feature_permutation_importance",
278
+ "generate_classification_report",
279
+ "generate_example_decision_tree",
280
+ "generate_features_importance_bar_graph",
281
+ "generate_features_importance_log",
282
+ "generate_precision_recall_curves",
283
+ "generate_rf_model_meta_data_file",
284
+ "generate_sklearn_learning_curves",
285
+ "learning_curve_data_splits",
286
+ "learning_curve_k_splits",
287
+ "n_feature_importance_bars",
288
+ "over_sample_ratio",
289
+ "over_sample_setting",
290
+ "train_test_size",
291
+ "train_test_split_type",
292
+ "under_sample_ratio",
293
+ "under_sample_setting",
294
+ "class_weight",
295
+ 'rf_max_depth',
296
+ ]
297
+
298
+ def get_cmaps() -> List[str]:
299
+ """
300
+ Get list of named matplotlib color palettes.
301
+ """
302
+ return [
303
+ "spring",
304
+ "summer",
305
+ "autumn",
306
+ "cool",
307
+ "Wistia",
308
+ "Pastel1",
309
+ "Set1",
310
+ "winter",
311
+ "afmhot",
312
+ "gist_heat",
313
+ "copper",
314
+ ]
315
+
316
+
317
+ def get_categorical_palettes():
318
+ return [
319
+ "Pastel1",
320
+ "Pastel2",
321
+ "Paired",
322
+ "Accent",
323
+ "Dark2",
324
+ "Set1",
325
+ "Set2",
326
+ "Set3",
327
+ "tab10",
328
+ "tab20",
329
+ "tab20b",
330
+ "tab20c",
331
+ ]
332
+
333
+
334
+ def get_color_dict() -> Dict[str, Tuple[int, int, int]]:
335
+ """
336
+ Get dict of color names as keys and RGB tuples as values
337
+ """
338
+
339
+ return {
340
+ "Grey": (220, 200, 200),
341
+ "Red": (0, 0, 255),
342
+ "Dark-red": (0, 0, 139),
343
+ "Maroon": (0, 0, 128),
344
+ "Orange": (0, 165, 255),
345
+ "Dark-orange": (0, 140, 255),
346
+ "Coral": (80, 127, 255),
347
+ "Chocolate": (30, 105, 210),
348
+ "Yellow": (0, 255, 255),
349
+ "Green": (0, 128, 0),
350
+ "Dark-grey": (105, 105, 105),
351
+ "Light-grey": (192, 192, 192),
352
+ "Pink": (178, 102, 255),
353
+ "Lime": (204, 255, 229),
354
+ "Purple": (255, 51, 153),
355
+ "Cyan": (255, 255, 102),
356
+ "White": (255, 255, 255),
357
+ "Black": (0, 0, 0),
358
+ "Darkgoldenrod": (184, 134, 11),
359
+ "Olive": (109, 113, 46),
360
+ "Seagreen": (46, 139, 87),
361
+ "Dodgerblue": (30, 144, 255),
362
+ "Springgreen": (0, 255, 127),
363
+ "Firebrick": (178, 34, 34),
364
+ "Indigo": (63, 15, 183),
365
+ }
366
+
367
+
368
+ def get_named_colors() -> List[str]:
369
+ """
370
+ Get list of named matplotlib colors.
371
+ """
372
+ return [
373
+ "red",
374
+ "pink",
375
+ "lime",
376
+ "gold",
377
+ "coral",
378
+ "lavender",
379
+ "sienna",
380
+ "tomato",
381
+ "grey",
382
+ "azure",
383
+ "crimson",
384
+ "lightgrey",
385
+ "aqua",
386
+ "plum",
387
+ "blue",
388
+ "teal",
389
+ "maroon",
390
+ "green",
391
+ "black",
392
+ "deeppink",
393
+ "darkgoldenrod",
394
+ "purple",
395
+ "olive",
396
+ "seagreen",
397
+ "dodgerblue",
398
+ "springgreen",
399
+ "firebrick",
400
+ "indigo",
401
+ "white",
402
+ ]
403
+
404
+
405
+ def create_color_palettes(no_animals: int, map_size: int) -> List[List[int]]:
406
+ """
407
+ Create list of lists of bgr colors, one for each animal. Each list is pulled from a different palette
408
+ matplotlib color map.
409
+
410
+ :param int no_animals: Number of different palette lists
411
+ :param int map_size: Number of colors in each created palette.
412
+ :return List[List[int]]: BGR colors
413
+
414
+ :example:
415
+ >>> create_color_palettes(no_animals=2, map_size=2)
416
+ >>> [[[255.0, 0.0, 255.0], [0.0, 255.0, 255.0]], [[102.0, 127.5, 0.0], [102.0, 255.0, 255.0]]]
417
+ """
418
+ colorListofList = []
419
+ cmaps = [
420
+ "spring",
421
+ "summer",
422
+ "autumn",
423
+ "cool",
424
+ "Wistia",
425
+ "Pastel1",
426
+ "Set1",
427
+ "winter",
428
+ "afmhot",
429
+ "gist_heat",
430
+ "copper",
431
+ "viridis",
432
+ "Set3",
433
+ "Set2",
434
+ "Paired",
435
+ "seismic",
436
+ "prism",
437
+ "ocean",
438
+ ]
439
+
440
+ for colormap in range(no_animals):
441
+ if hasattr(cm, "cmap_d") and colormap in cm.cmap_d:
442
+ currColorMap = cm.get_cmap(cmaps[colormap], map_size)
443
+ else:
444
+ currColorMap = cm.get_cmap("spring", map_size)
445
+ currColorList = []
446
+ for i in range(currColorMap.N):
447
+ rgb = list((currColorMap(i)[:3]))
448
+ rgb = [i * 255 for i in rgb]
449
+ rgb.reverse()
450
+ currColorList.append(rgb)
451
+ colorListofList.append(currColorList)
452
+ return colorListofList
453
+
454
+
455
+
456
+ def get_random_color_palette(n_colors: int):
457
+ """ Get a random color palette with N random colors."""
458
+ check_int(name=f'{get_random_color_palette.__name__} n_colors', value=n_colors, min_value=1, raise_error=True)
459
+ return [tuple(random.randint(0, 255) for _ in range(3)) for _ in range(n_colors)]
460
+
461
+ def cardinality_to_integer_lookup() -> Dict[str, int]:
462
+ """
463
+ Create dictionary that maps cardinal compass directions to integers.
464
+
465
+ :example:
466
+ >>> data = ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]
467
+ >>> [cardinality_to_integer_lookup()[d] for d in data]
468
+ >>> [0, 1, 2, 3, 4, 5, 6, 7]
469
+ """
470
+
471
+ return {"N": 0, "NE": 1, "E": 2, "SE": 3, "S": 4, "SW": 5, "W": 6, "NW": 7}
472
+
473
+
474
+ def integer_to_cardinality_lookup():
475
+ """
476
+ Create dictionary that maps integers to cardinal compass directions.
477
+ """
478
+ return {0: "N", 1: "NE", 2: "E", 3: "SE", 4: "S", 5: "SW", 6: "W", 7: "NW"}
479
+
480
+
481
+ def percent_to_crf_lookup() -> Dict[str, int]:
482
+ """
483
+ Create dictionary that matches human-readable percent values to FFmpeg Constant Rate Factor (CRF)
484
+ values that regulates video quality in CPU codecs. Higher CRF values translates to lower video quality and reduced
485
+ file sizes.
486
+ """
487
+ return {
488
+ "10": 37,
489
+ "20": 34,
490
+ "30": 31,
491
+ "40": 28,
492
+ "50": 25,
493
+ "60": 22,
494
+ "70": 19,
495
+ "80": 16,
496
+ "90": 13,
497
+ "100": 10,
498
+ }
499
+
500
+
501
+ def gpu_quality_to_cpu_quality_lk():
502
+ return {'fast': 34,
503
+ 'medium': 23,
504
+ 'slow': 13}
505
+
506
+ def percent_to_qv_lk():
507
+ """
508
+ Create dictionary that matches human-readable percent values to FFmpeg regulates video quality in CPU codecs.
509
+ Higher FFmpeg quality scores maps to smaller, lower quality videos. Used in some AVI codecs such as 'divx' and 'mjpeg'.
510
+ """
511
+ return {100: 3,
512
+ 90: 5,
513
+ 80: 7,
514
+ 70: 9,
515
+ 60: 11,
516
+ 50: 13,
517
+ 40: 15,
518
+ 30: 17,
519
+ 20: 19,
520
+ 10: 21}
521
+
522
+ def get_ffmpeg_crossfade_methods():
523
+ return ['fade',
524
+ 'fadeblack',
525
+ 'fadewhite',
526
+ 'distance',
527
+ 'wipeleft',
528
+ 'wiperight',
529
+ 'wipeup',
530
+ 'wipedown',
531
+ 'sideleft',
532
+ 'sideright',
533
+ 'sideup',
534
+ 'sidedown',
535
+ 'smoothleft',
536
+ 'smoothright',
537
+ 'smoothup',
538
+ 'smoothdown',
539
+ 'circlecrop',
540
+ 'rectcrop',
541
+ 'circleclose',
542
+ 'circleopen',
543
+ 'horzclose',
544
+ 'horzopen',
545
+ 'vertclose',
546
+ 'vertopen',
547
+ 'diagbl',
548
+ 'diagbr',
549
+ 'diagtl',
550
+ 'diagtr',
551
+ 'hlslice',
552
+ 'hrslice',
553
+ 'vuslice',
554
+ 'vdslice',
555
+ 'dissolve',
556
+ 'pixelize',
557
+ 'radial',
558
+ 'hblur',
559
+ 'wipetl',
560
+ 'wipetr',
561
+ 'wipebl',
562
+ 'wipebr',
563
+ 'fadegrays',
564
+ 'squeezev',
565
+ 'squeezeh',
566
+ 'zoomin',
567
+ 'hlwind',
568
+ 'hrwind',
569
+ 'vuwind',
570
+ 'vdwind',
571
+ 'coverleft',
572
+ 'coverright',
573
+ 'cobverup',
574
+ 'coverdown',
575
+ 'revealleft',
576
+ 'revealright',
577
+ 'revealup',
578
+ 'revealdown']
579
+
580
+
581
+ def video_quality_to_preset_lookup() -> Dict[str, str]:
582
+ """
583
+ Create dictionary that matches human-readable video quality settings to FFmpeg presets for GPU codecs.
584
+ """
585
+ return {"Low": "fast", "Medium": "medium", "High": "slow"}
586
+
587
+
588
+ def get_labelling_img_kbd_bindings() -> dict:
589
+ """
590
+ Returns dictionary of tkinter keyboard bindings.
591
+
592
+ .. note::
593
+ Change ``kbd`` values to change keyboard shortcuts. For example:
594
+
595
+ Some possible examples:
596
+ <Key>, <KeyPress>, <KeyRelease>: Binds to any key press or release.
597
+ <KeyPress-A>, <Key-a>: Binds to the 'a' key press (case sensitive).
598
+ <Up>, <Down>, <Left>, <Right>: Binds to the arrow keys.
599
+ <Control-KeyPress-A>, <Control-a>: Binds to Ctrl + A or Ctrl + a
600
+ """
601
+ return \
602
+ {'frame+1': # MOVE FORWARD 1 FRAME
603
+ {'label': 'Right Arrow = +1 frame',
604
+ 'kbd': "<Right>"},
605
+ 'frame-1': # MOVE BACK 1 FRAME
606
+ {'label': 'Left Arrow = -1 frame',
607
+ 'kbd': "<Left>"},
608
+ 'save': # SAVE CURRENT ANNOTATIONS STATS TO DISK
609
+ {'label': 'Ctrl + s = Save annotations file',
610
+ 'kbd': "<Control-s>"},
611
+ 'frame+1_keep_choices': # MOVE FORWARD 1 FRAME AND KEEP ANNOTATION SELECTIONS OF THE CURRENT FRAME
612
+ {'label': 'Ctrl + a = +1 frame and keep choices',
613
+ 'kbd': "<Control-a>"},
614
+ 'frame-1_keep_choices': # MOVE BACKWARDS 1 FRAME AND KEEP ANNOTATION SELECTIONS OF THE CURRENT FRAME
615
+ {'label': 'Ctrl + q = -1 frame and keep choices',
616
+ 'kbd': "<Control-q>"},
617
+ 'print_annotation_statistic': # PRINT ANNOTATION STATISTICS
618
+ {'label': 'Ctrl + p = Show annotation statistics',
619
+ 'kbd': "<Control-p>"},
620
+ 'last_frame': # SHOW LAST FRAME
621
+ {'label': 'Ctrl + l = Last frame',
622
+ 'kbd': "<Control-l>"},
623
+ 'first_frame': # SHOW FIRT FRAME
624
+ {'label': 'Ctrl + o = First frame',
625
+ 'kbd': "<Control-o>"}
626
+ }
627
+
628
+ def get_labelling_video_kbd_bindings() -> dict:
629
+ """
630
+ Returns a dictionary of OpenCV-compatible keyboard bindings for video labeling.
631
+
632
+ Notes:
633
+ - Change the `kbd` values to customize keyboard shortcuts.
634
+ - OpenCV key codes differ from Tkinter bindings (see `get_labelling_img_kbd_bindings`).
635
+ - Use either single-character strings (e.g. 'p') or integer ASCII codes (e.g. 32 for space bar).
636
+
637
+ Examples:
638
+ Remap space bar to Pause/Play:
639
+ {'Pause/Play': {'label': 'Space = Pause/Play', 'kbd': 32}}
640
+ """
641
+
642
+ bindings = {
643
+ 'Pause/Play': {
644
+ 'label': 'p = Pause/Play',
645
+ 'kbd': 'p'
646
+ },
647
+ 'forward_two_frames': {
648
+ 'label': 'o = +2 frames',
649
+ 'kbd': 'o'
650
+ },
651
+ 'forward_ten_frames': {
652
+ 'label': 'e = +10 frames',
653
+ 'kbd': 'e'
654
+ },
655
+ 'forward_one_second': {
656
+ 'label': 'w = +1 second',
657
+ 'kbd': 'w'
658
+ },
659
+ 'backwards_two_frames': {
660
+ 'label': 't = -2 frames',
661
+ 'kbd': 't'
662
+ },
663
+ 'backwards_ten_frames': {
664
+ 'label': 's = -10 frames',
665
+ 'kbd': 's'
666
+ },
667
+ 'backwards_one_second': {
668
+ 'label': 'x = -1 second',
669
+ 'kbd': 'x'
670
+ },
671
+ 'close_window': {
672
+ 'label': 'q = Close video window',
673
+ 'kbd': 'q'
674
+ },
675
+ }
676
+
677
+
678
+ #PERFORM CHECKS THAT BINDINGS ARE DEFINED CORRECTLY.
679
+ check_valid_dict( x=bindings, valid_key_dtypes=(str,), valid_values_dtypes=(dict,), source=f'{get_labelling_video_kbd_bindings.__name__} bindings')
680
+ cleaned_bindings = {}
681
+ for action, config in bindings.items():
682
+ check_valid_dict(x=config, valid_key_dtypes=(str,), valid_values_dtypes=(str, int), required_keys=('label', 'kbd'))
683
+ kbd_val = config['kbd']
684
+ check_str(value=config['label'], allow_blank=False, raise_error=True, name=f'{get_labelling_video_kbd_bindings.__name__} action')
685
+ if check_int(name=f'{action} kbd', value=kbd_val, raise_error=False)[0]:
686
+ new_config = copy(config)
687
+ new_config['kbd'] = int(kbd_val)
688
+ cleaned_bindings[action] = new_config
689
+ else:
690
+ cleaned_bindings[action] = config
691
+
692
+ return cleaned_bindings
693
+
694
+
695
+ def get_fonts(sort_alphabetically: bool = False):
696
+ """ Returns a dictionary with all fonts available in OS, with the font name as key and font path as value"""
697
+ font_dict = {f.name: f.fname for f in matplotlib.font_manager.fontManager.ttflist if not f.name.startswith('.')}
698
+ if len(font_dict) == 0:
699
+ NoDataFoundWarning(msg='No fonts found on disk using matplotlib.font_manager', source=get_fonts.__name__)
700
+ if platform.system() == OS.WINDOWS.value:
701
+ font_dict = {key: str(Path(value.replace('C:', '')).as_posix()) for key, value in font_dict.items()}
702
+ if sort_alphabetically:
703
+ font_dict = dict(sorted(font_dict.items(), key=lambda x: [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', x[0])]))
704
+ return font_dict
705
+
706
+ def get_log_config():
707
+ return {
708
+ "version": 1,
709
+ "disable_existing_loggers": False,
710
+ "formatters": {
711
+ "default": {
712
+ "format": "%(asctime)s|%(name)s||%(message)s",
713
+ "datefmt": "%Y-%m-%dT%H:%M:%SZ",
714
+ # "class": "pythonjsonlogger.jsonlogger.JsonFormatter",
715
+ }
716
+ },
717
+ "handlers": {
718
+ "file_handler": {
719
+ "class": "logging.handlers.RotatingFileHandler",
720
+ "formatter": "default",
721
+ "mode": "a",
722
+ "backupCount": 5,
723
+ "maxBytes": 5000000,
724
+ }
725
+ },
726
+ "loggers": {"": {"level": "INFO", "handlers": ["file_handler"]}},
727
+ }
728
+
729
+ def get_model_names():
730
+ model_names_dir = os.path.join(os.path.dirname(simba.__file__), Paths.UNSUPERVISED_MODEL_NAMES.value)
731
+ return list(pd.read_parquet(model_names_dir)[UML.NAMES.value])
732
+
733
+ def win_to_wsl_path(win_path: Union[str, os.PathLike]) -> str:
734
+ """Helper to convert a windows path name, to a WSL path name"""
735
+ result = subprocess.run(["wsl.exe", "wslpath", win_path], capture_output=True, text=True)
736
+ if result.returncode != 0:
737
+ raise RuntimeError(f"WSL path conversion failed: {result.stderr}")
738
+ return result.stdout.strip()
739
+
740
+
741
+ def get_available_ram():
742
+ total = psutil.virtual_memory().total
743
+ available = psutil.virtual_memory().available
744
+ total_mb = total / (1024 ** 2)
745
+ available_mb = available / (1024 ** 2)
746
+
747
+ results = {}
748
+ results["bytes"] = total
749
+ results["available_bytes"] = available
750
+ results["megabytes"] = total_mb
751
+ results["available_mb"] = available_mb
752
+ results["gigabytes"] = total_mb / 1024
753
+ results["available_gb"] = available_mb / 1024
754
+
755
+ return results
756
+
757
+
758
+ def get_current_time():
759
+ return datetime.now().strftime("%H:%M:%S")
760
+
761
+
762
+ def get_display_resolution() -> Tuple[int, int]:
763
+ """
764
+ Helper to get main monitor / display resolution.
765
+
766
+ .. note::
767
+ May return the virtual geometry in multi-display setups. To return the resolution of each available monitor in mosaic, see :func:`simba.utils.lookups.get_monitor_info`.
768
+
769
+ """
770
+ root = tk.Tk()
771
+ root.withdraw()
772
+ width = root.winfo_screenwidth()
773
+ height = root.winfo_screenheight()
774
+ root.destroy()
775
+ return (width, height)
776
+
777
+
778
+ def get_img_resize_info(img_size: Tuple[int ,int],
779
+ display_resolution: Optional[Tuple[int, int]] = None,
780
+ max_height_ratio: float = 0.5,
781
+ max_width_ratio: float = 0.5,
782
+ min_height_ratio: float = 0.0,
783
+ min_width_ratio: float = 0.0) -> Tuple[int, int, float, float]:
784
+ """
785
+ Calculates the new dimensions and scaling factors needed to resize an image while preserving its
786
+ aspect ratio so that it fits within a given portion of the display resolution.
787
+
788
+ :param Tuple[int, int] img_size : The original size of the image as (width, height).
789
+ :param Optional[Tuple[int, int]] display_resolution: Optional resolution of the display as (width, height). If none, then grabs the resolution of the main monitor.
790
+ :param float max_height_ratio: The maximum allowed height of the image as a fraction of the display height (default is 0.5).
791
+ :param float max_width_ratio: The maximum allowed width of the image as a fraction of the display width (default is 0.5).
792
+ :returns: Length 4 tuple with resized width, resized height, downscale factor, and upscale factor
793
+ :rtype: Tuple[int, int, float, float]
794
+ """
795
+
796
+ if display_resolution is None:
797
+ _, display_resolution = get_monitor_info()
798
+ max_width = round(display_resolution[0] * max_width_ratio)
799
+ max_height = round(display_resolution[1] * max_height_ratio)
800
+ min_width = round(display_resolution[0] * min_width_ratio)
801
+ min_height = round(display_resolution[1] * min_height_ratio)
802
+
803
+ if img_size[1] > max_width or img_size[0] > max_height:
804
+ width_ratio = max_width / img_size[0]
805
+ height_ratio = max_height / img_size[1]
806
+ downscale_factor = min(width_ratio, height_ratio)
807
+ upscale_factor = 1 / downscale_factor
808
+ new_width = round(img_size[0] * downscale_factor)
809
+ new_height = round(img_size[1] * downscale_factor)
810
+ return new_width, new_height, downscale_factor, upscale_factor
811
+
812
+
813
+ elif img_size[1] < min_width or img_size[0] < min_height:
814
+ width_ratio = min_width / img_size[0]
815
+ height_ratio = min_height / img_size[1]
816
+ scale = max(width_ratio, height_ratio) # ensures both dimensions meet or exceed min
817
+ new_width = round(round(img_size[0] * scale))
818
+ new_height = round(round(img_size[1] * scale))
819
+ return new_width, new_height, scale, 1 / scale
820
+
821
+ else:
822
+ return img_size[0], img_size[1], 1, 1
823
+
824
+ def is_running_in_ide():
825
+ return hasattr(sys, 'ps1') or sys.flags.interactive
826
+
827
+
828
+ def get_monitor_info() -> Tuple[Dict[int, Dict[str, Union[int, bool]]], Tuple[int, int]]:
829
+ """
830
+ Helper to get main monitor / display resolution.
831
+
832
+ .. note::
833
+ Returns dict containing the resolution of each available monitor. To get the virtual geometry, see :func:`simba.utils.lookups.get_display_resolution`, and tuple of main monitor width and height.
834
+ """
835
+ monitors = pyglet.canvas.get_display().get_screens()
836
+ results = {}
837
+ for monitor_cnt, monitor_info in enumerate(monitors):
838
+ primary = True if monitor_info.x == 0 and monitor_info.y == 0 else False
839
+ results[monitor_cnt] = {'width': monitor_info.width,
840
+ 'height': monitor_info.height,
841
+ 'primary': primary}
842
+
843
+ main_monitor = next(({'width': v['width'], 'height': v['height']} for v in results.values() if v.get('primary')), {'width': next(iter(results.values()))['width'], 'height': next(iter(results.values()))['height']})
844
+
845
+ return results, (int(main_monitor['width']), int(main_monitor['height']))
846
+
847
+
848
+
849
+ def get_table(data: Dict[str, Any],
850
+ headers: Optional[Tuple[str, str]] = ("SETTING", "VALUE"),
851
+ tablefmt: str = "grid") -> str:
852
+ """
853
+ Create a formatted table string from dictionary data using the tabulate library.
854
+
855
+ Converts a dictionary into a formatted table string suitable for display
856
+ or printing. Each key-value pair in the dictionary becomes a row in the table.
857
+
858
+ :param Dict[str, Any] data: Dictionary containing the data to be formatted as a table. Keys become the first column, values become the second column.
859
+ :param Optional[Tuple[str, str]] headers: Tuple of two strings representing the column headers. Default is ("SETTING", "VALUE").
860
+ :param Literal["grid"] tablefmt: Table format style. For options, see simba.utils.enums.Formats.VALID_TABLEFMT
861
+ :return str: Formatted table string ready for display or printing.
862
+
863
+ :example:
864
+ >>> data = {"fps": 30, "width": 1920, "height": 1080, "frame_count": 3000}
865
+ >>> table = get_table(data=data, headers=("PARAMETER", "VALUE"))
866
+ """
867
+
868
+ check_valid_dict(x=data, valid_key_dtypes=(str,), min_len_keys=1, source=f'{get_table.__name__} data')
869
+ check_valid_tuple(x=headers, source=f'{get_table.__name__} data', accepted_lengths=(2,), valid_dtypes=(str,))
870
+ check_str(name=f'{get_table.__name__} tablefmt', value=tablefmt, options=Formats.VALID_TABLEFMT.value, raise_error=True)
871
+ table_view = [[key, data[key]] for key in data]
872
+ return tabulate(table_view, headers=headers, tablefmt=tablefmt)
873
+
874
+
875
+
876
+
877
+ def print_video_meta_data(data_path: Union[str, os.PathLike]) -> None:
878
+ """
879
+ Print video metadata as formatted tables to the console.
880
+
881
+ This function reads video metadata from either a single video file or all video files
882
+ in a directory, then prints the metadata as formatted tables.
883
+
884
+ .. seealso::
885
+ To get video metadata as a dictionary without printing, use :func:`simba.utils.read_write.get_video_meta_data`.
886
+ To get video metadata as a table without printing, use :func:`simba.utils.lookups.get_table`.
887
+
888
+ :param Union[str, os.PathLike] data_path: Path to video file or directory containing videos.
889
+ :return: None. Video metadata is printed as formatted tables in the main console.
890
+ """
891
+
892
+ if os.path.isfile(data_path):
893
+ video_meta_data = [get_video_meta_data(video_path=data_path)]
894
+ elif os.path.isdir(data_path):
895
+ video_paths = find_files_of_filetypes_in_directory(directory=data_path, extensions=Options.ALL_VIDEO_FORMAT_OPTIONS.value, raise_error=False)
896
+ video_meta_data = [get_video_meta_data(video_path=x) for x in video_paths]
897
+ else:
898
+ raise NoFilesFoundError(msg=f'{data_path} is not a valid file or directory path', source=print_video_meta_data.__name__)
899
+ for video_meta in video_meta_data:
900
+ table = get_table(data=video_meta, headers=('VIDEO PARAMETER', 'VALUE'), tablefmt='grid')
901
+ print(f"{table} {Defaults.STR_SPLIT_DELIMITER.value}TABLE")
902
+
903
+
904
+ def get_ffmpeg_encoders(raise_error: bool = True, alphabetically_sorted: bool = False) -> List[str]:
905
+ """
906
+ Get a list of all available FFmpeg encoders.
907
+
908
+ :param bool raise_error: If True, raises an exception when FFmpeg is not available or the command fails. If False, returns an empty list on error. Default: True.
909
+ :return: List of encoder names (e.g., ['libx264', 'aac', 'libvpx', ...]). Returns empty list if FFmpeg is unavailable and raise_error=False.
910
+ :rtype: List[str]
911
+
912
+ :example:
913
+ >>> codecs = get_ffmpeg_encoders()
914
+ >>> print(Formats.BATCH_CODEC.value in codecs)
915
+ """
916
+
917
+ check_ffmpeg_available(raise_error=True)
918
+ try:
919
+ proc = subprocess.Popen(['ffmpeg', '-encoders'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
920
+ stdout, stderr = proc.communicate()
921
+ if isinstance(stdout, bytes):
922
+ stdout = stdout.decode('utf-8')
923
+ except Exception as e:
924
+ if raise_error:
925
+ raise FFMPEGNotFoundError(msg=str(e.args))
926
+ else:
927
+ return []
928
+ encoders = []
929
+ lines = stdout.split('\n')
930
+
931
+ for line in lines:
932
+ if re.match(r'^\s*[VAS]', line):
933
+ parts = line.split()
934
+ if len(parts) >= 2:
935
+ encoder_name = parts[1]
936
+ encoders.append(encoder_name)
937
+
938
+ return sorted(encoders) if alphabetically_sorted else encoders
939
+
940
+
941
+ def find_closest_string(target: str,
942
+ string_list: List[str],
943
+ case_sensitive: bool = False,
944
+ token_based: bool = True) -> Optional[Tuple[str, Union[int, float]]]:
945
+ """
946
+ Find the closest string in a list to a target string using hybrid similarity matching.
947
+
948
+ This function uses a combination of token-based matching and Levenshtein distance to find
949
+ the best match. Token-based matching is particularly useful for strings like body part names
950
+ where word order may vary (e.g., "Left_ear" vs "Ear_left").
951
+
952
+ :param str target: The target string to match against.
953
+ :param List[str] string_list: List of strings to search through.
954
+ :param bool case_sensitive: If True, comparison is case-sensitive. If False (default), comparison is case-insensitive.
955
+ :param bool token_based: If True (default), uses hybrid token-based and Levenshtein matching which handles word reordering better. If False, uses pure Levenshtein distance only.
956
+ :return: Tuple of (closest_string, distance) or None if string_list is empty. When token_based=True, distance is a float score (lower is better). When token_based=False, distance is integer edit distance.
957
+ :rtype: Optional[Tuple[str, Union[int, float]]]
958
+
959
+ :example:
960
+ >>> find_closest_string("cat", ["dog", "car", "bat"])
961
+ >>> ('car', 0.33)
962
+ >>> find_closest_string("Left_ear", ["Ear_left", "Right_ear", "Nose"])
963
+ >>> ('Ear_left', 0.0)
964
+ >>> find_closest_string("CAT", ["dog", "car", "bat"], case_sensitive=False)
965
+ >>> ('car', 0.33)
966
+ >>> find_closest_string("CAT", ["dog", "car", "bat"], case_sensitive=True, token_based=False)
967
+ >>> ('car', 3)
968
+ """
969
+
970
+ check_str(name=f'{find_closest_string.__name__} target', value=target, allow_blank=False, raise_error=True)
971
+ check_instance(source=f'{find_closest_string.__name__} string_list', instance=string_list, accepted_types=(list,), raise_error=True)
972
+ if len(string_list) == 0:
973
+ return None
974
+ for i in string_list:
975
+ check_str(name=f'{find_closest_string.__name__} string_list entry', value=i, allow_blank=False, raise_error=True)
976
+
977
+ def levenshtein(s1: str, s2: str) -> int:
978
+ if s1 == s2: return 0
979
+ if not s1: return len(s2)
980
+ if not s2: return len(s1)
981
+ if len(s1) > len(s2): s1, s2 = s2, s1
982
+ prev_row = list(range(len(s1) + 1))
983
+ for i, c2 in enumerate(s2):
984
+ curr_row = [i + 1]
985
+ for j, c1 in enumerate(s1):
986
+ cost = 0 if c1 == c2 else 1
987
+ curr_row.append(min(prev_row[j + 1] + 1, curr_row[j] + 1, prev_row[j] + cost))
988
+ prev_row = curr_row
989
+ return prev_row[-1]
990
+
991
+ def tokenize(s: str) -> List[str]:
992
+ """Split string by common delimiters and return sorted tokens"""
993
+ tokens = re.split(r'[_\-\s]+', s)
994
+ return sorted([t for t in tokens if t])
995
+
996
+ def token_sort_similarity(s1: str, s2: str) -> float:
997
+ """
998
+ Hybrid similarity score combining token matching with character-level Levenshtein.
999
+ Returns a score where 0.0 = perfect match, higher = worse match.
1000
+ """
1001
+ tokens1 = tokenize(s1)
1002
+ tokens2 = tokenize(s2)
1003
+
1004
+ # Token set matching
1005
+ set1, set2 = set(tokens1), set(tokens2)
1006
+ intersection = len(set1 & set2)
1007
+ union = len(set1 | set2)
1008
+
1009
+ if union == 0:
1010
+ token_score = 1.0
1011
+ else:
1012
+ token_score = 1.0 - (intersection / union) # Jaccard distance
1013
+
1014
+ sorted_s1 = '_'.join(tokens1)
1015
+ sorted_s2 = '_'.join(tokens2)
1016
+ max_len = max(len(sorted_s1), len(sorted_s2))
1017
+ if max_len == 0:
1018
+ lev_score = 0.0
1019
+ else:
1020
+ lev_score = levenshtein(sorted_s1, sorted_s2) / max_len
1021
+
1022
+ # Weighted combination: token matching (70%) + order similarity (30%)
1023
+ return token_score * 0.7 + lev_score * 0.3
1024
+
1025
+ # Prepare strings for comparison
1026
+ if not case_sensitive:
1027
+ target_cmp = target.lower()
1028
+ string_list_cmp = [s.lower() for s in string_list]
1029
+ else:
1030
+ target_cmp = target
1031
+ string_list_cmp = string_list
1032
+
1033
+ # Find closest match
1034
+ if token_based:
1035
+ scores = [token_sort_similarity(target_cmp, s) for s in string_list_cmp]
1036
+ closest_idx = min(range(len(scores)), key=lambda i: scores[i])
1037
+ closest = string_list[closest_idx]
1038
+ distance = scores[closest_idx]
1039
+ else:
1040
+ distances = [levenshtein(target_cmp, s) for s in string_list_cmp]
1041
+ closest_idx = min(range(len(distances)), key=lambda i: distances[i])
1042
+ closest = string_list[closest_idx]
1043
+ distance = distances[closest_idx]
1044
+
1045
+ return closest, distance
1046
+
1047
+
1048
+
1049
+ def create_directionality_cords(bp_dict: dict,
1050
+ left_ear_name: str,
1051
+ nose_name: str,
1052
+ right_ear_name: str) -> dict:
1053
+ """
1054
+ Helper to create a dictionary mapping animal body-parts (nose, left ear, right ear) to their X and Y coordinate
1055
+ column names for directionality analysis.
1056
+
1057
+ :param dict bp_dict: Dictionary with animal names as keys and body-part coordinate information as values. Expected to contain 'X_bps' and 'Y_bps' keys with lists of column names.
1058
+ :param str left_ear_name: Name of the left ear body-part to search for in coordinate column names.
1059
+ :param str nose_name: Name of the nose body-part to search for in coordinate column names.
1060
+ :param str right_ear_name: Name of the right ear body-part to search for in coordinate column names.
1061
+ :return: Nested dictionary with animal names as keys, body-part types (nose, ear_left, ear_right) as second-level keys, and coordinate types (X_bps, Y_bps) as third-level keys with corresponding column names as values.
1062
+ :rtype: dict
1063
+ :raises InvalidInputError: If any required body-part or coordinate cannot be found in the input dictionary.
1064
+
1065
+ :example:
1066
+ >>> bp_dict = {'Animal_1': {'X_bps': ['Animal_1_Nose_x', 'Animal_1_Ear_left_x', 'Animal_1_Ear_right_x'], 'Y_bps': ['Animal_1_Nose_y', 'Animal_1_Ear_left_y', 'Animal_1_Ear_right_y']}}
1067
+ >>> create_directionality_cords(bp_dict=bp_dict, left_ear_name='Ear_left', nose_name='Nose', right_ear_name='Ear_right')
1068
+ >>> {'Animal_1': {'nose': {'X_bps': 'Animal_1_Nose_x', 'Y_bps': 'Animal_1_Nose_y'}, 'ear_left': {'X_bps': 'Animal_1_Ear_left_x', 'Y_bps': 'Animal_1_Ear_left_y'}, 'ear_right': {'X_bps': 'Animal_1_Ear_right_x', 'Y_bps': 'Animal_1_Ear_right_y'}}}
1069
+ """
1070
+
1071
+ NOSE, EAR_LEFT, EAR_RIGHT = Keys.NOSE.value, Keys.EAR_LEFT.value, Keys.EAR_RIGHT.value
1072
+
1073
+ results = {}
1074
+ for animal in bp_dict.keys():
1075
+ results[animal] = {NOSE: {}, EAR_LEFT: {}, EAR_RIGHT: {}}
1076
+ for dimension in ["X_bps", "Y_bps"]:
1077
+ for cord in bp_dict[animal][dimension]:
1078
+ if (nose_name.lower() in cord.lower()) and ("x" in cord.lower()):
1079
+ results[animal][NOSE]["X_bps"] = cord
1080
+ elif (nose_name.lower() in cord.lower()) and ("y" in cord.lower()):
1081
+ results[animal][NOSE]["Y_bps"] = cord
1082
+ elif (left_ear_name.lower() in cord.lower()) and ("x" in cord.lower()):
1083
+ results[animal][EAR_LEFT]["X_bps"] = cord
1084
+ elif (left_ear_name.lower() in cord.lower()) and ("y" in cord.lower()):
1085
+ results[animal][EAR_LEFT]["Y_bps"] = cord
1086
+ elif (right_ear_name.lower() in cord.lower()) and ("x" in cord.lower()):
1087
+ results[animal][EAR_RIGHT]["X_bps"] = cord
1088
+ elif (right_ear_name.lower() in cord.lower()) and ("y" in cord.lower()):
1089
+ results[animal][EAR_RIGHT]["Y_bps"] = cord
1090
+
1091
+ for animal_name, animal_bps in results.items():
1092
+ for bp_name, bp_values in animal_bps.items():
1093
+ if len(bp_values.keys()) == 0:
1094
+ raise InvalidInputError(msg=f'Could not detect a body-part for animal {animal_name}, body-part {bp_name} in SimBA project. Make sure the body-part configuration file at {Paths.BP_NAMES.value} lists the appropriate body-parts', source=create_directionality_cords.__name__)
1095
+ for cord_key, cord_value in bp_values.items():
1096
+ if cord_value == '':
1097
+ raise InvalidInputError(msg=f'Could not detect a body-part for animal {animal_name}, body-part {bp_name} and coordinate {cord_key} in SimBA project. MAke sure the body-part configuration file at {Paths.BP_NAMES.value} lists the appropriate body-parts. Passed values: {left_ear_name, nose_name, right_ear_name}', source=create_directionality_cords.__name__)
1098
+ return results
1099
+
1100
+ def get_tooltips() -> Dict[str, str]:
1101
+ simba_dir = os.path.dirname(simba.__file__)
1102
+ tool_tips_path = os.path.join(simba_dir, Paths.TOOLTIPS.value)
1103
+ if not os.path.isfile(tool_tips_path):
1104
+ return {}
1105
+ else:
1106
+ return read_json(x=tool_tips_path, raise_error=False)
1107
+
1108
+ def intermittent_palette(n: int = 10,
1109
+ base_light: float = 0.55,
1110
+ contrast_delta: float = 0.18,
1111
+ seed_hue: Optional[float] = None,
1112
+ output: Literal["rgb", "rgb255", "hex"] = "rgb",
1113
+ rng: Optional[random.Random] = None) -> Union[List[RGBFloat], List[Tuple[int, int, int]], List[str]]:
1114
+ """
1115
+ Generate a categorical colour palette with evenly spaced hues and alternating lightness.
1116
+
1117
+ .. note::
1118
+ Use to get color palette where immediate colors are distinct.
1119
+
1120
+ :param int n: Number of colours to generate. Must be greater than or equal to 1.
1121
+ :param float base_light: Midpoint HSV value (0-1) used as the baseline lightness. Default ``0.55``.
1122
+ :param float contrast_delta: Lightness offset added/subtracted per colour to improve visual separation. Default ``0.18``.
1123
+ :param Optional[float] seed_hue: Initial hue (0-1). If ``None``, a random hue is sampled. Default ``None``.
1124
+ :param str output: Output colour format. One of ``{"rgb", "rgb255", "hex"}``. Default ``"rgb"``.
1125
+ :param Optional[random.Random] rng: Optional pre-seeded RNG for reproducible random starts.
1126
+ :return: Colour palette in the requested format (RGB floats, RGB 0-255 integers, or hexadecimal strings).
1127
+ :rtype: Union[List[Tuple[float, float, float]], List[Tuple[int, int, int]], List[str]]
1128
+
1129
+ :example:
1130
+ >>> palette = intermittent_palette(n=6, output="hex")
1131
+ >>> palette
1132
+ >>> ['#a33f46', '#51a5df', '#b36824', '#4dbd9f', '#c749b4', '#7a9a3e']
1133
+ """
1134
+
1135
+ fn_name = intermittent_palette.__name__
1136
+ check_int(name=f"{fn_name} n", value=n, min_value=1)
1137
+ check_float(name=f"{fn_name} base_light", value=base_light, min_value=0.0, max_value=1.0)
1138
+ check_float(name=f"{fn_name} contrast_delta", value=contrast_delta, min_value=0.0, max_value=1.0)
1139
+ if seed_hue is not None:
1140
+ check_float(name=f"{fn_name} seed_hue", value=seed_hue, min_value=0.0, max_value=1.0)
1141
+ check_str(name=f"{fn_name} output", value=output, options={"rgb", "rgb255", "hex"}, raise_error=True)
1142
+
1143
+ if rng is not None and not isinstance(rng, random.Random):
1144
+ raise InvalidInputError(msg="rng must be an instance of random.Random.", source=fn_name)
1145
+
1146
+ golden_ratio = 0.618033988749895
1147
+ rnd = rng or random.Random()
1148
+ hue = seed_hue % 1.0 if seed_hue is not None else rnd.random()
1149
+ colours: List[RGBFloat] = []
1150
+
1151
+ for idx in range(n):
1152
+ hue = (hue + golden_ratio) % 1.0
1153
+ sat = 0.72 if idx % 3 else 0.85
1154
+ light = base_light + (contrast_delta if idx % 2 else -contrast_delta)
1155
+ light = min(max(light, 0.25), 0.85)
1156
+ colours.append(tuple(hsv_to_rgb((hue, sat, light))))
1157
+
1158
+ fmt = output.lower()
1159
+ if fmt == "rgb":
1160
+ return colours
1161
+ elif fmt == "rgb255":
1162
+ return [tuple(int(round(c * 255)) for c in colour) for colour in colours]
1163
+ else:
1164
+ return [rgb2hex(colour) for colour in colours]
1165
+
1166
+ def quality_pct_to_crf(pct: int) -> int:
1167
+ check_int(name=f'{quality_pct_to_crf.__name__} pct', min_value=1, max_value=100, raise_error=True, value=pct)
1168
+ quality_lk = {int(k):v for k, v in percent_to_crf_lookup().items()}
1169
+ closest_key = min(quality_lk, key=lambda k: abs(k - pct))
1170
+ return quality_lk[closest_key]
1171
+
1172
+
1173
+ def check_for_updates(time_out: int = 2):
1174
+ """
1175
+ Check for SimBA package updates by querying PyPI and comparing with the installed version.
1176
+
1177
+ Fetches the latest SimBA version from PyPI and compares it with the currently installed
1178
+ version. Prints an informational message indicating whether an update is available or if
1179
+ the installation is up-to-date. Requires an active internet connection to query PyPI.
1180
+
1181
+ :parameter int time_out: Timeout in seconds for the PyPI API request. Default is 2 seconds.
1182
+ Must be at least 1 second.
1183
+ :returns: None. Prints update information to stdout via stdout_information.
1184
+ :raises SimBAPAckageVersionError: If the latest version cannot be fetched from PyPI, or if
1185
+ the local SimBA version cannot be determined.
1186
+
1187
+ :example:
1188
+ >>> check_for_updates()
1189
+ >>> # Prints: "UP-TO-DATE. You have the latest SimBA version (1.0.0)."
1190
+ >>> # or: "NEW SimBA VERSION AVAILABLE. You have SimBA version 1.0.0. The latest version is 1.1.0..."
1191
+ """
1192
+ check_int(name=f'{fetch_pip_data.__name__} time_out', value=time_out, min_value=1)
1193
+ _, latest_simba_version = fetch_pip_data(pip_url=r'https://pypi.org/pypi/simba-uw-tf-dev/json', time_out=time_out)
1194
+ env_simba_version = OS.SIMBA_VERSION.value
1195
+ if latest_simba_version is None:
1196
+ raise SimBAPAckageVersionError(msg='Could not fetch latest SimBA version.', source=check_for_updates.__name__)
1197
+ elif env_simba_version is None:
1198
+ raise SimBAPAckageVersionError(msg='Could not get local SimBA version.', source=check_for_updates.__name__)
1199
+ if latest_simba_version == env_simba_version:
1200
+ msg = f'UP-TO-DATE. \nYou have the latest SimBA version ({env_simba_version}).'
1201
+ else:
1202
+ msg = (f'NEW SimBA VERSION AVAILABLE. \nYou have SimBA version {env_simba_version}. \nThe latest version is {latest_simba_version}. '
1203
+ f'\nYou can update using "pip install simba-uw-tf-dev --upgrade"')
1204
1204
  stdout_information(msg=msg, source=check_for_updates.__name__)