sonusai 0.18.2__py3-none-any.whl → 0.18.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/genmetrics.py CHANGED
@@ -1,146 +1,174 @@
1
- # Generate mixdb metrics based on metrics listed in config.yml
1
+ """sonusai genmetrics
2
2
 
3
+ usage: genmetrics [-hvs] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
3
4
 
4
- class MixtureMetrics:
5
- @property
6
- def mxsnr(self):
7
- ...
5
+ options:
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
+ -n INCLUDE, --include INCLUDE Metrics to include. [default: all]
10
+ -x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
11
+ -s, --supported Show list of supported metrics.
8
12
 
9
- @property
10
- def mxssnravg(self):
11
- ...
13
+ Calculate speech enhancement metrics of SonusAI mixture data in LOC.
12
14
 
13
- @property
14
- def mxssnrstd(self):
15
- ...
15
+ Inputs:
16
+ LOC A SonusAI mixture database directory.
17
+ MIXID A glob of mixture ID(s) to generate.
18
+ INCLUDE Comma separated list of metrics to include. Can be 'all' or
19
+ any of the supported metrics.
20
+ EXCLUDE Comma separated list of metrics to exclude. Can be 'none' or
21
+ any of the supported metrics.
16
22
 
17
- @property
18
- def mxssnrdavg(self):
19
- ...
23
+ Examples:
20
24
 
21
- @property
22
- def mxssnrdstd(self):
23
- ...
25
+ Generate all available mxwer metrics (as determined by mixdb asr_configs parameter):
26
+ > sonusai genmetrics -n"mxwer" mixdb_loc
24
27
 
25
- @property
26
- def mxpesq(self):
27
- ...
28
+ Generate only mxwer.faster metrics:
29
+ > sonusai genmetrics -n"mxwer.faster" mixdb_loc
28
30
 
29
- @property
30
- def mxwsdr(self):
31
- ...
31
+ Generate all available metrics except for mxwer.faster:
32
+ > sonusai genmetrics -x"mxwer.faster" mixdb_loc
32
33
 
33
- @property
34
- def mxpd(self):
35
- ...
34
+ """
35
+ import signal
36
+ from dataclasses import dataclass
36
37
 
37
- @property
38
- def mxstoi(self):
39
- ...
38
+ from sonusai.mixture import MixtureDatabase
40
39
 
41
- @property
42
- def mxcsig(self):
43
- ...
44
40
 
45
- @property
46
- def mxcbak(self):
47
- ...
41
+ def signal_handler(_sig, _frame):
42
+ import sys
48
43
 
49
- @property
50
- def mxcovl(self):
51
- ...
44
+ from sonusai import logger
52
45
 
53
- def mxwer(self, engine: str, model: str):
54
- ...
46
+ logger.info('Canceled due to keyboard interrupt')
47
+ sys.exit(1)
55
48
 
56
- @property
57
- def tdco(self):
58
- ...
59
49
 
60
- @property
61
- def tmin(self):
62
- ...
50
+ signal.signal(signal.SIGINT, signal_handler)
63
51
 
64
- @property
65
- def tmax(self):
66
- ...
67
52
 
68
- @property
69
- def tpkdb(self):
70
- ...
53
+ @dataclass
54
+ class MPGlobal:
55
+ mixdb: MixtureDatabase = None
56
+ metrics: set[str] = None
71
57
 
72
- @property
73
- def tlrms(self):
74
- ...
75
58
 
76
- @property
77
- def tpkr(self):
78
- ...
59
+ MP_GLOBAL = MPGlobal()
79
60
 
80
- @property
81
- def ttr(self):
82
- ...
83
61
 
84
- @property
85
- def tcr(self):
86
- ...
62
+ def _initializer(location: str, metrics: set[str]) -> None:
63
+ MP_GLOBAL.mixdb = MixtureDatabase(location)
64
+ MP_GLOBAL.metrics = metrics
87
65
 
88
- @property
89
- def tfl(self):
90
- ...
91
66
 
92
- @property
93
- def tpkc(self):
94
- ...
67
+ def _process_mixture(mixid: int) -> None:
68
+ from sonusai.mixture import write_mixture_data
95
69
 
96
- @property
97
- def ndco(self):
98
- ...
70
+ mixdb = MP_GLOBAL.mixdb
71
+ metrics = list(MP_GLOBAL.metrics)
99
72
 
100
- @property
101
- def nmin(self):
102
- ...
73
+ values = mixdb.mixture_metrics(m_id=mixid, metrics=metrics, force=True)
74
+ write_data = list(zip(metrics, values))
103
75
 
104
- @property
105
- def nmax(self):
106
- ...
76
+ write_mixture_data(mixdb, mixdb.mixture(mixid), write_data)
107
77
 
