sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/ir_metric.py ADDED
@@ -0,0 +1,551 @@
1
+ """sonusai ir_metric
2
+
3
+ usage: ir_metric [-h] [-n NCPU] IRLOC
4
+
5
+ options:
6
+ -h, --help
7
+ -n, --num_process NCPU Number of parallel processes to use [default: auto]
8
+
9
+ Calculate delay and gain metrics of impulse response (IR) files <filename>.wav in IRLOC.
10
+ Metrics include gain and multiple ways to calculate the IR delay:
11
+ - gmax: max abs(fft(ir,4096))
12
+ - dcc: cross-correlation of ir with pulse train
13
+ - dmax: index of max(ir)
14
+ - dgd: group delay method
15
+ - dcen: centroid of energy
16
+
17
+ Results are written into IRLOC/ir_metrics.txt
18
+
19
+ IRLOC directory containing impulse response data in audio files (.wav, .flac, etc.). Only first channel is analyzed.
20
+
21
+ """
22
+
23
+ import glob
24
+ from os.path import abspath
25
+ from os.path import basename
26
+ from os.path import commonprefix
27
+ from os.path import dirname
28
+ from os.path import isdir
29
+ from os.path import isfile
30
+ from os.path import join
31
+ from os.path import relpath
32
+ from os.path import splitext
33
+
34
+ import matplotlib.pyplot as plt
35
+ import numpy as np
36
+ import pandas as pd
37
+ import soundfile
38
+ from numpy import fft
39
+
40
+ from sonusai.utils.braced_glob import braced_iglob
41
+
42
+
43
+ def tdoa(signal, reference, interp=1, phat=False, fs=1, t_max=None):
44
+ """
45
+ Estimates the shift of array signal with respect to reference
46
+ using generalized cross-correlation
47
+
48
+ Parameters
49
+ ----------
50
+ signal: array_like
51
+ The array whose tdoa is measured
52
+ reference: array_like
53
+ The reference array
54
+ interp: int, optional
55
+ The interpolation factor for the output array, default 1.
56
+ phat: bool, optional
57
+ Apply the PHAT weighting (default False)
58
+ fs: int or float, optional
59
+ The sampling frequency of the input arrays, default=1
60
+
61
+ Returns
62
+ -------
63
+ The estimated delay between the two arrays
64
+ """
65
+
66
+ signal = np.array(signal)
67
+ reference = np.array(reference)
68
+
69
+ N1 = signal.shape[0]
70
+ N2 = reference.shape[0]
71
+
72
+ r_12 = correlate(signal, reference, interp=interp, phat=phat)
73
+
74
+ delay = (np.argmax(np.abs(r_12)) / interp - (N2 - 1)) / fs
75
+
76
+ return delay
77
+
78
+
79
+ def correlate(x1, x2, interp=1, phat=False):
80
+ """
81
+ Compute the cross-correlation between x1 and x2
82
+
83
+ Parameters
84
+ ----------
85
+ x1,x2: array_like
86
+ The data arrays
87
+ interp: int, optional
88
+ The interpolation factor for the output array, default 1.
89
+ phat: bool, optional
90
+ Apply the PHAT weighting (default False)
91
+
92
+ Returns
93
+ -------
94
+ The cross-correlation between the two arrays
95
+ """
96
+
97
+ N1 = x1.shape[0]
98
+ N2 = x2.shape[0]
99
+
100
+ N = N1 + N2 - 1
101
+
102
+ X1 = fft.rfft(x1, n=N)
103
+ X2 = fft.rfft(x2, n=N)
104
+
105
+ if phat:
106
+ eps1 = np.mean(np.abs(X1)) * 1e-10
107
+ X1 /= np.abs(X1) + eps1
108
+ eps2 = np.mean(np.abs(X2)) * 1e-10
109
+ X2 /= np.abs(X2) + eps2
110
+
111
+ m = np.minimum(N1, N2)
112
+
113
+ out = fft.irfft(X1 * np.conj(X2), n=int(N * interp))
114
+
115
+ return np.concatenate([out[-interp * (N2 - 1) :], out[: (interp * N1)]])
116
+
117
+
118
+ def hilbert(u):
119
+ # N : fft length
120
+ # M : number of elements to zero out
121
+ # U : DFT of u
122
+ # v : IDFT of H(U)
123
+
124
+ N = len(u)
125
+ # take forward Fourier transform
126
+ U = fft.fft(u)
127
+ M = N - N // 2 - 1
128
+ # zero out negative frequency components
129
+ U[N // 2 + 1 :] = [0] * M
130
+ # double fft energy except @ DC0
131
+ U[1 : N // 2] = 2 * U[1 : N // 2]
132
+ # take inverse Fourier transform
133
+ v = fft.ifft(U)
134
+ return v
135
+
136
+
137
+ def measure_rt60(h, fs=1, decay_db=60, energy_thres=1.0, plot=False, rt60_tgt=None):
138
+ """
139
+ RT60 Measurement Routine (taken/modified from Pyroom acoustics.)
140
+
141
+ Calculates reverberation time of an impulse response using the Schroeder method [1].
142
+ Returns:
143
+ rt60: Reverberation time to -60db (-5db to -65db), will be estimated from rt20 or rt10 if noise floor > -65db
144
+ edt: Early decay time from 0db to -10db
145
+ rt10: Reverberation time to -10db (-5db to -15db)
146
+ rt20: Reverberation time to -20db (-5db to -25db), will be estimated from rt10 if noise floor > -25db
147
+ floor: 0 if noise floor > -10db or energy curve is not a decay
148
+ 1 if noise floor > -15db and edt is measured, but rt10 estimated from entire energy curve length
149
+ 2 if noise -15db > floor > -25db, rt20 is estimated from measured rt10
150
+ 3 if noise -25db > floor > -65db, rt60 is estimated from measured rt20
151
+ 4 if noise floor < -65db, rt60, edt, rt10, rt20 are all measured
152
+ Optionally plots some useful information.
153
+
154
+ Parameters
155
+ ----------
156
+ h: array_like
157
+ The impulse response.
158
+ fs: float or int, optional
159
+ The sampling frequency of h (default to 1, i.e., samples).
160
+ decay_db: float or int, optional
161
+ The decay in decibels for which we actually estimate the slope and time.
162
+ Although we want to estimate the RT60, it might not be practical. Instead,
163
+ we measure the RT10, RT20 or RT30 and extrapolate to RT60.
164
+ energy_thres: float
165
+ This should be a value between 0.0 and 1.0.
166
+ If provided, the fit will be done using a fraction energy_thres of the
167
+ whole energy. This is useful when there is a long noisy tail for example.
168
+ plot: bool, optional
169
+ If set to ``True``, the power decay and different estimated values will
170
+ be plotted (default False).
171
+ rt60_tgt: float
172
+ This parameter can be used to indicate a target RT60 to which we want
173
+ to compare the estimated value.
174
+
175
+ References
176
+ ----------
177
+
178
+ [1] M. R. Schroeder, "New Method of Measuring Reverberation Time,"
179
+ J. Acoust. Soc. Am., vol. 37, no. 3, pp. 409-412, Mar. 1968.
180
+ """
181
+
182
+ h = np.array(h)
183
+ fs = float(fs)
184
+ h = np.abs(hilbert(h)) # hilbert from scratch, see above
185
+
186
+ # The power of the impulse response in dB
187
+ power = h**2
188
+ # Backward energy integration according to Schroeder
189
+ energy = np.cumsum(power[::-1])[::-1] # Integration according to Schroeder
190
+
191
+ if energy_thres < 1.0:
192
+ assert 0.0 < energy_thres < 1.0
193
+ energy -= energy[0] * (1.0 - energy_thres)
194
+ energy = np.maximum(energy, 0.0)
195
+
196
+ # remove the possibly all zero tail
197
+ i_nz = np.max(np.where(energy > 0)[0])
198
+ energy = energy[:i_nz]
199
+ energy_db = 10 * np.log10(energy)
200
+ energy_db -= energy_db[0] # normalize to first sample assuming it's the peak
201
+
202
+ min_energy_db = -np.min(energy_db)
203
+ if min_energy_db - 5 < decay_db:
204
+ decay_db = min_energy_db
205
+
206
+ # -5 dB headroom
207
+ try:
208
+ i_5db = np.min(np.where(energy_db < -5)[0])
209
+ except ValueError:
210
+ floor = 0
211
+ return 0.0, 0.0, 0.0, 0.0, floor # failed, energy curve is not a decay, or has noise floor tail above -5db
212
+ e_5db = energy_db[i_5db]
213
+ t_5db = i_5db / fs # This is the initial decay to -5db, used as start of decay slope measurements
214
+
215
+ # Estimate slope from 0db to -10db - this is also known as EDT (early decay time)
216
+ try:
217
+ i_10db = np.min(np.where(energy_db < -10)[0])
218
+ except ValueError:
219
+ floor = 0
220
+ return 0.0, 0.0, 0.0, 0.0, floor # failed, energy curve is not a decay, or noise floor tail above -10db
221
+ e_10db = energy_db[i_10db]
222
+ edt = i_10db / fs # this is also known as EDT (early decay time)
223
+
224
+ # after initial decay, estimate RT10, RT20, RT60
225
+ try:
226
+ i_decay10db = np.min(np.where(energy_db < -5 - 10)[0])
227
+ except ValueError:
228
+ floor = 1
229
+ i_decay10db = len(energy_db) # noise floor tail is above -15db, use entire curve
230
+ t10_decay = i_decay10db / fs
231
+ rt10 = t10_decay - t_5db
232
+
233
+ try:
234
+ i_decay20db = np.min(np.where(energy_db < -5 - 20)[0])
235
+ except ValueError:
236
+ floor = 2
237
+ i_decay20db = len(energy_db) # noise floor tail is above -20db, use entire curve
238
+ t20_decay = i_decay20db / fs
239
+ rt20 = t20_decay - t_5db
240
+
241
+ try:
242
+ i_decay60db = np.min(np.where(energy_db < -5 - 60)[0])
243
+ t60_decay = i_decay60db / fs
244
+ rt60 = t60_decay - t_5db
245
+ floor = 4
246
+ except ValueError:
247
+ floor = 3
248
+ i_decay60db = len(energy_db) # noise floor tail is above -60db, use t20_decay to estimate
249
+ t60_decay = 3 * i_decay20db / fs
250
+ rt60 = t60_decay - t_5db
251
+
252
+ # # extrapolate to compute the rt60 decay time from decay_db decay time
253
+ # decay_time = t_decay - t_5db
254
+ # est_rt60 = (60 / decay_db) * decay_time
255
+
256
+ if plot:
257
+ # Remove clip power below to minimum energy (for plotting purpose mostly)
258
+ energy_min = energy[-1]
259
+ energy_db_min = energy_db[-1]
260
+ power[power < energy[-1]] = energy_min
261
+ power_db = 10 * np.log10(power)
262
+ power_db -= np.max(power_db)
263
+
264
+ # time vector
265
+ def get_time(x, fs):
266
+ return np.arange(x.shape[0]) / fs - i_5db / fs
267
+
268
+ T = get_time(power_db, fs)
269
+
270
+ # plot power and energy
271
+ plt.plot(get_time(energy_db, fs), energy_db, label="Energy")
272
+
273
+ # now the linear fit
274
+ plt.plot([0, rt60], [e_5db, -65], "--", label="Linear Fit")
275
+ plt.plot(T, np.ones_like(T) * -60, "--", label="-60 dB")
276
+ plt.vlines(rt60, energy_db_min, 0, linestyles="dashed", label="Estimated RT60")
277
+
278
+ if rt60_tgt is not None:
279
+ plt.vlines(rt60_tgt, energy_db_min, 0, label="Target RT60")
280
+
281
+ plt.legend()
282
+
283
+ return rt60, edt, rt10, rt20, floor
284
+
285
+
286
+ def process_path(path: str, extensions: list[str] | None = None) -> tuple[list, str | None]:
287
+ """
288
+ Check path which can be a single file, a subdirectory, or a regex
289
+ return:
290
+ - a list of files with matching extensions to any in extlist provided (i.e. ['.wav', '.mp3', '.acc'])
291
+ - the basedir of the path, if
292
+ """
293
+ if extensions is None:
294
+ extensions = [".wav", ".WAV", ".flac", ".FLAC", ".mp3", ".aac"]
295
+
296
+ # Check if the path is a single file, and return it as a list with the dirname
297
+ if isfile(path):
298
+ if any(path.endswith(ext) for ext in extensions):
299
+ basedir = dirname(path) # base directory
300
+ if not basedir:
301
+ basedir = "./"
302
+ return [path], basedir
303
+
304
+ return [], None
305
+
306
+ # Check if the path is a dir, recursively find all files any of the specified extensions, return file list and dir
307
+ if isdir(path):
308
+ matching_files = []
309
+ for ext in extensions:
310
+ matching_files.extend(glob.glob(join(path, "**/*" + ext), recursive=True))
311
+ return matching_files, path
312
+
313
+ # Process as a regex, return list of filenames and basedir
314
+ apath = abspath(path) # join(abspath(path), "**", "*.{wav,flac,WAV}")
315
+ matching_files = []
316
+ for file in braced_iglob(pathname=apath, recursive=True):
317
+ matching_files.append(file)
318
+
319
+ if matching_files:
320
+ basedir = commonprefix(matching_files) # Find basedir
321
+ return matching_files, basedir
322
+
323
+ return [], None
324
+
325
+
326
+ def _process_ir(pfile: str, irtab_col: list, basedir: str) -> pd.DataFrame:
327
+ # 1) Read ir audio file, and calc basic stats
328
+ ir_fname = pfile[1] # abs_path
329
+ irwav, sample_rate = soundfile.read(ir_fname)
330
+ if irwav.ndim == 2:
331
+ irwav = irwav[:, 0] # Only first channel of multi-channel
332
+ duration = len(irwav) / sample_rate
333
+ srk = sample_rate / 1000
334
+ ir_basename = relpath(ir_fname, basedir)
335
+
336
+ # 2) Compute delay via autocorrelation (not working - always zero, use interplated tdoa instead)
337
+ # ar = np.correlate(irwav, irwav, mode='same')
338
+ # acdelay_index = np.argmax(ar)
339
+ # dacc= acdelay_index - len(ar) // 2 # Center the delay around 0 of 'same' mode
340
+
341
+ # 3) Compute delay via max argument - find the peak
342
+ peak_index = np.argmax(irwav)
343
+ peak_value = irwav[peak_index]
344
+ dmax = peak_index
345
+
346
+ # 4) Calculate cross-correlation with white gaussian noise ref (ssame as pyrooma.tdoa() with interp=1)
347
+ np.random.seed(42)
348
+ wgn_ref = np.random.normal(0, 0.2, int(np.ceil(0.05 * sample_rate))) # (mean,std_dev,length)
349
+ wgn_conv = np.convolve(irwav, wgn_ref)
350
+ wgn_corr = np.correlate(wgn_conv, wgn_ref, mode="full") # Compute cross-correlation
351
+ delay_index = np.argmax(np.abs(wgn_corr)) # Find the delay (need abs??, yes)
352
+ dcc = delay_index - len(wgn_ref) + 1 # Adjust for the mode='full' shift
353
+ # GCC with PHAT weighting known to be best, but does seem to mismatch dcc, dmax more frequently
354
+ dtdoa = tdoa(wgn_conv, wgn_ref, interp=16, phat=True)
355
+ gdccmax = np.max(np.abs(wgn_conv)) / np.max(np.abs(wgn_ref)) # gain of max value
356
+
357
+ # # 4b) Calculate cross-correlation with chirp 20Hz-20KHz
358
+ # t_end = 2 # 1s
359
+ # t = np.linspace(0, t_end, int(t_end * sample_rate))
360
+ # k = (20 - 20000) / t_end
361
+ # chrp_phase = 2 * np.pi * (20 * t + 0.5 * k * t ** 2)
362
+ # chrp = np.cos(chrp_phase)
363
+ # chrp_convout = np.convolve(irwav,chrp)
364
+ # chrp_corr = np.correlate(chrp_convout, chrp, mode='full') # Compute cross-correlation
365
+ # chrp_delay_idx = np.argmax(np.abs(chrp_corr))
366
+ # dcchr = chrp_delay_idx - len(chrp) + 1
367
+ # dtdoachr = tdoa(chrp_convout, chrp, interp=16, phat=False)
368
+ # gdcchrmax = np.max(np.abs(chrp_convout)) / np.max(np.abs(chrp))
369
+ # #sin_ref = np.sin(2 * np.pi * 500/sample_rate * np.arange(0,sample_rate))
370
+
371
+ # # Create a pulse train alternating +1, -1, ... of width PW, spacing PS_ms
372
+ # PS = int(0.010 * sample_rate) # Spacing between pulses in sec (to samples)
373
+ # PW = 5 # Pulse width in samples, make sure < PS
374
+ # PTLEN = int(1 * sample_rate) # Length in sec (to samples)
375
+ # #sample_vec = np.arange(PTLEN)
376
+ #
377
+ # # Construct the pulse train
378
+ # ptrain_ref = np.zeros(PTLEN)
379
+ # polarity = 1
380
+ # for i in range(0, PTLEN, PS):
381
+ # if polarity == 1:
382
+ # ptrain_ref[i:(i + PW)] = 1
383
+ # polarity = -1
384
+ # else:
385
+ # ptrain_ref[i:(i + PW)] = -1
386
+ # polarity = 1
387
+ #
388
+ # pt_convout = np.convolve(irwav,ptrain_ref)
389
+ # pt_corr = np.correlate(pt_convout, ptrain_ref, mode='full') # Compute cross-correlation
390
+ # pt_delay_idx = np.argmax(np.abs(pt_corr))
391
+ # dcc = pt_delay_idx - len(ptrain_ref) + 1
392
+ # dtdoa = tdoa(pt_convout, ptrain_ref, interp=16, phat=True)
393
+ # gdccptmax = np.max(np.abs(pt_convout)) / np.max(np.abs(ptrain_ref))
394
+
395
+ # 5) Calculate delay using group_delay method
396
+ fft_size = len(irwav)
397
+ H = np.fft.fft(irwav, n=fft_size)
398
+ phase = np.unwrap(np.angle(H))
399
+ freq = np.fft.fftfreq(fft_size) # in samples, using d=1/sampling_rate=1
400
+ group_delay = -np.gradient(phase) / (2 * np.pi * np.gradient(freq))
401
+ dagd = np.mean(group_delay[np.isfinite(group_delay)]) # Average group delay
402
+ gmax = max(np.abs(H))
403
+
404
+ rt60, edt, rt10, rt20, nfloor = measure_rt60(irwav, sample_rate, plot=False)
405
+
406
+ # 4) Tabulate metrics as single row in table of scalar metrics per mixture
407
+ # irtab_col = ["dmax", "dcc", "dccphat", "dagd", "gdccmax", "rt20", "rt60", "max", "min", "gmax", "dur", "sr", "irfile"]
408
+ metr1 = [dmax, dcc, dtdoa, dagd, gdccmax, rt20, rt60, peak_value, min(irwav), gmax, duration, srk, ir_basename]
409
+ mtab1 = pd.DataFrame([metr1], columns=irtab_col, index=[pfile[0]]) # return tuple of dataframe
410
+
411
+ return mtab1
412
+
413
+
414
+ def main():
415
+ from docopt import docopt
416
+
417
+ from . import __version__ as sai_version
418
+ from .utils.docstring import trim_docstring
419
+
420
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
421
+
422
+ ir_location = args["IRLOC"]
423
+ num_proc = args["--num_process"]
424
+
425
+ import psutil
426
+
427
+ from .utils.create_timestamp import create_timestamp
428
+ from .utils.parallel import par_track
429
+ from .utils.parallel import track
430
+
431
+ # Check location, default ext are ['.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.aac']
432
+ pfiles, basedir = process_path(ir_location)
433
+ pfiles = sorted(pfiles, key=basename)
434
+
435
+ if pfiles is None or len(pfiles) < 1:
436
+ print(f"No IR audio files found in {ir_location}, exiting ...")
437
+ raise SystemExit(1)
438
+ if len(pfiles) == 1:
439
+ print(f"Found single IR audio file {ir_location} , writing to *-irmetric.txt ...")
440
+ fbase, ext = splitext(basename(pfiles[0]))
441
+ wlcsv_name = None
442
+ txt_fname = str(join(basedir, fbase + "-irmetric.txt"))
443
+ else:
444
+ print(f"Found {len(pfiles)} files under {basedir} for impulse response metric calculations")
445
+ wlcsv_name = str(join(basedir, "ir_metric_list.csv"))
446
+ txt_fname = str(join(basedir, "ir_metric_summary.txt"))
447
+
448
+ num_cpu = psutil.cpu_count()
449
+ cpu_percent = psutil.cpu_percent(interval=1)
450
+ print(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
451
+ print(f"Memory utilization: {psutil.virtual_memory().percent}%")
452
+ if num_proc == "auto":
453
+ use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
454
+ elif num_proc == "None":
455
+ use_cpu = None
456
+ else:
457
+ use_cpu = min(max(int(num_proc), 1), num_cpu)
458
+
459
+ timestamp = create_timestamp()
460
+ # Individual mixtures use pandas print, set precision to 2 decimal places
461
+ # pd.set_option('float_format', '{:.2f}'.format)
462
+ print(f"Calculating metrics for {len(pfiles)} impulse response files using {use_cpu} parallel processes ...")
463
+ progress = track(total=len(pfiles))
464
+ if use_cpu is None or len(pfiles) == 1:
465
+ no_par = True
466
+ num_cpus = None
467
+ else:
468
+ no_par = False
469
+ num_cpus = use_cpu
470
+
471
+ from functools import partial
472
+
473
+ # Setup pandas table for summarizing ir metrics
474
+ irtab_col = [
475
+ "dmax",
476
+ "dcc",
477
+ "dccphat",
478
+ "dagd",
479
+ "gdccmax",
480
+ "rt20",
481
+ "rt60",
482
+ "max",
483
+ "min",
484
+ "gmax",
485
+ "dur",
486
+ "sr",
487
+ "irfile",
488
+ ]
489
+ idx = range(len(pfiles))
490
+ llfiles = list(zip(idx, pfiles, strict=False))
491
+
492
+ all_metrics_tables = par_track(
493
+ partial(
494
+ _process_ir,
495
+ irtab_col=irtab_col,
496
+ basedir=basedir,
497
+ ),
498
+ llfiles,
499
+ progress=progress,
500
+ num_cpus=num_cpus,
501
+ no_par=no_par,
502
+ )
503
+ progress.close()
504
+
505
+ # progress = tqdm(total=len(pfiles), desc='ir_metric')
506
+ # if use_cpu is None:
507
+ # all_metrics_tab = pp_tqdm_imap(_process_mixture, pfiles, progress=progress, no_par=True)
508
+ # else:
509
+ # all_metrics_tab = pp_tqdm_imap(_process_mixture, pfiles, progress=progress, num_cpus=use_cpu)
510
+ # progress.close()
511
+
512
+ header_args = {
513
+ "mode": "a",
514
+ "encoding": "utf-8",
515
+ "index": False,
516
+ "header": False,
517
+ }
518
+ table_args = {
519
+ "mode": "a",
520
+ "encoding": "utf-8",
521
+ }
522
+
523
+ all_metrics_tab = pd.concat([item for item in all_metrics_tables]) # already sorted by truth filename via idx
524
+ mtabsort = all_metrics_tab.sort_values(by=["irfile"])
525
+
526
+ # Write list to .csv
527
+ if wlcsv_name:
528
+ pd.DataFrame([["Timestamp", timestamp]]).to_csv(wlcsv_name, header=False, index=False)
529
+ pd.DataFrame([f"IR metric list for {ir_location}:"]).to_csv(wlcsv_name, mode="a", header=False, index=False)
530
+ mtabsort.round(2).to_csv(wlcsv_name, **table_args)
531
+
532
+ # Write summary and list to .txt
533
+ with open(txt_fname, "w") as f:
534
+ print(f"Timestamp: {timestamp}", file=f)
535
+ print(f"IR metrics stats over {len(llfiles)} files:", file=f)
536
+ print(mtabsort.describe().round(3).T.to_string(float_format=lambda x: f"{x:.3f}", index=True), file=f)
537
+ print("", file=f)
538
+ print("", file=f)
539
+ print([f"IR metric list for {ir_location}:"], file=f)
540
+ print(mtabsort.round(3).to_string(), file=f)
541
+
542
+
543
+ if __name__ == "__main__":
544
+ from sonusai import exception_handler
545
+ from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
546
+
547
+ register_keyboard_interrupt()
548
+ try:
549
+ main()
550
+ except Exception as e:
551
+ exception_handler(e)
sonusai/lsdb.py ADDED
@@ -0,0 +1,141 @@
1
+ """sonusai lsdb
2
+
3
+ usage: lsdb [-hsa] [-i MIXID] [-c CID] LOC
4
+
5
+ Options:
6
+ -h, --help
7
+ -i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
8
+ -c CID, --class_index CID Analyze mixtures that contain this class index.
9
+ -s, --sources List all source files.
10
+ -a, --all_class_counts List all class counts.
11
+
12
+ List mixture data information from a SonusAI mixture database.
13
+
14
+ Inputs:
15
+ LOC A SonusAI mixture database directory.
16
+
17
+ """
18
+
19
+ from sonusai.datatypes import GeneralizedIDs
20
+ from sonusai.mixture import MixtureDatabase
21
+
22
+
23
+ def lsdb(
24
+ mixdb: MixtureDatabase,
25
+ mixids: GeneralizedIDs = "*",
26
+ class_index: int | None = None,
27
+ list_targets: bool = False,
28
+ all_class_counts: bool = False,
29
+ ) -> None:
30
+ from sonusai import logger
31
+ from sonusai.constants import SAMPLE_RATE
32
+ from sonusai.queries.queries import get_mixids_from_class_indices
33
+ from sonusai.utils.print_mixture_details import print_mixture_details
34
+ from sonusai.utils.ranges import consolidate_range
35
+ from sonusai.utils.seconds_to_hms import seconds_to_hms
36
+
37
+ desc_len = 24
38
+
39
+ total_samples = mixdb.total_samples()
40
+ total_duration = total_samples / SAMPLE_RATE
41
+
42
+ logger.info(f"{'Mixtures':{desc_len}} {mixdb.num_mixtures}")
43
+ logger.info(f"{'Duration':{desc_len}} {seconds_to_hms(seconds=total_duration)}")
44
+ logger.info(f"{'Sources':{desc_len}} {mixdb.num_source_files}")
45
+ logger.info(f"{'Feature':{desc_len}} {mixdb.feature}")
46
+ logger.info(
47
+ f"{'Feature shape':{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} "
48
+ f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
49
+ )
50
+ logger.info(f"{'Feature samples':{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
51
+ logger.info(
52
+ f"{'Feature step samples':{desc_len}} {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)"
53
+ )
54
+ logger.info(f"{'Feature overlap':{desc_len}} {mixdb.fg_step / mixdb.fg_stride} ({mixdb.feature_step_ms} ms)")
55
+ logger.info(f"{'SNRs':{desc_len}} {mixdb.snrs}")
56
+ logger.info(f"{'Random SNRs':{desc_len}} {mixdb.random_snrs}")
57
+ logger.info(f"{'Classes':{desc_len}} {mixdb.num_classes}")
58
+ # TODO: fix class count
59
+ logger.info(f"{'Class count':{desc_len}} not supported")
60
+ # print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info)
61
+ # TODO: add class weight calculations here
62
+ logger.info("")
63
+
64
+ if list_targets:
65
+ logger.info("Source details:")
66
+ for category, sources in mixdb.source_files.items():
67
+ print(f" {category}:")
68
+ for source in sources:
69
+ logger.info(f"{' Name':{desc_len}} {source.name}")
70
+ logger.info(f"{' Truth index':{desc_len}} {source.class_indices}")
71
+ logger.info("")
72
+
73
+ if class_index is not None:
74
+ if 0 <= class_index > mixdb.num_classes:
75
+ raise ValueError(f"Given class_index is outside valid range of 1-{mixdb.num_classes}")
76
+ ids = get_mixids_from_class_indices(mixdb=mixdb, predicate=lambda x: x in [class_index])[class_index]
77
+ logger.info(f"Mixtures with class index {class_index}: {ids}")
78
+ logger.info("")
79
+
80
+ mixids = mixdb.mixids_to_list(mixids)
81
+
82
+ if len(mixids) == 1:
83
+ print_mixture_details(mixdb=mixdb, mixid=mixids[0], print_fn=logger.info)
84
+ if all_class_counts:
85
+ # TODO: fix class count
86
+ logger.info("All class count not supported")
87
+ # print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info, all_class_counts=True)
88
+ else:
89
+ logger.info(
90
+ f"Calculating statistics from truth_f files for {len(mixids):,} mixtures ({consolidate_range(mixids)})"
91
+ )
92
+ logger.info("Not supported")
93
+
94
+
95
+ def main() -> None:
96
+ from docopt import docopt
97
+
98
+ from sonusai import __version__ as sai_version
99
+ from sonusai.utils.docstring import trim_docstring
100
+
101
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
102
+
103
+ mixid = args["--mixid"]
104
+ class_index = args["--class_index"]
105
+ list_targets = args["--targets"]
106
+ all_class_counts = args["--all_class_counts"]
107
+ location = args["LOC"]
108
+
109
+ from sonusai import create_file_handler
110
+ from sonusai import initial_log_messages
111
+ from sonusai import logger
112
+ from sonusai import update_console_handler
113
+
114
+ if class_index is not None:
115
+ class_index = int(class_index)
116
+
117
+ create_file_handler("lsdb.log")
118
+ update_console_handler(False)
119
+ initial_log_messages("lsdb")
120
+
121
+ logger.info(f"Analyzing {location}")
122
+
123
+ mixdb = MixtureDatabase(location)
124
+ lsdb(
125
+ mixdb=mixdb,
126
+ mixids=mixid,
127
+ class_index=class_index,
128
+ list_targets=list_targets,
129
+ all_class_counts=all_class_counts,
130
+ )
131
+
132
+
133
+ if __name__ == "__main__":
134
+ from sonusai import exception_handler
135
+ from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
136
+
137
+ register_keyboard_interrupt()
138
+ try:
139
+ main()
140
+ except Exception as e:
141
+ exception_handler(e)