sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,395 @@
|
|
1
|
+
import functools
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from pystoi import stoi
|
6
|
+
|
7
|
+
from ..constants import SAMPLE_RATE
|
8
|
+
from ..datatypes import AudioF
|
9
|
+
from ..datatypes import AudioStatsMetrics
|
10
|
+
from ..datatypes import AudioT
|
11
|
+
from ..datatypes import Segsnr
|
12
|
+
from ..datatypes import SpeechMetrics
|
13
|
+
from ..mixture.mixdb import MixtureDatabase
|
14
|
+
from ..utils.asr import calc_asr
|
15
|
+
from ..utils.db import linear_to_db
|
16
|
+
from .calc_audio_stats import calc_audio_stats
|
17
|
+
from .calc_pesq import calc_pesq
|
18
|
+
from .calc_phase_distance import calc_phase_distance
|
19
|
+
from .calc_segsnr_f import calc_segsnr_f
|
20
|
+
from .calc_segsnr_f import calc_segsnr_f_bin
|
21
|
+
from .calc_speech import calc_speech
|
22
|
+
from .calc_wer import calc_wer
|
23
|
+
from .calc_wsdr import calc_wsdr
|
24
|
+
|
25
|
+
|
26
|
+
def calculate_metrics(mixdb: MixtureDatabase, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
|
27
|
+
"""Get metrics data for the given mixture ID
|
28
|
+
|
29
|
+
:param mixdb: Mixture database object
|
30
|
+
:param m_id: Zero-based mixture ID
|
31
|
+
:param metrics: List of metrics to get
|
32
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
33
|
+
:return: Dictionary of metric data
|
34
|
+
"""
|
35
|
+
|
36
|
+
# Define cached functions for expensive operations
|
37
|
+
@functools.lru_cache(maxsize=1)
|
38
|
+
def mixture_sources() -> dict[str, AudioT]:
|
39
|
+
return mixdb.mixture_sources(m_id)
|
40
|
+
|
41
|
+
@functools.lru_cache(maxsize=1)
|
42
|
+
def mixture_source() -> AudioT:
|
43
|
+
return mixdb.mixture_source(m_id)
|
44
|
+
|
45
|
+
@functools.lru_cache(maxsize=1)
|
46
|
+
def mixture_source_f() -> AudioF:
|
47
|
+
return mixdb.mixture_source_f(m_id)
|
48
|
+
|
49
|
+
@functools.lru_cache(maxsize=1)
|
50
|
+
def mixture_noise() -> AudioT:
|
51
|
+
return mixdb.mixture_noise(m_id)
|
52
|
+
|
53
|
+
@functools.lru_cache(maxsize=1)
|
54
|
+
def mixture_noise_f() -> AudioF:
|
55
|
+
return mixdb.mixture_noise_f(m_id)
|
56
|
+
|
57
|
+
@functools.lru_cache(maxsize=1)
|
58
|
+
def mixture_mixture() -> AudioT:
|
59
|
+
return mixdb.mixture_mixture(m_id)
|
60
|
+
|
61
|
+
@functools.lru_cache(maxsize=1)
|
62
|
+
def mixture_mixture_f() -> AudioF:
|
63
|
+
return mixdb.mixture_mixture_f(m_id)
|
64
|
+
|
65
|
+
@functools.lru_cache(maxsize=1)
|
66
|
+
def mixture_segsnr() -> Segsnr:
|
67
|
+
return mixdb.mixture_segsnr(m_id)
|
68
|
+
|
69
|
+
@functools.lru_cache(maxsize=1)
|
70
|
+
def calculate_pesq() -> dict[str, float]:
|
71
|
+
return {category: calc_pesq(mixture_mixture(), audio) for category, audio in mixture_sources().items()}
|
72
|
+
|
73
|
+
@functools.lru_cache(maxsize=1)
|
74
|
+
def calculate_speech() -> dict[str, SpeechMetrics]:
|
75
|
+
return {
|
76
|
+
category: calc_speech(mixture_mixture(), audio, calculate_pesq()[category])
|
77
|
+
for category, audio in mixture_sources().items()
|
78
|
+
}
|
79
|
+
|
80
|
+
@functools.lru_cache(maxsize=1)
|
81
|
+
def mixture_stats() -> AudioStatsMetrics:
|
82
|
+
return calc_audio_stats(mixture_mixture(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
|
83
|
+
|
84
|
+
@functools.lru_cache(maxsize=1)
|
85
|
+
def sources_stats() -> dict[str, AudioStatsMetrics]:
|
86
|
+
return {
|
87
|
+
category: calc_audio_stats(audio, mixdb.fg_info.ft_config.length / SAMPLE_RATE)
|
88
|
+
for category, audio in mixture_sources().items()
|
89
|
+
}
|
90
|
+
|
91
|
+
@functools.lru_cache(maxsize=1)
|
92
|
+
def source_stats() -> AudioStatsMetrics:
|
93
|
+
return calc_audio_stats(mixture_source(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
|
94
|
+
|
95
|
+
@functools.lru_cache(maxsize=1)
|
96
|
+
def noise_stats() -> AudioStatsMetrics:
|
97
|
+
return calc_audio_stats(mixture_noise(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
|
98
|
+
|
99
|
+
# Cache ASR configurations
|
100
|
+
@functools.lru_cache(maxsize=32)
|
101
|
+
def get_asr_config(asr_name: str) -> dict:
|
102
|
+
value = mixdb.asr_configs.get(asr_name, None)
|
103
|
+
if value is None:
|
104
|
+
raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
|
105
|
+
return value
|
106
|
+
|
107
|
+
# Cache ASR results for sources, source and mixture
|
108
|
+
@functools.lru_cache(maxsize=16)
|
109
|
+
def sources_asr(asr_name: str) -> dict[str, str]:
|
110
|
+
return {
|
111
|
+
category: calc_asr(audio, **get_asr_config(asr_name)).text for category, audio in mixture_sources().items()
|
112
|
+
}
|
113
|
+
|
114
|
+
@functools.lru_cache(maxsize=16)
|
115
|
+
def source_asr(asr_name: str) -> str:
|
116
|
+
return calc_asr(mixture_source(), **get_asr_config(asr_name)).text
|
117
|
+
|
118
|
+
@functools.lru_cache(maxsize=16)
|
119
|
+
def mixture_asr(asr_name: str) -> str:
|
120
|
+
return calc_asr(mixture_mixture(), **get_asr_config(asr_name)).text
|
121
|
+
|
122
|
+
def get_asr_name(m: str) -> str:
|
123
|
+
parts = m.split(".")
|
124
|
+
if len(parts) != 2:
|
125
|
+
raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
126
|
+
asr_name = parts[1]
|
127
|
+
return asr_name
|
128
|
+
|
129
|
+
def calc(m: str) -> Any:
|
130
|
+
if m == "mxsnr":
|
131
|
+
return {category: source.snr for category, source in mixdb.mixture(m_id).all_sources.items()}
|
132
|
+
|
133
|
+
# Get cached data first, if exists
|
134
|
+
if not force:
|
135
|
+
value = mixdb.read_mixture_data(m_id, m)[m]
|
136
|
+
if value is not None:
|
137
|
+
return value
|
138
|
+
|
139
|
+
# Otherwise, generate data as needed
|
140
|
+
if m.startswith("mxwer"):
|
141
|
+
asr_name = get_asr_name(m)
|
142
|
+
|
143
|
+
if mixdb.mixture(m_id).is_noise_only:
|
144
|
+
# noise only, ignore/reset target asr
|
145
|
+
return float("nan")
|
146
|
+
|
147
|
+
if source_asr(asr_name):
|
148
|
+
return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
|
149
|
+
|
150
|
+
# TODO: should this be NaN like above?
|
151
|
+
return float(0)
|
152
|
+
|
153
|
+
if m.startswith("basewer"):
|
154
|
+
asr_name = get_asr_name(m)
|
155
|
+
|
156
|
+
text = mixdb.mixture_speech_metadata(m_id, "text")
|
157
|
+
return {
|
158
|
+
category: calc_wer(source, str(text[category])).wer * 100 if isinstance(text[category], str) else 0
|
159
|
+
for category, source in sources_asr(asr_name).items()
|
160
|
+
}
|
161
|
+
|
162
|
+
if m.startswith("mxasr"):
|
163
|
+
return mixture_asr(get_asr_name(m))
|
164
|
+
|
165
|
+
if m == "mxssnr_avg":
|
166
|
+
return calc_segsnr_f(mixture_segsnr()).avg
|
167
|
+
|
168
|
+
if m == "mxssnr_std":
|
169
|
+
return calc_segsnr_f(mixture_segsnr()).std
|
170
|
+
|
171
|
+
if m == "mxssnr_avg_db":
|
172
|
+
val = calc_segsnr_f(mixture_segsnr()).avg
|
173
|
+
if val is not None:
|
174
|
+
return linear_to_db(val)
|
175
|
+
return None
|
176
|
+
|
177
|
+
if m == "mxssnr_std_db":
|
178
|
+
val = calc_segsnr_f(mixture_segsnr()).std
|
179
|
+
if val is not None:
|
180
|
+
return linear_to_db(val)
|
181
|
+
return None
|
182
|
+
|
183
|
+
if m == "mxssnrdb_avg":
|
184
|
+
return calc_segsnr_f(mixture_segsnr()).db_avg
|
185
|
+
|
186
|
+
if m == "mxssnrdb_std":
|
187
|
+
return calc_segsnr_f(mixture_segsnr()).db_std
|
188
|
+
|
189
|
+
if m == "mxssnrf_avg":
|
190
|
+
return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).avg
|
191
|
+
|
192
|
+
if m == "mxssnrf_std":
|
193
|
+
return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).std
|
194
|
+
|
195
|
+
if m == "mxssnrdbf_avg":
|
196
|
+
return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_avg
|
197
|
+
|
198
|
+
if m == "mxssnrdbf_std":
|
199
|
+
return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_std
|
200
|
+
|
201
|
+
if m == "mxpesq":
|
202
|
+
if mixdb.mixture(m_id).is_noise_only:
|
203
|
+
return dict.fromkeys(calculate_pesq(), 0)
|
204
|
+
return calculate_pesq()
|
205
|
+
|
206
|
+
if m == "mxcsig":
|
207
|
+
if mixdb.mixture(m_id).is_noise_only:
|
208
|
+
return dict.fromkeys(calculate_speech(), 0)
|
209
|
+
return {category: s.csig for category, s in calculate_speech().items()}
|
210
|
+
|
211
|
+
if m == "mxcbak":
|
212
|
+
if mixdb.mixture(m_id).is_noise_only:
|
213
|
+
return dict.fromkeys(calculate_speech(), 0)
|
214
|
+
return {category: s.cbak for category, s in calculate_speech().items()}
|
215
|
+
|
216
|
+
if m == "mxcovl":
|
217
|
+
if mixdb.mixture(m_id).is_noise_only:
|
218
|
+
return dict.fromkeys(calculate_speech(), 0)
|
219
|
+
return {category: s.covl for category, s in calculate_speech().items()}
|
220
|
+
|
221
|
+
if m == "mxwsdr":
|
222
|
+
mixture = mixture_mixture()[:, np.newaxis]
|
223
|
+
target = mixture_source()[:, np.newaxis]
|
224
|
+
noise = mixture_noise()[:, np.newaxis]
|
225
|
+
return calc_wsdr(
|
226
|
+
hypothesis=np.concatenate((mixture, noise), axis=1),
|
227
|
+
reference=np.concatenate((target, noise), axis=1),
|
228
|
+
with_log=True,
|
229
|
+
)[0]
|
230
|
+
|
231
|
+
if m == "mxpd":
|
232
|
+
return calc_phase_distance(hypothesis=mixture_mixture_f(), reference=mixture_source_f())[0]
|
233
|
+
|
234
|
+
if m == "mxstoi":
|
235
|
+
return stoi(
|
236
|
+
x=mixture_source(),
|
237
|
+
y=mixture_mixture(),
|
238
|
+
fs_sig=SAMPLE_RATE,
|
239
|
+
extended=False,
|
240
|
+
)
|
241
|
+
|
242
|
+
if m == "mxdco":
|
243
|
+
return mixture_stats().dco
|
244
|
+
|
245
|
+
if m == "mxmin":
|
246
|
+
return mixture_stats().min
|
247
|
+
|
248
|
+
if m == "mxmax":
|
249
|
+
return mixture_stats().max
|
250
|
+
|
251
|
+
if m == "mxpkdb":
|
252
|
+
return mixture_stats().pkdb
|
253
|
+
|
254
|
+
if m == "mxlrms":
|
255
|
+
return mixture_stats().lrms
|
256
|
+
|
257
|
+
if m == "mxpkr":
|
258
|
+
return mixture_stats().pkr
|
259
|
+
|
260
|
+
if m == "mxtr":
|
261
|
+
return mixture_stats().tr
|
262
|
+
|
263
|
+
if m == "mxcr":
|
264
|
+
return mixture_stats().cr
|
265
|
+
|
266
|
+
if m == "mxfl":
|
267
|
+
return mixture_stats().fl
|
268
|
+
|
269
|
+
if m == "mxpkc":
|
270
|
+
return mixture_stats().pkc
|
271
|
+
|
272
|
+
if m == "sdco":
|
273
|
+
return {category: s.dco for category, s in sources_stats().items()}
|
274
|
+
|
275
|
+
if m == "smin":
|
276
|
+
return {category: s.min for category, s in sources_stats().items()}
|
277
|
+
|
278
|
+
if m == "smax":
|
279
|
+
return {category: s.max for category, s in sources_stats().items()}
|
280
|
+
|
281
|
+
if m == "spkdb":
|
282
|
+
return {category: s.pkdb for category, s in sources_stats().items()}
|
283
|
+
|
284
|
+
if m == "slrms":
|
285
|
+
return {category: s.lrms for category, s in sources_stats().items()}
|
286
|
+
|
287
|
+
if m == "spkr":
|
288
|
+
return {category: s.pkr for category, s in sources_stats().items()}
|
289
|
+
|
290
|
+
if m == "str":
|
291
|
+
return {category: s.tr for category, s in sources_stats().items()}
|
292
|
+
|
293
|
+
if m == "scr":
|
294
|
+
return {category: s.cr for category, s in sources_stats().items()}
|
295
|
+
|
296
|
+
if m == "sfl":
|
297
|
+
return {category: s.fl for category, s in sources_stats().items()}
|
298
|
+
|
299
|
+
if m == "spkc":
|
300
|
+
return {category: s.pkc for category, s in sources_stats().items()}
|
301
|
+
|
302
|
+
if m == "mxsdco":
|
303
|
+
return source_stats().dco
|
304
|
+
|
305
|
+
if m == "mxsmin":
|
306
|
+
return source_stats().min
|
307
|
+
|
308
|
+
if m == "mxsmax":
|
309
|
+
return source_stats().max
|
310
|
+
|
311
|
+
if m == "mxspkdb":
|
312
|
+
return source_stats().pkdb
|
313
|
+
|
314
|
+
if m == "mxslrms":
|
315
|
+
return source_stats().lrms
|
316
|
+
|
317
|
+
if m == "mxspkr":
|
318
|
+
return source_stats().pkr
|
319
|
+
|
320
|
+
if m == "mxstr":
|
321
|
+
return source_stats().tr
|
322
|
+
|
323
|
+
if m == "mxscr":
|
324
|
+
return source_stats().cr
|
325
|
+
|
326
|
+
if m == "mxsfl":
|
327
|
+
return source_stats().fl
|
328
|
+
|
329
|
+
if m == "mxspkc":
|
330
|
+
return source_stats().pkc
|
331
|
+
|
332
|
+
if m.startswith("sasr"):
|
333
|
+
return sources_asr(get_asr_name(m))
|
334
|
+
|
335
|
+
if m.startswith("mxsasr"):
|
336
|
+
return source_asr(get_asr_name(m))
|
337
|
+
|
338
|
+
if m == "ndco":
|
339
|
+
return noise_stats().dco
|
340
|
+
|
341
|
+
if m == "nmin":
|
342
|
+
return noise_stats().min
|
343
|
+
|
344
|
+
if m == "nmax":
|
345
|
+
return noise_stats().max
|
346
|
+
|
347
|
+
if m == "npkdb":
|
348
|
+
return noise_stats().pkdb
|
349
|
+
|
350
|
+
if m == "nlrms":
|
351
|
+
return noise_stats().lrms
|
352
|
+
|
353
|
+
if m == "npkr":
|
354
|
+
return noise_stats().pkr
|
355
|
+
|
356
|
+
if m == "ntr":
|
357
|
+
return noise_stats().tr
|
358
|
+
|
359
|
+
if m == "ncr":
|
360
|
+
return noise_stats().cr
|
361
|
+
|
362
|
+
if m == "nfl":
|
363
|
+
return noise_stats().fl
|
364
|
+
|
365
|
+
if m == "npkc":
|
366
|
+
return noise_stats().pkc
|
367
|
+
|
368
|
+
if m == "sedavg":
|
369
|
+
return 0
|
370
|
+
|
371
|
+
if m == "sedcnt":
|
372
|
+
return 0
|
373
|
+
|
374
|
+
if m == "sedtop3":
|
375
|
+
return np.zeros(3, dtype=np.float32)
|
376
|
+
|
377
|
+
if m == "sedtopn":
|
378
|
+
return 0
|
379
|
+
|
380
|
+
if m == "ssnr":
|
381
|
+
return mixture_segsnr()
|
382
|
+
|
383
|
+
raise AttributeError(f"Unrecognized metric: '{m}'")
|
384
|
+
|
385
|
+
result: dict[str, Any] = {}
|
386
|
+
for metric in metrics:
|
387
|
+
result[metric] = calc(metric)
|
388
|
+
|
389
|
+
# Check for metrics dependencies and add them even if not explicitly requested.
|
390
|
+
if metric.startswith("mxwer"):
|
391
|
+
dependencies = ("mxasr." + metric[6:], "sasr." + metric[6:])
|
392
|
+
for dependency in dependencies:
|
393
|
+
result[dependency] = calc(dependency)
|
394
|
+
|
395
|
+
return result
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# ruff: noqa: F821
|
2
|
+
import numpy as np
|
3
|
+
import pandas as pd
|
4
|
+
|
5
|
+
from ..datatypes import GeneralizedIDs
|
6
|
+
from ..datatypes import Predict
|
7
|
+
from ..datatypes import Truth
|
8
|
+
from ..mixture.mixdb import MixtureDatabase
|
9
|
+
|
10
|
+
|
11
|
+
def class_summary(
|
12
|
+
mixdb: MixtureDatabase,
|
13
|
+
mixids: GeneralizedIDs,
|
14
|
+
truth_f: Truth,
|
15
|
+
predict: Predict,
|
16
|
+
predict_thr: float | np.ndarray = 0,
|
17
|
+
truth_thr: float = 0.5,
|
18
|
+
timesteps: int = 0,
|
19
|
+
) -> pd.DataFrame:
|
20
|
+
"""Calculate table of metrics per class, and averages for a list
|
21
|
+
of mixtures using truth and prediction data [features, num_classes]
|
22
|
+
Example:
|
23
|
+
Generate multi-class metric summary into table, for example:
|
24
|
+
PPV TPR F1 FPR ACC AP AUC Support
|
25
|
+
Class 1 0.71 0.80 0.75 0.00 0.99 44
|
26
|
+
Class 2 0.90 0.76 0.82 0.00 0.99 128
|
27
|
+
Class 3 0.86 0.82 0.84 0.04 0.93 789
|
28
|
+
Other 0.94 0.96 0.95 0.18 0.92 2807
|
29
|
+
|
30
|
+
micro-avg 0.92 0.027 3768
|
31
|
+
macro avg 0.85 0.83 0.84 0.05 0.96 3768
|
32
|
+
micro-avgwo
|
33
|
+
"""
|
34
|
+
from ..metrics.one_hot import one_hot
|
35
|
+
|
36
|
+
num_classes = truth_f.shape[1]
|
37
|
+
|
38
|
+
# TODO: re-work for modern mixdb API
|
39
|
+
y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore[name-defined]
|
40
|
+
|
41
|
+
if num_classes > 1:
|
42
|
+
if not isinstance(predict_thr, np.ndarray):
|
43
|
+
if predict_thr == 0:
|
44
|
+
predict_thr = np.atleast_1d(0.5)
|
45
|
+
else:
|
46
|
+
predict_thr = np.atleast_1d(predict_thr)
|
47
|
+
else:
|
48
|
+
if predict_thr.ndim == 1 and predict_thr[0] == 0:
|
49
|
+
predict_thr = np.atleast_1d(0.5)
|
50
|
+
|
51
|
+
_, metrics, _, _, _, metavg = one_hot(y_truth_f, y_predict, predict_thr, truth_thr, timesteps)
|
52
|
+
|
53
|
+
# [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
|
54
|
+
table_idx = np.array([2, 1, 6, 4, 0, 12, 13, 9])
|
55
|
+
col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC", "Support"]
|
56
|
+
if len(mixdb.class_labels) == num_classes:
|
57
|
+
row_n = mixdb.class_labels
|
58
|
+
else:
|
59
|
+
row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
|
60
|
+
|
61
|
+
df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
|
62
|
+
|
63
|
+
# [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
|
64
|
+
avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
|
65
|
+
dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
|
66
|
+
|
67
|
+
# dfblank = pd.DataFrame([''])
|
68
|
+
# pd.concat([df, dfblank, dfblank, dfavg])
|
69
|
+
|
70
|
+
classdf = pd.concat([df, dfavg])
|
71
|
+
# classdf = classdf.round(2)
|
72
|
+
classdf["Support"] = classdf["Support"].astype(int)
|
73
|
+
|
74
|
+
return classdf
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# ruff: noqa: F821
|
2
|
+
import numpy as np
|
3
|
+
import pandas as pd
|
4
|
+
|
5
|
+
from ..datatypes import GeneralizedIDs
|
6
|
+
from ..datatypes import Predict
|
7
|
+
from ..datatypes import Truth
|
8
|
+
from ..mixture.mixdb import MixtureDatabase
|
9
|
+
|
10
|
+
|
11
|
+
def confusion_matrix_summary(
|
12
|
+
mixdb: MixtureDatabase,
|
13
|
+
mixids: GeneralizedIDs,
|
14
|
+
truth_f: Truth,
|
15
|
+
predict: Predict,
|
16
|
+
class_idx: int,
|
17
|
+
predict_thr: float | np.ndarray = 0,
|
18
|
+
truth_thr: float = 0.5,
|
19
|
+
timesteps: int = 0,
|
20
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
21
|
+
"""Calculate confusion matrix for specified class, using truth and prediction
|
22
|
+
data [features, num_classes].
|
23
|
+
|
24
|
+
predict_thr sets the decision threshold(s) applied to predict data, thus allowing
|
25
|
+
predict to be continuous probabilities.
|
26
|
+
|
27
|
+
Default predict_thr=0 will infer 0.5 for multi-label mode (truth_mutex = False), or
|
28
|
+
if single-label mode (truth_mutex == True) then ignore and use argmax mode, and
|
29
|
+
the confusion matrix is calculated for all classes.
|
30
|
+
|
31
|
+
Returns pandas dataframes of confusion matrix cmdf and normalized confusion matrix cmndf.
|
32
|
+
"""
|
33
|
+
from ..metrics.one_hot import one_hot
|
34
|
+
|
35
|
+
num_classes = truth_f.shape[1]
|
36
|
+
# TODO: re-work for modern mixdb API
|
37
|
+
ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore[name-defined]
|
38
|
+
|
39
|
+
# Check predict_thr array or scalar and return final scalar predict_thr value
|
40
|
+
if num_classes > 1:
|
41
|
+
if not isinstance(predict_thr, np.ndarray):
|
42
|
+
if predict_thr == 0:
|
43
|
+
# multi-label predict_thr scalar 0 force to 0.5 default
|
44
|
+
predict_thr = np.atleast_1d(0.5)
|
45
|
+
else:
|
46
|
+
predict_thr = np.atleast_1d(predict_thr)
|
47
|
+
else:
|
48
|
+
if predict_thr.ndim == 1:
|
49
|
+
if predict_thr[0] == 0:
|
50
|
+
# multi-label predict_thr array scalar 0 force to 0.5 default
|
51
|
+
predict_thr = np.atleast_1d(0.5)
|
52
|
+
else:
|
53
|
+
# multi-label predict_thr array set to scalar = array[0]
|
54
|
+
predict_thr = predict_thr[0]
|
55
|
+
else:
|
56
|
+
# multi-label predict_thr array scalar set = array[class_idx]
|
57
|
+
predict_thr = predict_thr[class_idx]
|
58
|
+
|
59
|
+
if len(mixdb.class_labels) == num_classes:
|
60
|
+
class_names = mixdb.class_labels
|
61
|
+
else:
|
62
|
+
class_names = [f"Class {i}" for i in range(1, num_classes + 1)]
|
63
|
+
|
64
|
+
_, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
|
65
|
+
cname = class_names[class_idx]
|
66
|
+
row_n = ["TrueN", "TrueP"]
|
67
|
+
col_n = ["N-" + cname, "P-" + cname]
|
68
|
+
cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32) # pyright: ignore [reportArgumentType]
|
69
|
+
cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32) # pyright: ignore [reportArgumentType]
|
70
|
+
# add thresholds in 3rd row
|
71
|
+
pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n) # pyright: ignore [reportArgumentType, reportCallIssue]
|
72
|
+
cmdf = pd.concat([cmdf, pdnote])
|
73
|
+
cmndf = pd.concat([cmndf, pdnote])
|
74
|
+
|
75
|
+
return cmdf, cmndf
|