sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +29 -14
  5. sonusai/genmetrics.py +60 -42
  6. sonusai/genmix.py +41 -29
  7. sonusai/genmixdb.py +56 -64
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_speech.py +6 -6
  12. sonusai/metrics/class_summary.py +6 -15
  13. sonusai/metrics/confusion_matrix_summary.py +11 -27
  14. sonusai/metrics/one_hot.py +3 -3
  15. sonusai/metrics/snr_summary.py +7 -7
  16. sonusai/mixture/__init__.py +2 -17
  17. sonusai/mixture/augmentation.py +5 -6
  18. sonusai/mixture/class_count.py +1 -1
  19. sonusai/mixture/config.py +36 -46
  20. sonusai/mixture/data_io.py +30 -1
  21. sonusai/mixture/datatypes.py +29 -40
  22. sonusai/mixture/db_datatypes.py +1 -1
  23. sonusai/mixture/feature.py +3 -23
  24. sonusai/mixture/generation.py +161 -204
  25. sonusai/mixture/helpers.py +29 -187
  26. sonusai/mixture/mixdb.py +386 -159
  27. sonusai/mixture/soundfile_audio.py +1 -1
  28. sonusai/mixture/sox_audio.py +4 -4
  29. sonusai/mixture/sox_augmentation.py +1 -1
  30. sonusai/mixture/target_class_balancing.py +9 -11
  31. sonusai/mixture/targets.py +23 -20
  32. sonusai/mixture/torchaudio_audio.py +18 -7
  33. sonusai/mixture/torchaudio_augmentation.py +3 -4
  34. sonusai/mixture/truth.py +21 -34
  35. sonusai/mixture/truth_functions/__init__.py +6 -0
  36. sonusai/mixture/truth_functions/crm.py +51 -37
  37. sonusai/mixture/truth_functions/energy.py +95 -50
  38. sonusai/mixture/truth_functions/file.py +12 -8
  39. sonusai/mixture/truth_functions/metadata.py +24 -0
  40. sonusai/mixture/truth_functions/metrics.py +28 -0
  41. sonusai/mixture/truth_functions/phoneme.py +4 -5
  42. sonusai/mixture/truth_functions/sed.py +32 -23
  43. sonusai/mixture/truth_functions/target.py +62 -29
  44. sonusai/mkwav.py +20 -19
  45. sonusai/queries/queries.py +9 -15
  46. sonusai/speech/l2arctic.py +6 -2
  47. sonusai/summarize_metric_spenh.py +1 -1
  48. sonusai/utils/__init__.py +1 -0
  49. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  50. sonusai/utils/audio_devices.py +27 -18
  51. sonusai/utils/docstring.py +6 -3
  52. sonusai/utils/energy_f.py +5 -3
  53. sonusai/utils/human_readable_size.py +6 -6
  54. sonusai/utils/load_object.py +15 -0
  55. sonusai/utils/onnx_utils.py +2 -2
  56. sonusai/utils/print_mixture_details.py +3 -3
  57. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
  58. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
  59. sonusai/mixture/truth_functions/datatypes.py +0 -37
  60. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
  61. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
