sonusai 0.19.9__py3-none-any.whl → 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sonusai/calc_metric_spenh.py +265 -233
  2. sonusai/data/genmixdb.yml +4 -2
  3. sonusai/data/silero_vad_v5.1.jit +0 -0
  4. sonusai/data/silero_vad_v5.1.onnx +0 -0
  5. sonusai/doc/doc.py +14 -0
  6. sonusai/genft.py +1 -1
  7. sonusai/genmetrics.py +15 -18
  8. sonusai/genmix.py +1 -1
  9. sonusai/genmixdb.py +30 -52
  10. sonusai/ir_metric.py +555 -0
  11. sonusai/metrics_summary.py +322 -0
  12. sonusai/mixture/__init__.py +6 -2
  13. sonusai/mixture/audio.py +139 -15
  14. sonusai/mixture/augmentation.py +199 -84
  15. sonusai/mixture/config.py +9 -4
  16. sonusai/mixture/constants.py +0 -1
  17. sonusai/mixture/datatypes.py +19 -10
  18. sonusai/mixture/generation.py +52 -64
  19. sonusai/mixture/helpers.py +38 -26
  20. sonusai/mixture/ir_delay.py +63 -0
  21. sonusai/mixture/mixdb.py +190 -46
  22. sonusai/mixture/targets.py +3 -6
  23. sonusai/mixture/truth_functions/energy.py +9 -5
  24. sonusai/mixture/truth_functions/metrics.py +1 -1
  25. sonusai/mkwav.py +1 -1
  26. sonusai/onnx_predict.py +1 -1
  27. sonusai/queries/queries.py +1 -1
  28. sonusai/utils/__init__.py +2 -0
  29. sonusai/utils/asr.py +1 -1
  30. sonusai/utils/load_object.py +8 -2
  31. sonusai/utils/stratified_shuffle_split.py +1 -1
  32. sonusai/utils/temp_seed.py +13 -0
  33. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/METADATA +2 -2
  34. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/RECORD +36 -35
  35. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/WHEEL +1 -1
  36. sonusai/mixture/soundfile_audio.py +0 -130
  37. sonusai/mixture/sox_audio.py +0 -476
  38. sonusai/mixture/sox_augmentation.py +0 -136
  39. sonusai/mixture/torchaudio_audio.py +0 -106
  40. sonusai/mixture/torchaudio_augmentation.py +0 -109
  41. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,322 @@
