sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
File without changes
|
sonusai/config/asr.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
def validate_asr_configs(given: dict) -> None:
|
2
|
+
"""Validate fields in 'asr_config' in the given config
|
3
|
+
|
4
|
+
:param given: The dictionary of the given config
|
5
|
+
"""
|
6
|
+
from ..utils.asr import validate_asr
|
7
|
+
from .constants import REQUIRED_ASR_CONFIGS_FIELDS
|
8
|
+
|
9
|
+
if "asr_configs" not in given:
|
10
|
+
raise AttributeError("config is missing required 'asr_configs'")
|
11
|
+
|
12
|
+
asr_configs = given["asr_configs"]
|
13
|
+
|
14
|
+
for name, asr_config in asr_configs.items():
|
15
|
+
for key in REQUIRED_ASR_CONFIGS_FIELDS:
|
16
|
+
if key not in asr_config:
|
17
|
+
raise AttributeError(f"'{name}' in asr_configs is missing required '{key}'")
|
18
|
+
|
19
|
+
engine = asr_config["engine"]
|
20
|
+
config = {x: asr_config[x] for x in asr_config if x != "engine"}
|
21
|
+
validate_asr(engine, **config)
|
sonusai/config/config.py
ADDED
@@ -0,0 +1,65 @@
|
|
1
|
+
from functools import lru_cache
|
2
|
+
|
3
|
+
|
4
|
+
def load_yaml(name: str) -> dict:
|
5
|
+
"""Load YAML file
|
6
|
+
|
7
|
+
:param name: File name
|
8
|
+
:return: Dictionary of config data
|
9
|
+
"""
|
10
|
+
import yaml
|
11
|
+
|
12
|
+
with open(file=name) as f:
|
13
|
+
config = yaml.safe_load(f)
|
14
|
+
|
15
|
+
return config
|
16
|
+
|
17
|
+
|
18
|
+
@lru_cache
|
19
|
+
def default_config() -> dict:
|
20
|
+
"""Load default SonusAI config
|
21
|
+
|
22
|
+
:return: Dictionary of default config data
|
23
|
+
"""
|
24
|
+
from .constants import DEFAULT_CONFIG
|
25
|
+
|
26
|
+
try:
|
27
|
+
return load_yaml(DEFAULT_CONFIG)
|
28
|
+
except Exception as e:
|
29
|
+
raise OSError(f"Error loading default config: {e}") from e
|
30
|
+
|
31
|
+
|
32
|
+
def _update_config_from_file(filename: str, given_config: dict) -> dict:
|
33
|
+
"""Update the given config with the config in the specified YAML file
|
34
|
+
|
35
|
+
:param filename: File name
|
36
|
+
:param given_config: Config dictionary to update
|
37
|
+
:return: Updated config dictionary
|
38
|
+
"""
|
39
|
+
from copy import deepcopy
|
40
|
+
|
41
|
+
updated_config = deepcopy(given_config)
|
42
|
+
|
43
|
+
try:
|
44
|
+
file_config = load_yaml(filename)
|
45
|
+
except Exception as e:
|
46
|
+
raise OSError(f"Error loading config from {filename}: {e}") from e
|
47
|
+
|
48
|
+
# Use default config as base and overwrite with given config keys as found
|
49
|
+
if file_config:
|
50
|
+
for key in updated_config:
|
51
|
+
if key in file_config:
|
52
|
+
updated_config[key] = file_config[key]
|
53
|
+
|
54
|
+
return updated_config
|
55
|
+
|
56
|
+
|
57
|
+
def load_config(name: str) -> dict:
|
58
|
+
"""Load the SonusAI default config and update with the given location (performing SonusAI variable substitution)
|
59
|
+
|
60
|
+
:param name: Directory containing mixture database
|
61
|
+
:return: Dictionary of config data
|
62
|
+
"""
|
63
|
+
from os.path import join
|
64
|
+
|
65
|
+
return _update_config_from_file(filename=join(name, "config.yml"), given_config=default_config())
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Default configuration for sonusai
|
2
|
+
|
3
|
+
# The values in this file are the defaults used if they are not specified in a
|
4
|
+
# local config.
|
5
|
+
|
6
|
+
feature: ""
|
7
|
+
|
8
|
+
class_indices: 1
|
9
|
+
|
10
|
+
num_classes: 1
|
11
|
+
|
12
|
+
class_labels: [ ]
|
13
|
+
|
14
|
+
seed: 0
|
15
|
+
|
16
|
+
class_weights_threshold: 0.5
|
17
|
+
|
18
|
+
asr_configs: { }
|
19
|
+
|
20
|
+
class_balancing: false
|
21
|
+
|
22
|
+
class_balancing_effect:
|
23
|
+
- norm -3.5
|
24
|
+
- pitch rand(-300, 300)
|
25
|
+
- tempo -s rand(0.8, 1.2)
|
26
|
+
- equalizer rand(50, 250) rand(0.2, 2.0) rand(-6, 6)
|
27
|
+
- equalizer rand(250, 1200) rand(0.2, 2.0) rand(-6, 6)
|
28
|
+
- equalizer rand(1200, 6000) rand(0.2, 2.0) rand(-6, 6)
|
29
|
+
|
30
|
+
spectral_masks:
|
31
|
+
- f_max_width: 27
|
32
|
+
f_num: 0
|
33
|
+
t_max_width: 100
|
34
|
+
t_num: 0
|
35
|
+
t_max_percent: 100
|
36
|
+
|
37
|
+
sources:
|
38
|
+
primary:
|
39
|
+
files: [ ]
|
40
|
+
noise:
|
41
|
+
files: [ ]
|
42
|
+
|
43
|
+
level_type: default
|
44
|
+
|
45
|
+
impulse_responses: [ ]
|
46
|
+
|
47
|
+
summed_source_effects: [ ]
|
48
|
+
|
49
|
+
mixture_effects: [ ]
|
@@ -0,0 +1,53 @@
|
|
1
|
+
from importlib.resources import as_file
|
2
|
+
from importlib.resources import files
|
3
|
+
|
4
|
+
REQUIRED_CONFIGS: tuple[str, ...] = (
|
5
|
+
"asr_configs",
|
6
|
+
"class_balancing",
|
7
|
+
"class_balancing_effect",
|
8
|
+
"class_indices",
|
9
|
+
"class_labels",
|
10
|
+
"class_weights_threshold",
|
11
|
+
"feature",
|
12
|
+
"impulse_responses",
|
13
|
+
"level_type",
|
14
|
+
"mixture_effects",
|
15
|
+
"num_classes",
|
16
|
+
"seed",
|
17
|
+
"sources",
|
18
|
+
"spectral_masks",
|
19
|
+
"summed_source_effects",
|
20
|
+
)
|
21
|
+
OPTIONAL_CONFIGS: tuple[str, ...] = ()
|
22
|
+
VALID_CONFIGS: tuple[str, ...] = REQUIRED_CONFIGS + OPTIONAL_CONFIGS
|
23
|
+
|
24
|
+
REQUIRED_SOURCES_CATEGORIES: tuple[str, ...] = (
|
25
|
+
"primary",
|
26
|
+
"noise",
|
27
|
+
)
|
28
|
+
|
29
|
+
REQUIRED_SOURCE_CONFIG_FIELDS: tuple[str, ...] = (
|
30
|
+
"effects",
|
31
|
+
"files",
|
32
|
+
)
|
33
|
+
OPTIONAL_SOURCE_CONFIG_FIELDS: tuple[str, ...] = ("truth_configs",)
|
34
|
+
REQUIRED_NON_PRIMARY_SOURCE_CONFIG_FIELDS: tuple[str, ...] = (
|
35
|
+
"mix_rules",
|
36
|
+
"snrs",
|
37
|
+
)
|
38
|
+
VALID_PRIMARY_SOURCE_CONFIG_FIELDS: tuple[str, ...] = REQUIRED_SOURCE_CONFIG_FIELDS + OPTIONAL_SOURCE_CONFIG_FIELDS
|
39
|
+
VALID_NON_PRIMARY_SOURCE_CONFIG_FIELDS: tuple[str, ...] = (
|
40
|
+
VALID_PRIMARY_SOURCE_CONFIG_FIELDS + REQUIRED_NON_PRIMARY_SOURCE_CONFIG_FIELDS
|
41
|
+
)
|
42
|
+
|
43
|
+
REQUIRED_TRUTH_CONFIGS: tuple[str, ...] = (
|
44
|
+
"function",
|
45
|
+
"stride_reduction",
|
46
|
+
)
|
47
|
+
|
48
|
+
REQUIRED_ASR_CONFIGS_FIELDS: tuple[str, ...] = ("engine",)
|
49
|
+
|
50
|
+
REQUIRED_TRUTH_CONFIG_FIELDS = ["function", "stride_reduction"]
|
51
|
+
|
52
|
+
with as_file(files("sonusai.config").joinpath("config.yml")) as path:
|
53
|
+
DEFAULT_CONFIG = str(path)
|
sonusai/config/ir.py
ADDED
@@ -0,0 +1,124 @@
|
|
1
|
+
from sonusai.datatypes import ImpulseResponseFile
|
2
|
+
|
3
|
+
|
4
|
+
def get_ir_files(config: dict, show_progress: bool = False) -> list[ImpulseResponseFile]:
|
5
|
+
"""Get the list of impulse response files from a config
|
6
|
+
|
7
|
+
:param config: Config dictionary
|
8
|
+
:param show_progress: Show progress bar
|
9
|
+
:return: List of impulse response files
|
10
|
+
"""
|
11
|
+
from itertools import chain
|
12
|
+
|
13
|
+
from ..utils.parallel import par_track
|
14
|
+
from ..utils.parallel import track
|
15
|
+
|
16
|
+
ir_files = list(
|
17
|
+
chain.from_iterable(
|
18
|
+
[
|
19
|
+
append_ir_files(
|
20
|
+
entry=ImpulseResponseFile(
|
21
|
+
name=entry["name"],
|
22
|
+
tags=entry.get("tags", []),
|
23
|
+
delay=entry.get("delay", "auto"),
|
24
|
+
)
|
25
|
+
)
|
26
|
+
for entry in config["impulse_responses"]
|
27
|
+
]
|
28
|
+
)
|
29
|
+
)
|
30
|
+
|
31
|
+
if len(ir_files) == 0:
|
32
|
+
return []
|
33
|
+
|
34
|
+
progress = track(total=len(ir_files), disable=not show_progress)
|
35
|
+
ir_files = par_track(_get_ir_delay, ir_files, progress=progress)
|
36
|
+
progress.close()
|
37
|
+
|
38
|
+
return ir_files
|
39
|
+
|
40
|
+
|
41
|
+
def append_ir_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[ImpulseResponseFile]:
|
42
|
+
"""Process impulse response files list and append as needed
|
43
|
+
|
44
|
+
:param entry: Impulse response file entry to append to the list
|
45
|
+
:param tokens: Tokens used for variable expansion
|
46
|
+
:return: List of impulse response files
|
47
|
+
"""
|
48
|
+
from glob import glob
|
49
|
+
from os import listdir
|
50
|
+
from os.path import dirname
|
51
|
+
from os.path import isabs
|
52
|
+
from os.path import isdir
|
53
|
+
from os.path import join
|
54
|
+
from os.path import splitext
|
55
|
+
|
56
|
+
from ..mixture.audio import validate_input_file
|
57
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
58
|
+
from ..utils.tokenized_shell_vars import tokenized_replace
|
59
|
+
from .config import load_yaml
|
60
|
+
|
61
|
+
if tokens is None:
|
62
|
+
tokens = {}
|
63
|
+
|
64
|
+
in_name, new_tokens = tokenized_expand(entry.name)
|
65
|
+
tokens.update(new_tokens)
|
66
|
+
names = sorted(glob(in_name))
|
67
|
+
if not names:
|
68
|
+
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
69
|
+
|
70
|
+
ir_files: list[ImpulseResponseFile] = []
|
71
|
+
for name in names:
|
72
|
+
ext = splitext(name)[1].lower()
|
73
|
+
dir_name = dirname(name)
|
74
|
+
if isdir(name):
|
75
|
+
for file in listdir(name):
|
76
|
+
if not isabs(file):
|
77
|
+
file = join(dir_name, file)
|
78
|
+
child = ImpulseResponseFile(file, entry.tags, entry.delay)
|
79
|
+
ir_files.extend(append_ir_files(entry=child, tokens=tokens))
|
80
|
+
else:
|
81
|
+
try:
|
82
|
+
if ext == ".txt":
|
83
|
+
with open(file=name) as txt_file:
|
84
|
+
for line in txt_file:
|
85
|
+
# strip comments
|
86
|
+
file = line.partition("#")[0]
|
87
|
+
file = file.rstrip()
|
88
|
+
if file:
|
89
|
+
file, new_tokens = tokenized_expand(file)
|
90
|
+
tokens.update(new_tokens)
|
91
|
+
if not isabs(file):
|
92
|
+
file = join(dir_name, file)
|
93
|
+
child = ImpulseResponseFile(file, entry.tags, entry.delay)
|
94
|
+
ir_files.extend(append_ir_files(entry=child, tokens=tokens))
|
95
|
+
elif ext == ".yml":
|
96
|
+
try:
|
97
|
+
yml_config = load_yaml(name)
|
98
|
+
|
99
|
+
if "impulse_responses" in yml_config:
|
100
|
+
for record in yml_config["impulse_responses"]:
|
101
|
+
ir_files.extend(append_ir_files(entry=record, tokens=tokens))
|
102
|
+
except Exception as e:
|
103
|
+
raise OSError(f"Error processing {name}: {e}") from e
|
104
|
+
else:
|
105
|
+
validate_input_file(name)
|
106
|
+
ir_files.append(ImpulseResponseFile(tokenized_replace(name, tokens), entry.tags, entry.delay))
|
107
|
+
except Exception as e:
|
108
|
+
raise OSError(f"Error processing {name}: {e}") from e
|
109
|
+
|
110
|
+
return ir_files
|
111
|
+
|
112
|
+
|
113
|
+
def _get_ir_delay(entry: ImpulseResponseFile) -> ImpulseResponseFile:
|
114
|
+
from .ir_delay import get_ir_delay
|
115
|
+
|
116
|
+
if entry.delay == "auto":
|
117
|
+
entry.delay = get_ir_delay(entry.name)
|
118
|
+
else:
|
119
|
+
try:
|
120
|
+
entry.delay = int(entry.delay)
|
121
|
+
except ValueError as e:
|
122
|
+
raise ValueError(f"Invalid impulse response delay: {entry.delay}") from e
|
123
|
+
|
124
|
+
return entry
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def get_ir_delay(file: str) -> int:
|
5
|
+
from ..mixture.audio import raw_read_audio
|
6
|
+
from ..utils.rand import seed_context
|
7
|
+
|
8
|
+
ir, sample_rate = raw_read_audio(file)
|
9
|
+
|
10
|
+
with seed_context(42):
|
11
|
+
wgn_ref = np.random.normal(loc=0, scale=0.2, size=int(np.ceil(0.05 * sample_rate))).astype(np.float32)
|
12
|
+
|
13
|
+
wgn_conv = np.convolve(ir, wgn_ref)
|
14
|
+
|
15
|
+
return int(np.round(tdoa(wgn_conv, wgn_ref, interp=16, phat=True)))
|
16
|
+
|
17
|
+
|
18
|
+
def tdoa(signal: np.ndarray, reference: np.ndarray, interp: int = 1, phat: bool = False, fs: int | float = 1) -> float:
|
19
|
+
"""Estimates the shift of array signal with respect to reference using generalized cross-correlation.
|
20
|
+
|
21
|
+
:param signal: The array whose tdoa is measured
|
22
|
+
:param reference: The reference array
|
23
|
+
:param interp: Interpolation factor for the output array
|
24
|
+
:param phat: Apply the PHAT weighting
|
25
|
+
:param fs: The sampling frequency of the input arrays
|
26
|
+
:return: The estimated delay between the two arrays
|
27
|
+
"""
|
28
|
+
n_reference = reference.shape[0]
|
29
|
+
|
30
|
+
r_12 = correlate(signal, reference, interp=interp, phat=phat)
|
31
|
+
|
32
|
+
delay = (np.argmax(np.abs(r_12)) / interp - (n_reference - 1)) / fs
|
33
|
+
|
34
|
+
return float(delay)
|
35
|
+
|
36
|
+
|
37
|
+
def correlate(x1: np.ndarray, x2: np.ndarray, interp: int = 1, phat: bool = False) -> np.ndarray:
|
38
|
+
"""Compute the cross-correlation between x1 and x2
|
39
|
+
|
40
|
+
:param x1: Input array 1
|
41
|
+
:param x2: Input array 2
|
42
|
+
:param interp: Interpolation factor for the output array
|
43
|
+
:param phat: Apply the PHAT weighting
|
44
|
+
:return: The cross-correlation between the two arrays
|
45
|
+
"""
|
46
|
+
n_x1 = x1.shape[0]
|
47
|
+
n_x2 = x2.shape[0]
|
48
|
+
|
49
|
+
n = n_x1 + n_x2 - 1
|
50
|
+
|
51
|
+
fft1 = np.fft.rfft(x1, n=n)
|
52
|
+
fft2 = np.fft.rfft(x2, n=n)
|
53
|
+
|
54
|
+
if phat:
|
55
|
+
eps1 = np.mean(np.abs(fft1)) * 1e-10
|
56
|
+
fft1 /= np.abs(fft1) + eps1
|
57
|
+
eps2 = np.mean(np.abs(fft2)) * 1e-10
|
58
|
+
fft2 /= np.abs(fft2) + eps2
|
59
|
+
|
60
|
+
out = np.fft.irfft(fft1 * np.conj(fft2), n=int(n * interp))
|
61
|
+
|
62
|
+
return np.concatenate([out[-interp * (n_x2 - 1) :], out[: (interp * n_x1)]])
|
sonusai/config/source.py
ADDED
@@ -0,0 +1,275 @@
|
|
1
|
+
from sonusai.datatypes import SourceFile
|
2
|
+
|
3
|
+
|
4
|
+
def update_sources(given: dict) -> dict:
|
5
|
+
"""Validate and update fields in given 'sources'
|
6
|
+
|
7
|
+
:param given: The dictionary of the given config
|
8
|
+
"""
|
9
|
+
from .constants import REQUIRED_NON_PRIMARY_SOURCE_CONFIG_FIELDS
|
10
|
+
from .constants import REQUIRED_SOURCE_CONFIG_FIELDS
|
11
|
+
from .constants import REQUIRED_SOURCES_CATEGORIES
|
12
|
+
from .constants import VALID_NON_PRIMARY_SOURCE_CONFIG_FIELDS
|
13
|
+
from .constants import VALID_PRIMARY_SOURCE_CONFIG_FIELDS
|
14
|
+
|
15
|
+
sources = given["sources"]
|
16
|
+
|
17
|
+
for category in REQUIRED_SOURCES_CATEGORIES:
|
18
|
+
if category not in sources:
|
19
|
+
raise AttributeError(f"config sources is missing required '{category}'")
|
20
|
+
|
21
|
+
for category, source in sources.items():
|
22
|
+
for key in REQUIRED_SOURCE_CONFIG_FIELDS:
|
23
|
+
if key not in source:
|
24
|
+
raise AttributeError(f"config source '{category}' is missing required '{key}'")
|
25
|
+
|
26
|
+
if category == "primary":
|
27
|
+
for key in source:
|
28
|
+
if key not in VALID_PRIMARY_SOURCE_CONFIG_FIELDS:
|
29
|
+
nice_list = "\n".join([f" {item}" for item in VALID_PRIMARY_SOURCE_CONFIG_FIELDS])
|
30
|
+
raise AttributeError(
|
31
|
+
f"Invalid source '{category}' config parameter: '{key}'.\nValid sources config parameters are:\n{nice_list}"
|
32
|
+
)
|
33
|
+
else:
|
34
|
+
for key in REQUIRED_NON_PRIMARY_SOURCE_CONFIG_FIELDS:
|
35
|
+
if key not in source:
|
36
|
+
raise AttributeError(f"config source '{category}' is missing required '{key}'")
|
37
|
+
|
38
|
+
for key in source:
|
39
|
+
if key not in VALID_NON_PRIMARY_SOURCE_CONFIG_FIELDS:
|
40
|
+
nice_list = "\n".join([f" {item}" for item in VALID_NON_PRIMARY_SOURCE_CONFIG_FIELDS])
|
41
|
+
raise AttributeError(
|
42
|
+
f"Invalid source '{category}' config parameter: '{key}'.\nValid source config parameters are:\n{nice_list}"
|
43
|
+
)
|
44
|
+
|
45
|
+
files = source["files"]
|
46
|
+
|
47
|
+
if isinstance(files, str) and files in sources and files != category:
|
48
|
+
continue
|
49
|
+
|
50
|
+
if isinstance(files, list):
|
51
|
+
continue
|
52
|
+
|
53
|
+
raise TypeError(
|
54
|
+
f"'file' parameter of config source '{category}' is not a list or a reference to another source"
|
55
|
+
)
|
56
|
+
|
57
|
+
count = 0
|
58
|
+
while any(isinstance(source["files"], str) for source in sources.values()) and count < 100:
|
59
|
+
count += 1
|
60
|
+
for category, source in sources.items():
|
61
|
+
files = source["files"]
|
62
|
+
if isinstance(files, str):
|
63
|
+
given["sources"][category]["files"] = sources[files]["files"]
|
64
|
+
|
65
|
+
if count == 100:
|
66
|
+
raise RuntimeError("Check config sources for circular references")
|
67
|
+
|
68
|
+
return given
|
69
|
+
|
70
|
+
|
71
|
+
def get_source_files(config: dict, show_progress: bool = False) -> list[SourceFile]:
|
72
|
+
"""Get the list of source files from a config
|
73
|
+
|
74
|
+
:param config: Config dictionary
|
75
|
+
:param show_progress: Show progress bar
|
76
|
+
:return: List of source files
|
77
|
+
"""
|
78
|
+
from itertools import chain
|
79
|
+
|
80
|
+
from ..utils.parallel import par_track
|
81
|
+
from ..utils.parallel import track
|
82
|
+
|
83
|
+
sources = config["sources"]
|
84
|
+
if not isinstance(sources, dict) and not all(isinstance(source, dict) for source in sources):
|
85
|
+
raise TypeError("'sources' must be a dictionary of dictionaries")
|
86
|
+
|
87
|
+
if "primary" not in sources:
|
88
|
+
raise AttributeError("'primary' is missing in 'sources'")
|
89
|
+
|
90
|
+
class_indices = config["class_indices"]
|
91
|
+
if not isinstance(class_indices, list):
|
92
|
+
class_indices = [class_indices]
|
93
|
+
|
94
|
+
level_type = config["level_type"]
|
95
|
+
|
96
|
+
source_files: list[SourceFile] = []
|
97
|
+
for category in sources:
|
98
|
+
source_files.extend(
|
99
|
+
chain.from_iterable(
|
100
|
+
[
|
101
|
+
append_source_files(
|
102
|
+
category=category,
|
103
|
+
entry=entry,
|
104
|
+
class_indices=class_indices,
|
105
|
+
truth_configs=sources[category].get("truth_configs", []),
|
106
|
+
level_type=level_type,
|
107
|
+
)
|
108
|
+
for entry in sources[category]["files"]
|
109
|
+
]
|
110
|
+
)
|
111
|
+
)
|
112
|
+
|
113
|
+
progress = track(total=len(source_files), disable=not show_progress)
|
114
|
+
source_files = par_track(_get_num_samples, source_files, progress=progress)
|
115
|
+
progress.close()
|
116
|
+
|
117
|
+
num_classes = config["num_classes"]
|
118
|
+
for source_file in source_files:
|
119
|
+
if any(class_index < 0 for class_index in source_file.class_indices):
|
120
|
+
raise ValueError("class indices must contain only positive elements")
|
121
|
+
|
122
|
+
if any(class_index > num_classes for class_index in source_file.class_indices):
|
123
|
+
raise ValueError(f"class index elements must not be greater than {num_classes}")
|
124
|
+
|
125
|
+
return source_files
|
126
|
+
|
127
|
+
|
128
|
+
def append_source_files(
|
129
|
+
category: str,
|
130
|
+
entry: dict,
|
131
|
+
class_indices: list[int],
|
132
|
+
truth_configs: dict,
|
133
|
+
level_type: str,
|
134
|
+
tokens: dict | None = None,
|
135
|
+
) -> list[SourceFile]:
|
136
|
+
"""Process source files list and append as needed
|
137
|
+
|
138
|
+
:param category: Source file category name
|
139
|
+
:param entry: Source file entry to append to the list
|
140
|
+
:param class_indices: Class indices
|
141
|
+
:param truth_configs: Truth configs
|
142
|
+
:param level_type: Level type
|
143
|
+
:param tokens: Tokens used for variable expansion
|
144
|
+
:return: List of source files
|
145
|
+
"""
|
146
|
+
from copy import deepcopy
|
147
|
+
from glob import glob
|
148
|
+
from os import listdir
|
149
|
+
from os.path import dirname
|
150
|
+
from os.path import isabs
|
151
|
+
from os.path import isdir
|
152
|
+
from os.path import join
|
153
|
+
from os.path import splitext
|
154
|
+
|
155
|
+
from ..datatypes import TruthConfig
|
156
|
+
from ..mixture.audio import validate_input_file
|
157
|
+
from ..utils.dataclass_from_dict import dataclass_from_dict
|
158
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
159
|
+
from ..utils.tokenized_shell_vars import tokenized_replace
|
160
|
+
from .constants import REQUIRED_TRUTH_CONFIG_FIELDS
|
161
|
+
|
162
|
+
if tokens is None:
|
163
|
+
tokens = {}
|
164
|
+
|
165
|
+
truth_configs_merged = deepcopy(truth_configs)
|
166
|
+
|
167
|
+
if not isinstance(entry, dict):
|
168
|
+
raise TypeError("'entry' must be a dictionary")
|
169
|
+
|
170
|
+
in_name = entry.get("name")
|
171
|
+
if in_name is None:
|
172
|
+
raise KeyError("Source file list contained record without name")
|
173
|
+
|
174
|
+
class_indices = entry.get("class_indices", class_indices)
|
175
|
+
if not isinstance(class_indices, list):
|
176
|
+
class_indices = [class_indices]
|
177
|
+
|
178
|
+
truth_configs_override = entry.get("truth_configs", {})
|
179
|
+
for key in truth_configs_override:
|
180
|
+
if key not in truth_configs:
|
181
|
+
raise AttributeError(
|
182
|
+
f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
|
183
|
+
)
|
184
|
+
if key in truth_configs_override:
|
185
|
+
truth_configs_merged[key] |= truth_configs_override[key]
|
186
|
+
|
187
|
+
level_type = entry.get("level_type", level_type)
|
188
|
+
|
189
|
+
in_name, new_tokens = tokenized_expand(in_name)
|
190
|
+
tokens.update(new_tokens)
|
191
|
+
names = sorted(glob(in_name))
|
192
|
+
if not names:
|
193
|
+
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
194
|
+
|
195
|
+
source_files: list[SourceFile] = []
|
196
|
+
for name in names:
|
197
|
+
ext = splitext(name)[1].lower()
|
198
|
+
dir_name = dirname(name)
|
199
|
+
if isdir(name):
|
200
|
+
for file in listdir(name):
|
201
|
+
child = file
|
202
|
+
if not isabs(child):
|
203
|
+
child = join(dir_name, child)
|
204
|
+
source_files.extend(
|
205
|
+
append_source_files(
|
206
|
+
category=category,
|
207
|
+
entry={"name": child},
|
208
|
+
class_indices=class_indices,
|
209
|
+
truth_configs=truth_configs_merged,
|
210
|
+
level_type=level_type,
|
211
|
+
tokens=tokens,
|
212
|
+
)
|
213
|
+
)
|
214
|
+
else:
|
215
|
+
try:
|
216
|
+
if ext == ".txt":
|
217
|
+
with open(file=name) as txt_file:
|
218
|
+
for line in txt_file:
|
219
|
+
# strip comments
|
220
|
+
child = line.partition("#")[0]
|
221
|
+
child = child.rstrip()
|
222
|
+
if child:
|
223
|
+
child, new_tokens = tokenized_expand(child)
|
224
|
+
tokens.update(new_tokens)
|
225
|
+
if not isabs(child):
|
226
|
+
child = join(dir_name, child)
|
227
|
+
source_files.extend(
|
228
|
+
append_source_files(
|
229
|
+
category=category,
|
230
|
+
entry={"name": child},
|
231
|
+
class_indices=class_indices,
|
232
|
+
truth_configs=truth_configs_merged,
|
233
|
+
level_type=level_type,
|
234
|
+
tokens=tokens,
|
235
|
+
)
|
236
|
+
)
|
237
|
+
else:
|
238
|
+
validate_input_file(name)
|
239
|
+
source_file = SourceFile(
|
240
|
+
category=category,
|
241
|
+
name=tokenized_replace(name, tokens),
|
242
|
+
samples=0,
|
243
|
+
class_indices=class_indices,
|
244
|
+
level_type=level_type,
|
245
|
+
truth_configs={},
|
246
|
+
)
|
247
|
+
if len(truth_configs_merged) > 0:
|
248
|
+
for tc_key, tc_value in truth_configs_merged.items():
|
249
|
+
config = deepcopy(tc_value)
|
250
|
+
truth_config: dict = {}
|
251
|
+
for key in REQUIRED_TRUTH_CONFIG_FIELDS:
|
252
|
+
truth_config[key] = config[key]
|
253
|
+
del config[key]
|
254
|
+
truth_config["config"] = config
|
255
|
+
source_file.truth_configs[tc_key] = dataclass_from_dict(TruthConfig, truth_config)
|
256
|
+
for tc_key in source_file.truth_configs:
|
257
|
+
if (
|
258
|
+
"function" in truth_configs_merged[tc_key]
|
259
|
+
and truth_configs_merged[tc_key]["function"] == "file"
|
260
|
+
):
|
261
|
+
truth_configs_merged[tc_key]["file"] = splitext(source_file.name)[0] + ".h5"
|
262
|
+
source_files.append(source_file)
|
263
|
+
except Exception as e:
|
264
|
+
raise OSError(f"Error processing {name}: {e}") from e
|
265
|
+
|
266
|
+
return source_files
|
267
|
+
|
268
|
+
|
269
|
+
def _get_num_samples(entry: SourceFile) -> SourceFile:
|
270
|
+
from ..mixture.audio import get_num_samples
|
271
|
+
|
272
|
+
entry.samples = get_num_samples(entry.name)
|
273
|
+
return entry
|
274
|
+
|
275
|
+
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from sonusai.datatypes import SpectralMask
|
2
|
+
|
3
|
+
|
4
|
+
def get_spectral_masks(config: dict) -> list[SpectralMask]:
|
5
|
+
"""Get the list of spectral masks from a config
|
6
|
+
|
7
|
+
:param config: Config dictionary
|
8
|
+
:return: List of spectral masks
|
9
|
+
"""
|
10
|
+
from ..utils.dataclass_from_dict import list_dataclass_from_dict
|
11
|
+
|
12
|
+
try:
|
13
|
+
return list_dataclass_from_dict(list[SpectralMask], config["spectral_masks"])
|
14
|
+
except Exception as e:
|
15
|
+
raise ValueError(f"Error in spectral_masks: {e}") from e
|