sonusai 0.19.5__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +38 -49
  5. sonusai/genmetrics.py +65 -70
  6. sonusai/genmix.py +62 -72
  7. sonusai/genmixdb.py +73 -95
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_segsnr_f.py +1 -1
  12. sonusai/metrics/calc_speech.py +6 -6
  13. sonusai/metrics/class_summary.py +6 -15
  14. sonusai/metrics/confusion_matrix_summary.py +11 -27
  15. sonusai/metrics/one_hot.py +3 -3
  16. sonusai/metrics/snr_summary.py +7 -7
  17. sonusai/mixture/__init__.py +3 -17
  18. sonusai/mixture/augmentation.py +5 -6
  19. sonusai/mixture/class_count.py +1 -1
  20. sonusai/mixture/config.py +36 -46
  21. sonusai/mixture/data_io.py +30 -1
  22. sonusai/mixture/datatypes.py +29 -40
  23. sonusai/mixture/db_datatypes.py +1 -1
  24. sonusai/mixture/feature.py +3 -23
  25. sonusai/mixture/generation.py +202 -235
  26. sonusai/mixture/helpers.py +29 -187
  27. sonusai/mixture/mixdb.py +386 -159
  28. sonusai/mixture/soundfile_audio.py +1 -1
  29. sonusai/mixture/sox_audio.py +4 -4
  30. sonusai/mixture/sox_augmentation.py +1 -1
  31. sonusai/mixture/target_class_balancing.py +9 -11
  32. sonusai/mixture/targets.py +23 -20
  33. sonusai/mixture/truth.py +21 -34
  34. sonusai/mixture/truth_functions/__init__.py +6 -0
  35. sonusai/mixture/truth_functions/crm.py +51 -37
  36. sonusai/mixture/truth_functions/energy.py +95 -50
  37. sonusai/mixture/truth_functions/file.py +12 -8
  38. sonusai/mixture/truth_functions/metadata.py +24 -0
  39. sonusai/mixture/truth_functions/metrics.py +28 -0
  40. sonusai/mixture/truth_functions/phoneme.py +4 -5
  41. sonusai/mixture/truth_functions/sed.py +32 -23
  42. sonusai/mixture/truth_functions/target.py +62 -29
  43. sonusai/mkwav.py +34 -43
  44. sonusai/queries/queries.py +9 -15
  45. sonusai/speech/l2arctic.py +6 -2
  46. sonusai/summarize_metric_spenh.py +1 -1
  47. sonusai/utils/__init__.py +1 -0
  48. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  49. sonusai/utils/audio_devices.py +27 -18
  50. sonusai/utils/docstring.py +6 -3
  51. sonusai/utils/energy_f.py +5 -3
  52. sonusai/utils/human_readable_size.py +6 -6
  53. sonusai/utils/load_object.py +15 -0
  54. sonusai/utils/onnx_utils.py +2 -2
  55. sonusai/utils/parallel.py +3 -5
  56. sonusai/utils/print_mixture_details.py +3 -3
  57. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/METADATA +2 -2
  58. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/RECORD +60 -58
  59. sonusai/mixture/truth_functions/datatypes.py +0 -37
  60. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/WHEEL +0 -0
  61. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/entry_points.txt +0 -0
sonusai/genmixdb.py CHANGED
@@ -1,15 +1,16 @@
1
1
  """sonusai genmixdb
2
2
 
3
- usage: genmixdb [-hvmfsdj] LOC
3
+ usage: genmixdb [-hvmfsdjn] LOC
4
4
 
5
5
  options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -m, --mix Save mixture data. [default: False].
9
- -f, --ft Save feature/truth_f data. [default: False].
10
- -s, --segsnr Save segsnr data. [default: False].
11
- -d, --dryrun Perform a dry run showing the processed config. [default: False].
12
- -j, --json Save JSON version of database. [default: False].
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -m, --mix ave mixture data. [default: False].
9
+ -f, --ft Save feature/truth_f data. [default: False].
10
+ -s, --segsnr Save segsnr data. [default: False].
11
+ -d, --dryrun Perform a dry run showing the processed config. [default: False].
12
+ -j, --json Save JSON version of database. [default: False].
13
+ -n, --nopar Do not run in parallel. [default: False].
13
14
 
14
15
  Create mixture database data for training and evaluation. Optionally, also create mixture audio and feature/truth data.
15
16
 
@@ -114,10 +115,6 @@ will find all .wav files in the specified directories and process them as target
114
115
  """
