sonusai 0.20.2__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +240 -76
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +23 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -17
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +5 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +484 -611
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +931 -669
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/METADATA +4 -2
- sonusai-1.0.1.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.2.dist-info/RECORD +0 -128
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/WHEEL +0 -0
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/entry_points.txt +0 -0
sonusai/metrics_summary.py
CHANGED
@@ -16,23 +16,9 @@ Inputs:
|
|
16
16
|
|
17
17
|
"""
|
18
18
|
|
19
|
-
import signal
|
20
|
-
|
21
19
|
import numpy as np
|
22
20
|
import pandas as pd
|
23
21
|
|
24
|
-
|
25
|
-
def signal_handler(_sig, _frame):
|
26
|
-
import sys
|
27
|
-
|
28
|
-
from sonusai import logger
|
29
|
-
|
30
|
-
logger.info("Canceled due to keyboard interrupt")
|
31
|
-
sys.exit(1)
|
32
|
-
|
33
|
-
|
34
|
-
signal.signal(signal.SIGINT, signal_handler)
|
35
|
-
|
36
22
|
DB_99 = np.power(10, 99 / 10)
|
37
23
|
DB_N99 = np.power(10, -99 / 10)
|
38
24
|
|
@@ -49,8 +35,8 @@ def _process_mixture(
|
|
49
35
|
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
50
36
|
from os.path import basename
|
51
37
|
|
38
|
+
from sonusai.constants import SAMPLE_RATE
|
52
39
|
from sonusai.metrics import calc_wer
|
53
|
-
from sonusai.mixture import SAMPLE_RATE
|
54
40
|
from sonusai.mixture import MixtureDatabase
|
55
41
|
|
56
42
|
mixdb = MixtureDatabase(location)
|
@@ -61,11 +47,11 @@ def _process_mixture(
|
|
61
47
|
duration = samples / SAMPLE_RATE
|
62
48
|
tf_frames = mixdb.mixture_transform_frames(m_id)
|
63
49
|
feat_frames = mixdb.mixture_feature_frames(m_id)
|
64
|
-
mxsnr = mixdb.mixture(m_id).snr
|
65
|
-
ti = mixdb.mixture(m_id).
|
50
|
+
mxsnr = mixdb.mixture(m_id).noise.snr
|
51
|
+
ti = mixdb.mixture(m_id).sources["primary"].file_id
|
66
52
|
ni = mixdb.mixture(m_id).noise.file_id
|
67
|
-
t0file = basename(mixdb.
|
68
|
-
nfile = basename(mixdb.
|
53
|
+
t0file = basename(mixdb.source_file(ti).name)
|
54
|
+
nfile = basename(mixdb.source_file(ni).name)
|
69
55
|
|
70
56
|
all_metrics = mixdb.mixture_metrics(m_id, all_metric_names)
|
71
57
|
|
@@ -104,10 +90,10 @@ def _process_mixture(
|
|
104
90
|
def main() -> None:
|
105
91
|
from docopt import docopt
|
106
92
|
|
107
|
-
from
|
108
|
-
from
|
93
|
+
from . import __version__ as sai_version
|
94
|
+
from .utils.docstring import trim_docstring
|
109
95
|
|
110
|
-
args = docopt(trim_docstring(__doc__), version=
|
96
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
111
97
|
|
112
98
|
verbose = args["--verbose"]
|
113
99
|
wrlist = args["--write-list"]
|
@@ -121,24 +107,22 @@ def main() -> None:
|
|
121
107
|
|
122
108
|
import psutil
|
123
109
|
|
124
|
-
from
|
125
|
-
from
|
126
|
-
from
|
127
|
-
from
|
128
|
-
from
|
129
|
-
from
|
130
|
-
from
|
131
|
-
from
|
110
|
+
from . import create_file_handler
|
111
|
+
from . import initial_log_messages
|
112
|
+
from . import logger
|
113
|
+
from . import update_console_handler
|
114
|
+
from .mixture.mixdb import MixtureDatabase
|
115
|
+
from .utils.create_timestamp import create_timestamp
|
116
|
+
from .utils.parallel import par_track
|
117
|
+
from .utils.parallel import track
|
132
118
|
|
133
|
-
|
134
|
-
|
135
|
-
print(f"Found SonusAI mixture database with {mixdb.num_mixtures} mixtures.")
|
136
|
-
except:
|
137
|
-
print(f"Could not open SonusAI mixture database in {location}, exiting ...")
|
138
|
-
return
|
119
|
+
mixdb = MixtureDatabase(location)
|
120
|
+
print(f"Found SonusAI mixture database with {mixdb.num_mixtures} mixtures.")
|
139
121
|
|
140
122
|
# Only check first and last mixture in order to save time
|
141
|
-
metrics_present = mixdb.cached_metrics([0, mixdb.num_mixtures - 1])
|
123
|
+
metrics_present = mixdb.cached_metrics([0, mixdb.num_mixtures - 1]) # return pre-generated metrics in mixdb tree
|
124
|
+
if "mxsnr" in metrics_present:
|
125
|
+
metrics_present.remove("mxsnr")
|
142
126
|
|
143
127
|
num_metrics_present = len(metrics_present)
|
144
128
|
if num_metrics_present < 1:
|
@@ -188,8 +172,8 @@ def main() -> None:
|
|
188
172
|
if len(metval) > 1:
|
189
173
|
logger.warning(f"Mixid {mixids[0]} metric {metric} has a list with more than 1 element, using first.")
|
190
174
|
metval = metval[0] # remove any list
|
191
|
-
if isinstance(metval, float):
|
192
|
-
logger.debug("Metric is scalar
|
175
|
+
if isinstance(metval, float | int):
|
176
|
+
logger.debug(f"Metric is scalar {type(metval)}, entering in summary table.")
|
193
177
|
scalar_metric_names.append(metric)
|
194
178
|
elif isinstance(metval, str):
|
195
179
|
logger.debug("Metric is string, will summarize with word count.")
|
@@ -205,7 +189,7 @@ def main() -> None:
|
|
205
189
|
else:
|
206
190
|
logger.warning(f"Mixid {mixids[0]} metric {metric} is a vector of improper size, ignoring.")
|
207
191
|
|
208
|
-
# Setup pandas table for summarizing scalar metrics
|
192
|
+
# Setup pandas table for summarizing scalar metrics, always include mxsnr first
|
209
193
|
ptab_labels = [
|
210
194
|
"mxsnr",
|
211
195
|
*scalar_metric_names,
|
@@ -276,7 +260,7 @@ def main() -> None:
|
|
276
260
|
ptab1.round(2).to_csv(wlcsv_name, **table_args)
|
277
261
|
ptab1_sorted = ptab1.sort_values(by=["mxsnr", "t0file"])
|
278
262
|
|
279
|
-
# Create metrics table except
|
263
|
+
# Create metrics table except -99 SNR
|
280
264
|
ptab1_nom99 = ptab1_sorted[ptab1_sorted.mxsnr != -99]
|
281
265
|
|
282
266
|
# Create summary by SNR for all scalar metrics, taking mean
|
@@ -294,7 +278,7 @@ def main() -> None:
|
|
294
278
|
nmixtot = mixdb.num_mixtures
|
295
279
|
pd.DataFrame([["Timestamp", timestamp]]).to_csv(snrcsv_name, header=False, index=False)
|
296
280
|
pd.DataFrame(['"Metrics avg over each SNR:"']).to_csv(snrcsv_name, **header_args)
|
297
|
-
mtab_snr_summary.round(2).to_csv(snrcsv_name, index=False,
|
281
|
+
mtab_snr_summary.round(2).T.to_csv(snrcsv_name, index=True, header=False, mode="a", encoding="utf-8")
|
298
282
|
pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
|
299
283
|
pd.DataFrame([f'"Metrics stats over {nmix} mixtures out of {nmixtot} total:"']).to_csv(snrcsv_name, **header_args)
|
300
284
|
ptab1.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
|
@@ -304,12 +288,14 @@ def main() -> None:
|
|
304
288
|
)
|
305
289
|
ptab1_nom99.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
|
306
290
|
|
307
|
-
# Write summary to
|
291
|
+
# Write summary to text file
|
308
292
|
snrtxt_name = str(join(location, "metric_summary_snr" + fsuffix + ".txt"))
|
309
293
|
with open(snrtxt_name, "w") as f:
|
310
294
|
print(f"Timestamp: {timestamp}", file=f)
|
311
295
|
print("Metrics avg over each SNR:", file=f)
|
312
|
-
print(
|
296
|
+
print(
|
297
|
+
mtab_snr_summary.round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True, header=False), file=f
|
298
|
+
)
|
313
299
|
print("", file=f)
|
314
300
|
print(f"Metrics stats over {len(mixids)} mixtures out of {mixdb.num_mixtures} total:", file=f)
|
315
301
|
print(ptab1.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
|
@@ -319,4 +305,11 @@ def main() -> None:
|
|
319
305
|
|
320
306
|
|
321
307
|
if __name__ == "__main__":
|
322
|
-
|
308
|
+
from sonusai import exception_handler
|
309
|
+
from sonusai.utils import register_keyboard_interrupt
|
310
|
+
|
311
|
+
register_keyboard_interrupt()
|
312
|
+
try:
|
313
|
+
main()
|
314
|
+
except Exception as e:
|
315
|
+
exception_handler(e)
|
sonusai/mixture/__init__.py
CHANGED
@@ -1,131 +1,32 @@
|
|
1
1
|
# SonusAI mixture utilities
|
2
2
|
# ruff: noqa: F401
|
3
3
|
|
4
|
-
from
|
5
|
-
from .audio import get_next_noise
|
6
|
-
from .audio import get_num_samples
|
7
|
-
from .audio import get_sample_rate
|
8
|
-
from .audio import raw_read_audio
|
4
|
+
from ..datatypes import AudioT
|
9
5
|
from .audio import read_audio
|
10
|
-
from .
|
11
|
-
from .
|
12
|
-
from .augmentation import apply_augmentation
|
13
|
-
from .augmentation import apply_gain
|
14
|
-
from .augmentation import apply_impulse_response
|
15
|
-
from .augmentation import augmentation_from_rule
|
16
|
-
from .augmentation import estimate_augmented_length_from_length
|
17
|
-
from .augmentation import evaluate_random_rule
|
18
|
-
from .augmentation import get_augmentation_indices_for_mixup
|
19
|
-
from .augmentation import get_augmentation_rules
|
20
|
-
from .augmentation import get_mixups
|
21
|
-
from .augmentation import pad_audio_to_length
|
22
|
-
from .class_count import get_class_count_from_mixids
|
23
|
-
from .config import get_default_config
|
24
|
-
from .config import get_impulse_response_files
|
25
|
-
from .config import get_noise_files
|
26
|
-
from .config import get_spectral_masks
|
27
|
-
from .config import get_target_files
|
28
|
-
from .config import get_truth_parameters
|
6
|
+
from .config import get_ir_files
|
7
|
+
from .config import get_source_files
|
29
8
|
from .config import load_config
|
30
|
-
from .config import raw_load_config
|
31
|
-
from .config import update_config_from_file
|
32
|
-
from .config import update_config_from_hierarchy
|
33
|
-
from .config import validate_truth_configs
|
34
|
-
from .constants import BIT_DEPTH
|
35
|
-
from .constants import CHANNEL_COUNT
|
36
|
-
from .constants import DEFAULT_CONFIG
|
37
|
-
from .constants import DEFAULT_NOISE
|
38
|
-
from .constants import DEFAULT_SPEECH
|
39
|
-
from .constants import ENCODING
|
40
|
-
from .constants import FLOAT_BYTES
|
41
|
-
from .constants import MIXDB_VERSION
|
42
|
-
from .constants import RAND_PATTERN
|
43
|
-
from .constants import REQUIRED_CONFIGS
|
44
|
-
from .constants import REQUIRED_TRUTH_CONFIGS
|
45
|
-
from .constants import SAMPLE_BYTES
|
46
|
-
from .constants import SAMPLE_RATE
|
47
|
-
from .constants import VALID_AUGMENTATIONS
|
48
|
-
from .constants import VALID_CONFIGS
|
49
|
-
from .constants import VALID_NOISE_MIX_MODES
|
50
|
-
from .data_io import clear_cached_data
|
51
9
|
from .data_io import read_cached_data
|
52
10
|
from .data_io import write_cached_data
|
53
|
-
from .
|
54
|
-
from .datatypes import AudioStatsMetrics
|
55
|
-
from .datatypes import AudioT
|
56
|
-
from .datatypes import Augmentation
|
57
|
-
from .datatypes import AugmentationEffects
|
58
|
-
from .datatypes import AugmentationRule
|
59
|
-
from .datatypes import AugmentationRuleEffects
|
60
|
-
from .datatypes import AugmentedTarget
|
61
|
-
from .datatypes import ClassCount
|
62
|
-
from .datatypes import EnergyF
|
63
|
-
from .datatypes import EnergyT
|
64
|
-
from .datatypes import Feature
|
65
|
-
from .datatypes import FeatureGeneratorConfig
|
66
|
-
from .datatypes import FeatureGeneratorInfo
|
67
|
-
from .datatypes import GeneralizedIDs
|
68
|
-
from .datatypes import GenFTData
|
69
|
-
from .datatypes import GenMixData
|
70
|
-
from .datatypes import ImpulseResponseData
|
71
|
-
from .datatypes import ImpulseResponseFile
|
72
|
-
from .datatypes import MetricDoc
|
73
|
-
from .datatypes import MetricDocs
|
74
|
-
from .datatypes import Mixture
|
75
|
-
from .datatypes import MixtureDatabaseConfig
|
76
|
-
from .datatypes import NoiseFile
|
77
|
-
from .datatypes import Predict
|
78
|
-
from .datatypes import Segsnr
|
79
|
-
from .datatypes import SnrFMetrics
|
80
|
-
from .datatypes import SpectralMask
|
81
|
-
from .datatypes import SpeechMetadata
|
82
|
-
from .datatypes import SpeechMetrics
|
83
|
-
from .datatypes import TargetFile
|
84
|
-
from .datatypes import TransformConfig
|
85
|
-
from .datatypes import Truth
|
86
|
-
from .datatypes import TruthConfig
|
87
|
-
from .datatypes import TruthConfigs
|
88
|
-
from .datatypes import TruthDict
|
89
|
-
from .datatypes import TruthParameter
|
90
|
-
from .datatypes import UniversalSNR
|
11
|
+
from .effects import get_effect_rules
|
91
12
|
from .feature import get_audio_from_feature
|
92
13
|
from .feature import get_feature_from_audio
|
93
14
|
from .generation import generate_mixtures
|
94
|
-
from .generation import get_all_snrs_from_config
|
95
15
|
from .generation import initialize_db
|
96
16
|
from .generation import populate_class_label_table
|
97
17
|
from .generation import populate_class_weights_threshold_table
|
98
18
|
from .generation import populate_impulse_response_file_table
|
99
19
|
from .generation import populate_mixture_table
|
100
|
-
from .generation import
|
20
|
+
from .generation import populate_source_file_table
|
101
21
|
from .generation import populate_spectral_mask_table
|
102
|
-
from .generation import populate_target_file_table
|
103
22
|
from .generation import populate_top_table
|
104
23
|
from .generation import populate_truth_parameters_table
|
105
24
|
from .generation import update_mixid_width
|
106
25
|
from .generation import update_mixture
|
107
|
-
from .helpers import augmented_noise_samples
|
108
|
-
from .helpers import augmented_target_samples
|
109
26
|
from .helpers import check_audio_files_exist
|
110
27
|
from .helpers import forward_transform
|
111
|
-
from .helpers import frames_from_samples
|
112
|
-
from .helpers import get_audio_from_transform
|
113
|
-
from .helpers import get_transform_from_audio
|
114
28
|
from .helpers import inverse_transform
|
115
|
-
from .helpers import mixture_metadata
|
116
29
|
from .helpers import write_mixture_metadata
|
117
|
-
from .ir_delay import get_impulse_response_delay
|
118
30
|
from .log_duration_and_sizes import log_duration_and_sizes
|
119
31
|
from .mixdb import MixtureDatabase
|
120
32
|
from .mixdb import db_file
|
121
|
-
from .spectral_mask import apply_spectral_mask
|
122
|
-
from .target_class_balancing import balance_targets
|
123
|
-
from .targets import get_augmented_target_ids_by_class
|
124
|
-
from .targets import get_augmented_target_ids_for_mixup
|
125
|
-
from .targets import get_augmented_targets
|
126
|
-
from .targets import get_target_augmentations_for_mixup
|
127
|
-
from .tokenized_shell_vars import tokenized_expand
|
128
|
-
from .tokenized_shell_vars import tokenized_replace
|
129
|
-
from .truth import get_truth_indices_for_mixid
|
130
|
-
from .truth import truth_function
|
131
|
-
from .truth import truth_stride_reduction
|
sonusai/mixture/audio.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
1
|
from functools import lru_cache
|
2
2
|
from pathlib import Path
|
3
3
|
|
4
|
-
from
|
5
|
-
from sonusai.mixture.datatypes import ImpulseResponseData
|
4
|
+
from ..datatypes import AudioT
|
6
5
|
|
7
6
|
|
8
7
|
def get_next_noise(audio: AudioT, offset: int, length: int) -> AudioT:
|
@@ -24,7 +23,7 @@ def get_duration(audio: AudioT) -> float:
|
|
24
23
|
:param audio: Time domain data [samples]
|
25
24
|
:return: Duration of audio in seconds
|
26
25
|
"""
|
27
|
-
from
|
26
|
+
from ..constants import SAMPLE_RATE
|
28
27
|
|
29
28
|
return len(audio) / SAMPLE_RATE
|
30
29
|
|
@@ -66,7 +65,7 @@ def _get_sample_rate(name: str | Path) -> int:
|
|
66
65
|
import soundfile
|
67
66
|
from pydub import AudioSegment
|
68
67
|
|
69
|
-
from .tokenized_shell_vars import tokenized_expand
|
68
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
70
69
|
|
71
70
|
expanded_name, _ = tokenized_expand(name)
|
72
71
|
|
@@ -90,7 +89,7 @@ def raw_read_audio(name: str | Path) -> tuple[AudioT, int]:
|
|
90
89
|
import soundfile
|
91
90
|
from pydub import AudioSegment
|
92
91
|
|
93
|
-
from .tokenized_shell_vars import tokenized_expand
|
92
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
94
93
|
|
95
94
|
expanded_name, _ = tokenized_expand(name)
|
96
95
|
|
@@ -135,40 +134,12 @@ def _read_audio(name: str | Path) -> AudioT:
|
|
135
134
|
:param name: File name
|
136
135
|
:return: Array of time domain audio data
|
137
136
|
"""
|
138
|
-
import
|
137
|
+
from ..constants import SAMPLE_RATE
|
138
|
+
from .resample import resample
|
139
139
|
|
140
|
-
from .constants import SAMPLE_RATE
|
141
|
-
|
142
|
-
out, sample_rate = raw_read_audio(name)
|
143
|
-
out = librosa.resample(out, orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
|
144
|
-
|
145
|
-
return out
|
146
|
-
|
147
|
-
|
148
|
-
def read_ir(name: str | Path, delay: int, use_cache: bool = True) -> ImpulseResponseData:
|
149
|
-
"""Read impulse response data
|
150
|
-
|
151
|
-
:param name: File name
|
152
|
-
:param delay: Delay in samples
|
153
|
-
:param use_cache: If true, use LRU caching
|
154
|
-
:return: ImpulseResponseData object
|
155
|
-
"""
|
156
|
-
if use_cache:
|
157
|
-
return _read_ir(name, delay)
|
158
|
-
return _read_ir.__wrapped__(name, delay)
|
159
|
-
|
160
|
-
|
161
|
-
@lru_cache
|
162
|
-
def _read_ir(name: str | Path, delay: int) -> ImpulseResponseData:
|
163
|
-
"""Read impulse response data using soundfile
|
164
|
-
|
165
|
-
:param name: File name
|
166
|
-
:param delay: Delay in samples
|
167
|
-
:return: ImpulseResponseData object
|
168
|
-
"""
|
169
140
|
out, sample_rate = raw_read_audio(name)
|
170
141
|
|
171
|
-
return
|
142
|
+
return resample(out, orig_sr=sample_rate, target_sr=SAMPLE_RATE)
|
172
143
|
|
173
144
|
|
174
145
|
def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
|
@@ -195,8 +166,8 @@ def _get_num_samples(name: str | Path) -> int:
|
|
195
166
|
import soundfile
|
196
167
|
from pydub import AudioSegment
|
197
168
|
|
198
|
-
from
|
199
|
-
from .tokenized_shell_vars import tokenized_expand
|
169
|
+
from ..constants import SAMPLE_RATE
|
170
|
+
from ..utils.tokenized_shell_vars import tokenized_expand
|
200
171
|
|
201
172
|
expanded_name, _ = tokenized_expand(name)
|
202
173
|
|
@@ -209,7 +180,7 @@ def _get_num_samples(name: str | Path) -> int:
|
|
209
180
|
samples = sound.frame_count()
|
210
181
|
sample_rate = sound.frame_rate
|
211
182
|
else:
|
212
|
-
info = soundfile.info(
|
183
|
+
info = soundfile.info(expanded_name)
|
213
184
|
samples = info.frames
|
214
185
|
sample_rate = info.samplerate
|
215
186
|
|
@@ -0,0 +1,103 @@
|
|
1
|
+
from ..datatypes import EffectList
|
2
|
+
from ..datatypes import EffectedFile
|
3
|
+
from ..datatypes import File
|
4
|
+
|
5
|
+
|
6
|
+
def balance_sources(
|
7
|
+
effected_sources: list[EffectedFile],
|
8
|
+
files: list[File],
|
9
|
+
effects: list[EffectList],
|
10
|
+
class_balancing_effect: EffectList,
|
11
|
+
num_classes: int,
|
12
|
+
num_ir: int,
|
13
|
+
mixups: list[int] | None = None,
|
14
|
+
) -> tuple[list[EffectedFile], list[EffectList]]:
|
15
|
+
import math
|
16
|
+
|
17
|
+
from .augmentation import get_mixups
|
18
|
+
from .sources import get_augmented_target_ids_by_class
|
19
|
+
|
20
|
+
first_cba_id = len(effects)
|
21
|
+
|
22
|
+
if mixups is None:
|
23
|
+
mixups = get_mixups(effects)
|
24
|
+
|
25
|
+
for mixup in mixups:
|
26
|
+
if mixup == 1:
|
27
|
+
continue
|
28
|
+
|
29
|
+
effected_sources_indices_by_class = get_augmented_target_ids_by_class(
|
30
|
+
augmented_targets=effected_sources,
|
31
|
+
targets=files,
|
32
|
+
target_augmentations=effects,
|
33
|
+
mixup=mixup,
|
34
|
+
num_classes=num_classes,
|
35
|
+
)
|
36
|
+
|
37
|
+
largest = max([len(item) for item in effected_sources_indices_by_class])
|
38
|
+
largest = math.ceil(largest / mixup) * mixup
|
39
|
+
for es_indices in effected_sources_indices_by_class:
|
40
|
+
additional_effects_needed = largest - len(es_indices)
|
41
|
+
file_ids = sorted({effected_sources[at_index].file_id for at_index in es_indices})
|
42
|
+
|
43
|
+
tfi_idx = 0
|
44
|
+
for _ in range(additional_effects_needed):
|
45
|
+
file_id = file_ids[tfi_idx]
|
46
|
+
tfi_idx = (tfi_idx + 1) % len(file_ids)
|
47
|
+
effect_id, effects = _get_unused_balancing_effect(
|
48
|
+
effected_sources=effected_sources,
|
49
|
+
files=files,
|
50
|
+
effects=effects,
|
51
|
+
class_balancing_effect=class_balancing_effect,
|
52
|
+
file_id=file_id,
|
53
|
+
mixup=mixup,
|
54
|
+
num_ir=num_ir,
|
55
|
+
first_cbe_id=first_cba_id,
|
56
|
+
)
|
57
|
+
effected_sources.append(EffectedFile(file_id=file_id, effect_id=effect_id))
|
58
|
+
|
59
|
+
return effected_sources, effects
|
60
|
+
|
61
|
+
|
62
|
+
def _get_unused_balancing_effect(
|
63
|
+
effected_sources: list[EffectedFile],
|
64
|
+
files: list[File],
|
65
|
+
effects: list[EffectList],
|
66
|
+
class_balancing_effect: EffectList,
|
67
|
+
file_id: int,
|
68
|
+
mixup: int,
|
69
|
+
num_ir: int,
|
70
|
+
first_cbe_id: int,
|
71
|
+
) -> tuple[int, list[EffectList]]:
|
72
|
+
"""Get an unused balancing augmentation for a given target file index"""
|
73
|
+
from dataclasses import asdict
|
74
|
+
|
75
|
+
from .augmentation import get_augmentation_rules
|
76
|
+
|
77
|
+
balancing_augmentations = [item for item in range(len(effects)) if item >= first_cbe_id]
|
78
|
+
used_balancing_augmentations = [
|
79
|
+
effected_source.effect_id
|
80
|
+
for effected_source in effected_sources
|
81
|
+
if effected_source.file_id == file_id and effected_source.effect_id in balancing_augmentations
|
82
|
+
]
|
83
|
+
|
84
|
+
augmentation_indices = [
|
85
|
+
item
|
86
|
+
for item in balancing_augmentations
|
87
|
+
if item not in used_balancing_augmentations and effects[item].mixup == mixup
|
88
|
+
]
|
89
|
+
if len(augmentation_indices) > 0:
|
90
|
+
return augmentation_indices[0], effects
|
91
|
+
|
92
|
+
class_balancing_effect = get_class_balancing_effect(file=files[file_id], default_cbe=class_balancing_effect)
|
93
|
+
new_effect = get_augmentation_rules(rules=asdict(class_balancing_effect), num_ir=num_ir)[0]
|
94
|
+
new_effect.mixup = mixup
|
95
|
+
effects.append(new_effect)
|
96
|
+
return len(effects) - 1, effects
|
97
|
+
|
98
|
+
|
99
|
+
def get_class_balancing_effect(file: File, default_cbe: EffectList) -> EffectList:
|
100
|
+
"""Get the class balancing effect rule for the given target"""
|
101
|
+
if file.class_balancing_effect is not None:
|
102
|
+
return file.class_balancing_effect
|
103
|
+
return default_cbe
|