108
- @property
109
- def npkdb(self):
110
- ...
111
78
 
112
- @property
113
- def nlrms(self):
114
- ...
79
+ def main() -> None:
80
+ from docopt import docopt
115
81
 
116
- @property
117
- def npkr(self):
118
- ...
82
+ import sonusai
83
+ from sonusai.utils import trim_docstring
119
84
 
120
- @property
121
- def ntr(self):
122
- ...
85
+ args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
123
86
 
124
- @property
125
- def ncr(self):
126
- ...
87
+ verbose = args['--verbose']
88
+ mixids = args['--mixid']
89
+ includes = [x.strip() for x in args['--include'].lower().split(',')]
90
+ excludes = [x.strip() for x in args['--exclude'].lower().split(',')]
91
+ show_supported = args['--supported']
92
+ location = args['LOC']
127
93
 
128
- @property
129
- def nfl(self):
130
- ...
94
+ import sys
95
+ import time
96
+ from os.path import join
131
97
 
132
- @property
133
- def npkc(self):
134
- ...
98
+ from sonusai import create_file_handler
99
+ from sonusai import initial_log_messages
100
+ from sonusai import logger
101
+ from sonusai import update_console_handler
102
+ from sonusai.utils import pp_tqdm_imap
103
+ from sonusai.utils import seconds_to_hms
104
+ from tqdm import tqdm
135
105
 
136
- @property
137
- def sedavg(self):
138
- ...
106
+ start_time = time.monotonic()
139
107
 
140
- @property
141
- def sedcnt(self):
142
- ...
108
+ # Setup logging file
109
+ create_file_handler(join(location, 'genmetrics.log'))
110
+ update_console_handler(verbose)
111
+ initial_log_messages('genmetrics')
143
112
 