sonusai/genmixdb.py CHANGED
@@ -1,15 +1,16 @@
1
1
  """sonusai genmixdb
2
2
 
3
- usage: genmixdb [-hvmfsdj] LOC
3
+ usage: genmixdb [-hvmfsdjn] LOC
4
4
 
5
5
  options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -m, --mix Save mixture data. [default: False].
9
- -f, --ft Save feature/truth_f data. [default: False].
10
- -s, --segsnr Save segsnr data. [default: False].
11
- -d, --dryrun Perform a dry run showing the processed config. [default: False].
12
- -j, --json Save JSON version of database. [default: False].
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -m, --mix ave mixture data. [default: False].
9
+ -f, --ft Save feature/truth_f data. [default: False].
10
+ -s, --segsnr Save segsnr data. [default: False].
11
+ -d, --dryrun Perform a dry run showing the processed config. [default: False].
12
+ -j, --json Save JSON version of database. [default: False].
13
+ -n, --nopar Do not run in parallel. [default: False].
13
14
 
14
15
  Create mixture database data for training and evaluation. Optionally, also create mixture audio and feature/truth data.
15
16
 
@@ -115,8 +116,6 @@ will find all .wav files in the specified directories and process them as target
115
116
 
116
117
  import signal
117
118
 
118
- from sonusai.mixture import Mixture
119
-
120
119
 
121
120
  def signal_handler(_sig, _frame):
122
121
  import sys
@@ -139,6 +138,7 @@ def genmixdb(
139
138
  show_progress: bool = False,
140
139
  test: bool = False,
141
140
  save_json: bool = False,
141
+ no_par: bool = False,
142
142
  ) -> None:
143
143
  from functools import partial
144
144
  from random import seed
@@ -151,7 +151,6 @@ def genmixdb(
151
151
  from sonusai.mixture import AugmentationRule
152
152
  from sonusai.mixture import MixtureDatabase
153
153
  from sonusai.mixture import balance_targets
154
- from sonusai.mixture import generate_mixtures
155
154
  from sonusai.mixture import get_all_snrs_from_config
156
155
  from sonusai.mixture import get_augmentation_rules
157
156
  from sonusai.mixture import get_augmented_targets
@@ -293,7 +292,7 @@ def genmixdb(
293
292
  augmented_targets=augmented_targets,
294
293
  targets=target_files,
295
294
  target_augmentations=target_augmentations,
296
- class_balancing_augmentation=class_balancing_augmentation,
295
+ class_balancing_augmentation=class_balancing_augmentation, # pyright: ignore [reportArgumentType]
297
296
  num_classes=mixdb.num_classes,
298
297
  num_ir=mixdb.num_impulse_response_files,
299
298
  mixups=mixups,
@@ -317,7 +316,8 @@ def genmixdb(
317
316
  f"{seconds_to_hms(seconds=noise_audio_duration)}"
318
317
  )
319
318
 
320
- used_noise_files, used_noise_samples, mixtures = generate_mixtures(
319
+ used_noise_files, used_noise_samples = populate_mixture_table(
320
+ location=location,
321
321
  noise_mix_mode=mixdb.noise_mix_mode,
322
322
  augmented_targets=augmented_targets,
323
323
  target_files=target_files,
@@ -330,16 +330,17 @@ def genmixdb(
330
330
  num_classes=mixdb.num_classes,
331
331
  feature_step_samples=mixdb.feature_step_samples,
332
332
  num_ir=mixdb.num_impulse_response_files,
333
+ test=test,
333
334
  )
334
335
 
335
- num_mixtures = len(mixtures)
336
+ num_mixtures = len(mixdb.mixtures)
336
337
  update_mixid_width(location, num_mixtures, test)
337
338
 
338
339
  if logging:
339
340
  logger.info("")
340
341
  logger.info(f"Found {num_mixtures:,} mixtures to process")
341
342
 
342
- total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
343
+ total_duration = float(sum([mixture.samples for mixture in mixdb.mixtures])) / SAMPLE_RATE
343
344
 
344
345
  if logging:
345
346
  log_duration_and_sizes(
@@ -353,7 +354,7 @@ def genmixdb(
353
354
  logger.info(
354
355
  f"Feature shape: "
355
356
  f"{mixdb.fg_stride} x {mixdb.feature_parameters} "
356
- f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
357
+ f"({mixdb.fg_stride * mixdb.feature_parameters} total parameters)"
357
358
  )
358
359
  logger.info(f"Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
359
360
  logger.info(f"Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)")
@@ -363,7 +364,7 @@ def genmixdb(
363
364
  if logging:
364
365
  logger.info("Generating mixtures")
365
366
  progress = track(total=num_mixtures, disable=not show_progress)
366
- mixtures = par_track(
367
+ par_track(
367
368
  partial(
368
369
  _process_mixture,
369
370
  location=location,
@@ -372,13 +373,12 @@ def genmixdb(
372
373
  save_segsnr=save_segsnr,
373
374
  test=test,
374
375
  ),
375
- mixtures,
376
+ range(num_mixtures),
376
377
  progress=progress,
378
+ no_par=no_par,
377
379
  )
378
380
  progress.close()
379
381
 
380
- populate_mixture_table(location, mixtures, test)
381
-
382
382
  total_noise_files = len(noise_files)
383
383
 
384
384
  total_samples = mixdb.total_samples()
@@ -409,70 +409,60 @@ def genmixdb(
409
409
 
410
410
 
411
411
  def _process_mixture(
412
- mixture: Mixture,
412
+ m_id: int,
413
413
  location: str,
414
414
  save_mix: bool,
415
415
  save_ft: bool,
416
416
  save_segsnr: bool,
417
417
  test: bool,
418
- ) -> Mixture:
419
- from typing import Any
418
+ ) -> None:
419
+ from functools import partial
420
420
 
421
421
  from sonusai.mixture import MixtureDatabase
422
- from sonusai.mixture import get_ft
423
- from sonusai.mixture import get_segsnr
424
- from sonusai.mixture import get_truth
425
- from sonusai.mixture import update_mixture
422
+ from sonusai.mixture import clear_cached_data
423
+ from sonusai.mixture import update_mixture_table
426
424
  from sonusai.mixture import write_cached_data
427
425
  from sonusai.mixture import write_mixture_metadata
428
426
 
429
- with_data = save_mix or save_ft
427
+ with_data = save_mix or save_ft or save_segsnr
428
+
429
+ genmix_data = update_mixture_table(location, m_id, with_data, test)
430
+
430
431
  mixdb = MixtureDatabase(location, test)
432
+ mixture = mixdb.mixture(m_id)
431
433
 
432
- mixture, genmix_data = update_mixture(mixdb, mixture, with_data)
434
+ write = partial(write_cached_data, location=location, name="mixture", index=mixture.name)
435
+ clear = partial(clear_cached_data, location=location, name="mixture", index=mixture.name)
433
436
 
434
437
  if with_data:
435
- write_data: list[tuple[str, Any]] = []
436
-
437
- if save_mix:
438
- write_data.append(("targets", genmix_data.targets))
439
- write_data.append(("noise", genmix_data.noise))
440
- write_data.append(("mixture", genmix_data.mixture))
438
+ write(
439
+ items=[
440
+ ("targets", genmix_data.targets),
441
+ ("target", genmix_data.target),
442
+ ("noise", genmix_data.noise),
443
+ ("mixture", genmix_data.mixture),
444
+ ]
445
+ )
441
446
 
442
447
  if save_ft:
443
- if genmix_data.targets is None or genmix_data.noise is None or genmix_data.mixture is None:
444
- raise RuntimeError("Mixture data was not generated properly")
445
- truth_t = get_truth(
446
- mixdb=mixdb,
447
- mixture=mixture,
448
- targets_audio=genmix_data.targets,
449
- noise_audio=genmix_data.noise,
450
- mixture_audio=genmix_data.mixture,
451
- )
452
- feature, truth_f = get_ft(
453
- mixdb=mixdb,
454
- mixture=mixture,
455
- mixture_audio=genmix_data.mixture,
456
- truth_t=truth_t,
448
+ clear(items=["feature", "truth_f"])
449
+ feature, truth_f = mixdb.mixture_ft(m_id)
450
+ write(
451
+ items=[
452
+ ("feature", feature),
453
+ ("truth_f", truth_f),
454
+ ]
457
455
  )
458
- write_data.append(("feature", feature))
459
- write_data.append(("truth_f", truth_f))
460
456
 
461
- if save_segsnr:
462
- if genmix_data.target is None:
463
- raise RuntimeError("Target data was not generated properly")
464
- segsnr = get_segsnr(
465
- mixdb=mixdb,
466
- mixture=mixture,
467
- target_audio=genmix_data.target,
468
- noise=genmix_data.noise,
469
- )
470
- write_data.append(("segsnr", segsnr))
457
+ if save_segsnr:
458
+ clear(items=["segsnr"])
459
+ segsnr = mixdb.mixture_segsnr(m_id)
460
+ write(items=[("segsnr", segsnr)])
471
461
 
472
- write_cached_data(mixdb.location, "mixture", mixture.name, write_data)
473
- write_mixture_metadata(mixdb, mixture)
462
+ if not save_mix:
463
+ clear(items=["targets", "target", "noise", "mixture"])
474
464
 
475
- return mixture
465
+ write_mixture_metadata(mixdb, m_id)
476
466
 
477
467
 
478
468
  def main() -> None:
@@ -505,6 +495,7 @@ def main() -> None:
505
495
  save_segsnr = args["--segsnr"]
506
496
  dryrun = args["--dryrun"]
507
497
  save_json = args["--json"]
498
+ no_par = args["--nopar"]
508
499
  location = args["LOC"]
509
500
 
510
501
  start_time = time.monotonic()
@@ -535,6 +526,7 @@ def main() -> None:
535
526
  save_segsnr=save_segsnr,
536
527
  show_progress=True,
537
528
  save_json=save_json,
529
+ no_par=no_par,
538
530
  )
539
531
  except Exception as e:
540
532
  logger.debug(e)
@@ -54,7 +54,7 @@ def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = Non
54
54
 
55
55
  def calc_class_weights_from_mixdb(
56
56
  mixdb: MixtureDatabase,
57
- mixids: GeneralizedIDs | None = None,
57
+ mixids: GeneralizedIDs = "*",
58
58
  other_weight: float = 1,
59
59
  other_index: int = -1,
60
60
  ) -> tuple[np.ndarray, np.ndarray]:
@@ -77,8 +77,6 @@ def calc_class_weights_from_mixdb(
77
77
  from sonusai.mixture import get_class_count_from_mixids
78
78
 
79
79
  count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
80
- if mixdb.truth_mutex and other_weight is not None and other_weight > 0:
81
- count[other_index] = count[other_index] / np.float32(other_weight)
82
80
  total_features = sum(count)
83
81
 
84
82
  weights = np.empty(mixdb.num_classes, dtype=np.float32)
@@ -51,8 +51,8 @@ def calc_optimal_thresholds(
51
51
  AUC[nci] = np.NaN
52
52
  AP[nci] = np.NaN
53
53
  else:
54
- AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None)
55
- AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None)
54
+ AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
55
+ AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
56
56
 
57
57
  # Optimal threshold from PR curve, optimizes f-score
58
58
  precision, recall, thrpr = precision_recall_curve(truth_binary[:, nci], predict[:, nci])
@@ -26,7 +26,7 @@ def calc_phase_distance(
26
26
  # weighted mean over all (scalar)
27
27
  reference_mag = np.abs(reference)
28
28
  ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
29
- err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
29
+ err = float(np.around(np.sum(ref_weight * rh_angle_diff), 3))
30
30
 
31
31
  # weighted mean over frames (value per bin)
32
32
  err_b = np.zeros(reference.shape[1])
@@ -32,16 +32,16 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
32
32
  llr_mean = np.mean(ll_rs[:llr_len])
33
33
 
34
34
  # Segmental SNR
35
- snr_dist, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
35
+ _, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
36
36
  seg_snr = np.mean(segsnr_dist)
37
37
 
38
38
  # PESQ
39
39
  _pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
40
40
 
41
41
  # Now compute the composite measures
42
- csig = np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5)
43
- cbak = np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5)
44
- covl = np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5)
42
+ csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5))
43
+ cbak = float(np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
44
+ covl = float(np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
45
45
 
46
46
  return SpeechMetrics(_pesq, csig, cbak, covl)
47
47
 
@@ -284,8 +284,8 @@ def _calc_log_likelihood_ratio_measure(
284
284
  hypothesis_frame = np.multiply(hypothesis_frame, window)
285
285
 
286
286
  # (2) Get the autocorrelation lags and LPC parameters used to compute the log likelihood ratio measure.
287
- r_reference, ref_reference, a_reference = _lp_coefficients(reference_frame, p)
288
- r_hypothesis, ref_hypothesis, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
287
+ r_reference, _, a_reference = _lp_coefficients(reference_frame, p)
288
+ _, _, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
289
289
 
290
290
  # (3) Compute the log likelihood ratio measure
291
291
  numerator = np.dot(np.matmul(a_hypothesis, toeplitz(r_reference)), a_hypothesis)
@@ -38,7 +38,7 @@ def class_summary(
38
38
  # TODO: re-work for modern mixdb API
39
39
  y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore[name-defined]
40
40
 
41
- if not mixdb.truth_mutex and num_classes > 1:
41
+ if num_classes > 1:
42
42
  if not isinstance(predict_thr, np.ndarray):
43
43
  if predict_thr == 0:
44
44
  predict_thr = np.atleast_1d(0.5)
@@ -53,25 +53,16 @@ def class_summary(
53
53
  # [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
54
54
  table_idx = np.array([2, 1, 6, 4, 0, 12, 13, 9])
55
55
  col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC", "Support"]
56
- if mixdb.truth_mutex:
57
- if len(mixdb.class_labels) >= num_classes - 1: # labels exist with or without Other
58
- row_n = mixdb.class_labels
59
- if len(mixdb.class_labels) == num_classes - 1: # Other label does not exist, so add it
60
- row_n.append("Other")
61
- else:
62
- row_n = [f"Class {i}" for i in range(1, num_classes)]
63
- row_n.append("Other")
56
+ if len(mixdb.class_labels) == num_classes:
57
+ row_n = mixdb.class_labels
64
58
  else:
65
- if len(mixdb.class_labels) == num_classes:
66
- row_n = mixdb.class_labels
67
- else:
68
- row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
59
+ row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
69
60
 
70
- df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n)
61
+ df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
71
62
 
72
63
  # [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
73
64
  avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
74
- dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n)
65
+ dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
75
66
 
76
67
  # dfblank = pd.DataFrame([''])
77
68
  # pd.concat([df, dfblank, dfblank, dfavg])
@@ -37,7 +37,7 @@ def confusion_matrix_summary(
37
37
  ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore[name-defined]
38
38
 
39
39
  # Check predict_thr array or scalar and return final scalar predict_thr value
40
- if not mixdb.truth_mutex and num_classes > 1:
40
+ if num_classes > 1:
41
41
  if not isinstance(predict_thr, np.ndarray):
42
42
  if predict_thr == 0:
43
43
  # multi-label predict_thr scalar 0 force to 0.5 default
@@ -61,31 +61,15 @@ def confusion_matrix_summary(
61
61
  else:
62
62
  class_names = [f"Class {i}" for i in range(1, num_classes + 1)]
63
63
 
64
- class_nums = [f"{i}" for i in range(1, num_classes + 1)]
65
-
66
- if mixdb.truth_mutex:
67
- # single-label mode force to argmax mode
68
- predict_thr = np.array(0, dtype=np.float32)
69
- _, _, cm, cmn, _, _ = one_hot(ytrue, ypred, predict_thr, truth_thr, timesteps)
70
- row_n = class_names
71
- row_n[-1] = "Other"
72
- # mux = pd.MultiIndex.from_product([['Single-label/mutex mode, truth thr = {}'.format(truth_thr)],
73
- # class_nums])
74
- # mux = pd.MultiIndex.from_product([['truth thr = {}'.format(truth_thr)], class_nums])
75
-
76
- cmdf = pd.DataFrame(cm, index=row_n, columns=class_nums, dtype=np.int32)
77
- cmndf = pd.DataFrame(cmn, index=row_n, columns=class_nums, dtype=np.float32)
78
-
79
- else:
80
- _, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
81
- cname = class_names[class_idx]
82
- row_n = ["TrueN", "TrueP"]
83
- col_n = ["N-" + cname, "P-" + cname]
84
- cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32)
85
- cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32)
86
- # add thresholds in 3rd row
87
- pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n)
88
- cmdf = pd.concat([cmdf, pdnote])
89
- cmndf = pd.concat([cmndf, pdnote])
64
+ _, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
65
+ cname = class_names[class_idx]
66
+ row_n = ["TrueN", "TrueP"]
67
+ col_n = ["N-" + cname, "P-" + cname]
68
+ cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32) # pyright: ignore [reportArgumentType]
69
+ cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32) # pyright: ignore [reportArgumentType]
70
+ # add thresholds in 3rd row
71
+ pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n) # pyright: ignore [reportArgumentType, reportCallIssue]
72
+ cmdf = pd.concat([cmdf, pdnote])
73
+ cmndf = pd.concat([cmndf, pdnote])
90
74
 
91
75
  return cmdf, cmndf
@@ -185,11 +185,11 @@ def one_hot(
185
185
  AP = np.NaN
186
186
  # threshold_optpr[nci] = np.NaN
187
187
  else:
188
- AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None)
188
+ AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
189
189
  if len(np.unique(truthb[:, nci])) < 2: # if active classes not > 1 AUC must be NaN
190
190
  AUC = np.NaN # i.e. all ones sklearn auc will fail
191
191
  else:
192
- AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None)
192
+ AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
193
193
  # # Optimal threshold from PR curve, optimizes f-score
194
194
  # precision, recall, thresholds = precision_recall_curve(truthb[:, nci], predict[:, nci])
195
195
  # fscore = (2 * precision * recall) / (precision + recall)
@@ -263,7 +263,7 @@ def one_hot(
263
263
  ] # specific format, last 3 are unique
264
264
 
265
265
  # weighted average TBD
266
- wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0)
266
+ wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0) # pyright: ignore [reportArgumentType]
267
267
  if np.sum(truthb):
268
268
  taidx = np.sum(truthb, axis=0) > 0
269
269
  wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
@@ -48,7 +48,7 @@ def snr_summary(
48
48
  snr_mixids = get_mixids_from_snr(mixdb=mixdb, mixids=mixid)
49
49
 
50
50
  # Check predict_thr array or scalar and return final scalar predict_thr value
51
- if not mixdb.truth_mutex and num_classes > 1:
51
+ if num_classes > 1:
52
52
  if not isinstance(predict_thr, np.ndarray):
53
53
  if predict_thr == 0:
54
54
  # multi-label predict_thr scalar 0 force to 0.5 default
@@ -84,7 +84,7 @@ def snr_summary(
84
84
  for ii, snr in enumerate(snr_mixids):
85
85
  # TODO: re-work for modern mixdb API
86
86
  y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
87
- _, metrics, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
87
+ _, _, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
88
88
 
89
89
  # mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
90
90
  macro_avg[ii, :] = mavg[0, 0:7]
@@ -104,21 +104,21 @@ def snr_summary(
104
104
 
105
105
  # SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
106
106
  col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
107
- snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n)
107
+ snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
108
108
  snr_macrodf.sort_index(ascending=False, inplace=True)
109
109
 
110
- snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n)
110
+ snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
111
111
  snr_microdf.sort_index(ascending=False, inplace=True)
112
112
 
113
- snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n)
113
+ snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
114
114
  snr_wghtdf.sort_index(ascending=False, inplace=True)
115
115
 
116
116
  # Add segmental SNR columns if provided
117
117
  if segsnr is not None:
118
118
  ssnrdf = pd.DataFrame(
119
119
  ssnr_stats,
120
- index=list(snr_mixids.keys()),
121
- columns=["SSNRavg", "SSNR80p", "SSNRmax"],
120
+ index=list(snr_mixids.keys()), # pyright: ignore [reportArgumentType]
121
+ columns=["SSNRavg", "SSNR80p", "SSNRmax"], # pyright: ignore [reportArgumentType]
122
122
  )
123
123
  ssnrdf.sort_index(ascending=False, inplace=True)
124
124
  snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
@@ -46,19 +46,15 @@ from .constants import SAMPLE_RATE
46
46
  from .constants import VALID_AUGMENTATIONS
47
47
  from .constants import VALID_CONFIGS
48
48
  from .constants import VALID_NOISE_MIX_MODES
49
+ from .data_io import clear_cached_data
49
50
  from .data_io import read_cached_data
50
51
  from .data_io import write_cached_data
51
52
  from .datatypes import AudioF
52
- from .datatypes import AudiosF
53
- from .datatypes import AudiosT
54
53
  from .datatypes import AudioStatsMetrics
55
54
  from .datatypes import AudioT
56
55
  from .datatypes import Augmentation
57
56
  from .datatypes import AugmentationRule
58
- from .datatypes import AugmentationRules
59
- from .datatypes import Augmentations
60
57
  from .datatypes import AugmentedTarget
61
- from .datatypes import AugmentedTargets
62
58
  from .datatypes import ClassCount
63
59
  from .datatypes import EnergyF
64
60
  from .datatypes import EnergyT
@@ -70,35 +66,27 @@ from .datatypes import GenFTData
70
66
  from .datatypes import GenMixData
71
67
  from .datatypes import ImpulseResponseData
72
68
  from .datatypes import ImpulseResponseFile
73
- from .datatypes import ImpulseResponseFiles
74
- from .datatypes import ListAudiosT
75
69
  from .datatypes import MetricDoc
76
70
  from .datatypes import MetricDocs
77
71
  from .datatypes import Mixture
78
72
  from .datatypes import MixtureDatabaseConfig
79
- from .datatypes import Mixtures
80
73
  from .datatypes import NoiseFile
81
- from .datatypes import NoiseFiles
82
74
  from .datatypes import Predict
83
75
  from .datatypes import Segsnr
84
76
  from .datatypes import SnrFMetrics
85
77
  from .datatypes import SpectralMask
86
- from .datatypes import SpectralMasks
87
78
  from .datatypes import SpeechMetadata
88
79
  from .datatypes import SpeechMetrics
89
80
  from .datatypes import TargetFile
90
- from .datatypes import TargetFiles
91
81
  from .datatypes import TransformConfig
92
82
  from .datatypes import Truth
93
83
  from .datatypes import TruthConfig
94
84
  from .datatypes import TruthConfigs
95
85
  from .datatypes import TruthDict
96
86
  from .datatypes import TruthParameter
97
- from .datatypes import TruthParameters
98
87
  from .datatypes import UniversalSNR
99
88
  from .feature import get_audio_from_feature
100
89
  from .feature import get_feature_from_audio
101
- from .generation import generate_mixtures
102
90
  from .generation import get_all_snrs_from_config
103
91
  from .generation import initialize_db
104
92
  from .generation import populate_class_label_table
@@ -111,17 +99,14 @@ from .generation import populate_target_file_table
111
99
  from .generation import populate_top_table
112
100
  from .generation import populate_truth_parameters_table
113
101
  from .generation import update_mixid_width
114
- from .generation import update_mixture
102
+ from .generation import update_mixture_table
115
103
  from .helpers import augmented_noise_samples
116
104
  from .helpers import augmented_target_samples
117
105
  from .helpers import check_audio_files_exist
118
106
  from .helpers import forward_transform
119
107
  from .helpers import frames_from_samples
120
108
  from .helpers import get_audio_from_transform
121
- from .helpers import get_ft
122
- from .helpers import get_segsnr
123
109
  from .helpers import get_transform_from_audio
124
- from .helpers import get_truth
125
110
  from .helpers import inverse_transform
126
111
  from .helpers import mixture_metadata
127
112
  from .helpers import write_mixture_metadata
@@ -1,12 +1,11 @@
1
1
  from sonusai.mixture.datatypes import AudioT
2
2
  from sonusai.mixture.datatypes import Augmentation
3
3
  from sonusai.mixture.datatypes import AugmentationRule
4
- from sonusai.mixture.datatypes import AugmentationRules
5
4
  from sonusai.mixture.datatypes import ImpulseResponseData
6
5
  from sonusai.mixture.datatypes import OptionalNumberStr
7
6
 
8
7
 
9
- def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> AugmentationRules:
8
+ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
10
9
  """Generate augmentation rules from list of input rules
