sonusai 0.20.2__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +240 -76
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +23 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -17
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +5 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +484 -611
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +931 -669
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/METADATA +4 -2
- sonusai-1.0.1.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.2.dist-info/RECORD +0 -128
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/WHEEL +0 -0
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/entry_points.txt +0 -0
sonusai/lsdb.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
"""sonusai lsdb
|
2
2
|
|
3
|
-
usage: lsdb [-
|
3
|
+
usage: lsdb [-hsa] [-i MIXID] [-c CID] LOC
|
4
4
|
|
5
5
|
Options:
|
6
6
|
-h, --help
|
7
7
|
-i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
|
8
8
|
-c CID, --class_index CID Analyze mixtures that contain this class index.
|
9
|
-
-
|
9
|
+
-s, --sources List all source files.
|
10
10
|
-a, --all_class_counts List all class counts.
|
11
11
|
|
12
12
|
List mixture data information from a SonusAI mixture database.
|
@@ -16,25 +16,10 @@ Inputs:
|
|
16
16
|
|
17
17
|
"""
|
18
18
|
|
19
|
-
import
|
20
|
-
|
21
|
-
from sonusai import logger
|
22
|
-
from sonusai.mixture import GeneralizedIDs
|
19
|
+
from sonusai.datatypes import GeneralizedIDs
|
23
20
|
from sonusai.mixture import MixtureDatabase
|
24
21
|
|
25
22
|
|
26
|
-
def signal_handler(_sig, _frame):
|
27
|
-
import sys
|
28
|
-
|
29
|
-
from sonusai import logger
|
30
|
-
|
31
|
-
logger.info("Canceled due to keyboard interrupt")
|
32
|
-
sys.exit(1)
|
33
|
-
|
34
|
-
|
35
|
-
signal.signal(signal.SIGINT, signal_handler)
|
36
|
-
|
37
|
-
|
38
23
|
def lsdb(
|
39
24
|
mixdb: MixtureDatabase,
|
40
25
|
mixids: GeneralizedIDs = "*",
|
@@ -42,7 +27,8 @@ def lsdb(
|
|
42
27
|
list_targets: bool = False,
|
43
28
|
all_class_counts: bool = False,
|
44
29
|
) -> None:
|
45
|
-
from sonusai
|
30
|
+
from sonusai import logger
|
31
|
+
from sonusai.constants import SAMPLE_RATE
|
46
32
|
from sonusai.queries import get_mixids_from_class_indices
|
47
33
|
from sonusai.utils import consolidate_range
|
48
34
|
from sonusai.utils import max_text_width
|
@@ -54,38 +40,36 @@ def lsdb(
|
|
54
40
|
total_samples = mixdb.total_samples()
|
55
41
|
total_duration = total_samples / SAMPLE_RATE
|
56
42
|
|
57
|
-
logger.info(f'
|
58
|
-
logger.info(f'
|
59
|
-
logger.info(f'
|
60
|
-
logger.info(f'
|
61
|
-
logger.info(f'{"Feature":{desc_len}} {mixdb.feature}')
|
43
|
+
logger.info(f"{'Mixtures':{desc_len}} {mixdb.num_mixtures}")
|
44
|
+
logger.info(f"{'Duration':{desc_len}} {seconds_to_hms(seconds=total_duration)}")
|
45
|
+
logger.info(f"{'Sources':{desc_len}} {mixdb.num_source_files}")
|
46
|
+
logger.info(f"{'Feature':{desc_len}} {mixdb.feature}")
|
62
47
|
logger.info(
|
63
|
-
f'
|
64
|
-
f
|
48
|
+
f"{'Feature shape':{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} "
|
49
|
+
f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
|
65
50
|
)
|
66
|
-
logger.info(f'
|
51
|
+
logger.info(f"{'Feature samples':{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
|
67
52
|
logger.info(
|
68
|
-
f'
|
53
|
+
f"{'Feature step samples':{desc_len}} {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)"
|
69
54
|
)
|
70
|
-
logger.info(f'
|
71
|
-
logger.info(f'
|
72
|
-
logger.info(f'
|
73
|
-
logger.info(f'
|
55
|
+
logger.info(f"{'Feature overlap':{desc_len}} {mixdb.fg_step / mixdb.fg_stride} ({mixdb.feature_step_ms} ms)")
|
56
|
+
logger.info(f"{'SNRs':{desc_len}} {mixdb.snrs}")
|
57
|
+
logger.info(f"{'Random SNRs':{desc_len}} {mixdb.random_snrs}")
|
58
|
+
logger.info(f"{'Classes':{desc_len}} {mixdb.num_classes}")
|
74
59
|
# TODO: fix class count
|
75
|
-
logger.info(f'
|
60
|
+
logger.info(f"{'Class count':{desc_len}} not supported")
|
76
61
|
# print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info)
|
77
62
|
# TODO: add class weight calculations here
|
78
63
|
logger.info("")
|
79
64
|
|
80
65
|
if list_targets:
|
81
|
-
logger.info("
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
logger.info(
|
88
|
-
logger.info("")
|
66
|
+
logger.info("Source details:")
|
67
|
+
for category, sources in mixdb.source_files.items():
|
68
|
+
print(f" {category}:")
|
69
|
+
for source in sources:
|
70
|
+
logger.info(f"{' Name':{desc_len}} {source.name}")
|
71
|
+
logger.info(f"{' Truth index':{desc_len}} {source.class_indices}")
|
72
|
+
logger.info("")
|
89
73
|
|
90
74
|
if class_index is not None:
|
91
75
|
if 0 <= class_index > mixdb.num_classes:
|
@@ -104,7 +88,7 @@ def lsdb(
|
|
104
88
|
# print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info, all_class_counts=True)
|
105
89
|
else:
|
106
90
|
logger.info(
|
107
|
-
f"Calculating statistics from truth_f files for {len(mixids):,} mixtures
|
91
|
+
f"Calculating statistics from truth_f files for {len(mixids):,} mixtures ({consolidate_range(mixids)})"
|
108
92
|
)
|
109
93
|
logger.info("Not supported")
|
110
94
|
|
@@ -112,13 +96,10 @@ def lsdb(
|
|
112
96
|
def main() -> None:
|
113
97
|
from docopt import docopt
|
114
98
|
|
115
|
-
import
|
116
|
-
from sonusai import create_file_handler
|
117
|
-
from sonusai import initial_log_messages
|
118
|
-
from sonusai import update_console_handler
|
99
|
+
from sonusai import __version__ as sai_version
|
119
100
|
from sonusai.utils import trim_docstring
|
120
101
|
|
121
|
-
args = docopt(trim_docstring(__doc__), version=
|
102
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
122
103
|
|
123
104
|
mixid = args["--mixid"]
|
124
105
|
class_index = args["--class_index"]
|
@@ -126,6 +107,11 @@ def main() -> None:
|
|
126
107
|
all_class_counts = args["--all_class_counts"]
|
127
108
|
location = args["LOC"]
|
128
109
|
|
110
|
+
from sonusai import create_file_handler
|
111
|
+
from sonusai import initial_log_messages
|
112
|
+
from sonusai import logger
|
113
|
+
from sonusai import update_console_handler
|
114
|
+
|
129
115
|
if class_index is not None:
|
130
116
|
class_index = int(class_index)
|
131
117
|
|
@@ -146,4 +132,11 @@ def main() -> None:
|
|
146
132
|
|
147
133
|
|
148
134
|
if __name__ == "__main__":
|
149
|
-
|
135
|
+
from sonusai import exception_handler
|
136
|
+
from sonusai.utils import register_keyboard_interrupt
|
137
|
+
|
138
|
+
register_keyboard_interrupt()
|
139
|
+
try:
|
140
|
+
main()
|
141
|
+
except Exception as e:
|
142
|
+
exception_handler(e)
|
sonusai/main.py
CHANGED
@@ -10,21 +10,6 @@ for more information on a specific command.
|
|
10
10
|
|
11
11
|
"""
|
12
12
|
|
13
|
-
import signal
|
14
|
-
|
15
|
-
|
16
|
-
def signal_handler(_sig, _frame):
|
17
|
-
import sys
|
18
|
-
|
19
|
-
from sonusai import logger
|
20
|
-
|
21
|
-
logger.info("Canceled due to keyboard interrupt")
|
22
|
-
sys.exit(1)
|
23
|
-
|
24
|
-
|
25
|
-
signal.signal(signal.SIGINT, signal_handler)
|
26
|
-
|
27
|
-
|
28
13
|
def main() -> None:
|
29
14
|
from importlib import import_module
|
30
15
|
from pkgutil import iter_modules
|
@@ -44,13 +29,13 @@ def main() -> None:
|
|
44
29
|
|
45
30
|
from docopt import docopt
|
46
31
|
|
47
|
-
from sonusai import __version__
|
32
|
+
from sonusai import __version__ as sai_version
|
48
33
|
from sonusai.utils import add_commands_to_docstring
|
49
34
|
from sonusai.utils import trim_docstring
|
50
35
|
|
51
36
|
args = docopt(
|
52
37
|
trim_docstring(add_commands_to_docstring(__doc__, plugin_docstrings)),
|
53
|
-
version=
|
38
|
+
version=sai_version,
|
54
39
|
options_first=True,
|
55
40
|
)
|
56
41
|
|
@@ -61,15 +46,16 @@ def main() -> None:
|
|
61
46
|
from os.path import join
|
62
47
|
from subprocess import call
|
63
48
|
|
64
|
-
import
|
49
|
+
from sonusai import BASEDIR
|
50
|
+
from sonusai import commands_list
|
65
51
|
from sonusai import logger
|
66
52
|
|
67
|
-
base_commands =
|
53
|
+
base_commands = commands_list()
|
68
54
|
if command == "help":
|
69
55
|
if not argv:
|
70
56
|
exit(call(["sonusai", "-h"])) # noqa: S603, S607
|
71
57
|
elif argv[0] in base_commands:
|
72
|
-
exit(call(["python", f"{join(
|
58
|
+
exit(call(["python", f"{join(BASEDIR, argv[0])}.py", "-h"])) # noqa: S603, S607
|
73
59
|
|
74
60
|
for data in plugins.values():
|
75
61
|
if argv[0] in data["commands"]:
|
@@ -79,7 +65,7 @@ def main() -> None:
|
|
79
65
|
sys.exit(1)
|
80
66
|
|
81
67
|
if command in base_commands:
|
82
|
-
exit(call(["python", f"{join(
|
68
|
+
exit(call(["python", f"{join(BASEDIR, command)}.py", *argv])) # noqa: S603, S607
|
83
69
|
|
84
70
|
for data in plugins.values():
|
85
71
|
if command in data["commands"]:
|
@@ -90,4 +76,11 @@ def main() -> None:
|
|
90
76
|
|
91
77
|
|
92
78
|
if __name__ == "__main__":
|
93
|
-
|
79
|
+
from sonusai import exception_handler
|
80
|
+
from sonusai.utils import register_keyboard_interrupt
|
81
|
+
|
82
|
+
register_keyboard_interrupt()
|
83
|
+
try:
|
84
|
+
main()
|
85
|
+
except Exception as e:
|
86
|
+
exception_handler(e)
|
@@ -1,5 +1,5 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from ..datatypes import AudioStatsMetrics
|
2
|
+
from ..datatypes import AudioT
|
3
3
|
|
4
4
|
|
5
5
|
def _convert_str_with_factors_to_int(x: str) -> int:
|
@@ -11,22 +11,9 @@ def _convert_str_with_factors_to_int(x: str) -> int:
|
|
11
11
|
|
12
12
|
|
13
13
|
def calc_audio_stats(audio: AudioT, win_len: float | None = None) -> AudioStatsMetrics:
|
14
|
-
from
|
15
|
-
from sonusai.mixture import Transformer
|
14
|
+
from ..mixture.sox_effects import sox_stats
|
16
15
|
|
17
|
-
|
18
|
-
if win_len is not None:
|
19
|
-
args.extend(["-w", str(win_len)])
|
20
|
-
|
21
|
-
tfm = Transformer()
|
22
|
-
|
23
|
-
_, _, out = tfm.build(
|
24
|
-
input_array=audio,
|
25
|
-
sample_rate_in=SAMPLE_RATE,
|
26
|
-
output_filepath="-n",
|
27
|
-
extra_args=args,
|
28
|
-
return_output=True,
|
29
|
-
)
|
16
|
+
out = sox_stats(audio, win_len)
|
30
17
|
|
31
18
|
if out is None:
|
32
19
|
raise SystemError("Call to sox failed")
|
@@ -1,8 +1,8 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from
|
4
|
-
from
|
5
|
-
from
|
3
|
+
from ..datatypes import GeneralizedIDs
|
4
|
+
from ..datatypes import Truth
|
5
|
+
from ..mixture.mixdb import MixtureDatabase
|
6
6
|
|
7
7
|
|
8
8
|
def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = None, other_index: int = -1) -> np.ndarray:
|
@@ -74,7 +74,7 @@ def calc_class_weights_from_mixdb(
|
|
74
74
|
weights: Class weights. [num_classes, 1]
|
75
75
|
Note: for Keras use dict(enumerate(weights))
|
76
76
|
"""
|
77
|
-
from
|
77
|
+
from ..mixture.class_count import get_class_count_from_mixids
|
78
78
|
|
79
79
|
count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
|
80
80
|
total_features = sum(count)
|
@@ -1,11 +1,14 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from
|
4
|
-
from
|
3
|
+
from ..datatypes import Predict
|
4
|
+
from ..datatypes import Truth
|
5
5
|
|
6
6
|
|
7
7
|
def calc_optimal_thresholds(
|
8
|
-
truth: Truth,
|
8
|
+
truth: Truth,
|
9
|
+
predict: Predict,
|
10
|
+
timesteps: int = 0,
|
11
|
+
truth_thr: float = 0.5,
|
9
12
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
10
13
|
"""Calculates optimal thresholds for each class from one-hot prediction and truth data where both are
|
11
14
|
one-hot probabilities (or quantized decisions) with size [frames, num_classes] or [frames, timesteps, num_classes].
|
@@ -24,8 +27,8 @@ def calc_optimal_thresholds(
|
|
24
27
|
from sklearn.metrics import roc_auc_score
|
25
28
|
from sklearn.metrics import roc_curve
|
26
29
|
|
27
|
-
from
|
28
|
-
from
|
30
|
+
from ..utils.reshape import get_num_classes_from_predict
|
31
|
+
from ..utils.reshape import reshape_outputs
|
29
32
|
|
30
33
|
if truth.shape != predict.shape:
|
31
34
|
raise ValueError("truth and predict are not the same shape")
|
sonusai/metrics/calc_pesq.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from
|
3
|
+
from ..constants import SAMPLE_RATE
|
4
4
|
|
5
5
|
|
6
6
|
def calc_pesq(
|
@@ -23,7 +23,7 @@ def calc_pesq(
|
|
23
23
|
|
24
24
|
from pesq import pesq
|
25
25
|
|
26
|
-
from
|
26
|
+
from .. import logger
|
27
27
|
|
28
28
|
try:
|
29
29
|
with warnings.catch_warnings():
|
sonusai/metrics/calc_segsnr_f.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from
|
3
|
+
from ..datatypes import AudioF
|
4
|
+
from ..datatypes import Segsnr
|
5
|
+
from ..datatypes import SnrFBinMetrics
|
6
|
+
from ..datatypes import SnrFMetrics
|
7
7
|
|
8
8
|
|
9
9
|
def calc_segsnr_f(segsnr_f: Segsnr) -> SnrFMetrics:
|
sonusai/metrics/calc_speech.py
CHANGED
@@ -1,18 +1,23 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from
|
4
|
-
from
|
5
|
-
|
3
|
+
from ..constants import SAMPLE_RATE
|
4
|
+
from ..datatypes import SpeechMetrics
|
6
5
|
from .calc_pesq import calc_pesq
|
7
6
|
|
8
7
|
|
9
|
-
def calc_speech(
|
10
|
-
|
8
|
+
def calc_speech(
|
9
|
+
hypothesis: np.ndarray,
|
10
|
+
reference: np.ndarray,
|
11
|
+
pesq: float | None = None,
|
12
|
+
sample_rate: int = SAMPLE_RATE,
|
13
|
+
) -> SpeechMetrics:
|
14
|
+
"""Calculate speech metrics c_sig, c_bak, and c_ovl.
|
11
15
|
|
12
16
|
These are all related and thus included in one function. Reference: matlab script "compute_metrics.m".
|
13
17
|
|
14
18
|
:param hypothesis: estimated audio
|
15
19
|
:param reference: reference audio
|
20
|
+
:param pesq: pesq
|
16
21
|
:param sample_rate: sample rate of audio
|
17
22
|
:return: SpeechMetrics named tuple
|
18
23
|
"""
|
@@ -36,18 +41,21 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
|
|
36
41
|
seg_snr = np.mean(segsnr_dist)
|
37
42
|
|
38
43
|
# PESQ
|
39
|
-
|
44
|
+
if pesq is None:
|
45
|
+
pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
40
46
|
|
41
47
|
# Now compute the composite measures
|
42
|
-
csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 *
|
43
|
-
cbak = float(np.clip(1.634 + 0.478 *
|
44
|
-
covl = float(np.clip(1.594 + 0.805 *
|
48
|
+
csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * pesq - 0.009 * wss_dist, 1, 5))
|
49
|
+
cbak = float(np.clip(1.634 + 0.478 * pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
|
50
|
+
covl = float(np.clip(1.594 + 0.805 * pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
|
45
51
|
|
46
|
-
return SpeechMetrics(
|
52
|
+
return SpeechMetrics(csig, cbak, covl)
|
47
53
|
|
48
54
|
|
49
55
|
def _calc_weighted_spectral_slope_measure(
|
50
|
-
hypothesis: np.ndarray,
|
56
|
+
hypothesis: np.ndarray,
|
57
|
+
reference: np.ndarray,
|
58
|
+
sample_rate: int = SAMPLE_RATE,
|
51
59
|
) -> np.ndarray:
|
52
60
|
from scipy.fftpack import fft
|
53
61
|
|
@@ -250,7 +258,9 @@ def _calc_weighted_spectral_slope_measure(
|
|
250
258
|
|
251
259
|
|
252
260
|
def _calc_log_likelihood_ratio_measure(
|
253
|
-
hypothesis: np.ndarray,
|
261
|
+
hypothesis: np.ndarray,
|
262
|
+
reference: np.ndarray,
|
263
|
+
sample_rate: int = SAMPLE_RATE,
|
254
264
|
) -> np.ndarray:
|
255
265
|
from scipy.linalg import toeplitz
|
256
266
|
|
@@ -296,7 +306,9 @@ def _calc_log_likelihood_ratio_measure(
|
|
296
306
|
|
297
307
|
|
298
308
|
def _calc_snr(
|
299
|
-
hypothesis: np.ndarray,
|
309
|
+
hypothesis: np.ndarray,
|
310
|
+
reference: np.ndarray,
|
311
|
+
sample_rate: int = SAMPLE_RATE,
|
300
312
|
) -> tuple[float, np.ndarray]:
|
301
313
|
# The lengths of the reference and hypothesis must be the same.
|
302
314
|
reference_length = len(reference)
|
sonusai/metrics/class_summary.py
CHANGED
@@ -2,10 +2,10 @@
|
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
4
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
5
|
+
from ..datatypes import GeneralizedIDs
|
6
|
+
from ..datatypes import Predict
|
7
|
+
from ..datatypes import Truth
|
8
|
+
from ..mixture.mixdb import MixtureDatabase
|
9
9
|
|
10
10
|
|
11
11
|
def class_summary(
|
@@ -31,7 +31,7 @@ def class_summary(
|
|
31
31
|
macro avg 0.85 0.83 0.84 0.05 0.96 3768
|
32
32
|
micro-avgwo
|
33
33
|
"""
|
34
|
-
from
|
34
|
+
from ..metrics.one_hot import one_hot
|
35
35
|
|
36
36
|
num_classes = truth_f.shape[1]
|
37
37
|
|
@@ -58,11 +58,11 @@ def class_summary(
|
|
58
58
|
else:
|
59
59
|
row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
|
60
60
|
|
61
|
-
df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n)
|
61
|
+
df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
|
62
62
|
|
63
63
|
# [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
|
64
64
|
avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
|
65
|
-
dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n)
|
65
|
+
dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
|
66
66
|
|
67
67
|
# dfblank = pd.DataFrame([''])
|
68
68
|
# pd.concat([df, dfblank, dfblank, dfavg])
|
@@ -2,10 +2,10 @@
|
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
4
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
5
|
+
from ..datatypes import GeneralizedIDs
|
6
|
+
from ..datatypes import Predict
|
7
|
+
from ..datatypes import Truth
|
8
|
+
from ..mixture.mixdb import MixtureDatabase
|
9
9
|
|
10
10
|
|
11
11
|
def confusion_matrix_summary(
|
@@ -30,7 +30,7 @@ def confusion_matrix_summary(
|
|
30
30
|
|
31
31
|
Returns pandas dataframes of confusion matrix cmdf and normalized confusion matrix cmndf.
|
32
32
|
"""
|
33
|
-
from
|
33
|
+
from ..metrics.one_hot import one_hot
|
34
34
|
|
35
35
|
num_classes = truth_f.shape[1]
|
36
36
|
# TODO: re-work for modern mixdb API
|
sonusai/metrics/one_hot.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from
|
4
|
-
from
|
3
|
+
from ..datatypes import Predict
|
4
|
+
from ..datatypes import Truth
|
5
5
|
|
6
6
|
|
7
7
|
def one_hot(
|
@@ -53,8 +53,8 @@ def one_hot(
|
|
53
53
|
from sklearn.metrics import precision_recall_fscore_support
|
54
54
|
from sklearn.metrics import roc_auc_score
|
55
55
|
|
56
|
-
from
|
57
|
-
from
|
56
|
+
from ..utils.reshape import get_num_classes_from_predict
|
57
|
+
from ..utils.reshape import reshape_outputs
|
58
58
|
|
59
59
|
if truth.shape != predict.shape:
|
60
60
|
raise ValueError("truth and predict are not the same shape")
|
sonusai/metrics/snr_summary.py
CHANGED
@@ -2,11 +2,11 @@
|
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
4
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from
|
5
|
+
from ..datatypes import GeneralizedIDs
|
6
|
+
from ..datatypes import Predict
|
7
|
+
from ..datatypes import Segsnr
|
8
|
+
from ..datatypes import Truth
|
9
|
+
from ..mixture.mixdb import MixtureDatabase
|
10
10
|
|
11
11
|
|
12
12
|
def snr_summary(
|
@@ -40,8 +40,8 @@ def snr_summary(
|
|
40
40
|
"""
|
41
41
|
import warnings
|
42
42
|
|
43
|
-
from
|
44
|
-
from
|
43
|
+
from ..metrics.one_hot import one_hot
|
44
|
+
from ..queries.queries import get_mixids_from_snr
|
45
45
|
|
46
46
|
num_classes = truth_f.shape[1]
|
47
47
|
|