sonusai 0.19.5__py3-none-any.whl → 0.19.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -1
- sonusai/aawscd_probwrite.py +1 -1
- sonusai/calc_metric_spenh.py +1 -1
- sonusai/genft.py +38 -49
- sonusai/genmetrics.py +65 -70
- sonusai/genmix.py +62 -72
- sonusai/genmixdb.py +73 -95
- sonusai/metrics/calc_class_weights.py +1 -3
- sonusai/metrics/calc_optimal_thresholds.py +2 -2
- sonusai/metrics/calc_phase_distance.py +1 -1
- sonusai/metrics/calc_segsnr_f.py +1 -1
- sonusai/metrics/calc_speech.py +6 -6
- sonusai/metrics/class_summary.py +6 -15
- sonusai/metrics/confusion_matrix_summary.py +11 -27
- sonusai/metrics/one_hot.py +3 -3
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/mixture/__init__.py +3 -17
- sonusai/mixture/augmentation.py +5 -6
- sonusai/mixture/class_count.py +1 -1
- sonusai/mixture/config.py +36 -46
- sonusai/mixture/data_io.py +30 -1
- sonusai/mixture/datatypes.py +29 -40
- sonusai/mixture/db_datatypes.py +1 -1
- sonusai/mixture/feature.py +3 -23
- sonusai/mixture/generation.py +202 -235
- sonusai/mixture/helpers.py +29 -187
- sonusai/mixture/mixdb.py +386 -159
- sonusai/mixture/soundfile_audio.py +1 -1
- sonusai/mixture/sox_audio.py +4 -4
- sonusai/mixture/sox_augmentation.py +1 -1
- sonusai/mixture/target_class_balancing.py +9 -11
- sonusai/mixture/targets.py +23 -20
- sonusai/mixture/truth.py +21 -34
- sonusai/mixture/truth_functions/__init__.py +6 -0
- sonusai/mixture/truth_functions/crm.py +51 -37
- sonusai/mixture/truth_functions/energy.py +95 -50
- sonusai/mixture/truth_functions/file.py +12 -8
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +4 -5
- sonusai/mixture/truth_functions/sed.py +32 -23
- sonusai/mixture/truth_functions/target.py +62 -29
- sonusai/mkwav.py +34 -43
- sonusai/queries/queries.py +9 -15
- sonusai/speech/l2arctic.py +6 -2
- sonusai/summarize_metric_spenh.py +1 -1
- sonusai/utils/__init__.py +1 -0
- sonusai/utils/asr_functions/aaware_whisper.py +1 -1
- sonusai/utils/audio_devices.py +27 -18
- sonusai/utils/docstring.py +6 -3
- sonusai/utils/energy_f.py +5 -3
- sonusai/utils/human_readable_size.py +6 -6
- sonusai/utils/load_object.py +15 -0
- sonusai/utils/onnx_utils.py +2 -2
- sonusai/utils/parallel.py +3 -5
- sonusai/utils/print_mixture_details.py +3 -3
- {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/METADATA +2 -2
- {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/RECORD +60 -58
- sonusai/mixture/truth_functions/datatypes.py +0 -37
- {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/WHEEL +0 -0
- {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/entry_points.txt +0 -0
sonusai/genmixdb.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1
1
|
"""sonusai genmixdb
|
2
2
|
|
3
|
-
usage: genmixdb [-
|
3
|
+
usage: genmixdb [-hvmfsdjn] LOC
|
4
4
|
|
5
5
|
options:
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-m, --mix ave mixture data. [default: False].
|
9
|
+
-f, --ft Save feature/truth_f data. [default: False].
|
10
|
+
-s, --segsnr Save segsnr data. [default: False].
|
11
|
+
-d, --dryrun Perform a dry run showing the processed config. [default: False].
|
12
|
+
-j, --json Save JSON version of database. [default: False].
|
13
|
+
-n, --nopar Do not run in parallel. [default: False].
|
13
14
|
|
14
15
|
Create mixture database data for training and evaluation. Optionally, also create mixture audio and feature/truth data.
|
15
16
|
|
@@ -114,10 +115,6 @@ will find all .wav files in the specified directories and process them as target
|
|
114
115
|
"""
|
115
116
|
|
116
117
|
import signal
|
117
|
-
from dataclasses import dataclass
|
118
|
-
|
119
|
-
from sonusai.mixture import Mixture
|
120
|
-
from sonusai.mixture import MixtureDatabase
|
121
118
|
|
122
119
|
|
123
120
|
def signal_handler(_sig, _frame):
|
@@ -132,17 +129,6 @@ def signal_handler(_sig, _frame):
|
|
132
129
|
signal.signal(signal.SIGINT, signal_handler)
|
133
130
|
|
134
131
|
|
135
|
-
@dataclass
|
136
|
-
class MPGlobal:
|
137
|
-
mixdb: MixtureDatabase
|
138
|
-
save_mix: bool
|
139
|
-
save_ft: bool
|
140
|
-
save_segsnr: bool
|
141
|
-
|
142
|
-
|
143
|
-
MP_GLOBAL: MPGlobal
|
144
|
-
|
145
|
-
|
146
132
|
def genmixdb(
|
147
133
|
location: str,
|
148
134
|
save_mix: bool = False,
|
@@ -152,7 +138,9 @@ def genmixdb(
|
|
152
138
|
show_progress: bool = False,
|
153
139
|
test: bool = False,
|
154
140
|
save_json: bool = False,
|
155
|
-
|
141
|
+
no_par: bool = False,
|
142
|
+
) -> None:
|
143
|
+
from functools import partial
|
156
144
|
from random import seed
|
157
145
|
|
158
146
|
import yaml
|
@@ -163,7 +151,6 @@ def genmixdb(
|
|
163
151
|
from sonusai.mixture import AugmentationRule
|
164
152
|
from sonusai.mixture import MixtureDatabase
|
165
153
|
from sonusai.mixture import balance_targets
|
166
|
-
from sonusai.mixture import generate_mixtures
|
167
154
|
from sonusai.mixture import get_all_snrs_from_config
|
168
155
|
from sonusai.mixture import get_augmentation_rules
|
169
156
|
from sonusai.mixture import get_augmented_targets
|
@@ -329,7 +316,8 @@ def genmixdb(
|
|
329
316
|
f"{seconds_to_hms(seconds=noise_audio_duration)}"
|
330
317
|
)
|
331
318
|
|
332
|
-
used_noise_files, used_noise_samples
|
319
|
+
used_noise_files, used_noise_samples = populate_mixture_table(
|
320
|
+
location=location,
|
333
321
|
noise_mix_mode=mixdb.noise_mix_mode,
|
334
322
|
augmented_targets=augmented_targets,
|
335
323
|
target_files=target_files,
|
@@ -342,16 +330,17 @@ def genmixdb(
|
|
342
330
|
num_classes=mixdb.num_classes,
|
343
331
|
feature_step_samples=mixdb.feature_step_samples,
|
344
332
|
num_ir=mixdb.num_impulse_response_files,
|
333
|
+
test=test,
|
345
334
|
)
|
346
335
|
|
347
|
-
num_mixtures = len(mixtures)
|
336
|
+
num_mixtures = len(mixdb.mixtures)
|
348
337
|
update_mixid_width(location, num_mixtures, test)
|
349
338
|
|
350
339
|
if logging:
|
351
340
|
logger.info("")
|
352
341
|
logger.info(f"Found {num_mixtures:,} mixtures to process")
|
353
342
|
|
354
|
-
total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
|
343
|
+
total_duration = float(sum([mixture.samples for mixture in mixdb.mixtures])) / SAMPLE_RATE
|
355
344
|
|
356
345
|
if logging:
|
357
346
|
log_duration_and_sizes(
|
@@ -375,17 +364,21 @@ def genmixdb(
|
|
375
364
|
if logging:
|
376
365
|
logger.info("Generating mixtures")
|
377
366
|
progress = track(total=num_mixtures, disable=not show_progress)
|
378
|
-
|
379
|
-
|
380
|
-
|
367
|
+
par_track(
|
368
|
+
partial(
|
369
|
+
_process_mixture,
|
370
|
+
location=location,
|
371
|
+
save_mix=save_mix,
|
372
|
+
save_ft=save_ft,
|
373
|
+
save_segsnr=save_segsnr,
|
374
|
+
test=test,
|
375
|
+
),
|
376
|
+
range(num_mixtures),
|
381
377
|
progress=progress,
|
382
|
-
|
383
|
-
initargs=(location, save_mix, save_ft, save_segsnr, test),
|
378
|
+
no_par=no_par,
|
384
379
|
)
|
385
380
|
progress.close()
|
386
381
|
|
387
|
-
populate_mixture_table(location, mixtures, test)
|
388
|
-
|
389
382
|
total_noise_files = len(noise_files)
|
390
383
|
|
391
384
|
total_samples = mixdb.total_samples()
|
@@ -414,79 +407,62 @@ def genmixdb(
|
|
414
407
|
mixdb = MixtureDatabase(location)
|
415
408
|
mixdb.save()
|
416
409
|
|
417
|
-
return mixdb
|
418
|
-
|
419
410
|
|
420
|
-
def
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
def _process_mixture(mixture: Mixture) -> Mixture:
|
432
|
-
from typing import Any
|
411
|
+
def _process_mixture(
|
412
|
+
m_id: int,
|
413
|
+
location: str,
|
414
|
+
save_mix: bool,
|
415
|
+
save_ft: bool,
|
416
|
+
save_segsnr: bool,
|
417
|
+
test: bool,
|
418
|
+
) -> None:
|
419
|
+
from functools import partial
|
433
420
|
|
434
|
-
from sonusai.mixture import
|
435
|
-
from sonusai.mixture import
|
436
|
-
from sonusai.mixture import
|
437
|
-
from sonusai.mixture import update_mixture
|
421
|
+
from sonusai.mixture import MixtureDatabase
|
422
|
+
from sonusai.mixture import clear_cached_data
|
423
|
+
from sonusai.mixture import update_mixture_table
|
438
424
|
from sonusai.mixture import write_cached_data
|
439
425
|
from sonusai.mixture import write_mixture_metadata
|
440
426
|
|
441
|
-
|
427
|
+
with_data = save_mix or save_ft or save_segsnr
|
428
|
+
|
429
|
+
genmix_data = update_mixture_table(location, m_id, with_data, test)
|
442
430
|
|
443
|
-
|
444
|
-
|
431
|
+
mixdb = MixtureDatabase(location, test)
|
432
|
+
mixture = mixdb.mixture(m_id)
|
445
433
|
|
446
|
-
|
434
|
+
write = partial(write_cached_data, location=location, name="mixture", index=mixture.name)
|
435
|
+
clear = partial(clear_cached_data, location=location, name="mixture", index=mixture.name)
|
447
436
|
|
448
437
|
if with_data:
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
feature, truth_f = get_ft(
|
467
|
-
mixdb=mixdb,
|
468
|
-
mixture=mixture,
|
469
|
-
mixture_audio=genmix_data.mixture,
|
470
|
-
truth_t=truth_t,
|
438
|
+
write(
|
439
|
+
items=[
|
440
|
+
("targets", genmix_data.targets),
|
441
|
+
("target", genmix_data.target),
|
442
|
+
("noise", genmix_data.noise),
|
443
|
+
("mixture", genmix_data.mixture),
|
444
|
+
]
|
445
|
+
)
|
446
|
+
|
447
|
+
if save_ft:
|
448
|
+
clear(items=["feature", "truth_f"])
|
449
|
+
feature, truth_f = mixdb.mixture_ft(m_id)
|
450
|
+
write(
|
451
|
+
items=[
|
452
|
+
("feature", feature),
|
453
|
+
("truth_f", truth_f),
|
454
|
+
]
|
471
455
|
)
|
472
|
-
write_data.append(("feature", feature))
|
473
|
-
write_data.append(("truth_f", truth_f))
|
474
456
|
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
mixdb=mixdb,
|
480
|
-
mixture=mixture,
|
481
|
-
target_audio=genmix_data.target,
|
482
|
-
noise=genmix_data.noise,
|
483
|
-
)
|
484
|
-
write_data.append(("segsnr", segsnr))
|
457
|
+
if save_segsnr:
|
458
|
+
clear(items=["segsnr"])
|
459
|
+
segsnr = mixdb.mixture_segsnr(m_id)
|
460
|
+
write(items=[("segsnr", segsnr)])
|
485
461
|
|
486
|
-
|
487
|
-
|
462
|
+
if not save_mix:
|
463
|
+
clear(items=["targets", "target", "noise", "mixture"])
|
488
464
|
|
489
|
-
|
465
|
+
write_mixture_metadata(mixdb, m_id)
|
490
466
|
|
491
467
|
|
492
468
|
def main() -> None:
|
@@ -519,6 +495,7 @@ def main() -> None:
|
|
519
495
|
save_segsnr = args["--segsnr"]
|
520
496
|
dryrun = args["--dryrun"]
|
521
497
|
save_json = args["--json"]
|
498
|
+
no_par = args["--nopar"]
|
522
499
|
location = args["LOC"]
|
523
500
|
|
524
501
|
start_time = time.monotonic()
|
@@ -549,6 +526,7 @@ def main() -> None:
|
|
549
526
|
save_segsnr=save_segsnr,
|
550
527
|
show_progress=True,
|
551
528
|
save_json=save_json,
|
529
|
+
no_par=no_par,
|
552
530
|
)
|
553
531
|
except Exception as e:
|
554
532
|
logger.debug(e)
|
@@ -54,7 +54,7 @@ def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = Non
|
|
54
54
|
|
55
55
|
def calc_class_weights_from_mixdb(
|
56
56
|
mixdb: MixtureDatabase,
|
57
|
-
mixids: GeneralizedIDs
|
57
|
+
mixids: GeneralizedIDs = "*",
|
58
58
|
other_weight: float = 1,
|
59
59
|
other_index: int = -1,
|
60
60
|
) -> tuple[np.ndarray, np.ndarray]:
|
@@ -77,8 +77,6 @@ def calc_class_weights_from_mixdb(
|
|
77
77
|
from sonusai.mixture import get_class_count_from_mixids
|
78
78
|
|
79
79
|
count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
|
80
|
-
if mixdb.truth_mutex and other_weight is not None and other_weight > 0:
|
81
|
-
count[other_index] = count[other_index] / np.float32(other_weight)
|
82
80
|
total_features = sum(count)
|
83
81
|
|
84
82
|
weights = np.empty(mixdb.num_classes, dtype=np.float32)
|
@@ -51,8 +51,8 @@ def calc_optimal_thresholds(
|
|
51
51
|
AUC[nci] = np.NaN
|
52
52
|
AP[nci] = np.NaN
|
53
53
|
else:
|
54
|
-
AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None)
|
55
|
-
AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None)
|
54
|
+
AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
55
|
+
AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
56
56
|
|
57
57
|
# Optimal threshold from PR curve, optimizes f-score
|
58
58
|
precision, recall, thrpr = precision_recall_curve(truth_binary[:, nci], predict[:, nci])
|
@@ -26,7 +26,7 @@ def calc_phase_distance(
|
|
26
26
|
# weighted mean over all (scalar)
|
27
27
|
reference_mag = np.abs(reference)
|
28
28
|
ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
|
29
|
-
err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
|
29
|
+
err = float(np.around(np.sum(ref_weight * rh_angle_diff), 3))
|
30
30
|
|
31
31
|
# weighted mean over frames (value per bin)
|
32
32
|
err_b = np.zeros(reference.shape[1])
|
sonusai/metrics/calc_segsnr_f.py
CHANGED
@@ -45,7 +45,7 @@ def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
|
|
45
45
|
if target_f.ndim != 2 and noise_f.ndim != 2:
|
46
46
|
raise ValueError("target_f and noise_f must have 2 dimensions")
|
47
47
|
|
48
|
-
segsnr_f = (np.abs(target_f) ** 2) / (np.abs(noise_f) ** 2)
|
48
|
+
segsnr_f = (np.abs(target_f) ** 2) / (np.abs(noise_f) ** 2 + np.finfo(np.float32).eps)
|
49
49
|
|
50
50
|
frames, bins = segsnr_f.shape
|
51
51
|
if np.count_nonzero(segsnr_f) == 0:
|
sonusai/metrics/calc_speech.py
CHANGED
@@ -32,16 +32,16 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
|
|
32
32
|
llr_mean = np.mean(ll_rs[:llr_len])
|
33
33
|
|
34
34
|
# Segmental SNR
|
35
|
-
|
35
|
+
_, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
36
36
|
seg_snr = np.mean(segsnr_dist)
|
37
37
|
|
38
38
|
# PESQ
|
39
39
|
_pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
40
40
|
|
41
41
|
# Now compute the composite measures
|
42
|
-
csig = np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5)
|
43
|
-
cbak = np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5)
|
44
|
-
covl = np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5)
|
42
|
+
csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5))
|
43
|
+
cbak = float(np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
|
44
|
+
covl = float(np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
|
45
45
|
|
46
46
|
return SpeechMetrics(_pesq, csig, cbak, covl)
|
47
47
|
|
@@ -284,8 +284,8 @@ def _calc_log_likelihood_ratio_measure(
|
|
284
284
|
hypothesis_frame = np.multiply(hypothesis_frame, window)
|
285
285
|
|
286
286
|
# (2) Get the autocorrelation lags and LPC parameters used to compute the log likelihood ratio measure.
|
287
|
-
r_reference,
|
288
|
-
|
287
|
+
r_reference, _, a_reference = _lp_coefficients(reference_frame, p)
|
288
|
+
_, _, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
|
289
289
|
|
290
290
|
# (3) Compute the log likelihood ratio measure
|
291
291
|
numerator = np.dot(np.matmul(a_hypothesis, toeplitz(r_reference)), a_hypothesis)
|
sonusai/metrics/class_summary.py
CHANGED
@@ -38,7 +38,7 @@ def class_summary(
|
|
38
38
|
# TODO: re-work for modern mixdb API
|
39
39
|
y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore[name-defined]
|
40
40
|
|
41
|
-
if
|
41
|
+
if num_classes > 1:
|
42
42
|
if not isinstance(predict_thr, np.ndarray):
|
43
43
|
if predict_thr == 0:
|
44
44
|
predict_thr = np.atleast_1d(0.5)
|
@@ -53,25 +53,16 @@ def class_summary(
|
|
53
53
|
# [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
|
54
54
|
table_idx = np.array([2, 1, 6, 4, 0, 12, 13, 9])
|
55
55
|
col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC", "Support"]
|
56
|
-
if mixdb.
|
57
|
-
|
58
|
-
row_n = mixdb.class_labels
|
59
|
-
if len(mixdb.class_labels) == num_classes - 1: # Other label does not exist, so add it
|
60
|
-
row_n.append("Other")
|
61
|
-
else:
|
62
|
-
row_n = [f"Class {i}" for i in range(1, num_classes)]
|
63
|
-
row_n.append("Other")
|
56
|
+
if len(mixdb.class_labels) == num_classes:
|
57
|
+
row_n = mixdb.class_labels
|
64
58
|
else:
|
65
|
-
|
66
|
-
row_n = mixdb.class_labels
|
67
|
-
else:
|
68
|
-
row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
|
59
|
+
row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
|
69
60
|
|
70
|
-
df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n)
|
61
|
+
df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
|
71
62
|
|
72
63
|
# [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
|
73
64
|
avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
|
74
|
-
dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n)
|
65
|
+
dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
|
75
66
|
|
76
67
|
# dfblank = pd.DataFrame([''])
|
77
68
|
# pd.concat([df, dfblank, dfblank, dfavg])
|
@@ -37,7 +37,7 @@ def confusion_matrix_summary(
|
|
37
37
|
ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore[name-defined]
|
38
38
|
|
39
39
|
# Check predict_thr array or scalar and return final scalar predict_thr value
|
40
|
-
if
|
40
|
+
if num_classes > 1:
|
41
41
|
if not isinstance(predict_thr, np.ndarray):
|
42
42
|
if predict_thr == 0:
|
43
43
|
# multi-label predict_thr scalar 0 force to 0.5 default
|
@@ -61,31 +61,15 @@ def confusion_matrix_summary(
|
|
61
61
|
else:
|
62
62
|
class_names = [f"Class {i}" for i in range(1, num_classes + 1)]
|
63
63
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
# mux = pd.MultiIndex.from_product([['truth thr = {}'.format(truth_thr)], class_nums])
|
75
|
-
|
76
|
-
cmdf = pd.DataFrame(cm, index=row_n, columns=class_nums, dtype=np.int32)
|
77
|
-
cmndf = pd.DataFrame(cmn, index=row_n, columns=class_nums, dtype=np.float32)
|
78
|
-
|
79
|
-
else:
|
80
|
-
_, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
|
81
|
-
cname = class_names[class_idx]
|
82
|
-
row_n = ["TrueN", "TrueP"]
|
83
|
-
col_n = ["N-" + cname, "P-" + cname]
|
84
|
-
cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32)
|
85
|
-
cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32)
|
86
|
-
# add thresholds in 3rd row
|
87
|
-
pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n)
|
88
|
-
cmdf = pd.concat([cmdf, pdnote])
|
89
|
-
cmndf = pd.concat([cmndf, pdnote])
|
64
|
+
_, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
|
65
|
+
cname = class_names[class_idx]
|
66
|
+
row_n = ["TrueN", "TrueP"]
|
67
|
+
col_n = ["N-" + cname, "P-" + cname]
|
68
|
+
cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32) # pyright: ignore [reportArgumentType]
|
69
|
+
cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32) # pyright: ignore [reportArgumentType]
|
70
|
+
# add thresholds in 3rd row
|
71
|
+
pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n) # pyright: ignore [reportArgumentType, reportCallIssue]
|
72
|
+
cmdf = pd.concat([cmdf, pdnote])
|
73
|
+
cmndf = pd.concat([cmndf, pdnote])
|
90
74
|
|
91
75
|
return cmdf, cmndf
|
sonusai/metrics/one_hot.py
CHANGED
@@ -185,11 +185,11 @@ def one_hot(
|
|
185
185
|
AP = np.NaN
|
186
186
|
# threshold_optpr[nci] = np.NaN
|
187
187
|
else:
|
188
|
-
AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None)
|
188
|
+
AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
189
189
|
if len(np.unique(truthb[:, nci])) < 2: # if active classes not > 1 AUC must be NaN
|
190
190
|
AUC = np.NaN # i.e. all ones sklearn auc will fail
|
191
191
|
else:
|
192
|
-
AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None)
|
192
|
+
AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
193
193
|
# # Optimal threshold from PR curve, optimizes f-score
|
194
194
|
# precision, recall, thresholds = precision_recall_curve(truthb[:, nci], predict[:, nci])
|
195
195
|
# fscore = (2 * precision * recall) / (precision + recall)
|
@@ -263,7 +263,7 @@ def one_hot(
|
|
263
263
|
] # specific format, last 3 are unique
|
264
264
|
|
265
265
|
# weighted average TBD
|
266
|
-
wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0)
|
266
|
+
wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0) # pyright: ignore [reportArgumentType]
|
267
267
|
if np.sum(truthb):
|
268
268
|
taidx = np.sum(truthb, axis=0) > 0
|
269
269
|
wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
|
sonusai/metrics/snr_summary.py
CHANGED
@@ -48,7 +48,7 @@ def snr_summary(
|
|
48
48
|
snr_mixids = get_mixids_from_snr(mixdb=mixdb, mixids=mixid)
|
49
49
|
|
50
50
|
# Check predict_thr array or scalar and return final scalar predict_thr value
|
51
|
-
if
|
51
|
+
if num_classes > 1:
|
52
52
|
if not isinstance(predict_thr, np.ndarray):
|
53
53
|
if predict_thr == 0:
|
54
54
|
# multi-label predict_thr scalar 0 force to 0.5 default
|
@@ -84,7 +84,7 @@ def snr_summary(
|
|
84
84
|
for ii, snr in enumerate(snr_mixids):
|
85
85
|
# TODO: re-work for modern mixdb API
|
86
86
|
y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
|
87
|
-
_,
|
87
|
+
_, _, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
|
88
88
|
|
89
89
|
# mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
|
90
90
|
macro_avg[ii, :] = mavg[0, 0:7]
|
@@ -104,21 +104,21 @@ def snr_summary(
|
|
104
104
|
|
105
105
|
# SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
|
106
106
|
col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
|
107
|
-
snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n)
|
107
|
+
snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
108
108
|
snr_macrodf.sort_index(ascending=False, inplace=True)
|
109
109
|
|
110
|
-
snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n)
|
110
|
+
snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
111
111
|
snr_microdf.sort_index(ascending=False, inplace=True)
|
112
112
|
|
113
|
-
snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n)
|
113
|
+
snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
114
114
|
snr_wghtdf.sort_index(ascending=False, inplace=True)
|
115
115
|
|
116
116
|
# Add segmental SNR columns if provided
|
117
117
|
if segsnr is not None:
|
118
118
|
ssnrdf = pd.DataFrame(
|
119
119
|
ssnr_stats,
|
120
|
-
index=list(snr_mixids.keys()),
|
121
|
-
columns=["SSNRavg", "SSNR80p", "SSNRmax"],
|
120
|
+
index=list(snr_mixids.keys()), # pyright: ignore [reportArgumentType]
|
121
|
+
columns=["SSNRavg", "SSNR80p", "SSNRmax"], # pyright: ignore [reportArgumentType]
|
122
122
|
)
|
123
123
|
ssnrdf.sort_index(ascending=False, inplace=True)
|
124
124
|
snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
|
sonusai/mixture/__init__.py
CHANGED
@@ -46,19 +46,15 @@ from .constants import SAMPLE_RATE
|
|
46
46
|
from .constants import VALID_AUGMENTATIONS
|
47
47
|
from .constants import VALID_CONFIGS
|
48
48
|
from .constants import VALID_NOISE_MIX_MODES
|
49
|
+
from .data_io import clear_cached_data
|
49
50
|
from .data_io import read_cached_data
|
50
51
|
from .data_io import write_cached_data
|
51
52
|
from .datatypes import AudioF
|
52
|
-
from .datatypes import AudiosF
|
53
|
-
from .datatypes import AudiosT
|
54
53
|
from .datatypes import AudioStatsMetrics
|
55
54
|
from .datatypes import AudioT
|
56
55
|
from .datatypes import Augmentation
|
57
56
|
from .datatypes import AugmentationRule
|
58
|
-
from .datatypes import AugmentationRules
|
59
|
-
from .datatypes import Augmentations
|
60
57
|
from .datatypes import AugmentedTarget
|
61
|
-
from .datatypes import AugmentedTargets
|
62
58
|
from .datatypes import ClassCount
|
63
59
|
from .datatypes import EnergyF
|
64
60
|
from .datatypes import EnergyT
|
@@ -70,34 +66,27 @@ from .datatypes import GenFTData
|
|
70
66
|
from .datatypes import GenMixData
|
71
67
|
from .datatypes import ImpulseResponseData
|
72
68
|
from .datatypes import ImpulseResponseFile
|
73
|
-
from .datatypes import ImpulseResponseFiles
|
74
|
-
from .datatypes import ListAudiosT
|
75
69
|
from .datatypes import MetricDoc
|
76
70
|
from .datatypes import MetricDocs
|
77
71
|
from .datatypes import Mixture
|
78
72
|
from .datatypes import MixtureDatabaseConfig
|
79
|
-
from .datatypes import Mixtures
|
80
73
|
from .datatypes import NoiseFile
|
81
|
-
from .datatypes import NoiseFiles
|
82
74
|
from .datatypes import Predict
|
83
75
|
from .datatypes import Segsnr
|
84
76
|
from .datatypes import SnrFMetrics
|
85
77
|
from .datatypes import SpectralMask
|
86
|
-
from .datatypes import SpectralMasks
|
87
78
|
from .datatypes import SpeechMetadata
|
88
79
|
from .datatypes import SpeechMetrics
|
89
80
|
from .datatypes import TargetFile
|
90
|
-
from .datatypes import TargetFiles
|
91
81
|
from .datatypes import TransformConfig
|
92
82
|
from .datatypes import Truth
|
93
83
|
from .datatypes import TruthConfig
|
94
84
|
from .datatypes import TruthConfigs
|
85
|
+
from .datatypes import TruthDict
|
95
86
|
from .datatypes import TruthParameter
|
96
|
-
from .datatypes import TruthParameters
|
97
87
|
from .datatypes import UniversalSNR
|
98
88
|
from .feature import get_audio_from_feature
|
99
89
|
from .feature import get_feature_from_audio
|
100
|
-
from .generation import generate_mixtures
|
101
90
|
from .generation import get_all_snrs_from_config
|
102
91
|
from .generation import initialize_db
|
103
92
|
from .generation import populate_class_label_table
|
@@ -110,17 +99,14 @@ from .generation import populate_target_file_table
|
|
110
99
|
from .generation import populate_top_table
|
111
100
|
from .generation import populate_truth_parameters_table
|
112
101
|
from .generation import update_mixid_width
|
113
|
-
from .generation import
|
102
|
+
from .generation import update_mixture_table
|
114
103
|
from .helpers import augmented_noise_samples
|
115
104
|
from .helpers import augmented_target_samples
|
116
105
|
from .helpers import check_audio_files_exist
|
117
106
|
from .helpers import forward_transform
|
118
107
|
from .helpers import frames_from_samples
|
119
108
|
from .helpers import get_audio_from_transform
|
120
|
-
from .helpers import get_ft
|
121
|
-
from .helpers import get_segsnr
|
122
109
|
from .helpers import get_transform_from_audio
|
123
|
-
from .helpers import get_truth
|
124
110
|
from .helpers import inverse_transform
|
125
111
|
from .helpers import mixture_metadata
|
126
112
|
from .helpers import write_mixture_metadata
|
sonusai/mixture/augmentation.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
from sonusai.mixture.datatypes import AudioT
|
2
2
|
from sonusai.mixture.datatypes import Augmentation
|
3
3
|
from sonusai.mixture.datatypes import AugmentationRule
|
4
|
-
from sonusai.mixture.datatypes import AugmentationRules
|
5
4
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
6
5
|
from sonusai.mixture.datatypes import OptionalNumberStr
|
7
6
|
|
8
7
|
|
9
|
-
def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) ->
|
8
|
+
def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
|
10
9
|
"""Generate augmentation rules from list of input rules
|
11
10
|
|
12
11
|
:param rules: Dictionary of augmentation config rule[s]
|
@@ -25,7 +24,7 @@ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> Augment
|
|
25
24
|
rule = _parse_ir(rule, num_ir)
|
26
25
|
processed_rules = _expand_rules(expanded_rules=processed_rules, rule=rule)
|
27
26
|
|
28
|
-
return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules]
|
27
|
+
return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules] # pyright: ignore [reportReturnType]
|
29
28
|
|
30
29
|
|
31
30
|
def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
|
@@ -163,7 +162,7 @@ def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr
|
|
163
162
|
return length
|
164
163
|
|
165
164
|
|
166
|
-
def get_mixups(augmentations:
|
165
|
+
def get_mixups(augmentations: list[AugmentationRule]) -> list[int]:
|
167
166
|
"""Get a list of mixup values used
|
168
167
|
|
169
168
|
:param augmentations: List of augmentations
|
@@ -172,7 +171,7 @@ def get_mixups(augmentations: AugmentationRules) -> list[int]:
|
|
172
171
|
return sorted({augmentation.mixup for augmentation in augmentations})
|
173
172
|
|
174
173
|
|
175
|
-
def get_augmentation_indices_for_mixup(augmentations:
|
174
|
+
def get_augmentation_indices_for_mixup(augmentations: list[AugmentationRule], mixup: int) -> list[int]:
|
176
175
|
"""Get a list of augmentation indices for a given mixup value
|
177
176
|
|
178
177
|
:param augmentations: List of augmentations
|
@@ -327,4 +326,4 @@ def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
|
|
327
326
|
if _rule_has_rand(processed_rule):
|
328
327
|
processed_rule = _generate_random_rule(processed_rule, num_ir)
|
329
328
|
|
330
|
-
return dataclass_from_dict(Augmentation, processed_rule)
|
329
|
+
return dataclass_from_dict(Augmentation, processed_rule) # pyright: ignore [reportReturnType]
|
sonusai/mixture/class_count.py
CHANGED
@@ -3,7 +3,7 @@ from sonusai.mixture.datatypes import GeneralizedIDs
|
|
3
3
|
from sonusai.mixture.mixdb import MixtureDatabase
|
4
4
|
|
5
5
|
|
6
|
-
def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs
|
6
|
+
def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs = "*") -> ClassCount:
|
7
7
|
"""Sums the class counts for given mixids"""
|
8
8
|
total_class_count = [0] * mixdb.num_classes
|
9
9
|
m_ids = mixdb.mixids_to_list(mixids)
|