sonusai 0.18.2__py3-none-any.whl → 0.18.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -0
- sonusai/audiofe.py +1 -1
- sonusai/calc_metric_spenh.py +32 -362
- sonusai/data/genmixdb.yml +2 -0
- sonusai/doc/doc.py +45 -4
- sonusai/genmetrics.py +137 -109
- sonusai/lsdb.py +2 -2
- sonusai/metrics/__init__.py +4 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_pesq.py +12 -8
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_snr_f.py +34 -0
- sonusai/metrics/calc_speech.py +312 -0
- sonusai/metrics/calc_wer.py +2 -3
- sonusai/metrics/calc_wsdr.py +0 -59
- sonusai/mixture/__init__.py +3 -2
- sonusai/mixture/audio.py +6 -5
- sonusai/mixture/config.py +13 -0
- sonusai/mixture/constants.py +1 -0
- sonusai/mixture/datatypes.py +33 -0
- sonusai/mixture/generation.py +6 -2
- sonusai/mixture/mixdb.py +261 -122
- sonusai/mixture/soundfile_audio.py +8 -6
- sonusai/mixture/sox_audio.py +16 -13
- sonusai/mixture/torchaudio_audio.py +6 -4
- sonusai/mixture/truth_functions/energy.py +40 -28
- sonusai/mixture/truth_functions/target.py +0 -1
- sonusai/utils/__init__.py +1 -1
- sonusai/utils/asr.py +26 -39
- sonusai/utils/asr_functions/aaware_whisper.py +3 -3
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/METADATA +1 -1
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/RECORD +34 -31
- sonusai/mixture/mapped_snr_f.py +0 -100
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/WHEEL +0 -0
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/entry_points.txt +0 -0
sonusai/genmetrics.py
CHANGED
@@ -1,146 +1,174 @@
|
|
1
|
-
|
1
|
+
"""sonusai genmetrics
|
2
2
|
|
3
|
+
usage: genmetrics [-hvs] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
|
3
4
|
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
|
9
|
+
-n INCLUDE, --include INCLUDE Metrics to include. [default: all]
|
10
|
+
-x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
|
11
|
+
-s, --supported Show list of supported metrics.
|
8
12
|
|
9
|
-
|
10
|
-
def mxssnravg(self):
|
11
|
-
...
|
13
|
+
Calculate speech enhancement metrics of SonusAI mixture data in LOC.
|
12
14
|
|
13
|
-
|
14
|
-
|
15
|
-
|
15
|
+
Inputs:
|
16
|
+
LOC A SonusAI mixture database directory.
|
17
|
+
MIXID A glob of mixture ID(s) to generate.
|
18
|
+
INCLUDE Comma separated list of metrics to include. Can be 'all' or
|
19
|
+
any of the supported metrics.
|
20
|
+
EXCLUDE Comma separated list of metrics to exclude. Can be 'none' or
|
21
|
+
any of the supported metrics.
|
16
22
|
|
17
|
-
|
18
|
-
def mxssnrdavg(self):
|
19
|
-
...
|
23
|
+
Examples:
|
20
24
|
|
21
|
-
|
22
|
-
|
23
|
-
...
|
25
|
+
Generate all available mxwer metrics (as determined by mixdb asr_configs parameter):
|
26
|
+
> sonusai genmetrics -n"mxwer" mixdb_loc
|
24
27
|
|
25
|
-
|
26
|
-
|
27
|
-
...
|
28
|
+
Generate only mxwer.faster metrics:
|
29
|
+
> sonusai genmetrics -n"mxwer.faster" mixdb_loc
|
28
30
|
|
29
|
-
|
30
|
-
|
31
|
-
...
|
31
|
+
Generate all available metrics except for mxwer.faster:
|
32
|
+
> sonusai genmetrics -x"mxwer.faster" mixdb_loc
|
32
33
|
|
33
|
-
|
34
|
-
|
35
|
-
|
34
|
+
"""
|
35
|
+
import signal
|
36
|
+
from dataclasses import dataclass
|
36
37
|
|
37
|
-
|
38
|
-
def mxstoi(self):
|
39
|
-
...
|
38
|
+
from sonusai.mixture import MixtureDatabase
|
40
39
|
|
41
|
-
@property
|
42
|
-
def mxcsig(self):
|
43
|
-
...
|
44
40
|
|
45
|
-
|
46
|
-
|
47
|
-
...
|
41
|
+
def signal_handler(_sig, _frame):
|
42
|
+
import sys
|
48
43
|
|
49
|
-
|
50
|
-
def mxcovl(self):
|
51
|
-
...
|
44
|
+
from sonusai import logger
|
52
45
|
|
53
|
-
|
54
|
-
|
46
|
+
logger.info('Canceled due to keyboard interrupt')
|
47
|
+
sys.exit(1)
|
55
48
|
|
56
|
-
@property
|
57
|
-
def tdco(self):
|
58
|
-
...
|
59
49
|
|
60
|
-
|
61
|
-
def tmin(self):
|
62
|
-
...
|
50
|
+
signal.signal(signal.SIGINT, signal_handler)
|
63
51
|
|
64
|
-
@property
|
65
|
-
def tmax(self):
|
66
|
-
...
|
67
52
|
|
68
|
-
|
69
|
-
|
70
|
-
|
53
|
+
@dataclass
|
54
|
+
class MPGlobal:
|
55
|
+
mixdb: MixtureDatabase = None
|
56
|
+
metrics: set[str] = None
|
71
57
|
|
72
|
-
@property
|
73
|
-
def tlrms(self):
|
74
|
-
...
|
75
58
|
|
76
|
-
|
77
|
-
def tpkr(self):
|
78
|
-
...
|
59
|
+
MP_GLOBAL = MPGlobal()
|
79
60
|
|
80
|
-
@property
|
81
|
-
def ttr(self):
|
82
|
-
...
|
83
61
|
|
84
|
-
|
85
|
-
|
86
|
-
|
62
|
+
def _initializer(location: str, metrics: set[str]) -> None:
|
63
|
+
MP_GLOBAL.mixdb = MixtureDatabase(location)
|
64
|
+
MP_GLOBAL.metrics = metrics
|
87
65
|
|
88
|
-
@property
|
89
|
-
def tfl(self):
|
90
|
-
...
|
91
66
|
|
92
|
-
|
93
|
-
|
94
|
-
...
|
67
|
+
def _process_mixture(mixid: int) -> None:
|
68
|
+
from sonusai.mixture import write_mixture_data
|
95
69
|
|
96
|
-
|
97
|
-
|
98
|
-
...
|
70
|
+
mixdb = MP_GLOBAL.mixdb
|
71
|
+
metrics = list(MP_GLOBAL.metrics)
|
99
72
|
|
100
|
-
|
101
|
-
|
102
|
-
...
|
73
|
+
values = mixdb.mixture_metrics(m_id=mixid, metrics=metrics, force=True)
|
74
|
+
write_data = list(zip(metrics, values))
|
103
75
|
|
104
|
-
|
105
|
-
def nmax(self):
|
106
|
-
...
|
76
|
+
write_mixture_data(mixdb, mixdb.mixture(mixid), write_data)
|
107
77
|
|
108
|
-
@property
|
109
|
-
def npkdb(self):
|
110
|
-
...
|
111
78
|
|
112
|
-
|
113
|
-
|
114
|
-
...
|
79
|
+
def main() -> None:
|
80
|
+
from docopt import docopt
|
115
81
|
|
116
|
-
|
117
|
-
|
118
|
-
...
|
82
|
+
import sonusai
|
83
|
+
from sonusai.utils import trim_docstring
|
119
84
|
|
120
|
-
|
121
|
-
def ntr(self):
|
122
|
-
...
|
85
|
+
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
123
86
|
|
124
|
-
|
125
|
-
|
126
|
-
|
87
|
+
verbose = args['--verbose']
|
88
|
+
mixids = args['--mixid']
|
89
|
+
includes = [x.strip() for x in args['--include'].lower().split(',')]
|
90
|
+
excludes = [x.strip() for x in args['--exclude'].lower().split(',')]
|
91
|
+
show_supported = args['--supported']
|
92
|
+
location = args['LOC']
|
127
93
|
|
128
|
-
|
129
|
-
|
130
|
-
|
94
|
+
import sys
|
95
|
+
import time
|
96
|
+
from os.path import join
|
131
97
|
|
132
|
-
|
133
|
-
|
134
|
-
|
98
|
+
from sonusai import create_file_handler
|
99
|
+
from sonusai import initial_log_messages
|
100
|
+
from sonusai import logger
|
101
|
+
from sonusai import update_console_handler
|
102
|
+
from sonusai.utils import pp_tqdm_imap
|
103
|
+
from sonusai.utils import seconds_to_hms
|
104
|
+
from tqdm import tqdm
|
135
105
|
|
136
|
-
|
137
|
-
def sedavg(self):
|
138
|
-
...
|
106
|
+
start_time = time.monotonic()
|
139
107
|
|
140
|
-
|
141
|
-
|
142
|
-
|
108
|
+
# Setup logging file
|
109
|
+
create_file_handler(join(location, 'genmetrics.log'))
|
110
|
+
update_console_handler(verbose)
|
111
|
+
initial_log_messages('genmetrics')
|
143
112
|
|
144
|
-
|
145
|
-
|
146
|
-
|
113
|
+
logger.info(f'Load mixture database from {location}')
|
114
|
+
|
115
|
+
mixdb = MixtureDatabase(location)
|
116
|
+
supported = mixdb.supported_metrics
|
117
|
+
if show_supported:
|
118
|
+
logger.info(f'\nSupported metrics: {", ".join(sorted(supported))}')
|
119
|
+
sys.exit(0)
|
120
|
+
|
121
|
+
if includes is None or 'all' in includes:
|
122
|
+
metrics = supported
|
123
|
+
else:
|
124
|
+
metrics = set(includes)
|
125
|
+
if 'mxwer' in metrics:
|
126
|
+
metrics.remove('mxwer')
|
127
|
+
for name in mixdb.asr_configs:
|
128
|
+
metrics.add(f'mxwer.{name}')
|
129
|
+
|
130
|
+
diff = metrics.difference(supported)
|
131
|
+
if diff:
|
132
|
+
logger.error(f'Unrecognized metric: {", ".join(diff)}')
|
133
|
+
sys.exit(1)
|
134
|
+
|
135
|
+
if excludes is None or 'none' in excludes:
|
136
|
+
_excludes = set([])
|
137
|
+
else:
|
138
|
+
_excludes = set(excludes)
|
139
|
+
if 'mxwer' in _excludes:
|
140
|
+
_excludes.remove('mxwer')
|
141
|
+
for name in mixdb.asr_configs:
|
142
|
+
_excludes.add(f'mxwer.{name}')
|
143
|
+
|
144
|
+
diff = _excludes.difference(supported)
|
145
|
+
if diff:
|
146
|
+
logger.error(f'Unrecognized metric: {", ".join(diff)}')
|
147
|
+
sys.exit(1)
|
148
|
+
|
149
|
+
for exclude in _excludes:
|
150
|
+
metrics.discard(exclude)
|
151
|
+
|
152
|
+
logger.info(f'Generating metrics: {", ".join(metrics)}')
|
153
|
+
|
154
|
+
mixids = mixdb.mixids_to_list(mixids)
|
155
|
+
logger.info('')
|
156
|
+
logger.info(f'Found {len(mixids):,} mixtures to process')
|
157
|
+
|
158
|
+
progress = tqdm(total=len(mixids), desc='genmetrics')
|
159
|
+
pp_tqdm_imap(_process_mixture, mixids,
|
160
|
+
progress=progress,
|
161
|
+
initializer=_initializer,
|
162
|
+
initargs=(location, metrics))
|
163
|
+
progress.close()
|
164
|
+
|
165
|
+
logger.info(f'Wrote metrics for {len(mixids)} mixtures to {location}')
|
166
|
+
logger.info('')
|
167
|
+
|
168
|
+
end_time = time.monotonic()
|
169
|
+
logger.info(f'Completed in {seconds_to_hms(seconds=end_time - start_time)}')
|
170
|
+
logger.info('')
|
171
|
+
|
172
|
+
|
173
|
+
if __name__ == '__main__':
|
174
|
+
main()
|
sonusai/lsdb.py
CHANGED
@@ -43,7 +43,7 @@ def lsdb(mixdb: MixtureDatabase,
|
|
43
43
|
import h5py
|
44
44
|
|
45
45
|
from sonusai import SonusAIError
|
46
|
-
from sonusai.
|
46
|
+
from sonusai.metrics import calc_snr_f
|
47
47
|
from sonusai.mixture import SAMPLE_RATE
|
48
48
|
from sonusai.mixture import get_truth_indices_for_target
|
49
49
|
from sonusai.queries import get_mixids_from_truth_index
|
@@ -113,7 +113,7 @@ def lsdb(mixdb: MixtureDatabase,
|
|
113
113
|
else:
|
114
114
|
truth_f = np.concatenate((truth_f, np.array(f['truth_f'])))
|
115
115
|
|
116
|
-
snr_mean, snr_std, snr_db_mean, snr_db_std =
|
116
|
+
snr_mean, snr_std, snr_db_mean, snr_db_std = calc_snr_f(truth_f)
|
117
117
|
|
118
118
|
logger.info('Truth')
|
119
119
|
logger.info(f' {"mean":^8s} {"std":^8s} {"db_mean":^8s} {"db_std":^8s}')
|
sonusai/metrics/__init__.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1
1
|
# SonusAI metrics utilities for model training and validation
|
2
|
+
from .calc_audio_stats import calc_audio_stats
|
2
3
|
from .calc_class_weights import calc_class_weights_from_mixdb
|
3
4
|
from .calc_class_weights import calc_class_weights_from_truth
|
4
5
|
from .calc_optimal_thresholds import calc_optimal_thresholds
|
5
6
|
from .calc_pcm import calc_pcm
|
6
7
|
from .calc_pesq import calc_pesq
|
8
|
+
from .calc_phase_distance import calc_phase_distance
|
7
9
|
from .calc_sa_sdr import calc_sa_sdr
|
8
10
|
from .calc_sample_weights import calc_sample_weights
|
11
|
+
from .calc_snr_f import calc_snr_f
|
12
|
+
from .calc_speech import calc_speech
|
9
13
|
from .calc_wer import calc_wer
|
10
14
|
from .calc_wsdr import calc_wsdr
|
11
15
|
from .class_summary import class_summary
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from sonusai.mixture.datatypes import AudioStatsMetrics
|
2
|
+
from sonusai.mixture.datatypes import AudioT
|
3
|
+
|
4
|
+
|
5
|
+
def calc_audio_stats(audio: AudioT, win_len: float = None) -> AudioStatsMetrics:
|
6
|
+
from sonusai.mixture import SAMPLE_RATE
|
7
|
+
from sonusai.mixture import Transformer
|
8
|
+
|
9
|
+
args = ['stats']
|
10
|
+
if win_len is not None:
|
11
|
+
args.extend(['-w', str(win_len)])
|
12
|
+
|
13
|
+
tfm = Transformer()
|
14
|
+
|
15
|
+
_, _, out = tfm.build(input_array=audio,
|
16
|
+
sample_rate_in=SAMPLE_RATE,
|
17
|
+
output_filepath='-n',
|
18
|
+
extra_args=args,
|
19
|
+
return_output=True)
|
20
|
+
|
21
|
+
stats = {}
|
22
|
+
lines = out.split('\n')
|
23
|
+
for line in lines:
|
24
|
+
split_line = line.split()
|
25
|
+
if len(split_line) == 0:
|
26
|
+
continue
|
27
|
+
value = split_line[-1]
|
28
|
+
key = ' '.join(split_line[:-1])
|
29
|
+
stats[key] = value
|
30
|
+
|
31
|
+
return AudioStatsMetrics(
|
32
|
+
dco=float(stats['DC offset']),
|
33
|
+
min=float(stats['Min level']),
|
34
|
+
max=float(stats['Max level']),
|
35
|
+
pkdb=float(stats['Pk lev dB']),
|
36
|
+
lrms=float(stats['RMS lev dB']),
|
37
|
+
pkr=float(stats['RMS Pk dB']),
|
38
|
+
tr=float(stats['RMS Tr dB']),
|
39
|
+
cr=float(stats['Crest factor']),
|
40
|
+
fl=float(stats['Flat factor']),
|
41
|
+
pkc=int(stats['Pk count']),
|
42
|
+
)
|
sonusai/metrics/calc_pesq.py
CHANGED
@@ -1,14 +1,20 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
+
from sonusai.mixture.constants import SAMPLE_RATE
|
3
4
|
|
4
|
-
|
5
|
-
|
5
|
+
|
6
|
+
def calc_pesq(hypothesis: np.ndarray,
|
7
|
+
reference: np.ndarray,
|
8
|
+
error_value: float = 0.0,
|
9
|
+
sample_rate: int = SAMPLE_RATE) -> float:
|
10
|
+
"""Computes the PESQ score of hypothesis vs. reference
|
6
11
|
|
7
12
|
Upon error, assigns a value of 0, or user specified value in error_value
|
8
13
|
|
9
|
-
:param hypothesis:
|
10
|
-
:param reference:
|
11
|
-
:param error_value:
|
14
|
+
:param hypothesis: estimated audio
|
15
|
+
:param reference: reference audio
|
16
|
+
:param error_value: value to use if error occurs
|
17
|
+
:param sample_rate: sample rate of audio
|
12
18
|
:return: value between -0.5 to 4.5
|
13
19
|
"""
|
14
20
|
import warnings
|
@@ -16,12 +22,10 @@ def calc_pesq(hypothesis: np.ndarray, reference: np.ndarray, error_value: float
|
|
16
22
|
from pesq import pesq
|
17
23
|
|
18
24
|
from sonusai import logger
|
19
|
-
from sonusai.mixture import SAMPLE_RATE
|
20
|
-
|
21
25
|
try:
|
22
26
|
with warnings.catch_warnings():
|
23
27
|
warnings.simplefilter('ignore')
|
24
|
-
score = pesq(
|
28
|
+
score = pesq(fs=sample_rate, ref=reference, deg=hypothesis, mode='wb')
|
25
29
|
except Exception as e:
|
26
30
|
logger.debug(f'PESQ error {e}')
|
27
31
|
score = error_value
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def calc_phase_distance(reference: np.ndarray,
|
5
|
+
hypothesis: np.ndarray,
|
6
|
+
eps: float = 1e-9) -> tuple[float, np.ndarray, np.ndarray]:
|
7
|
+
"""Calculate weighted phase distance error (weight normalization over bins per frame)
|
8
|
+
|
9
|
+
:param reference: complex [frames, bins]
|
10
|
+
:param hypothesis: complex [frames, bins]
|
11
|
+
:param eps: epsilon value
|
12
|
+
:return: mean, mean per bin, mean per frame
|
13
|
+
"""
|
14
|
+
ang_diff = np.angle(reference) - np.angle(hypothesis)
|
15
|
+
phd_mod = (ang_diff + np.pi) % (2 * np.pi) - np.pi
|
16
|
+
rh_angle_diff = phd_mod * 180 / np.pi # angle diff in deg
|
17
|
+
|
18
|
+
# Use complex divide to intrinsically keep angle diff +/-180 deg, but avoid div by zero (real hyp)
|
19
|
+
# hyp_real = np.real(hypothesis)
|
20
|
+
# near_zeros = np.real(hyp_real) < eps
|
21
|
+
# hyp_real = hyp_real * (np.logical_not(near_zeros))
|
22
|
+
# hyp_real = hyp_real + (near_zeros * eps)
|
23
|
+
# hypothesis = hyp_real + 1j*np.imag(hypothesis)
|
24
|
+
# rh_angle_diff = np.angle(reference / hypothesis) * 180 / np.pi # angle diff +/-180
|
25
|
+
|
26
|
+
# weighted mean over all (scalar)
|
27
|
+
reference_mag = np.abs(reference)
|
28
|
+
ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
|
29
|
+
err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
|
30
|
+
|
31
|
+
# weighted mean over frames (value per bin)
|
32
|
+
err_b = np.zeros(reference.shape[1])
|
33
|
+
for bi in range(reference.shape[1]):
|
34
|
+
ref_weight = reference_mag[:, bi] / (np.sum(reference_mag[:, bi], axis=0) + eps)
|
35
|
+
err_b[bi] = np.around(np.sum(ref_weight * rh_angle_diff[:, bi]), 3)
|
36
|
+
|
37
|
+
# weighted mean over bins (value per frame)
|
38
|
+
err_f = np.zeros(reference.shape[0])
|
39
|
+
for fi in range(reference.shape[0]):
|
40
|
+
ref_weight = reference_mag[fi, :] / (np.sum(reference_mag[fi, :]) + eps)
|
41
|
+
err_f[fi] = np.around(np.sum(ref_weight * rh_angle_diff[fi, :]), 3)
|
42
|
+
|
43
|
+
return err, err_b, err_f
|
@@ -0,0 +1,34 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from sonusai.mixture.datatypes import Segsnr
|
4
|
+
from sonusai.mixture.datatypes import SnrFMetrics
|
5
|
+
|
6
|
+
|
7
|
+
def calc_snr_f(segsnr_f: Segsnr) -> SnrFMetrics:
|
8
|
+
"""Calculate metrics of snr_f truth data.
|
9
|
+
|
10
|
+
For now, includes mean and variance of the raw values (usually energy)
|
11
|
+
and mean and standard deviation of the dB values (10 * log10).
|
12
|
+
"""
|
13
|
+
if np.count_nonzero(segsnr_f) == 0:
|
14
|
+
# If all entries are zeros
|
15
|
+
return SnrFMetrics(0, 0, -np.inf, 0)
|
16
|
+
|
17
|
+
tmp = np.ma.array(segsnr_f, mask=np.logical_not(np.isfinite(segsnr_f)), dtype=np.float32)
|
18
|
+
if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
|
19
|
+
# If all entries are infinite
|
20
|
+
return SnrFMetrics(np.inf, 0, np.inf, 0)
|
21
|
+
|
22
|
+
snr_mean = np.mean(tmp, axis=0)
|
23
|
+
snr_var = np.var(tmp, axis=0)
|
24
|
+
|
25
|
+
tmp = 10 * np.ma.log10(tmp)
|
26
|
+
if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
|
27
|
+
# If all entries are masked, special case where all inputs are either 0 or infinite
|
28
|
+
snr_db_mean = -np.inf
|
29
|
+
snr_db_std = np.inf
|
30
|
+
else:
|
31
|
+
snr_db_mean = np.mean(tmp, axis=0)
|
32
|
+
snr_db_std = np.std(tmp, axis=0)
|
33
|
+
|
34
|
+
return SnrFMetrics(snr_mean, snr_var, snr_db_mean, snr_db_std)
|