115
116
 
116
117
  import signal
117
- from dataclasses import dataclass
118
-
119
- from sonusai.mixture import Mixture
120
- from sonusai.mixture import MixtureDatabase
121
118
 
122
119
 
123
120
  def signal_handler(_sig, _frame):
@@ -132,17 +129,6 @@ def signal_handler(_sig, _frame):
132
129
  signal.signal(signal.SIGINT, signal_handler)
133
130
 
134
131
 
135
- @dataclass
136
- class MPGlobal:
137
- mixdb: MixtureDatabase
138
- save_mix: bool
139
- save_ft: bool
140
- save_segsnr: bool
141
-
142
-
143
- MP_GLOBAL: MPGlobal
144
-
145
-
146
132
  def genmixdb(
147
133
  location: str,
148
134
  save_mix: bool = False,
@@ -152,7 +138,9 @@ def genmixdb(
152
138
  show_progress: bool = False,
153
139
  test: bool = False,
154
140
  save_json: bool = False,
155
- ) -> MixtureDatabase:
141
+ no_par: bool = False,
142
+ ) -> None:
143
+ from functools import partial
156
144
  from random import seed
157
145
 
158
146
  import yaml
@@ -163,7 +151,6 @@ def genmixdb(
163
151
  from sonusai.mixture import AugmentationRule
164
152
  from sonusai.mixture import MixtureDatabase
165
153
  from sonusai.mixture import balance_targets
166
- from sonusai.mixture import generate_mixtures
167
154
  from sonusai.mixture import get_all_snrs_from_config
168
155
  from sonusai.mixture import get_augmentation_rules
169
156
  from sonusai.mixture import get_augmented_targets
@@ -329,7 +316,8 @@ def genmixdb(
329
316
  f"{seconds_to_hms(seconds=noise_audio_duration)}"
330
317
  )
331
318
 
332
- used_noise_files, used_noise_samples, mixtures = generate_mixtures(
319
+ used_noise_files, used_noise_samples = populate_mixture_table(
320
+ location=location,
333
321
  noise_mix_mode=mixdb.noise_mix_mode,
334
322
  augmented_targets=augmented_targets,
335
323
  target_files=target_files,
@@ -342,16 +330,17 @@ def genmixdb(
342
330
  num_classes=mixdb.num_classes,
343
331
  feature_step_samples=mixdb.feature_step_samples,
344
332
  num_ir=mixdb.num_impulse_response_files,
333
+ test=test,
345
334
  )
346
335
 
347
- num_mixtures = len(mixtures)
336
+ num_mixtures = len(mixdb.mixtures)
348
337
  update_mixid_width(location, num_mixtures, test)
349
338
 
350
339
  if logging:
351
340
  logger.info("")
352
341
  logger.info(f"Found {num_mixtures:,} mixtures to process")
353
342
 
354
- total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
343
+ total_duration = float(sum([mixture.samples for mixture in mixdb.mixtures])) / SAMPLE_RATE
355
344
 
356
345
  if logging:
357
346
  log_duration_and_sizes(
@@ -375,17 +364,21 @@ def genmixdb(
375
364
  if logging:
376
365
  logger.info("Generating mixtures")
377
366
  progress = track(total=num_mixtures, disable=not show_progress)
378
- mixtures = par_track(
379
- _process_mixture,
380
- mixtures,
367
+ par_track(
368
+ partial(
369
+ _process_mixture,
370
+ location=location,
371
+ save_mix=save_mix,
372
+ save_ft=save_ft,
373
+ save_segsnr=save_segsnr,
374
+ test=test,
375
+ ),
376
+ range(num_mixtures),
381
377
  progress=progress,
382
- initializer=_initializer,
383
- initargs=(location, save_mix, save_ft, save_segsnr, test),
378
+ no_par=no_par,
384
379
  )
385
380
  progress.close()
386
381
 
387
- populate_mixture_table(location, mixtures, test)
388
-
389
382
  total_noise_files = len(noise_files)
390
383
 
391
384
  total_samples = mixdb.total_samples()
@@ -414,79 +407,62 @@ def genmixdb(
414
407
  mixdb = MixtureDatabase(location)
415
408
  mixdb.save()
416
409
 
417
- return mixdb
418
-
419
410
 
420
- def _initializer(location: str, save_mix: bool, save_ft: bool, save_segsnr: bool, test: bool) -> None:
421
- global MP_GLOBAL
422
-
423
- MP_GLOBAL = MPGlobal(
424
- mixdb=MixtureDatabase(location, test),
425
- save_mix=save_mix,
426
- save_ft=save_ft,
427
- save_segsnr=save_segsnr,
428
- )
429
-
430
-
431
- def _process_mixture(mixture: Mixture) -> Mixture:
432
- from typing import Any
411
+ def _process_mixture(
412
+ m_id: int,
413
+ location: str,
414
+ save_mix: bool,
415
+ save_ft: bool,
416
+ save_segsnr: bool,
417
+ test: bool,
418
+ ) -> None:
419
+ from functools import partial
433
420
 
434
- from sonusai.mixture import get_ft
435
- from sonusai.mixture import get_segsnr
436
- from sonusai.mixture import get_truth
437
- from sonusai.mixture import update_mixture
421
+ from sonusai.mixture import MixtureDatabase
422
+ from sonusai.mixture import clear_cached_data
423
+ from sonusai.mixture import update_mixture_table
438
424
  from sonusai.mixture import write_cached_data
439
425
  from sonusai.mixture import write_mixture_metadata
440
426
 
441
- global MP_GLOBAL
427
+ with_data = save_mix or save_ft or save_segsnr
428
+
429
+ genmix_data = update_mixture_table(location, m_id, with_data, test)
442
430
 
443
- with_data = MP_GLOBAL.save_mix or MP_GLOBAL.save_ft
444
- mixdb = MP_GLOBAL.mixdb
431
+ mixdb = MixtureDatabase(location, test)
432
+ mixture = mixdb.mixture(m_id)
445
433
 
446
- mixture, genmix_data = update_mixture(mixdb, mixture, with_data)
434
+ write = partial(write_cached_data, location=location, name="mixture", index=mixture.name)
435
+ clear = partial(clear_cached_data, location=location, name="mixture", index=mixture.name)
447
436
 
448
437
  if with_data:
449
- write_data: list[tuple[str, Any]] = []
450
-
451
- if MP_GLOBAL.save_mix:
452
- write_data.append(("targets", genmix_data.targets))
453
- write_data.append(("noise", genmix_data.noise))
454
- write_data.append(("mixture", genmix_data.mixture))
455
-
456
- if MP_GLOBAL.save_ft:
457
- if genmix_data.targets is None or genmix_data.noise is None or genmix_data.mixture is None:
458
- raise RuntimeError("Mixture data was not generated properly")
459
- truth_t = get_truth(
460
- mixdb=mixdb,
461
- mixture=mixture,
462
- targets_audio=genmix_data.targets,
463
- noise_audio=genmix_data.noise,
464
- mixture_audio=genmix_data.mixture,
465
- )
466
- feature, truth_f = get_ft(
467
- mixdb=mixdb,
468
- mixture=mixture,
469
- mixture_audio=genmix_data.mixture,
470
- truth_t=truth_t,
438
+ write(
439
+ items=[
440
+ ("targets", genmix_data.targets),
441
+ ("target", genmix_data.target),
442
+ ("noise", genmix_data.noise),
443
+ ("mixture", genmix_data.mixture),
444
+ ]
445
+ )
446
+
447
+ if save_ft:
448
+ clear(items=["feature", "truth_f"])
449
+ feature, truth_f = mixdb.mixture_ft(m_id)
450
+ write(
451
+ items=[
452
+ ("feature", feature),
453
+ ("truth_f", truth_f),
454
+ ]
471
455
  )
472
- write_data.append(("feature", feature))
473
- write_data.append(("truth_f", truth_f))
474
456
 
475
- if MP_GLOBAL.save_segsnr:
476
- if genmix_data.target is None:
477
- raise RuntimeError("Target data was not generated properly")
478
- segsnr = get_segsnr(
479
- mixdb=mixdb,
480
- mixture=mixture,
481
- target_audio=genmix_data.target,
482
- noise=genmix_data.noise,
483
- )
484
- write_data.append(("segsnr", segsnr))
457
+ if save_segsnr:
458
+ clear(items=["segsnr"])
459
+ segsnr = mixdb.mixture_segsnr(m_id)
460
+ write(items=[("segsnr", segsnr)])
485
461
 
486
- write_cached_data(mixdb.location, "mixture", mixture.name, write_data)
487
- write_mixture_metadata(mixdb, mixture)
462
+ if not save_mix:
463
+ clear(items=["targets", "target", "noise", "mixture"])
488
464
 
489
- return mixture
465
+ write_mixture_metadata(mixdb, m_id)
490
466
 
491
467
 
492
468
  def main() -> None:
@@ -519,6 +495,7 @@ def main() -> None:
519
495
  save_segsnr = args["--segsnr"]
520
496
  dryrun = args["--dryrun"]
521
497
  save_json = args["--json"]
498
+ no_par = args["--nopar"]
522
499
  location = args["LOC"]
523
500
 
524
501
  start_time = time.monotonic()
@@ -549,6 +526,7 @@ def main() -> None:
549
526
  save_segsnr=save_segsnr,
550
527
  show_progress=True,
551
528
  save_json=save_json,
529
+ no_par=no_par,
552
530
  )
553
531
  except Exception as e:
554
532
  logger.debug(e)
@@ -54,7 +54,7 @@ def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = Non
54
54
 
55
55
  def calc_class_weights_from_mixdb(
56
56
  mixdb: MixtureDatabase,
57
- mixids: GeneralizedIDs | None = None,
57
+ mixids: GeneralizedIDs = "*",
58
58
  other_weight: float = 1,
59
59
  other_index: int = -1,
60
60
  ) -> tuple[np.ndarray, np.ndarray]:
@@ -77,8 +77,6 @@ def calc_class_weights_from_mixdb(
77
77
  from sonusai.mixture import get_class_count_from_mixids
78
78
 
79
79
  count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
80
- if mixdb.truth_mutex and other_weight is not None and other_weight > 0:
81
- count[other_index] = count[other_index] / np.float32(other_weight)
82
80
  total_features = sum(count)
83
81
 
84
82
  weights = np.empty(mixdb.num_classes, dtype=np.float32)
@@ -51,8 +51,8 @@ def calc_optimal_thresholds(
51
51
  AUC[nci] = np.NaN
52
52
  AP[nci] = np.NaN
53
53
  else:
54
- AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None)
55
- AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None)
54
+ AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
55
+ AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
56
56
 
57
57
  # Optimal threshold from PR curve, optimizes f-score
58
58
  precision, recall, thrpr = precision_recall_curve(truth_binary[:, nci], predict[:, nci])
@@ -26,7 +26,7 @@ def calc_phase_distance(
26
26
  # weighted mean over all (scalar)
27
27
  reference_mag = np.abs(reference)
28
28
  ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
29
- err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
29
+ err = float(np.around(np.sum(ref_weight * rh_angle_diff), 3))
30
30
 
31
31
  # weighted mean over frames (value per bin)
32
32
  err_b = np.zeros(reference.shape[1])
@@ -45,7 +45,7 @@ def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
45
45
  if target_f.ndim != 2 and noise_f.ndim != 2:
46
46
  raise ValueError("target_f and noise_f must have 2 dimensions")
47
47
 
48
- segsnr_f = (np.abs(target_f) ** 2) / (np.abs(noise_f) ** 2)
48
+ segsnr_f = (np.abs(target_f) ** 2) / (np.abs(noise_f) ** 2 + np.finfo(np.float32).eps)
49
49
 
50
50
  frames, bins = segsnr_f.shape
51
51
  if np.count_nonzero(segsnr_f) == 0:
@@ -32,16 +32,16 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
32
32
  llr_mean = np.mean(ll_rs[:llr_len])
33
33
 
34
34
  # Segmental SNR
35
- snr_dist, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
35
+ _, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
36
36
  seg_snr = np.mean(segsnr_dist)
37
37
 
38
38
  # PESQ
39
39
  _pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
40
40
 
41
41
  # Now compute the composite measures
42
- csig = np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5)
43
- cbak = np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5)
44
- covl = np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5)
42
+ csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5))
43
+ cbak = float(np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
44
+ covl = float(np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
45
45
 
46
46
  return SpeechMetrics(_pesq, csig, cbak, covl)
47
47
 
@@ -284,8 +284,8 @@ def _calc_log_likelihood_ratio_measure(
284
284
  hypothesis_frame = np.multiply(hypothesis_frame, window)
285
285
 
286
286
  # (2) Get the autocorrelation lags and LPC parameters used to compute the log likelihood ratio measure.
287
- r_reference, ref_reference, a_reference = _lp_coefficients(reference_frame, p)
288
- r_hypothesis, ref_hypothesis, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
287
+ r_reference, _, a_reference = _lp_coefficients(reference_frame, p)
288
+ _, _, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
289
289
 
290
290
  # (3) Compute the log likelihood ratio measure
291
291
  numerator = np.dot(np.matmul(a_hypothesis, toeplitz(r_reference)), a_hypothesis)
@@ -38,7 +38,7 @@ def class_summary(
38
38
  # TODO: re-work for modern mixdb API
39
39
  y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore[name-defined]
40
40
 
41
- if not mixdb.truth_mutex and num_classes > 1:
41
+ if num_classes > 1:
42
42
  if not isinstance(predict_thr, np.ndarray):
43
43
  if predict_thr == 0:
44
44
  predict_thr = np.atleast_1d(0.5)
@@ -53,25 +53,16 @@ def class_summary(
53
53
  # [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
54
54
  table_idx = np.array([2, 1, 6, 4, 0, 12, 13, 9])
55
55
  col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC", "Support"]
56
- if mixdb.truth_mutex:
57
- if len(mixdb.class_labels) >= num_classes - 1: # labels exist with or without Other
58
- row_n = mixdb.class_labels
59
- if len(mixdb.class_labels) == num_classes - 1: # Other label does not exist, so add it
60
- row_n.append("Other")
61
- else:
62
- row_n = [f"Class {i}" for i in range(1, num_classes)]
63
- row_n.append("Other")
56
+ if len(mixdb.class_labels) == num_classes:
57
+ row_n = mixdb.class_labels
64
58
  else:
65
- if len(mixdb.class_labels) == num_classes:
66
- row_n = mixdb.class_labels
67
- else:
68
- row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
59
+ row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
69
60
 
70
- df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n)
61
+ df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
71
62
 
72
63
  # [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
73
64
  avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
74
- dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n)
65
+ dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
75
66
 
76
67
  # dfblank = pd.DataFrame([''])
77
68
  # pd.concat([df, dfblank, dfblank, dfavg])
@@ -37,7 +37,7 @@ def confusion_matrix_summary(
37
37
  ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore[name-defined]
38
38
 
39
39
  # Check predict_thr array or scalar and return final scalar predict_thr value
40
- if not mixdb.truth_mutex and num_classes > 1:
40
+ if num_classes > 1:
41
41
  if not isinstance(predict_thr, np.ndarray):
42
42
  if predict_thr == 0:
43
43
  # multi-label predict_thr scalar 0 force to 0.5 default
@@ -61,31 +61,15 @@ def confusion_matrix_summary(
61
61
  else:
62
62
  class_names = [f"Class {i}" for i in range(1, num_classes + 1)]
63
63
 
64
- class_nums = [f"{i}" for i in range(1, num_classes + 1)]
65
-
66
- if mixdb.truth_mutex:
67
- # single-label mode force to argmax mode
68
- predict_thr = np.array(0, dtype=np.float32)
69
- _, _, cm, cmn, _, _ = one_hot(ytrue, ypred, predict_thr, truth_thr, timesteps)
70
- row_n = class_names
71
- row_n[-1] = "Other"
72
- # mux = pd.MultiIndex.from_product([['Single-label/mutex mode, truth thr = {}'.format(truth_thr)],
73
- # class_nums])
74
- # mux = pd.MultiIndex.from_product([['truth thr = {}'.format(truth_thr)], class_nums])
75
-
76
- cmdf = pd.DataFrame(cm, index=row_n, columns=class_nums, dtype=np.int32)
77
- cmndf = pd.DataFrame(cmn, index=row_n, columns=class_nums, dtype=np.float32)
78
-
79
- else:
80
- _, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
81
- cname = class_names[class_idx]
82
- row_n = ["TrueN", "TrueP"]
83
- col_n = ["N-" + cname, "P-" + cname]
84
- cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32)
85
- cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32)
86
- # add thresholds in 3rd row
87
- pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n)
88
- cmdf = pd.concat([cmdf, pdnote])
89
- cmndf = pd.concat([cmndf, pdnote])
64
+ _, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
65
+ cname = class_names[class_idx]
66
+ row_n = ["TrueN", "TrueP"]
67
+ col_n = ["N-" + cname, "P-" + cname]
68
+ cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32) # pyright: ignore [reportArgumentType]
69
+ cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32) # pyright: ignore [reportArgumentType]
70
+ # add thresholds in 3rd row
71
+ pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n) # pyright: ignore [reportArgumentType, reportCallIssue]
72
+ cmdf = pd.concat([cmdf, pdnote])
73
+ cmndf = pd.concat([cmndf, pdnote])
90
74
 
91
75
  return cmdf, cmndf
@@ -185,11 +185,11 @@ def one_hot(
185
185
  AP = np.NaN
186
186
  # threshold_optpr[nci] = np.NaN
187
187
  else:
188
- AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None)
188
+ AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
189
189
  if len(np.unique(truthb[:, nci])) < 2: # if active classes not > 1 AUC must be NaN
190
190
  AUC = np.NaN # i.e. all ones sklearn auc will fail
191
191
  else:
192
- AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None)
192
+ AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
193
193
  # # Optimal threshold from PR curve, optimizes f-score
194
194
  # precision, recall, thresholds = precision_recall_curve(truthb[:, nci], predict[:, nci])
195
195
  # fscore = (2 * precision * recall) / (precision + recall)
@@ -263,7 +263,7 @@ def one_hot(
263
263
  ] # specific format, last 3 are unique
264
264
 
265
265
  # weighted average TBD
266
- wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0)
266
+ wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0) # pyright: ignore [reportArgumentType]
267
267
  if np.sum(truthb):
