sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/main.py
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
"""sonusai
|
2
|
+
|
3
|
+
usage: sonusai [--version] [--help] <command> [<args>...]
|
4
|
+
|
5
|
+
The sonusai commands are:
|
6
|
+
<This information is automatically generated.>
|
7
|
+
|
8
|
+
Aaware Sound and Voice Machine Learning Framework. See 'sonusai help <command>'
|
9
|
+
for more information on a specific command.
|
10
|
+
|
11
|
+
"""
|
12
|
+
|
13
|
+
import sys
|
14
|
+
from importlib import import_module
|
15
|
+
from pkgutil import iter_modules
|
16
|
+
|
17
|
+
from docopt import docopt
|
18
|
+
|
19
|
+
from sonusai import BASEDIR
|
20
|
+
from sonusai import __version__ as sai_version
|
21
|
+
from sonusai import commands_list
|
22
|
+
from sonusai import logger
|
23
|
+
from sonusai.utils.docstring import add_commands_to_docstring
|
24
|
+
from sonusai.utils.docstring import trim_docstring
|
25
|
+
|
26
|
+
|
27
|
+
def discover_plugins():
|
28
|
+
plugins = {}
|
29
|
+
plugin_docstrings = []
|
30
|
+
for _, name, _ in iter_modules():
|
31
|
+
if name.startswith("sonusai_") and not name.startswith("sonusai_asr_"):
|
32
|
+
module = import_module(name)
|
33
|
+
plugins[name] = {
|
34
|
+
"commands": commands_list(module.commands_doc),
|
35
|
+
"basedir": module.BASEDIR,
|
36
|
+
}
|
37
|
+
plugin_docstrings.append(module.commands_doc)
|
38
|
+
return plugins, plugin_docstrings
|
39
|
+
|
40
|
+
|
41
|
+
def execute_command_direct(command: str, argv: list[str], basedir: str) -> None:
|
42
|
+
"""Execute a command by importing and running it directly."""
|
43
|
+
try:
|
44
|
+
# Add the command directory to the Python path temporarily
|
45
|
+
if basedir not in sys.path:
|
46
|
+
sys.path.insert(0, basedir)
|
47
|
+
|
48
|
+
# Import the command module
|
49
|
+
command_module = import_module(command)
|
50
|
+
|
51
|
+
# Set up sys.argv as the command module expects it
|
52
|
+
original_argv = sys.argv
|
53
|
+
sys.argv = [command, *argv]
|
54
|
+
|
55
|
+
try:
|
56
|
+
# Execute the main function if it exists
|
57
|
+
if hasattr(command_module, "main"):
|
58
|
+
command_module.main()
|
59
|
+
else:
|
60
|
+
logger.error(f"Command module {command} has no main() function")
|
61
|
+
sys.exit(1)
|
62
|
+
finally:
|
63
|
+
# Restore original sys.argv
|
64
|
+
sys.argv = original_argv
|
65
|
+
|
66
|
+
except ImportError as err:
|
67
|
+
logger.error(f"Failed to import command module {command}: {err}")
|
68
|
+
sys.exit(1)
|
69
|
+
except Exception as err:
|
70
|
+
logger.error(f"Error executing command {command}: {err}")
|
71
|
+
sys.exit(1)
|
72
|
+
|
73
|
+
|
74
|
+
def handle_help_command_direct(argv: list[str], base_commands: list[str], plugins: dict) -> None:
|
75
|
+
"""Handle the help command by executing modules directly."""
|
76
|
+
if not argv:
|
77
|
+
# Show the main help by re-running with -h
|
78
|
+
sys.argv = ["sonusai", "-h"]
|
79
|
+
main()
|
80
|
+
return
|
81
|
+
|
82
|
+
help_target = argv[0]
|
83
|
+
|
84
|
+
if help_target in base_commands:
|
85
|
+
execute_command_direct(help_target, ["-h"], BASEDIR)
|
86
|
+
else:
|
87
|
+
for data in plugins.values():
|
88
|
+
if help_target in data["commands"]:
|
89
|
+
execute_command_direct(help_target, ["-h"], data["basedir"])
|
90
|
+
return
|
91
|
+
|
92
|
+
logger.error(f"{help_target} is not a SonusAI command. See 'sonusai help'.")
|
93
|
+
sys.exit(1)
|
94
|
+
|
95
|
+
|
96
|
+
def main() -> None:
|
97
|
+
plugins, plugin_docstrings = discover_plugins()
|
98
|
+
updated_docstring = add_commands_to_docstring(__doc__, plugin_docstrings)
|
99
|
+
args = docopt(
|
100
|
+
trim_docstring(updated_docstring),
|
101
|
+
version=sai_version,
|
102
|
+
options_first=True,
|
103
|
+
)
|
104
|
+
|
105
|
+
command = args["<command>"]
|
106
|
+
argv = args["<args>"]
|
107
|
+
base_commands = commands_list()
|
108
|
+
|
109
|
+
if command == "help":
|
110
|
+
handle_help_command_direct(argv, base_commands, plugins)
|
111
|
+
return
|
112
|
+
|
113
|
+
if command in base_commands:
|
114
|
+
execute_command_direct(command, argv, BASEDIR)
|
115
|
+
return
|
116
|
+
|
117
|
+
for data in plugins.values():
|
118
|
+
if command in data["commands"]:
|
119
|
+
execute_command_direct(command, argv, data["basedir"])
|
120
|
+
return
|
121
|
+
|
122
|
+
logger.error(f"{command} is not a SonusAI command. See 'sonusai help'.")
|
123
|
+
sys.exit(1)
|
124
|
+
|
125
|
+
|
126
|
+
if __name__ == "__main__":
|
127
|
+
from sonusai import exception_handler
|
128
|
+
from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
|
129
|
+
|
130
|
+
register_keyboard_interrupt()
|
131
|
+
try:
|
132
|
+
main()
|
133
|
+
except Exception as e:
|
134
|
+
exception_handler(e)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# SonusAI metrics utilities for model training and validation
|
2
|
+
|
3
|
+
from .calc_audio_stats import calc_audio_stats
|
4
|
+
from .calc_class_weights import calc_class_weights_from_mixdb
|
5
|
+
from .calc_class_weights import calc_class_weights_from_truth
|
6
|
+
from .calc_optimal_thresholds import calc_optimal_thresholds
|
7
|
+
from .calc_pcm import calc_pcm
|
8
|
+
from .calc_pesq import calc_pesq
|
9
|
+
from .calc_phase_distance import calc_phase_distance
|
10
|
+
from .calc_sa_sdr import calc_sa_sdr
|
11
|
+
from .calc_sample_weights import calc_sample_weights
|
12
|
+
from .calc_segsnr_f import calc_segsnr_f
|
13
|
+
from .calc_segsnr_f import calc_segsnr_f_bin
|
14
|
+
from .calc_speech import calc_speech
|
15
|
+
from .calc_wer import calc_wer
|
16
|
+
from .calc_wsdr import calc_wsdr
|
17
|
+
from .calculate_metrics import calculate_metrics
|
18
|
+
from .class_summary import class_summary
|
19
|
+
from .confusion_matrix_summary import confusion_matrix_summary
|
20
|
+
from .one_hot import one_hot
|
21
|
+
from .snr_summary import snr_summary
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
"calc_audio_stats",
|
25
|
+
"calc_class_weights_from_mixdb",
|
26
|
+
"calc_class_weights_from_truth",
|
27
|
+
"calc_optimal_thresholds",
|
28
|
+
"calc_pcm",
|
29
|
+
"calc_pesq",
|
30
|
+
"calc_phase_distance",
|
31
|
+
"calc_sa_sdr",
|
32
|
+
"calc_sample_weights",
|
33
|
+
"calc_segsnr_f",
|
34
|
+
"calc_segsnr_f_bin",
|
35
|
+
"calc_speech",
|
36
|
+
"calc_wer",
|
37
|
+
"calc_wsdr",
|
38
|
+
"calculate_metrics",
|
39
|
+
"class_summary",
|
40
|
+
"confusion_matrix_summary",
|
41
|
+
"one_hot",
|
42
|
+
"snr_summary",
|
43
|
+
]
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from ..datatypes import AudioStatsMetrics
|
2
|
+
from ..datatypes import AudioT
|
3
|
+
|
4
|
+
|
5
|
+
def _convert_str_with_factors_to_int(x: str) -> int:
|
6
|
+
if "k" in x:
|
7
|
+
return int(1000 * float(x.replace("k", "")))
|
8
|
+
if "M" in x:
|
9
|
+
return int(1000000 * float(x.replace("M", "")))
|
10
|
+
return int(x)
|
11
|
+
|
12
|
+
|
13
|
+
def calc_audio_stats(audio: AudioT, win_len: float | None = None) -> AudioStatsMetrics:
|
14
|
+
from ..mixture.sox_effects import sox_stats
|
15
|
+
|
16
|
+
out = sox_stats(audio, win_len)
|
17
|
+
|
18
|
+
if out is None:
|
19
|
+
raise SystemError("Call to sox failed")
|
20
|
+
|
21
|
+
stats = {}
|
22
|
+
lines = out.split("\n")
|
23
|
+
for line in lines:
|
24
|
+
split_line = line.split()
|
25
|
+
if len(split_line) == 0:
|
26
|
+
continue
|
27
|
+
value = split_line[-1]
|
28
|
+
key = " ".join(split_line[:-1])
|
29
|
+
stats[key] = value
|
30
|
+
|
31
|
+
return AudioStatsMetrics(
|
32
|
+
dco=float(stats["DC offset"]),
|
33
|
+
min=float(stats["Min level"]),
|
34
|
+
max=float(stats["Max level"]),
|
35
|
+
pkdb=float(stats["Pk lev dB"]),
|
36
|
+
lrms=float(stats["RMS lev dB"]),
|
37
|
+
pkr=float(stats["RMS Pk dB"]),
|
38
|
+
tr=float(stats["RMS Tr dB"]),
|
39
|
+
cr=float(stats["Crest factor"]),
|
40
|
+
fl=float(stats["Flat factor"]),
|
41
|
+
pkc=_convert_str_with_factors_to_int(stats["Pk count"]),
|
42
|
+
)
|
@@ -0,0 +1,90 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from ..datatypes import GeneralizedIDs
|
4
|
+
from ..datatypes import Truth
|
5
|
+
from ..mixture.mixdb import MixtureDatabase
|
6
|
+
|
7
|
+
|
8
|
+
def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = None, other_index: int = -1) -> np.ndarray:
|
9
|
+
"""Calculate class weights.
|
10
|
+
|
11
|
+
Supports non-existent classes (a problem with sklearn) where non-existent
|
12
|
+
classes get a weight of 0 (instead of inf).
|
13
|
+
Includes optional weighting of an "other" class if specified.
|
14
|
+
|
15
|
+
Reference:
|
16
|
+
weights = class_weight.compute_class_weight(class_weight='balanced', classes=clabels, y=labels)
|
17
|
+
|
18
|
+
Arguments:
|
19
|
+
truth: Truth data in one-hot format. Size can be:
|
20
|
+
- [frames, timesteps, num_classes]
|
21
|
+
- [frames, num_classes]
|
22
|
+
other_weight: float or `None`. Weight of the "other" class.
|
23
|
+
> 1 = increase weighting/importance relative to the true count
|
24
|
+
0 > `other_weight` < 1 = decrease weighting/importance relative
|
25
|
+
< 0 or `None` = disable, use true count (default = `None`)
|
26
|
+
other_index: int. Index of the "other" class in one-hot mode. Defaults to -1 (the last).
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
A numpy array containing class weights.
|
30
|
+
"""
|
31
|
+
frames, num_classes = truth.shape
|
32
|
+
|
33
|
+
if num_classes > 1:
|
34
|
+
labels = np.argmax(truth, axis=-1) # [frames, 1 labels] from one-hot, last dim
|
35
|
+
count = np.bincount(labels, minlength=num_classes).astype(float)
|
36
|
+
else:
|
37
|
+
num_classes = 2
|
38
|
+
labels = np.array(truth >= 0.5).astype(np.int8)[:, 0] # quantize to binary and shape (frames,) for bincount
|
39
|
+
count = np.bincount(labels, minlength=num_classes).astype(float)
|
40
|
+
|
41
|
+
if other_weight is not None and other_weight > 0:
|
42
|
+
count[other_index] = count[other_index] / np.float32(other_weight)
|
43
|
+
|
44
|
+
weights = np.empty((len(count)), dtype=np.float32)
|
45
|
+
for n in range(len(weights)):
|
46
|
+
if count[n] == 0:
|
47
|
+
# Avoid sklearn problem with absent classes and assign non-existent classes a weight of 0.
|
48
|
+
weights[n] = 0
|
49
|
+
else:
|
50
|
+
weights[n] = frames / (num_classes * count[n])
|
51
|
+
|
52
|
+
return weights
|
53
|
+
|
54
|
+
|
55
|
+
def calc_class_weights_from_mixdb(
|
56
|
+
mixdb: MixtureDatabase,
|
57
|
+
mixids: GeneralizedIDs = "*",
|
58
|
+
other_weight: float = 1,
|
59
|
+
other_index: int = -1,
|
60
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
61
|
+
"""Calculate class weights using estimated feature counts from a mixture database.
|
62
|
+
|
63
|
+
Arguments:
|
64
|
+
mixdb: Mixture database.
|
65
|
+
mixids: Mixture ID's.
|
66
|
+
other_weight: float or `None`. Weight of the "other" class.
|
67
|
+
> 1 = increase weighting/importance relative to the true count
|
68
|
+
0 > `other_weight` < 1 = decrease weighting/importance relative
|
69
|
+
< 0 or `None` = disable, use true count
|
70
|
+
other_index: int. Index of the "other" class in one-hot mode. Defaults to -1 (the last).
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
count: Count of features in each class.
|
74
|
+
weights: Class weights. [num_classes, 1]
|
75
|
+
Note: for Keras use dict(enumerate(weights))
|
76
|
+
"""
|
77
|
+
from ..mixture.class_count import get_class_count_from_mixids
|
78
|
+
|
79
|
+
count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
|
80
|
+
total_features = sum(count)
|
81
|
+
|
82
|
+
weights = np.empty(mixdb.num_classes, dtype=np.float32)
|
83
|
+
for n in range(len(weights)):
|
84
|
+
if count[n] == 0:
|
85
|
+
# Avoid sklearn problem with absent classes and assign non-existent classes a weight of 0.
|
86
|
+
weights[n] = 0
|
87
|
+
else:
|
88
|
+
weights[n] = total_features / (mixdb.num_classes * count[n])
|
89
|
+
|
90
|
+
return count, weights
|
@@ -0,0 +1,73 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from ..datatypes import Predict
|
4
|
+
from ..datatypes import Truth
|
5
|
+
|
6
|
+
|
7
|
+
def calc_optimal_thresholds(
|
8
|
+
truth: Truth,
|
9
|
+
predict: Predict,
|
10
|
+
timesteps: int = 0,
|
11
|
+
truth_thr: float = 0.5,
|
12
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
13
|
+
"""Calculates optimal thresholds for each class from one-hot prediction and truth data where both are
|
14
|
+
one-hot probabilities (or quantized decisions) with size [frames, num_classes] or [frames, timesteps, num_classes].
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
thresholds_opt_pr [num_classes, 1] optimal thresholds for PR-curve (F1) performance
|
18
|
+
thresholds_opt_roc [num_classes, 1] optimal thresholds for ROC-curve (TPR/FPR) performance
|
19
|
+
AP [num_classes, 1]
|
20
|
+
AUC [num_classes, 1]
|
21
|
+
|
22
|
+
Optional truth_thr is the decision threshold(s) applied to truth one-hot input allowing truth to optionally be
|
23
|
+
continuous probabilities. Default is 0.5.
|
24
|
+
"""
|
25
|
+
from sklearn.metrics import average_precision_score
|
26
|
+
from sklearn.metrics import precision_recall_curve
|
27
|
+
from sklearn.metrics import roc_auc_score
|
28
|
+
from sklearn.metrics import roc_curve
|
29
|
+
|
30
|
+
from ..utils.reshape import get_num_classes_from_predict
|
31
|
+
from ..utils.reshape import reshape_outputs
|
32
|
+
|
33
|
+
if truth.shape != predict.shape:
|
34
|
+
raise ValueError("truth and predict are not the same shape")
|
35
|
+
|
36
|
+
predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps) # type: ignore[assignment]
|
37
|
+
num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
|
38
|
+
|
39
|
+
# Apply decision to truth input
|
40
|
+
truth_binary = np.array(truth >= truth_thr).astype(np.int8)
|
41
|
+
|
42
|
+
AP = np.zeros((num_classes, 1))
|
43
|
+
AUC = np.zeros((num_classes, 1))
|
44
|
+
thresholds_opt_pr = np.zeros((num_classes, 1))
|
45
|
+
thresholds_opt_roc = np.zeros((num_classes, 1))
|
46
|
+
eps = np.finfo(float).eps
|
47
|
+
for nci in range(num_classes):
|
48
|
+
# Average Precision also called area under the PR curve AUCPR and
|
49
|
+
# AUC ROC curve using binary-ized truth and continuous prediction probabilities
|
50
|
+
# sklearn returns nan if no active truth in a class but w/un-suppressible div-by-zero warning
|
51
|
+
if sum(truth_binary[:, nci]) == 0: # no active truth must be NaN
|
52
|
+
thresholds_opt_pr[nci] = np.NaN
|
53
|
+
thresholds_opt_roc[nci] = np.NaN
|
54
|
+
AUC[nci] = np.NaN
|
55
|
+
AP[nci] = np.NaN
|
56
|
+
else:
|
57
|
+
AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
58
|
+
AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
59
|
+
|
60
|
+
# Optimal threshold from PR curve, optimizes f-score
|
61
|
+
precision, recall, thrpr = precision_recall_curve(truth_binary[:, nci], predict[:, nci])
|
62
|
+
fscore = (2 * precision * recall) / (precision + recall + eps)
|
63
|
+
ix = np.argmax(fscore) # index of largest f1 score
|
64
|
+
thresholds_opt_pr[nci] = thrpr[ix]
|
65
|
+
|
66
|
+
# Optimal threshold from ROC curve, optimizes J-statistic (TPR-FPR) or gmean
|
67
|
+
fpr, tpr, thrroc = roc_curve(truth_binary[:, nci], predict[:, nci])
|
68
|
+
# J = tpr - fpr # J can result in thr > 1
|
69
|
+
gmeans = np.sqrt(tpr * (1 - fpr)) # gmean seems better behaved
|
70
|
+
ix = np.argmax(gmeans)
|
71
|
+
thresholds_opt_roc[nci] = thrroc[ix]
|
72
|
+
|
73
|
+
return thresholds_opt_pr, thresholds_opt_roc, AP, AUC
|
@@ -0,0 +1,45 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def calc_pcm(
|
5
|
+
hypothesis: np.ndarray, reference: np.ndarray, with_log: bool = False
|
6
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
7
|
+
"""Calculate phase constrained magnitude error
|
8
|
+
|
9
|
+
These must include a noise to make a complete mixture estimate, i.e.,
|
10
|
+
noise_est = mixture - sum-over-nsrc(s_est(:, nsrc, :))
|
11
|
+
should be one of the sources in s_true and s_est.
|
12
|
+
|
13
|
+
Calculates mean-over-srcs(mean-over-tf(| (|Sr(t, f)| + |Si(t, f)|) - (|Shr(t, f)| + |Shi(t, f)|) |))
|
14
|
+
|
15
|
+
Reference:
|
16
|
+
Self-attending RNN for Speech Enhancement to Improve Cross-corpus Generalization
|
17
|
+
Ashutosh Pandey, Student Member, IEEE and DeLiang Wang, Fellow, IEEE
|
18
|
+
https://doi.org/10.48550/arXiv.2105.12831
|
19
|
+
|
20
|
+
:param hypothesis: complex [frames, nsrc, bins]
|
21
|
+
:param reference: complex [frames, nsrc, bins]
|
22
|
+
:param with_log: enable log
|
23
|
+
:return: (error, error per bin, error per frame)
|
24
|
+
"""
|
25
|
+
# LSM = 1/(T*F) * sumtf(| (|Sr(t, f)| + |Si(t, f)|) - (|Shr(t, f)| + |Shi(t, f)|) |)
|
26
|
+
# LPCM = 1/2 * LSM(s, sh) + 1/2 * LSM(n, nh)
|
27
|
+
|
28
|
+
# [frames, nsrc, bins]
|
29
|
+
hypothesis_abs = np.abs(np.real(hypothesis)) + np.abs(np.imag(hypothesis))
|
30
|
+
reference_abs = np.abs(np.real(reference)) + np.abs(np.imag(reference))
|
31
|
+
err = np.abs(reference_abs - hypothesis_abs)
|
32
|
+
|
33
|
+
# mean over frames, nsrc for value per bin
|
34
|
+
err_b = np.mean(np.mean(err, axis=0), axis=0)
|
35
|
+
# mean over bins, nsrc for value per frame
|
36
|
+
err_f = np.mean(np.mean(err, axis=2), axis=1)
|
37
|
+
# mean over bins and frames, nsrc for scalar value
|
38
|
+
err = np.mean(np.mean(err, axis=(0, 2)), axis=0)
|
39
|
+
|
40
|
+
if with_log:
|
41
|
+
err_b = np.around(20 * np.log10(err_b + np.finfo(np.float32).eps), 3)
|
42
|
+
err_f = np.around(20 * np.log10(err_f + np.finfo(np.float32).eps), 3)
|
43
|
+
err = np.around(20 * np.log10(err + np.finfo(np.float32).eps), 3)
|
44
|
+
|
45
|
+
return err, err_b, err_f
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from ..constants import SAMPLE_RATE
|
4
|
+
|
5
|
+
|
6
|
+
def calc_pesq(
|
7
|
+
hypothesis: np.ndarray,
|
8
|
+
reference: np.ndarray,
|
9
|
+
error_value: float = 0.0,
|
10
|
+
sample_rate: int = SAMPLE_RATE,
|
11
|
+
) -> float:
|
12
|
+
"""Computes the PESQ score of hypothesis vs. reference
|
13
|
+
|
14
|
+
Upon error, assigns a value of 0, or user specified value in error_value
|
15
|
+
|
16
|
+
:param hypothesis: estimated audio
|
17
|
+
:param reference: reference audio
|
18
|
+
:param error_value: value to use if error occurs
|
19
|
+
:param sample_rate: sample rate of audio
|
20
|
+
:return: value between -0.5 to 4.5
|
21
|
+
"""
|
22
|
+
import warnings
|
23
|
+
|
24
|
+
from pesq import pesq
|
25
|
+
|
26
|
+
from .. import logger
|
27
|
+
|
28
|
+
try:
|
29
|
+
with warnings.catch_warnings():
|
30
|
+
warnings.simplefilter("ignore")
|
31
|
+
score = pesq(fs=sample_rate, ref=reference, deg=hypothesis, mode="wb")
|
32
|
+
except Exception as e:
|
33
|
+
logger.debug(f"PESQ error {e}")
|
34
|
+
score = error_value
|
35
|
+
|
36
|
+
return score
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def calc_phase_distance(
|
5
|
+
reference: np.ndarray, hypothesis: np.ndarray, eps: float = 1e-9
|
6
|
+
) -> tuple[float, np.ndarray, np.ndarray]:
|
7
|
+
"""Calculate weighted phase distance error (weight normalization over bins per frame)
|
8
|
+
|
9
|
+
:param reference: complex [frames, bins]
|
10
|
+
:param hypothesis: complex [frames, bins]
|
11
|
+
:param eps: epsilon value
|
12
|
+
:return: mean, mean per bin, mean per frame
|
13
|
+
"""
|
14
|
+
ang_diff = np.angle(reference) - np.angle(hypothesis)
|
15
|
+
phd_mod = (ang_diff + np.pi) % (2 * np.pi) - np.pi
|
16
|
+
rh_angle_diff = phd_mod * 180 / np.pi # angle diff in deg
|
17
|
+
|
18
|
+
# Use complex divide to intrinsically keep angle diff +/-180 deg, but avoid div by zero (real hyp)
|
19
|
+
# hyp_real = np.real(hypothesis)
|
20
|
+
# near_zeros = np.real(hyp_real) < eps
|
21
|
+
# hyp_real = hyp_real * (np.logical_not(near_zeros))
|
22
|
+
# hyp_real = hyp_real + (near_zeros * eps)
|
23
|
+
# hypothesis = hyp_real + 1j*np.imag(hypothesis)
|
24
|
+
# rh_angle_diff = np.angle(reference / hypothesis) * 180 / np.pi # angle diff +/-180
|
25
|
+
|
26
|
+
# weighted mean over all (scalar)
|
27
|
+
reference_mag = np.abs(reference)
|
28
|
+
ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
|
29
|
+
err = float(np.around(np.sum(ref_weight * rh_angle_diff), 3))
|
30
|
+
|
31
|
+
# weighted mean over frames (value per bin)
|
32
|
+
err_b = np.zeros(reference.shape[1])
|
33
|
+
for bi in range(reference.shape[1]):
|
34
|
+
ref_weight = reference_mag[:, bi] / (np.sum(reference_mag[:, bi], axis=0) + eps)
|
35
|
+
err_b[bi] = np.around(np.sum(ref_weight * rh_angle_diff[:, bi]), 3)
|
36
|
+
|
37
|
+
# weighted mean over bins (value per frame)
|
38
|
+
err_f = np.zeros(reference.shape[0])
|
39
|
+
for fi in range(reference.shape[0]):
|
40
|
+
ref_weight = reference_mag[fi, :] / (np.sum(reference_mag[fi, :]) + eps)
|
41
|
+
err_f[fi] = np.around(np.sum(ref_weight * rh_angle_diff[fi, :]), 3)
|
42
|
+
|
43
|
+
return err, err_b, err_f
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def calc_sa_sdr(
|
5
|
+
hypothesis: np.ndarray,
|
6
|
+
reference: np.ndarray,
|
7
|
+
with_scale: bool = False,
|
8
|
+
with_negate: bool = False,
|
9
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
10
|
+
"""Calculate source-aggregated SDR (signal distortion ratio) using all source inputs which are [samples, nsrc].
|
11
|
+
|
12
|
+
These should include a noise to be a complete mixture estimate, i.e.,
|
13
|
+
noise_est = sum-over-all-srcs(s_est(0:nsamples, :) - sum-over-non-noisesrc(s_est(0:nsamples, n))
|
14
|
+
should be one of the sources in reference (s_true) and hypothesis (s_est).
|
15
|
+
|
16
|
+
Calculates -10*log10(sumn(||sn||^2) / sumn(||sn - shn||^2)
|
17
|
+
Note: for SA method, sums are done independently on ref and error before division, vs. SDR and SI-SDR
|
18
|
+
where sum over n is taken after divide (before log). This is more stable in noise-only cases and also
|
19
|
+
when some sources are poorly estimated.
|
20
|
+
TBD: add soft-max option with eps and tau params
|
21
|
+
|
22
|
+
Reference:
|
23
|
+
SA-SDR: A Novel Loss Function for Separation of Meeting Style Data
|
24
|
+
Thilo von Neumann, Keisuke Kinoshita, Christoph Boeddeker, Marc Delcroix, Reinhold Haeb-Umbach
|
25
|
+
https://doi.org/10.48550/arXiv.2110.15581
|
26
|
+
|
27
|
+
:param hypothesis: [samples, nsrc]
|
28
|
+
:param reference: [samples, nsrc]
|
29
|
+
:param with_scale: enable scaling (scaling is same as in SI-SDR)
|
30
|
+
:param with_negate: enable negation (for use as a loss function)
|
31
|
+
:return: (sa_sdr, opt_scale)
|
32
|
+
"""
|
33
|
+
if with_scale:
|
34
|
+
# calc 1 x nsrc scaling factors
|
35
|
+
ref_energy = np.sum(reference**2, axis=0, keepdims=True)
|
36
|
+
# if ref_energy is zero, just set scaling to 1.0
|
37
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
38
|
+
opt_scale = np.sum(reference * hypothesis, axis=0, keepdims=True) / ref_energy
|
39
|
+
opt_scale[opt_scale == np.inf] = 1.0
|
40
|
+
opt_scale = np.nan_to_num(opt_scale, nan=1.0)
|
41
|
+
scaled_ref = opt_scale * reference
|
42
|
+
else:
|
43
|
+
scaled_ref = reference
|
44
|
+
opt_scale = np.ones((1, reference.shape[1]), dtype=float)
|
45
|
+
|
46
|
+
# multisrc sa-sdr, inputs must be [samples, nsrc]
|
47
|
+
err = scaled_ref - hypothesis
|
48
|
+
|
49
|
+
# -10*log10(sumk(||sk||^2) / sumk(||sk - shk||^2)
|
50
|
+
# sum over samples and sources
|
51
|
+
num = np.sum(reference**2)
|
52
|
+
den = np.sum(err**2)
|
53
|
+
if num == 0 and den == 0:
|
54
|
+
ratio = np.inf
|
55
|
+
else:
|
56
|
+
ratio = num / (den + np.finfo(np.float32).eps)
|
57
|
+
|
58
|
+
sa_sdr = 10 * np.log10(ratio)
|
59
|
+
|
60
|
+
if with_negate:
|
61
|
+
# for use as a loss function
|
62
|
+
sa_sdr = -sa_sdr
|
63
|
+
|
64
|
+
return sa_sdr, opt_scale
|
@@ -0,0 +1,25 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def calc_sample_weights(class_weights: np.ndarray, truth: np.ndarray) -> np.ndarray:
|
5
|
+
"""Calculate sample weights from class weights and a given truth with 2D or 3D shape.
|
6
|
+
|
7
|
+
Supports one-hot encoded multi-class or binary truth/labels
|
8
|
+
Note returns sum of weighted truth over classes, thus should also work for multi-label ? TBD
|
9
|
+
|
10
|
+
Inputs:
|
11
|
+
class_weights [num_classes, 1] weights for each class
|
12
|
+
truth [frames, timesteps, num_classes] or [frames, num_classes]
|
13
|
+
|
14
|
+
Returns:
|
15
|
+
sample_weights [frames, timesteps, 1] or [frames, 1]
|
16
|
+
"""
|
17
|
+
ts = truth.shape
|
18
|
+
cs = class_weights.shape
|
19
|
+
|
20
|
+
if ts[-1] == 1 and cs[0] == 2:
|
21
|
+
# Binary truth needs 2nd "none" truth dimension
|
22
|
+
truth = np.concatenate((truth, 1 - truth), axis=1)
|
23
|
+
|
24
|
+
# broadcast [num_classes, 1] over [frames, num_classes] or [frames, timesteps, num_classes]
|
25
|
+
return np.sum(class_weights * truth, axis=-1)
|
@@ -0,0 +1,82 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from ..datatypes import AudioF
|
4
|
+
from ..datatypes import Segsnr
|
5
|
+
from ..datatypes import SnrFBinMetrics
|
6
|
+
from ..datatypes import SnrFMetrics
|
7
|
+
|
8
|
+
|
9
|
+
def calc_segsnr_f(segsnr_f: Segsnr) -> SnrFMetrics:
|
10
|
+
"""Calculate metrics of snr_f truth data.
|
11
|
+
|
12
|
+
Includes mean and standard deviation of the linear values (usually energy)
|
13
|
+
and mean and standard deviation of the dB values (10 * log10).
|
14
|
+
"""
|
15
|
+
if np.count_nonzero(segsnr_f) == 0:
|
16
|
+
# If all entries are zeros
|
17
|
+
return SnrFMetrics(0, 0, -np.inf, 0)
|
18
|
+
|
19
|
+
tmp = np.ma.array(segsnr_f, mask=np.logical_not(np.isfinite(segsnr_f)))
|
20
|
+
if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
|
21
|
+
# If all entries are infinite
|
22
|
+
return SnrFMetrics(np.inf, 0, np.inf, 0)
|
23
|
+
|
24
|
+
snr_mean = np.mean(tmp, axis=0)
|
25
|
+
snr_std = np.std(tmp, axis=0)
|
26
|
+
|
27
|
+
tmp = 10 * np.ma.log10(tmp)
|
28
|
+
if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
|
29
|
+
# If all entries are masked, special case where all inputs are either 0 or infinite
|
30
|
+
snr_db_mean = -np.inf
|
31
|
+
snr_db_std = np.inf
|
32
|
+
else:
|
33
|
+
snr_db_mean = np.mean(tmp, axis=0)
|
34
|
+
snr_db_std = np.std(tmp, axis=0)
|
35
|
+
|
36
|
+
return SnrFMetrics(snr_mean, snr_std, snr_db_mean, snr_db_std)
|
37
|
+
|
38
|
+
|
39
|
+
def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
|
40
|
+
"""Calculate per-bin segmental SNR metrics.
|
41
|
+
|
42
|
+
Includes per-bin mean and standard deviation of the linear values
|
43
|
+
and mean and standard deviation of the dB values.
|
44
|
+
"""
|
45
|
+
if target_f.ndim != 2 and noise_f.ndim != 2:
|
46
|
+
raise ValueError("target_f and noise_f must have 2 dimensions")
|
47
|
+
|
48
|
+
segsnr_f = (np.abs(target_f) ** 2) / (np.abs(noise_f) ** 2 + np.finfo(np.float32).eps)
|
49
|
+
|
50
|
+
frames, bins = segsnr_f.shape
|
51
|
+
if np.count_nonzero(segsnr_f) == 0:
|
52
|
+
# If all entries are zeros
|
53
|
+
return SnrFBinMetrics(np.zeros(bins), np.zeros(bins), -np.inf * np.ones(bins), np.zeros(bins))
|
54
|
+
|
55
|
+
tmp = np.ma.array(segsnr_f, mask=np.logical_not(np.isfinite(segsnr_f)))
|
56
|
+
if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
|
57
|
+
# If all entries are infinite
|
58
|
+
return SnrFBinMetrics(
|
59
|
+
np.inf * np.ones(bins),
|
60
|
+
np.zeros(bins),
|
61
|
+
np.inf * np.ones(bins),
|
62
|
+
np.zeros(bins),
|
63
|
+
)
|
64
|
+
|
65
|
+
snr_mean = np.mean(tmp, axis=0)
|
66
|
+
snr_std = np.std(tmp, axis=0)
|
67
|
+
|
68
|
+
tmp = 10 * np.ma.log10(tmp)
|
69
|
+
if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
|
70
|
+
# If all entries are masked, special case where all inputs are either 0 or infinite
|
71
|
+
snr_db_mean = -np.inf * np.ones(bins)
|
72
|
+
snr_db_std = np.inf * np.ones(bins)
|
73
|
+
else:
|
74
|
+
snr_db_mean = np.mean(tmp, axis=0)
|
75
|
+
snr_db_std = np.std(tmp, axis=0)
|
76
|
+
|
77
|
+
return SnrFBinMetrics(
|
78
|
+
np.ma.getdata(snr_mean),
|
79
|
+
np.ma.getdata(snr_std),
|
80
|
+
np.ma.getdata(snr_db_mean),
|
81
|
+
np.ma.getdata(snr_db_std),
|
82
|
+
)
|