sonusai 1.0.16__cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/utils/reshape.py
ADDED
@@ -0,0 +1,154 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from ..datatypes import Feature
|
4
|
+
from ..datatypes import Predict
|
5
|
+
from ..datatypes import Truth
|
6
|
+
|
7
|
+
|
8
|
+
def get_input_shape(feature: Feature) -> tuple[int, ...]:
|
9
|
+
return feature.shape[1:]
|
10
|
+
|
11
|
+
|
12
|
+
def reshape_inputs(
|
13
|
+
feature: Feature,
|
14
|
+
batch_size: int,
|
15
|
+
truth: Truth | None = None,
|
16
|
+
timesteps: int = 0,
|
17
|
+
flatten: bool = False,
|
18
|
+
add1ch: bool = False,
|
19
|
+
) -> tuple[Feature, Truth | None]:
|
20
|
+
"""Check SonusAI feature and truth data and reshape feature of size [frames, strides, feature_parameters] into
|
21
|
+
one of several options:
|
22
|
+
|
23
|
+
If timesteps > 0: (i.e., for recurrent NNs):
|
24
|
+
no-flatten, no-channel: [sequences, timesteps, strides, feature_parameters] (4-dim)
|
25
|
+
flatten, no-channel: [sequences, timesteps, strides*feature_parameters] (3-dim)
|
26
|
+
no-flatten, add-1channel: [sequences, timesteps, strides, feature_parameters, 1] (5-dim)
|
27
|
+
flatten, add-1channel: [sequences, timesteps, strides*feature_parameters, 1] (4-dim)
|
28
|
+
|
29
|
+
If batch_size is None, then do not reshape; just calculate new input shape and return.
|
30
|
+
|
31
|
+
If timesteps == 0, then do not add timesteps dimension.
|
32
|
+
|
33
|
+
The number of samples is trimmed to be a multiple of batch_size (Keras requirement) for
|
34
|
+
both feature and truth.
|
35
|
+
Channel is added to last/outer dimension for channel_last support in Keras/TF.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
feature reshaped feature
|
39
|
+
truth reshaped truth
|
40
|
+
"""
|
41
|
+
frames, strides, feature_parameters = feature.shape
|
42
|
+
if truth is not None:
|
43
|
+
truth_frames, num_classes = truth.shape
|
44
|
+
# Double-check correctness of inputs
|
45
|
+
if frames != truth_frames:
|
46
|
+
raise ValueError("Frames in feature and truth do not match")
|
47
|
+
else:
|
48
|
+
num_classes = 0
|
49
|
+
|
50
|
+
if flatten:
|
51
|
+
feature = np.reshape(feature, (frames, strides * feature_parameters))
|
52
|
+
|
53
|
+
# Reshape for Keras/TF recurrent models that require timesteps/sequence length dimension
|
54
|
+
if timesteps > 0:
|
55
|
+
sequences = frames // timesteps
|
56
|
+
|
57
|
+
# Remove frames if remainder exists (not fitting into a multiple of new number of sequences)
|
58
|
+
frames_rem = frames % timesteps
|
59
|
+
batch_rem = (frames // timesteps) % batch_size
|
60
|
+
bf_rem = batch_rem * timesteps
|
61
|
+
sequences = sequences - batch_rem
|
62
|
+
fr2drop = frames_rem + bf_rem
|
63
|
+
if fr2drop:
|
64
|
+
if feature.ndim == 2:
|
65
|
+
feature = feature[0:-fr2drop,] # flattened input
|
66
|
+
elif feature.ndim == 3:
|
67
|
+
feature = feature[0:-fr2drop,] # un-flattened input
|
68
|
+
|
69
|
+
if truth is not None:
|
70
|
+
truth = truth[0:-fr2drop,]
|
71
|
+
|
72
|
+
# Reshape
|
73
|
+
if feature.ndim == 2: # flattened input
|
74
|
+
# was [frames, feature_parameters*timesteps]
|
75
|
+
feature = np.reshape(feature, (sequences, timesteps, strides * feature_parameters))
|
76
|
+
if truth is not None:
|
77
|
+
# was [frames, num_classes]
|
78
|
+
truth = np.reshape(truth, (sequences, timesteps, num_classes))
|
79
|
+
elif feature.ndim == 3: # un-flattened input
|
80
|
+
# was [frames, feature_parameters, timesteps]
|
81
|
+
feature = np.reshape(feature, (sequences, timesteps, strides, feature_parameters))
|
82
|
+
if truth is not None:
|
83
|
+
# was [frames, num_classes]
|
84
|
+
truth = np.reshape(truth, (sequences, timesteps, num_classes))
|
85
|
+
else:
|
86
|
+
# Drop frames if remainder exists (not fitting into a multiple of new number of sequences)
|
87
|
+
fr2drop = feature.shape[0] % batch_size
|
88
|
+
if fr2drop > 0:
|
89
|
+
feature = feature[0:-fr2drop,]
|
90
|
+
if truth is not None:
|
91
|
+
truth = truth[0:-fr2drop,]
|
92
|
+
|
93
|
+
# Add channel dimension if required for input to model (i.e. for cnn type input)
|
94
|
+
if add1ch:
|
95
|
+
feature = np.expand_dims(feature, axis=feature.ndim) # add as last/outermost dim
|
96
|
+
|
97
|
+
return feature, truth
|
98
|
+
|
99
|
+
|
100
|
+
def get_num_classes_from_predict(predict: Predict, timesteps: int = 0) -> int:
|
101
|
+
num_dims = predict.ndim
|
102
|
+
dims = predict.shape
|
103
|
+
|
104
|
+
if num_dims == 3 or (num_dims == 2 and timesteps > 0):
|
105
|
+
# 2D with timesteps - [frames, timesteps]
|
106
|
+
if num_dims == 2:
|
107
|
+
return 1
|
108
|
+
|
109
|
+
# 3D - [frames, timesteps, num_classes]
|
110
|
+
return dims[2]
|
111
|
+
|
112
|
+
# 1D - [frames]
|
113
|
+
if num_dims == 1:
|
114
|
+
return 1
|
115
|
+
|
116
|
+
# 2D without timesteps - [frames, num_classes]
|
117
|
+
return dims[1]
|
118
|
+
|
119
|
+
|
120
|
+
def reshape_outputs(predict: Predict, truth: Truth | None = None, timesteps: int = 0) -> tuple[Predict, Truth | None]:
|
121
|
+
"""Reshape model output data.
|
122
|
+
|
123
|
+
truth and predict can be either [frames, num_classes], or [frames, timesteps, num_classes]
|
124
|
+
In binary case, num_classes dim may not exist; detect this and set num_classes to 1.
|
125
|
+
"""
|
126
|
+
if truth is not None and predict.shape != truth.shape:
|
127
|
+
raise ValueError("predict and truth shapes do not match")
|
128
|
+
|
129
|
+
ndim = predict.ndim
|
130
|
+
shape = predict.shape
|
131
|
+
|
132
|
+
if not (0 < ndim <= 3):
|
133
|
+
raise ValueError(f"do not know how to reshape data with {ndim} dimensions")
|
134
|
+
|
135
|
+
if ndim == 3 or (ndim == 2 and timesteps > 0):
|
136
|
+
if ndim == 2:
|
137
|
+
# 2D with timesteps - [frames, timesteps]
|
138
|
+
num_classes = 1
|
139
|
+
else:
|
140
|
+
# 3D - [frames, timesteps, num_classes]
|
141
|
+
num_classes = shape[2]
|
142
|
+
|
143
|
+
# reshape to remove timestep dimension
|
144
|
+
shape = (shape[0] * shape[1], num_classes)
|
145
|
+
predict = np.reshape(predict, shape)
|
146
|
+
if truth is not None:
|
147
|
+
truth = np.reshape(truth, shape)
|
148
|
+
elif ndim == 1:
|
149
|
+
# convert to 2D - [frames, 1]
|
150
|
+
predict = np.expand_dims(predict, 1)
|
151
|
+
if truth is not None:
|
152
|
+
truth = np.expand_dims(truth, 1)
|
153
|
+
|
154
|
+
return predict, truth
|
@@ -0,0 +1,82 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def stack_complex(unstacked: np.ndarray) -> np.ndarray:
|
5
|
+
"""Stack a complex array
|
6
|
+
|
7
|
+
A stacked array doubles the last dimension and organizes the data as:
|
8
|
+
- first half is all the real data
|
9
|
+
- second half is all the imaginary data
|
10
|
+
|
11
|
+
:param unstacked: An nD array (n > 1) containing complex data
|
12
|
+
:return: A stacked array
|
13
|
+
:raises TypeError:
|
14
|
+
"""
|
15
|
+
if not unstacked.ndim > 1:
|
16
|
+
raise ValueError("unstacked must have more than 1 dimension")
|
17
|
+
|
18
|
+
shape = list(unstacked.shape)
|
19
|
+
shape[-1] = shape[-1] * 2
|
20
|
+
stacked = np.empty(shape, dtype=np.float32)
|
21
|
+
half = unstacked.shape[-1]
|
22
|
+
stacked[..., :half] = np.real(unstacked)
|
23
|
+
stacked[..., half:] = np.imag(unstacked)
|
24
|
+
|
25
|
+
return stacked
|
26
|
+
|
27
|
+
|
28
|
+
def unstack_complex(stacked: np.ndarray) -> np.ndarray:
|
29
|
+
"""Unstack a stacked complex array
|
30
|
+
|
31
|
+
:param stacked: An nD array (n > 1) where the last dimension contains stacked complex data in which the first half
|
32
|
+
is all the real data and the second half is all the imaginary data
|
33
|
+
:return: An unstacked complex array
|
34
|
+
:raises TypeError:
|
35
|
+
"""
|
36
|
+
if not stacked.ndim > 1:
|
37
|
+
raise ValueError("stacked must have more than 1 dimension")
|
38
|
+
|
39
|
+
if stacked.shape[-1] % 2 != 0:
|
40
|
+
raise ValueError("last dimension of stacked must be a multiple of 2")
|
41
|
+
|
42
|
+
half = stacked.shape[-1] // 2
|
43
|
+
unstacked = 1j * stacked[..., half:]
|
44
|
+
unstacked += stacked[..., :half]
|
45
|
+
|
46
|
+
return unstacked
|
47
|
+
|
48
|
+
|
49
|
+
def stacked_complex_real(stacked: np.ndarray) -> np.ndarray:
|
50
|
+
"""Get the real elements from a stacked complex array
|
51
|
+
|
52
|
+
:param stacked: An nD array (n > 1) where the last dimension contains stacked complex data in which the first half
|
53
|
+
is all the real data and the second half is all the imaginary data
|
54
|
+
:return: The real elements
|
55
|
+
:raises TypeError:
|
56
|
+
"""
|
57
|
+
if not stacked.ndim > 1:
|
58
|
+
raise ValueError("stacked must have more than 1 dimension")
|
59
|
+
|
60
|
+
if stacked.shape[-1] % 2 != 0:
|
61
|
+
raise ValueError("last dimension of stacked must be a multiple of 2")
|
62
|
+
|
63
|
+
half = stacked.shape[-1] // 2
|
64
|
+
return stacked[..., :half]
|
65
|
+
|
66
|
+
|
67
|
+
def stacked_complex_imag(stacked: np.ndarray) -> np.ndarray:
|
68
|
+
"""Get the imaginary elements from a stacked complex array
|
69
|
+
|
70
|
+
:param stacked: An nD array (n > 1) where the last dimension contains stacked complex data in which the first half
|
71
|
+
is all the real data and the second half is all the imaginary data
|
72
|
+
:return: The imaginary elements
|
73
|
+
:raises TypeError:
|
74
|
+
"""
|
75
|
+
if not stacked.ndim > 1:
|
76
|
+
raise ValueError("stacked must have more than 1 dimension")
|
77
|
+
|
78
|
+
if stacked.shape[-1] % 2 != 0:
|
79
|
+
raise ValueError("last dimension of stacked must be a multiple of 2")
|
80
|
+
|
81
|
+
half = stacked.shape[-1] // 2
|
82
|
+
return stacked[..., half:]
|
@@ -0,0 +1,170 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from ..mixture.mixdb import MixtureDatabase
|
4
|
+
|
5
|
+
|
6
|
+
def stratified_shuffle_split_mixid(
|
7
|
+
mixdb: MixtureDatabase,
|
8
|
+
vsplit: float = 0.2,
|
9
|
+
nsplit: int = 0,
|
10
|
+
rnd_seed: int | None = 0,
|
11
|
+
) -> tuple[list[int], list[int], np.ndarray, np.ndarray]:
|
12
|
+
"""
|
13
|
+
Create a training and test/validation list of mixture IDs from all mixtures in a mixture database.
|
14
|
+
The test/validation split is specified by vsplit (0.0 to 1.0), default 0.2.
|
15
|
+
The mixtures are randomly shuffled by rnd_seed; set to int for repeatability, or None for no shuffle.
|
16
|
+
The mixtures are then stratified across all populated classes.
|
17
|
+
|
18
|
+
Inputs:
|
19
|
+
mixdb: Mixture database created by Aaware SonusAI genmixdb.
|
20
|
+
vsplit: Fractional split of mixtures for validation, 1-vsplit for training.
|
21
|
+
nsplit: Number of splits (TBD).
|
22
|
+
rnd_seed: Seed integer for reproducible random shuffling (or None for no shuffling).
|
23
|
+
|
24
|
+
Outputs:
|
25
|
+
t_mixid: list of mixture IDs for training
|
26
|
+
v_mixid: list of mixture IDs for validation
|
27
|
+
t_num_mixid: list of class counts in t_mixid
|
28
|
+
v_num_mixid: list of class counts in v_mixid
|
29
|
+
|
30
|
+
Examples:
|
31
|
+
t_mixid, v_mixid, t_num_mixid, v_num_mixid = stratified_shuffle_split_mixid(mixdb, vsplit=vsplit)
|
32
|
+
|
33
|
+
@author: Chris Eddington
|
34
|
+
"""
|
35
|
+
import random
|
36
|
+
from copy import deepcopy
|
37
|
+
|
38
|
+
from .. import logger
|
39
|
+
from ..mixture.class_count import get_class_count_from_mixids
|
40
|
+
|
41
|
+
if vsplit < 0 or vsplit > 1:
|
42
|
+
raise ValueError("vsplit must be between 0 and 1")
|
43
|
+
|
44
|
+
a_class_mixid: dict[int, list[int]] = {i + 1: [] for i in range(mixdb.num_classes)}
|
45
|
+
for mixid, mixture in enumerate(mixdb.mixtures()):
|
46
|
+
class_count = get_class_count_from_mixids(mixdb, mixid)
|
47
|
+
if any(class_count):
|
48
|
+
for class_index in mixdb.target_files[mixture.targets[0].file_id].class_indices:
|
49
|
+
a_class_mixid[class_index].append(mixid)
|
50
|
+
else:
|
51
|
+
# no counts and mutex mode means this is all 'other' class
|
52
|
+
a_class_mixid[mixdb.num_classes].append(mixid)
|
53
|
+
|
54
|
+
t_class_mixid: list[list[int]] = [[] for _ in range(mixdb.num_classes)]
|
55
|
+
v_class_mixid: list[list[int]] = [[] for _ in range(mixdb.num_classes)]
|
56
|
+
|
57
|
+
a_num_mixid = np.zeros(mixdb.num_classes, dtype=np.int32)
|
58
|
+
t_num_mixid = np.zeros(mixdb.num_classes, dtype=np.int32)
|
59
|
+
v_num_mixid = np.zeros(mixdb.num_classes, dtype=np.int32)
|
60
|
+
|
61
|
+
if rnd_seed is not None:
|
62
|
+
random.seed(rnd_seed)
|
63
|
+
|
64
|
+
# For each class pick percentage of shuffled mixids for training, validation
|
65
|
+
for ci in range(mixdb.num_classes):
|
66
|
+
# total number of mixids for class
|
67
|
+
a_num_mixid[ci] = len(a_class_mixid[ci + 1])
|
68
|
+
|
69
|
+
# number of training mixids for class
|
70
|
+
t_num_mixid[ci] = int(np.floor(a_num_mixid[ci] * (1 - vsplit)))
|
71
|
+
|
72
|
+
# number of validation mixids for class
|
73
|
+
v_num_mixid[ci] = a_num_mixid[ci] - t_num_mixid[ci]
|
74
|
+
|
75
|
+
# indices for all mixids in class
|
76
|
+
indices = [*range(a_num_mixid[ci])]
|
77
|
+
if rnd_seed is not None:
|
78
|
+
# randomize order
|
79
|
+
random.shuffle(indices)
|
80
|
+
|
81
|
+
t_class_mixid[ci] = [a_class_mixid[ci + 1][ii] for ii in indices[0 : t_num_mixid[ci]]]
|
82
|
+
v_class_mixid[ci] = [a_class_mixid[ci + 1][ii] for ii in indices[t_num_mixid[ci] :]]
|
83
|
+
|
84
|
+
if np.any(~(t_num_mixid > 0)):
|
85
|
+
logger.warning(f"Some classes have zero coverage: {np.where(~(t_num_mixid > 0))[0]}")
|
86
|
+
|
87
|
+
# Stratify over non-zero classes
|
88
|
+
nz_indices = np.where(t_num_mixid > 0)[0]
|
89
|
+
# First stratify pass is min count / 3 times through all classes, one each least populated class count (of non-zero)
|
90
|
+
min_class = min(t_num_mixid[nz_indices])
|
91
|
+
# number of mixids in each class for stratify by 1
|
92
|
+
n0 = int(np.ceil(min_class / 3))
|
93
|
+
# 3rd stage for stratify by 1
|
94
|
+
n3 = int(n0)
|
95
|
+
# 2nd stage stratify by class_count/min(class_count-n3) n2 times
|
96
|
+
n2 = int(max(min_class - n0 - n3, 0))
|
97
|
+
|
98
|
+
logger.info(
|
99
|
+
f"Stratifying training, x1 cnt {n0}: x(class_count/{n2}): x1 cnt {n3} x1, "
|
100
|
+
f"for {len(nz_indices)} populated classes"
|
101
|
+
)
|
102
|
+
|
103
|
+
# initialize source list
|
104
|
+
tt = deepcopy(t_class_mixid)
|
105
|
+
t_num_mixid2 = deepcopy(t_num_mixid)
|
106
|
+
t_mixid = []
|
107
|
+
for _ in range(n0):
|
108
|
+
for ci in range(mixdb.num_classes):
|
109
|
+
if t_num_mixid2[ci] > 0:
|
110
|
+
# append first
|
111
|
+
t_mixid.append(tt[ci][0])
|
112
|
+
del tt[ci][0]
|
113
|
+
t_num_mixid2[ci] = len(tt[ci])
|
114
|
+
|
115
|
+
# Now extract weighted by how many are left in class minus n3
|
116
|
+
# which will leave approx n3 remaining
|
117
|
+
if n2 > 0:
|
118
|
+
# should always be non-zero
|
119
|
+
min_class = int(np.min(t_num_mixid2 - n3))
|
120
|
+
class_count = np.floor((t_num_mixid2 - n3) / min_class)
|
121
|
+
# class_count = np.maximum(np.floor((t_num_mixid2 - n3) / n2),0) # Counts per class
|
122
|
+
for _ in range(min_class):
|
123
|
+
for ci in range(mixdb.num_classes):
|
124
|
+
if class_count[ci] > 0:
|
125
|
+
for _ in range(int(class_count[ci])):
|
126
|
+
# append first
|
127
|
+
t_mixid.append(tt[ci][0])
|
128
|
+
del tt[ci][0]
|
129
|
+
t_num_mixid2[ci] = len(tt[ci])
|
130
|
+
|
131
|
+
# Now extract remaining mixids, one each class until empty
|
132
|
+
# There should be ~n3 remaining mixids in each
|
133
|
+
t_mixid = _extract_remaining_mixids(mixdb, t_mixid, t_num_mixid2, tt)
|
134
|
+
|
135
|
+
if len(t_mixid) != sum(t_num_mixid):
|
136
|
+
logger.warning("Final stratified training list length does not match starting list length.")
|
137
|
+
|
138
|
+
if any(t_num_mixid2) or any(tt):
|
139
|
+
logger.warning("Remaining training mixid list not empty.")
|
140
|
+
|
141
|
+
# Now stratify the validation list, which is probably not as important, so use simple method
|
142
|
+
# initialize source list
|
143
|
+
vv = deepcopy(v_class_mixid)
|
144
|
+
v_num_mixid2 = deepcopy(v_num_mixid)
|
145
|
+
v_mixid = _extract_remaining_mixids(mixdb, [], v_num_mixid2, vv)
|
146
|
+
|
147
|
+
if len(v_mixid) != sum(v_num_mixid):
|
148
|
+
logger.warning("Final stratified validation list length does not match starting lists length.")
|
149
|
+
|
150
|
+
if any(v_num_mixid2) or any(vv):
|
151
|
+
logger.warning("Remaining validation mixid list not empty.")
|
152
|
+
|
153
|
+
return t_mixid, v_mixid, t_num_mixid, v_num_mixid
|
154
|
+
|
155
|
+
|
156
|
+
def _extract_remaining_mixids(
|
157
|
+
mixdb: MixtureDatabase,
|
158
|
+
mixid: list[int],
|
159
|
+
num_mixid: np.ndarray,
|
160
|
+
class_mixid: list[list[int]],
|
161
|
+
) -> list[int]:
|
162
|
+
for _ in range(max(num_mixid)):
|
163
|
+
for ci in range(mixdb.num_classes):
|
164
|
+
if num_mixid[ci] > 0:
|
165
|
+
# append first
|
166
|
+
mixid.append(class_mixid[ci][0])
|
167
|
+
del class_mixid[ci][0]
|
168
|
+
num_mixid[ci] = len(class_mixid[ci])
|
169
|
+
|
170
|
+
return mixid
|
@@ -0,0 +1,143 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
|
4
|
+
def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
|
5
|
+
"""Expand shell variables of the forms $var, ${var} and %var%.
|
6
|
+
Unknown variables are left unchanged.
|
7
|
+
|
8
|
+
Expand paths containing shell variable substitutions. The following rules apply:
|
9
|
+
- no expansion within single quotes
|
10
|
+
- '$$' is translated into '$'
|
11
|
+
- '%%' is translated into '%' if '%%' are not seen in %var1%%var2%
|
12
|
+
- ${var} is accepted.
|
13
|
+
- $varname is accepted.
|
14
|
+
- %var% is accepted.
|
15
|
+
- vars can be made out of letters, digits and the characters '_-'
|
16
|
+
(though is not verified in the ${var} and %var% cases)
|
17
|
+
|
18
|
+
:param name: String to expand
|
19
|
+
:return: Tuple of (expanded string, dictionary of tokens)
|
20
|
+
"""
|
21
|
+
import os
|
22
|
+
import string
|
23
|
+
|
24
|
+
from ..constants import DEFAULT_NOISE
|
25
|
+
|
26
|
+
os.environ["default_noise"] = str(DEFAULT_NOISE) # noqa: SIM112
|
27
|
+
|
28
|
+
if isinstance(name, bytes):
|
29
|
+
name = name.decode("utf-8")
|
30
|
+
|
31
|
+
if isinstance(name, Path):
|
32
|
+
name = name.as_posix()
|
33
|
+
|
34
|
+
name = os.fspath(name)
|
35
|
+
token_map: dict = {}
|
36
|
+
|
37
|
+
if "$" not in name and "%" not in name:
|
38
|
+
return name, token_map
|
39
|
+
|
40
|
+
var_chars = string.ascii_letters + string.digits + "_-"
|
41
|
+
quote = "'"
|
42
|
+
percent = "%"
|
43
|
+
brace = "{"
|
44
|
+
rbrace = "}"
|
45
|
+
dollar = "$"
|
46
|
+
environ = os.environ
|
47
|
+
|
48
|
+
result = name[:0]
|
49
|
+
index = 0
|
50
|
+
path_len = len(name)
|
51
|
+
while index < path_len:
|
52
|
+
c = name[index : index + 1]
|
53
|
+
if c == quote: # no expansion within single quotes
|
54
|
+
name = name[index + 1 :]
|
55
|
+
path_len = len(name)
|
56
|
+
try:
|
57
|
+
index = name.index(c)
|
58
|
+
result += c + name[: index + 1]
|
59
|
+
except ValueError:
|
60
|
+
result += c + name
|
61
|
+
index = path_len - 1
|
62
|
+
elif c == percent: # variable or '%'
|
63
|
+
if name[index + 1 : index + 2] == percent:
|
64
|
+
result += c
|
65
|
+
index += 1
|
66
|
+
else:
|
67
|
+
name = name[index + 1 :]
|
68
|
+
path_len = len(name)
|
69
|
+
try:
|
70
|
+
index = name.index(percent)
|
71
|
+
except ValueError:
|
72
|
+
result += percent + name
|
73
|
+
index = path_len - 1
|
74
|
+
else:
|
75
|
+
var = name[:index]
|
76
|
+
try:
|
77
|
+
if environ is None:
|
78
|
+
value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
|
79
|
+
else:
|
80
|
+
value = environ[var]
|
81
|
+
token_map[var] = value
|
82
|
+
except KeyError:
|
83
|
+
value = percent + var + percent
|
84
|
+
result += value
|
85
|
+
elif c == dollar: # variable or '$$'
|
86
|
+
if name[index + 1 : index + 2] == dollar:
|
87
|
+
result += c
|
88
|
+
index += 1
|
89
|
+
elif name[index + 1 : index + 2] == brace:
|
90
|
+
name = name[index + 2 :]
|
91
|
+
path_len = len(name)
|
92
|
+
try:
|
93
|
+
index = name.index(rbrace)
|
94
|
+
except ValueError:
|
95
|
+
result += dollar + brace + name
|
96
|
+
index = path_len - 1
|
97
|
+
else:
|
98
|
+
var = name[:index]
|
99
|
+
try:
|
100
|
+
if environ is None:
|
101
|
+
value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
|
102
|
+
else:
|
103
|
+
value = environ[var]
|
104
|
+
token_map[var] = value
|
105
|
+
except KeyError:
|
106
|
+
value = dollar + brace + var + rbrace
|
107
|
+
result += value
|
108
|
+
else:
|
109
|
+
var = name[:0]
|
110
|
+
index += 1
|
111
|
+
c = name[index : index + 1]
|
112
|
+
while c and c in var_chars:
|
113
|
+
var += c
|
114
|
+
index += 1
|
115
|
+
c = name[index : index + 1]
|
116
|
+
try:
|
117
|
+
if environ is None:
|
118
|
+
value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
|
119
|
+
else:
|
120
|
+
value = environ[var]
|
121
|
+
token_map[var] = value
|
122
|
+
except KeyError:
|
123
|
+
value = dollar + var
|
124
|
+
result += value
|
125
|
+
if c:
|
126
|
+
index -= 1
|
127
|
+
else:
|
128
|
+
result += c
|
129
|
+
index += 1
|
130
|
+
|
131
|
+
return result, token_map
|
132
|
+
|
133
|
+
|
134
|
+
def tokenized_replace(name: str, tokens: dict[str, str]) -> str:
|
135
|
+
"""Replace text with shell variables.
|
136
|
+
|
137
|
+
:param name: String to replace
|
138
|
+
:param tokens: Dictionary of replacement tokens
|
139
|
+
:return: replaced string
|
140
|
+
"""
|
141
|
+
for key, value in tokens.items():
|
142
|
+
name = name.replace(value, f"${key}")
|
143
|
+
return name
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from ..constants import SAMPLE_RATE
|
2
|
+
from ..datatypes import AudioT
|
3
|
+
|
4
|
+
|
5
|
+
def write_audio(name: str, audio: AudioT, sample_rate: int = SAMPLE_RATE) -> None:
|
6
|
+
"""Write an audio file.
|
7
|
+
|
8
|
+
To write multiple channels, use a 2D array of shape [channels, samples].
|
9
|
+
The bits per sample and PCM/float are determined by the data type.
|
10
|
+
|
11
|
+
"""
|
12
|
+
import torch
|
13
|
+
import torchaudio
|
14
|
+
|
15
|
+
data = torch.tensor(audio)
|
16
|
+
|
17
|
+
if data.dim() == 1:
|
18
|
+
data = torch.reshape(data, (1, data.shape[0]))
|
19
|
+
if data.dim() != 2:
|
20
|
+
raise ValueError("audio must be a 1D or 2D array")
|
21
|
+
|
22
|
+
# Assuming data has more samples than channels, check if array needs to be transposed
|
23
|
+
if data.shape[1] < data.shape[0]:
|
24
|
+
data = torch.transpose(data, 0, 1)
|
25
|
+
|
26
|
+
torchaudio.save(uri=name, src=data, sample_rate=sample_rate)
|
sonusai/vars.py
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
"""sonusai vars
|
2
|
+
|
3
|
+
usage: vars [-h]
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help Display this help.
|
7
|
+
|
8
|
+
List custom SonusAI variables.
|
9
|
+
|
10
|
+
"""
|
11
|
+
|
12
|
+
|
13
|
+
def main() -> None:
|
14
|
+
from docopt import docopt
|
15
|
+
|
16
|
+
from sonusai import __version__ as sai_version
|
17
|
+
from sonusai.utils.docstring import trim_docstring
|
18
|
+
|
19
|
+
docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
20
|
+
|
21
|
+
from os import environ
|
22
|
+
from os import getenv
|
23
|
+
|
24
|
+
from sonusai.constants import DEFAULT_NOISE
|
25
|
+
|
26
|
+
print("Custom SonusAI variables:")
|
27
|
+
print("")
|
28
|
+
print(f"${{default_noise}}: {DEFAULT_NOISE}")
|
29
|
+
print("")
|
30
|
+
print("SonusAI recognized environment variables:")
|
31
|
+
print("")
|
32
|
+
print(f"DEEPGRAM_API_KEY {getenv('DEEPGRAM_API_KEY')}")
|
33
|
+
print(f"GOOGLE_SPEECH_API_KEY {getenv('GOOGLE_SPEECH_API_KEY')}")
|
34
|
+
print("")
|
35
|
+
items = ["DEEPGRAM_API_KEY", "GOOGLE_SPEECH_API_KEY"]
|
36
|
+
items += [item for item in environ if item.upper().startswith("AIXP_WHISPER_")]
|
37
|
+
|
38
|
+
|
39
|
+
if __name__ == "__main__":
|
40
|
+
from sonusai import exception_handler
|
41
|
+
from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
|
42
|
+
|
43
|
+
register_keyboard_interrupt()
|
44
|
+
try:
|
45
|
+
main()
|
46
|
+
except Exception as e:
|
47
|
+
exception_handler(e)
|