sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/lsdb.py CHANGED
@@ -1,11 +1,11 @@
1
1
  """sonusai lsdb
2
2
 
3
- usage: lsdb [-hta] [-i MIXID] [-c TID] LOC
3
+ usage: lsdb [-hta] [-i MIXID] [-c CID] LOC
4
4
 
5
5
  Options:
6
6
  -h, --help
7
7
  -i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
8
- -c TID, --truth_index TID Analyze mixtures that contain this truth index.
8
+ -c CID, --class_index CID Analyze mixtures that contain this class index.
9
9
  -t, --targets List all target files.
10
10
  -a, --all_class_counts List all class counts.
11
11
 
@@ -15,6 +15,7 @@ Inputs:
15
15
  LOC A SonusAI mixture database directory.
16
16
 
17
17
  """
18
+
18
19
  import signal
19
20
 
20
21
  from sonusai import logger
@@ -27,26 +28,22 @@ def signal_handler(_sig, _frame):
27
28
 
28
29
  from sonusai import logger
29
30
 
30
- logger.info('Canceled due to keyboard interrupt')
31
+ logger.info("Canceled due to keyboard interrupt")
31
32
  sys.exit(1)
32
33
 
33
34
 
34
35
  signal.signal(signal.SIGINT, signal_handler)
35
36
 
36
37
 
37
- def lsdb(mixdb: MixtureDatabase,
38
- mixids: GeneralizedIDs = None,
39
- truth_index: int = None,
40
- list_targets: bool = False,
41
- all_class_counts: bool = False) -> None:
42
- import numpy as np
43
- import h5py
44
-
45
- from sonusai import SonusAIError
46
- from sonusai.metrics import calc_snr_f
38
+ def lsdb(
39
+ mixdb: MixtureDatabase,
40
+ mixids: GeneralizedIDs = "*",
41
+ class_index: int | None = None,
42
+ list_targets: bool = False,
43
+ all_class_counts: bool = False,
44
+ ) -> None:
47
45
  from sonusai.mixture import SAMPLE_RATE
48
- from sonusai.mixture import get_truth_indices_for_target
49
- from sonusai.queries import get_mixids_from_truth_index
46
+ from sonusai.queries import get_mixids_from_class_indices
50
47
  from sonusai.utils import consolidate_range
51
48
  from sonusai.utils import max_text_width
52
49
  from sonusai.utils import print_mixture_details
