sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
File without changes
sonusai/config/asr.py ADDED
@@ -0,0 +1,21 @@
1
+ def validate_asr_configs(given: dict) -> None:
2
+ """Validate fields in 'asr_config' in the given config
3
+
4
+ :param given: The dictionary of the given config
5
+ """
6
+ from ..utils.asr import validate_asr
7
+ from .constants import REQUIRED_ASR_CONFIGS_FIELDS
8
+
9
+ if "asr_configs" not in given:
10
+ raise AttributeError("config is missing required 'asr_configs'")
11
+
12
+ asr_configs = given["asr_configs"]
13
+
14
+ for name, asr_config in asr_configs.items():
15
+ for key in REQUIRED_ASR_CONFIGS_FIELDS:
16
+ if key not in asr_config:
17
+ raise AttributeError(f"'{name}' in asr_configs is missing required '{key}'")
18
+
19
+ engine = asr_config["engine"]
20
+ config = {x: asr_config[x] for x in asr_config if x != "engine"}
21
+ validate_asr(engine, **config)
@@ -0,0 +1,65 @@
1
+ from functools import lru_cache
2
+
3
+
4
+ def load_yaml(name: str) -> dict:
5
+ """Load YAML file
6
+
7
+ :param name: File name
8
+ :return: Dictionary of config data
9
+ """
10
+ import yaml
11
+
12
+ with open(file=name) as f:
13
+ config = yaml.safe_load(f)
14
+
15
+ return config
16
+
17
+
18
+ @lru_cache
19
+ def default_config() -> dict:
20
+ """Load default SonusAI config
21
+
22
+ :return: Dictionary of default config data
23
+ """
24
+ from .constants import DEFAULT_CONFIG
25
+
26
+ try:
27
+ return load_yaml(DEFAULT_CONFIG)
28
+ except Exception as e:
29
+ raise OSError(f"Error loading default config: {e}") from e
30
+
31
+
32
+ def _update_config_from_file(filename: str, given_config: dict) -> dict:
33
+ """Update the given config with the config in the specified YAML file
34
+
35
+ :param filename: File name
36
+ :param given_config: Config dictionary to update
37
+ :return: Updated config dictionary
38
+ """
39
+ from copy import deepcopy
40
+
41
+ updated_config = deepcopy(given_config)
42
+
43
+ try:
44
+ file_config = load_yaml(filename)
45
+ except Exception as e:
46
+ raise OSError(f"Error loading config from {filename}: {e}") from e
47
+
48
+ # Use default config as base and overwrite with given config keys as found
49
+ if file_config:
50
+ for key in updated_config:
51
+ if key in file_config:
52
+ updated_config[key] = file_config[key]
53
+
54
+ return updated_config
55
+
56
+
57
+ def load_config(name: str) -> dict:
58
+ """Load the SonusAI default config and update with the given location (performing SonusAI variable substitution)
59
+
60
+ :param name: Directory containing mixture database
61
+ :return: Dictionary of config data
62
+ """
63
+ from os.path import join
64
+
65
+ return _update_config_from_file(filename=join(name, "config.yml"), given_config=default_config())
@@ -0,0 +1,49 @@
1
+ # Default configuration for sonusai
2
+
3
+ # The values in this file are the defaults used if they are not specified in a
4
+ # local config.
5
+
6
+ feature: ""
7
+
8
+ class_indices: 1
9
+
10
+ num_classes: 1
11
+
12
+ class_labels: [ ]
13
+
14
+ seed: 0
15
+
16
+ class_weights_threshold: 0.5
17
+
18
+ asr_configs: { }
19
+
20
+ class_balancing: false
21
+
22
+ class_balancing_effect:
23
+ - norm -3.5
24
+ - pitch rand(-300, 300)
25
+ - tempo -s rand(0.8, 1.2)
26
+ - equalizer rand(50, 250) rand(0.2, 2.0) rand(-6, 6)
27
+ - equalizer rand(250, 1200) rand(0.2, 2.0) rand(-6, 6)
28
+ - equalizer rand(1200, 6000) rand(0.2, 2.0) rand(-6, 6)
29
+
30
+ spectral_masks:
31
+ - f_max_width: 27
32
+ f_num: 0
33
+ t_max_width: 100
34
+ t_num: 0
35
+ t_max_percent: 100
36
+
37
+ sources:
38
+ primary:
39
+ files: [ ]
40
+ noise:
41
+ files: [ ]
42
+
43
+ level_type: default
44
+
45
+ impulse_responses: [ ]
46
+
47
+ summed_source_effects: [ ]
48
+
49
+ mixture_effects: [ ]
@@ -0,0 +1,53 @@
1
+ from importlib.resources import as_file
2
+ from importlib.resources import files
3
+
4
+ REQUIRED_CONFIGS: tuple[str, ...] = (
5
+ "asr_configs",
6
+ "class_balancing",
7
+ "class_balancing_effect",
8
+ "class_indices",
9
+ "class_labels",
10
+ "class_weights_threshold",
11
+ "feature",
12
+ "impulse_responses",
13
+ "level_type",
14
+ "mixture_effects",
15
+ "num_classes",
16
+ "seed",
17
+ "sources",
18
+ "spectral_masks",
19
+ "summed_source_effects",
20
+ )
21
+ OPTIONAL_CONFIGS: tuple[str, ...] = ()
22
+ VALID_CONFIGS: tuple[str, ...] = REQUIRED_CONFIGS + OPTIONAL_CONFIGS
23
+
24
+ REQUIRED_SOURCES_CATEGORIES: tuple[str, ...] = (
25
+ "primary",
26
+ "noise",
27
+ )
28
+
29
+ REQUIRED_SOURCE_CONFIG_FIELDS: tuple[str, ...] = (
30
+ "effects",
31
+ "files",
32
+ )
33
+ OPTIONAL_SOURCE_CONFIG_FIELDS: tuple[str, ...] = ("truth_configs",)
34
+ REQUIRED_NON_PRIMARY_SOURCE_CONFIG_FIELDS: tuple[str, ...] = (
35
+ "mix_rules",
36
+ "snrs",
37
+ )
38
+ VALID_PRIMARY_SOURCE_CONFIG_FIELDS: tuple[str, ...] = REQUIRED_SOURCE_CONFIG_FIELDS + OPTIONAL_SOURCE_CONFIG_FIELDS
39
+ VALID_NON_PRIMARY_SOURCE_CONFIG_FIELDS: tuple[str, ...] = (
40
+ VALID_PRIMARY_SOURCE_CONFIG_FIELDS + REQUIRED_NON_PRIMARY_SOURCE_CONFIG_FIELDS
41
+ )
42
+
43
+ REQUIRED_TRUTH_CONFIGS: tuple[str, ...] = (
44
+ "function",
45
+ "stride_reduction",
46
+ )
47
+
48
+ REQUIRED_ASR_CONFIGS_FIELDS: tuple[str, ...] = ("engine",)
49
+
50
+ REQUIRED_TRUTH_CONFIG_FIELDS = ["function", "stride_reduction"]
51
+
52
+ with as_file(files("sonusai.config").joinpath("config.yml")) as path:
53
+ DEFAULT_CONFIG = str(path)
sonusai/config/ir.py ADDED
@@ -0,0 +1,124 @@
1
+ from sonusai.datatypes import ImpulseResponseFile
2
+
3
+
4
+ def get_ir_files(config: dict, show_progress: bool = False) -> list[ImpulseResponseFile]:
5
+ """Get the list of impulse response files from a config
6
+
7
+ :param config: Config dictionary
8
+ :param show_progress: Show progress bar
9
+ :return: List of impulse response files
10
+ """
11
+ from itertools import chain
12
+
13
+ from ..utils.parallel import par_track
14
+ from ..utils.parallel import track
15
+
16
+ ir_files = list(
17
+ chain.from_iterable(
18
+ [
19
+ append_ir_files(
20
+ entry=ImpulseResponseFile(
21
+ name=entry["name"],
22
+ tags=entry.get("tags", []),
23
+ delay=entry.get("delay", "auto"),
24
+ )
25
+ )
26
+ for entry in config["impulse_responses"]
27
+ ]
28
+ )
29
+ )
30
+
31
+ if len(ir_files) == 0:
32
+ return []
33
+
34
+ progress = track(total=len(ir_files), disable=not show_progress)
35
+ ir_files = par_track(_get_ir_delay, ir_files, progress=progress)
36
+ progress.close()
37
+
38
+ return ir_files
39
+
40
+
41
+ def append_ir_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[ImpulseResponseFile]:
42
+ """Process impulse response files list and append as needed
43
+
44
+ :param entry: Impulse response file entry to append to the list
45
+ :param tokens: Tokens used for variable expansion
46
+ :return: List of impulse response files
47
+ """
48
+ from glob import glob
49
+ from os import listdir
50
+ from os.path import dirname
51
+ from os.path import isabs
52
+ from os.path import isdir
53
+ from os.path import join
54
+ from os.path import splitext
55
+
56
+ from ..mixture.audio import validate_input_file
57
+ from ..utils.tokenized_shell_vars import tokenized_expand
58
+ from ..utils.tokenized_shell_vars import tokenized_replace
59
+ from .config import load_yaml
60
+
61
+ if tokens is None:
62
+ tokens = {}
63
+
64
+ in_name, new_tokens = tokenized_expand(entry.name)
65
+ tokens.update(new_tokens)
66
+ names = sorted(glob(in_name))
67
+ if not names:
68
+ raise OSError(f"Could not find {in_name}. Make sure path exists")
69
+
70
+ ir_files: list[ImpulseResponseFile] = []
71
+ for name in names:
72
+ ext = splitext(name)[1].lower()
73
+ dir_name = dirname(name)
74
+ if isdir(name):
75
+ for file in listdir(name):
76
+ if not isabs(file):
77
+ file = join(dir_name, file)
78
+ child = ImpulseResponseFile(file, entry.tags, entry.delay)
79
+ ir_files.extend(append_ir_files(entry=child, tokens=tokens))
80
+ else:
81
+ try:
82
+ if ext == ".txt":
83
+ with open(file=name) as txt_file:
84
+ for line in txt_file:
85
+ # strip comments
86
+ file = line.partition("#")[0]
87
+ file = file.rstrip()
88
+ if file:
89
+ file, new_tokens = tokenized_expand(file)
90
+ tokens.update(new_tokens)
91
+ if not isabs(file):
92
+ file = join(dir_name, file)
93
+ child = ImpulseResponseFile(file, entry.tags, entry.delay)
94
+ ir_files.extend(append_ir_files(entry=child, tokens=tokens))
95
+ elif ext == ".yml":
96
+ try:
97
+ yml_config = load_yaml(name)
98
+
99
+ if "impulse_responses" in yml_config:
100
+ for record in yml_config["impulse_responses"]:
101
+ ir_files.extend(append_ir_files(entry=record, tokens=tokens))
102
+ except Exception as e:
103
+ raise OSError(f"Error processing {name}: {e}") from e
104
+ else:
105
+ validate_input_file(name)
106
+ ir_files.append(ImpulseResponseFile(tokenized_replace(name, tokens), entry.tags, entry.delay))
107
+ except Exception as e:
108
+ raise OSError(f"Error processing {name}: {e}") from e
109
+
110
+ return ir_files
111
+
112
+
113
+ def _get_ir_delay(entry: ImpulseResponseFile) -> ImpulseResponseFile:
114
+ from .ir_delay import get_ir_delay
115
+
116
+ if entry.delay == "auto":
117
+ entry.delay = get_ir_delay(entry.name)
118
+ else:
119
+ try:
120
+ entry.delay = int(entry.delay)
121
+ except ValueError as e:
122
+ raise ValueError(f"Invalid impulse response delay: {entry.delay}") from e
123
+
124
+ return entry
@@ -0,0 +1,62 @@
1
+ import numpy as np
2
+
3
+
4
+ def get_ir_delay(file: str) -> int:
5
+ from ..mixture.audio import raw_read_audio
6
+ from ..utils.rand import seed_context
7
+
8
+ ir, sample_rate = raw_read_audio(file)
9
+
10
+ with seed_context(42):
11
+ wgn_ref = np.random.normal(loc=0, scale=0.2, size=int(np.ceil(0.05 * sample_rate))).astype(np.float32)
12
+
13
+ wgn_conv = np.convolve(ir, wgn_ref)
14
+
15
+ return int(np.round(tdoa(wgn_conv, wgn_ref, interp=16, phat=True)))
16
+
17
+
18
+ def tdoa(signal: np.ndarray, reference: np.ndarray, interp: int = 1, phat: bool = False, fs: int | float = 1) -> float:
19
+ """Estimates the shift of array signal with respect to reference using generalized cross-correlation.
20
+
21
+ :param signal: The array whose tdoa is measured
22
+ :param reference: The reference array
23
+ :param interp: Interpolation factor for the output array
24
+ :param phat: Apply the PHAT weighting
25
+ :param fs: The sampling frequency of the input arrays
26
+ :return: The estimated delay between the two arrays
27
+ """
28
+ n_reference = reference.shape[0]
29
+
30
+ r_12 = correlate(signal, reference, interp=interp, phat=phat)
31
+
32
+ delay = (np.argmax(np.abs(r_12)) / interp - (n_reference - 1)) / fs
33
+
34
+ return float(delay)
35
+
36
+
37
+ def correlate(x1: np.ndarray, x2: np.ndarray, interp: int = 1, phat: bool = False) -> np.ndarray:
38
+ """Compute the cross-correlation between x1 and x2
39
+
40
+ :param x1: Input array 1
41
+ :param x2: Input array 2
42
+ :param interp: Interpolation factor for the output array
43
+ :param phat: Apply the PHAT weighting
44
+ :return: The cross-correlation between the two arrays
45
+ """
46
+ n_x1 = x1.shape[0]
47
+ n_x2 = x2.shape[0]
48
+
49
+ n = n_x1 + n_x2 - 1
50
+
51
+ fft1 = np.fft.rfft(x1, n=n)
52
+ fft2 = np.fft.rfft(x2, n=n)
53
+
54
+ if phat:
55
+ eps1 = np.mean(np.abs(fft1)) * 1e-10
56
+ fft1 /= np.abs(fft1) + eps1
57
+ eps2 = np.mean(np.abs(fft2)) * 1e-10
58
+ fft2 /= np.abs(fft2) + eps2
59
+
60
+ out = np.fft.irfft(fft1 * np.conj(fft2), n=int(n * interp))
61
+
62
+ return np.concatenate([out[-interp * (n_x2 - 1) :], out[: (interp * n_x1)]])
@@ -0,0 +1,275 @@
1
+ from sonusai.datatypes import SourceFile
2
+
3
+
4
+ def update_sources(given: dict) -> dict:
5
+ """Validate and update fields in given 'sources'
6
+
7
+ :param given: The dictionary of the given config
8
+ """
9
+ from .constants import REQUIRED_NON_PRIMARY_SOURCE_CONFIG_FIELDS
10
+ from .constants import REQUIRED_SOURCE_CONFIG_FIELDS
11
+ from .constants import REQUIRED_SOURCES_CATEGORIES
12
+ from .constants import VALID_NON_PRIMARY_SOURCE_CONFIG_FIELDS
13
+ from .constants import VALID_PRIMARY_SOURCE_CONFIG_FIELDS
14
+
15
+ sources = given["sources"]
16
+
17
+ for category in REQUIRED_SOURCES_CATEGORIES:
18
+ if category not in sources:
19
+ raise AttributeError(f"config sources is missing required '{category}'")
20
+
21
+ for category, source in sources.items():
22
+ for key in REQUIRED_SOURCE_CONFIG_FIELDS:
23
+ if key not in source:
24
+ raise AttributeError(f"config source '{category}' is missing required '{key}'")
25
+
26
+ if category == "primary":
27
+ for key in source:
28
+ if key not in VALID_PRIMARY_SOURCE_CONFIG_FIELDS:
29
+ nice_list = "\n".join([f" {item}" for item in VALID_PRIMARY_SOURCE_CONFIG_FIELDS])
30
+ raise AttributeError(
31
+ f"Invalid source '{category}' config parameter: '{key}'.\nValid sources config parameters are:\n{nice_list}"
32
+ )
33
+ else:
34
+ for key in REQUIRED_NON_PRIMARY_SOURCE_CONFIG_FIELDS:
35
+ if key not in source:
36
+ raise AttributeError(f"config source '{category}' is missing required '{key}'")
37
+
38
+ for key in source:
39
+ if key not in VALID_NON_PRIMARY_SOURCE_CONFIG_FIELDS:
40
+ nice_list = "\n".join([f" {item}" for item in VALID_NON_PRIMARY_SOURCE_CONFIG_FIELDS])
41
+ raise AttributeError(
42
+ f"Invalid source '{category}' config parameter: '{key}'.\nValid source config parameters are:\n{nice_list}"
43
+ )
44
+
45
+ files = source["files"]
46
+
47
+ if isinstance(files, str) and files in sources and files != category:
48
+ continue
49
+
50
+ if isinstance(files, list):
51
+ continue
52
+
53
+ raise TypeError(
54
+ f"'file' parameter of config source '{category}' is not a list or a reference to another source"
55
+ )
56
+
57
+ count = 0
58
+ while any(isinstance(source["files"], str) for source in sources.values()) and count < 100:
59
+ count += 1
60
+ for category, source in sources.items():
61
+ files = source["files"]
62
+ if isinstance(files, str):
63
+ given["sources"][category]["files"] = sources[files]["files"]
64
+
65
+ if count == 100:
66
+ raise RuntimeError("Check config sources for circular references")
67
+
68
+ return given
69
+
70
+
71
+ def get_source_files(config: dict, show_progress: bool = False) -> list[SourceFile]:
72
+ """Get the list of source files from a config
73
+
74
+ :param config: Config dictionary
75
+ :param show_progress: Show progress bar
76
+ :return: List of source files
77
+ """
78
+ from itertools import chain
79
+
80
+ from ..utils.parallel import par_track
81
+ from ..utils.parallel import track
82
+
83
+ sources = config["sources"]
84
+ if not isinstance(sources, dict) and not all(isinstance(source, dict) for source in sources):
85
+ raise TypeError("'sources' must be a dictionary of dictionaries")
86
+
87
+ if "primary" not in sources:
88
+ raise AttributeError("'primary' is missing in 'sources'")
89
+
90
+ class_indices = config["class_indices"]
91
+ if not isinstance(class_indices, list):
92
+ class_indices = [class_indices]
93
+
94
+ level_type = config["level_type"]
95
+
96
+ source_files: list[SourceFile] = []
97
+ for category in sources:
98
+ source_files.extend(
99
+ chain.from_iterable(
100
+ [
101
+ append_source_files(
102
+ category=category,
103
+ entry=entry,
104
+ class_indices=class_indices,
105
+ truth_configs=sources[category].get("truth_configs", []),
106
+ level_type=level_type,
107
+ )
108
+ for entry in sources[category]["files"]
109
+ ]
110
+ )
111
+ )
112
+
113
+ progress = track(total=len(source_files), disable=not show_progress)
114
+ source_files = par_track(_get_num_samples, source_files, progress=progress)
115
+ progress.close()
116
+
117
+ num_classes = config["num_classes"]
118
+ for source_file in source_files:
119
+ if any(class_index < 0 for class_index in source_file.class_indices):
120
+ raise ValueError("class indices must contain only positive elements")
121
+
122
+ if any(class_index > num_classes for class_index in source_file.class_indices):
123
+ raise ValueError(f"class index elements must not be greater than {num_classes}")
124
+
125
+ return source_files
126
+
127
+
128
+ def append_source_files(
129
+ category: str,
130
+ entry: dict,
131
+ class_indices: list[int],
132
+ truth_configs: dict,
133
+ level_type: str,
134
+ tokens: dict | None = None,
135
+ ) -> list[SourceFile]:
136
+ """Process source files list and append as needed
137
+
138
+ :param category: Source file category name
139
+ :param entry: Source file entry to append to the list
140
+ :param class_indices: Class indices
141
+ :param truth_configs: Truth configs
142
+ :param level_type: Level type
143
+ :param tokens: Tokens used for variable expansion
144
+ :return: List of source files
145
+ """
146
+ from copy import deepcopy
147
+ from glob import glob
148
+ from os import listdir
149
+ from os.path import dirname
150
+ from os.path import isabs
151
+ from os.path import isdir
152
+ from os.path import join
153
+ from os.path import splitext
154
+
155
+ from ..datatypes import TruthConfig
156
+ from ..mixture.audio import validate_input_file
157
+ from ..utils.dataclass_from_dict import dataclass_from_dict
158
+ from ..utils.tokenized_shell_vars import tokenized_expand
159
+ from ..utils.tokenized_shell_vars import tokenized_replace
160
+ from .constants import REQUIRED_TRUTH_CONFIG_FIELDS
161
+
162
+ if tokens is None:
163
+ tokens = {}
164
+
165
+ truth_configs_merged = deepcopy(truth_configs)
166
+
167
+ if not isinstance(entry, dict):
168
+ raise TypeError("'entry' must be a dictionary")
169
+
170
+ in_name = entry.get("name")
171
+ if in_name is None:
172
+ raise KeyError("Source file list contained record without name")
173
+
174
+ class_indices = entry.get("class_indices", class_indices)
175
+ if not isinstance(class_indices, list):
176
+ class_indices = [class_indices]
177
+
178
+ truth_configs_override = entry.get("truth_configs", {})
179
+ for key in truth_configs_override:
180
+ if key not in truth_configs:
181
+ raise AttributeError(
182
+ f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
183
+ )
184
+ if key in truth_configs_override:
185
+ truth_configs_merged[key] |= truth_configs_override[key]
186
+
187
+ level_type = entry.get("level_type", level_type)
188
+
189
+ in_name, new_tokens = tokenized_expand(in_name)
190
+ tokens.update(new_tokens)
191
+ names = sorted(glob(in_name))
192
+ if not names:
193
+ raise OSError(f"Could not find {in_name}. Make sure path exists")
194
+
195
+ source_files: list[SourceFile] = []
196
+ for name in names:
197
+ ext = splitext(name)[1].lower()
198
+ dir_name = dirname(name)
199
+ if isdir(name):
200
+ for file in listdir(name):
201
+ child = file
202
+ if not isabs(child):
203
+ child = join(dir_name, child)
204
+ source_files.extend(
205
+ append_source_files(
206
+ category=category,
207
+ entry={"name": child},
208
+ class_indices=class_indices,
209
+ truth_configs=truth_configs_merged,
210
+ level_type=level_type,
211
+ tokens=tokens,
212
+ )
213
+ )
214
+ else:
215
+ try:
216
+ if ext == ".txt":
217
+ with open(file=name) as txt_file:
218
+ for line in txt_file:
219
+ # strip comments
220
+ child = line.partition("#")[0]
221
+ child = child.rstrip()
222
+ if child:
223
+ child, new_tokens = tokenized_expand(child)
224
+ tokens.update(new_tokens)
225
+ if not isabs(child):
226
+ child = join(dir_name, child)
227
+ source_files.extend(
228
+ append_source_files(
229
+ category=category,
230
+ entry={"name": child},
231
+ class_indices=class_indices,
232
+ truth_configs=truth_configs_merged,
233
+ level_type=level_type,
234
+ tokens=tokens,
235
+ )
236
+ )
237
+ else:
238
+ validate_input_file(name)
239
+ source_file = SourceFile(
240
+ category=category,
241
+ name=tokenized_replace(name, tokens),
242
+ samples=0,
243
+ class_indices=class_indices,
244
+ level_type=level_type,
245
+ truth_configs={},
246
+ )
247
+ if len(truth_configs_merged) > 0:
248
+ for tc_key, tc_value in truth_configs_merged.items():
249
+ config = deepcopy(tc_value)
250
+ truth_config: dict = {}
251
+ for key in REQUIRED_TRUTH_CONFIG_FIELDS:
252
+ truth_config[key] = config[key]
253
+ del config[key]
254
+ truth_config["config"] = config
255
+ source_file.truth_configs[tc_key] = dataclass_from_dict(TruthConfig, truth_config)
256
+ for tc_key in source_file.truth_configs:
257
+ if (
258
+ "function" in truth_configs_merged[tc_key]
259
+ and truth_configs_merged[tc_key]["function"] == "file"
260
+ ):
261
+ truth_configs_merged[tc_key]["file"] = splitext(source_file.name)[0] + ".h5"
262
+ source_files.append(source_file)
263
+ except Exception as e:
264
+ raise OSError(f"Error processing {name}: {e}") from e
265
+
266
+ return source_files
267
+
268
+
269
+ def _get_num_samples(entry: SourceFile) -> SourceFile:
270
+ from ..mixture.audio import get_num_samples
271
+
272
+ entry.samples = get_num_samples(entry.name)
273
+ return entry
274
+
275
+
@@ -0,0 +1,15 @@
1
+ from sonusai.datatypes import SpectralMask
2
+
3
+
4
+ def get_spectral_masks(config: dict) -> list[SpectralMask]:
5
+ """Get the list of spectral masks from a config
6
+
7
+ :param config: Config dictionary
8
+ :return: List of spectral masks
9
+ """
10
+ from ..utils.dataclass_from_dict import list_dataclass_from_dict
11
+
12
+ try:
13
+ return list_dataclass_from_dict(list[SpectralMask], config["spectral_masks"])
14
+ except Exception as e:
15
+ raise ValueError(f"Error in spectral_masks: {e}") from e