11
10
 
12
11
  :param rules: Dictionary of augmentation config rule[s]
@@ -25,7 +24,7 @@ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> Augment
25
24
  rule = _parse_ir(rule, num_ir)
26
25
  processed_rules = _expand_rules(expanded_rules=processed_rules, rule=rule)
27
26
 
28
- return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules]
27
+ return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules] # pyright: ignore [reportReturnType]
29
28
 
30
29
 
31
30
  def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
@@ -163,7 +162,7 @@ def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr
163
162
  return length
164
163
 
165
164
 
166
- def get_mixups(augmentations: AugmentationRules) -> list[int]:
165
+ def get_mixups(augmentations: list[AugmentationRule]) -> list[int]:
167
166
  """Get a list of mixup values used
168
167
 
169
168
  :param augmentations: List of augmentations
@@ -172,7 +171,7 @@ def get_mixups(augmentations: AugmentationRules) -> list[int]:
172
171
  return sorted({augmentation.mixup for augmentation in augmentations})
173
172
 
174
173
 
175
- def get_augmentation_indices_for_mixup(augmentations: AugmentationRules, mixup: int) -> list[int]:
174
+ def get_augmentation_indices_for_mixup(augmentations: list[AugmentationRule], mixup: int) -> list[int]:
176
175
  """Get a list of augmentation indices for a given mixup value
177
176
 
178
177
  :param augmentations: List of augmentations
@@ -327,4 +326,4 @@ def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
327
326
  if _rule_has_rand(processed_rule):
328
327
  processed_rule = _generate_random_rule(processed_rule, num_ir)
329
328
 
330
- return dataclass_from_dict(Augmentation, processed_rule)
329
+ return dataclass_from_dict(Augmentation, processed_rule) # pyright: ignore [reportReturnType]
@@ -3,7 +3,7 @@ from sonusai.mixture.datatypes import GeneralizedIDs
3
3
  from sonusai.mixture.mixdb import MixtureDatabase
4
4
 
5
5
 
6
- def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs | None = None) -> ClassCount:
6
+ def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs = "*") -> ClassCount:
7
7
  """Sums the class counts for given mixids"""
8
8
  total_class_count = [0] * mixdb.num_classes
9
9
  m_ids = mixdb.mixids_to_list(mixids)