sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,314 @@
1
+ """sonusai metrics_summary
2
+
3
+ usage: lsdb [-vlh] [-i MIXID] [-n NCPU] LOCATION
4
+
5
+ Options:
6
+ -h, --help
7
+ -v, --verbose
8
+ -l, --write-list Write .csv file list of all mixture metrics
9
+ -i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
10
+ -n, --num_process NCPU Number of parallel processes to use [default: auto]
11
+
12
+ Summarize mixture metrics across a SonusAI mixture database where metrics have been generated by SonusAI genmetrics.
13
+
14
+ Inputs:
15
+ LOCATION A SonusAI mixture database directory with mixdb.db and pre-generated metrics from SonusAI genmetrics.
16
+
17
+ """
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+
22
+ DB_99 = np.power(10, 99 / 10)
23
+ DB_N99 = np.power(10, -99 / 10)
24
+
25
+
26
+ def _process_mixture(
27
+ m_id: int,
28
+ location: str,
29
+ all_metric_names: list[str],
30
+ scalar_metric_names: list[str],
31
+ string_metric_names: list[str],
32
+ frame_metric_names: list[str],
33
+ bin_metric_names: list[str],
34
+ ptab_labels: list[str],
35
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
36
+ from os.path import basename
37
+
38
+ from sonusai.constants import SAMPLE_RATE
39
+ from sonusai.metrics import calc_wer
40
+ from sonusai.mixture import MixtureDatabase
41
+
42
+ mixdb = MixtureDatabase(location)
43
+
44
+ # Process mixture
45
+ # for mixid in mixids:
46
+ samples = mixdb.mixture(m_id).samples
47
+ duration = samples / SAMPLE_RATE
48
+ tf_frames = mixdb.mixture_transform_frames(m_id)
49
+ feat_frames = mixdb.mixture_feature_frames(m_id)
50
+ mxsnr = mixdb.mixture(m_id).noise.snr
51
+ ti = mixdb.mixture(m_id).sources["primary"].file_id
52
+ ni = mixdb.mixture(m_id).noise.file_id
53
+ t0file = basename(mixdb.source_file(ti).name)
54
+ nfile = basename(mixdb.source_file(ni).name)
55
+
56
+ all_metrics = mixdb.mixture_metrics(m_id, all_metric_names)
57
+
58
+ # replace dict with 'primary' value (ignore mixup)
59
+ scalar_metrics = {
60
+ key: all_metrics[key]["primary"] if isinstance(all_metrics[key], dict) else all_metrics[key]
61
+ for key in scalar_metric_names
62
+ }
63
+ string_metrics = {
64
+ key: all_metrics[key]["primary"] if isinstance(all_metrics[key], dict) else all_metrics[key]
65
+ for key in string_metric_names
66
+ }
67
+
68
+ # Convert strings into word count
69
+ for key in string_metrics:
70
+ string_metrics[key] = calc_wer(string_metrics[key], string_metrics[key]).words
71
+
72
+ # Collect pandas table values note: must match given ptab_labels
73
+ ptab_data: list = [
74
+ mxsnr,
75
+ *scalar_metrics.values(),
76
+ *string_metrics.values(),
77
+ tf_frames,
78
+ duration,
79
+ t0file,
80
+ nfile,
81
+ ]
82
+
83
+ ptab1 = pd.DataFrame([ptab_data], columns=ptab_labels, index=[m_id])
84
+
85
+ # TODO: collect frame metrics and bin metrics
86
+
87
+ return ptab1, ptab1
88
+
89
+
90
+ def main() -> None:
91
+ from docopt import docopt
92
+
93
+ from sonusai import __version__ as sai_version
94
+ from sonusai.utils.docstring import trim_docstring
95
+
96
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
97
+
98
+ verbose = args["--verbose"]
99
+ wrlist = args["--write-list"]
100
+ mixids = args["--mixid"]
101
+ location = args["LOCATION"]
102
+ num_proc = args["--num_process"]
103
+
104
+ from functools import partial
105
+ from os.path import basename
106
+ from os.path import join
107
+
108
+ import psutil
109
+
110
+ from sonusai import create_file_handler
111
+ from sonusai import initial_log_messages
112
+ from sonusai import logger
113
+ from sonusai import update_console_handler
114
+ from sonusai.mixture import MixtureDatabase
115
+ from sonusai.utils.create_timestamp import create_timestamp
116
+ from sonusai.utils.parallel import par_track
117
+ from sonusai.utils.parallel import track
118
+
119
+ mixdb = MixtureDatabase(location)
120
+ print(f"Found SonusAI mixture database with {mixdb.num_mixtures} mixtures.")
121
+
122
+ # Only check first and last mixture in order to save time
123
+ metrics_present = mixdb.cached_metrics([0, mixdb.num_mixtures - 1]) # return pre-generated metrics in mixdb tree
124
+ if "mxsnr" in metrics_present:
125
+ metrics_present.remove("mxsnr")
126
+
127
+ num_metrics_present = len(metrics_present)
128
+ if num_metrics_present < 1:
129
+ print(f"mixdb reports no pre-generated metrics are present. Nothing to summarize in {location}, exiting ...")
130
+ return
131
+
132
+ # Setup logging file
133
+ timestamp = create_timestamp() # string good for embedding into filenames
134
+ mixdb_fname = basename(location)
135
+ if verbose:
136
+ create_file_handler(join(location, "metrics_summary.log"), verbose)
137
+ update_console_handler(verbose)
138
+ initial_log_messages("metrics_summary")
139
+ logger.info(f"Logging summary of SonusAI mixture database at {location}")
140
+ else:
141
+ update_console_handler(verbose)
142
+
143
+ logger.info("")
144
+ mixids = mixdb.mixids_to_list(mixids)
145
+ if len(mixids) < mixdb.num_mixtures:
146
+ logger.info(
147
+ f"Processing a subset of {len(mixids)} out of total mixdb mixtures of {mixdb.num_mixtures}, "
148
+ f"summary results will not include entire dataset."
149
+ )
150
+ fsuffix = f"_s{len(mixids)}t{mixdb.num_mixtures}"
151
+ else:
152
+ logger.info(
153
+ f"Summarizing SonusAI mixture database with {mixdb.num_mixtures} mixtures "
154
+ f"and {num_metrics_present} pre-generated metrics ..."
155
+ )
156
+ fsuffix = ""
157
+
158
+ metric_sup = mixdb.supported_metrics
159
+ ft_bins = mixdb.ft_config.bin_end - mixdb.ft_config.bin_start + 1 # bins of forward transform
160
+ # Pre-process first mixid to gather metrics into 4 types: scalar, str (scalar word cnt), frame-array, bin-array
161
+ # Collect list of indices for each
162
+ scalar_metric_names: list[str] = []
163
+ string_metric_names: list[str] = []
164
+ frame_metric_names: list[str] = []
165
+ bin_metric_names: list[str] = []
166
+ all_metrics = mixdb.mixture_metrics(mixids[0], metrics_present)
167
+ tf_frames = mixdb.mixture_transform_frames(mixids[0])
168
+ for metric in metrics_present:
169
+ metval = all_metrics[metric] # get metric value
170
+ logger.debug(f"First mixid {mixids[0]} metric {metric} = {metval}")
171
+ if isinstance(metval, dict):
172
+ logger.warning(f"Mixid {mixids[0]} metric {metric} is a dict, using 'primary'.")
173
+ metval = metval["primary"] # remove any dict
174
+ if isinstance(metval, float | int):
175
+ logger.debug(f"Metric is scalar {type(metval)}, entering in summary table.")
176
+ scalar_metric_names.append(metric)
177
+ elif isinstance(metval, str):
178
+ logger.debug("Metric is string, will summarize with word count.")
179
+ string_metric_names.append(metric)
180
+ elif isinstance(metval, np.ndarray):
181
+ if metval.ndim == 1:
182
+ if metval.size == tf_frames:
183
+ logger.debug("Metric is frames vector.")
184
+ frame_metric_names.append(metric)
185
+ elif metval.size == ft_bins:
186
+ logger.debug("Metric is bins vector.")
187
+ bin_metric_names.append(metric)
188
+ else:
189
+ logger.warning(f"Mixid {mixids[0]} metric {metric} is a vector of improper size, ignoring.")
190
+
191
+ # Setup pandas table for summarizing scalar metrics, always include mxsnr first
192
+ ptab_labels = [
193
+ "mxsnr",
194
+ *scalar_metric_names,
195
+ *string_metric_names,
196
+ "fcnt",
197
+ "duration",
198
+ "t0file",
199
+ "nfile",
200
+ ]
201
+
202
+ num_cpu = psutil.cpu_count()
203
+ cpu_percent = psutil.cpu_percent(interval=1)
204
+ logger.info("")
205
+ logger.info(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
206
+ logger.info(f"Memory utilization: {psutil.virtual_memory().percent}%")
207
+ if num_proc == "auto":
208
+ use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
209
+ elif num_proc == "None":
210
+ use_cpu = None
211
+ else:
212
+ use_cpu = min(max(int(num_proc), 1), num_cpu)
213
+
214
+ logger.info(f"Summarizing metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
215
+
216
+ # progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
217
+ progress = track(total=len(mixids))
218
+ if use_cpu is None:
219
+ no_par = True
220
+ num_cpus = None
221
+ else:
222
+ no_par = False
223
+ num_cpus = use_cpu
224
+
225
+ all_metrics_tables = par_track(
226
+ partial(
227
+ _process_mixture,
228
+ location=location,
229
+ all_metric_names=metrics_present,
230
+ scalar_metric_names=scalar_metric_names,
231
+ string_metric_names=string_metric_names,
232
+ frame_metric_names=frame_metric_names,
233
+ bin_metric_names=bin_metric_names,
234
+ ptab_labels=ptab_labels,
235
+ ),
236
+ mixids,
237
+ progress=progress,
238
+ num_cpus=num_cpus,
239
+ no_par=no_par,
240
+ )
241
+ progress.close()
242
+
243
+ # Done with mixtures, write out summary metrics
244
+ header_args = {
245
+ "mode": "a",
246
+ "encoding": "utf-8",
247
+ "index": False,
248
+ "header": False,
249
+ }
250
+ table_args = {
251
+ "mode": "a",
252
+ "encoding": "utf-8",
253
+ }
254
+ ptab1 = pd.concat([item[0] for item in all_metrics_tables])
255
+ if wrlist:
256
+ wlcsv_name = str(join(location, "metric_summary_list" + fsuffix + ".csv"))
257
+ pd.DataFrame([["Timestamp", timestamp]]).to_csv(wlcsv_name, header=False, index=False)
258
+ pd.DataFrame([f"Metric list for {mixdb_fname}:"]).to_csv(wlcsv_name, mode="a", header=False, index=False)
259
+ ptab1.round(2).to_csv(wlcsv_name, **table_args)
260
+ ptab1_sorted = ptab1.sort_values(by=["mxsnr", "t0file"])
261
+
262
+ # Create metrics table except -99 SNR
263
+ ptab1_nom99 = ptab1_sorted[ptab1_sorted.mxsnr != -99]
264
+
265
+ # Create summary by SNR for all scalar metrics, taking mean
266
+ mtab_snr_summary = None
267
+ for snri in range(0, len(mixdb.snrs)):
268
+ tmp = ptab1_sorted.query("mxsnr==" + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
269
+ # avoid nan when subset of mixids specified (i.e. no mixtures exist for an SNR)
270
+ if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
271
+ mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
272
+ mtab_snr_summary = mtab_snr_summary.sort_values(by=["mxsnr"], ascending=False)
273
+
274
+ # Write summary to .csv
275
+ snrcsv_name = str(join(location, "metric_summary_snr" + fsuffix + ".csv"))
276
+ nmix = len(mixids)
277
+ nmixtot = mixdb.num_mixtures
278
+ pd.DataFrame([["Timestamp", timestamp]]).to_csv(snrcsv_name, header=False, index=False)
279
+ pd.DataFrame(['"Metrics avg over each SNR:"']).to_csv(snrcsv_name, **header_args)
280
+ mtab_snr_summary.round(2).T.to_csv(snrcsv_name, index=True, header=False, mode="a", encoding="utf-8")
281
+ pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
282
+ pd.DataFrame([f'"Metrics stats over {nmix} mixtures out of {nmixtot} total:"']).to_csv(snrcsv_name, **header_args)
283
+ ptab1.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
284
+ pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
285
+ pd.DataFrame([f'"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {nmixtot} total:"']).to_csv(
286
+ snrcsv_name, **header_args
287
+ )
288
+ ptab1_nom99.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
289
+
290
+ # Write summary to text file
291
+ snrtxt_name = str(join(location, "metric_summary_snr" + fsuffix + ".txt"))
292
+ with open(snrtxt_name, "w") as f:
293
+ print(f"Timestamp: {timestamp}", file=f)
294
+ print("Metrics avg over each SNR:", file=f)
295
+ print(
296
+ mtab_snr_summary.round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True, header=False), file=f
297
+ )
298
+ print("", file=f)
299
+ print(f"Metrics stats over {len(mixids)} mixtures out of {mixdb.num_mixtures} total:", file=f)
300
+ print(ptab1.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
301
+ print("", file=f)
302
+ print(f"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {mixdb.num_mixtures} total:", file=f)
303
+ print(ptab1_nom99.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
304
+
305
+
306
+ if __name__ == "__main__":
307
+ from sonusai import exception_handler
308
+ from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
309
+
310
+ register_keyboard_interrupt()
311
+ try:
312
+ main()
313
+ except Exception as e:
314
+ exception_handler(e)
@@ -0,0 +1,15 @@
1
+ # SonusAI mixture utilities
2
+
3
+ from .feature import get_audio_from_feature
4
+ from .feature import get_feature_from_audio
5
+ from .helpers import forward_transform
6
+ from .helpers import inverse_transform
7
+ from .mixdb import MixtureDatabase
8
+
9
+ __all__ = [
10
+ "MixtureDatabase",
11
+ "forward_transform",
12
+ "get_audio_from_feature",
13
+ "get_feature_from_audio",
14
+ "inverse_transform",
15
+ ]
@@ -0,0 +1,187 @@
1
+ from functools import lru_cache
2
+ from pathlib import Path
3
+
4
+ from ..datatypes import AudioT
5
+
6
+
7
+ def get_next_noise(audio: AudioT, offset: int, length: int) -> AudioT:
8
+ """Get the next sequence of noise data from noise audio
9
+
10
+ :param audio: Overall noise audio (entire file's worth of data)
11
+ :param offset: Starting sample
12
+ :param length: Number of samples to get
13
+ :return: Sequence of noise audio data
14
+ """
15
+ import numpy as np
16
+
17
+ return np.take(audio, range(offset, offset + length), mode="wrap")
18
+
19
+
20
+ def get_duration(audio: AudioT) -> float:
21
+ """Get duration of audio in seconds
22
+
23
+ :param audio: Time domain data [samples]
24
+ :return: Duration of audio in seconds
25
+ """
26
+ from ..constants import SAMPLE_RATE
27
+
28
+ return len(audio) / SAMPLE_RATE
29
+
30
+
31
+ def validate_input_file(input_filepath: str | Path) -> None:
32
+ from os.path import exists
33
+ from os.path import splitext
34
+
35
+ from soundfile import available_formats
36
+
37
+ if not exists(input_filepath):
38
+ raise OSError(f"input_filepath {input_filepath} does not exist.")
39
+
40
+ ext = splitext(input_filepath)[1][1:].lower()
41
+ read_formats = [item.lower() for item in available_formats()]
42
+ if ext not in read_formats:
43
+ raise OSError(f"This installation cannot process .{ext} files")
44
+
45
+
46
+ def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
47
+ """Get sample rate from audio file
48
+
49
+ :param name: File name
50
+ :param use_cache: If true, use LRU caching
51
+ :return: Sample rate
52
+ """
53
+ if use_cache:
54
+ return _get_sample_rate(name)
55
+ return _get_sample_rate.__wrapped__(name)
56
+
57
+
58
+ @lru_cache
59
+ def _get_sample_rate(name: str | Path) -> int:
60
+ """Get sample rate from audio file using soundfile
61
+
62
+ :param name: File name
63
+ :return: Sample rate
64
+ """
65
+ import soundfile
66
+ from pydub import AudioSegment
67
+
68
+ from ..utils.tokenized_shell_vars import tokenized_expand
69
+
70
+ expanded_name, _ = tokenized_expand(name)
71
+
72
+ try:
73
+ if expanded_name.endswith(".mp3"):
74
+ return AudioSegment.from_mp3(expanded_name).frame_rate
75
+
76
+ if expanded_name.endswith(".m4a"):
77
+ return AudioSegment.from_file(expanded_name).frame_rate
78
+
79
+ return soundfile.info(expanded_name).samplerate
80
+ except Exception as e:
81
+ if name != expanded_name:
82
+ raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
83
+ else:
84
+ raise OSError(f"Error reading {name}: {e}") from e
85
+
86
+
87
+ def raw_read_audio(name: str | Path) -> tuple[AudioT, int]:
88
+ import numpy as np
89
+ import soundfile
90
+ from pydub import AudioSegment
91
+
92
+ from ..utils.tokenized_shell_vars import tokenized_expand
93
+
94
+ expanded_name, _ = tokenized_expand(name)
95
+
96
+ try:
97
+ if expanded_name.endswith(".mp3"):
98
+ sound = AudioSegment.from_mp3(expanded_name)
99
+ raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
100
+ raw = raw / 2 ** (sound.sample_width * 8 - 1)
101
+ sample_rate = sound.frame_rate
102
+ elif expanded_name.endswith(".m4a"):
103
+ sound = AudioSegment.from_file(expanded_name)
104
+ raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
105
+ raw = raw / 2 ** (sound.sample_width * 8 - 1)
106
+ sample_rate = sound.frame_rate
107
+ else:
108
+ raw, sample_rate = soundfile.read(expanded_name, always_2d=True, dtype="float32")
109
+ except Exception as e:
110
+ if name != expanded_name:
111
+ raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
112
+ else:
113
+ raise OSError(f"Error reading {name}: {e}") from e
114
+
115
+ return np.squeeze(raw[:, 0].astype(np.float32)), sample_rate
116
+
117
+
118
+ def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
119
+ """Read audio data from a file
120
+
121
+ :param name: File name
122
+ :param use_cache: If true, use LRU caching
123
+ :return: Array of time domain audio data
124
+ """
125
+ if use_cache:
126
+ return _read_audio(name)
127
+ return _read_audio.__wrapped__(name)
128
+
129
+
130
+ @lru_cache
131
+ def _read_audio(name: str | Path) -> AudioT:
132
+ """Read audio data from a file using soundfile
133
+
134
+ :param name: File name
135
+ :return: Array of time domain audio data
136
+ """
137
+ from ..constants import SAMPLE_RATE
138
+ from .resample import resample
139
+
140
+ out, sample_rate = raw_read_audio(name)
141
+
142
+ return resample(out, orig_sr=sample_rate, target_sr=SAMPLE_RATE)
143
+
144
+
145
+ def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
146
+ """Get the number of samples resampled to the SonusAI sample rate in the given file
147
+
148
+ :param name: File name
149
+ :param use_cache: If true, use LRU caching
150
+ :return: number of samples in resampled audio
151
+ """
152
+ if use_cache:
153
+ return _get_num_samples(name)
154
+ return _get_num_samples.__wrapped__(name)
155
+
156
+
157
+ @lru_cache
158
+ def _get_num_samples(name: str | Path) -> int:
159
+ """Get the number of samples resampled to the SonusAI sample rate in the given file
160
+
161
+ :param name: File name
162
+ :return: number of samples in resampled audio
163
+ """
164
+ import math
165
+
166
+ import soundfile
167
+ from pydub import AudioSegment
168
+
169
+ from ..constants import SAMPLE_RATE
170
+ from ..utils.tokenized_shell_vars import tokenized_expand
171
+
172
+ expanded_name, _ = tokenized_expand(name)
173
+
174
+ if expanded_name.endswith(".mp3"):
175
+ sound = AudioSegment.from_mp3(expanded_name)
176
+ samples = sound.frame_count()
177
+ sample_rate = sound.frame_rate
178
+ elif expanded_name.endswith(".m4a"):
179
+ sound = AudioSegment.from_file(expanded_name)
180
+ samples = sound.frame_count()
181
+ sample_rate = sound.frame_rate
182
+ else:
183
+ info = soundfile.info(expanded_name)
184
+ samples = info.frames
185
+ sample_rate = info.samplerate
186
+
187
+ return math.ceil(SAMPLE_RATE * samples / sample_rate)
@@ -0,0 +1,103 @@
1
+ from ..datatypes import EffectList
2
+ from ..datatypes import EffectedFile
3
+ from ..datatypes import File
4
+
5
+
6
+ def balance_sources(
7
+ effected_sources: list[EffectedFile],
8
+ files: list[File],
9
+ effects: list[EffectList],
10
+ class_balancing_effect: EffectList,
11
+ num_classes: int,
12
+ num_ir: int,
13
+ mixups: list[int] | None = None,
14
+ ) -> tuple[list[EffectedFile], list[EffectList]]:
15
+ import math
16
+
17
+ from .augmentation import get_mixups
18
+ from .sources import get_augmented_target_ids_by_class
19
+
20
+ first_cba_id = len(effects)
21
+
22
+ if mixups is None:
23
+ mixups = get_mixups(effects)
24
+
25
+ for mixup in mixups:
26
+ if mixup == 1:
27
+ continue
28
+
29
+ effected_sources_indices_by_class = get_augmented_target_ids_by_class(
30
+ augmented_targets=effected_sources,
31
+ targets=files,
32
+ target_augmentations=effects,
33
+ mixup=mixup,
34
+ num_classes=num_classes,
35
+ )
36
+
37
+ largest = max([len(item) for item in effected_sources_indices_by_class])
38
+ largest = math.ceil(largest / mixup) * mixup
39
+ for es_indices in effected_sources_indices_by_class:
40
+ additional_effects_needed = largest - len(es_indices)
41
+ file_ids = sorted({effected_sources[at_index].file_id for at_index in es_indices})
42
+
43
+ tfi_idx = 0
44
+ for _ in range(additional_effects_needed):
45
+ file_id = file_ids[tfi_idx]
46
+ tfi_idx = (tfi_idx + 1) % len(file_ids)
47
+ effect_id, effects = _get_unused_balancing_effect(
48
+ effected_sources=effected_sources,
49
+ files=files,
50
+ effects=effects,
51
+ class_balancing_effect=class_balancing_effect,
52
+ file_id=file_id,
53
+ mixup=mixup,
54
+ num_ir=num_ir,
55
+ first_cbe_id=first_cba_id,
56
+ )
57
+ effected_sources.append(EffectedFile(file_id=file_id, effect_id=effect_id))
58
+
59
+ return effected_sources, effects
60
+
61
+
62
+ def _get_unused_balancing_effect(
63
+ effected_sources: list[EffectedFile],
64
+ files: list[File],
65
+ effects: list[EffectList],
66
+ class_balancing_effect: EffectList,
67
+ file_id: int,
68
+ mixup: int,
69
+ num_ir: int,
70
+ first_cbe_id: int,
71
+ ) -> tuple[int, list[EffectList]]:
72
+ """Get an unused balancing augmentation for a given target file index"""
73
+ from dataclasses import asdict
74
+
75
+ from .augmentation import get_augmentation_rules
76
+
77
+ balancing_augmentations = [item for item in range(len(effects)) if item >= first_cbe_id]
78
+ used_balancing_augmentations = [
79
+ effected_source.effect_id
80
+ for effected_source in effected_sources
81
+ if effected_source.file_id == file_id and effected_source.effect_id in balancing_augmentations
82
+ ]
83
+
84
+ augmentation_indices = [
85
+ item
86
+ for item in balancing_augmentations
87
+ if item not in used_balancing_augmentations and effects[item].mixup == mixup
88
+ ]
89
+ if len(augmentation_indices) > 0:
90
+ return augmentation_indices[0], effects
91
+
92
+ class_balancing_effect = get_class_balancing_effect(file=files[file_id], default_cbe=class_balancing_effect)
93
+ new_effect = get_augmentation_rules(rules=asdict(class_balancing_effect), num_ir=num_ir)[0]
94
+ new_effect.mixup = mixup
95
+ effects.append(new_effect)
96
+ return len(effects) - 1, effects
97
+
98
+
99
+ def get_class_balancing_effect(file: File, default_cbe: EffectList) -> EffectList:
100
+ """Get the class balancing effect rule for the given target"""
101
+ if file.class_balancing_effect is not None:
102
+ return file.class_balancing_effect
103
+ return default_cbe
@@ -0,0 +1,3 @@
1
+ MIXDB_VERSION = 3
2
+ MIXDB_NAME = "mixdb.db"
3
+ TEST_MIXDB_NAME = "mixdb_test.db"