268
268
  taidx = np.sum(truthb, axis=0) > 0
269
269
  wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
@@ -48,7 +48,7 @@ def snr_summary(
48
48
  snr_mixids = get_mixids_from_snr(mixdb=mixdb, mixids=mixid)
49
49
 
50
50
  # Check predict_thr array or scalar and return final scalar predict_thr value
51
- if not mixdb.truth_mutex and num_classes > 1:
51
+ if num_classes > 1:
52
52
  if not isinstance(predict_thr, np.ndarray):
53
53
  if predict_thr == 0:
54
54
  # multi-label predict_thr scalar 0 force to 0.5 default
@@ -84,7 +84,7 @@ def snr_summary(
84
84
  for ii, snr in enumerate(snr_mixids):
85
85
  # TODO: re-work for modern mixdb API
86
86
  y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
87
- _, metrics, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
87
+ _, _, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
88
88
 
89
89
  # mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
90
90
  macro_avg[ii, :] = mavg[0, 0:7]
@@ -104,21 +104,21 @@ def snr_summary(
104
104
 
105
105
  # SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
106
106
  col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
107
- snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n)
107
+ snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
108
108
  snr_macrodf.sort_index(ascending=False, inplace=True)
109
109
 
110
- snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n)
110
+ snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
111
111
  snr_microdf.sort_index(ascending=False, inplace=True)
112
112
 
113
- snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n)
113
+ snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
114
114
  snr_wghtdf.sort_index(ascending=False, inplace=True)
