sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/main.py ADDED
@@ -0,0 +1,134 @@
1
+ """sonusai
2
+
3
+ usage: sonusai [--version] [--help] <command> [<args>...]
4
+
5
+ The sonusai commands are:
6
+ <This information is automatically generated.>
7
+
8
+ Aaware Sound and Voice Machine Learning Framework. See 'sonusai help <command>'
9
+ for more information on a specific command.
10
+
11
+ """
12
+
13
+ import sys
14
+ from importlib import import_module
15
+ from pkgutil import iter_modules
16
+
17
+ from docopt import docopt
18
+
19
+ from sonusai import BASEDIR
20
+ from sonusai import __version__ as sai_version
21
+ from sonusai import commands_list
22
+ from sonusai import logger
23
+ from sonusai.utils.docstring import add_commands_to_docstring
24
+ from sonusai.utils.docstring import trim_docstring
25
+
26
+
27
+ def discover_plugins():
28
+ plugins = {}
29
+ plugin_docstrings = []
30
+ for _, name, _ in iter_modules():
31
+ if name.startswith("sonusai_") and not name.startswith("sonusai_asr_"):
32
+ module = import_module(name)
33
+ plugins[name] = {
34
+ "commands": commands_list(module.commands_doc),
35
+ "basedir": module.BASEDIR,
36
+ }
37
+ plugin_docstrings.append(module.commands_doc)
38
+ return plugins, plugin_docstrings
39
+
40
+
41
+ def execute_command_direct(command: str, argv: list[str], basedir: str) -> None:
42
+ """Execute a command by importing and running it directly."""
43
+ try:
44
+ # Add the command directory to the Python path temporarily
45
+ if basedir not in sys.path:
46
+ sys.path.insert(0, basedir)
47
+
48
+ # Import the command module
49
+ command_module = import_module(command)
50
+
51
+ # Set up sys.argv as the command module expects it
52
+ original_argv = sys.argv
53
+ sys.argv = [command, *argv]
54
+
55
+ try:
56
+ # Execute the main function if it exists
57
+ if hasattr(command_module, "main"):
58
+ command_module.main()
59
+ else:
60
+ logger.error(f"Command module {command} has no main() function")
61
+ sys.exit(1)
62
+ finally:
63
+ # Restore original sys.argv
64
+ sys.argv = original_argv
65
+
66
+ except ImportError as err:
67
+ logger.error(f"Failed to import command module {command}: {err}")
68
+ sys.exit(1)
69
+ except Exception as err:
70
+ logger.error(f"Error executing command {command}: {err}")
71
+ sys.exit(1)
72
+
73
+
74
+ def handle_help_command_direct(argv: list[str], base_commands: list[str], plugins: dict) -> None:
75
+ """Handle the help command by executing modules directly."""
76
+ if not argv:
77
+ # Show the main help by re-running with -h
78
+ sys.argv = ["sonusai", "-h"]
79
+ main()
80
+ return
81
+
82
+ help_target = argv[0]
83
+
84
+ if help_target in base_commands:
85
+ execute_command_direct(help_target, ["-h"], BASEDIR)
86
+ else:
87
+ for data in plugins.values():
88
+ if help_target in data["commands"]:
89
+ execute_command_direct(help_target, ["-h"], data["basedir"])
90
+ return
91
+
92
+ logger.error(f"{help_target} is not a SonusAI command. See 'sonusai help'.")
93
+ sys.exit(1)
94
+
95
+
96
+ def main() -> None:
97
+ plugins, plugin_docstrings = discover_plugins()
98
+ updated_docstring = add_commands_to_docstring(__doc__, plugin_docstrings)
99
+ args = docopt(
100
+ trim_docstring(updated_docstring),
101
+ version=sai_version,
102
+ options_first=True,
103
+ )
104
+
105
+ command = args["<command>"]
106
+ argv = args["<args>"]
107
+ base_commands = commands_list()
108
+
109
+ if command == "help":
110
+ handle_help_command_direct(argv, base_commands, plugins)
111
+ return
112
+
113
+ if command in base_commands:
114
+ execute_command_direct(command, argv, BASEDIR)
115
+ return
116
+
117
+ for data in plugins.values():
118
+ if command in data["commands"]:
119
+ execute_command_direct(command, argv, data["basedir"])
120
+ return
121
+
122
+ logger.error(f"{command} is not a SonusAI command. See 'sonusai help'.")
123
+ sys.exit(1)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ from sonusai import exception_handler
128
+ from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
129
+
130
+ register_keyboard_interrupt()
131
+ try:
132
+ main()
133
+ except Exception as e:
134
+ exception_handler(e)
@@ -0,0 +1,43 @@
1
+ # SonusAI metrics utilities for model training and validation
2
+
3
+ from .calc_audio_stats import calc_audio_stats
4
+ from .calc_class_weights import calc_class_weights_from_mixdb
5
+ from .calc_class_weights import calc_class_weights_from_truth
6
+ from .calc_optimal_thresholds import calc_optimal_thresholds
7
+ from .calc_pcm import calc_pcm
8
+ from .calc_pesq import calc_pesq
9
+ from .calc_phase_distance import calc_phase_distance
10
+ from .calc_sa_sdr import calc_sa_sdr
11
+ from .calc_sample_weights import calc_sample_weights
12
+ from .calc_segsnr_f import calc_segsnr_f
13
+ from .calc_segsnr_f import calc_segsnr_f_bin
14
+ from .calc_speech import calc_speech
15
+ from .calc_wer import calc_wer
16
+ from .calc_wsdr import calc_wsdr
17
+ from .calculate_metrics import calculate_metrics
18
+ from .class_summary import class_summary
19
+ from .confusion_matrix_summary import confusion_matrix_summary
20
+ from .one_hot import one_hot
21
+ from .snr_summary import snr_summary
22
+
23
+ __all__ = [
24
+ "calc_audio_stats",
25
+ "calc_class_weights_from_mixdb",
26
+ "calc_class_weights_from_truth",
27
+ "calc_optimal_thresholds",
28
+ "calc_pcm",
29
+ "calc_pesq",
30
+ "calc_phase_distance",
31
+ "calc_sa_sdr",
32
+ "calc_sample_weights",
33
+ "calc_segsnr_f",
34
+ "calc_segsnr_f_bin",
35
+ "calc_speech",
36
+ "calc_wer",
37
+ "calc_wsdr",
38
+ "calculate_metrics",
39
+ "class_summary",
40
+ "confusion_matrix_summary",
41
+ "one_hot",
42
+ "snr_summary",
43
+ ]
@@ -0,0 +1,42 @@
1
+ from ..datatypes import AudioStatsMetrics
2
+ from ..datatypes import AudioT
3
+
4
+
5
+ def _convert_str_with_factors_to_int(x: str) -> int:
6
+ if "k" in x:
7
+ return int(1000 * float(x.replace("k", "")))
8
+ if "M" in x:
9
+ return int(1000000 * float(x.replace("M", "")))
10
+ return int(x)
11
+
12
+
13
+ def calc_audio_stats(audio: AudioT, win_len: float | None = None) -> AudioStatsMetrics:
14
+ from ..mixture.sox_effects import sox_stats
15
+
16
+ out = sox_stats(audio, win_len)
17
+
18
+ if out is None:
19
+ raise SystemError("Call to sox failed")
20
+
21
+ stats = {}
22
+ lines = out.split("\n")
23
+ for line in lines:
24
+ split_line = line.split()
25
+ if len(split_line) == 0:
26
+ continue
27
+ value = split_line[-1]
28
+ key = " ".join(split_line[:-1])
29
+ stats[key] = value
30
+
31
+ return AudioStatsMetrics(
32
+ dco=float(stats["DC offset"]),
33
+ min=float(stats["Min level"]),
34
+ max=float(stats["Max level"]),
35
+ pkdb=float(stats["Pk lev dB"]),
36
+ lrms=float(stats["RMS lev dB"]),
37
+ pkr=float(stats["RMS Pk dB"]),
38
+ tr=float(stats["RMS Tr dB"]),
39
+ cr=float(stats["Crest factor"]),
40
+ fl=float(stats["Flat factor"]),
41
+ pkc=_convert_str_with_factors_to_int(stats["Pk count"]),
42
+ )
@@ -0,0 +1,90 @@
1
+ import numpy as np
2
+
3
+ from ..datatypes import GeneralizedIDs
4
+ from ..datatypes import Truth
5
+ from ..mixture.mixdb import MixtureDatabase
6
+
7
+
8
+ def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = None, other_index: int = -1) -> np.ndarray:
9
+ """Calculate class weights.
10
+
11
+ Supports non-existent classes (a problem with sklearn) where non-existent
12
+ classes get a weight of 0 (instead of inf).
13
+ Includes optional weighting of an "other" class if specified.
14
+
15
+ Reference:
16
+ weights = class_weight.compute_class_weight(class_weight='balanced', classes=clabels, y=labels)
17
+
18
+ Arguments:
19
+ truth: Truth data in one-hot format. Size can be:
20
+ - [frames, timesteps, num_classes]
21
+ - [frames, num_classes]
22
+ other_weight: float or `None`. Weight of the "other" class.
23
+ > 1 = increase weighting/importance relative to the true count
24
+ 0 > `other_weight` < 1 = decrease weighting/importance relative
25
+ < 0 or `None` = disable, use true count (default = `None`)
26
+ other_index: int. Index of the "other" class in one-hot mode. Defaults to -1 (the last).
27
+
28
+ Returns:
29
+ A numpy array containing class weights.
30
+ """
31
+ frames, num_classes = truth.shape
32
+
33
+ if num_classes > 1:
34
+ labels = np.argmax(truth, axis=-1) # [frames, 1 labels] from one-hot, last dim
35
+ count = np.bincount(labels, minlength=num_classes).astype(float)
36
+ else:
37
+ num_classes = 2
38
+ labels = np.array(truth >= 0.5).astype(np.int8)[:, 0] # quantize to binary and shape (frames,) for bincount
39
+ count = np.bincount(labels, minlength=num_classes).astype(float)
40
+
41
+ if other_weight is not None and other_weight > 0:
42
+ count[other_index] = count[other_index] / np.float32(other_weight)
43
+
44
+ weights = np.empty((len(count)), dtype=np.float32)
45
+ for n in range(len(weights)):
46
+ if count[n] == 0:
47
+ # Avoid sklearn problem with absent classes and assign non-existent classes a weight of 0.
48
+ weights[n] = 0
49
+ else:
50
+ weights[n] = frames / (num_classes * count[n])
51
+
52
+ return weights
53
+
54
+
55
+ def calc_class_weights_from_mixdb(
56
+ mixdb: MixtureDatabase,
57
+ mixids: GeneralizedIDs = "*",
58
+ other_weight: float = 1,
59
+ other_index: int = -1,
60
+ ) -> tuple[np.ndarray, np.ndarray]:
61
+ """Calculate class weights using estimated feature counts from a mixture database.
62
+
63
+ Arguments:
64
+ mixdb: Mixture database.
65
+ mixids: Mixture ID's.
66
+ other_weight: float or `None`. Weight of the "other" class.
67
+ > 1 = increase weighting/importance relative to the true count
68
+ 0 > `other_weight` < 1 = decrease weighting/importance relative
69
+ < 0 or `None` = disable, use true count
70
+ other_index: int. Index of the "other" class in one-hot mode. Defaults to -1 (the last).
71
+
72
+ Returns:
73
+ count: Count of features in each class.
74
+ weights: Class weights. [num_classes, 1]
75
+ Note: for Keras use dict(enumerate(weights))
76
+ """
77
+ from ..mixture.class_count import get_class_count_from_mixids
78
+
79
+ count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
80
+ total_features = sum(count)
81
+
82
+ weights = np.empty(mixdb.num_classes, dtype=np.float32)
83
+ for n in range(len(weights)):
84
+ if count[n] == 0:
85
+ # Avoid sklearn problem with absent classes and assign non-existent classes a weight of 0.
86
+ weights[n] = 0
87
+ else:
88
+ weights[n] = total_features / (mixdb.num_classes * count[n])
89
+
90
+ return count, weights
@@ -0,0 +1,73 @@
1
+ import numpy as np
2
+
3
+ from ..datatypes import Predict
4
+ from ..datatypes import Truth
5
+
6
+
7
+ def calc_optimal_thresholds(
8
+ truth: Truth,
9
+ predict: Predict,
10
+ timesteps: int = 0,
11
+ truth_thr: float = 0.5,
12
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
13
+ """Calculates optimal thresholds for each class from one-hot prediction and truth data where both are
14
+ one-hot probabilities (or quantized decisions) with size [frames, num_classes] or [frames, timesteps, num_classes].
15
+
16
+ Returns:
17
+ thresholds_opt_pr [num_classes, 1] optimal thresholds for PR-curve (F1) performance
18
+ thresholds_opt_roc [num_classes, 1] optimal thresholds for ROC-curve (TPR/FPR) performance
19
+ AP [num_classes, 1]
20
+ AUC [num_classes, 1]
21
+
22
+ Optional truth_thr is the decision threshold(s) applied to truth one-hot input allowing truth to optionally be
23
+ continuous probabilities. Default is 0.5.
24
+ """
25
+ from sklearn.metrics import average_precision_score
26
+ from sklearn.metrics import precision_recall_curve
27
+ from sklearn.metrics import roc_auc_score
28
+ from sklearn.metrics import roc_curve
29
+
30
+ from ..utils.reshape import get_num_classes_from_predict
31
+ from ..utils.reshape import reshape_outputs
32
+
33
+ if truth.shape != predict.shape:
34
+ raise ValueError("truth and predict are not the same shape")
35
+
36
+ predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps) # type: ignore[assignment]
37
+ num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
38
+
39
+ # Apply decision to truth input
40
+ truth_binary = np.array(truth >= truth_thr).astype(np.int8)
41
+
42
+ AP = np.zeros((num_classes, 1))
43
+ AUC = np.zeros((num_classes, 1))
44
+ thresholds_opt_pr = np.zeros((num_classes, 1))
45
+ thresholds_opt_roc = np.zeros((num_classes, 1))
46
+ eps = np.finfo(float).eps
47
+ for nci in range(num_classes):
48
+ # Average Precision also called area under the PR curve AUCPR and
49
+ # AUC ROC curve using binary-ized truth and continuous prediction probabilities
50
+ # sklearn returns nan if no active truth in a class but w/un-suppressible div-by-zero warning
51
+ if sum(truth_binary[:, nci]) == 0: # no active truth must be NaN
52
+ thresholds_opt_pr[nci] = np.NaN
53
+ thresholds_opt_roc[nci] = np.NaN
54
+ AUC[nci] = np.NaN
55
+ AP[nci] = np.NaN
56
+ else:
57
+ AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
58
+ AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
59
+
60
+ # Optimal threshold from PR curve, optimizes f-score
61
+ precision, recall, thrpr = precision_recall_curve(truth_binary[:, nci], predict[:, nci])
62
+ fscore = (2 * precision * recall) / (precision + recall + eps)
63
+ ix = np.argmax(fscore) # index of largest f1 score
64
+ thresholds_opt_pr[nci] = thrpr[ix]
65
+
66
+ # Optimal threshold from ROC curve, optimizes J-statistic (TPR-FPR) or gmean
67
+ fpr, tpr, thrroc = roc_curve(truth_binary[:, nci], predict[:, nci])
68
+ # J = tpr - fpr # J can result in thr > 1
69
+ gmeans = np.sqrt(tpr * (1 - fpr)) # gmean seems better behaved
70
+ ix = np.argmax(gmeans)
71
+ thresholds_opt_roc[nci] = thrroc[ix]
72
+
73
+ return thresholds_opt_pr, thresholds_opt_roc, AP, AUC
@@ -0,0 +1,45 @@
1
+ import numpy as np
2
+
3
+
4
+ def calc_pcm(
5
+ hypothesis: np.ndarray, reference: np.ndarray, with_log: bool = False
6
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
7
+ """Calculate phase constrained magnitude error
8
+
9
+ These must include a noise to make a complete mixture estimate, i.e.,
10
+ noise_est = mixture - sum-over-nsrc(s_est(:, nsrc, :))
11
+ should be one of the sources in s_true and s_est.
12
+
13
+ Calculates mean-over-srcs(mean-over-tf(| (|Sr(t, f)| + |Si(t, f)|) - (|Shr(t, f)| + |Shi(t, f)|) |))
14
+
15
+ Reference:
16
+ Self-attending RNN for Speech Enhancement to Improve Cross-corpus Generalization
17
+ Ashutosh Pandey, Student Member, IEEE and DeLiang Wang, Fellow, IEEE
18
+ https://doi.org/10.48550/arXiv.2105.12831
19
+
20
+ :param hypothesis: complex [frames, nsrc, bins]
21
+ :param reference: complex [frames, nsrc, bins]
22
+ :param with_log: enable log
23
+ :return: (error, error per bin, error per frame)
24
+ """
25
+ # LSM = 1/(T*F) * sumtf(| (|Sr(t, f)| + |Si(t, f)|) - (|Shr(t, f)| + |Shi(t, f)|) |)
26
+ # LPCM = 1/2 * LSM(s, sh) + 1/2 * LSM(n, nh)
27
+
28
+ # [frames, nsrc, bins]
29
+ hypothesis_abs = np.abs(np.real(hypothesis)) + np.abs(np.imag(hypothesis))
30
+ reference_abs = np.abs(np.real(reference)) + np.abs(np.imag(reference))
31
+ err = np.abs(reference_abs - hypothesis_abs)
32
+
33
+ # mean over frames, nsrc for value per bin
34
+ err_b = np.mean(np.mean(err, axis=0), axis=0)
35
+ # mean over bins, nsrc for value per frame
36
+ err_f = np.mean(np.mean(err, axis=2), axis=1)
37
+ # mean over bins and frames, nsrc for scalar value
38
+ err = np.mean(np.mean(err, axis=(0, 2)), axis=0)
39
+
40
+ if with_log:
41
+ err_b = np.around(20 * np.log10(err_b + np.finfo(np.float32).eps), 3)
42
+ err_f = np.around(20 * np.log10(err_f + np.finfo(np.float32).eps), 3)
43
+ err = np.around(20 * np.log10(err + np.finfo(np.float32).eps), 3)
44
+
45
+ return err, err_b, err_f
@@ -0,0 +1,36 @@
1
+ import numpy as np
2
+
3
+ from ..constants import SAMPLE_RATE
4
+
5
+
6
+ def calc_pesq(
7
+ hypothesis: np.ndarray,
8
+ reference: np.ndarray,
9
+ error_value: float = 0.0,
10
+ sample_rate: int = SAMPLE_RATE,
11
+ ) -> float:
12
+ """Computes the PESQ score of hypothesis vs. reference
13
+
14
+ Upon error, assigns a value of 0, or user specified value in error_value
15
+
16
+ :param hypothesis: estimated audio
17
+ :param reference: reference audio
18
+ :param error_value: value to use if error occurs
19
+ :param sample_rate: sample rate of audio
20
+ :return: value between -0.5 to 4.5
21
+ """
22
+ import warnings
23
+
24
+ from pesq import pesq
25
+
26
+ from .. import logger
27
+
28
+ try:
29
+ with warnings.catch_warnings():
30
+ warnings.simplefilter("ignore")
31
+ score = pesq(fs=sample_rate, ref=reference, deg=hypothesis, mode="wb")
32
+ except Exception as e:
33
+ logger.debug(f"PESQ error {e}")
34
+ score = error_value
35
+
36
+ return score
@@ -0,0 +1,43 @@
1
+ import numpy as np
2
+
3
+
4
+ def calc_phase_distance(
5
+ reference: np.ndarray, hypothesis: np.ndarray, eps: float = 1e-9
6
+ ) -> tuple[float, np.ndarray, np.ndarray]:
7
+ """Calculate weighted phase distance error (weight normalization over bins per frame)
8
+
9
+ :param reference: complex [frames, bins]
10
+ :param hypothesis: complex [frames, bins]
11
+ :param eps: epsilon value
12
+ :return: mean, mean per bin, mean per frame
13
+ """
14
+ ang_diff = np.angle(reference) - np.angle(hypothesis)
15
+ phd_mod = (ang_diff + np.pi) % (2 * np.pi) - np.pi
16
+ rh_angle_diff = phd_mod * 180 / np.pi # angle diff in deg
17
+
18
+ # Use complex divide to intrinsically keep angle diff +/-180 deg, but avoid div by zero (real hyp)
19
+ # hyp_real = np.real(hypothesis)
20
+ # near_zeros = np.real(hyp_real) < eps
21
+ # hyp_real = hyp_real * (np.logical_not(near_zeros))
22
+ # hyp_real = hyp_real + (near_zeros * eps)
23
+ # hypothesis = hyp_real + 1j*np.imag(hypothesis)
24
+ # rh_angle_diff = np.angle(reference / hypothesis) * 180 / np.pi # angle diff +/-180
25
+
26
+ # weighted mean over all (scalar)
27
+ reference_mag = np.abs(reference)
28
+ ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
29
+ err = float(np.around(np.sum(ref_weight * rh_angle_diff), 3))
30
+
31
+ # weighted mean over frames (value per bin)
32
+ err_b = np.zeros(reference.shape[1])
33
+ for bi in range(reference.shape[1]):
34
+ ref_weight = reference_mag[:, bi] / (np.sum(reference_mag[:, bi], axis=0) + eps)
35
+ err_b[bi] = np.around(np.sum(ref_weight * rh_angle_diff[:, bi]), 3)
36
+
37
+ # weighted mean over bins (value per frame)
38
+ err_f = np.zeros(reference.shape[0])
39
+ for fi in range(reference.shape[0]):
40
+ ref_weight = reference_mag[fi, :] / (np.sum(reference_mag[fi, :]) + eps)
41
+ err_f[fi] = np.around(np.sum(ref_weight * rh_angle_diff[fi, :]), 3)
42
+
43
+ return err, err_b, err_f
@@ -0,0 +1,64 @@
1
+ import numpy as np
2
+
3
+
4
+ def calc_sa_sdr(
5
+ hypothesis: np.ndarray,
6
+ reference: np.ndarray,
7
+ with_scale: bool = False,
8
+ with_negate: bool = False,
9
+ ) -> tuple[np.ndarray, np.ndarray]:
10
+ """Calculate source-aggregated SDR (signal distortion ratio) using all source inputs which are [samples, nsrc].
11
+
12
+ These should include a noise to be a complete mixture estimate, i.e.,
13
+ noise_est = sum-over-all-srcs(s_est(0:nsamples, :) - sum-over-non-noisesrc(s_est(0:nsamples, n))
14
+ should be one of the sources in reference (s_true) and hypothesis (s_est).
15
+
16
+ Calculates -10*log10(sumn(||sn||^2) / sumn(||sn - shn||^2)
17
+ Note: for SA method, sums are done independently on ref and error before division, vs. SDR and SI-SDR
18
+ where sum over n is taken after divide (before log). This is more stable in noise-only cases and also
19
+ when some sources are poorly estimated.
20
+ TBD: add soft-max option with eps and tau params
21
+
22
+ Reference:
23
+ SA-SDR: A Novel Loss Function for Separation of Meeting Style Data
24
+ Thilo von Neumann, Keisuke Kinoshita, Christoph Boeddeker, Marc Delcroix, Reinhold Haeb-Umbach
25
+ https://doi.org/10.48550/arXiv.2110.15581
26
+
27
+ :param hypothesis: [samples, nsrc]
28
+ :param reference: [samples, nsrc]
29
+ :param with_scale: enable scaling (scaling is same as in SI-SDR)
30
+ :param with_negate: enable negation (for use as a loss function)
31
+ :return: (sa_sdr, opt_scale)
32
+ """
33
+ if with_scale:
34
+ # calc 1 x nsrc scaling factors
35
+ ref_energy = np.sum(reference**2, axis=0, keepdims=True)
36
+ # if ref_energy is zero, just set scaling to 1.0
37
+ with np.errstate(divide="ignore", invalid="ignore"):
38
+ opt_scale = np.sum(reference * hypothesis, axis=0, keepdims=True) / ref_energy
39
+ opt_scale[opt_scale == np.inf] = 1.0
40
+ opt_scale = np.nan_to_num(opt_scale, nan=1.0)
41
+ scaled_ref = opt_scale * reference
42
+ else:
43
+ scaled_ref = reference
44
+ opt_scale = np.ones((1, reference.shape[1]), dtype=float)
45
+
46
+ # multisrc sa-sdr, inputs must be [samples, nsrc]
47
+ err = scaled_ref - hypothesis
48
+
49
+ # -10*log10(sumk(||sk||^2) / sumk(||sk - shk||^2)
50
+ # sum over samples and sources
51
+ num = np.sum(reference**2)
52
+ den = np.sum(err**2)
53
+ if num == 0 and den == 0:
54
+ ratio = np.inf
55
+ else:
56
+ ratio = num / (den + np.finfo(np.float32).eps)
57
+
58
+ sa_sdr = 10 * np.log10(ratio)
59
+
60
+ if with_negate:
61
+ # for use as a loss function
62
+ sa_sdr = -sa_sdr
63
+
64
+ return sa_sdr, opt_scale
@@ -0,0 +1,25 @@
1
+ import numpy as np
2
+
3
+
4
+ def calc_sample_weights(class_weights: np.ndarray, truth: np.ndarray) -> np.ndarray:
5
+ """Calculate sample weights from class weights and a given truth with 2D or 3D shape.
6
+
7
+ Supports one-hot encoded multi-class or binary truth/labels
8
+ Note returns sum of weighted truth over classes, thus should also work for multi-label ? TBD
9
+
10
+ Inputs:
11
+ class_weights [num_classes, 1] weights for each class
12
+ truth [frames, timesteps, num_classes] or [frames, num_classes]
13
+
14
+ Returns:
15
+ sample_weights [frames, timesteps, 1] or [frames, 1]
16
+ """
17
+ ts = truth.shape
18
+ cs = class_weights.shape
19
+
20
+ if ts[-1] == 1 and cs[0] == 2:
21
+ # Binary truth needs 2nd "none" truth dimension
22
+ truth = np.concatenate((truth, 1 - truth), axis=1)
23
+
24
+ # broadcast [num_classes, 1] over [frames, num_classes] or [frames, timesteps, num_classes]
25
+ return np.sum(class_weights * truth, axis=-1)
@@ -0,0 +1,82 @@
1
+ import numpy as np
2
+
3
+ from ..datatypes import AudioF
4
+ from ..datatypes import Segsnr
5
+ from ..datatypes import SnrFBinMetrics
6
+ from ..datatypes import SnrFMetrics
7
+
8
+
9
+ def calc_segsnr_f(segsnr_f: Segsnr) -> SnrFMetrics:
10
+ """Calculate metrics of snr_f truth data.
11
+
12
+ Includes mean and standard deviation of the linear values (usually energy)
13
+ and mean and standard deviation of the dB values (10 * log10).
14
+ """
15
+ if np.count_nonzero(segsnr_f) == 0:
16
+ # If all entries are zeros
17
+ return SnrFMetrics(0, 0, -np.inf, 0)
18
+
19
+ tmp = np.ma.array(segsnr_f, mask=np.logical_not(np.isfinite(segsnr_f)))
20
+ if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
21
+ # If all entries are infinite
22
+ return SnrFMetrics(np.inf, 0, np.inf, 0)
23
+
24
+ snr_mean = np.mean(tmp, axis=0)
25
+ snr_std = np.std(tmp, axis=0)
26
+
27
+ tmp = 10 * np.ma.log10(tmp)
28
+ if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
29
+ # If all entries are masked, special case where all inputs are either 0 or infinite
30
+ snr_db_mean = -np.inf
31
+ snr_db_std = np.inf
32
+ else:
33
+ snr_db_mean = np.mean(tmp, axis=0)
34
+ snr_db_std = np.std(tmp, axis=0)
35
+
36
+ return SnrFMetrics(snr_mean, snr_std, snr_db_mean, snr_db_std)
37
+
38
+
39
+ def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
40
+ """Calculate per-bin segmental SNR metrics.
41
+
42
+ Includes per-bin mean and standard deviation of the linear values
43
+ and mean and standard deviation of the dB values.
44
+ """
45
+ if target_f.ndim != 2 and noise_f.ndim != 2:
46
+ raise ValueError("target_f and noise_f must have 2 dimensions")
47
+
48
+ segsnr_f = (np.abs(target_f) ** 2) / (np.abs(noise_f) ** 2 + np.finfo(np.float32).eps)
49
+
50
+ frames, bins = segsnr_f.shape
51
+ if np.count_nonzero(segsnr_f) == 0:
52
+ # If all entries are zeros
53
+ return SnrFBinMetrics(np.zeros(bins), np.zeros(bins), -np.inf * np.ones(bins), np.zeros(bins))
54
+
55
+ tmp = np.ma.array(segsnr_f, mask=np.logical_not(np.isfinite(segsnr_f)))
56
+ if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
57
+ # If all entries are infinite
58
+ return SnrFBinMetrics(
59
+ np.inf * np.ones(bins),
60
+ np.zeros(bins),
61
+ np.inf * np.ones(bins),
62
+ np.zeros(bins),
63
+ )
64
+
65
+ snr_mean = np.mean(tmp, axis=0)
66
+ snr_std = np.std(tmp, axis=0)
67
+
68
+ tmp = 10 * np.ma.log10(tmp)
69
+ if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
70
+ # If all entries are masked, special case where all inputs are either 0 or infinite
71
+ snr_db_mean = -np.inf * np.ones(bins)
72
+ snr_db_std = np.inf * np.ones(bins)
73
+ else:
74
+ snr_db_mean = np.mean(tmp, axis=0)
75
+ snr_db_std = np.std(tmp, axis=0)
76
+
77
+ return SnrFBinMetrics(
78
+ np.ma.getdata(snr_mean),
79
+ np.ma.getdata(snr_std),
80
+ np.ma.getdata(snr_db_mean),
81
+ np.ma.getdata(snr_db_std),
82
+ )