sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
@@ -1,23 +1,20 @@
1
1
  from pyaaware import ForwardTransform
2
2
  from pyaaware import InverseTransform
3
3
 
4
- from sonusai.mixture.datatypes import AudioF
5
- from sonusai.mixture.datatypes import AudioT
6
- from sonusai.mixture.datatypes import Augmentation
7
- from sonusai.mixture.datatypes import AugmentationRule
8
- from sonusai.mixture.datatypes import EnergyT
9
- from sonusai.mixture.datatypes import FeatureGeneratorConfig
10
- from sonusai.mixture.datatypes import FeatureGeneratorInfo
11
- from sonusai.mixture.datatypes import GeneralizedIDs
12
- from sonusai.mixture.datatypes import Mixture
13
- from sonusai.mixture.datatypes import NoiseFile
14
- from sonusai.mixture.datatypes import SpeechMetadata
15
- from sonusai.mixture.datatypes import Target
16
- from sonusai.mixture.datatypes import TargetFile
17
- from sonusai.mixture.datatypes import TransformConfig
18
- from sonusai.mixture.db_datatypes import MixtureRecord
19
- from sonusai.mixture.db_datatypes import TargetRecord
20
- from sonusai.mixture.mixdb import MixtureDatabase
4
+ from ..datatypes import AudioF
5
+ from ..datatypes import AudioT
6
+ from ..datatypes import EnergyT
7
+ from ..datatypes import FeatureGeneratorConfig
8
+ from ..datatypes import FeatureGeneratorInfo
9
+ from ..datatypes import GeneralizedIDs
10
+ from ..datatypes import Mixture
11
+ from ..datatypes import Source
12
+ from ..datatypes import Sources
13
+ from ..datatypes import SpeechMetadata
14
+ from ..datatypes import TransformConfig
15
+ from .db_datatypes import MixtureRecord
16
+ from .db_datatypes import SourceRecord
17
+ from .mixdb import MixtureDatabase
21
18
 
22
19
 
23
20
  def generic_ids_to_list(num_ids: int, ids: GeneralizedIDs = "*") -> list[int]:
@@ -58,17 +55,12 @@ def generic_ids_to_list(num_ids: int, ids: GeneralizedIDs = "*") -> list[int]:
58
55
  return result
59
56
 
60
57
 
61
- def get_feature_generator_info(
62
- fg_config: FeatureGeneratorConfig,
63
- ) -> FeatureGeneratorInfo:
64
- from dataclasses import asdict
65
-
58
+ def get_feature_generator_info(fg_config: FeatureGeneratorConfig) -> FeatureGeneratorInfo:
66
59
  from pyaaware import FeatureGenerator
67
60
 
68
- from .datatypes import FeatureGeneratorInfo
69
- from .datatypes import TransformConfig
61
+ from ..datatypes import TransformConfig
70
62
 
71
- fg = FeatureGenerator(**asdict(fg_config))
63
+ fg = FeatureGenerator(feature_mode=fg_config.feature_mode)
72
64
 
