sonusai 1.0.16__cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,154 @@
1
+ import numpy as np
2
+
3
+ from ..datatypes import Feature
4
+ from ..datatypes import Predict
5
+ from ..datatypes import Truth
6
+
7
+
8
+ def get_input_shape(feature: Feature) -> tuple[int, ...]:
9
+ return feature.shape[1:]
10
+
11
+
12
+ def reshape_inputs(
13
+ feature: Feature,
14
+ batch_size: int,
15
+ truth: Truth | None = None,
16
+ timesteps: int = 0,
17
+ flatten: bool = False,
18
+ add1ch: bool = False,
19
+ ) -> tuple[Feature, Truth | None]:
20
+ """Check SonusAI feature and truth data and reshape feature of size [frames, strides, feature_parameters] into
21
+ one of several options:
22
+
23
+ If timesteps > 0: (i.e., for recurrent NNs):
24
+ no-flatten, no-channel: [sequences, timesteps, strides, feature_parameters] (4-dim)
25
+ flatten, no-channel: [sequences, timesteps, strides*feature_parameters] (3-dim)
26
+ no-flatten, add-1channel: [sequences, timesteps, strides, feature_parameters, 1] (5-dim)
27
+ flatten, add-1channel: [sequences, timesteps, strides*feature_parameters, 1] (4-dim)
28
+
29
+ If batch_size is None, then do not reshape; just calculate new input shape and return.
30
+
31
+ If timesteps == 0, then do not add timesteps dimension.
32
+
33
+ The number of samples is trimmed to be a multiple of batch_size (Keras requirement) for
34
+ both feature and truth.
35
+ Channel is added to last/outer dimension for channel_last support in Keras/TF.
36
+
37
+ Returns:
38
+ feature reshaped feature
39
+ truth reshaped truth
40
+ """
41
+ frames, strides, feature_parameters = feature.shape
42
+ if truth is not None:
43
+ truth_frames, num_classes = truth.shape
44
+ # Double-check correctness of inputs
45
+ if frames != truth_frames:
46
+ raise ValueError("Frames in feature and truth do not match")
47
+ else:
48
+ num_classes = 0
49
+
50
+ if flatten:
51
+ feature = np.reshape(feature, (frames, strides * feature_parameters))
52
+
53
+ # Reshape for Keras/TF recurrent models that require timesteps/sequence length dimension
54
+ if timesteps > 0:
55
+ sequences = frames // timesteps
56
+
57
+ # Remove frames if remainder exists (not fitting into a multiple of new number of sequences)
58
+ frames_rem = frames % timesteps
59
+ batch_rem = (frames // timesteps) % batch_size
60
+ bf_rem = batch_rem * timesteps
61
+ sequences = sequences - batch_rem
62
+ fr2drop = frames_rem + bf_rem
63
+ if fr2drop:
64
+ if feature.ndim == 2:
65
+ feature = feature[0:-fr2drop,] # flattened input
66
+ elif feature.ndim == 3:
67
+ feature = feature[0:-fr2drop,] # un-flattened input
68
+
69
+ if truth is not None:
70
+ truth = truth[0:-fr2drop,]
71
+
72
+ # Reshape
73
+ if feature.ndim == 2: # flattened input
74
+ # was [frames, feature_parameters*timesteps]
75
+ feature = np.reshape(feature, (sequences, timesteps, strides * feature_parameters))
76
+ if truth is not None:
77
+ # was [frames, num_classes]
78
+ truth = np.reshape(truth, (sequences, timesteps, num_classes))
79
+ elif feature.ndim == 3: # un-flattened input
80
+ # was [frames, feature_parameters, timesteps]
81
+ feature = np.reshape(feature, (sequences, timesteps, strides, feature_parameters))
82
+ if truth is not None:
83
+ # was [frames, num_classes]
84
+ truth = np.reshape(truth, (sequences, timesteps, num_classes))
85
+ else:
86
+ # Drop frames if remainder exists (not fitting into a multiple of new number of sequences)
87
+ fr2drop = feature.shape[0] % batch_size
88
+ if fr2drop > 0:
89
+ feature = feature[0:-fr2drop,]
90
+ if truth is not None:
91
+ truth = truth[0:-fr2drop,]
92
+
93
+ # Add channel dimension if required for input to model (i.e. for cnn type input)
94
+ if add1ch:
95
+ feature = np.expand_dims(feature, axis=feature.ndim) # add as last/outermost dim
96
+
97
+ return feature, truth
98
+
99
+
100
+ def get_num_classes_from_predict(predict: Predict, timesteps: int = 0) -> int:
101
+ num_dims = predict.ndim
102
+ dims = predict.shape
103
+
104
+ if num_dims == 3 or (num_dims == 2 and timesteps > 0):
105
+ # 2D with timesteps - [frames, timesteps]
106
+ if num_dims == 2:
107
+ return 1
108
+
109
+ # 3D - [frames, timesteps, num_classes]
110
+ return dims[2]
111
+
112
+ # 1D - [frames]
113
+ if num_dims == 1:
114
+ return 1
115
+
116
+ # 2D without timesteps - [frames, num_classes]
117
+ return dims[1]
118
+
119
+
120
+ def reshape_outputs(predict: Predict, truth: Truth | None = None, timesteps: int = 0) -> tuple[Predict, Truth | None]:
121
+ """Reshape model output data.
122
+
123
+ truth and predict can be either [frames, num_classes], or [frames, timesteps, num_classes]
124
+ In binary case, num_classes dim may not exist; detect this and set num_classes to 1.
125
+ """
126
+ if truth is not None and predict.shape != truth.shape:
127
+ raise ValueError("predict and truth shapes do not match")
128
+
129
+ ndim = predict.ndim
130
+ shape = predict.shape
131
+
132
+ if not (0 < ndim <= 3):
133
+ raise ValueError(f"do not know how to reshape data with {ndim} dimensions")
134
+
135
+ if ndim == 3 or (ndim == 2 and timesteps > 0):
136
+ if ndim == 2:
137
+ # 2D with timesteps - [frames, timesteps]
138
+ num_classes = 1
139
+ else:
140
+ # 3D - [frames, timesteps, num_classes]
141
+ num_classes = shape[2]
142
+
143
+ # reshape to remove timestep dimension
144
+ shape = (shape[0] * shape[1], num_classes)
145
+ predict = np.reshape(predict, shape)
146
+ if truth is not None:
147
+ truth = np.reshape(truth, shape)
148
+ elif ndim == 1:
149
+ # convert to 2D - [frames, 1]
150
+ predict = np.expand_dims(predict, 1)
151
+ if truth is not None:
152
+ truth = np.expand_dims(truth, 1)
153
+
154
+ return predict, truth
@@ -0,0 +1,7 @@
1
+ def seconds_to_hms(seconds: float) -> str:
2
+ """Convert given seconds into string of H:MM:SS"""
3
+ h = int(seconds / 3600)
4
+ s = seconds - h * 3600
5
+ m = int(s / 60)
6
+ s = int(seconds - h * 3600 - m * 60)
7
+ return f"{h:d}:{m:02d}:{s:02d} (H:MM:SS)"
@@ -0,0 +1,82 @@
1
+ import numpy as np
2
+
3
+
4
+ def stack_complex(unstacked: np.ndarray) -> np.ndarray:
5
+ """Stack a complex array
6
+
7
+ A stacked array doubles the last dimension and organizes the data as:
8
+ - first half is all the real data
9
+ - second half is all the imaginary data
10
+
11
+ :param unstacked: An nD array (n > 1) containing complex data
12
+ :return: A stacked array
13
+ :raises TypeError:
14
+ """
15
+ if not unstacked.ndim > 1:
16
+ raise ValueError("unstacked must have more than 1 dimension")
17
+
18
+ shape = list(unstacked.shape)
19
+ shape[-1] = shape[-1] * 2
20
+ stacked = np.empty(shape, dtype=np.float32)
21
+ half = unstacked.shape[-1]
22
+ stacked[..., :half] = np.real(unstacked)
23
+ stacked[..., half:] = np.imag(unstacked)
24
+
25
+ return stacked
26
+
27
+
28
+ def unstack_complex(stacked: np.ndarray) -> np.ndarray:
29
+ """Unstack a stacked complex array
30
+
31
+ :param stacked: An nD array (n > 1) where the last dimension contains stacked complex data in which the first half
32
+ is all the real data and the second half is all the imaginary data
33
+ :return: An unstacked complex array
34
+ :raises TypeError:
35
+ """
36
+ if not stacked.ndim > 1:
37
+ raise ValueError("stacked must have more than 1 dimension")
38
+
39
+ if stacked.shape[-1] % 2 != 0:
40
+ raise ValueError("last dimension of stacked must be a multiple of 2")
41
+
42
+ half = stacked.shape[-1] // 2
43
+ unstacked = 1j * stacked[..., half:]
44
+ unstacked += stacked[..., :half]
45
+
46
+ return unstacked
47
+
48
+
49
+ def stacked_complex_real(stacked: np.ndarray) -> np.ndarray:
50
+ """Get the real elements from a stacked complex array
51
+
52
+ :param stacked: An nD array (n > 1) where the last dimension contains stacked complex data in which the first half
53
+ is all the real data and the second half is all the imaginary data
54
+ :return: The real elements
55
+ :raises TypeError:
56
+ """
57
+ if not stacked.ndim > 1:
58
+ raise ValueError("stacked must have more than 1 dimension")
59
+
60
+ if stacked.shape[-1] % 2 != 0:
61
+ raise ValueError("last dimension of stacked must be a multiple of 2")
62
+
63
+ half = stacked.shape[-1] // 2
64
+ return stacked[..., :half]
65
+
66
+
67
+ def stacked_complex_imag(stacked: np.ndarray) -> np.ndarray:
68
+ """Get the imaginary elements from a stacked complex array
69
+
70
+ :param stacked: An nD array (n > 1) where the last dimension contains stacked complex data in which the first half
71
+ is all the real data and the second half is all the imaginary data
72
+ :return: The imaginary elements
73
+ :raises TypeError:
74
+ """
75
+ if not stacked.ndim > 1:
76
+ raise ValueError("stacked must have more than 1 dimension")
77
+
78
+ if stacked.shape[-1] % 2 != 0:
79
+ raise ValueError("last dimension of stacked must be a multiple of 2")
80
+
81
+ half = stacked.shape[-1] // 2
82
+ return stacked[..., half:]
@@ -0,0 +1,170 @@
1
+ import numpy as np
2
+
3
+ from ..mixture.mixdb import MixtureDatabase
4
+
5
+
6
+ def stratified_shuffle_split_mixid(
7
+ mixdb: MixtureDatabase,
8
+ vsplit: float = 0.2,
9
+ nsplit: int = 0,
10
+ rnd_seed: int | None = 0,
11
+ ) -> tuple[list[int], list[int], np.ndarray, np.ndarray]:
12
+ """
13
+ Create a training and test/validation list of mixture IDs from all mixtures in a mixture database.
14
+ The test/validation split is specified by vsplit (0.0 to 1.0), default 0.2.
15
+ The mixtures are randomly shuffled by rnd_seed; set to int for repeatability, or None for no shuffle.
16
+ The mixtures are then stratified across all populated classes.
17
+
18
+ Inputs:
19
+ mixdb: Mixture database created by Aaware SonusAI genmixdb.
20
+ vsplit: Fractional split of mixtures for validation, 1-vsplit for training.
21
+ nsplit: Number of splits (TBD).
22
+ rnd_seed: Seed integer for reproducible random shuffling (or None for no shuffling).
23
+
24
+ Outputs:
25
+ t_mixid: list of mixture IDs for training
26
+ v_mixid: list of mixture IDs for validation
27
+ t_num_mixid: list of class counts in t_mixid
28
+ v_num_mixid: list of class counts in v_mixid
29
+
30
+ Examples:
31
+ t_mixid, v_mixid, t_num_mixid, v_num_mixid = stratified_shuffle_split_mixid(mixdb, vsplit=vsplit)
32
+
33
+ @author: Chris Eddington
34
+ """
35
+ import random
36
+ from copy import deepcopy
37
+
38
+ from .. import logger
39
+ from ..mixture.class_count import get_class_count_from_mixids
40
+
41
+ if vsplit < 0 or vsplit > 1:
42
+ raise ValueError("vsplit must be between 0 and 1")
43
+
44
+ a_class_mixid: dict[int, list[int]] = {i + 1: [] for i in range(mixdb.num_classes)}
45
+ for mixid, mixture in enumerate(mixdb.mixtures()):
46
+ class_count = get_class_count_from_mixids(mixdb, mixid)
47
+ if any(class_count):
48
+ for class_index in mixdb.target_files[mixture.targets[0].file_id].class_indices:
49
+ a_class_mixid[class_index].append(mixid)
50
+ else:
51
+ # no counts and mutex mode means this is all 'other' class
52
+ a_class_mixid[mixdb.num_classes].append(mixid)
53
+
54
+ t_class_mixid: list[list[int]] = [[] for _ in range(mixdb.num_classes)]
55
+ v_class_mixid: list[list[int]] = [[] for _ in range(mixdb.num_classes)]
56
+
57
+ a_num_mixid = np.zeros(mixdb.num_classes, dtype=np.int32)
58
+ t_num_mixid = np.zeros(mixdb.num_classes, dtype=np.int32)
59
+ v_num_mixid = np.zeros(mixdb.num_classes, dtype=np.int32)
60
+
61
+ if rnd_seed is not None:
62
+ random.seed(rnd_seed)
63
+
64
+ # For each class pick percentage of shuffled mixids for training, validation
65
+ for ci in range(mixdb.num_classes):
66
+ # total number of mixids for class
67
+ a_num_mixid[ci] = len(a_class_mixid[ci + 1])
68
+
69
+ # number of training mixids for class
70
+ t_num_mixid[ci] = int(np.floor(a_num_mixid[ci] * (1 - vsplit)))
71
+
72
+ # number of validation mixids for class
73
+ v_num_mixid[ci] = a_num_mixid[ci] - t_num_mixid[ci]
74
+
75
+ # indices for all mixids in class
76
+ indices = [*range(a_num_mixid[ci])]
77
+ if rnd_seed is not None:
78
+ # randomize order
79
+ random.shuffle(indices)
80
+
81
+ t_class_mixid[ci] = [a_class_mixid[ci + 1][ii] for ii in indices[0 : t_num_mixid[ci]]]
82
+ v_class_mixid[ci] = [a_class_mixid[ci + 1][ii] for ii in indices[t_num_mixid[ci] :]]
83
+
84
+ if np.any(~(t_num_mixid > 0)):
85
+ logger.warning(f"Some classes have zero coverage: {np.where(~(t_num_mixid > 0))[0]}")
86
+
87
+ # Stratify over non-zero classes
88
+ nz_indices = np.where(t_num_mixid > 0)[0]
89
+ # First stratify pass is min count / 3 times through all classes, one each least populated class count (of non-zero)
90
+ min_class = min(t_num_mixid[nz_indices])
91
+ # number of mixids in each class for stratify by 1
92
+ n0 = int(np.ceil(min_class / 3))
93
+ # 3rd stage for stratify by 1
94
+ n3 = int(n0)
95
+ # 2nd stage stratify by class_count/min(class_count-n3) n2 times
96
+ n2 = int(max(min_class - n0 - n3, 0))
97
+
98
+ logger.info(
99
+ f"Stratifying training, x1 cnt {n0}: x(class_count/{n2}): x1 cnt {n3} x1, "
100
+ f"for {len(nz_indices)} populated classes"
101
+ )
102
+
103
+ # initialize source list
104
+ tt = deepcopy(t_class_mixid)
105
+ t_num_mixid2 = deepcopy(t_num_mixid)
106
+ t_mixid = []
107
+ for _ in range(n0):
108
+ for ci in range(mixdb.num_classes):
109
+ if t_num_mixid2[ci] > 0:
110
+ # append first
111
+ t_mixid.append(tt[ci][0])
112
+ del tt[ci][0]
113
+ t_num_mixid2[ci] = len(tt[ci])
114
+
115
+ # Now extract weighted by how many are left in class minus n3
116
+ # which will leave approx n3 remaining
117
+ if n2 > 0:
118
+ # should always be non-zero
119
+ min_class = int(np.min(t_num_mixid2 - n3))
120
+ class_count = np.floor((t_num_mixid2 - n3) / min_class)
121
+ # class_count = np.maximum(np.floor((t_num_mixid2 - n3) / n2),0) # Counts per class
122
+ for _ in range(min_class):
123
+ for ci in range(mixdb.num_classes):
124
+ if class_count[ci] > 0:
125
+ for _ in range(int(class_count[ci])):
126
+ # append first
127
+ t_mixid.append(tt[ci][0])
128
+ del tt[ci][0]
129
+ t_num_mixid2[ci] = len(tt[ci])
130
+
131
+ # Now extract remaining mixids, one each class until empty
132
+ # There should be ~n3 remaining mixids in each
133
+ t_mixid = _extract_remaining_mixids(mixdb, t_mixid, t_num_mixid2, tt)
134
+
135
+ if len(t_mixid) != sum(t_num_mixid):
136
+ logger.warning("Final stratified training list length does not match starting list length.")
137
+
138
+ if any(t_num_mixid2) or any(tt):
139
+ logger.warning("Remaining training mixid list not empty.")
140
+
141
+ # Now stratify the validation list, which is probably not as important, so use simple method
142
+ # initialize source list
143
+ vv = deepcopy(v_class_mixid)
144
+ v_num_mixid2 = deepcopy(v_num_mixid)
145
+ v_mixid = _extract_remaining_mixids(mixdb, [], v_num_mixid2, vv)
146
+
147
+ if len(v_mixid) != sum(v_num_mixid):
148
+ logger.warning("Final stratified validation list length does not match starting lists length.")
149
+
150
+ if any(v_num_mixid2) or any(vv):
151
+ logger.warning("Remaining validation mixid list not empty.")
152
+
153
+ return t_mixid, v_mixid, t_num_mixid, v_num_mixid
154
+
155
+
156
+ def _extract_remaining_mixids(
157
+ mixdb: MixtureDatabase,
158
+ mixid: list[int],
159
+ num_mixid: np.ndarray,
160
+ class_mixid: list[list[int]],
161
+ ) -> list[int]:
162
+ for _ in range(max(num_mixid)):
163
+ for ci in range(mixdb.num_classes):
164
+ if num_mixid[ci] > 0:
165
+ # append first
166
+ mixid.append(class_mixid[ci][0])
167
+ del class_mixid[ci][0]
168
+ num_mixid[ci] = len(class_mixid[ci])
169
+
170
+ return mixid
@@ -0,0 +1,143 @@
1
+ from pathlib import Path
2
+
3
+
4
+ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
5
+ """Expand shell variables of the forms $var, ${var} and %var%.
6
+ Unknown variables are left unchanged.
7
+
8
+ Expand paths containing shell variable substitutions. The following rules apply:
9
+ - no expansion within single quotes
10
+ - '$$' is translated into '$'
11
+ - '%%' is translated into '%' if '%%' are not seen in %var1%%var2%
12
+ - ${var} is accepted.
13
+ - $varname is accepted.
14
+ - %var% is accepted.
15
+ - vars can be made out of letters, digits and the characters '_-'
16
+ (though is not verified in the ${var} and %var% cases)
17
+
18
+ :param name: String to expand
19
+ :return: Tuple of (expanded string, dictionary of tokens)
20
+ """
21
+ import os
22
+ import string
23
+
24
+ from ..constants import DEFAULT_NOISE
25
+
26
+ os.environ["default_noise"] = str(DEFAULT_NOISE) # noqa: SIM112
27
+
28
+ if isinstance(name, bytes):
29
+ name = name.decode("utf-8")
30
+
31
+ if isinstance(name, Path):
32
+ name = name.as_posix()
33
+
34
+ name = os.fspath(name)
35
+ token_map: dict = {}
36
+
37
+ if "$" not in name and "%" not in name:
38
+ return name, token_map
39
+
40
+ var_chars = string.ascii_letters + string.digits + "_-"
41
+ quote = "'"
42
+ percent = "%"
43
+ brace = "{"
44
+ rbrace = "}"
45
+ dollar = "$"
46
+ environ = os.environ
47
+
48
+ result = name[:0]
49
+ index = 0
50
+ path_len = len(name)
51
+ while index < path_len:
52
+ c = name[index : index + 1]
53
+ if c == quote: # no expansion within single quotes
54
+ name = name[index + 1 :]
55
+ path_len = len(name)
56
+ try:
57
+ index = name.index(c)
58
+ result += c + name[: index + 1]
59
+ except ValueError:
60
+ result += c + name
61
+ index = path_len - 1
62
+ elif c == percent: # variable or '%'
63
+ if name[index + 1 : index + 2] == percent:
64
+ result += c
65
+ index += 1
66
+ else:
67
+ name = name[index + 1 :]
68
+ path_len = len(name)
69
+ try:
70
+ index = name.index(percent)
71
+ except ValueError:
72
+ result += percent + name
73
+ index = path_len - 1
74
+ else:
75
+ var = name[:index]
76
+ try:
77
+ if environ is None:
78
+ value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
79
+ else:
80
+ value = environ[var]
81
+ token_map[var] = value
82
+ except KeyError:
83
+ value = percent + var + percent
84
+ result += value
85
+ elif c == dollar: # variable or '$$'
86
+ if name[index + 1 : index + 2] == dollar:
87
+ result += c
88
+ index += 1
89
+ elif name[index + 1 : index + 2] == brace:
90
+ name = name[index + 2 :]
91
+ path_len = len(name)
92
+ try:
93
+ index = name.index(rbrace)
94
+ except ValueError:
95
+ result += dollar + brace + name
96
+ index = path_len - 1
97
+ else:
98
+ var = name[:index]
99
+ try:
100
+ if environ is None:
101
+ value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
102
+ else:
103
+ value = environ[var]
104
+ token_map[var] = value
105
+ except KeyError:
106
+ value = dollar + brace + var + rbrace
107
+ result += value
108
+ else:
109
+ var = name[:0]
110
+ index += 1
111
+ c = name[index : index + 1]
112
+ while c and c in var_chars:
113
+ var += c
114
+ index += 1
115
+ c = name[index : index + 1]
116
+ try:
117
+ if environ is None:
118
+ value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
119
+ else:
120
+ value = environ[var]
121
+ token_map[var] = value
122
+ except KeyError:
123
+ value = dollar + var
124
+ result += value
125
+ if c:
126
+ index -= 1
127
+ else:
128
+ result += c
129
+ index += 1
130
+
131
+ return result, token_map
132
+
133
+
134
+ def tokenized_replace(name: str, tokens: dict[str, str]) -> str:
135
+ """Replace text with shell variables.
136
+
137
+ :param name: String to replace
138
+ :param tokens: Dictionary of replacement tokens
139
+ :return: replaced string
140
+ """
141
+ for key, value in tokens.items():
142
+ name = name.replace(value, f"${key}")
143
+ return name
@@ -0,0 +1,26 @@
1
+ from ..constants import SAMPLE_RATE
2
+ from ..datatypes import AudioT
3
+
4
+
5
+ def write_audio(name: str, audio: AudioT, sample_rate: int = SAMPLE_RATE) -> None:
6
+ """Write an audio file.
7
+
8
+ To write multiple channels, use a 2D array of shape [channels, samples].
9
+ The bits per sample and PCM/float are determined by the data type.
10
+
11
+ """
12
+ import torch
13
+ import torchaudio
14
+
15
+ data = torch.tensor(audio)
16
+
17
+ if data.dim() == 1:
18
+ data = torch.reshape(data, (1, data.shape[0]))
19
+ if data.dim() != 2:
20
+ raise ValueError("audio must be a 1D or 2D array")
21
+
22
+ # Assuming data has more samples than channels, check if array needs to be transposed
23
+ if data.shape[1] < data.shape[0]:
24
+ data = torch.transpose(data, 0, 1)
25
+
26
+ torchaudio.save(uri=name, src=data, sample_rate=sample_rate)
@@ -0,0 +1,8 @@
1
+ def yes_or_no(question: str) -> bool:
2
+ """Wait for yes or no input"""
3
+ while True:
4
+ reply = str(input(question + " (y/n)?: ")).lower().strip()
5
+ if reply[:1] == "y":
6
+ return True
7
+ if reply[:1] == "n":
8
+ return False
sonusai/vars.py ADDED
@@ -0,0 +1,47 @@
1
+ """sonusai vars
2
+
3
+ usage: vars [-h]
4
+
5
+ options:
6
+ -h, --help Display this help.
7
+
8
+ List custom SonusAI variables.
9
+
10
+ """
11
+
12
+
13
+ def main() -> None:
14
+ from docopt import docopt
15
+
16
+ from sonusai import __version__ as sai_version
17
+ from sonusai.utils.docstring import trim_docstring
18
+
19
+ docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
20
+
21
+ from os import environ
22
+ from os import getenv
23
+
24
+ from sonusai.constants import DEFAULT_NOISE
25
+
26
+ print("Custom SonusAI variables:")
27
+ print("")
28
+ print(f"${{default_noise}}: {DEFAULT_NOISE}")
29
+ print("")
30
+ print("SonusAI recognized environment variables:")
31
+ print("")
32
+ print(f"DEEPGRAM_API_KEY {getenv('DEEPGRAM_API_KEY')}")
33
+ print(f"GOOGLE_SPEECH_API_KEY {getenv('GOOGLE_SPEECH_API_KEY')}")
34
+ print("")
35
+ items = ["DEEPGRAM_API_KEY", "GOOGLE_SPEECH_API_KEY"]
36
+ items += [item for item in environ if item.upper().startswith("AIXP_WHISPER_")]
37
+
38
+
39
+ if __name__ == "__main__":
40
+ from sonusai import exception_handler
41
+ from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
42
+
43
+ register_keyboard_interrupt()
44
+ try:
45
+ main()
46
+ except Exception as e:
47
+ exception_handler(e)