sonusai 1.0.11__py3-none-any.whl → 1.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/genmetrics.py +4 -6
- sonusai/metrics/__init__.py +1 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/mixture/mixdb.py +22 -551
- {sonusai-1.0.11.dist-info → sonusai-1.0.13.dist-info}/METADATA +1 -1
- {sonusai-1.0.11.dist-info → sonusai-1.0.13.dist-info}/RECORD +8 -7
- {sonusai-1.0.11.dist-info → sonusai-1.0.13.dist-info}/WHEEL +0 -0
- {sonusai-1.0.11.dist-info → sonusai-1.0.13.dist-info}/entry_points.txt +0 -0
sonusai/genmetrics.py
CHANGED
@@ -147,12 +147,10 @@ def main() -> None:
|
|
147
147
|
logger.info("")
|
148
148
|
logger.info(f"Found {len(mixids):,} mixtures to process")
|
149
149
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
no_par = False
|
155
|
-
num_proc = int(num_proc) # TBD add support for 'auto'
|
150
|
+
no_par = num_proc == 1 or len(mixids) == 1
|
151
|
+
|
152
|
+
if num_proc is not None:
|
153
|
+
num_proc = int(num_proc)
|
156
154
|
|
157
155
|
progress = track(total=len(mixids), desc="genmetrics")
|
158
156
|
results = par_track(
|
sonusai/metrics/__init__.py
CHANGED
@@ -15,6 +15,7 @@ from .calc_segsnr_f import calc_segsnr_f_bin
|
|
15
15
|
from .calc_speech import calc_speech
|
16
16
|
from .calc_wer import calc_wer
|
17
17
|
from .calc_wsdr import calc_wsdr
|
18
|
+
from .calculate_metrics import calculate_metrics
|
18
19
|
from .class_summary import class_summary
|
19
20
|
from .confusion_matrix_summary import confusion_matrix_summary
|
20
21
|
from .one_hot import one_hot
|
@@ -0,0 +1,395 @@
|
|
1
|
+
import functools
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from pystoi import stoi
|
6
|
+
|
7
|
+
from ..constants import SAMPLE_RATE
|
8
|
+
from ..datatypes import AudioF
|
9
|
+
from ..datatypes import AudioStatsMetrics
|
10
|
+
from ..datatypes import AudioT
|
11
|
+
from ..datatypes import Segsnr
|
12
|
+
from ..datatypes import SpeechMetrics
|
13
|
+
from ..mixture.mixdb import MixtureDatabase
|
14
|
+
from ..utils.asr import calc_asr
|
15
|
+
from ..utils.db import linear_to_db
|
16
|
+
from .calc_audio_stats import calc_audio_stats
|
17
|
+
from .calc_pesq import calc_pesq
|
18
|
+
from .calc_phase_distance import calc_phase_distance
|
19
|
+
from .calc_segsnr_f import calc_segsnr_f
|
20
|
+
from .calc_segsnr_f import calc_segsnr_f_bin
|
21
|
+
from .calc_speech import calc_speech
|
22
|
+
from .calc_wer import calc_wer
|
23
|
+
from .calc_wsdr import calc_wsdr
|
24
|
+
|
25
|
+
|
26
|
+
def calculate_metrics(mixdb: MixtureDatabase, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
|
27
|
+
"""Get metrics data for the given mixture ID
|
28
|
+
|
29
|
+
:param mixdb: Mixture database object
|
30
|
+
:param m_id: Zero-based mixture ID
|
31
|
+
:param metrics: List of metrics to get
|
32
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
33
|
+
:return: Dictionary of metric data
|
34
|
+
"""
|
35
|
+
|
36
|
+
# Define cached functions for expensive operations
|
37
|
+
@functools.lru_cache(maxsize=1)
|
38
|
+
def mixture_sources() -> dict[str, AudioT]:
|
39
|
+
return mixdb.mixture_sources(m_id)
|
40
|
+
|
41
|
+
@functools.lru_cache(maxsize=1)
|
42
|
+
def mixture_source() -> AudioT:
|
43
|
+
return mixdb.mixture_source(m_id)
|
44
|
+
|
45
|
+
@functools.lru_cache(maxsize=1)
|
46
|
+
def mixture_source_f() -> AudioF:
|
47
|
+
return mixdb.mixture_source_f(m_id)
|
48
|
+
|
49
|
+
@functools.lru_cache(maxsize=1)
|
50
|
+
def mixture_noise() -> AudioT:
|
51
|
+
return mixdb.mixture_noise(m_id)
|
52
|
+
|
53
|
+
@functools.lru_cache(maxsize=1)
|
54
|
+
def mixture_noise_f() -> AudioF:
|
55
|
+
return mixdb.mixture_noise_f(m_id)
|
56
|
+
|
57
|
+
@functools.lru_cache(maxsize=1)
|
58
|
+
def mixture_mixture() -> AudioT:
|
59
|
+
return mixdb.mixture_mixture(m_id)
|
60
|
+
|
61
|
+
@functools.lru_cache(maxsize=1)
|
62
|
+
def mixture_mixture_f() -> AudioF:
|
63
|
+
return mixdb.mixture_mixture_f(m_id)
|
64
|
+
|
65
|
+
@functools.lru_cache(maxsize=1)
|
66
|
+
def mixture_segsnr() -> Segsnr:
|
67
|
+
return mixdb.mixture_segsnr(m_id)
|
68
|
+
|
69
|
+
@functools.lru_cache(maxsize=1)
|
70
|
+
def calculate_pesq() -> dict[str, float]:
|
71
|
+
return {category: calc_pesq(mixture_mixture(), audio) for category, audio in mixture_sources().items()}
|
72
|
+
|
73
|
+
@functools.lru_cache(maxsize=1)
|
74
|
+
def calculate_speech() -> dict[str, SpeechMetrics]:
|
75
|
+
return {
|
76
|
+
category: calc_speech(mixture_mixture(), audio, calculate_pesq()[category])
|
77
|
+
for category, audio in mixture_sources().items()
|
78
|
+
}
|
79
|
+
|
80
|
+
@functools.lru_cache(maxsize=1)
|
81
|
+
def mixture_stats() -> AudioStatsMetrics:
|
82
|
+
return calc_audio_stats(mixture_mixture(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
|
83
|
+
|
84
|
+
@functools.lru_cache(maxsize=1)
|
85
|
+
def sources_stats() -> dict[str, AudioStatsMetrics]:
|
86
|
+
return {
|
87
|
+
category: calc_audio_stats(audio, mixdb.fg_info.ft_config.length / SAMPLE_RATE)
|
88
|
+
for category, audio in mixture_sources().items()
|
89
|
+
}
|
90
|
+
|
91
|
+
@functools.lru_cache(maxsize=1)
|
92
|
+
def source_stats() -> AudioStatsMetrics:
|
93
|
+
return calc_audio_stats(mixture_source(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
|
94
|
+
|
95
|
+
@functools.lru_cache(maxsize=1)
|
96
|
+
def noise_stats() -> AudioStatsMetrics:
|
97
|
+
return calc_audio_stats(mixture_noise(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
|
98
|
+
|
99
|
+
# Cache ASR configurations
|
100
|
+
@functools.lru_cache(maxsize=32)
|
101
|
+
def get_asr_config(asr_name: str) -> dict:
|
102
|
+
value = mixdb.asr_configs.get(asr_name, None)
|
103
|
+
if value is None:
|
104
|
+
raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
|
105
|
+
return value
|
106
|
+
|
107
|
+
# Cache ASR results for sources, source and mixture
|
108
|
+
@functools.lru_cache(maxsize=16)
|
109
|
+
def sources_asr(asr_name: str) -> dict[str, str]:
|
110
|
+
return {
|
111
|
+
category: calc_asr(audio, **get_asr_config(asr_name)).text for category, audio in mixture_sources().items()
|
112
|
+
}
|
113
|
+
|
114
|
+
@functools.lru_cache(maxsize=16)
|
115
|
+
def source_asr(asr_name: str) -> str:
|
116
|
+
return calc_asr(mixture_source(), **get_asr_config(asr_name)).text
|
117
|
+
|
118
|
+
@functools.lru_cache(maxsize=16)
|
119
|
+
def mixture_asr(asr_name: str) -> str:
|
120
|
+
return calc_asr(mixture_mixture(), **get_asr_config(asr_name)).text
|
121
|
+
|
122
|
+
def get_asr_name(m: str) -> str:
|
123
|
+
parts = m.split(".")
|
124
|
+
if len(parts) != 2:
|
125
|
+
raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
126
|
+
asr_name = parts[1]
|
127
|
+
return asr_name
|
128
|
+
|
129
|
+
def calc(m: str) -> Any:
|
130
|
+
if m == "mxsnr":
|
131
|
+
return {category: source.snr for category, source in mixdb.mixture(m_id).all_sources.items()}
|
132
|
+
|
133
|
+
# Get cached data first, if exists
|
134
|
+
if not force:
|
135
|
+
value = mixdb.read_mixture_data(m_id, m)[m]
|
136
|
+
if value is not None:
|
137
|
+
return value
|
138
|
+
|
139
|
+
# Otherwise, generate data as needed
|
140
|
+
if m.startswith("mxwer"):
|
141
|
+
asr_name = get_asr_name(m)
|
142
|
+
|
143
|
+
if mixdb.mixture(m_id).is_noise_only:
|
144
|
+
# noise only, ignore/reset target asr
|
145
|
+
return float("nan")
|
146
|
+
|
147
|
+
if source_asr(asr_name):
|
148
|
+
return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
|
149
|
+
|
150
|
+
# TODO: should this be NaN like above?
|
151
|
+
return float(0)
|
152
|
+
|
153
|
+
if m.startswith("basewer"):
|
154
|
+
asr_name = get_asr_name(m)
|
155
|
+
|
156
|
+
text = mixdb.mixture_speech_metadata(m_id, "text")
|
157
|
+
return {
|
158
|
+
category: calc_wer(source, str(text[category])).wer * 100 if isinstance(text[category], str) else 0
|
159
|
+
for category, source in sources_asr(asr_name).items()
|
160
|
+
}
|
161
|
+
|
162
|
+
if m.startswith("mxasr"):
|
163
|
+
return mixture_asr(get_asr_name(m))
|
164
|
+
|
165
|
+
if m == "mxssnr_avg":
|
166
|
+
return calc_segsnr_f(mixture_segsnr()).avg
|
167
|
+
|
168
|
+
if m == "mxssnr_std":
|
169
|
+
return calc_segsnr_f(mixture_segsnr()).std
|
170
|
+
|
171
|
+
if m == "mxssnr_avg_db":
|
172
|
+
val = calc_segsnr_f(mixture_segsnr()).avg
|
173
|
+
if val is not None:
|
174
|
+
return linear_to_db(val)
|
175
|
+
return None
|
176
|
+
|
177
|
+
if m == "mxssnr_std_db":
|
178
|
+
val = calc_segsnr_f(mixture_segsnr()).std
|
179
|
+
if val is not None:
|
180
|
+
return linear_to_db(val)
|
181
|
+
return None
|
182
|
+
|
183
|
+
if m == "mxssnrdb_avg":
|
184
|
+
return calc_segsnr_f(mixture_segsnr()).db_avg
|
185
|
+
|
186
|
+
if m == "mxssnrdb_std":
|
187
|
+
return calc_segsnr_f(mixture_segsnr()).db_std
|
188
|
+
|
189
|
+
if m == "mxssnrf_avg":
|
190
|
+
return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).avg
|
191
|
+
|
192
|
+
if m == "mxssnrf_std":
|
193
|
+
return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).std
|
194
|
+
|
195
|
+
if m == "mxssnrdbf_avg":
|
196
|
+
return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_avg
|
197
|
+
|
198
|
+
if m == "mxssnrdbf_std":
|
199
|
+
return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_std
|
200
|
+
|
201
|
+
if m == "mxpesq":
|
202
|
+
if mixdb.mixture(m_id).is_noise_only:
|
203
|
+
return dict.fromkeys(calculate_pesq(), 0)
|
204
|
+
return calculate_pesq()
|
205
|
+
|
206
|
+
if m == "mxcsig":
|
207
|
+
if mixdb.mixture(m_id).is_noise_only:
|
208
|
+
return dict.fromkeys(calculate_speech(), 0)
|
209
|
+
return {category: s.csig for category, s in calculate_speech().items()}
|
210
|
+
|
211
|
+
if m == "mxcbak":
|
212
|
+
if mixdb.mixture(m_id).is_noise_only:
|
213
|
+
return dict.fromkeys(calculate_speech(), 0)
|
214
|
+
return {category: s.cbak for category, s in calculate_speech().items()}
|
215
|
+
|
216
|
+
if m == "mxcovl":
|
217
|
+
if mixdb.mixture(m_id).is_noise_only:
|
218
|
+
return dict.fromkeys(calculate_speech(), 0)
|
219
|
+
return {category: s.covl for category, s in calculate_speech().items()}
|
220
|
+
|
221
|
+
if m == "mxwsdr":
|
222
|
+
mixture = mixture_mixture()[:, np.newaxis]
|
223
|
+
target = mixture_source()[:, np.newaxis]
|
224
|
+
noise = mixture_noise()[:, np.newaxis]
|
225
|
+
return calc_wsdr(
|
226
|
+
hypothesis=np.concatenate((mixture, noise), axis=1),
|
227
|
+
reference=np.concatenate((target, noise), axis=1),
|
228
|
+
with_log=True,
|
229
|
+
)[0]
|
230
|
+
|
231
|
+
if m == "mxpd":
|
232
|
+
return calc_phase_distance(hypothesis=mixture_mixture_f(), reference=mixture_source_f())[0]
|
233
|
+
|
234
|
+
if m == "mxstoi":
|
235
|
+
return stoi(
|
236
|
+
x=mixture_source(),
|
237
|
+
y=mixture_mixture(),
|
238
|
+
fs_sig=SAMPLE_RATE,
|
239
|
+
extended=False,
|
240
|
+
)
|
241
|
+
|
242
|
+
if m == "mxdco":
|
243
|
+
return mixture_stats().dco
|
244
|
+
|
245
|
+
if m == "mxmin":
|
246
|
+
return mixture_stats().min
|
247
|
+
|
248
|
+
if m == "mxmax":
|
249
|
+
return mixture_stats().max
|
250
|
+
|
251
|
+
if m == "mxpkdb":
|
252
|
+
return mixture_stats().pkdb
|
253
|
+
|
254
|
+
if m == "mxlrms":
|
255
|
+
return mixture_stats().lrms
|
256
|
+
|
257
|
+
if m == "mxpkr":
|
258
|
+
return mixture_stats().pkr
|
259
|
+
|
260
|
+
if m == "mxtr":
|
261
|
+
return mixture_stats().tr
|
262
|
+
|
263
|
+
if m == "mxcr":
|
264
|
+
return mixture_stats().cr
|
265
|
+
|
266
|
+
if m == "mxfl":
|
267
|
+
return mixture_stats().fl
|
268
|
+
|
269
|
+
if m == "mxpkc":
|
270
|
+
return mixture_stats().pkc
|
271
|
+
|
272
|
+
if m == "sdco":
|
273
|
+
return {category: s.dco for category, s in sources_stats().items()}
|
274
|
+
|
275
|
+
if m == "smin":
|
276
|
+
return {category: s.min for category, s in sources_stats().items()}
|
277
|
+
|
278
|
+
if m == "smax":
|
279
|
+
return {category: s.max for category, s in sources_stats().items()}
|
280
|
+
|
281
|
+
if m == "spkdb":
|
282
|
+
return {category: s.pkdb for category, s in sources_stats().items()}
|
283
|
+
|
284
|
+
if m == "slrms":
|
285
|
+
return {category: s.lrms for category, s in sources_stats().items()}
|
286
|
+
|
287
|
+
if m == "spkr":
|
288
|
+
return {category: s.pkr for category, s in sources_stats().items()}
|
289
|
+
|
290
|
+
if m == "str":
|
291
|
+
return {category: s.tr for category, s in sources_stats().items()}
|
292
|
+
|
293
|
+
if m == "scr":
|
294
|
+
return {category: s.cr for category, s in sources_stats().items()}
|
295
|
+
|
296
|
+
if m == "sfl":
|
297
|
+
return {category: s.fl for category, s in sources_stats().items()}
|
298
|
+
|
299
|
+
if m == "spkc":
|
300
|
+
return {category: s.pkc for category, s in sources_stats().items()}
|
301
|
+
|
302
|
+
if m == "mxsdco":
|
303
|
+
return source_stats().dco
|
304
|
+
|
305
|
+
if m == "mxsmin":
|
306
|
+
return source_stats().min
|
307
|
+
|
308
|
+
if m == "mxsmax":
|
309
|
+
return source_stats().max
|
310
|
+
|
311
|
+
if m == "mxspkdb":
|
312
|
+
return source_stats().pkdb
|
313
|
+
|
314
|
+
if m == "mxslrms":
|
315
|
+
return source_stats().lrms
|
316
|
+
|
317
|
+
if m == "mxspkr":
|
318
|
+
return source_stats().pkr
|
319
|
+
|
320
|
+
if m == "mxstr":
|
321
|
+
return source_stats().tr
|
322
|
+
|
323
|
+
if m == "mxscr":
|
324
|
+
return source_stats().cr
|
325
|
+
|
326
|
+
if m == "mxsfl":
|
327
|
+
return source_stats().fl
|
328
|
+
|
329
|
+
if m == "mxspkc":
|
330
|
+
return source_stats().pkc
|
331
|
+
|
332
|
+
if m.startswith("sasr"):
|
333
|
+
return sources_asr(get_asr_name(m))
|
334
|
+
|
335
|
+
if m.startswith("mxsasr"):
|
336
|
+
return source_asr(get_asr_name(m))
|
337
|
+
|
338
|
+
if m == "ndco":
|
339
|
+
return noise_stats().dco
|
340
|
+
|
341
|
+
if m == "nmin":
|
342
|
+
return noise_stats().min
|
343
|
+
|
344
|
+
if m == "nmax":
|
345
|
+
return noise_stats().max
|
346
|
+
|
347
|
+
if m == "npkdb":
|
348
|
+
return noise_stats().pkdb
|
349
|
+
|
350
|
+
if m == "nlrms":
|
351
|
+
return noise_stats().lrms
|
352
|
+
|
353
|
+
if m == "npkr":
|
354
|
+
return noise_stats().pkr
|
355
|
+
|
356
|
+
if m == "ntr":
|
357
|
+
return noise_stats().tr
|
358
|
+
|
359
|
+
if m == "ncr":
|
360
|
+
return noise_stats().cr
|
361
|
+
|
362
|
+
if m == "nfl":
|
363
|
+
return noise_stats().fl
|
364
|
+
|
365
|
+
if m == "npkc":
|
366
|
+
return noise_stats().pkc
|
367
|
+
|
368
|
+
if m == "sedavg":
|
369
|
+
return 0
|
370
|
+
|
371
|
+
if m == "sedcnt":
|
372
|
+
return 0
|
373
|
+
|
374
|
+
if m == "sedtop3":
|
375
|
+
return np.zeros(3, dtype=np.float32)
|
376
|
+
|
377
|
+
if m == "sedtopn":
|
378
|
+
return 0
|
379
|
+
|
380
|
+
if m == "ssnr":
|
381
|
+
return mixture_segsnr()
|
382
|
+
|
383
|
+
raise AttributeError(f"Unrecognized metric: '{m}'")
|
384
|
+
|
385
|
+
result: dict[str, Any] = {}
|
386
|
+
for metric in metrics:
|
387
|
+
result[metric] = calc(metric)
|
388
|
+
|
389
|
+
# Check for metrics dependencies and add them even if not explicitly requested.
|
390
|
+
if metric.startswith("mxwer"):
|
391
|
+
dependencies = ("mxasr." + metric[6:], "sasr." + metric[6:])
|
392
|
+
for dependency in dependencies:
|
393
|
+
result[dependency] = calc(dependency)
|
394
|
+
|
395
|
+
return result
|
sonusai/mixture/mixdb.py
CHANGED
@@ -215,16 +215,6 @@ class MixtureDatabase:
|
|
215
215
|
MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
|
216
216
|
MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
|
217
217
|
MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
|
218
|
-
MetricDoc("Mixture Metrics", "mxtdco", "Mixture source DC offset"),
|
219
|
-
MetricDoc("Mixture Metrics", "mxtmin", "Mixture source min level"),
|
220
|
-
MetricDoc("Mixture Metrics", "mxtmax", "Mixture source max levl"),
|
221
|
-
MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture source Pk lev dB"),
|
222
|
-
MetricDoc("Mixture Metrics", "mxtlrms", "Mixture source RMS lev dB"),
|
223
|
-
MetricDoc("Mixture Metrics", "mxtpkr", "Mixture source RMS Pk dB"),
|
224
|
-
MetricDoc("Mixture Metrics", "mxttr", "Mixture source RMS Tr dB"),
|
225
|
-
MetricDoc("Mixture Metrics", "mxtcr", "Mixture source Crest factor"),
|
226
|
-
MetricDoc("Mixture Metrics", "mxtfl", "Mixture source Flat factor"),
|
227
|
-
MetricDoc("Mixture Metrics", "mxtpkc", "Mixture source Pk count"),
|
228
218
|
MetricDoc("Sources Metrics", "sdco", "Sources DC offset"),
|
229
219
|
MetricDoc("Sources Metrics", "smin", "Sources min level"),
|
230
220
|
MetricDoc("Sources Metrics", "smax", "Sources max levl"),
|
@@ -235,6 +225,16 @@ class MixtureDatabase:
|
|
235
225
|
MetricDoc("Sources Metrics", "scr", "Sources Crest factor"),
|
236
226
|
MetricDoc("Sources Metrics", "sfl", "Sources Flat factor"),
|
237
227
|
MetricDoc("Sources Metrics", "spkc", "Sources Pk count"),
|
228
|
+
MetricDoc("Source Metrics", "mxsdco", "Source DC offset"),
|
229
|
+
MetricDoc("Source Metrics", "mxsmin", "Source min level"),
|
230
|
+
MetricDoc("Source Metrics", "mxsmax", "Source max levl"),
|
231
|
+
MetricDoc("Source Metrics", "mxspkdb", "Source Pk lev dB"),
|
232
|
+
MetricDoc("Source Metrics", "mxslrms", "Source RMS lev dB"),
|
233
|
+
MetricDoc("Source Metrics", "mxspkr", "Source RMS Pk dB"),
|
234
|
+
MetricDoc("Source Metrics", "mxstr", "Source RMS Tr dB"),
|
235
|
+
MetricDoc("Source Metrics", "mxscr", "Source Crest factor"),
|
236
|
+
MetricDoc("Source Metrics", "mxsfl", "Source Flat factor"),
|
237
|
+
MetricDoc("Source Metrics", "mxspkc", "Source Pk count"),
|
238
238
|
MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
|
239
239
|
MetricDoc("Noise Metrics", "nmin", "Noise min level"),
|
240
240
|
MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
|
@@ -272,12 +272,12 @@ class MixtureDatabase:
|
|
272
272
|
MetricDoc(
|
273
273
|
"Source Metrics",
|
274
274
|
f"mxsasr.{name}",
|
275
|
-
f"
|
275
|
+
f"Source ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
276
276
|
)
|
277
277
|
)
|
278
278
|
metrics.append(
|
279
279
|
MetricDoc(
|
280
|
-
"
|
280
|
+
"Sources Metrics",
|
281
281
|
f"sasr.{name}",
|
282
282
|
f"Sources ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
283
283
|
)
|
@@ -291,7 +291,7 @@ class MixtureDatabase:
|
|
291
291
|
)
|
292
292
|
metrics.append(
|
293
293
|
MetricDoc(
|
294
|
-
"
|
294
|
+
"Sources Metrics",
|
295
295
|
f"basewer.{name}",
|
296
296
|
f"Word error rate of sasr.{name} vs. speech text metadata for the source",
|
297
297
|
)
|
@@ -1296,17 +1296,15 @@ class MixtureDatabase:
|
|
1296
1296
|
fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
|
1297
1297
|
|
1298
1298
|
feature, truth_f = fg.execute_all(mixture_f, truth_t)
|
1299
|
-
if truth_f is
|
1300
|
-
truth_configs = self.mixture_truth_configs(m_id)
|
1301
|
-
for category, configs in truth_configs.items():
|
1302
|
-
for name, config in configs.items():
|
1303
|
-
if self.truth_parameters[category][name] is not None:
|
1304
|
-
truth_f[category][name] = truth_stride_reduction(
|
1305
|
-
truth_f[category][name], config.stride_reduction
|
1306
|
-
)
|
1307
|
-
else:
|
1299
|
+
if truth_f is None:
|
1308
1300
|
raise TypeError("Unexpected truth of None from feature generator")
|
1309
1301
|
|
1302
|
+
truth_configs = self.mixture_truth_configs(m_id)
|
1303
|
+
for category, configs in truth_configs.items():
|
1304
|
+
for name, config in configs.items():
|
1305
|
+
if self.truth_parameters[category][name] is not None:
|
1306
|
+
truth_f[category][name] = truth_stride_reduction(truth_f[category][name], config.stride_reduction)
|
1307
|
+
|
1310
1308
|
if cache:
|
1311
1309
|
write_cached_data(
|
1312
1310
|
location=self.location,
|
@@ -1598,536 +1596,9 @@ class MixtureDatabase:
|
|
1598
1596
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1599
1597
|
:return: Dictionary of metric data
|
1600
1598
|
"""
|
1601
|
-
from
|
1602
|
-
|
1603
|
-
import numpy as np
|
1604
|
-
from pystoi import stoi
|
1605
|
-
|
1606
|
-
from ..constants import SAMPLE_RATE
|
1607
|
-
from ..datatypes import AudioStatsMetrics
|
1608
|
-
from ..datatypes import SpeechMetrics
|
1609
|
-
from ..metrics.calc_audio_stats import calc_audio_stats
|
1610
|
-
from ..metrics.calc_pesq import calc_pesq
|
1611
|
-
from ..metrics.calc_phase_distance import calc_phase_distance
|
1612
|
-
from ..metrics.calc_segsnr_f import calc_segsnr_f
|
1613
|
-
from ..metrics.calc_segsnr_f import calc_segsnr_f_bin
|
1614
|
-
from ..metrics.calc_speech import calc_speech
|
1615
|
-
from ..metrics.calc_wer import calc_wer
|
1616
|
-
from ..metrics.calc_wsdr import calc_wsdr
|
1617
|
-
from ..utils.asr import calc_asr
|
1618
|
-
from ..utils.db import linear_to_db
|
1619
|
-
|
1620
|
-
def create_sources_audio() -> Callable[[], dict[str, AudioT]]:
|
1621
|
-
state: dict[str, AudioT] | None = None
|
1622
|
-
|
1623
|
-
def get() -> dict[str, AudioT]:
|
1624
|
-
nonlocal state
|
1625
|
-
if state is None:
|
1626
|
-
state = self.mixture_sources(m_id)
|
1627
|
-
return state
|
1628
|
-
|
1629
|
-
return get
|
1630
|
-
|
1631
|
-
sources_audio = create_sources_audio()
|
1632
|
-
|
1633
|
-
def create_source_audio() -> Callable[[], AudioT]:
|
1634
|
-
state: AudioT | None = None
|
1635
|
-
|
1636
|
-
def get() -> AudioT:
|
1637
|
-
nonlocal state
|
1638
|
-
if state is None:
|
1639
|
-
state = self.mixture_source(m_id)
|
1640
|
-
return state
|
1641
|
-
|
1642
|
-
return get
|
1643
|
-
|
1644
|
-
source_audio = create_source_audio()
|
1645
|
-
|
1646
|
-
def create_source_f() -> Callable[[], AudioF]:
|
1647
|
-
state: AudioF | None = None
|
1648
|
-
|
1649
|
-
def get() -> AudioF:
|
1650
|
-
nonlocal state
|
1651
|
-
if state is None:
|
1652
|
-
state = self.mixture_source_f(m_id)
|
1653
|
-
return state
|
1654
|
-
|
1655
|
-
return get
|
1656
|
-
|
1657
|
-
source_f = create_source_f()
|
1658
|
-
|
1659
|
-
def create_noise_audio() -> Callable[[], AudioT]:
|
1660
|
-
state: AudioT | None = None
|
1661
|
-
|
1662
|
-
def get() -> AudioT:
|
1663
|
-
nonlocal state
|
1664
|
-
if state is None:
|
1665
|
-
state = self.mixture_noise(m_id)
|
1666
|
-
return state
|
1667
|
-
|
1668
|
-
return get
|
1669
|
-
|
1670
|
-
noise_audio = create_noise_audio()
|
1671
|
-
|
1672
|
-
def create_noise_f() -> Callable[[], AudioF]:
|
1673
|
-
state: AudioF | None = None
|
1674
|
-
|
1675
|
-
def get() -> AudioF:
|
1676
|
-
nonlocal state
|
1677
|
-
if state is None:
|
1678
|
-
state = self.mixture_noise_f(m_id)
|
1679
|
-
return state
|
1680
|
-
|
1681
|
-
return get
|
1682
|
-
|
1683
|
-
noise_f = create_noise_f()
|
1684
|
-
|
1685
|
-
def create_mixture_audio() -> Callable[[], AudioT]:
|
1686
|
-
state: AudioT | None = None
|
1687
|
-
|
1688
|
-
def get() -> AudioT:
|
1689
|
-
nonlocal state
|
1690
|
-
if state is None:
|
1691
|
-
state = self.mixture_mixture(m_id)
|
1692
|
-
return state
|
1693
|
-
|
1694
|
-
return get
|
1695
|
-
|
1696
|
-
mixture_audio = create_mixture_audio()
|
1697
|
-
|
1698
|
-
def create_segsnr_f() -> Callable[[], Segsnr]:
|
1699
|
-
state: Segsnr | None = None
|
1700
|
-
|
1701
|
-
def get() -> Segsnr:
|
1702
|
-
nonlocal state
|
1703
|
-
if state is None:
|
1704
|
-
state = self.mixture_segsnr(m_id)
|
1705
|
-
return state
|
1706
|
-
|
1707
|
-
return get
|
1708
|
-
|
1709
|
-
segsnr_f = create_segsnr_f()
|
1710
|
-
|
1711
|
-
def create_pesq() -> Callable[[], dict[str, float]]:
|
1712
|
-
state: dict[str, float] | None = None
|
1713
|
-
|
1714
|
-
def get() -> dict[str, float]:
|
1715
|
-
nonlocal state
|
1716
|
-
if state is None:
|
1717
|
-
state = {category: calc_pesq(mixture_audio(), audio) for category, audio in sources_audio().items()}
|
1718
|
-
return state
|
1719
|
-
|
1720
|
-
return get
|
1721
|
-
|
1722
|
-
pesq = create_pesq()
|
1723
|
-
|
1724
|
-
def create_speech() -> Callable[[], dict[str, SpeechMetrics]]:
|
1725
|
-
state: dict[str, SpeechMetrics] | None = None
|
1726
|
-
|
1727
|
-
def get() -> dict[str, SpeechMetrics]:
|
1728
|
-
nonlocal state
|
1729
|
-
if state is None:
|
1730
|
-
state = {
|
1731
|
-
category: calc_speech(mixture_audio(), audio, pesq()[category])
|
1732
|
-
for category, audio in sources_audio().items()
|
1733
|
-
}
|
1734
|
-
return state
|
1735
|
-
|
1736
|
-
return get
|
1737
|
-
|
1738
|
-
speech = create_speech()
|
1739
|
-
|
1740
|
-
def create_mixture_stats() -> Callable[[], AudioStatsMetrics]:
|
1741
|
-
state: AudioStatsMetrics | None = None
|
1742
|
-
|
1743
|
-
def get() -> AudioStatsMetrics:
|
1744
|
-
nonlocal state
|
1745
|
-
if state is None:
|
1746
|
-
state = calc_audio_stats(mixture_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1747
|
-
return state
|
1748
|
-
|
1749
|
-
return get
|
1750
|
-
|
1751
|
-
mixture_stats = create_mixture_stats()
|
1752
|
-
|
1753
|
-
def create_sources_stats() -> Callable[[], dict[str, AudioStatsMetrics]]:
|
1754
|
-
state: dict[str, AudioStatsMetrics] | None = None
|
1755
|
-
|
1756
|
-
def get() -> dict[str, AudioStatsMetrics]:
|
1757
|
-
nonlocal state
|
1758
|
-
if state is None:
|
1759
|
-
state = {
|
1760
|
-
category: calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE)
|
1761
|
-
for category, audio in sources_audio().items()
|
1762
|
-
}
|
1763
|
-
return state
|
1764
|
-
|
1765
|
-
return get
|
1766
|
-
|
1767
|
-
sources_stats = create_sources_stats()
|
1768
|
-
|
1769
|
-
def create_source_stats() -> Callable[[], AudioStatsMetrics]:
|
1770
|
-
state: AudioStatsMetrics | None = None
|
1771
|
-
|
1772
|
-
def get() -> AudioStatsMetrics:
|
1773
|
-
nonlocal state
|
1774
|
-
if state is None:
|
1775
|
-
state = calc_audio_stats(source_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1776
|
-
return state
|
1777
|
-
|
1778
|
-
return get
|
1779
|
-
|
1780
|
-
source_stats = create_source_stats()
|
1781
|
-
|
1782
|
-
def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
|
1783
|
-
state: AudioStatsMetrics | None = None
|
1784
|
-
|
1785
|
-
def get() -> AudioStatsMetrics:
|
1786
|
-
nonlocal state
|
1787
|
-
if state is None:
|
1788
|
-
state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1789
|
-
return state
|
1790
|
-
|
1791
|
-
return get
|
1792
|
-
|
1793
|
-
noise_stats = create_noise_stats()
|
1794
|
-
|
1795
|
-
def create_asr_config() -> Callable[[str], dict]:
|
1796
|
-
state: dict[str, dict] = {}
|
1797
|
-
|
1798
|
-
def get(asr_name) -> dict:
|
1799
|
-
nonlocal state
|
1800
|
-
if asr_name not in state:
|
1801
|
-
value = self.asr_configs.get(asr_name, None)
|
1802
|
-
if value is None:
|
1803
|
-
raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
|
1804
|
-
state[asr_name] = value
|
1805
|
-
return state[asr_name]
|
1806
|
-
|
1807
|
-
return get
|
1808
|
-
|
1809
|
-
asr_config = create_asr_config()
|
1810
|
-
|
1811
|
-
def create_sources_asr() -> Callable[[str], dict[str, str]]:
|
1812
|
-
state: dict[str, dict[str, str]] = {}
|
1813
|
-
|
1814
|
-
def get(asr_name) -> dict[str, str]:
|
1815
|
-
nonlocal state
|
1816
|
-
if asr_name not in state:
|
1817
|
-
state[asr_name] = {
|
1818
|
-
category: calc_asr(audio, **asr_config(asr_name)).text
|
1819
|
-
for category, audio in sources_audio().items()
|
1820
|
-
}
|
1821
|
-
return state[asr_name]
|
1822
|
-
|
1823
|
-
return get
|
1824
|
-
|
1825
|
-
sources_asr = create_sources_asr()
|
1826
|
-
|
1827
|
-
def create_source_asr() -> Callable[[str], str]:
|
1828
|
-
state: dict[str, str] = {}
|
1829
|
-
|
1830
|
-
def get(asr_name) -> str:
|
1831
|
-
nonlocal state
|
1832
|
-
if asr_name not in state:
|
1833
|
-
state[asr_name] = calc_asr(source_audio(), **asr_config(asr_name)).text
|
1834
|
-
return state[asr_name]
|
1835
|
-
|
1836
|
-
return get
|
1837
|
-
|
1838
|
-
source_asr = create_source_asr()
|
1839
|
-
|
1840
|
-
def create_mixture_asr() -> Callable[[str], str]:
|
1841
|
-
state: dict[str, str] = {}
|
1842
|
-
|
1843
|
-
def get(asr_name) -> str:
|
1844
|
-
nonlocal state
|
1845
|
-
if asr_name not in state:
|
1846
|
-
state[asr_name] = calc_asr(mixture_audio(), **asr_config(asr_name)).text
|
1847
|
-
return state[asr_name]
|
1848
|
-
|
1849
|
-
return get
|
1850
|
-
|
1851
|
-
mixture_asr = create_mixture_asr()
|
1852
|
-
|
1853
|
-
def get_asr_name(m: str) -> str:
|
1854
|
-
parts = m.split(".")
|
1855
|
-
if len(parts) != 2:
|
1856
|
-
raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
1857
|
-
asr_name = parts[1]
|
1858
|
-
return asr_name
|
1859
|
-
|
1860
|
-
def calc(m: str) -> Any:
|
1861
|
-
if m == "mxsnr":
|
1862
|
-
return {category: source.snr for category, source in self.mixture(m_id).all_sources.items()}
|
1863
|
-
|
1864
|
-
# Get cached data first, if exists
|
1865
|
-
if not force:
|
1866
|
-
value = self.read_mixture_data(m_id, m)[m]
|
1867
|
-
if value is not None:
|
1868
|
-
return value
|
1869
|
-
|
1870
|
-
# Otherwise, generate data as needed
|
1871
|
-
if m.startswith("mxwer"):
|
1872
|
-
asr_name = get_asr_name(m)
|
1873
|
-
|
1874
|
-
if self.mixture(m_id).is_noise_only:
|
1875
|
-
# noise only, ignore/reset target asr
|
1876
|
-
return float("nan")
|
1877
|
-
|
1878
|
-
if source_asr(asr_name):
|
1879
|
-
return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
|
1880
|
-
|
1881
|
-
# TODO: should this be NaN like above?
|
1882
|
-
return float(0)
|
1883
|
-
|
1884
|
-
if m.startswith("basewer"):
|
1885
|
-
asr_name = get_asr_name(m)
|
1886
|
-
|
1887
|
-
text = self.mixture_speech_metadata(m_id, "text")
|
1888
|
-
base_wer: dict[str, float] = {}
|
1889
|
-
for category, source in sources_asr(asr_name).items():
|
1890
|
-
if isinstance(text[category], str):
|
1891
|
-
base_wer[category] = calc_wer(source, str(text[category])).wer * 100
|
1892
|
-
else:
|
1893
|
-
base_wer[category] = 0
|
1894
|
-
return base_wer
|
1895
|
-
|
1896
|
-
if m.startswith("mxasr"):
|
1897
|
-
return mixture_asr(get_asr_name(m))
|
1898
|
-
|
1899
|
-
if m == "mxssnr_avg":
|
1900
|
-
return calc_segsnr_f(segsnr_f()).avg
|
1901
|
-
|
1902
|
-
if m == "mxssnr_std":
|
1903
|
-
return calc_segsnr_f(segsnr_f()).std
|
1904
|
-
|
1905
|
-
if m == "mxssnr_avg_db":
|
1906
|
-
val = calc_segsnr_f(segsnr_f()).avg
|
1907
|
-
if val is not None:
|
1908
|
-
return linear_to_db(val)
|
1909
|
-
return None
|
1910
|
-
|
1911
|
-
if m == "mxssnr_std_db":
|
1912
|
-
val = calc_segsnr_f(segsnr_f()).std
|
1913
|
-
if val is not None:
|
1914
|
-
return linear_to_db(val)
|
1915
|
-
return None
|
1916
|
-
|
1917
|
-
if m == "mxssnrdb_avg":
|
1918
|
-
return calc_segsnr_f(segsnr_f()).db_avg
|
1919
|
-
|
1920
|
-
if m == "mxssnrdb_std":
|
1921
|
-
return calc_segsnr_f(segsnr_f()).db_std
|
1922
|
-
|
1923
|
-
if m == "mxssnrf_avg":
|
1924
|
-
return calc_segsnr_f_bin(source_f(), noise_f()).avg
|
1925
|
-
|
1926
|
-
if m == "mxssnrf_std":
|
1927
|
-
return calc_segsnr_f_bin(source_f(), noise_f()).std
|
1928
|
-
|
1929
|
-
if m == "mxssnrdbf_avg":
|
1930
|
-
return calc_segsnr_f_bin(source_f(), noise_f()).db_avg
|
1931
|
-
|
1932
|
-
if m == "mxssnrdbf_std":
|
1933
|
-
return calc_segsnr_f_bin(source_f(), noise_f()).db_std
|
1934
|
-
|
1935
|
-
if m == "mxpesq":
|
1936
|
-
if self.mixture(m_id).is_noise_only:
|
1937
|
-
return dict.fromkeys(pesq(), 0)
|
1938
|
-
return pesq()
|
1939
|
-
|
1940
|
-
if m == "mxcsig":
|
1941
|
-
if self.mixture(m_id).is_noise_only:
|
1942
|
-
return dict.fromkeys(speech(), 0)
|
1943
|
-
return {category: s.csig for category, s in speech().items()}
|
1944
|
-
|
1945
|
-
if m == "mxcbak":
|
1946
|
-
if self.mixture(m_id).is_noise_only:
|
1947
|
-
return dict.fromkeys(speech(), 0)
|
1948
|
-
return {category: s.cbak for category, s in speech().items()}
|
1949
|
-
|
1950
|
-
if m == "mxcovl":
|
1951
|
-
if self.mixture(m_id).is_noise_only:
|
1952
|
-
return dict.fromkeys(speech(), 0)
|
1953
|
-
return {category: s.covl for category, s in speech().items()}
|
1954
|
-
|
1955
|
-
if m == "mxwsdr":
|
1956
|
-
mixture = mixture_audio()[:, np.newaxis]
|
1957
|
-
target = source_audio()[:, np.newaxis]
|
1958
|
-
noise = noise_audio()[:, np.newaxis]
|
1959
|
-
return calc_wsdr(
|
1960
|
-
hypothesis=np.concatenate((mixture, noise), axis=1),
|
1961
|
-
reference=np.concatenate((target, noise), axis=1),
|
1962
|
-
with_log=True,
|
1963
|
-
)[0]
|
1964
|
-
|
1965
|
-
if m == "mxpd":
|
1966
|
-
mixture_f = self.mixture_mixture_f(m_id)
|
1967
|
-
return calc_phase_distance(hypothesis=mixture_f, reference=source_f())[0]
|
1968
|
-
|
1969
|
-
if m == "mxstoi":
|
1970
|
-
return stoi(
|
1971
|
-
x=source_audio(),
|
1972
|
-
y=mixture_audio(),
|
1973
|
-
fs_sig=SAMPLE_RATE,
|
1974
|
-
extended=False,
|
1975
|
-
)
|
1976
|
-
|
1977
|
-
if m == "mxdco":
|
1978
|
-
return mixture_stats().dco
|
1979
|
-
|
1980
|
-
if m == "mxmin":
|
1981
|
-
return mixture_stats().min
|
1982
|
-
|
1983
|
-
if m == "mxmax":
|
1984
|
-
return mixture_stats().max
|
1985
|
-
|
1986
|
-
if m == "mxpkdb":
|
1987
|
-
return mixture_stats().pkdb
|
1988
|
-
|
1989
|
-
if m == "mxlrms":
|
1990
|
-
return mixture_stats().lrms
|
1991
|
-
|
1992
|
-
if m == "mxpkr":
|
1993
|
-
return mixture_stats().pkr
|
1994
|
-
|
1995
|
-
if m == "mxtr":
|
1996
|
-
return mixture_stats().tr
|
1997
|
-
|
1998
|
-
if m == "mxcr":
|
1999
|
-
return mixture_stats().cr
|
2000
|
-
|
2001
|
-
if m == "mxfl":
|
2002
|
-
return mixture_stats().fl
|
2003
|
-
|
2004
|
-
if m == "mxpkc":
|
2005
|
-
return mixture_stats().pkc
|
2006
|
-
|
2007
|
-
if m == "mxtdco":
|
2008
|
-
return source_stats().dco
|
2009
|
-
|
2010
|
-
if m == "mxtmin":
|
2011
|
-
return source_stats().min
|
2012
|
-
|
2013
|
-
if m == "mxtmax":
|
2014
|
-
return source_stats().max
|
2015
|
-
|
2016
|
-
if m == "mxtpkdb":
|
2017
|
-
return source_stats().pkdb
|
2018
|
-
|
2019
|
-
if m == "mxtlrms":
|
2020
|
-
return source_stats().lrms
|
2021
|
-
|
2022
|
-
if m == "mxtpkr":
|
2023
|
-
return source_stats().pkr
|
2024
|
-
|
2025
|
-
if m == "mxttr":
|
2026
|
-
return source_stats().tr
|
2027
|
-
|
2028
|
-
if m == "mxtcr":
|
2029
|
-
return source_stats().cr
|
2030
|
-
|
2031
|
-
if m == "mxtfl":
|
2032
|
-
return source_stats().fl
|
2033
|
-
|
2034
|
-
if m == "mxtpkc":
|
2035
|
-
return source_stats().pkc
|
2036
|
-
|
2037
|
-
if m == "sdco":
|
2038
|
-
return {category: s.dco for category, s in sources_stats().items()}
|
2039
|
-
|
2040
|
-
if m == "smin":
|
2041
|
-
return {category: s.min for category, s in sources_stats().items()}
|
2042
|
-
|
2043
|
-
if m == "smax":
|
2044
|
-
return {category: s.max for category, s in sources_stats().items()}
|
2045
|
-
|
2046
|
-
if m == "spkdb":
|
2047
|
-
return {category: s.pkdb for category, s in sources_stats().items()}
|
2048
|
-
|
2049
|
-
if m == "slrms":
|
2050
|
-
return {category: s.lrms for category, s in sources_stats().items()}
|
2051
|
-
|
2052
|
-
if m == "spkr":
|
2053
|
-
return {category: s.pkr for category, s in sources_stats().items()}
|
2054
|
-
|
2055
|
-
if m == "str":
|
2056
|
-
return {category: s.tr for category, s in sources_stats().items()}
|
2057
|
-
|
2058
|
-
if m == "scr":
|
2059
|
-
return {category: s.cr for category, s in sources_stats().items()}
|
2060
|
-
|
2061
|
-
if m == "sfl":
|
2062
|
-
return {category: s.fl for category, s in sources_stats().items()}
|
2063
|
-
|
2064
|
-
if m == "spkc":
|
2065
|
-
return {category: s.pkc for category, s in sources_stats().items()}
|
2066
|
-
|
2067
|
-
if m.startswith("sasr"):
|
2068
|
-
return sources_asr(get_asr_name(m))
|
2069
|
-
|
2070
|
-
if m.startswith("mxsasr"):
|
2071
|
-
return source_asr(get_asr_name(m))
|
2072
|
-
|
2073
|
-
if m == "ndco":
|
2074
|
-
return noise_stats().dco
|
2075
|
-
|
2076
|
-
if m == "nmin":
|
2077
|
-
return noise_stats().min
|
2078
|
-
|
2079
|
-
if m == "nmax":
|
2080
|
-
return noise_stats().max
|
2081
|
-
|
2082
|
-
if m == "npkdb":
|
2083
|
-
return noise_stats().pkdb
|
2084
|
-
|
2085
|
-
if m == "nlrms":
|
2086
|
-
return noise_stats().lrms
|
2087
|
-
|
2088
|
-
if m == "npkr":
|
2089
|
-
return noise_stats().pkr
|
2090
|
-
|
2091
|
-
if m == "ntr":
|
2092
|
-
return noise_stats().tr
|
2093
|
-
|
2094
|
-
if m == "ncr":
|
2095
|
-
return noise_stats().cr
|
2096
|
-
|
2097
|
-
if m == "nfl":
|
2098
|
-
return noise_stats().fl
|
2099
|
-
|
2100
|
-
if m == "npkc":
|
2101
|
-
return noise_stats().pkc
|
2102
|
-
|
2103
|
-
if m == "sedavg":
|
2104
|
-
return 0
|
2105
|
-
|
2106
|
-
if m == "sedcnt":
|
2107
|
-
return 0
|
2108
|
-
|
2109
|
-
if m == "sedtop3":
|
2110
|
-
return np.zeros(3, dtype=np.float32)
|
2111
|
-
|
2112
|
-
if m == "sedtopn":
|
2113
|
-
return 0
|
2114
|
-
|
2115
|
-
if m == "ssnr":
|
2116
|
-
return segsnr_f()
|
2117
|
-
|
2118
|
-
raise AttributeError(f"Unrecognized metric: '{m}'")
|
2119
|
-
|
2120
|
-
result: dict[str, Any] = {}
|
2121
|
-
for metric in metrics:
|
2122
|
-
result[metric] = calc(metric)
|
2123
|
-
|
2124
|
-
# Check for metrics dependencies and add them even if not explicitly requested.
|
2125
|
-
if metric.startswith("mxwer"):
|
2126
|
-
dependencies = ("mxasr." + metric[6:], "sasr." + metric[6:])
|
2127
|
-
for dependency in dependencies:
|
2128
|
-
result[dependency] = calc(dependency)
|
1599
|
+
from ..metrics import calculate_metrics
|
2129
1600
|
|
2130
|
-
return
|
1601
|
+
return calculate_metrics(self, m_id, metrics, force)
|
2131
1602
|
|
2132
1603
|
|
2133
1604
|
def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
|
@@ -21,13 +21,13 @@ sonusai/doc/__init__.py,sha256=KyQ26Um0RM8A3GYsb_tbFH64RwpoAw6lja2f_moUWas,33
|
|
21
21
|
sonusai/doc/doc.py,sha256=FURO3pvGKrUCHs5iHf0L2zeNofdePW_jiEwtKQX4pJw,19520
|
22
22
|
sonusai/doc.py,sha256=ZgFSSI56oNDb-yC3xi-RHMClMjryR2VrgGyi3ggX8gM,1098
|
23
23
|
sonusai/genft.py,sha256=yiADvi0J-Fy4kNpNOEB3wVvU9RZowGvOsCTJndQYXFw,5580
|
24
|
-
sonusai/genmetrics.py,sha256=
|
24
|
+
sonusai/genmetrics.py,sha256=C-rp_axsxeKvdXtrtExMGqGnNFJgXHq_7EoKeamUkWA,6116
|
25
25
|
sonusai/genmix.py,sha256=gcmqcPqZ1Vz_TtZMp29L8cGnqTK5jcw0cAOc16NOR9A,5753
|
26
26
|
sonusai/genmixdb.py,sha256=VDQMF6JHcHc-yJAZ1Se3CM3ac8fFKIgnaxv4e5jdE1I,11281
|
27
27
|
sonusai/ir_metric.py,sha256=nxS_mARPSZG5Y0G3L8HysOnkPj4v-RGxAxAVBYe-gJI,19600
|
28
28
|
sonusai/lsdb.py,sha256=-Fhwd7YuL-OIymFqaNcBHtOq8l_8LxzoEE6ztduQCpY,5059
|
29
29
|
sonusai/main.py,sha256=72feJv5XEVJE_CQatmNIL1VD9ca-Mo0QNDbXxLrHrbQ,2619
|
30
|
-
sonusai/metrics/__init__.py,sha256=
|
30
|
+
sonusai/metrics/__init__.py,sha256=0Y0xFHiO3TrH4DRt-htCXEXsc8TLGNRWfD16q16yWEs,927
|
31
31
|
sonusai/metrics/calc_audio_stats.py,sha256=tIfTa40UdYCkj999kUghWafwnFBqFtJxB5yZhVp1YpA,1244
|
32
32
|
sonusai/metrics/calc_class_weights.py,sha256=uF1jeFz73l5nSk6SQ-xkBGbrgvAvX_MKUA_Det2KAEM,3609
|
33
33
|
sonusai/metrics/calc_optimal_thresholds.py,sha256=1bKPoqUYyHpq7lrx7hPnVXrJ5xWIewQjNG632GzKNNU,3502
|
@@ -40,6 +40,7 @@ sonusai/metrics/calc_segsnr_f.py,sha256=yLqUt--8osVgCNAkopbDZsldlVJ6a5AZEggarN8d
|
|
40
40
|
sonusai/metrics/calc_speech.py,sha256=bFiWtKz_Fuu4F1kdWGmZ3qZ_LdoSI3pj0ziXZKxXE3U,14828
|
41
41
|
sonusai/metrics/calc_wer.py,sha256=1MQYMx8ldHeodtJEtGibvDKhvSaGe6DBmZV4L8qOMgg,2362
|
42
42
|
sonusai/metrics/calc_wsdr.py,sha256=vcALY-zuhyThRa1QMz2qW8L9kSBc2v32gV9u8bV7VaM,2556
|
43
|
+
sonusai/metrics/calculate_metrics.py,sha256=jcAyEV6loenu4fU_EvwEkpKxOrP8-K9O3rwQGlE48IU,12475
|
43
44
|
sonusai/metrics/class_summary.py,sha256=mQbMxQ8EtFIN7S2h7A4Dk0X4XF_CIxKk3W8zZMmpfcw,2801
|
44
45
|
sonusai/metrics/confusion_matrix_summary.py,sha256=lhd8TyHVMC03khX85h_D75XElmawx56KkqpX3X2O2gQ,3133
|
45
46
|
sonusai/metrics/one_hot.py,sha256=aKc-xYd4zWIjbmoQikIcQ6BJB1k-68XKTg8eJCacHTU,13906
|
@@ -60,7 +61,7 @@ sonusai/mixture/helpers.py,sha256=dmyHwf1C5dZjYOd11kVV16KI33CaM-dU_fyaxOrrKt8,11
|
|
60
61
|
sonusai/mixture/ir_delay.py,sha256=aiC23HMWQ08-v5wORgMx1_DOJSdh4kunULqiQ-SGuMo,2026
|
61
62
|
sonusai/mixture/ir_effects.py,sha256=PqiqD4PS42-7kD6ESnsZi2a3tnKCFa4E0xqUujRBvGg,2152
|
62
63
|
sonusai/mixture/log_duration_and_sizes.py,sha256=3ekS27IMKlnxIkQAmprzmBnzHOpRjZh3d7maL2VqWQU,927
|
63
|
-
sonusai/mixture/mixdb.py,sha256=
|
64
|
+
sonusai/mixture/mixdb.py,sha256=BzFzVON6ZupJcZ9Bx-OXOirck5szLrRY92bSr3042S8,67874
|
64
65
|
sonusai/mixture/pad_audio.py,sha256=KNxVQAejA0hblLOnMJgLS6lFaeE0n3tWQ5rclaHBnIY,1015
|
65
66
|
sonusai/mixture/parse.py,sha256=nqhjuR-J7_3wlGhVitYFvQwLJ1sclU8WZrVF0SyW2Cw,3700
|
66
67
|
sonusai/mixture/resample.py,sha256=jXqH6FrZ0mlhQ07XqPx88TT9elu3HHVLw7Q0a7Lh5M4,221
|
@@ -133,7 +134,7 @@ sonusai/utils/tokenized_shell_vars.py,sha256=EDrrAgz5lJ0RBAjLcTJt1MeyjhbNZiqXkym
|
|
133
134
|
sonusai/utils/write_audio.py,sha256=IHzrJoFtFcea_J6wo6QSiojRkgnNOzAEcg-z0rFV7nU,810
|
134
135
|
sonusai/utils/yes_or_no.py,sha256=0h1okjXmDNbJp7rZJFR2V-HFU1GJDm3YFTUVmYExkOU,263
|
135
136
|
sonusai/vars.py,sha256=m8pdgfR4A6A9TCGf_rok6jPAT5BgrEsYXTSISIh1nrI,1163
|
136
|
-
sonusai-1.0.
|
137
|
-
sonusai-1.0.
|
138
|
-
sonusai-1.0.
|
139
|
-
sonusai-1.0.
|
137
|
+
sonusai-1.0.13.dist-info/METADATA,sha256=l-tODfpKcDr2Xfqriw3VDvw52-k7YR9S5fEXtutpS1k,2695
|
138
|
+
sonusai-1.0.13.dist-info/WHEEL,sha256=RaoafKOydTQ7I_I3JTrPCg6kUmTgtm4BornzOqyEfJ8,88
|
139
|
+
sonusai-1.0.13.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
|
140
|
+
sonusai-1.0.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|