115
115
 
116
116
  # Add segmental SNR columns if provided
117
117
  if segsnr is not None:
118
118
  ssnrdf = pd.DataFrame(
119
119
  ssnr_stats,
120
- index=list(snr_mixids.keys()),
121
- columns=["SSNRavg", "SSNR80p", "SSNRmax"],
120
+ index=list(snr_mixids.keys()), # pyright: ignore [reportArgumentType]
121
+ columns=["SSNRavg", "SSNR80p", "SSNRmax"], # pyright: ignore [reportArgumentType]
122
122
  )
123
123
  ssnrdf.sort_index(ascending=False, inplace=True)
124
124
  snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
@@ -46,19 +46,15 @@ from .constants import SAMPLE_RATE
46
46
  from .constants import VALID_AUGMENTATIONS
47
47
  from .constants import VALID_CONFIGS
48
48
  from .constants import VALID_NOISE_MIX_MODES
49
+ from .data_io import clear_cached_data
49
50
  from .data_io import read_cached_data
50
51
  from .data_io import write_cached_data
51
52
  from .datatypes import AudioF
52
- from .datatypes import AudiosF
53
- from .datatypes import AudiosT
54
53
  from .datatypes import AudioStatsMetrics
55
54
  from .datatypes import AudioT
56
55
  from .datatypes import Augmentation
