sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +93 -77
- sonusai/genmetrics.py +59 -46
- sonusai/genmix.py +116 -104
- sonusai/genmixdb.py +194 -153
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +15 -17
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +19 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +52 -85
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +40 -27
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
- sonusai-0.19.5.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/lsdb.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
"""sonusai lsdb
|
2
2
|
|
3
|
-
usage: lsdb [-hta] [-i MIXID] [-c
|
3
|
+
usage: lsdb [-hta] [-i MIXID] [-c CID] LOC
|
4
4
|
|
5
5
|
Options:
|
6
6
|
-h, --help
|
7
7
|
-i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
|
8
|
-
-c
|
8
|
+
-c CID, --class_index CID Analyze mixtures that contain this class index.
|
9
9
|
-t, --targets List all target files.
|
10
10
|
-a, --all_class_counts List all class counts.
|
11
11
|
|
@@ -15,6 +15,7 @@ Inputs:
|
|
15
15
|
LOC A SonusAI mixture database directory.
|
16
16
|
|
17
17
|
"""
|
18
|
+
|
18
19
|
import signal
|
19
20
|
|
20
21
|
from sonusai import logger
|
@@ -27,26 +28,22 @@ def signal_handler(_sig, _frame):
|
|
27
28
|
|
28
29
|
from sonusai import logger
|
29
30
|
|
30
|
-
logger.info(
|
31
|
+
logger.info("Canceled due to keyboard interrupt")
|
31
32
|
sys.exit(1)
|
32
33
|
|
33
34
|
|
34
35
|
signal.signal(signal.SIGINT, signal_handler)
|
35
36
|
|
36
37
|
|
37
|
-
def lsdb(
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
from sonusai import SonusAIError
|
46
|
-
from sonusai.metrics import calc_snr_f
|
38
|
+
def lsdb(
|
39
|
+
mixdb: MixtureDatabase,
|
40
|
+
mixids: GeneralizedIDs = "*",
|
41
|
+
class_index: int | None = None,
|
42
|
+
list_targets: bool = False,
|
43
|
+
all_class_counts: bool = False,
|
44
|
+
) -> None:
|
47
45
|
from sonusai.mixture import SAMPLE_RATE
|
48
|
-
from sonusai.
|
49
|
-
from sonusai.queries import get_mixids_from_truth_index
|
46
|
+
from sonusai.queries import get_mixids_from_class_indices
|
50
47
|
from sonusai.utils import consolidate_range
|
51
48
|
from sonusai.utils import max_text_width
|
52
49
|
from sonusai.utils import print_mixture_details
|
@@ -62,38 +59,40 @@ def lsdb(mixdb: MixtureDatabase,
|
|
62
59
|
logger.info(f'{"Targets":{desc_len}} {mixdb.num_target_files}')
|
63
60
|
logger.info(f'{"Noises":{desc_len}} {mixdb.num_noise_files}')
|
64
61
|
logger.info(f'{"Feature":{desc_len}} {mixdb.feature}')
|
65
|
-
logger.info(
|
66
|
-
|
62
|
+
logger.info(
|
63
|
+
f'{"Feature shape":{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} '
|
64
|
+
f'({mixdb.fg_stride * mixdb.feature_parameters} total params)'
|
65
|
+
)
|
67
66
|
logger.info(f'{"Feature samples":{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
|
68
|
-
logger.info(
|
69
|
-
|
67
|
+
logger.info(
|
68
|
+
f'{"Feature step samples":{desc_len}} {mixdb.feature_step_samples} samples ' f'({mixdb.feature_step_ms} ms)'
|
69
|
+
)
|
70
70
|
logger.info(f'{"Feature overlap":{desc_len}} {mixdb.fg_step / mixdb.fg_stride} ({mixdb.feature_step_ms} ms)')
|
71
71
|
logger.info(f'{"SNRs":{desc_len}} {mixdb.snrs}')
|
72
72
|
logger.info(f'{"Random SNRs":{desc_len}} {mixdb.random_snrs}')
|
73
73
|
logger.info(f'{"Classes":{desc_len}} {mixdb.num_classes}')
|
74
|
-
logger.info(f'{"Truth mutex":{desc_len}} {mixdb.truth_mutex}')
|
75
74
|
# TODO: fix class count
|
76
75
|
logger.info(f'{"Class count":{desc_len}} not supported')
|
77
76
|
# print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info)
|
78
77
|
# TODO: add class weight calculations here
|
79
|
-
logger.info(
|
78
|
+
logger.info("")
|
80
79
|
|
81
80
|
if list_targets:
|
82
|
-
logger.info(
|
81
|
+
logger.info("Target details:")
|
83
82
|
idx_len = max_text_width(mixdb.num_target_files)
|
84
83
|
for idx, target in enumerate(mixdb.target_files):
|
85
|
-
desc = f
|
86
|
-
logger.info(f
|
87
|
-
desc = f
|
88
|
-
logger.info(f
|
89
|
-
logger.info(
|
90
|
-
|
91
|
-
if
|
92
|
-
if 0 <=
|
93
|
-
raise
|
94
|
-
ids =
|
95
|
-
logger.info(f
|
96
|
-
logger.info(
|
84
|
+
desc = f" {idx:{idx_len}} Name"
|
85
|
+
logger.info(f"{desc:{desc_len}} {target.name}")
|
86
|
+
desc = f" {idx:{idx_len}} Truth index"
|
87
|
+
logger.info(f"{desc:{desc_len}} {target.class_indices}")
|
88
|
+
logger.info("")
|
89
|
+
|
90
|
+
if class_index is not None:
|
91
|
+
if 0 <= class_index > mixdb.num_classes:
|
92
|
+
raise ValueError(f"Given class_index is outside valid range of 1-{mixdb.num_classes}")
|
93
|
+
ids = get_mixids_from_class_indices(mixdb=mixdb, predicate=lambda x: x in [class_index])[class_index]
|
94
|
+
logger.info(f"Mixtures with class index {class_index}: {ids}")
|
95
|
+
logger.info("")
|
97
96
|
|
98
97
|
mixids = mixdb.mixids_to_list(mixids)
|
99
98
|
|
@@ -101,24 +100,13 @@ def lsdb(mixdb: MixtureDatabase,
|
|
101
100
|
print_mixture_details(mixdb=mixdb, mixid=mixids[0], desc_len=desc_len, print_fn=logger.info)
|
102
101
|
if all_class_counts:
|
103
102
|
# TODO: fix class count
|
104
|
-
logger.info(
|
103
|
+
logger.info("All class count not supported")
|
105
104
|
# print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info, all_class_counts=True)
|
106
105
|
else:
|
107
|
-
logger.info(
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
if mixid == mixids[0]:
|
112
|
-
truth_f = np.array(f['truth_f'])
|
113
|
-
else:
|
114
|
-
truth_f = np.concatenate((truth_f, np.array(f['truth_f'])))
|
115
|
-
|
116
|
-
snr_mean, snr_std, snr_db_mean, snr_db_std = calc_snr_f(truth_f)
|
117
|
-
|
118
|
-
logger.info('Truth')
|
119
|
-
logger.info(f' {"mean":^8s} {"std":^8s} {"db_mean":^8s} {"db_std":^8s}')
|
120
|
-
for c in range(len(snr_mean)):
|
121
|
-
logger.info(f' {snr_mean[c]:8.2f} {snr_std[c]:8.2f} {snr_db_mean[c]:8.2f} {snr_db_std[c]:8.2f}')
|
106
|
+
logger.info(
|
107
|
+
f"Calculating statistics from truth_f files for {len(mixids):,} mixtures" f" ({consolidate_range(mixids)})"
|
108
|
+
)
|
109
|
+
logger.info("Not supported")
|
122
110
|
|
123
111
|
|
124
112
|
def main() -> None:
|
@@ -132,28 +120,30 @@ def main() -> None:
|
|
132
120
|
|
133
121
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
134
122
|
|
135
|
-
mixid = args[
|
136
|
-
|
137
|
-
list_targets = args[
|
138
|
-
all_class_counts = args[
|
139
|
-
location = args[
|
123
|
+
mixid = args["--mixid"]
|
124
|
+
class_index = args["--class_index"]
|
125
|
+
list_targets = args["--targets"]
|
126
|
+
all_class_counts = args["--all_class_counts"]
|
127
|
+
location = args["LOC"]
|
140
128
|
|
141
|
-
if
|
142
|
-
|
129
|
+
if class_index is not None:
|
130
|
+
class_index = int(class_index)
|
143
131
|
|
144
|
-
create_file_handler(
|
132
|
+
create_file_handler("lsdb.log")
|
145
133
|
update_console_handler(False)
|
146
|
-
initial_log_messages(
|
134
|
+
initial_log_messages("lsdb")
|
147
135
|
|
148
|
-
logger.info(f
|
136
|
+
logger.info(f"Analyzing {location}")
|
149
137
|
|
150
138
|
mixdb = MixtureDatabase(location)
|
151
|
-
lsdb(
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
139
|
+
lsdb(
|
140
|
+
mixdb=mixdb,
|
141
|
+
mixids=mixid,
|
142
|
+
class_index=class_index,
|
143
|
+
list_targets=list_targets,
|
144
|
+
all_class_counts=all_class_counts,
|
145
|
+
)
|
156
146
|
|
157
147
|
|
158
|
-
if __name__ ==
|
148
|
+
if __name__ == "__main__":
|
159
149
|
main()
|
sonusai/main.py
CHANGED
@@ -9,6 +9,7 @@ Aaware Sound and Voice Machine Learning Framework. See 'sonusai help <command>'
|
|
9
9
|
for more information on a specific command.
|
10
10
|
|
11
11
|
"""
|
12
|
+
|
12
13
|
import signal
|
13
14
|
|
14
15
|
|
@@ -17,7 +18,7 @@ def signal_handler(_sig, _frame):
|
|
17
18
|
|
18
19
|
from sonusai import logger
|
19
20
|
|
20
|
-
logger.info(
|
21
|
+
logger.info("Canceled due to keyboard interrupt")
|
21
22
|
sys.exit(1)
|
22
23
|
|
23
24
|
|
@@ -33,11 +34,11 @@ def main() -> None:
|
|
33
34
|
plugins = {}
|
34
35
|
plugin_docstrings = []
|
35
36
|
for _, name, _ in iter_modules():
|
36
|
-
if name.startswith(
|
37
|
+
if name.startswith("sonusai_") and not name.startswith("sonusai_asr_"):
|
37
38
|
module = import_module(name)
|
38
39
|
plugins[name] = {
|
39
|
-
|
40
|
-
|
40
|
+
"commands": commands_list(module.commands_doc),
|
41
|
+
"basedir": module.BASEDIR,
|
41
42
|
}
|
42
43
|
plugin_docstrings.append(module.commands_doc)
|
43
44
|
|
@@ -47,12 +48,14 @@ def main() -> None:
|
|
47
48
|
from sonusai.utils import add_commands_to_docstring
|
48
49
|
from sonusai.utils import trim_docstring
|
49
50
|
|
50
|
-
args = docopt(
|
51
|
-
|
52
|
-
|
51
|
+
args = docopt(
|
52
|
+
trim_docstring(add_commands_to_docstring(__doc__, plugin_docstrings)),
|
53
|
+
version=__version__,
|
54
|
+
options_first=True,
|
55
|
+
)
|
53
56
|
|
54
|
-
command = args[
|
55
|
-
argv = args[
|
57
|
+
command = args["<command>"]
|
58
|
+
argv = args["<args>"]
|
56
59
|
|
57
60
|
import sys
|
58
61
|
from os.path import join
|
@@ -62,29 +65,29 @@ def main() -> None:
|
|
62
65
|
from sonusai import logger
|
63
66
|
|
64
67
|
base_commands = sonusai.commands_list()
|
65
|
-
if command ==
|
68
|
+
if command == "help":
|
66
69
|
if not argv:
|
67
|
-
exit(call([
|
70
|
+
exit(call(["sonusai", "-h"])) # noqa: S603, S607
|
68
71
|
elif argv[0] in base_commands:
|
69
|
-
exit(call([
|
72
|
+
exit(call(["python", f"{join(sonusai.BASEDIR, argv[0])}.py", "-h"])) # noqa: S603, S607
|
70
73
|
|
71
|
-
for
|
72
|
-
if argv[0] in data[
|
73
|
-
exit(call([
|
74
|
+
for data in plugins.values():
|
75
|
+
if argv[0] in data["commands"]:
|
76
|
+
exit(call(["python", f"{join(data['basedir'], argv[0])}.py", "-h"])) # noqa: S603, S607
|
74
77
|
|
75
78
|
logger.error(f"{argv[0]} is not a SonusAI command. See 'sonusai help'.")
|
76
79
|
sys.exit(1)
|
77
80
|
|
78
81
|
if command in base_commands:
|
79
|
-
exit(call([
|
82
|
+
exit(call(["python", f"{join(sonusai.BASEDIR, command)}.py", *argv])) # noqa: S603, S607
|
80
83
|
|
81
|
-
for
|
82
|
-
if command in data[
|
83
|
-
exit(call([
|
84
|
+
for data in plugins.values():
|
85
|
+
if command in data["commands"]:
|
86
|
+
exit(call(["python", f"{join(data['basedir'], command)}.py", *argv])) # noqa: S603, S607
|
84
87
|
|
85
88
|
logger.error(f"{command} is not a SonusAI command. See 'sonusai help'.")
|
86
89
|
sys.exit(1)
|
87
90
|
|
88
91
|
|
89
|
-
if __name__ ==
|
92
|
+
if __name__ == "__main__":
|
90
93
|
main()
|
sonusai/metrics/__init__.py
CHANGED
@@ -3,48 +3,53 @@ from sonusai.mixture.datatypes import AudioT
|
|
3
3
|
|
4
4
|
|
5
5
|
def _convert_str_with_factors_to_int(x: str) -> int:
|
6
|
-
if
|
7
|
-
return int(1000 * float(x.replace(
|
8
|
-
if
|
9
|
-
return int(1000000 * float(x.replace(
|
6
|
+
if "k" in x:
|
7
|
+
return int(1000 * float(x.replace("k", "")))
|
8
|
+
if "M" in x:
|
9
|
+
return int(1000000 * float(x.replace("M", "")))
|
10
10
|
return int(x)
|
11
11
|
|
12
12
|
|
13
|
-
def calc_audio_stats(audio: AudioT, win_len: float = None) -> AudioStatsMetrics:
|
13
|
+
def calc_audio_stats(audio: AudioT, win_len: float | None = None) -> AudioStatsMetrics:
|
14
14
|
from sonusai.mixture import SAMPLE_RATE
|
15
15
|
from sonusai.mixture import Transformer
|
16
16
|
|
17
|
-
args = [
|
17
|
+
args = ["stats"]
|
18
18
|
if win_len is not None:
|
19
|
-
args.extend([
|
19
|
+
args.extend(["-w", str(win_len)])
|
20
20
|
|
21
21
|
tfm = Transformer()
|
22
22
|
|
23
|
-
_, _, out = tfm.build(
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
23
|
+
_, _, out = tfm.build(
|
24
|
+
input_array=audio,
|
25
|
+
sample_rate_in=SAMPLE_RATE,
|
26
|
+
output_filepath="-n",
|
27
|
+
extra_args=args,
|
28
|
+
return_output=True,
|
29
|
+
)
|
30
|
+
|
31
|
+
if out is None:
|
32
|
+
raise SystemError("Call to sox failed")
|
28
33
|
|
29
34
|
stats = {}
|
30
|
-
lines = out.split(
|
35
|
+
lines = out.split("\n")
|
31
36
|
for line in lines:
|
32
37
|
split_line = line.split()
|
33
38
|
if len(split_line) == 0:
|
34
39
|
continue
|
35
40
|
value = split_line[-1]
|
36
|
-
key =
|
41
|
+
key = " ".join(split_line[:-1])
|
37
42
|
stats[key] = value
|
38
43
|
|
39
44
|
return AudioStatsMetrics(
|
40
|
-
dco=float(stats[
|
41
|
-
min=float(stats[
|
42
|
-
max=float(stats[
|
43
|
-
pkdb=float(stats[
|
44
|
-
lrms=float(stats[
|
45
|
-
pkr=float(stats[
|
46
|
-
tr=float(stats[
|
47
|
-
cr=float(stats[
|
48
|
-
fl=float(stats[
|
49
|
-
pkc=_convert_str_with_factors_to_int(stats[
|
45
|
+
dco=float(stats["DC offset"]),
|
46
|
+
min=float(stats["Min level"]),
|
47
|
+
max=float(stats["Max level"]),
|
48
|
+
pkdb=float(stats["Pk lev dB"]),
|
49
|
+
lrms=float(stats["RMS lev dB"]),
|
50
|
+
pkr=float(stats["RMS Pk dB"]),
|
51
|
+
tr=float(stats["RMS Tr dB"]),
|
52
|
+
cr=float(stats["Crest factor"]),
|
53
|
+
fl=float(stats["Flat factor"]),
|
54
|
+
pkc=_convert_str_with_factors_to_int(stats["Pk count"]),
|
50
55
|
)
|
@@ -5,9 +5,7 @@ from sonusai.mixture.datatypes import Truth
|
|
5
5
|
from sonusai.mixture.mixdb import MixtureDatabase
|
6
6
|
|
7
7
|
|
8
|
-
def calc_class_weights_from_truth(truth: Truth,
|
9
|
-
other_weight: float = None,
|
10
|
-
other_index: int = -1) -> np.ndarray:
|
8
|
+
def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = None, other_index: int = -1) -> np.ndarray:
|
11
9
|
"""Calculate class weights.
|
12
10
|
|
13
11
|
Supports non-existent classes (a problem with sklearn) where non-existent
|
@@ -54,10 +52,12 @@ def calc_class_weights_from_truth(truth: Truth,
|
|
54
52
|
return weights
|
55
53
|
|
56
54
|
|
57
|
-
def calc_class_weights_from_mixdb(
|
58
|
-
|
59
|
-
|
60
|
-
|
55
|
+
def calc_class_weights_from_mixdb(
|
56
|
+
mixdb: MixtureDatabase,
|
57
|
+
mixids: GeneralizedIDs | None = None,
|
58
|
+
other_weight: float = 1,
|
59
|
+
other_index: int = -1,
|
60
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
61
61
|
"""Calculate class weights using estimated feature counts from a mixture database.
|
62
62
|
|
63
63
|
Arguments:
|
@@ -4,10 +4,9 @@ from sonusai.mixture.datatypes import Predict
|
|
4
4
|
from sonusai.mixture.datatypes import Truth
|
5
5
|
|
6
6
|
|
7
|
-
def calc_optimal_thresholds(
|
8
|
-
|
9
|
-
|
10
|
-
truth_thr: float = 0.5) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
7
|
+
def calc_optimal_thresholds(
|
8
|
+
truth: Truth, predict: Predict, timesteps: int = 0, truth_thr: float = 0.5
|
9
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
11
10
|
"""Calculates optimal thresholds for each class from one-hot prediction and truth data where both are
|
12
11
|
one-hot probabilities (or quantized decisions) with size [frames, num_classes] or [frames, timesteps, num_classes].
|
13
12
|
|
@@ -25,14 +24,13 @@ def calc_optimal_thresholds(truth: Truth,
|
|
25
24
|
from sklearn.metrics import roc_auc_score
|
26
25
|
from sklearn.metrics import roc_curve
|
27
26
|
|
28
|
-
from sonusai import SonusAIError
|
29
27
|
from sonusai.utils import get_num_classes_from_predict
|
30
28
|
from sonusai.utils import reshape_outputs
|
31
29
|
|
32
30
|
if truth.shape != predict.shape:
|
33
|
-
raise
|
31
|
+
raise ValueError("truth and predict are not the same shape")
|
34
32
|
|
35
|
-
predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps)
|
33
|
+
predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps) # type: ignore[assignment]
|
36
34
|
num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
|
37
35
|
|
38
36
|
# Apply decision to truth input
|
sonusai/metrics/calc_pcm.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
3
|
|
4
|
-
def calc_pcm(
|
5
|
-
|
6
|
-
|
4
|
+
def calc_pcm(
|
5
|
+
hypothesis: np.ndarray, reference: np.ndarray, with_log: bool = False
|
6
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
7
7
|
"""Calculate phase constrained magnitude error
|
8
8
|
|
9
9
|
These must include a noise to make a complete mixture estimate, i.e.,
|
sonusai/metrics/calc_pesq.py
CHANGED
@@ -3,10 +3,12 @@ import numpy as np
|
|
3
3
|
from sonusai.mixture.constants import SAMPLE_RATE
|
4
4
|
|
5
5
|
|
6
|
-
def calc_pesq(
|
7
|
-
|
8
|
-
|
9
|
-
|
6
|
+
def calc_pesq(
|
7
|
+
hypothesis: np.ndarray,
|
8
|
+
reference: np.ndarray,
|
9
|
+
error_value: float = 0.0,
|
10
|
+
sample_rate: int = SAMPLE_RATE,
|
11
|
+
) -> float:
|
10
12
|
"""Computes the PESQ score of hypothesis vs. reference
|
11
13
|
|
12
14
|
Upon error, assigns a value of 0, or user specified value in error_value
|
@@ -22,12 +24,13 @@ def calc_pesq(hypothesis: np.ndarray,
|
|
22
24
|
from pesq import pesq
|
23
25
|
|
24
26
|
from sonusai import logger
|
27
|
+
|
25
28
|
try:
|
26
29
|
with warnings.catch_warnings():
|
27
|
-
warnings.simplefilter(
|
28
|
-
score = pesq(fs=sample_rate, ref=reference, deg=hypothesis, mode=
|
30
|
+
warnings.simplefilter("ignore")
|
31
|
+
score = pesq(fs=sample_rate, ref=reference, deg=hypothesis, mode="wb")
|
29
32
|
except Exception as e:
|
30
|
-
logger.debug(f
|
33
|
+
logger.debug(f"PESQ error {e}")
|
31
34
|
score = error_value
|
32
35
|
|
33
36
|
return score
|
@@ -1,9 +1,9 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
3
|
|
4
|
-
def calc_phase_distance(
|
5
|
-
|
6
|
-
|
4
|
+
def calc_phase_distance(
|
5
|
+
reference: np.ndarray, hypothesis: np.ndarray, eps: float = 1e-9
|
6
|
+
) -> tuple[float, np.ndarray, np.ndarray]:
|
7
7
|
"""Calculate weighted phase distance error (weight normalization over bins per frame)
|
8
8
|
|
9
9
|
:param reference: complex [frames, bins]
|
sonusai/metrics/calc_sa_sdr.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
3
|
|
4
|
-
def calc_sa_sdr(
|
5
|
-
|
6
|
-
|
7
|
-
|
4
|
+
def calc_sa_sdr(
|
5
|
+
hypothesis: np.ndarray,
|
6
|
+
reference: np.ndarray,
|
7
|
+
with_scale: bool = False,
|
8
|
+
with_negate: bool = False,
|
9
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
8
10
|
"""Calculate source-aggregated SDR (signal distortion ratio) using all source inputs which are [samples, nsrc].
|
9
11
|
|
10
12
|
These should include a noise to be a complete mixture estimate, i.e.,
|
@@ -30,9 +32,9 @@ def calc_sa_sdr(hypothesis: np.ndarray,
|
|
30
32
|
"""
|
31
33
|
if with_scale:
|
32
34
|
# calc 1 x nsrc scaling factors
|
33
|
-
ref_energy = np.sum(reference
|
35
|
+
ref_energy = np.sum(reference**2, axis=0, keepdims=True)
|
34
36
|
# if ref_energy is zero, just set scaling to 1.0
|
35
|
-
with np.errstate(divide=
|
37
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
36
38
|
opt_scale = np.sum(reference * hypothesis, axis=0, keepdims=True) / ref_energy
|
37
39
|
opt_scale[opt_scale == np.inf] = 1.0
|
38
40
|
opt_scale = np.nan_to_num(opt_scale, nan=1.0)
|
@@ -46,8 +48,8 @@ def calc_sa_sdr(hypothesis: np.ndarray,
|
|
46
48
|
|
47
49
|
# -10*log10(sumk(||sk||^2) / sumk(||sk - shk||^2)
|
48
50
|
# sum over samples and sources
|
49
|
-
num = np.sum(reference
|
50
|
-
den = np.sum(err
|
51
|
+
num = np.sum(reference**2)
|
52
|
+
den = np.sum(err**2)
|
51
53
|
if num == 0 and den == 0:
|
52
54
|
ratio = np.inf
|
53
55
|
else:
|
sonusai/metrics/calc_segsnr_f.py
CHANGED
@@ -33,10 +33,7 @@ def calc_segsnr_f(segsnr_f: Segsnr) -> SnrFMetrics:
|
|
33
33
|
snr_db_mean = np.mean(tmp, axis=0)
|
34
34
|
snr_db_std = np.std(tmp, axis=0)
|
35
35
|
|
36
|
-
return SnrFMetrics(snr_mean,
|
37
|
-
snr_std,
|
38
|
-
snr_db_mean,
|
39
|
-
snr_db_std)
|
36
|
+
return SnrFMetrics(snr_mean, snr_std, snr_db_mean, snr_db_std)
|
40
37
|
|
41
38
|
|
42
39
|
def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
|
@@ -46,25 +43,24 @@ def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
|
|
46
43
|
and mean and standard deviation of the dB values.
|
47
44
|
"""
|
48
45
|
if target_f.ndim != 2 and noise_f.ndim != 2:
|
49
|
-
raise ValueError(
|
46
|
+
raise ValueError("target_f and noise_f must have 2 dimensions")
|
50
47
|
|
51
48
|
segsnr_f = (np.abs(target_f) ** 2) / (np.abs(noise_f) ** 2)
|
52
49
|
|
53
50
|
frames, bins = segsnr_f.shape
|
54
51
|
if np.count_nonzero(segsnr_f) == 0:
|
55
52
|
# If all entries are zeros
|
56
|
-
return SnrFBinMetrics(np.zeros(bins),
|
57
|
-
np.zeros(bins),
|
58
|
-
-np.inf * np.ones(bins),
|
59
|
-
np.zeros(bins))
|
53
|
+
return SnrFBinMetrics(np.zeros(bins), np.zeros(bins), -np.inf * np.ones(bins), np.zeros(bins))
|
60
54
|
|
61
55
|
tmp = np.ma.array(segsnr_f, mask=np.logical_not(np.isfinite(segsnr_f)))
|
62
56
|
if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
|
63
57
|
# If all entries are infinite
|
64
|
-
return SnrFBinMetrics(
|
65
|
-
|
66
|
-
|
67
|
-
|
58
|
+
return SnrFBinMetrics(
|
59
|
+
np.inf * np.ones(bins),
|
60
|
+
np.zeros(bins),
|
61
|
+
np.inf * np.ones(bins),
|
62
|
+
np.zeros(bins),
|
63
|
+
)
|
68
64
|
|
69
65
|
snr_mean = np.mean(tmp, axis=0)
|
70
66
|
snr_std = np.std(tmp, axis=0)
|
@@ -78,7 +74,9 @@ def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
|
|
78
74
|
snr_db_mean = np.mean(tmp, axis=0)
|
79
75
|
snr_db_std = np.std(tmp, axis=0)
|
80
76
|
|
81
|
-
return SnrFBinMetrics(
|
82
|
-
|
83
|
-
|
84
|
-
|
77
|
+
return SnrFBinMetrics(
|
78
|
+
np.ma.getdata(snr_mean),
|
79
|
+
np.ma.getdata(snr_std),
|
80
|
+
np.ma.getdata(snr_db_mean),
|
81
|
+
np.ma.getdata(snr_db_std),
|
82
|
+
)
|