73
65
  return FeatureGeneratorInfo(
74
66
  decimation=fg.decimation,
@@ -99,38 +91,35 @@ def get_feature_generator_info(
99
91
  )
100
92
 
101
93
 
102
- def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> list[dict[str, SpeechMetadata]]:
94
+ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> dict[str, dict[str, SpeechMetadata]]:
103
95
  """Get a list of all speech metadata for the given mixture"""
104
96
  from praatio.utilities.constants import Interval
105
97
 
106
- from .datatypes import SpeechMetadata
98
+ from ..datatypes import SpeechMetadata
107
99
 
108
- results: list[dict[str, SpeechMetadata]] = []
109
- for target in mixture.targets:
100
+ results: dict[str, dict[str, SpeechMetadata]] = {}
101
+ for category, source in mixture.all_sources.items():
110
102
  data: dict[str, SpeechMetadata] = {}
111
103
  for tier in mixdb.speaker_metadata_tiers:
112
- data[tier] = mixdb.speaker(mixdb.target_file(target.file_id).speaker_id, tier)
104
+ data[tier] = mixdb.speaker(mixdb.source_file(source.file_id).speaker_id, tier)
113
105
 
114
106
  for tier in mixdb.textgrid_metadata_tiers:
115
- item = get_textgrid_tier_from_target_file(mixdb.target_file(target.file_id).name, tier)
107
+ item = get_textgrid_tier_from_source_file(mixdb.source_file(source.file_id).name, tier)
116
108
  if isinstance(item, list):
117
- # Check for tempo augmentation and adjust Interval start and end data as needed
109
+ # Check for tempo effect and adjust Interval start and end data as needed
118
110
  entries = []
119
111
  for entry in item:
120
- if target.augmentation.pre.tempo is not None:
121
- entries.append(
122
- Interval(
123
- entry.start / target.augmentation.pre.tempo,
124
- entry.end / target.augmentation.pre.tempo,
125
- entry.label,
126
- )
112
+ entries.append(
113
+ Interval(
114
+ entry.start / source.pre_tempo,
115
+ entry.end / source.pre_tempo,
116
+ entry.label,
127
117
  )
128
- else:
129
- entries.append(entry)
118
+ )
130
119
  data[tier] = entries
131
120
  else:
132
121
  data[tier] = item
133
- results.append(data)
122
+ results[category] = data
134
123
 
135
124
  return results
136
125
 
@@ -151,27 +140,23 @@ def mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: M
151
140
 
152
141
  metadata = ""
153
142
  speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
154
- for mi, target in enumerate(mixture.targets):
155
- target_file = mixdb.target_file(target.file_id)
156
- metadata += f"target {mi} name: {target_file.name}\n"
157
- metadata += f"target {mi} augmentation: {target.augmentation.to_dict()}\n"
158
- metadata += f"target {mi} target_gain: {target.gain if not mixture.is_noise_only else 0}\n"
159
- metadata += f"target {mi} class indices: {target_file.class_indices}\n"
160
- for key in target_file.truth_configs:
161
- metadata += f"target {mi} truth '{key}' function: {target_file.truth_configs[key].function}\n"
162
- metadata += f"target {mi} truth '{key}' config: {target_file.truth_configs[key].config}\n"
163
- for key in speech_metadata[mi]:
164
- metadata += f"target {mi} speech {key}: {speech_metadata[mi][key]}\n"
165
- noise = mixdb.noise_file(mixture.noise.file_id)
166
- noise_augmentation = mixture.noise.augmentation
167
- metadata += f"noise name: {noise.name}\n"
168
- metadata += f"noise augmentation: {noise_augmentation.to_dict()}\n"
169
- metadata += f"noise offset: {mixture.noise_offset}\n"
170
- metadata += f"snr: {mixture.snr}\n"
171
- metadata += f"random_snr: {mixture.snr.is_random}\n"
172
143
  metadata += f"samples: {mixture.samples}\n"
173
- metadata += f"target_snr_gain: {float(mixture.target_snr_gain)}\n"
174
- metadata += f"noise_snr_gain: {float(mixture.noise_snr_gain)}\n"
144
+ for category, source in mixture.all_sources.items():
145
+ source_file = mixdb.source_file(source.file_id)
146
+ metadata += f"{category} name: {source_file.name}\n"
147
+ metadata += f"{category} effects: {source.effects.to_dict()}\n"
148
+ metadata += f"{category} pre_tempo: {source.pre_tempo}\n"
149
+ metadata += f"{category} class indices: {source_file.class_indices}\n"
150
+ metadata += f"{category} start: {source.start}\n"
151
+ metadata += f"{category} repeat: {source.repeat}\n"
152
+ metadata += f"{category} snr: {source.snr}\n"
153
+ metadata += f"{category} random_snr: {source.snr.is_random}\n"
154
+ metadata += f"{category} snr_gain: {source.snr_gain}\n"
155
+ for key in source_file.truth_configs:
156
+ metadata += f"{category} truth '{key}' function: {source_file.truth_configs[key].function}\n"
157
+ metadata += f"{category} truth '{key}' config: {source_file.truth_configs[key].config}\n"
158
+ for key in speech_metadata[category]:
159
+ metadata += f"{category} speech {key}: {speech_metadata[category][key]}\n"
175
160
 
176
161
  return metadata
177
162
 
@@ -197,95 +182,51 @@ def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixt
197
182
  f.write(mixture_metadata(mixdb, m_id, mixture))
198
183
 
199
184
 
200
- def from_mixture(
201
- mixture: Mixture,
202
- ) -> tuple[str, int, str, int, float, bool, float, int, int, int, float]:
203
- return (
204
- mixture.name,
205
- mixture.noise.file_id,
206
- mixture.noise.augmentation.to_json(),
207
- mixture.noise_offset,
208
- mixture.noise_snr_gain,
209
- mixture.snr.is_random,
210
- mixture.snr,
211
- mixture.samples,
212
- mixture.spectral_mask_id,
213
- mixture.spectral_mask_seed,
214
- mixture.target_snr_gain,
215
- )
216
-
217
-
218
- def to_mixture(entry: MixtureRecord, targets: list[Target]) -> Mixture:
219
- import json
220
-
221
- from sonusai.utils import dataclass_from_dict
185
+ def from_mixture(mixture: Mixture) -> tuple[str, int, int, int]:
186
+ return mixture.name, mixture.samples, mixture.spectral_mask_id, mixture.spectral_mask_seed
222
187
 
223
- from .datatypes import Noise
224
- from .datatypes import UniversalSNR
225
188
 
189
+ def to_mixture(entry: MixtureRecord, sources: Sources) -> Mixture:
226
190
  return Mixture(
227
- targets=targets,
228
191
  name=entry.name,
229
- noise=Noise(
230
- file_id=entry.noise_file_id,
231
- augmentation=dataclass_from_dict(Augmentation, json.loads(entry.noise_augmentation)), # pyright: ignore [reportArgumentType]
232
- ),
233
- noise_offset=entry.noise_offset,
234
- noise_snr_gain=entry.noise_snr_gain,
235
- snr=UniversalSNR(is_random=entry.random_snr, value=entry.snr),
236
192
  samples=entry.samples,
193
+ all_sources=sources,
237
194
  spectral_mask_id=entry.spectral_mask_id,
238
195
  spectral_mask_seed=entry.spectral_mask_seed,
239
- target_snr_gain=entry.target_snr_gain,
240
196
  )
241
197
 
242
198
 
243
- def from_target(target: Target) -> tuple[int, str]:
244
- return target.file_id, target.augmentation.to_json()
199
+ def from_source(source: Source) -> tuple[str, int, float, bool, float, float, bool, int]:
200
+ return (
201
+ source.effects.to_json(),
202
+ source.file_id,
203
+ source.pre_tempo,
204
+ source.repeat,
205
+ source.snr,
206
+ source.snr_gain,
207
+ source.snr.is_random,
208
+ source.start,
209
+ )
245
210
 
246
211
 
247
- def to_target(entry: TargetRecord) -> Target:
212
+ def to_source(entry: SourceRecord) -> Source:
248
213
  import json
249
214
 
250
- from sonusai.utils import dataclass_from_dict
251
-
252
- from .datatypes import Augmentation
215
+ from ..datatypes import Effects
216
+ from ..datatypes import UniversalSNR
217
+ from ..utils.dataclass_from_dict import dataclass_from_dict
253
218
 
254
- return Target(
219
+ return Source(
255
220
  file_id=entry.file_id,
256
- augmentation=dataclass_from_dict(Augmentation, json.loads(entry.augmentation)), # pyright: ignore [reportArgumentType]
221
+ effects=dataclass_from_dict(Effects, json.loads(entry.effects)),
222
+ start=entry.start,
223
+ repeat=entry.repeat,
224
+ snr=UniversalSNR(entry.snr, entry.snr_random),
225
+ snr_gain=entry.snr_gain,
226
+ pre_tempo=entry.pre_tempo,
257
227
  )
258
228
 
259
229
 
260
- def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[AudioT]) -> AudioT:
261
- """Get the augmented target audio data for the given mixture record
262
-
263
- :param mixdb: Mixture database
264
- :param mixture: Mixture record
265
- :param targets_audio: List of augmented target audio data (one per target in the mixup)
266
- :return: Sum of augmented target audio data
267
- """
268
- # Apply post-truth augmentation effects to targets and sum
269
- import numpy as np
270
-
271
- from .augmentation import apply_augmentation
272
-
273
- targets_post = []
274
- for idx, target_audio in enumerate(targets_audio):
275
- target = mixture.targets[idx]
276
- targets_post.append(
277
- apply_augmentation(
278
- mixdb=mixdb,
279
- audio=target_audio,
280
- augmentation=target.augmentation.post,
281
- frame_length=mixdb.feature_step_samples,
282
- )
283
- )
284
-
285
- # Return sum of targets
286
- return np.sum(targets_post, axis=0)
287
-
288
-
289
230
  def get_transform_from_audio(audio: AudioT, transform: ForwardTransform) -> tuple[AudioF, EnergyT]:
290
231
  """Apply forward transform to input audio data to generate transform data
291
232
 
@@ -368,67 +309,24 @@ def check_audio_files_exist(mixdb: MixtureDatabase) -> None:
368
309
  """Walk through all the noise and target audio files in a mixture database ensuring that they exist"""
369
310
  from os.path import exists
370
311
 
371
- from .tokenized_shell_vars import tokenized_expand
372
-
373
- for noise in mixdb.noise_files:
374
- file_name, _ = tokenized_expand(noise.name)
375
- if not exists(file_name):
376
- raise OSError(f"Could not find {file_name}")
377
-
378
- for target in mixdb.target_files:
379
- file_name, _ = tokenized_expand(target.name)
380
- if not exists(file_name):
381
- raise OSError(f"Could not find {file_name}")
382
-
383
-
384
- def augmented_target_samples(
385
- target_files: list[TargetFile],
386
- target_augmentations: list[AugmentationRule],
387
- feature_step_samples: int,
388
- ) -> int:
389
- from itertools import product
390
-
391
- from .augmentation import estimate_augmented_length_from_length
392
-
393
- target_ids = list(range(len(target_files)))
394
- target_augmentation_ids = list(range(len(target_augmentations)))
395
- it = list(product(*[target_ids, target_augmentation_ids]))
396
- return sum(
397
- [
398
- estimate_augmented_length_from_length(
399
- length=target_files[fi].samples,
400
- tempo=target_augmentations[ai].pre.tempo,
401
- frame_length=feature_step_samples,
402
- )
403
- for fi, ai in it
404
- ]
405
- )
406
-
407
-
408
- def augmented_noise_samples(noise_files: list[NoiseFile], noise_augmentations: list[Augmentation]) -> int:
409
- from itertools import product
410
-
411
- noise_ids = list(range(len(noise_files)))
412
- noise_augmentation_ids = list(range(len(noise_augmentations)))
413
- it = list(product(*[noise_ids, noise_augmentation_ids]))
414
- return sum([augmented_noise_length(noise_files[fi], noise_augmentations[ai]) for fi, ai in it])
415
-
416
-
417
- def augmented_noise_length(noise_file: NoiseFile, noise_augmentation: Augmentation) -> int:
418
- from .augmentation import estimate_augmented_length_from_length
312
+ from ..utils.tokenized_shell_vars import tokenized_expand
419
313
 
420
- return estimate_augmented_length_from_length(length=noise_file.samples, tempo=noise_augmentation.pre.tempo)
314
+ for source_files in mixdb.source_files.values():
315
+ for source_file in source_files:
316
+ file_name, _ = tokenized_expand(source_file.name)
317
+ if not exists(file_name):
318
+ raise OSError(f"Could not find {file_name}")
421
319
 
422
320
 
423
- def get_textgrid_tier_from_target_file(target_file: str, tier: str) -> SpeechMetadata | None:
321
+ def get_textgrid_tier_from_source_file(source_file: str, tier: str) -> SpeechMetadata | None:
424
322
  from pathlib import Path
425
323
 
426
324
  from praatio import textgrid
427
325
  from praatio.utilities.constants import Interval
428
326
 
429
- from .tokenized_shell_vars import tokenized_expand
327
+ from ..utils.tokenized_shell_vars import tokenized_expand
430
328
 
431
- textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix(".TextGrid")
329
+ textgrid_file = Path(tokenized_expand(source_file)[0]).with_suffix(".TextGrid")
432
330
  if not textgrid_file.exists():
433
331
  return None
434
332
 
@@ -1,14 +1,13 @@
1
1
  import numpy as np
2
2
 
3
3
 
4
- def get_impulse_response_delay(file: str) -> int:
5
- from sonusai.utils import temp_seed
6
-
4
+ def get_ir_delay(file: str) -> int:
5
+ from ..utils.rand import seed_context
7
6
  from .audio import raw_read_audio
8
7
 
9
8
  ir, sample_rate = raw_read_audio(file)
10
9
 
11
- with temp_seed(42):
10
+ with seed_context(42):
12
11
  wgn_ref = np.random.normal(loc=0, scale=0.2, size=int(np.ceil(0.05 * sample_rate))).astype(np.float32)
13
12
 
14
13
  wgn_conv = np.convolve(ir, wgn_ref)
@@ -0,0 +1,77 @@
1
+ from functools import lru_cache
2
+ from pathlib import Path
3
+
4
+ from ..datatypes import AudioT
5
+ from ..datatypes import ImpulseResponseData
6
+ from .audio import raw_read_audio
7
+
8
+
9
+ def apply_ir(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
10
+ """Apply impulse response to audio data using scipy
11
+
12
+ :param audio: Audio
13
+ :param ir: Impulse response data
14
+ :return: Effected audio
15
+ """
16
+ import numpy as np
17
+ from librosa import resample
18
+ from scipy.signal import fftconvolve
19
+
20
+ from ..constants import SAMPLE_RATE
21
+
22
+ if not isinstance(audio, np.ndarray):
23
+ raise TypeError("audio must be a numpy array")
24
+
25
+ # Early exit if no ir or if all audio is zero
26
+ if ir is None or not audio.any():
27
+ return audio
28
+
29
+ pk_in = np.max(np.abs(audio))
30
+
31
+ # Convert audio to IR sample rate
32
+ audio_in = resample(audio, orig_sr=SAMPLE_RATE, target_sr=ir.sample_rate, res_type="soxr_hq")
33
+
34
+ # Apply IR
35
+ audio_out = fftconvolve(audio_in, ir.data, mode="full")
36
+
37
+ # Delay compensation
38
+ audio_out = audio_out[ir.delay :]
39
+
40
+ # Convert back to global sample rate
41
+ audio_out = resample(audio_out, orig_sr=ir.sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
42
+
43
+ # Trim to length
44
+ audio_out = audio_out[: len(audio)]
45
+
46
+ # Gain compensation
47
+ pk_out = np.max(np.abs(audio_out))
48
+ pk_gain = pk_in / pk_out
49
+ audio_out = audio_out * pk_gain
50
+
51
+ return audio_out
52
+
53
+
54
+ def read_ir(name: str | Path, delay: int, use_cache: bool = True) -> ImpulseResponseData:
55
+ """Read impulse response data
56
+
57
+ :param name: File name
58
+ :param delay: Delay in samples
59
+ :param use_cache: If true, use LRU caching
60
+ :return: ImpulseResponseData object
61
+ """
62
+ if use_cache:
63
+ return _read_ir(name, delay)
64
+ return _read_ir.__wrapped__(name, delay)
65
+
66
+
67
+ @lru_cache
68
+ def _read_ir(name: str | Path, delay: int) -> ImpulseResponseData:
69
+ """Read impulse response data using soundfile
70
+
71
+ :param name: File name
72
+ :param delay: Delay in samples
73
+ :return: ImpulseResponseData object
74
+ """
75
+ out, sample_rate = raw_read_audio(name)
76
+
77
+ return ImpulseResponseData(data=out, sample_rate=sample_rate, delay=delay)
@@ -1,29 +1,23 @@
1
1
  def log_duration_and_sizes(
2
2
  total_duration: float,
3
- num_classes: int,
4
3
  feature_step_samples: int,
5
4
  feature_parameters: int,
6
5
  stride: int,
7
6
  desc: str,
8
7
  ) -> None:
9
- from sonusai import logger
10
- from sonusai.utils import human_readable_size
11
- from sonusai.utils import seconds_to_hms
12
-
13
- from .constants import FLOAT_BYTES
14
- from .constants import SAMPLE_BYTES
15
- from .constants import SAMPLE_RATE
8
+ from .. import logger
9
+ from ..constants import FLOAT_BYTES
10
+ from ..constants import SAMPLE_BYTES
11
+ from ..constants import SAMPLE_RATE
12
+ from ..utils.human_readable_size import human_readable_size
13
+ from ..utils.seconds_to_hms import seconds_to_hms
16
14
 
17
15
  total_samples = int(total_duration * SAMPLE_RATE)
18
16
  mixture_bytes = total_samples * SAMPLE_BYTES
19
- truth_t_bytes = total_samples * num_classes * FLOAT_BYTES
20
17
  feature_bytes = total_samples / feature_step_samples * stride * feature_parameters * FLOAT_BYTES
21
- truth_f_bytes = total_samples / feature_step_samples * num_classes * FLOAT_BYTES
22
18
 
23
19
  logger.info("")
24
20
  logger.info(f"{desc} duration: {seconds_to_hms(seconds=total_duration)}")
25
21
  logger.info(f"{desc} sizes:")
26
22
  logger.info(f" mixture: {human_readable_size(mixture_bytes, 1)}")
27
- logger.info(f" truth_t: {human_readable_size(truth_t_bytes, 1)}")
28
23
  logger.info(f" feature: {human_readable_size(feature_bytes, 1)}")
29
- logger.info(f" truth_f: {human_readable_size(truth_f_bytes, 1)}")