57
56
  from .datatypes import AugmentationRule
58
- from .datatypes import AugmentationRules
59
- from .datatypes import Augmentations
60
57
  from .datatypes import AugmentedTarget
61
- from .datatypes import AugmentedTargets
62
58
  from .datatypes import ClassCount
63
59
  from .datatypes import EnergyF
64
60
  from .datatypes import EnergyT
@@ -70,34 +66,27 @@ from .datatypes import GenFTData
70
66
  from .datatypes import GenMixData
71
67
  from .datatypes import ImpulseResponseData
72
68
  from .datatypes import ImpulseResponseFile
73
- from .datatypes import ImpulseResponseFiles
74
- from .datatypes import ListAudiosT
75
69
  from .datatypes import MetricDoc
76
70
  from .datatypes import MetricDocs
77
71
  from .datatypes import Mixture
78
72
  from .datatypes import MixtureDatabaseConfig
79
- from .datatypes import Mixtures
80
73
  from .datatypes import NoiseFile
81
- from .datatypes import NoiseFiles
82
74
  from .datatypes import Predict
83
75
  from .datatypes import Segsnr
84
76
  from .datatypes import SnrFMetrics
85
77
  from .datatypes import SpectralMask
86
- from .datatypes import SpectralMasks
87
78
  from .datatypes import SpeechMetadata
