sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -1
- sonusai/aawscd_probwrite.py +1 -1
- sonusai/calc_metric_spenh.py +1 -1
- sonusai/genft.py +29 -14
- sonusai/genmetrics.py +60 -42
- sonusai/genmix.py +41 -29
- sonusai/genmixdb.py +56 -64
- sonusai/metrics/calc_class_weights.py +1 -3
- sonusai/metrics/calc_optimal_thresholds.py +2 -2
- sonusai/metrics/calc_phase_distance.py +1 -1
- sonusai/metrics/calc_speech.py +6 -6
- sonusai/metrics/class_summary.py +6 -15
- sonusai/metrics/confusion_matrix_summary.py +11 -27
- sonusai/metrics/one_hot.py +3 -3
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/mixture/__init__.py +2 -17
- sonusai/mixture/augmentation.py +5 -6
- sonusai/mixture/class_count.py +1 -1
- sonusai/mixture/config.py +36 -46
- sonusai/mixture/data_io.py +30 -1
- sonusai/mixture/datatypes.py +29 -40
- sonusai/mixture/db_datatypes.py +1 -1
- sonusai/mixture/feature.py +3 -23
- sonusai/mixture/generation.py +161 -204
- sonusai/mixture/helpers.py +29 -187
- sonusai/mixture/mixdb.py +386 -159
- sonusai/mixture/soundfile_audio.py +1 -1
- sonusai/mixture/sox_audio.py +4 -4
- sonusai/mixture/sox_augmentation.py +1 -1
- sonusai/mixture/target_class_balancing.py +9 -11
- sonusai/mixture/targets.py +23 -20
- sonusai/mixture/torchaudio_audio.py +18 -7
- sonusai/mixture/torchaudio_augmentation.py +3 -4
- sonusai/mixture/truth.py +21 -34
- sonusai/mixture/truth_functions/__init__.py +6 -0
- sonusai/mixture/truth_functions/crm.py +51 -37
- sonusai/mixture/truth_functions/energy.py +95 -50
- sonusai/mixture/truth_functions/file.py +12 -8
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +4 -5
- sonusai/mixture/truth_functions/sed.py +32 -23
- sonusai/mixture/truth_functions/target.py +62 -29
- sonusai/mkwav.py +20 -19
- sonusai/queries/queries.py +9 -15
- sonusai/speech/l2arctic.py +6 -2
- sonusai/summarize_metric_spenh.py +1 -1
- sonusai/utils/__init__.py +1 -0
- sonusai/utils/asr_functions/aaware_whisper.py +1 -1
- sonusai/utils/audio_devices.py +27 -18
- sonusai/utils/docstring.py +6 -3
- sonusai/utils/energy_f.py +5 -3
- sonusai/utils/human_readable_size.py +6 -6
- sonusai/utils/load_object.py +15 -0
- sonusai/utils/onnx_utils.py +2 -2
- sonusai/utils/print_mixture_details.py +3 -3
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
- sonusai/mixture/truth_functions/datatypes.py +0 -37
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
sonusai/genmixdb.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1
1
|
"""sonusai genmixdb
|
2
2
|
|
3
|
-
usage: genmixdb [-
|
3
|
+
usage: genmixdb [-hvmfsdjn] LOC
|
4
4
|
|
5
5
|
options:
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-m, --mix ave mixture data. [default: False].
|
9
|
+
-f, --ft Save feature/truth_f data. [default: False].
|
10
|
+
-s, --segsnr Save segsnr data. [default: False].
|
11
|
+
-d, --dryrun Perform a dry run showing the processed config. [default: False].
|
12
|
+
-j, --json Save JSON version of database. [default: False].
|
13
|
+
-n, --nopar Do not run in parallel. [default: False].
|
13
14
|
|
14
15
|
Create mixture database data for training and evaluation. Optionally, also create mixture audio and feature/truth data.
|
15
16
|
|
@@ -115,8 +116,6 @@ will find all .wav files in the specified directories and process them as target
|
|
115
116
|
|
116
117
|
import signal
|
117
118
|
|
118
|
-
from sonusai.mixture import Mixture
|
119
|
-
|
120
119
|
|
121
120
|
def signal_handler(_sig, _frame):
|
122
121
|
import sys
|
@@ -139,6 +138,7 @@ def genmixdb(
|
|
139
138
|
show_progress: bool = False,
|
140
139
|
test: bool = False,
|
141
140
|
save_json: bool = False,
|
141
|
+
no_par: bool = False,
|
142
142
|
) -> None:
|
143
143
|
from functools import partial
|
144
144
|
from random import seed
|
@@ -151,7 +151,6 @@ def genmixdb(
|
|
151
151
|
from sonusai.mixture import AugmentationRule
|
152
152
|
from sonusai.mixture import MixtureDatabase
|
153
153
|
from sonusai.mixture import balance_targets
|
154
|
-
from sonusai.mixture import generate_mixtures
|
155
154
|
from sonusai.mixture import get_all_snrs_from_config
|
156
155
|
from sonusai.mixture import get_augmentation_rules
|
157
156
|
from sonusai.mixture import get_augmented_targets
|
@@ -293,7 +292,7 @@ def genmixdb(
|
|
293
292
|
augmented_targets=augmented_targets,
|
294
293
|
targets=target_files,
|
295
294
|
target_augmentations=target_augmentations,
|
296
|
-
class_balancing_augmentation=class_balancing_augmentation,
|
295
|
+
class_balancing_augmentation=class_balancing_augmentation, # pyright: ignore [reportArgumentType]
|
297
296
|
num_classes=mixdb.num_classes,
|
298
297
|
num_ir=mixdb.num_impulse_response_files,
|
299
298
|
mixups=mixups,
|
@@ -317,7 +316,8 @@ def genmixdb(
|
|
317
316
|
f"{seconds_to_hms(seconds=noise_audio_duration)}"
|
318
317
|
)
|
319
318
|
|
320
|
-
used_noise_files, used_noise_samples
|
319
|
+
used_noise_files, used_noise_samples = populate_mixture_table(
|
320
|
+
location=location,
|
321
321
|
noise_mix_mode=mixdb.noise_mix_mode,
|
322
322
|
augmented_targets=augmented_targets,
|
323
323
|
target_files=target_files,
|
@@ -330,16 +330,17 @@ def genmixdb(
|
|
330
330
|
num_classes=mixdb.num_classes,
|
331
331
|
feature_step_samples=mixdb.feature_step_samples,
|
332
332
|
num_ir=mixdb.num_impulse_response_files,
|
333
|
+
test=test,
|
333
334
|
)
|
334
335
|
|
335
|
-
num_mixtures = len(mixtures)
|
336
|
+
num_mixtures = len(mixdb.mixtures)
|
336
337
|
update_mixid_width(location, num_mixtures, test)
|
337
338
|
|
338
339
|
if logging:
|
339
340
|
logger.info("")
|
340
341
|
logger.info(f"Found {num_mixtures:,} mixtures to process")
|
341
342
|
|
342
|
-
total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
|
343
|
+
total_duration = float(sum([mixture.samples for mixture in mixdb.mixtures])) / SAMPLE_RATE
|
343
344
|
|
344
345
|
if logging:
|
345
346
|
log_duration_and_sizes(
|
@@ -353,7 +354,7 @@ def genmixdb(
|
|
353
354
|
logger.info(
|
354
355
|
f"Feature shape: "
|
355
356
|
f"{mixdb.fg_stride} x {mixdb.feature_parameters} "
|
356
|
-
f"({mixdb.fg_stride * mixdb.feature_parameters} total
|
357
|
+
f"({mixdb.fg_stride * mixdb.feature_parameters} total parameters)"
|
357
358
|
)
|
358
359
|
logger.info(f"Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
|
359
360
|
logger.info(f"Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)")
|
@@ -363,7 +364,7 @@ def genmixdb(
|
|
363
364
|
if logging:
|
364
365
|
logger.info("Generating mixtures")
|
365
366
|
progress = track(total=num_mixtures, disable=not show_progress)
|
366
|
-
|
367
|
+
par_track(
|
367
368
|
partial(
|
368
369
|
_process_mixture,
|
369
370
|
location=location,
|
@@ -372,13 +373,12 @@ def genmixdb(
|
|
372
373
|
save_segsnr=save_segsnr,
|
373
374
|
test=test,
|
374
375
|
),
|
375
|
-
|
376
|
+
range(num_mixtures),
|
376
377
|
progress=progress,
|
378
|
+
no_par=no_par,
|
377
379
|
)
|
378
380
|
progress.close()
|
379
381
|
|
380
|
-
populate_mixture_table(location, mixtures, test)
|
381
|
-
|
382
382
|
total_noise_files = len(noise_files)
|
383
383
|
|
384
384
|
total_samples = mixdb.total_samples()
|
@@ -409,70 +409,60 @@ def genmixdb(
|
|
409
409
|
|
410
410
|
|
411
411
|
def _process_mixture(
|
412
|
-
|
412
|
+
m_id: int,
|
413
413
|
location: str,
|
414
414
|
save_mix: bool,
|
415
415
|
save_ft: bool,
|
416
416
|
save_segsnr: bool,
|
417
417
|
test: bool,
|
418
|
-
) ->
|
419
|
-
from
|
418
|
+
) -> None:
|
419
|
+
from functools import partial
|
420
420
|
|
421
421
|
from sonusai.mixture import MixtureDatabase
|
422
|
-
from sonusai.mixture import
|
423
|
-
from sonusai.mixture import
|
424
|
-
from sonusai.mixture import get_truth
|
425
|
-
from sonusai.mixture import update_mixture
|
422
|
+
from sonusai.mixture import clear_cached_data
|
423
|
+
from sonusai.mixture import update_mixture_table
|
426
424
|
from sonusai.mixture import write_cached_data
|
427
425
|
from sonusai.mixture import write_mixture_metadata
|
428
426
|
|
429
|
-
with_data = save_mix or save_ft
|
427
|
+
with_data = save_mix or save_ft or save_segsnr
|
428
|
+
|
429
|
+
genmix_data = update_mixture_table(location, m_id, with_data, test)
|
430
|
+
|
430
431
|
mixdb = MixtureDatabase(location, test)
|
432
|
+
mixture = mixdb.mixture(m_id)
|
431
433
|
|
432
|
-
|
434
|
+
write = partial(write_cached_data, location=location, name="mixture", index=mixture.name)
|
435
|
+
clear = partial(clear_cached_data, location=location, name="mixture", index=mixture.name)
|
433
436
|
|
434
437
|
if with_data:
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
438
|
+
write(
|
439
|
+
items=[
|
440
|
+
("targets", genmix_data.targets),
|
441
|
+
("target", genmix_data.target),
|
442
|
+
("noise", genmix_data.noise),
|
443
|
+
("mixture", genmix_data.mixture),
|
444
|
+
]
|
445
|
+
)
|
441
446
|
|
442
447
|
if save_ft:
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
mixture_audio=genmix_data.mixture,
|
451
|
-
)
|
452
|
-
feature, truth_f = get_ft(
|
453
|
-
mixdb=mixdb,
|
454
|
-
mixture=mixture,
|
455
|
-
mixture_audio=genmix_data.mixture,
|
456
|
-
truth_t=truth_t,
|
448
|
+
clear(items=["feature", "truth_f"])
|
449
|
+
feature, truth_f = mixdb.mixture_ft(m_id)
|
450
|
+
write(
|
451
|
+
items=[
|
452
|
+
("feature", feature),
|
453
|
+
("truth_f", truth_f),
|
454
|
+
]
|
457
455
|
)
|
458
|
-
write_data.append(("feature", feature))
|
459
|
-
write_data.append(("truth_f", truth_f))
|
460
456
|
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
mixdb=mixdb,
|
466
|
-
mixture=mixture,
|
467
|
-
target_audio=genmix_data.target,
|
468
|
-
noise=genmix_data.noise,
|
469
|
-
)
|
470
|
-
write_data.append(("segsnr", segsnr))
|
457
|
+
if save_segsnr:
|
458
|
+
clear(items=["segsnr"])
|
459
|
+
segsnr = mixdb.mixture_segsnr(m_id)
|
460
|
+
write(items=[("segsnr", segsnr)])
|
471
461
|
|
472
|
-
|
473
|
-
|
462
|
+
if not save_mix:
|
463
|
+
clear(items=["targets", "target", "noise", "mixture"])
|
474
464
|
|
475
|
-
|
465
|
+
write_mixture_metadata(mixdb, m_id)
|
476
466
|
|
477
467
|
|
478
468
|
def main() -> None:
|
@@ -505,6 +495,7 @@ def main() -> None:
|
|
505
495
|
save_segsnr = args["--segsnr"]
|
506
496
|
dryrun = args["--dryrun"]
|
507
497
|
save_json = args["--json"]
|
498
|
+
no_par = args["--nopar"]
|
508
499
|
location = args["LOC"]
|
509
500
|
|
510
501
|
start_time = time.monotonic()
|
@@ -535,6 +526,7 @@ def main() -> None:
|
|
535
526
|
save_segsnr=save_segsnr,
|
536
527
|
show_progress=True,
|
537
528
|
save_json=save_json,
|
529
|
+
no_par=no_par,
|
538
530
|
)
|
539
531
|
except Exception as e:
|
540
532
|
logger.debug(e)
|
@@ -54,7 +54,7 @@ def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = Non
|
|
54
54
|
|
55
55
|
def calc_class_weights_from_mixdb(
|
56
56
|
mixdb: MixtureDatabase,
|
57
|
-
mixids: GeneralizedIDs
|
57
|
+
mixids: GeneralizedIDs = "*",
|
58
58
|
other_weight: float = 1,
|
59
59
|
other_index: int = -1,
|
60
60
|
) -> tuple[np.ndarray, np.ndarray]:
|
@@ -77,8 +77,6 @@ def calc_class_weights_from_mixdb(
|
|
77
77
|
from sonusai.mixture import get_class_count_from_mixids
|
78
78
|
|
79
79
|
count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
|
80
|
-
if mixdb.truth_mutex and other_weight is not None and other_weight > 0:
|
81
|
-
count[other_index] = count[other_index] / np.float32(other_weight)
|
82
80
|
total_features = sum(count)
|
83
81
|
|
84
82
|
weights = np.empty(mixdb.num_classes, dtype=np.float32)
|
@@ -51,8 +51,8 @@ def calc_optimal_thresholds(
|
|
51
51
|
AUC[nci] = np.NaN
|
52
52
|
AP[nci] = np.NaN
|
53
53
|
else:
|
54
|
-
AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None)
|
55
|
-
AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None)
|
54
|
+
AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
55
|
+
AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
56
56
|
|
57
57
|
# Optimal threshold from PR curve, optimizes f-score
|
58
58
|
precision, recall, thrpr = precision_recall_curve(truth_binary[:, nci], predict[:, nci])
|
@@ -26,7 +26,7 @@ def calc_phase_distance(
|
|
26
26
|
# weighted mean over all (scalar)
|
27
27
|
reference_mag = np.abs(reference)
|
28
28
|
ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
|
29
|
-
err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
|
29
|
+
err = float(np.around(np.sum(ref_weight * rh_angle_diff), 3))
|
30
30
|
|
31
31
|
# weighted mean over frames (value per bin)
|
32
32
|
err_b = np.zeros(reference.shape[1])
|
sonusai/metrics/calc_speech.py
CHANGED
@@ -32,16 +32,16 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
|
|
32
32
|
llr_mean = np.mean(ll_rs[:llr_len])
|
33
33
|
|
34
34
|
# Segmental SNR
|
35
|
-
|
35
|
+
_, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
36
36
|
seg_snr = np.mean(segsnr_dist)
|
37
37
|
|
38
38
|
# PESQ
|
39
39
|
_pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
40
40
|
|
41
41
|
# Now compute the composite measures
|
42
|
-
csig = np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5)
|
43
|
-
cbak = np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5)
|
44
|
-
covl = np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5)
|
42
|
+
csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5))
|
43
|
+
cbak = float(np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
|
44
|
+
covl = float(np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
|
45
45
|
|
46
46
|
return SpeechMetrics(_pesq, csig, cbak, covl)
|
47
47
|
|
@@ -284,8 +284,8 @@ def _calc_log_likelihood_ratio_measure(
|
|
284
284
|
hypothesis_frame = np.multiply(hypothesis_frame, window)
|
285
285
|
|
286
286
|
# (2) Get the autocorrelation lags and LPC parameters used to compute the log likelihood ratio measure.
|
287
|
-
r_reference,
|
288
|
-
|
287
|
+
r_reference, _, a_reference = _lp_coefficients(reference_frame, p)
|
288
|
+
_, _, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
|
289
289
|
|
290
290
|
# (3) Compute the log likelihood ratio measure
|
291
291
|
numerator = np.dot(np.matmul(a_hypothesis, toeplitz(r_reference)), a_hypothesis)
|
sonusai/metrics/class_summary.py
CHANGED
@@ -38,7 +38,7 @@ def class_summary(
|
|
38
38
|
# TODO: re-work for modern mixdb API
|
39
39
|
y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore[name-defined]
|
40
40
|
|
41
|
-
if
|
41
|
+
if num_classes > 1:
|
42
42
|
if not isinstance(predict_thr, np.ndarray):
|
43
43
|
if predict_thr == 0:
|
44
44
|
predict_thr = np.atleast_1d(0.5)
|
@@ -53,25 +53,16 @@ def class_summary(
|
|
53
53
|
# [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
|
54
54
|
table_idx = np.array([2, 1, 6, 4, 0, 12, 13, 9])
|
55
55
|
col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC", "Support"]
|
56
|
-
if mixdb.
|
57
|
-
|
58
|
-
row_n = mixdb.class_labels
|
59
|
-
if len(mixdb.class_labels) == num_classes - 1: # Other label does not exist, so add it
|
60
|
-
row_n.append("Other")
|
61
|
-
else:
|
62
|
-
row_n = [f"Class {i}" for i in range(1, num_classes)]
|
63
|
-
row_n.append("Other")
|
56
|
+
if len(mixdb.class_labels) == num_classes:
|
57
|
+
row_n = mixdb.class_labels
|
64
58
|
else:
|
65
|
-
|
66
|
-
row_n = mixdb.class_labels
|
67
|
-
else:
|
68
|
-
row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
|
59
|
+
row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
|
69
60
|
|
70
|
-
df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n)
|
61
|
+
df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
|
71
62
|
|
72
63
|
# [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
|
73
64
|
avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
|
74
|
-
dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n)
|
65
|
+
dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
|
75
66
|
|
76
67
|
# dfblank = pd.DataFrame([''])
|
77
68
|
# pd.concat([df, dfblank, dfblank, dfavg])
|
@@ -37,7 +37,7 @@ def confusion_matrix_summary(
|
|
37
37
|
ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore[name-defined]
|
38
38
|
|
39
39
|
# Check predict_thr array or scalar and return final scalar predict_thr value
|
40
|
-
if
|
40
|
+
if num_classes > 1:
|
41
41
|
if not isinstance(predict_thr, np.ndarray):
|
42
42
|
if predict_thr == 0:
|
43
43
|
# multi-label predict_thr scalar 0 force to 0.5 default
|
@@ -61,31 +61,15 @@ def confusion_matrix_summary(
|
|
61
61
|
else:
|
62
62
|
class_names = [f"Class {i}" for i in range(1, num_classes + 1)]
|
63
63
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
# mux = pd.MultiIndex.from_product([['truth thr = {}'.format(truth_thr)], class_nums])
|
75
|
-
|
76
|
-
cmdf = pd.DataFrame(cm, index=row_n, columns=class_nums, dtype=np.int32)
|
77
|
-
cmndf = pd.DataFrame(cmn, index=row_n, columns=class_nums, dtype=np.float32)
|
78
|
-
|
79
|
-
else:
|
80
|
-
_, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
|
81
|
-
cname = class_names[class_idx]
|
82
|
-
row_n = ["TrueN", "TrueP"]
|
83
|
-
col_n = ["N-" + cname, "P-" + cname]
|
84
|
-
cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32)
|
85
|
-
cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32)
|
86
|
-
# add thresholds in 3rd row
|
87
|
-
pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n)
|
88
|
-
cmdf = pd.concat([cmdf, pdnote])
|
89
|
-
cmndf = pd.concat([cmndf, pdnote])
|
64
|
+
_, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
|
65
|
+
cname = class_names[class_idx]
|
66
|
+
row_n = ["TrueN", "TrueP"]
|
67
|
+
col_n = ["N-" + cname, "P-" + cname]
|
68
|
+
cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32) # pyright: ignore [reportArgumentType]
|
69
|
+
cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32) # pyright: ignore [reportArgumentType]
|
70
|
+
# add thresholds in 3rd row
|
71
|
+
pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n) # pyright: ignore [reportArgumentType, reportCallIssue]
|
72
|
+
cmdf = pd.concat([cmdf, pdnote])
|
73
|
+
cmndf = pd.concat([cmndf, pdnote])
|
90
74
|
|
91
75
|
return cmdf, cmndf
|
sonusai/metrics/one_hot.py
CHANGED
@@ -185,11 +185,11 @@ def one_hot(
|
|
185
185
|
AP = np.NaN
|
186
186
|
# threshold_optpr[nci] = np.NaN
|
187
187
|
else:
|
188
|
-
AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None)
|
188
|
+
AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
189
189
|
if len(np.unique(truthb[:, nci])) < 2: # if active classes not > 1 AUC must be NaN
|
190
190
|
AUC = np.NaN # i.e. all ones sklearn auc will fail
|
191
191
|
else:
|
192
|
-
AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None)
|
192
|
+
AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
193
193
|
# # Optimal threshold from PR curve, optimizes f-score
|
194
194
|
# precision, recall, thresholds = precision_recall_curve(truthb[:, nci], predict[:, nci])
|
195
195
|
# fscore = (2 * precision * recall) / (precision + recall)
|
@@ -263,7 +263,7 @@ def one_hot(
|
|
263
263
|
] # specific format, last 3 are unique
|
264
264
|
|
265
265
|
# weighted average TBD
|
266
|
-
wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0)
|
266
|
+
wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0) # pyright: ignore [reportArgumentType]
|
267
267
|
if np.sum(truthb):
|
268
268
|
taidx = np.sum(truthb, axis=0) > 0
|
269
269
|
wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
|
sonusai/metrics/snr_summary.py
CHANGED
@@ -48,7 +48,7 @@ def snr_summary(
|
|
48
48
|
snr_mixids = get_mixids_from_snr(mixdb=mixdb, mixids=mixid)
|
49
49
|
|
50
50
|
# Check predict_thr array or scalar and return final scalar predict_thr value
|
51
|
-
if
|
51
|
+
if num_classes > 1:
|
52
52
|
if not isinstance(predict_thr, np.ndarray):
|
53
53
|
if predict_thr == 0:
|
54
54
|
# multi-label predict_thr scalar 0 force to 0.5 default
|
@@ -84,7 +84,7 @@ def snr_summary(
|
|
84
84
|
for ii, snr in enumerate(snr_mixids):
|
85
85
|
# TODO: re-work for modern mixdb API
|
86
86
|
y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
|
87
|
-
_,
|
87
|
+
_, _, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
|
88
88
|
|
89
89
|
# mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
|
90
90
|
macro_avg[ii, :] = mavg[0, 0:7]
|
@@ -104,21 +104,21 @@ def snr_summary(
|
|
104
104
|
|
105
105
|
# SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
|
106
106
|
col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
|
107
|
-
snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n)
|
107
|
+
snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
108
108
|
snr_macrodf.sort_index(ascending=False, inplace=True)
|
109
109
|
|
110
|
-
snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n)
|
110
|
+
snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
111
111
|
snr_microdf.sort_index(ascending=False, inplace=True)
|
112
112
|
|
113
|
-
snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n)
|
113
|
+
snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
114
114
|
snr_wghtdf.sort_index(ascending=False, inplace=True)
|
115
115
|
|
116
116
|
# Add segmental SNR columns if provided
|
117
117
|
if segsnr is not None:
|
118
118
|
ssnrdf = pd.DataFrame(
|
119
119
|
ssnr_stats,
|
120
|
-
index=list(snr_mixids.keys()),
|
121
|
-
columns=["SSNRavg", "SSNR80p", "SSNRmax"],
|
120
|
+
index=list(snr_mixids.keys()), # pyright: ignore [reportArgumentType]
|
121
|
+
columns=["SSNRavg", "SSNR80p", "SSNRmax"], # pyright: ignore [reportArgumentType]
|
122
122
|
)
|
123
123
|
ssnrdf.sort_index(ascending=False, inplace=True)
|
124
124
|
snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
|
sonusai/mixture/__init__.py
CHANGED
@@ -46,19 +46,15 @@ from .constants import SAMPLE_RATE
|
|
46
46
|
from .constants import VALID_AUGMENTATIONS
|
47
47
|
from .constants import VALID_CONFIGS
|
48
48
|
from .constants import VALID_NOISE_MIX_MODES
|
49
|
+
from .data_io import clear_cached_data
|
49
50
|
from .data_io import read_cached_data
|
50
51
|
from .data_io import write_cached_data
|
51
52
|
from .datatypes import AudioF
|
52
|
-
from .datatypes import AudiosF
|
53
|
-
from .datatypes import AudiosT
|
54
53
|
from .datatypes import AudioStatsMetrics
|
55
54
|
from .datatypes import AudioT
|
56
55
|
from .datatypes import Augmentation
|
57
56
|
from .datatypes import AugmentationRule
|
58
|
-
from .datatypes import AugmentationRules
|
59
|
-
from .datatypes import Augmentations
|
60
57
|
from .datatypes import AugmentedTarget
|
61
|
-
from .datatypes import AugmentedTargets
|
62
58
|
from .datatypes import ClassCount
|
63
59
|
from .datatypes import EnergyF
|
64
60
|
from .datatypes import EnergyT
|
@@ -70,35 +66,27 @@ from .datatypes import GenFTData
|
|
70
66
|
from .datatypes import GenMixData
|
71
67
|
from .datatypes import ImpulseResponseData
|
72
68
|
from .datatypes import ImpulseResponseFile
|
73
|
-
from .datatypes import ImpulseResponseFiles
|
74
|
-
from .datatypes import ListAudiosT
|
75
69
|
from .datatypes import MetricDoc
|
76
70
|
from .datatypes import MetricDocs
|
77
71
|
from .datatypes import Mixture
|
78
72
|
from .datatypes import MixtureDatabaseConfig
|
79
|
-
from .datatypes import Mixtures
|
80
73
|
from .datatypes import NoiseFile
|
81
|
-
from .datatypes import NoiseFiles
|
82
74
|
from .datatypes import Predict
|
83
75
|
from .datatypes import Segsnr
|
84
76
|
from .datatypes import SnrFMetrics
|
85
77
|
from .datatypes import SpectralMask
|
86
|
-
from .datatypes import SpectralMasks
|
87
78
|
from .datatypes import SpeechMetadata
|
88
79
|
from .datatypes import SpeechMetrics
|
89
80
|
from .datatypes import TargetFile
|
90
|
-
from .datatypes import TargetFiles
|
91
81
|
from .datatypes import TransformConfig
|
92
82
|
from .datatypes import Truth
|
93
83
|
from .datatypes import TruthConfig
|
94
84
|
from .datatypes import TruthConfigs
|
95
85
|
from .datatypes import TruthDict
|
96
86
|
from .datatypes import TruthParameter
|
97
|
-
from .datatypes import TruthParameters
|
98
87
|
from .datatypes import UniversalSNR
|
99
88
|
from .feature import get_audio_from_feature
|
100
89
|
from .feature import get_feature_from_audio
|
101
|
-
from .generation import generate_mixtures
|
102
90
|
from .generation import get_all_snrs_from_config
|
103
91
|
from .generation import initialize_db
|
104
92
|
from .generation import populate_class_label_table
|
@@ -111,17 +99,14 @@ from .generation import populate_target_file_table
|
|
111
99
|
from .generation import populate_top_table
|
112
100
|
from .generation import populate_truth_parameters_table
|
113
101
|
from .generation import update_mixid_width
|
114
|
-
from .generation import
|
102
|
+
from .generation import update_mixture_table
|
115
103
|
from .helpers import augmented_noise_samples
|
116
104
|
from .helpers import augmented_target_samples
|
117
105
|
from .helpers import check_audio_files_exist
|
118
106
|
from .helpers import forward_transform
|
119
107
|
from .helpers import frames_from_samples
|
120
108
|
from .helpers import get_audio_from_transform
|
121
|
-
from .helpers import get_ft
|
122
|
-
from .helpers import get_segsnr
|
123
109
|
from .helpers import get_transform_from_audio
|
124
|
-
from .helpers import get_truth
|
125
110
|
from .helpers import inverse_transform
|
126
111
|
from .helpers import mixture_metadata
|
127
112
|
from .helpers import write_mixture_metadata
|
sonusai/mixture/augmentation.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
from sonusai.mixture.datatypes import AudioT
|
2
2
|
from sonusai.mixture.datatypes import Augmentation
|
3
3
|
from sonusai.mixture.datatypes import AugmentationRule
|
4
|
-
from sonusai.mixture.datatypes import AugmentationRules
|
5
4
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
6
5
|
from sonusai.mixture.datatypes import OptionalNumberStr
|
7
6
|
|
8
7
|
|
9
|
-
def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) ->
|
8
|
+
def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
|
10
9
|
"""Generate augmentation rules from list of input rules
|
11
10
|
|
12
11
|
:param rules: Dictionary of augmentation config rule[s]
|
@@ -25,7 +24,7 @@ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> Augment
|
|
25
24
|
rule = _parse_ir(rule, num_ir)
|
26
25
|
processed_rules = _expand_rules(expanded_rules=processed_rules, rule=rule)
|
27
26
|
|
28
|
-
return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules]
|
27
|
+
return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules] # pyright: ignore [reportReturnType]
|
29
28
|
|
30
29
|
|
31
30
|
def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
|
@@ -163,7 +162,7 @@ def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr
|
|
163
162
|
return length
|
164
163
|
|
165
164
|
|
166
|
-
def get_mixups(augmentations:
|
165
|
+
def get_mixups(augmentations: list[AugmentationRule]) -> list[int]:
|
167
166
|
"""Get a list of mixup values used
|
168
167
|
|
169
168
|
:param augmentations: List of augmentations
|
@@ -172,7 +171,7 @@ def get_mixups(augmentations: AugmentationRules) -> list[int]:
|
|
172
171
|
return sorted({augmentation.mixup for augmentation in augmentations})
|
173
172
|
|
174
173
|
|
175
|
-
def get_augmentation_indices_for_mixup(augmentations:
|
174
|
+
def get_augmentation_indices_for_mixup(augmentations: list[AugmentationRule], mixup: int) -> list[int]:
|
176
175
|
"""Get a list of augmentation indices for a given mixup value
|
177
176
|
|
178
177
|
:param augmentations: List of augmentations
|
@@ -327,4 +326,4 @@ def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
|
|
327
326
|
if _rule_has_rand(processed_rule):
|
328
327
|
processed_rule = _generate_random_rule(processed_rule, num_ir)
|
329
328
|
|
330
|
-
return dataclass_from_dict(Augmentation, processed_rule)
|
329
|
+
return dataclass_from_dict(Augmentation, processed_rule) # pyright: ignore [reportReturnType]
|
sonusai/mixture/class_count.py
CHANGED
@@ -3,7 +3,7 @@ from sonusai.mixture.datatypes import GeneralizedIDs
|
|
3
3
|
from sonusai.mixture.mixdb import MixtureDatabase
|
4
4
|
|
5
5
|
|
6
|
-
def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs
|
6
|
+
def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs = "*") -> ClassCount:
|
7
7
|
"""Sums the class counts for given mixids"""
|
8
8
|
total_class_count = [0] * mixdb.num_classes
|
9
9
|
m_ids = mixdb.mixids_to_list(mixids)
|