@@ -62,38 +59,40 @@ def lsdb(mixdb: MixtureDatabase,
62
59
  logger.info(f'{"Targets":{desc_len}} {mixdb.num_target_files}')
63
60
  logger.info(f'{"Noises":{desc_len}} {mixdb.num_noise_files}')
64
61
  logger.info(f'{"Feature":{desc_len}} {mixdb.feature}')
65
- logger.info(f'{"Feature shape":{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} '
66
- f'({mixdb.fg_stride * mixdb.feature_parameters} total params)')
62
+ logger.info(
63
+ f'{"Feature shape":{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} '
64
+ f'({mixdb.fg_stride * mixdb.feature_parameters} total params)'
65
+ )
67
66
  logger.info(f'{"Feature samples":{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
68
- logger.info(f'{"Feature step samples":{desc_len}} {mixdb.feature_step_samples} samples '
69
- f'({mixdb.feature_step_ms} ms)')
67
+ logger.info(
68
+ f'{"Feature step samples":{desc_len}} {mixdb.feature_step_samples} samples ' f'({mixdb.feature_step_ms} ms)'
69
+ )
70
70
  logger.info(f'{"Feature overlap":{desc_len}} {mixdb.fg_step / mixdb.fg_stride} ({mixdb.feature_step_ms} ms)')
71
71
  logger.info(f'{"SNRs":{desc_len}} {mixdb.snrs}')
72
72
  logger.info(f'{"Random SNRs":{desc_len}} {mixdb.random_snrs}')
73
73
  logger.info(f'{"Classes":{desc_len}} {mixdb.num_classes}')
74
- logger.info(f'{"Truth mutex":{desc_len}} {mixdb.truth_mutex}')
75
74
  # TODO: fix class count
76
75
  logger.info(f'{"Class count":{desc_len}} not supported')
77
76
  # print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info)
78
77
  # TODO: add class weight calculations here
79
- logger.info('')
78
+ logger.info("")
80
79
 
81
80
  if list_targets:
82
- logger.info('Target details:')
81
+ logger.info("Target details:")
83
82
  idx_len = max_text_width(mixdb.num_target_files)
84
83
  for idx, target in enumerate(mixdb.target_files):
85
- desc = f' {idx:{idx_len}} Name'
86
- logger.info(f'{desc:{desc_len}} {target.name}')
87
- desc = f' {idx:{idx_len}} Truth index'
88
- logger.info(f'{desc:{desc_len}} {get_truth_indices_for_target(target)}')
89
- logger.info('')
90
-
91
- if truth_index is not None:
92
- if 0 <= truth_index > mixdb.num_classes:
93
- raise SonusAIError(f'Given truth_index is outside valid range of 1-{mixdb.num_classes}')
94
- ids = get_mixids_from_truth_index(mixdb=mixdb, predicate=lambda x: x in [truth_index])[truth_index]
95
- logger.info(f'Mixtures with truth index {truth_index}: {ids}')
96
- logger.info('')
84
+ desc = f" {idx:{idx_len}} Name"
85
+ logger.info(f"{desc:{desc_len}} {target.name}")
86
+ desc = f" {idx:{idx_len}} Truth index"
87
+ logger.info(f"{desc:{desc_len}} {target.class_indices}")
88
+ logger.info("")
89
+
90
+ if class_index is not None:
91
+ if 0 <= class_index > mixdb.num_classes:
92
+ raise ValueError(f"Given class_index is outside valid range of 1-{mixdb.num_classes}")
93
+ ids = get_mixids_from_class_indices(mixdb=mixdb, predicate=lambda x: x in [class_index])[class_index]
94
+ logger.info(f"Mixtures with class index {class_index}: {ids}")
95
+ logger.info("")
97
96
 
98
97
  mixids = mixdb.mixids_to_list(mixids)
99
98
 
@@ -101,24 +100,13 @@ def lsdb(mixdb: MixtureDatabase,
101
100
  print_mixture_details(mixdb=mixdb, mixid=mixids[0], desc_len=desc_len, print_fn=logger.info)
102
101
  if all_class_counts:
103
102
  # TODO: fix class count
104
- logger.info('All class count not supported')
103
+ logger.info("All class count not supported")
105
104
  # print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info, all_class_counts=True)
106
105
  else:
107
- logger.info(f'Calculating statistics from truth_f files for {len(mixids):,} mixtures'
108
- f' ({consolidate_range(mixids)})')
109
- for mixid in mixids:
110
- with h5py.File(mixdb.mixture_filename(mixid), 'r') as f:
111
- if mixid == mixids[0]:
112
- truth_f = np.array(f['truth_f'])
113
- else:
114
- truth_f = np.concatenate((truth_f, np.array(f['truth_f'])))
115
-
116
- snr_mean, snr_std, snr_db_mean, snr_db_std = calc_snr_f(truth_f)
117
-
118
- logger.info('Truth')
119
- logger.info(f' {"mean":^8s} {"std":^8s} {"db_mean":^8s} {"db_std":^8s}')
120
- for c in range(len(snr_mean)):
121
- logger.info(f' {snr_mean[c]:8.2f} {snr_std[c]:8.2f} {snr_db_mean[c]:8.2f} {snr_db_std[c]:8.2f}')
106
+ logger.info(
107
+ f"Calculating statistics from truth_f files for {len(mixids):,} mixtures" f" ({consolidate_range(mixids)})"
108
+ )
109
+ logger.info("Not supported")
122
110
 
123
111
 
124
112
  def main() -> None:
@@ -132,28 +120,30 @@ def main() -> None:
132
120
 
133
121
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
134
122
 
135
- mixid = args['--mixid']
136
- truth_index = args['--truth_index']
137
- list_targets = args['--targets']
138
- all_class_counts = args['--all_class_counts']
139
- location = args['LOC']
123
+ mixid = args["--mixid"]
124
+ class_index = args["--class_index"]
125
+ list_targets = args["--targets"]
126
+ all_class_counts = args["--all_class_counts"]
127
+ location = args["LOC"]
140
128
 
141
- if truth_index is not None:
142
- truth_index = int(truth_index)
129
+ if class_index is not None:
130
+ class_index = int(class_index)
143
131
 
144
- create_file_handler('lsdb.log')
132
+ create_file_handler("lsdb.log")
145
133
  update_console_handler(False)
146
- initial_log_messages('lsdb')
134
+ initial_log_messages("lsdb")
147
135
 
148
- logger.info(f'Analyzing {location}')
136
+ logger.info(f"Analyzing {location}")
149
137
 
150
138
  mixdb = MixtureDatabase(location)
151
- lsdb(mixdb=mixdb,
152
- mixids=mixid,
153
- truth_index=truth_index,
154
- list_targets=list_targets,
155
- all_class_counts=all_class_counts)
139
+ lsdb(
140
+ mixdb=mixdb,
141
+ mixids=mixid,
142
+ class_index=class_index,
143
+ list_targets=list_targets,
144
+ all_class_counts=all_class_counts,
145
+ )
156
146
 
157
147
 
158
- if __name__ == '__main__':
148
+ if __name__ == "__main__":
159
149
  main()
sonusai/main.py CHANGED
@@ -9,6 +9,7 @@ Aaware Sound and Voice Machine Learning Framework. See 'sonusai help <command>'
9
9
  for more information on a specific command.
10
10
 
11
11
  """
12
+
12
13
  import signal
13
14
 
14
15
 
@@ -17,7 +18,7 @@ def signal_handler(_sig, _frame):
17
18
 
18
19
  from sonusai import logger
19
20
 
20
- logger.info('Canceled due to keyboard interrupt')
21
+ logger.info("Canceled due to keyboard interrupt")
21
22
  sys.exit(1)
22
23
 
23
24
 
@@ -33,11 +34,11 @@ def main() -> None:
33
34
  plugins = {}
34
35
  plugin_docstrings = []
35
36
  for _, name, _ in iter_modules():
36
- if name.startswith('sonusai_') and not name.startswith('sonusai_asr_'):
37
+ if name.startswith("sonusai_") and not name.startswith("sonusai_asr_"):
37
38
  module = import_module(name)
38
39
  plugins[name] = {
39
- 'commands': commands_list(module.commands_doc),
40
- 'basedir': module.BASEDIR,
40
+ "commands": commands_list(module.commands_doc),
41
+ "basedir": module.BASEDIR,
41
42
  }
42
43
  plugin_docstrings.append(module.commands_doc)
43
44
 
@@ -47,12 +48,14 @@ def main() -> None:
47
48
  from sonusai.utils import add_commands_to_docstring
48
49
  from sonusai.utils import trim_docstring
49
50
 
50
- args = docopt(trim_docstring(add_commands_to_docstring(__doc__, plugin_docstrings)),
51
- version=__version__,
52
- options_first=True)
51
+ args = docopt(
52
+ trim_docstring(add_commands_to_docstring(__doc__, plugin_docstrings)),
53
+ version=__version__,
54
+ options_first=True,
55
+ )
53
56
 
54
- command = args['<command>']
55
- argv = args['<args>']
57
+ command = args["<command>"]
58
+ argv = args["<args>"]
56
59
 
57
60
  import sys
58
61
  from os.path import join
@@ -62,29 +65,29 @@ def main() -> None:
62
65
  from sonusai import logger
63
66
 
64
67
  base_commands = sonusai.commands_list()
65
- if command == 'help':
68
+ if command == "help":
66
69
  if not argv:
67
- exit(call(['sonusai', '-h']))
70
+ exit(call(["sonusai", "-h"])) # noqa: S603, S607
68
71
  elif argv[0] in base_commands:
69
- exit(call(['python', f'{join(sonusai.BASEDIR, argv[0])}.py', '-h']))
72
+ exit(call(["python", f"{join(sonusai.BASEDIR, argv[0])}.py", "-h"])) # noqa: S603, S607
70
73
 
71
- for plugin, data in plugins.items():
72
- if argv[0] in data['commands']:
73
- exit(call(['python', f'{join(data["basedir"], argv[0])}.py', '-h']))
74
+ for data in plugins.values():
75
+ if argv[0] in data["commands"]:
76
+ exit(call(["python", f"{join(data['basedir'], argv[0])}.py", "-h"])) # noqa: S603, S607
74
77
 
75
78
  logger.error(f"{argv[0]} is not a SonusAI command. See 'sonusai help'.")
76
79
  sys.exit(1)
77
80
 
78
81
  if command in base_commands:
79
- exit(call(['python', f'{join(sonusai.BASEDIR, command)}.py'] + argv))
82
+ exit(call(["python", f"{join(sonusai.BASEDIR, command)}.py", *argv])) # noqa: S603, S607
80
83
 
81
- for plugin, data in plugins.items():
82
- if command in data['commands']:
83
- exit(call(['python', f'{join(data["basedir"], command)}.py'] + argv))
84
+ for data in plugins.values():
85
+ if command in data["commands"]:
86
+ exit(call(["python", f"{join(data['basedir'], command)}.py", *argv])) # noqa: S603, S607
84
87
 
85
88
  logger.error(f"{command} is not a SonusAI command. See 'sonusai help'.")
86
89
  sys.exit(1)
87
90
 
88
91
 
89
- if __name__ == '__main__':
92
+ if __name__ == "__main__":
90
93
  main()
@@ -1,4 +1,6 @@
1
1
  # SonusAI metrics utilities for model training and validation
2
+ # ruff: noqa: F401
3
+
2
4
  from .calc_audio_stats import calc_audio_stats
3
5
  from .calc_class_weights import calc_class_weights_from_mixdb
4
6
  from .calc_class_weights import calc_class_weights_from_truth
@@ -3,48 +3,53 @@ from sonusai.mixture.datatypes import AudioT
3
3
 
4
4
 
5
5
  def _convert_str_with_factors_to_int(x: str) -> int:
6
- if 'k' in x:
7
- return int(1000 * float(x.replace('k', '')))
8
- if 'M' in x:
9
- return int(1000000 * float(x.replace('M', '')))
6
+ if "k" in x:
7
+ return int(1000 * float(x.replace("k", "")))
8
+ if "M" in x:
9
+ return int(1000000 * float(x.replace("M", "")))
10
10
  return int(x)
11
11
 
12
12
 
13
- def calc_audio_stats(audio: AudioT, win_len: float = None) -> AudioStatsMetrics:
13
+ def calc_audio_stats(audio: AudioT, win_len: float | None = None) -> AudioStatsMetrics:
14
14
  from sonusai.mixture import SAMPLE_RATE
15
15
  from sonusai.mixture import Transformer
16
16
 
17
- args = ['stats']
17
+ args = ["stats"]
18
18
  if win_len is not None:
19
- args.extend(['-w', str(win_len)])
19
+ args.extend(["-w", str(win_len)])
20
20
 
21
21
  tfm = Transformer()
22
22
 
23
- _, _, out = tfm.build(input_array=audio,
24
- sample_rate_in=SAMPLE_RATE,
25
- output_filepath='-n',
26
- extra_args=args,
27
- return_output=True)
23
+ _, _, out = tfm.build(
24
+ input_array=audio,
25
+ sample_rate_in=SAMPLE_RATE,
26
+ output_filepath="-n",
27
+ extra_args=args,
28
+ return_output=True,
29
+ )
30
+
31
+ if out is None:
32
+ raise SystemError("Call to sox failed")
28
33
 
29
34
  stats = {}
30
- lines = out.split('\n')
35
+ lines = out.split("\n")
31
36
  for line in lines:
32
37
  split_line = line.split()
33
38
  if len(split_line) == 0:
34
39
  continue
35
40
  value = split_line[-1]
36
- key = ' '.join(split_line[:-1])
41
+ key = " ".join(split_line[:-1])
37
42
  stats[key] = value
38
43
 
39
44
  return AudioStatsMetrics(
40
- dco=float(stats['DC offset']),
41
- min=float(stats['Min level']),
42
- max=float(stats['Max level']),
43
- pkdb=float(stats['Pk lev dB']),
44
- lrms=float(stats['RMS lev dB']),
45
- pkr=float(stats['RMS Pk dB']),
46
- tr=float(stats['RMS Tr dB']),
47
- cr=float(stats['Crest factor']),
48
- fl=float(stats['Flat factor']),
49
- pkc=_convert_str_with_factors_to_int(stats['Pk count']),
45
+ dco=float(stats["DC offset"]),
46
+ min=float(stats["Min level"]),
47
+ max=float(stats["Max level"]),
48
+ pkdb=float(stats["Pk lev dB"]),
49
+ lrms=float(stats["RMS lev dB"]),
50
+ pkr=float(stats["RMS Pk dB"]),
51
+ tr=float(stats["RMS Tr dB"]),
52
+ cr=float(stats["Crest factor"]),
53
+ fl=float(stats["Flat factor"]),
54
+ pkc=_convert_str_with_factors_to_int(stats["Pk count"]),
50
55
  )
@@ -5,9 +5,7 @@ from sonusai.mixture.datatypes import Truth
5
5
  from sonusai.mixture.mixdb import MixtureDatabase
6
6
 
7
7
 
8
- def calc_class_weights_from_truth(truth: Truth,
9
- other_weight: float = None,
10
- other_index: int = -1) -> np.ndarray:
8
+ def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = None, other_index: int = -1) -> np.ndarray:
11
9
  """Calculate class weights.
12
10
 
13
11
  Supports non-existent classes (a problem with sklearn) where non-existent
@@ -54,10 +52,12 @@ def calc_class_weights_from_truth(truth: Truth,
54
52
  return weights
55
53
 
56
54
 
57
- def calc_class_weights_from_mixdb(mixdb: MixtureDatabase,
58
- mixids: GeneralizedIDs = None,
59
- other_weight: float = 1,
60
- other_index: int = -1) -> tuple[np.ndarray, np.ndarray]:
55
+ def calc_class_weights_from_mixdb(
56
+ mixdb: MixtureDatabase,
57
+ mixids: GeneralizedIDs | None = None,
58
+ other_weight: float = 1,
59
+ other_index: int = -1,
60
+ ) -> tuple[np.ndarray, np.ndarray]:
61
61
  """Calculate class weights using estimated feature counts from a mixture database.
62
62
 
63
63
  Arguments:
@@ -4,10 +4,9 @@ from sonusai.mixture.datatypes import Predict
4
4
  from sonusai.mixture.datatypes import Truth
5
5
 
6
6
 
7
- def calc_optimal_thresholds(truth: Truth,
8
- predict: Predict,
9
- timesteps: int = 0,
10
- truth_thr: float = 0.5) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
7
+ def calc_optimal_thresholds(
8
+ truth: Truth, predict: Predict, timesteps: int = 0, truth_thr: float = 0.5
9
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
11
10
  """Calculates optimal thresholds for each class from one-hot prediction and truth data where both are
12
11
  one-hot probabilities (or quantized decisions) with size [frames, num_classes] or [frames, timesteps, num_classes].
13
12
 
@@ -25,14 +24,13 @@ def calc_optimal_thresholds(truth: Truth,
25
24
  from sklearn.metrics import roc_auc_score
26
25
  from sklearn.metrics import roc_curve
27
26
 
28
- from sonusai import SonusAIError
29
27
  from sonusai.utils import get_num_classes_from_predict
30
28
  from sonusai.utils import reshape_outputs
31
29
 
32
30
  if truth.shape != predict.shape:
33
- raise SonusAIError('truth and predict are not the same shape')
31
+ raise ValueError("truth and predict are not the same shape")
34
32
 
35
- predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps)
33
+ predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps) # type: ignore[assignment]
36
34
  num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
37
35
 
38
36
  # Apply decision to truth input
@@ -1,9 +1,9 @@
1
1
  import numpy as np
2
2
 
3
3
 
4
- def calc_pcm(hypothesis: np.ndarray,
5
- reference: np.ndarray,
6
- with_log: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
4
+ def calc_pcm(
5
+ hypothesis: np.ndarray, reference: np.ndarray, with_log: bool = False
6
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
7
7
  """Calculate phase constrained magnitude error
8
8
 
9
9
  These must include a noise to make a complete mixture estimate, i.e.,
@@ -3,10 +3,12 @@ import numpy as np
3
3
  from sonusai.mixture.constants import SAMPLE_RATE
4
4
 
5
5
 
6
- def calc_pesq(hypothesis: np.ndarray,
7
- reference: np.ndarray,
8
- error_value: float = 0.0,
9
- sample_rate: int = SAMPLE_RATE) -> float:
6
+ def calc_pesq(
7
+ hypothesis: np.ndarray,
8
+ reference: np.ndarray,
9
+ error_value: float = 0.0,
10
+ sample_rate: int = SAMPLE_RATE,
11
+ ) -> float:
10
12
  """Computes the PESQ score of hypothesis vs. reference
11
13
 
12
14
  Upon error, assigns a value of 0, or user specified value in error_value
@@ -22,12 +24,13 @@ def calc_pesq(hypothesis: np.ndarray,
22
24
  from pesq import pesq
23
25
 
24
26
  from sonusai import logger
27
+
25
28
  try:
26
29
  with warnings.catch_warnings():
27
- warnings.simplefilter('ignore')
28
- score = pesq(fs=sample_rate, ref=reference, deg=hypothesis, mode='wb')
30
+ warnings.simplefilter("ignore")
31
+ score = pesq(fs=sample_rate, ref=reference, deg=hypothesis, mode="wb")
29
32
  except Exception as e:
30
- logger.debug(f'PESQ error {e}')
33
+ logger.debug(f"PESQ error {e}")
31
34
  score = error_value
32
35
 
33
36
  return score
@@ -1,9 +1,9 @@
1
1
  import numpy as np
2
2
 
3
3
 
4
- def calc_phase_distance(reference: np.ndarray,
5
- hypothesis: np.ndarray,
6
- eps: float = 1e-9) -> tuple[float, np.ndarray, np.ndarray]:
4
+ def calc_phase_distance(
5
+ reference: np.ndarray, hypothesis: np.ndarray, eps: float = 1e-9
6
+ ) -> tuple[float, np.ndarray, np.ndarray]:
7
7
  """Calculate weighted phase distance error (weight normalization over bins per frame)
8
8
 
9
9
  :param reference: complex [frames, bins]
@@ -1,10 +1,12 @@
1
1
  import numpy as np
2
2
 
3
3
 
4
- def calc_sa_sdr(hypothesis: np.ndarray,
5
- reference: np.ndarray,
6
- with_scale: bool = False,
7
- with_negate: bool = False) -> tuple[np.ndarray, np.ndarray]:
4
+ def calc_sa_sdr(
5
+ hypothesis: np.ndarray,
6
+ reference: np.ndarray,
7
+ with_scale: bool = False,
8
+ with_negate: bool = False,
9
+ ) -> tuple[np.ndarray, np.ndarray]:
8
10
  """Calculate source-aggregated SDR (signal distortion ratio) using all source inputs which are [samples, nsrc].
9
11
 
10
12
  These should include a noise to be a complete mixture estimate, i.e.,
@@ -30,9 +32,9 @@ def calc_sa_sdr(hypothesis: np.ndarray,
30
32
  """
31
33
  if with_scale:
32
34
  # calc 1 x nsrc scaling factors
33
- ref_energy = np.sum(reference ** 2, axis=0, keepdims=True)
35
+ ref_energy = np.sum(reference**2, axis=0, keepdims=True)
34
36
  # if ref_energy is zero, just set scaling to 1.0
35
- with np.errstate(divide='ignore', invalid='ignore'):
37
+ with np.errstate(divide="ignore", invalid="ignore"):
36
38
  opt_scale = np.sum(reference * hypothesis, axis=0, keepdims=True) / ref_energy
37
39
  opt_scale[opt_scale == np.inf] = 1.0
38
40
  opt_scale = np.nan_to_num(opt_scale, nan=1.0)
@@ -46,8 +48,8 @@ def calc_sa_sdr(hypothesis: np.ndarray,
46
48
 
47
49
  # -10*log10(sumk(||sk||^2) / sumk(||sk - shk||^2)
48
50
  # sum over samples and sources
49
- num = np.sum(reference ** 2)
50
- den = np.sum(err ** 2)
51
+ num = np.sum(reference**2)
52
+ den = np.sum(err**2)
51
53
  if num == 0 and den == 0:
52
54
  ratio = np.inf
53
55
  else:
@@ -33,10 +33,7 @@ def calc_segsnr_f(segsnr_f: Segsnr) -> SnrFMetrics:
33
33
  snr_db_mean = np.mean(tmp, axis=0)
34
34
  snr_db_std = np.std(tmp, axis=0)
35
35
 
36
- return SnrFMetrics(snr_mean,
37
- snr_std,
38
- snr_db_mean,
39
- snr_db_std)
36
+ return SnrFMetrics(snr_mean, snr_std, snr_db_mean, snr_db_std)
40
37
 
41
38
 
42
39
  def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
@@ -46,25 +43,24 @@ def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
46
43
  and mean and standard deviation of the dB values.
47
44
  """
48
45
  if target_f.ndim != 2 and noise_f.ndim != 2:
49
- raise ValueError('target_f and noise_f must have 2 dimensions')
46
+ raise ValueError("target_f and noise_f must have 2 dimensions")
50
47
 
51
48
  segsnr_f = (np.abs(target_f) ** 2) / (np.abs(noise_f) ** 2)
52
49
 
53
50
  frames, bins = segsnr_f.shape
54
51
  if np.count_nonzero(segsnr_f) == 0:
55
52
  # If all entries are zeros
56
- return SnrFBinMetrics(np.zeros(bins),
57
- np.zeros(bins),
58
- -np.inf * np.ones(bins),
59
- np.zeros(bins))
53
+ return SnrFBinMetrics(np.zeros(bins), np.zeros(bins), -np.inf * np.ones(bins), np.zeros(bins))
60
54
 
61
55
  tmp = np.ma.array(segsnr_f, mask=np.logical_not(np.isfinite(segsnr_f)))
62
56
  if np.ma.count_masked(tmp) == np.ma.size(tmp, axis=0):
63
57
  # If all entries are infinite
64
- return SnrFBinMetrics(np.inf * np.ones(bins),
65
- np.zeros(bins),
66
- np.inf * np.ones(bins),
67
- np.zeros(bins))
58
+ return SnrFBinMetrics(
59
+ np.inf * np.ones(bins),
60
+ np.zeros(bins),
61
+ np.inf * np.ones(bins),
62
+ np.zeros(bins),
63
+ )
68
64
 
69
65
  snr_mean = np.mean(tmp, axis=0)
70
66
  snr_std = np.std(tmp, axis=0)
@@ -78,7 +74,9 @@ def calc_segsnr_f_bin(target_f: AudioF, noise_f: AudioF) -> SnrFBinMetrics:
78
74
  snr_db_mean = np.mean(tmp, axis=0)
79
75
  snr_db_std = np.std(tmp, axis=0)
80
76
 
81
- return SnrFBinMetrics(np.ma.getdata(snr_mean),
82
- np.ma.getdata(snr_std),
83
- np.ma.getdata(snr_db_mean),
84
- np.ma.getdata(snr_db_std))
77
+ return SnrFBinMetrics(
78
+ np.ma.getdata(snr_mean),
79
+ np.ma.getdata(snr_std),
80
+ np.ma.getdata(snr_db_mean),
81
+ np.ma.getdata(snr_db_std),
82
+ )