88
79
  from .datatypes import SpeechMetrics
89
80
  from .datatypes import TargetFile
90
- from .datatypes import TargetFiles
91
81
  from .datatypes import TransformConfig
92
82
  from .datatypes import Truth
93
83
  from .datatypes import TruthConfig
94
84
  from .datatypes import TruthConfigs
85
+ from .datatypes import TruthDict
95
86
  from .datatypes import TruthParameter
96
- from .datatypes import TruthParameters
97
87
  from .datatypes import UniversalSNR
98
88
  from .feature import get_audio_from_feature
99
89
  from .feature import get_feature_from_audio
100
- from .generation import generate_mixtures
101
90
  from .generation import get_all_snrs_from_config
102
91
  from .generation import initialize_db
103
92
  from .generation import populate_class_label_table
@@ -110,17 +99,14 @@ from .generation import populate_target_file_table
110
99
  from .generation import populate_top_table
111
100
  from .generation import populate_truth_parameters_table
112
101
  from .generation import update_mixid_width
113
- from .generation import update_mixture
102
+ from .generation import update_mixture_table
114
103
  from .helpers import augmented_noise_samples
115
104
  from .helpers import augmented_target_samples
116
105
  from .helpers import check_audio_files_exist
117
106
  from .helpers import forward_transform