1
+ """sonusai metrics_summary
2
+
3
+ usage: lsdb [-vlh] [-i MIXID] [-n NCPU] LOCATION
4
+
5
+ Options:
6
+ -h, --help
7
+ -v, --verbose
8
+ -l, --write-list Write .csv file list of all mixture metrics
9
+ -i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
10
+ -n, --num_process NCPU Number of parallel processes to use [default: auto]
11
+
12
+ Summarize mixture metrics across a SonusAI mixture database where metrics have been generated by SonusAI genmetrics.
13
+
14
+ Inputs:
15
+ LOCATION A SonusAI mixture database directory with mixdb.db and pre-generated metrics from SonusAI genmetrics.
16
+
17
+ """
18
+
19
+ import signal
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+
24
+
25
+ def signal_handler(_sig, _frame):
26
+ import sys
27
+
28
+ from sonusai import logger
29
+
30
+ logger.info("Canceled due to keyboard interrupt")
31
+ sys.exit(1)
32
+
33
+
34
+ signal.signal(signal.SIGINT, signal_handler)
35
+
36
+ DB_99 = np.power(10, 99 / 10)
37
+ DB_N99 = np.power(10, -99 / 10)
38
+
39
+
40
+ def _process_mixture(
41
+ m_id: int,
42
+ location: str,
43
+ all_metric_names: list[str],
44
+ scalar_metric_names: list[str],
45
+ string_metric_names: list[str],
46
+ frame_metric_names: list[str],
47
+ bin_metric_names: list[str],
48
+ ptab_labels: list[str],
49
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
50
+ from os.path import basename
51
+
52
+ from sonusai.metrics import calc_wer
53
+ from sonusai.mixture import SAMPLE_RATE
54
+ from sonusai.mixture import MixtureDatabase
55
+
56
+ mixdb = MixtureDatabase(location)
57
+
58
+ # Process mixture
59
+ # for mixid in mixids:
60
+ samples = mixdb.mixture(m_id).samples
61
+ duration = samples / SAMPLE_RATE
62
+ tf_frames = mixdb.mixture_transform_frames(m_id)
63
+ feat_frames = mixdb.mixture_feature_frames(m_id)
64
+ mxsnr = mixdb.mixture(m_id).snr
65
+ ti = mixdb.mixture(m_id).targets[0].file_id
66
+ ni = mixdb.mixture(m_id).noise.file_id
67
+ t0file = basename(mixdb.target_file(ti).name)
68
+ nfile = basename(mixdb.noise_file(ni).name)
69
+
70
+ all_metrics = mixdb.mixture_metrics(m_id, all_metric_names)
71
+
72
+ # replace lists with first value (ignore mixup)
73
+ scalar_metrics = {
74
+ key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
75
+ for key in scalar_metric_names
76
+ }
77
+ string_metrics = {
78
+ key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
79
+ for key in string_metric_names
80
+ }
81
+
82
+ # Convert strings into word count
83
+ for key in string_metrics:
84
+ string_metrics[key] = calc_wer(string_metrics[key], string_metrics[key]).words
85
+
86
+ # Collect pandas table values note: must match given ptab_labels
87
+ ptab_data: list = [
88
+ mxsnr,
89
+ *scalar_metrics.values(),
90
+ *string_metrics.values(),
91
+ tf_frames,
92
+ duration,
93
+ t0file,
94
+ nfile,
95
+ ]
96
+
97
+ ptab1 = pd.DataFrame([ptab_data], columns=ptab_labels, index=[m_id])
98
+
99
+ # TODO: collect frame metrics and bin metrics
100
+
101
+ return ptab1, ptab1
102
+
103
+
104
+ def main() -> None:
105
+ from docopt import docopt
106
+
107
+ from sonusai import __version__ as sonusai_ver
108
+ from sonusai.utils import trim_docstring
109
+
110
+ args = docopt(trim_docstring(__doc__), version=sonusai_ver, options_first=True)
111
+
112
+ verbose = args["--verbose"]
113
+ wrlist = args["--write-list"]
114
+ mixids = args["--mixid"]
115
+ location = args["LOCATION"]
116
+ num_proc = args["--num_process"]
117
+
118
+ from functools import partial
119
+ from os.path import basename
120
+ from os.path import join
121
+
122
+ import psutil
123
+
124
+ from sonusai import create_file_handler
125
+ from sonusai import initial_log_messages
126
+ from sonusai import logger
127
+ from sonusai import update_console_handler
128
+ from sonusai.mixture import MixtureDatabase
129
+ from sonusai.utils import create_timestamp
130
+ from sonusai.utils import par_track
131
+ from sonusai.utils import track
132
+
133
+ try:
134
+ mixdb = MixtureDatabase(location)
135
+ print(f"Found SonusAI mixture database with {mixdb.num_mixtures} mixtures.")
136
+ except:
137
+ print(f"Could not open SonusAI mixture database in {location}, exiting ...")
138
+ return
139
+
140
+ # Only check first and last mixture in order to save time
141
+ metrics_present = mixdb.cached_metrics([0, mixdb.num_mixtures - 1])
142
+
143
+ num_metrics_present = len(metrics_present)
144
+ if num_metrics_present < 1:
145
+ print(f"mixdb reports no pre-generated metrics are present. Nothing to summarize in {location}, exiting ...")
146
+ return
147
+
148
+ # Setup logging file
149
+ timestamp = create_timestamp() # string good for embedding into filenames
150
+ mixdb_fname = basename(location)
151
+ if verbose:
152
+ create_file_handler(join(location, "metrics_summary.log"))
153
+ update_console_handler(verbose)
154
+ initial_log_messages("metrics_summary")
155
+ logger.info(f"Logging summary of SonusAI mixture database at {location}")
156
+ else:
157
+ update_console_handler(verbose)
158
+
159
+ logger.info("")
160
+ mixids = mixdb.mixids_to_list(mixids)
161
+ if len(mixids) < mixdb.num_mixtures:
162
+ logger.info(
163
+ f"Processing a subset of {len(mixids)} out of total mixdb mixtures of {mixdb.num_mixtures}, "
164
+ f"summary results will not include entire dataset."
165
+ )
166
+ fsuffix = f"_s{len(mixids)}t{mixdb.num_mixtures}"
167
+ else:
168
+ logger.info(
169
+ f"Summarizing SonusAI mixture database with {mixdb.num_mixtures} mixtures "
170
+ f"and {num_metrics_present} pre-generated metrics ..."
171
+ )
172
+ fsuffix = ""
173
+
174
+ metric_sup = mixdb.supported_metrics
175
+ ft_bins = mixdb.ft_config.bin_end - mixdb.ft_config.bin_start + 1 # bins of forward transform
176
+ # Pre-process first mixid to gather metrics into 4 types: scalar, str (scalar word cnt), frame-array, bin-array
177
+ # Collect list of indices for each
178
+ scalar_metric_names: list[str] = []
179
+ string_metric_names: list[str] = []
180
+ frame_metric_names: list[str] = []
181
+ bin_metric_names: list[str] = []
182
+ all_metrics = mixdb.mixture_metrics(mixids[0], metrics_present)
183
+ tf_frames = mixdb.mixture_transform_frames(mixids[0])
184
+ for metric in metrics_present:
185
+ metval = all_metrics[metric] # get metric value
186
+ logger.debug(f"First mixid {mixids[0]} metric {metric} = {metval}")
187
+ if isinstance(metval, list):
188
+ if len(metval) > 1:
189
+ logger.warning(f"Mixid {mixids[0]} metric {metric} has a list with more than 1 element, using first.")
190
+ metval = metval[0] # remove any list
191
+ if isinstance(metval, float):
192
+ logger.debug("Metric is scalar float, entering in summary table.")
193
+ scalar_metric_names.append(metric)
194
+ elif isinstance(metval, str):
195
+ logger.debug("Metric is string, will summarize with word count.")
196
+ string_metric_names.append(metric)
197
+ elif isinstance(metval, np.ndarray):
198
+ if metval.ndim == 1:
199
+ if metval.size == tf_frames:
200
+ logger.debug("Metric is frames vector.")
201
+ frame_metric_names.append(metric)
202
+ elif metval.size == ft_bins:
203
+ logger.debug("Metric is bins vector.")
204
+ bin_metric_names.append(metric)
205
+ else:
206
+ logger.warning(f"Mixid {mixids[0]} metric {metric} is a vector of improper size, ignoring.")
207
+
208
+ # Setup pandas table for summarizing scalar metrics
209
+ ptab_labels = [
210
+ "mxsnr",
211
+ *scalar_metric_names,
212
+ *string_metric_names,
213
+ "fcnt",
214
+ "duration",
215
+ "t0file",
216
+ "nfile",
217
+ ]
218
+
219
+ num_cpu = psutil.cpu_count()
220
+ cpu_percent = psutil.cpu_percent(interval=1)
221
+ logger.info("")
222
+ logger.info(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
223
+ logger.info(f"Memory utilization: {psutil.virtual_memory().percent}%")
224
+ if num_proc == "auto":
225
+ use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
226
+ elif num_proc == "None":
227
+ use_cpu = None
228
+ else:
229
+ use_cpu = min(max(int(num_proc), 1), num_cpu)
230
+
231
+ logger.info(f"Summarizing metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
232
+
233
+ # progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
234
+ progress = track(total=len(mixids))
235
+ if use_cpu is None:
236
+ no_par = True
237
+ num_cpus = None
238
+ else:
239
+ no_par = False
240
+ num_cpus = use_cpu
241
+
242
+ all_metrics_tables = par_track(
243
+ partial(
244
+ _process_mixture,
245
+ location=location,
246
+ all_metric_names=metrics_present,
247
+ scalar_metric_names=scalar_metric_names,
248
+ string_metric_names=string_metric_names,
249
+ frame_metric_names=frame_metric_names,
250
+ bin_metric_names=bin_metric_names,
251
+ ptab_labels=ptab_labels,
252
+ ),
253
+ mixids,
254
+ progress=progress,
255
+ num_cpus=num_cpus,
256
+ no_par=no_par,
257
+ )
258
+ progress.close()
259
+
260
+ # Done with mixtures, write out summary metrics
261
+ header_args = {
262
+ "mode": "a",
263
+ "encoding": "utf-8",
264
+ "index": False,
265
+ "header": False,
266
+ }
267
+ table_args = {
268
+ "mode": "a",
269
+ "encoding": "utf-8",
270
+ }
271
+ ptab1 = pd.concat([item[0] for item in all_metrics_tables])
272
+ if wrlist:
273
+ wlcsv_name = str(join(location, "metric_summary_list" + fsuffix + ".csv"))
274
+ pd.DataFrame([["Timestamp", timestamp]]).to_csv(wlcsv_name, header=False, index=False)
275
+ pd.DataFrame([f"Metric list for {mixdb_fname}:"]).to_csv(wlcsv_name, mode="a", header=False, index=False)
276
+ ptab1.round(2).to_csv(wlcsv_name, **table_args)
277
+ ptab1_sorted = ptab1.sort_values(by=["mxsnr", "t0file"])
278
+
279
+ # Create metrics table except except -99 SNR
280
+ ptab1_nom99 = ptab1_sorted[ptab1_sorted.mxsnr != -99]
281
+
282
+ # Create summary by SNR for all scalar metrics, taking mean
283
+ mtab_snr_summary = None
284
+ for snri in range(0, len(mixdb.snrs)):
285
+ tmp = ptab1_sorted.query("mxsnr==" + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
286
+ # avoid nan when subset of mixids specified (i.e. no mixtures exist for an SNR)
287
+ if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
288
+ mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
289
+ mtab_snr_summary = mtab_snr_summary.sort_values(by=["mxsnr"], ascending=False)
290
+
291
+ # Write summary to .csv
292
+ snrcsv_name = str(join(location, "metric_summary_snr" + fsuffix + ".csv"))
293
+ nmix = len(mixids)
294
+ nmixtot = mixdb.num_mixtures
295
+ pd.DataFrame([["Timestamp", timestamp]]).to_csv(snrcsv_name, header=False, index=False)
296
+ pd.DataFrame(['"Metrics avg over each SNR:"']).to_csv(snrcsv_name, **header_args)
297
+ mtab_snr_summary.round(2).to_csv(snrcsv_name, index=False, **table_args)
298
+ pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
299
+ pd.DataFrame([f'"Metrics stats over {nmix} mixtures out of {nmixtot} total:"']).to_csv(snrcsv_name, **header_args)
300
+ ptab1.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
301
+ pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
302
+ pd.DataFrame([f'"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {nmixtot} total:"']).to_csv(
303
+ snrcsv_name, **header_args
304
+ )
305
+ ptab1_nom99.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
306
+
307
+ # Write summary to .csv
308
+ snrtxt_name = str(join(location, "metric_summary_snr" + fsuffix + ".txt"))
309
+ with open(snrtxt_name, "w") as f:
310
+ print(f"Timestamp: {timestamp}", file=f)
311
+ print("Metrics avg over each SNR:", file=f)
312
+ print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f)
313
+ print("", file=f)
314
+ print(f"Metrics stats over {len(mixids)} mixtures out of {mixdb.num_mixtures} total:", file=f)
315
+ print(ptab1.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
316
+ print("", file=f)
317
+ print(f"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {mixdb.num_mixtures} total:", file=f)
318
+ print(ptab1_nom99.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
319
+
320
+
321
+ if __name__ == "__main__":
322
+ main()
@@ -5,6 +5,7 @@ from .audio import get_duration
5
5
  from .audio import get_next_noise
6
6
  from .audio import get_num_samples
7
7
  from .audio import get_sample_rate
8
+ from .audio import raw_read_audio
8
9
  from .audio import read_audio
9
10
  from .audio import read_ir
10
11
  from .audio import validate_input_file
@@ -53,7 +54,9 @@ from .datatypes import AudioF
53
54
  from .datatypes import AudioStatsMetrics
54
55
  from .datatypes import AudioT
55
56
  from .datatypes import Augmentation
57
+ from .datatypes import AugmentationEffects
56
58
  from .datatypes import AugmentationRule
59
+ from .datatypes import AugmentationRuleEffects
57
60
  from .datatypes import AugmentedTarget
58
61
  from .datatypes import ClassCount
59
62
  from .datatypes import EnergyF
@@ -87,6 +90,7 @@ from .datatypes import TruthParameter
87
90
  from .datatypes import UniversalSNR
88
91
  from .feature import get_audio_from_feature
89
92
  from .feature import get_feature_from_audio
93
+ from .generation import generate_mixtures
90
94
  from .generation import get_all_snrs_from_config
91
95
  from .generation import initialize_db
92
96
  from .generation import populate_class_label_table
@@ -99,7 +103,7 @@ from .generation import populate_target_file_table
99
103
  from .generation import populate_top_table
100
104
  from .generation import populate_truth_parameters_table
101
105
  from .generation import update_mixid_width
102
- from .generation import update_mixture_table
106
+ from .generation import update_mixture
103
107
  from .helpers import augmented_noise_samples
104
108
  from .helpers import augmented_target_samples
105
109
  from .helpers import check_audio_files_exist
@@ -110,10 +114,10 @@ from .helpers import get_transform_from_audio
110
114
  from .helpers import inverse_transform
111
115
  from .helpers import mixture_metadata
112
116
  from .helpers import write_mixture_metadata
117
+ from .ir_delay import get_impulse_response_delay
113
118
  from .log_duration_and_sizes import log_duration_and_sizes
114
119
  from .mixdb import MixtureDatabase
115
120
  from .mixdb import db_file
116
- from .sox_audio import Transformer
117
121
  from .spectral_mask import apply_spectral_mask
118
122
  from .target_class_balancing import balance_targets
119
123
  from .targets import get_augmented_target_ids_by_class
sonusai/mixture/audio.py CHANGED
@@ -44,49 +44,173 @@ def validate_input_file(input_filepath: str | Path) -> None:
44
44
  raise OSError(f"This installation cannot process .{ext} files")
45
45
 
46
46
 
47
- @lru_cache
48
- def get_sample_rate(name: str | Path) -> int:
47
+ def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
49
48
  """Get sample rate from audio file
50
49
 
51
50
  :param name: File name
51
+ :param use_cache: If true, use LRU caching
52
52
  :return: Sample rate
53
53
  """
54
- from .soundfile_audio import get_sample_rate
55
-
56
- return get_sample_rate(name)
54
+ if use_cache:
55
+ return _get_sample_rate(name)
56
+ return _get_sample_rate.__wrapped__(name)
57
57
 
58
58
 
59
59
  @lru_cache
60
- def read_audio(name: str | Path) -> AudioT:
60
+ def _get_sample_rate(name: str | Path) -> int:
61
+ """Get sample rate from audio file using soundfile
62
+
63
+ :param name: File name
64
+ :return: Sample rate
65
+ """
66
+ import soundfile
67
+ from pydub import AudioSegment
68
+
69
+ from .tokenized_shell_vars import tokenized_expand
70
+
71
+ expanded_name, _ = tokenized_expand(name)
72
+
73
+ try:
74
+ if expanded_name.endswith(".mp3"):
75
+ return AudioSegment.from_mp3(expanded_name).frame_rate
76
+
77
+ if expanded_name.endswith(".m4a"):
78
+ return AudioSegment.from_file(expanded_name).frame_rate
79
+
80
+ return soundfile.info(expanded_name).samplerate
81
+ except Exception as e:
82
+ if name != expanded_name:
83
+ raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
84
+ else:
85
+ raise OSError(f"Error reading {name}: {e}") from e
86
+
87
+
88
+ def raw_read_audio(name: str | Path) -> tuple[AudioT, int]:
89
+ import numpy as np
90
+ import soundfile
91
+ from pydub import AudioSegment
92
+
93
+ from .tokenized_shell_vars import tokenized_expand
94
+
95
+ expanded_name, _ = tokenized_expand(name)
96
+
97
+ try:
98
+ if expanded_name.endswith(".mp3"):
99
+ sound = AudioSegment.from_mp3(expanded_name)
100
+ raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
101
+ raw = raw / 2 ** (sound.sample_width * 8 - 1)
102
+ sample_rate = sound.frame_rate
103
+ elif expanded_name.endswith(".m4a"):
104
+ sound = AudioSegment.from_file(expanded_name)
105
+ raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
106
+ raw = raw / 2 ** (sound.sample_width * 8 - 1)
107
+ sample_rate = sound.frame_rate
108
+ else:
109
+ raw, sample_rate = soundfile.read(expanded_name, always_2d=True, dtype="float32")
110
+ except Exception as e:
111
+ if name != expanded_name:
112
+ raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
113
+ else:
114
+ raise OSError(f"Error reading {name}: {e}") from e
115
+
116
+ return np.squeeze(raw[:, 0].astype(np.float32)), sample_rate
117
+
118
+
119
+ def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
61
120
  """Read audio data from a file
62
121
 
63
122
  :param name: File name
123
+ :param use_cache: If true, use LRU caching
64
124
  :return: Array of time domain audio data
65
125
  """
66
- from .soundfile_audio import read_audio
67
-
68
- return read_audio(name)
126
+ if use_cache:
127
+ return _read_audio(name)
128
+ return _read_audio.__wrapped__(name)
69
129
 
70
130
 
71
131
  @lru_cache
72
- def read_ir(name: str | Path) -> ImpulseResponseData:
132
+ def _read_audio(name: str | Path) -> AudioT:
133
+ """Read audio data from a file using soundfile
134
+
135
+ :param name: File name
136
+ :return: Array of time domain audio data
137
+ """
138
+ import librosa
139
+
140
+ from .constants import SAMPLE_RATE
141
+
142
+ out, sample_rate = raw_read_audio(name)
143
+ out = librosa.resample(out, orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
144
+
145
+ return out
146
+
147
+
148
+ def read_ir(name: str | Path, delay: int, use_cache: bool = True) -> ImpulseResponseData:
73
149
  """Read impulse response data
74
150
 
75
151
  :param name: File name
152
+ :param delay: Delay in samples
153
+ :param use_cache: If true, use LRU caching
154
+ :return: ImpulseResponseData object
155
+ """
156
+ if use_cache:
157
+ return _read_ir(name, delay)
158
+ return _read_ir.__wrapped__(name, delay)
159
+
160
+
161
+ @lru_cache
162
+ def _read_ir(name: str | Path, delay: int) -> ImpulseResponseData:
163
+ """Read impulse response data using soundfile
164
+
165
+ :param name: File name
166
+ :param delay: Delay in samples
76
167
  :return: ImpulseResponseData object
77
168
  """
78
- from .soundfile_audio import read_ir
169
+ out, sample_rate = raw_read_audio(name)
170
+
171
+ return ImpulseResponseData(data=out, sample_rate=sample_rate, delay=delay)
172
+
173
+
174
+ def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
175
+ """Get the number of samples resampled to the SonusAI sample rate in the given file
79
176
 
80
- return read_ir(name)
177
+ :param name: File name
178
+ :param use_cache: If true, use LRU caching
179
+ :return: number of samples in resampled audio
180
+ """
181
+ if use_cache:
182
+ return _get_num_samples(name)
183
+ return _get_num_samples.__wrapped__(name)
81
184
 
82
185
 
83
186
  @lru_cache
84
- def get_num_samples(name: str | Path) -> int:
187
+ def _get_num_samples(name: str | Path) -> int:
85
188
  """Get the number of samples resampled to the SonusAI sample rate in the given file
86
189
 
87
190
  :param name: File name
88
191
  :return: number of samples in resampled audio
89
192
  """
90
- from .soundfile_audio import get_num_samples
193
+ import math
91
194
 
92
- return get_num_samples(name)
195
+ import soundfile
196
+ from pydub import AudioSegment
197
+
198
+ from .constants import SAMPLE_RATE
199
+ from .tokenized_shell_vars import tokenized_expand
200
+
201
+ expanded_name, _ = tokenized_expand(name)
202
+
203
+ if expanded_name.endswith(".mp3"):
204
+ sound = AudioSegment.from_mp3(expanded_name)
205
+ samples = sound.frame_count()
206
+ sample_rate = sound.frame_rate
207
+ elif expanded_name.endswith(".m4a"):
208
+ sound = AudioSegment.from_file(expanded_name)
209
+ samples = sound.frame_count()
210
+ sample_rate = sound.frame_rate
211
+ else:
212
+ info = soundfile.info(name)
213
+ samples = info.frames
214
+ sample_rate = info.samplerate
215
+
216
+ return math.ceil(SAMPLE_RATE * samples / sample_rate)