sonusai 0.20.2__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +240 -76
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +23 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -17
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +5 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +484 -611
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +931 -669
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/METADATA +4 -2
  89. sonusai-1.0.1.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.2.dist-info/RECORD +0 -128
  96. {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/entry_points.txt +0 -0
sonusai/lsdb.py CHANGED
@@ -1,12 +1,12 @@
1
1
  """sonusai lsdb
2
2
 
3
- usage: lsdb [-hta] [-i MIXID] [-c CID] LOC
3
+ usage: lsdb [-hsa] [-i MIXID] [-c CID] LOC
4
4
 
5
5
  Options:
6
6
  -h, --help
7
7
  -i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
8
8
  -c CID, --class_index CID Analyze mixtures that contain this class index.
9
- -t, --targets List all target files.
9
+ -s, --sources List all source files.
10
10
  -a, --all_class_counts List all class counts.
11
11
 
12
12
  List mixture data information from a SonusAI mixture database.
@@ -16,25 +16,10 @@ Inputs:
16
16
 
17
17
  """
18
18
 
19
- import signal
20
-
21
- from sonusai import logger
22
- from sonusai.mixture import GeneralizedIDs
19
+ from sonusai.datatypes import GeneralizedIDs
23
20
  from sonusai.mixture import MixtureDatabase
24
21
 
25
22
 
26
- def signal_handler(_sig, _frame):
27
- import sys
28
-
29
- from sonusai import logger
30
-
31
- logger.info("Canceled due to keyboard interrupt")
32
- sys.exit(1)
33
-
34
-
35
- signal.signal(signal.SIGINT, signal_handler)
36
-
37
-
38
23
  def lsdb(
39
24
  mixdb: MixtureDatabase,
40
25
  mixids: GeneralizedIDs = "*",
@@ -42,7 +27,8 @@ def lsdb(
42
27
  list_targets: bool = False,
43
28
  all_class_counts: bool = False,
44
29
  ) -> None:
45
- from sonusai.mixture import SAMPLE_RATE
30
+ from sonusai import logger
31
+ from sonusai.constants import SAMPLE_RATE
46
32
  from sonusai.queries import get_mixids_from_class_indices
47
33
  from sonusai.utils import consolidate_range
48
34
  from sonusai.utils import max_text_width
@@ -54,38 +40,36 @@ def lsdb(
54
40
  total_samples = mixdb.total_samples()
55
41
  total_duration = total_samples / SAMPLE_RATE
56
42
 
57
- logger.info(f'{"Mixtures":{desc_len}} {mixdb.num_mixtures}')
58
- logger.info(f'{"Duration":{desc_len}} {seconds_to_hms(seconds=total_duration)}')
59
- logger.info(f'{"Targets":{desc_len}} {mixdb.num_target_files}')
60
- logger.info(f'{"Noises":{desc_len}} {mixdb.num_noise_files}')
61
- logger.info(f'{"Feature":{desc_len}} {mixdb.feature}')
43
+ logger.info(f"{'Mixtures':{desc_len}} {mixdb.num_mixtures}")
44
+ logger.info(f"{'Duration':{desc_len}} {seconds_to_hms(seconds=total_duration)}")
45
+ logger.info(f"{'Sources':{desc_len}} {mixdb.num_source_files}")
46
+ logger.info(f"{'Feature':{desc_len}} {mixdb.feature}")
62
47
  logger.info(
63
- f'{"Feature shape":{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} '
64
- f'({mixdb.fg_stride * mixdb.feature_parameters} total params)'
48
+ f"{'Feature shape':{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} "
49
+ f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
65
50
  )
66
- logger.info(f'{"Feature samples":{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
51
+ logger.info(f"{'Feature samples':{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
67
52
  logger.info(
68
- f'{"Feature step samples":{desc_len}} {mixdb.feature_step_samples} samples ' f'({mixdb.feature_step_ms} ms)'
53
+ f"{'Feature step samples':{desc_len}} {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)"
69
54
  )
70
- logger.info(f'{"Feature overlap":{desc_len}} {mixdb.fg_step / mixdb.fg_stride} ({mixdb.feature_step_ms} ms)')
71
- logger.info(f'{"SNRs":{desc_len}} {mixdb.snrs}')
72
- logger.info(f'{"Random SNRs":{desc_len}} {mixdb.random_snrs}')
73
- logger.info(f'{"Classes":{desc_len}} {mixdb.num_classes}')
55
+ logger.info(f"{'Feature overlap':{desc_len}} {mixdb.fg_step / mixdb.fg_stride} ({mixdb.feature_step_ms} ms)")
56
+ logger.info(f"{'SNRs':{desc_len}} {mixdb.snrs}")
57
+ logger.info(f"{'Random SNRs':{desc_len}} {mixdb.random_snrs}")
58
+ logger.info(f"{'Classes':{desc_len}} {mixdb.num_classes}")
74
59
  # TODO: fix class count
75
- logger.info(f'{"Class count":{desc_len}} not supported')
60
+ logger.info(f"{'Class count':{desc_len}} not supported")
76
61
  # print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info)
77
62
  # TODO: add class weight calculations here
78
63
  logger.info("")
79
64
 
80
65
  if list_targets:
81
- logger.info("Target details:")
82
- idx_len = max_text_width(mixdb.num_target_files)
83
- for idx, target in enumerate(mixdb.target_files):
84
- desc = f" {idx:{idx_len}} Name"
85
- logger.info(f"{desc:{desc_len}} {target.name}")
86
- desc = f" {idx:{idx_len}} Truth index"
87
- logger.info(f"{desc:{desc_len}} {target.class_indices}")
88
- logger.info("")
66
+ logger.info("Source details:")
67
+ for category, sources in mixdb.source_files.items():
68
+ print(f" {category}:")
69
+ for source in sources:
70
+ logger.info(f"{' Name':{desc_len}} {source.name}")
71
+ logger.info(f"{' Truth index':{desc_len}} {source.class_indices}")
72
+ logger.info("")
89
73
 
90
74
  if class_index is not None:
91
75
  if 0 <= class_index > mixdb.num_classes:
@@ -104,7 +88,7 @@ def lsdb(
104
88
  # print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info, all_class_counts=True)
105
89
  else:
106
90
  logger.info(
107
- f"Calculating statistics from truth_f files for {len(mixids):,} mixtures" f" ({consolidate_range(mixids)})"
91
+ f"Calculating statistics from truth_f files for {len(mixids):,} mixtures ({consolidate_range(mixids)})"
108
92
  )
109
93
  logger.info("Not supported")
110
94
 
@@ -112,13 +96,10 @@ def lsdb(
112
96
  def main() -> None:
113
97
  from docopt import docopt
114
98
 
115
- import sonusai
116
- from sonusai import create_file_handler
117
- from sonusai import initial_log_messages
118
- from sonusai import update_console_handler
99
+ from sonusai import __version__ as sai_version
119
100
  from sonusai.utils import trim_docstring
120
101
 
121
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
102
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
122
103
 
123
104
  mixid = args["--mixid"]
124
105
  class_index = args["--class_index"]
@@ -126,6 +107,11 @@ def main() -> None:
126
107
  all_class_counts = args["--all_class_counts"]
127
108
  location = args["LOC"]
128
109
 
110
+ from sonusai import create_file_handler
111
+ from sonusai import initial_log_messages
112
+ from sonusai import logger
113
+ from sonusai import update_console_handler
114
+
129
115
  if class_index is not None:
130
116
  class_index = int(class_index)
131
117
 
@@ -146,4 +132,11 @@ def main() -> None:
146
132
 
147
133
 
148
134
  if __name__ == "__main__":
149
- main()
135
+ from sonusai import exception_handler
136
+ from sonusai.utils import register_keyboard_interrupt
137
+
138
+ register_keyboard_interrupt()
139
+ try:
140
+ main()
141
+ except Exception as e:
142
+ exception_handler(e)
sonusai/main.py CHANGED
@@ -10,21 +10,6 @@ for more information on a specific command.
10
10
 
11
11
  """
12
12
 
13
- import signal
14
-
15
-
16
- def signal_handler(_sig, _frame):
17
- import sys
18
-
19
- from sonusai import logger
20
-
21
- logger.info("Canceled due to keyboard interrupt")
22
- sys.exit(1)
23
-
24
-
25
- signal.signal(signal.SIGINT, signal_handler)
26
-
27
-
28
13
  def main() -> None:
29
14
  from importlib import import_module
30
15
  from pkgutil import iter_modules
@@ -44,13 +29,13 @@ def main() -> None:
44
29
 
45
30
  from docopt import docopt
46
31
 
47
- from sonusai import __version__
32
+ from sonusai import __version__ as sai_version
48
33
  from sonusai.utils import add_commands_to_docstring
49
34
  from sonusai.utils import trim_docstring
50
35
 
51
36
  args = docopt(
52
37
  trim_docstring(add_commands_to_docstring(__doc__, plugin_docstrings)),
53
- version=__version__,
38
+ version=sai_version,
54
39
  options_first=True,
55
40
  )
56
41
 
@@ -61,15 +46,16 @@ def main() -> None:
61
46
  from os.path import join
62
47
  from subprocess import call
63
48
 
64
- import sonusai
49
+ from sonusai import BASEDIR
50
+ from sonusai import commands_list
65
51
  from sonusai import logger
66
52
 
67
- base_commands = sonusai.commands_list()
53
+ base_commands = commands_list()
68
54
  if command == "help":
69
55
  if not argv:
70
56
  exit(call(["sonusai", "-h"])) # noqa: S603, S607
71
57
  elif argv[0] in base_commands:
72
- exit(call(["python", f"{join(sonusai.BASEDIR, argv[0])}.py", "-h"])) # noqa: S603, S607
58
+ exit(call(["python", f"{join(BASEDIR, argv[0])}.py", "-h"])) # noqa: S603, S607
73
59
 
74
60
  for data in plugins.values():
75
61
  if argv[0] in data["commands"]:
@@ -79,7 +65,7 @@ def main() -> None:
79
65
  sys.exit(1)
80
66
 
81
67
  if command in base_commands:
82
- exit(call(["python", f"{join(sonusai.BASEDIR, command)}.py", *argv])) # noqa: S603, S607
68
+ exit(call(["python", f"{join(BASEDIR, command)}.py", *argv])) # noqa: S603, S607
83
69
 
84
70
  for data in plugins.values():
85
71
  if command in data["commands"]:
@@ -90,4 +76,11 @@ def main() -> None:
90
76
 
91
77
 
92
78
  if __name__ == "__main__":
93
- main()
79
+ from sonusai import exception_handler
80
+ from sonusai.utils import register_keyboard_interrupt
81
+
82
+ register_keyboard_interrupt()
83
+ try:
84
+ main()
85
+ except Exception as e:
86
+ exception_handler(e)
@@ -1,5 +1,5 @@
1
- from sonusai.mixture.datatypes import AudioStatsMetrics
2
- from sonusai.mixture.datatypes import AudioT
1
+ from ..datatypes import AudioStatsMetrics
2
+ from ..datatypes import AudioT
3
3
 
4
4
 
5
5
  def _convert_str_with_factors_to_int(x: str) -> int:
@@ -11,22 +11,9 @@ def _convert_str_with_factors_to_int(x: str) -> int:
11
11
 
12
12
 
13
13
  def calc_audio_stats(audio: AudioT, win_len: float | None = None) -> AudioStatsMetrics:
14
- from sonusai.mixture import SAMPLE_RATE
15
- from sonusai.mixture import Transformer
14
+ from ..mixture.sox_effects import sox_stats
16
15
 
17
- args = ["stats"]
18
- if win_len is not None:
19
- args.extend(["-w", str(win_len)])
20
-
21
- tfm = Transformer()
22
-
23
- _, _, out = tfm.build(
24
- input_array=audio,
25
- sample_rate_in=SAMPLE_RATE,
26
- output_filepath="-n",
27
- extra_args=args,
28
- return_output=True,
29
- )
16
+ out = sox_stats(audio, win_len)
30
17
 
31
18
  if out is None:
32
19
  raise SystemError("Call to sox failed")
@@ -1,8 +1,8 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.datatypes import GeneralizedIDs
4
- from sonusai.mixture.datatypes import Truth
5
- from sonusai.mixture.mixdb import MixtureDatabase
3
+ from ..datatypes import GeneralizedIDs
4
+ from ..datatypes import Truth
5
+ from ..mixture.mixdb import MixtureDatabase
6
6
 
7
7
 
8
8
  def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = None, other_index: int = -1) -> np.ndarray:
@@ -74,7 +74,7 @@ def calc_class_weights_from_mixdb(
74
74
  weights: Class weights. [num_classes, 1]
75
75
  Note: for Keras use dict(enumerate(weights))
76
76
  """
77
- from sonusai.mixture import get_class_count_from_mixids
77
+ from ..mixture.class_count import get_class_count_from_mixids
78
78
 
79
79
  count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
80
80
  total_features = sum(count)
@@ -1,11 +1,14 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.datatypes import Predict
4
- from sonusai.mixture.datatypes import Truth
3
+ from ..datatypes import Predict
4
+ from ..datatypes import Truth
5
5
 
6
6
 
7
7
  def calc_optimal_thresholds(
8
- truth: Truth, predict: Predict, timesteps: int = 0, truth_thr: float = 0.5
8
+ truth: Truth,
9
+ predict: Predict,
10
+ timesteps: int = 0,
11
+ truth_thr: float = 0.5,
9
12
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
10
13
  """Calculates optimal thresholds for each class from one-hot prediction and truth data where both are
11
14
  one-hot probabilities (or quantized decisions) with size [frames, num_classes] or [frames, timesteps, num_classes].
@@ -24,8 +27,8 @@ def calc_optimal_thresholds(
24
27
  from sklearn.metrics import roc_auc_score
25
28
  from sklearn.metrics import roc_curve
26
29
 
27
- from sonusai.utils import get_num_classes_from_predict
28
- from sonusai.utils import reshape_outputs
30
+ from ..utils.reshape import get_num_classes_from_predict
31
+ from ..utils.reshape import reshape_outputs
29
32
 
30
33
  if truth.shape != predict.shape:
31
34
  raise ValueError("truth and predict are not the same shape")
@@ -1,6 +1,6 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.constants import SAMPLE_RATE
3
+ from ..constants import SAMPLE_RATE
4
4
 
5
5
 
6
6
  def calc_pesq(
@@ -23,7 +23,7 @@ def calc_pesq(
23
23
 
24
24
  from pesq import pesq
25
25
 
26
- from sonusai import logger
26
+ from .. import logger
27
27
 
28
28
  try:
29
29
  with warnings.catch_warnings():
@@ -1,9 +1,9 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.datatypes import AudioF
4
- from sonusai.mixture.datatypes import Segsnr
5
- from sonusai.mixture.datatypes import SnrFBinMetrics
6
- from sonusai.mixture.datatypes import SnrFMetrics
3
+ from ..datatypes import AudioF
4
+ from ..datatypes import Segsnr
5
+ from ..datatypes import SnrFBinMetrics
6
+ from ..datatypes import SnrFMetrics
7
7
 
8
8
 
9
9
  def calc_segsnr_f(segsnr_f: Segsnr) -> SnrFMetrics:
@@ -1,18 +1,23 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.constants import SAMPLE_RATE
4
- from sonusai.mixture.datatypes import SpeechMetrics
5
-
3
+ from ..constants import SAMPLE_RATE
4
+ from ..datatypes import SpeechMetrics
6
5
  from .calc_pesq import calc_pesq
7
6
 
8
7
 
9
- def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE) -> SpeechMetrics:
10
- """Calculate speech metrics pesq, c_sig, c_bak, and c_ovl.
8
+ def calc_speech(
9
+ hypothesis: np.ndarray,
10
+ reference: np.ndarray,
11
+ pesq: float | None = None,
12
+ sample_rate: int = SAMPLE_RATE,
13
+ ) -> SpeechMetrics:
14
+ """Calculate speech metrics c_sig, c_bak, and c_ovl.
11
15
 
12
16
  These are all related and thus included in one function. Reference: matlab script "compute_metrics.m".
13
17
 
14
18
  :param hypothesis: estimated audio
15
19
  :param reference: reference audio
20
+ :param pesq: pesq
16
21
  :param sample_rate: sample rate of audio
17
22
  :return: SpeechMetrics named tuple
18
23
  """
@@ -36,18 +41,21 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
36
41
  seg_snr = np.mean(segsnr_dist)
37
42
 
38
43
  # PESQ
39
- _pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
44
+ if pesq is None:
45
+ pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
40
46
 
41
47
  # Now compute the composite measures
42
- csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5))
43
- cbak = float(np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
44
- covl = float(np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
48
+ csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * pesq - 0.009 * wss_dist, 1, 5))
49
+ cbak = float(np.clip(1.634 + 0.478 * pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
50
+ covl = float(np.clip(1.594 + 0.805 * pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
45
51
 
46
- return SpeechMetrics(_pesq, csig, cbak, covl)
52
+ return SpeechMetrics(csig, cbak, covl)
47
53
 
48
54
 
49
55
  def _calc_weighted_spectral_slope_measure(
50
- hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE
56
+ hypothesis: np.ndarray,
57
+ reference: np.ndarray,
58
+ sample_rate: int = SAMPLE_RATE,
51
59
  ) -> np.ndarray:
52
60
  from scipy.fftpack import fft
53
61
 
@@ -250,7 +258,9 @@ def _calc_weighted_spectral_slope_measure(
250
258
 
251
259
 
252
260
  def _calc_log_likelihood_ratio_measure(
253
- hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE
261
+ hypothesis: np.ndarray,
262
+ reference: np.ndarray,
263
+ sample_rate: int = SAMPLE_RATE,
254
264
  ) -> np.ndarray:
255
265
  from scipy.linalg import toeplitz
256
266
 
@@ -296,7 +306,9 @@ def _calc_log_likelihood_ratio_measure(
296
306
 
297
307
 
298
308
  def _calc_snr(
299
- hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int = SAMPLE_RATE
309
+ hypothesis: np.ndarray,
310
+ reference: np.ndarray,
311
+ sample_rate: int = SAMPLE_RATE,
300
312
  ) -> tuple[float, np.ndarray]:
301
313
  # The lengths of the reference and hypothesis must be the same.
302
314
  reference_length = len(reference)
@@ -2,10 +2,10 @@
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
 
5
- from sonusai.mixture import GeneralizedIDs
6
- from sonusai.mixture import MixtureDatabase
7
- from sonusai.mixture import Predict
8
- from sonusai.mixture import Truth
5
+ from ..datatypes import GeneralizedIDs
6
+ from ..datatypes import Predict
7
+ from ..datatypes import Truth
8
+ from ..mixture.mixdb import MixtureDatabase
9
9
 
10
10
 
11
11
  def class_summary(
@@ -31,7 +31,7 @@ def class_summary(
31
31
  macro avg 0.85 0.83 0.84 0.05 0.96 3768
32
32
  micro-avgwo
33
33
  """
34
- from sonusai.metrics import one_hot
34
+ from ..metrics.one_hot import one_hot
35
35
 
36
36
  num_classes = truth_f.shape[1]
37
37
 
@@ -58,11 +58,11 @@ def class_summary(
58
58
  else:
59
59
  row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
60
60
 
61
- df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
61
+ df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
62
62
 
63
63
  # [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
64
64
  avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
65
- dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
65
+ dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
66
66
 
67
67
  # dfblank = pd.DataFrame([''])
68
68
  # pd.concat([df, dfblank, dfblank, dfavg])
@@ -2,10 +2,10 @@
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
 
5
- from sonusai.mixture import GeneralizedIDs
6
- from sonusai.mixture import MixtureDatabase
7
- from sonusai.mixture import Predict
8
- from sonusai.mixture import Truth
5
+ from ..datatypes import GeneralizedIDs
6
+ from ..datatypes import Predict
7
+ from ..datatypes import Truth
8
+ from ..mixture.mixdb import MixtureDatabase
9
9
 
10
10
 
11
11
  def confusion_matrix_summary(
@@ -30,7 +30,7 @@ def confusion_matrix_summary(
30
30
 
31
31
  Returns pandas dataframes of confusion matrix cmdf and normalized confusion matrix cmndf.
32
32
  """
33
- from sonusai.metrics import one_hot
33
+ from ..metrics.one_hot import one_hot
34
34
 
35
35
  num_classes = truth_f.shape[1]
36
36
  # TODO: re-work for modern mixdb API
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.datatypes import Predict
4
- from sonusai.mixture.datatypes import Truth
3
+ from ..datatypes import Predict
4
+ from ..datatypes import Truth
5
5
 
6
6
 
7
7
  def one_hot(
@@ -53,8 +53,8 @@ def one_hot(
53
53
  from sklearn.metrics import precision_recall_fscore_support
54
54
  from sklearn.metrics import roc_auc_score
55
55
 
56
- from sonusai.utils import get_num_classes_from_predict
57
- from sonusai.utils import reshape_outputs
56
+ from ..utils.reshape import get_num_classes_from_predict
57
+ from ..utils.reshape import reshape_outputs
58
58
 
59
59
  if truth.shape != predict.shape:
60
60
  raise ValueError("truth and predict are not the same shape")
@@ -2,11 +2,11 @@
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
 
5
- from sonusai.mixture import GeneralizedIDs
6
- from sonusai.mixture import MixtureDatabase
7
- from sonusai.mixture import Predict
8
- from sonusai.mixture import Segsnr
9
- from sonusai.mixture import Truth
5
+ from ..datatypes import GeneralizedIDs
6
+ from ..datatypes import Predict
7
+ from ..datatypes import Segsnr
8
+ from ..datatypes import Truth
9
+ from ..mixture.mixdb import MixtureDatabase
10
10
 
11
11
 
12
12
  def snr_summary(
@@ -40,8 +40,8 @@ def snr_summary(
40
40
  """
41
41
  import warnings
42
42
 
43
- from sonusai.metrics import one_hot
44
- from sonusai.queries import get_mixids_from_snr
43
+ from ..metrics.one_hot import one_hot
44
+ from ..queries.queries import get_mixids_from_snr
45
45
 
46
46
  num_classes = truth_f.shape[1]
47
47