118
107
  from .helpers import frames_from_samples
119
108
  from .helpers import get_audio_from_transform
120
- from .helpers import get_ft
121
- from .helpers import get_segsnr
122
109
  from .helpers import get_transform_from_audio
123
- from .helpers import get_truth
124
110
  from .helpers import inverse_transform
125
111
  from .helpers import mixture_metadata
126
112
  from .helpers import write_mixture_metadata
@@ -1,12 +1,11 @@
1
1
  from sonusai.mixture.datatypes import AudioT
2
2
  from sonusai.mixture.datatypes import Augmentation
3
3
  from sonusai.mixture.datatypes import AugmentationRule
4
- from sonusai.mixture.datatypes import AugmentationRules
5
4
  from sonusai.mixture.datatypes import ImpulseResponseData
6
5
  from sonusai.mixture.datatypes import OptionalNumberStr
7
6
 
8
7
 
9
- def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> AugmentationRules:
8
+ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
10
9
  """Generate augmentation rules from list of input rules
11
10
 
12
11
  :param rules: Dictionary of augmentation config rule[s]
@@ -25,7 +24,7 @@ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> Augment
25
24
  rule = _parse_ir(rule, num_ir)
26
25
  processed_rules = _expand_rules(expanded_rules=processed_rules, rule=rule)
27
26
 
28
- return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules]
27
+ return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules] # pyright: ignore [reportReturnType]
29
28
 
30
29
 
31
30
  def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
@@ -163,7 +162,7 @@ def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr
163
162
  return length
164
163
 
165
164
 
166
- def get_mixups(augmentations: AugmentationRules) -> list[int]:
165
+ def get_mixups(augmentations: list[AugmentationRule]) -> list[int]:
167
166
  """Get a list of mixup values used