144
- @property
145
- def sedtopn(self):
146
- ...
113
+ logger.info(f'Load mixture database from {location}')
114
+
115
+ mixdb = MixtureDatabase(location)
116
+ supported = mixdb.supported_metrics
117
+ if show_supported:
118
+ logger.info(f'\nSupported metrics: {", ".join(sorted(supported))}')
119
+ sys.exit(0)
120
+
121
+ if includes is None or 'all' in includes:
122
+ metrics = supported
123
+ else:
124
+ metrics = set(includes)
125
+ if 'mxwer' in metrics:
126
+ metrics.remove('mxwer')
127
+ for name in mixdb.asr_configs:
128
+ metrics.add(f'mxwer.{name}')
129
+
130
+ diff = metrics.difference(supported)
131
+ if diff:
132
+ logger.error(f'Unrecognized metric: {", ".join(diff)}')
133
+ sys.exit(1)
134
+
135
+ if excludes is None or 'none' in excludes:
136
+ _excludes = set([])
137
+ else:
138
+ _excludes = set(excludes)
139
+ if 'mxwer' in _excludes:
140
+ _excludes.remove('mxwer')
141
+ for name in mixdb.asr_configs:
142
+ _excludes.add(f'mxwer.{name}')
143
+
144
+ diff = _excludes.difference(supported)
145
+ if diff:
146
+ logger.error(f'Unrecognized metric: {", ".join(diff)}')
147
+ sys.exit(1)
148
+
149
+ for exclude in _excludes:
150
+ metrics.discard(exclude)
151
+
152
+ logger.info(f'Generating metrics: {", ".join(metrics)}')
153
+
154
+ mixids = mixdb.mixids_to_list(mixids)
155
+ logger.info('')
156
+ logger.info(f'Found {len(mixids):,} mixtures to process')
157
+
158
+ progress = tqdm(total=len(mixids), desc='genmetrics')
159
+ pp_tqdm_imap(_process_mixture, mixids,
160
+ progress=progress,
161
+ initializer=_initializer,
162
+ initargs=(location, metrics))
163
+ progress.close()
164
+
165
+ logger.info(f'Wrote metrics for {len(mixids)} mixtures to {location}')
166
+ logger.info('')
167
+
168
+ end_time = time.monotonic()
169
+ logger.info(f'Completed in {seconds_to_hms(seconds=end_time - start_time)}')
170
+ logger.info('')
171
+
172
+
173
+ if __name__ == '__main__':
174
+ main()
sonusai/lsdb.py CHANGED
@@ -43,7 +43,7 @@ def lsdb(mixdb: MixtureDatabase,
43
43
  import h5py
44
44
 
45
45
  from sonusai import SonusAIError
46
- from sonusai.mixture import calculate_snr_f_statistics
46
+ from sonusai.metrics import calc_snr_f
47
47
  from sonusai.mixture import SAMPLE_RATE
48
48
  from sonusai.mixture import get_truth_indices_for_target
49
49
  from sonusai.queries import get_mixids_from_truth_index
@@ -113,7 +113,7 @@ def lsdb(mixdb: MixtureDatabase,
113
113
  else:
114
114
  truth_f = np.concatenate((truth_f, np.array(f['truth_f'])))
115
115
 
116
- snr_mean, snr_std, snr_db_mean, snr_db_std = calculate_snr_f_statistics(truth_f)
116
+ snr_mean, snr_std, snr_db_mean, snr_db_std = calc_snr_f(truth_f)
117
117
 
118
118
  logger.info('Truth')
119
119
  logger.info(f' {"mean":^8s} {"std":^8s} {"db_mean":^8s} {"db_std":^8s}')
@@ -1,11 +1,15 @@
1
1
  # SonusAI metrics utilities for model training and validation
2
+ from .calc_audio_stats import calc_audio_stats
2
3
  from .calc_class_weights import calc_class_weights_from_mixdb
3
4
  from .calc_class_weights import calc_class_weights_from_truth
4
5
  from .calc_optimal_thresholds import calc_optimal_thresholds
5
6
  from .calc_pcm import calc_pcm
6
7
  from .calc_pesq import calc_pesq
8
+ from .calc_phase_distance import calc_phase_distance
7
9
  from .calc_sa_sdr import calc_sa_sdr
8
10
  from .calc_sample_weights import calc_sample_weights
11
+ from .calc_snr_f import calc_snr_f
12
+ from .calc_speech import calc_speech
9
13
  from .calc_wer import calc_wer
10
14
  from .calc_wsdr import calc_wsdr
11
15
  from .class_summary import class_summary
@@ -0,0 +1,42 @@
1
+ from sonusai.mixture.datatypes import AudioStatsMetrics
2
+ from sonusai.mixture.datatypes import AudioT
3
+
4
+
5
+ def calc_audio_stats(audio: AudioT, win_len: float = None) -> AudioStatsMetrics:
6
+ from sonusai.mixture import SAMPLE_RATE
7
+ from sonusai.mixture import Transformer
8
+
9
+ args = ['stats']
10
+ if win_len is not None:
11
+ args.extend(['-w', str(win_len)])
12
+
13
+ tfm = Transformer()
14
+
15
+ _, _, out = tfm.build(input_array=audio,
16
+ sample_rate_in=SAMPLE_RATE,
17
+ output_filepath='-n',
18
+ extra_args=args,
19
+ return_output=True)
20
+
21
+ stats = {}
22
+ lines = out.split('\n')
23
+ for line in lines:
24
+ split_line = line.split()
25
+ if len(split_line) == 0:
26
+ continue
27
+ value = split_line[-1]
28
+ key = ' '.join(split_line[:-1])
29
+ stats[key] = value
30
+
31
+ return AudioStatsMetrics(
32
+ dco=float(stats['DC offset']),
33
+ min=float(stats['Min level']),
34
+ max=float(stats['Max level']),
35
+ pkdb=float(stats['Pk lev dB']),
36
+ lrms=float(stats['RMS lev dB']),
37
+ pkr=float(stats['RMS Pk dB']),
38
+ tr=float(stats['RMS Tr dB']),
39
+ cr=float(stats['Crest factor']),
40
+ fl=float(stats['Flat factor']),
41
+ pkc=int(stats['Pk count']),
42
+ )
@@ -1,14 +1,20 @@
1
1
  import numpy as np
2
2
 
3
+ from sonusai.mixture.constants import SAMPLE_RATE
3
4
 
4
- def calc_pesq(hypothesis: np.ndarray, reference: np.ndarray, error_value: float = 0.0) -> float:
5
- """Computes the PESQ score of speech estimate audio vs. the clean speech estimate audio
5
+
6
+ def calc_pesq(hypothesis: np.ndarray,
7
+ reference: np.ndarray,
8
+ error_value: float = 0.0,
9
+ sample_rate: int = SAMPLE_RATE) -> float:
10
+ """Computes the PESQ score of hypothesis vs. reference
6
11
 
7
12
  Upon error, assigns a value of 0, or user specified value in error_value
8
13
 
9
- :param hypothesis: speech estimated audio
10
- :param reference: speech reference audio
11
- :param error_value:
14
+ :param hypothesis: estimated audio
15
+ :param reference: reference audio
16
+ :param error_value: value to use if error occurs
17
+ :param sample_rate: sample rate of audio
12
18
  :return: value between -0.5 to 4.5
13
19
  """
14
20
  import warnings
@@ -16,12 +22,10 @@ def calc_pesq(hypothesis: np.ndarray, reference: np.ndarray, error_value: float
16
22
  from pesq import pesq
17
23
 
18
24
  from sonusai import logger
19
- from sonusai.mixture import SAMPLE_RATE
20
-
21
25
  try:
22
26
  with warnings.catch_warnings():
23
27
  warnings.simplefilter('ignore')
24
- score = pesq(SAMPLE_RATE, reference, hypothesis, mode='wb')
28
+ score = pesq(fs=sample_rate, ref=reference, deg=hypothesis, mode='wb')
25
29
  except Exception as e:
26
30
  logger.debug(f'PESQ error {e}')
27
31
  score = error_value
@@ -0,0 +1,43 @@
1
+ import numpy as np
2
+
3
+
4
+ def calc_phase_distance(reference: np.ndarray,
5
+ hypothesis: np.ndarray,
6
+ eps: float = 1e-9) -> tuple[float, np.ndarray, np.ndarray]:
7
+ """Calculate weighted phase distance error (weight normalization over bins per frame)
8
+
9
+ :param reference: complex [frames, bins]
10
+ :param hypothesis: complex [frames, bins]
11
+ :param eps: epsilon value
12
+ :return: mean, mean per bin, mean per frame
13
+ """
14
+ ang_diff = np.angle(reference) - np.angle(hypothesis)
15
+ phd_mod = (ang_diff + np.pi) % (2 * np.pi) - np.pi
16
+ rh_angle_diff = phd_mod * 180 / np.pi # angle diff in deg
17
+
18
+ # Use complex divide to intrinsically keep angle diff +/-180 deg, but avoid div by zero (real hyp)
19
+ # hyp_real = np.real(hypothesis)
20
+ # near_zeros = np.real(hyp_real) < eps
21
+ # hyp_real = hyp_real * (np.logical_not(near_zeros))
22
+ # hyp_real = hyp_real + (near_zeros * eps)
23
+ # hypothesis = hyp_real + 1j*np.imag(hypothesis)
24
+ # rh_angle_diff = np.angle(reference / hypothesis) * 180 / np.pi # angle diff +/-180
25
+
26
+ # weighted mean over all (scalar)
27
+ reference_mag = np.abs(reference)
28
+ ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
29
+ err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
30
+
31
+ # weighted mean over frames (value per bin)
32
+ err_b = np.zeros(reference.shape[1])
33
+ for bi in range(reference.shape[1]):
34
+ ref_weight = reference_mag[:, bi] / (np.sum(reference_mag[:, bi], axis=0) + eps)
35
+ err_b[bi] = np.around(np.sum(ref_weight * rh_angle_diff[:, bi]), 3)
36
+
37
+ # weighted mean over bins (value per frame)
38
+ err_f = np.zeros(reference.shape[0])
39
+ for fi in range(reference.shape[0]):
40
+ ref_weight = reference_mag[fi, :] / (np.sum(reference_mag[fi, :]) + eps)
41
+ err_f[fi] = np.around(np.sum(ref_weight * rh_angle_diff[fi, :]), 3)
42
+
43
+ return err, err_b, err_f
@@ -0,0 +1,34 @@
1
+ import numpy as np
2
+
3
+ from sonusai.mixture.datatypes import Segsnr
4
+ from sonusai.mixture.datatypes import SnrFMetrics
5
+
6
+
7
+ def calc_snr_f(segsnr_f: Segsnr) -> SnrFMetrics:
8
+ """Calculate metrics of snr_f truth data.
9
+
10
+ For now, includes mean and variance of the raw values (usually energy)
11
+ and mean and standard deviation of the dB values (10 * log10).
12
+ """
13
+ if np.count_nonzero(segsnr_f) == 0:
14
+ # If all entries are zeros
15
+ return SnrFMetrics(0, 0, -np.inf, 0)
16
+
17
+ tmp = np.ma.array(segsnr_f, mask=np.logical_not(np.isfinite(segsnr_f)), dtype=np.float32)
18
+ if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
19
+ # If all entries are infinite
20
+ return SnrFMetrics(np.inf, 0, np.inf, 0)
21
+
22
+ snr_mean = np.mean(tmp, axis=0)
23
+ snr_var = np.var(tmp, axis=0)
24
+
25
+ tmp = 10 * np.ma.log10(tmp)
26
+ if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
27
+ # If all entries are masked, special case where all inputs are either 0 or infinite
28
+ snr_db_mean = -np.inf
29
+ snr_db_std = np.inf
30
+ else:
31
+ snr_db_mean = np.mean(tmp, axis=0)
32
+ snr_db_std = np.std(tmp, axis=0)
33
+
34
+ return SnrFMetrics(snr_mean, snr_var, snr_db_mean, snr_db_std)