168
167
 
169
168
  :param augmentations: List of augmentations
@@ -172,7 +171,7 @@ def get_mixups(augmentations: AugmentationRules) -> list[int]:
172
171
  return sorted({augmentation.mixup for augmentation in augmentations})
173
172
 
174
173
 
175
- def get_augmentation_indices_for_mixup(augmentations: AugmentationRules, mixup: int) -> list[int]:
174
+ def get_augmentation_indices_for_mixup(augmentations: list[AugmentationRule], mixup: int) -> list[int]:
176
175
  """Get a list of augmentation indices for a given mixup value
177
176
 
178
177
  :param augmentations: List of augmentations
@@ -327,4 +326,4 @@ def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
327
326
  if _rule_has_rand(processed_rule):
328
327
  processed_rule = _generate_random_rule(processed_rule, num_ir)
329
328
 
330
- return dataclass_from_dict(Augmentation, processed_rule)
329
+ return dataclass_from_dict(Augmentation, processed_rule) # pyright: ignore [reportReturnType]
@@ -3,7 +3,7 @@ from sonusai.mixture.datatypes import GeneralizedIDs
3
3
  from sonusai.mixture.mixdb import MixtureDatabase
4
4
 
5
5
 
6
- def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs | None = None) -> ClassCount:
6
+ def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs = "*") -> ClassCount:
7
7
  """Sums the class counts for given mixids"""
8
8
  total_class_count = [0] * mixdb.num_classes
9
9
  m_ids = mixdb.mixids_to_list(mixids)