sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +241 -77
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +25 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -293
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +4 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +478 -628
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +910 -729
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
- sonusai-1.0.2.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.3.dist-info/RECORD +0 -128
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
@@ -15,6 +15,7 @@ AudioT: TypeAlias = npt.NDArray[np.float32]
|
|
15
15
|
|
16
16
|
Truth: TypeAlias = Any
|
17
17
|
TruthDict: TypeAlias = dict[str, Truth]
|
18
|
+
TruthsDict: TypeAlias = dict[str, TruthDict]
|
18
19
|
Segsnr: TypeAlias = npt.NDArray[np.float32]
|
19
20
|
|
20
21
|
AudioF: TypeAlias = npt.NDArray[np.complex64]
|
@@ -68,63 +69,34 @@ class TruthConfig(DataClassSonusAIMixin):
|
|
68
69
|
|
69
70
|
|
70
71
|
TruthConfigs: TypeAlias = dict[str, TruthConfig]
|
72
|
+
TruthsConfigs: TypeAlias = dict[str, TruthConfigs]
|
73
|
+
|
74
|
+
|
71
75
|
NumberStr: TypeAlias = float | int | str
|
72
76
|
OptionalNumberStr: TypeAlias = NumberStr | None
|
73
77
|
OptionalListNumberStr: TypeAlias = list[NumberStr] | None
|
74
|
-
EQ: TypeAlias = tuple[float | int, float | int, float | int]
|
75
|
-
|
76
|
-
|
77
|
-
@dataclass
|
78
|
-
class AugmentationRuleEffects(DataClassSonusAIMixin):
|
79
|
-
normalize: OptionalNumberStr = None
|
80
|
-
pitch: OptionalNumberStr = None
|
81
|
-
tempo: OptionalNumberStr = None
|
82
|
-
gain: OptionalNumberStr = None
|
83
|
-
eq1: OptionalListNumberStr = None
|
84
|
-
eq2: OptionalListNumberStr = None
|
85
|
-
eq3: OptionalListNumberStr = None
|
86
|
-
lpf: OptionalNumberStr = None
|
87
|
-
ir: OptionalNumberStr = None
|
88
|
-
|
89
|
-
|
90
|
-
@dataclass
|
91
|
-
class AugmentationRule(DataClassSonusAIMixin):
|
92
|
-
pre: AugmentationRuleEffects
|
93
|
-
post: AugmentationRuleEffects | None = None
|
94
|
-
mixup: int = 1
|
95
78
|
|
96
79
|
|
97
|
-
|
98
|
-
class AugmentationEffects(DataClassSonusAIMixin):
|
99
|
-
normalize: float | None = None
|
100
|
-
pitch: float | None = None
|
101
|
-
tempo: float | None = None
|
102
|
-
gain: float | None = None
|
103
|
-
eq1: EQ | None = None
|
104
|
-
eq2: EQ | None = None
|
105
|
-
eq3: EQ | None = None
|
106
|
-
lpf: float | None = None
|
107
|
-
ir: int | None = None
|
80
|
+
EffectList: TypeAlias = list[str]
|
108
81
|
|
109
82
|
|
110
83
|
@dataclass
|
111
|
-
class
|
112
|
-
pre:
|
113
|
-
post:
|
84
|
+
class Effects(DataClassSonusAIMixin):
|
85
|
+
pre: EffectList
|
86
|
+
post: EffectList = field(default_factory=EffectList)
|
114
87
|
|
115
88
|
|
116
|
-
@dataclass(frozen=True)
|
117
89
|
class UniversalSNRGenerator:
|
118
|
-
|
119
|
-
|
90
|
+
def __init__(self, raw_value: float | str) -> None:
|
91
|
+
self._raw_value = str(raw_value)
|
92
|
+
self.is_random = isinstance(raw_value, str) and raw_value.startswith("sai_rand")
|
120
93
|
|
121
94
|
@property
|
122
95
|
def value(self) -> float:
|
123
|
-
|
124
|
-
from .augmentation import evaluate_random_rule
|
125
|
-
|
126
|
-
return float(evaluate_random_rule(str(self._raw_value)))
|
96
|
+
from sonusai.mixture.effects import evaluate_sai_random_float
|
127
97
|
|
98
|
+
if self.is_random:
|
99
|
+
return float(evaluate_sai_random_float(self._raw_value))
|
128
100
|
return float(self._raw_value)
|
129
101
|
|
130
102
|
|
@@ -145,12 +117,14 @@ Speaker: TypeAlias = dict[str, str]
|
|
145
117
|
|
146
118
|
|
147
119
|
@dataclass
|
148
|
-
class
|
120
|
+
class SourceFile(DataClassSonusAIMixin):
|
121
|
+
category: str
|
122
|
+
class_indices: list[int]
|
149
123
|
name: str
|
150
124
|
samples: int
|
151
|
-
class_indices: list[int]
|
152
125
|
truth_configs: TruthConfigs
|
153
|
-
|
126
|
+
class_balancing_effect: EffectList | None = None
|
127
|
+
id: int = -1
|
154
128
|
level_type: str | None = None
|
155
129
|
speaker_id: int | None = None
|
156
130
|
|
@@ -162,21 +136,9 @@ class TargetFile(DataClassSonusAIMixin):
|
|
162
136
|
|
163
137
|
|
164
138
|
@dataclass
|
165
|
-
class
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
@dataclass
|
171
|
-
class NoiseFile(DataClassSonusAIMixin):
|
172
|
-
name: str
|
173
|
-
samples: int
|
174
|
-
|
175
|
-
@property
|
176
|
-
def duration(self) -> float:
|
177
|
-
from .constants import SAMPLE_RATE
|
178
|
-
|
179
|
-
return self.samples / SAMPLE_RATE
|
139
|
+
class EffectedFile(DataClassSonusAIMixin):
|
140
|
+
file_id: int
|
141
|
+
effect_id: int
|
180
142
|
|
181
143
|
|
182
144
|
ClassCount: TypeAlias = list[int]
|
@@ -184,37 +146,6 @@ ClassCount: TypeAlias = list[int]
|
|
184
146
|
GeneralizedIDs: TypeAlias = str | int | list[int] | range
|
185
147
|
|
186
148
|
|
187
|
-
@dataclass
|
188
|
-
class GenMixData:
|
189
|
-
targets: list[AudioT] | None = None
|
190
|
-
target: AudioT | None = None
|
191
|
-
noise: AudioT | None = None
|
192
|
-
mixture: AudioT | None = None
|
193
|
-
truth_t: list[TruthDict] | None = None
|
194
|
-
segsnr_t: Segsnr | None = None
|
195
|
-
|
196
|
-
|
197
|
-
@dataclass
|
198
|
-
class GenFTData:
|
199
|
-
feature: Feature | None = None
|
200
|
-
truth_f: TruthDict | None = None
|
201
|
-
segsnr: Segsnr | None = None
|
202
|
-
|
203
|
-
|
204
|
-
@dataclass
|
205
|
-
class ImpulseResponseData:
|
206
|
-
data: AudioT
|
207
|
-
sample_rate: int
|
208
|
-
delay: int
|
209
|
-
|
210
|
-
|
211
|
-
@dataclass
|
212
|
-
class ImpulseResponseFile:
|
213
|
-
file: str
|
214
|
-
tags: list[str]
|
215
|
-
delay: int
|
216
|
-
|
217
|
-
|
218
149
|
@dataclass(frozen=True)
|
219
150
|
class SpectralMask(DataClassSonusAIMixin):
|
220
151
|
f_max_width: int
|
@@ -226,68 +157,70 @@ class SpectralMask(DataClassSonusAIMixin):
|
|
226
157
|
|
227
158
|
@dataclass(frozen=True)
|
228
159
|
class TruthParameter(DataClassSonusAIMixin):
|
160
|
+
category: str
|
229
161
|
name: str
|
230
162
|
parameters: int | None
|
231
163
|
|
232
164
|
|
233
165
|
@dataclass
|
234
|
-
class
|
166
|
+
class Source(DataClassSonusAIMixin):
|
167
|
+
effects: Effects
|
235
168
|
file_id: int
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
# to its normalized level when calculating truth (if needed).
|
242
|
-
if self.augmentation.pre.gain is None:
|
243
|
-
return 1.0
|
244
|
-
return round(10 ** (self.augmentation.pre.gain / 20), ndigits=5)
|
169
|
+
pre_tempo: float = 1
|
170
|
+
repeat: bool = False
|
171
|
+
snr: UniversalSNR = field(default_factory=lambda: UniversalSNR(0))
|
172
|
+
snr_gain: float = 0
|
173
|
+
start: int = 0
|
245
174
|
|
246
175
|
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
@dataclass
|
251
|
-
class Noise(DataClassSonusAIMixin):
|
252
|
-
file_id: int
|
253
|
-
augmentation: Augmentation
|
176
|
+
Sources: TypeAlias = dict[str, Source]
|
177
|
+
SourcesAudioT: TypeAlias = dict[str, AudioT]
|
178
|
+
SourcesAudioF: TypeAlias = dict[str, AudioF]
|
254
179
|
|
255
180
|
|
256
181
|
@dataclass
|
257
182
|
class Mixture(DataClassSonusAIMixin):
|
258
183
|
name: str
|
259
|
-
targets: list[Target]
|
260
|
-
noise: Noise
|
261
|
-
noise_offset: int
|
262
184
|
samples: int
|
263
|
-
|
185
|
+
all_sources: Sources
|
264
186
|
spectral_mask_id: int
|
265
187
|
spectral_mask_seed: int
|
266
|
-
|
267
|
-
|
188
|
+
|
189
|
+
@property
|
190
|
+
def all_source_ids(self) -> dict[str, int]:
|
191
|
+
return {category: source.file_id for category, source in self.all_sources.items()}
|
192
|
+
|
193
|
+
@property
|
194
|
+
def sources(self) -> Sources:
|
195
|
+
return {category: source for category, source in self.all_sources.items() if category != "noise"}
|
196
|
+
|
197
|
+
@property
|
198
|
+
def source_ids(self) -> dict[str, int]:
|
199
|
+
return {category: source.file_id for category, source in self.sources.items()}
|
200
|
+
|
201
|
+
@property
|
202
|
+
def noise(self) -> Source:
|
203
|
+
return self.all_sources["noise"]
|
268
204
|
|
269
205
|
@property
|
270
206
|
def noise_id(self) -> int:
|
271
207
|
return self.noise.file_id
|
272
208
|
|
273
209
|
@property
|
274
|
-
def
|
275
|
-
return
|
210
|
+
def source_effects(self) -> dict[str, Effects]:
|
211
|
+
return {category: source.effects for category, source in self.sources.items()}
|
276
212
|
|
277
213
|
@property
|
278
|
-
def
|
279
|
-
return
|
214
|
+
def noise_effects(self) -> Effects:
|
215
|
+
return self.noise.effects
|
280
216
|
|
281
217
|
@property
|
282
218
|
def is_noise_only(self) -> bool:
|
283
|
-
return self.snr < -96
|
219
|
+
return self.noise.snr < -96
|
284
220
|
|
285
221
|
@property
|
286
|
-
def
|
287
|
-
return self.snr > 96
|
288
|
-
|
289
|
-
def target_gain(self, target_index: int) -> float:
|
290
|
-
return (self.targets[target_index].gain if not self.is_noise_only else 0) * self.target_snr_gain
|
222
|
+
def is_source_only(self) -> bool:
|
223
|
+
return self.noise.snr > 96
|
291
224
|
|
292
225
|
|
293
226
|
@dataclass(frozen=True)
|
@@ -302,7 +235,7 @@ class TransformConfig:
|
|
302
235
|
@dataclass(frozen=True)
|
303
236
|
class FeatureGeneratorConfig:
|
304
237
|
feature_mode: str
|
305
|
-
truth_parameters: dict[str, int | None]
|
238
|
+
truth_parameters: dict[str, dict[str, int | None]]
|
306
239
|
|
307
240
|
|
308
241
|
@dataclass(frozen=True)
|
@@ -319,6 +252,37 @@ class FeatureGeneratorInfo:
|
|
319
252
|
ASRConfigs: TypeAlias = dict[str, dict[str, Any]]
|
320
253
|
|
321
254
|
|
255
|
+
@dataclass
|
256
|
+
class GenMixData:
|
257
|
+
mixture: AudioT | None = None
|
258
|
+
truth_t: TruthsDict | None = None
|
259
|
+
segsnr_t: Segsnr | None = None
|
260
|
+
sources: SourcesAudioT | None = None
|
261
|
+
source: AudioT | None = None
|
262
|
+
noise: AudioT | None = None
|
263
|
+
|
264
|
+
|
265
|
+
@dataclass
|
266
|
+
class GenFTData:
|
267
|
+
feature: Feature | None = None
|
268
|
+
truth_f: TruthsDict | None = None
|
269
|
+
segsnr: Segsnr | None = None
|
270
|
+
|
271
|
+
|
272
|
+
@dataclass
|
273
|
+
class ImpulseResponseData:
|
274
|
+
data: AudioT
|
275
|
+
sample_rate: int
|
276
|
+
delay: int
|
277
|
+
|
278
|
+
|
279
|
+
@dataclass
|
280
|
+
class ImpulseResponseFile(DataClassSonusAIMixin):
|
281
|
+
name: str
|
282
|
+
tags: list[str]
|
283
|
+
delay: str | int = "auto"
|
284
|
+
|
285
|
+
|
322
286
|
@dataclass
|
323
287
|
class MixtureDatabaseConfig(DataClassSonusAIMixin):
|
324
288
|
asr_configs: ASRConfigs
|
@@ -326,13 +290,11 @@ class MixtureDatabaseConfig(DataClassSonusAIMixin):
|
|
326
290
|
class_labels: list[str]
|
327
291
|
class_weights_threshold: list[float]
|
328
292
|
feature: str
|
329
|
-
|
293
|
+
ir_files: list[ImpulseResponseFile]
|
330
294
|
mixtures: list[Mixture]
|
331
|
-
noise_mix_mode: str
|
332
|
-
noise_files: list[NoiseFile]
|
333
295
|
num_classes: int
|
296
|
+
source_files: dict[str, list[SourceFile]]
|
334
297
|
spectral_masks: list[SpectralMask]
|
335
|
-
target_files: list[TargetFile]
|
336
298
|
|
337
299
|
|
338
300
|
SpeechMetadata: TypeAlias = str | list[Interval] | None
|
@@ -353,7 +315,6 @@ class SnrFBinMetrics(NamedTuple):
|
|
353
315
|
|
354
316
|
|
355
317
|
class SpeechMetrics(NamedTuple):
|
356
|
-
pesq: float | None = None
|
357
318
|
csig: float | None = None
|
358
319
|
cbak: float | None = None
|
359
320
|
covl: float | None = None
|
sonusai/deprecated/plot.py
CHANGED
@@ -45,11 +45,10 @@ import signal
|
|
45
45
|
|
46
46
|
import numpy as np
|
47
47
|
from matplotlib import pyplot as plt
|
48
|
-
|
49
|
-
from sonusai.
|
50
|
-
from sonusai.
|
51
|
-
from sonusai.
|
52
|
-
from sonusai.mixture import Truth
|
48
|
+
from sonusai.datatypes import AudioT
|
49
|
+
from sonusai.datatypes import Feature
|
50
|
+
from sonusai.datatypes import Predict
|
51
|
+
from sonusai.datatypes import Truth
|
53
52
|
|
54
53
|
|
55
54
|
def signal_handler(_sig, _frame):
|
sonusai/doc/doc.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from
|
1
|
+
from ..mixture.config import get_default_config
|
2
2
|
|
3
3
|
|
4
4
|
def doc_seed() -> str:
|
@@ -124,7 +124,7 @@ in this dataset.
|
|
124
124
|
|
125
125
|
|
126
126
|
def get_truth_functions() -> str:
|
127
|
-
from
|
127
|
+
from ..mixture import truth_functions
|
128
128
|
|
129
129
|
functions = [function for function in dir(truth_functions) if not function.startswith("__")]
|
130
130
|
text = "\nSupported truth functions:\n\n"
|
@@ -471,7 +471,7 @@ Rules must specify all the following parameters:
|
|
471
471
|
|
472
472
|
|
473
473
|
def doc_config() -> str:
|
474
|
-
from
|
474
|
+
from ..mixture.constants import VALID_CONFIGS
|
475
475
|
|
476
476
|
text = "\n"
|
477
477
|
text += "The SonusAI database is defined using a config.yml file.\n\n"
|
@@ -482,7 +482,7 @@ def doc_config() -> str:
|
|
482
482
|
|
483
483
|
|
484
484
|
def doc_asr_configs() -> str:
|
485
|
-
from
|
485
|
+
from ..utils.asr import get_available_engines
|
486
486
|
|
487
487
|
default = f"\nDefault value: {get_default_config()['asr_configs']}"
|
488
488
|
engines = get_available_engines()
|
sonusai/doc.py
CHANGED
@@ -13,16 +13,16 @@ Show SonusAI documentation.
|
|
13
13
|
def main() -> None:
|
14
14
|
from docopt import docopt
|
15
15
|
|
16
|
-
import
|
16
|
+
from sonusai import __version__ as sai_version
|
17
17
|
from sonusai.utils import trim_docstring
|
18
18
|
|
19
|
-
args = docopt(trim_docstring(__doc__), version=
|
19
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
20
20
|
|
21
21
|
from sonusai import doc
|
22
22
|
|
23
23
|
topic = args["TOPIC"]
|
24
24
|
|
25
|
-
print(f"SonusAI {
|
25
|
+
print(f"SonusAI {sai_version} Documentation")
|
26
26
|
print("")
|
27
27
|
|
28
28
|
topics = sorted([item[4:] for item in dir(doc) if item.startswith("doc_")])
|
@@ -42,4 +42,11 @@ def main() -> None:
|
|
42
42
|
|
43
43
|
|
44
44
|
if __name__ == "__main__":
|
45
|
-
|
45
|
+
from sonusai import exception_handler
|
46
|
+
from sonusai.utils import register_keyboard_interrupt
|
47
|
+
|
48
|
+
register_keyboard_interrupt()
|
49
|
+
try:
|
50
|
+
main()
|
51
|
+
except Exception as e:
|
52
|
+
exception_handler(e)
|
sonusai/genft.py
CHANGED
@@ -25,22 +25,8 @@ Outputs the following to the mixture database directory:
|
|
25
25
|
|
26
26
|
"""
|
27
27
|
|
28
|
-
import
|
29
|
-
|
30
|
-
from sonusai.mixture import GeneralizedIDs
|
31
|
-
from sonusai.mixture import GenFTData
|
32
|
-
|
33
|
-
|
34
|
-
def signal_handler(_sig, _frame):
|
35
|
-
import sys
|
36
|
-
|
37
|
-
from sonusai import logger
|
38
|
-
|
39
|
-
logger.info("Canceled due to keyboard interrupt")
|
40
|
-
sys.exit(1)
|
41
|
-
|
42
|
-
|
43
|
-
signal.signal(signal.SIGINT, signal_handler)
|
28
|
+
from sonusai.datatypes import GeneralizedIDs
|
29
|
+
from sonusai.datatypes import GenFTData
|
44
30
|
|
45
31
|
|
46
32
|
def genft(
|
@@ -48,7 +34,7 @@ def genft(
|
|
48
34
|
mixids: GeneralizedIDs = "*",
|
49
35
|
compute_truth: bool = True,
|
50
36
|
compute_segsnr: bool = False,
|
51
|
-
|
37
|
+
cache: bool = False,
|
52
38
|
show_progress: bool = False,
|
53
39
|
force: bool = True,
|
54
40
|
no_par: bool = False,
|
@@ -69,8 +55,8 @@ def genft(
|
|
69
55
|
location=location,
|
70
56
|
compute_truth=compute_truth,
|
71
57
|
compute_segsnr=compute_segsnr,
|
58
|
+
cache=cache,
|
72
59
|
force=force,
|
73
|
-
write=write,
|
74
60
|
),
|
75
61
|
mixids,
|
76
62
|
progress=progress,
|
@@ -82,33 +68,42 @@ def genft(
|
|
82
68
|
|
83
69
|
|
84
70
|
def _genft_kernel(
|
85
|
-
m_id: int,
|
71
|
+
m_id: int,
|
72
|
+
location: str,
|
73
|
+
compute_truth: bool,
|
74
|
+
compute_segsnr: bool,
|
75
|
+
cache: bool,
|
76
|
+
force: bool,
|
86
77
|
) -> GenFTData:
|
78
|
+
from functools import partial
|
79
|
+
from typing import Any
|
80
|
+
|
87
81
|
from sonusai.mixture import MixtureDatabase
|
88
82
|
from sonusai.mixture import write_cached_data
|
89
83
|
from sonusai.mixture import write_mixture_metadata
|
90
84
|
|
91
85
|
mixdb = MixtureDatabase(location)
|
92
86
|
|
87
|
+
write_func = partial(write_cached_data, location=mixdb.location, name="mixture", index=mixdb.mixture(m_id).name)
|
88
|
+
|
93
89
|
result = GenFTData()
|
94
90
|
|
95
|
-
|
91
|
+
mixture = mixdb.mixture_mixture(m_id)
|
92
|
+
|
93
|
+
feature, truth_f = mixdb.mixture_ft(m_id=m_id, mixture=mixture, force=force)
|
96
94
|
result.feature = feature
|
97
|
-
|
98
|
-
write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("feature", feature)])
|
95
|
+
items: dict[str, Any] = {"feature": feature}
|
99
96
|
|
100
97
|
if compute_truth:
|
101
98
|
result.truth_f = truth_f
|
102
|
-
|
103
|
-
write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("truth_f", truth_f)])
|
99
|
+
items["truth_f"] = truth_f
|
104
100
|
|
105
101
|
if compute_segsnr:
|
106
|
-
|
107
|
-
result.segsnr =
|
108
|
-
if write:
|
109
|
-
write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("segsnr", segsnr)])
|
102
|
+
segsnr_t = mixdb.mixture_segsnr_t(m_id)
|
103
|
+
result.segsnr = mixdb.mixture_segsnr(m_id=m_id, segsnr_t=segsnr_t, force=force, cache=cache)
|
110
104
|
|
111
|
-
if
|
105
|
+
if cache:
|
106
|
+
write_func(items=items)
|
112
107
|
write_mixture_metadata(mixdb, m_id=m_id)
|
113
108
|
|
114
109
|
return result
|
@@ -117,10 +112,10 @@ def _genft_kernel(
|
|
117
112
|
def main() -> None:
|
118
113
|
from docopt import docopt
|
119
114
|
|
120
|
-
import
|
115
|
+
from sonusai import __version__ as sai_version
|
121
116
|
from sonusai.utils import trim_docstring
|
122
117
|
|
123
|
-
args = docopt(trim_docstring(__doc__), version=
|
118
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
124
119
|
|
125
120
|
import time
|
126
121
|
from os.path import join
|
@@ -129,7 +124,7 @@ def main() -> None:
|
|
129
124
|
from sonusai import initial_log_messages
|
130
125
|
from sonusai import logger
|
131
126
|
from sonusai import update_console_handler
|
132
|
-
from sonusai.
|
127
|
+
from sonusai.constants import SAMPLE_RATE
|
133
128
|
from sonusai.mixture import MixtureDatabase
|
134
129
|
from sonusai.mixture import check_audio_files_exist
|
135
130
|
from sonusai.utils import human_readable_size
|
@@ -166,18 +161,14 @@ def main() -> None:
|
|
166
161
|
|
167
162
|
check_audio_files_exist(mixdb)
|
168
163
|
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
)
|
178
|
-
except Exception as e:
|
179
|
-
logger.debug(e)
|
180
|
-
raise
|
164
|
+
genft(
|
165
|
+
location=location,
|
166
|
+
mixids=mixids,
|
167
|
+
compute_segsnr=compute_segsnr,
|
168
|
+
cache=True,
|
169
|
+
show_progress=True,
|
170
|
+
no_par=no_par,
|
171
|
+
)
|
181
172
|
|
182
173
|
logger.info(f"Wrote {len(mixids)} mixtures to {location}")
|
183
174
|
logger.info("")
|
@@ -195,4 +186,11 @@ def main() -> None:
|
|
195
186
|
|
196
187
|
|
197
188
|
if __name__ == "__main__":
|
198
|
-
|
189
|
+
from sonusai import exception_handler
|
190
|
+
from sonusai.utils import register_keyboard_interrupt
|
191
|
+
|
192
|
+
register_keyboard_interrupt()
|
193
|
+
try:
|
194
|
+
main()
|
195
|
+
except Exception as e:
|
196
|
+
exception_handler(e)
|
sonusai/genmetrics.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
1
|
"""sonusai genmetrics
|
2
2
|
|
3
|
-
usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
|
3
|
+
usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-p NUMPROC] [-x EXCLUDE] LOC
|
4
4
|
|
5
5
|
options:
|
6
6
|
-h, --help
|
7
7
|
-v, --verbose Be verbose.
|
8
8
|
-i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
|
9
9
|
-n INCLUDE, --include INCLUDE Metrics to include. [default: all]
|
10
|
+
-p NUMPROC, --nproc NUMPROC Number of parallel processes to use. Default single thread.
|
10
11
|
-x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
|
11
12
|
-u, --update Update metrics (do not regenerate existing metrics).
|
12
13
|
-s, --supported Show list of supported metrics.
|
@@ -46,28 +47,16 @@ Generate all available metrics except for mxcovl
|
|
46
47
|
|
47
48
|
"""
|
48
49
|
|
49
|
-
import signal
|
50
|
-
|
51
|
-
|
52
|
-
def signal_handler(_sig, _frame):
|
53
|
-
import sys
|
54
|
-
|
55
|
-
from sonusai import logger
|
56
|
-
|
57
|
-
logger.info("Canceled due to keyboard interrupt")
|
58
|
-
sys.exit(1)
|
59
|
-
|
60
|
-
|
61
|
-
signal.signal(signal.SIGINT, signal_handler)
|
62
|
-
|
63
50
|
|
64
51
|
def _process_mixture(mixid: int, location: str, metrics: list[str], update: bool = False) -> set[str]:
|
65
52
|
from sonusai.mixture import MixtureDatabase
|
66
53
|
from sonusai.mixture import write_cached_data
|
54
|
+
from sonusai.mixture import write_mixture_metadata
|
67
55
|
|
68
56
|
mixdb = MixtureDatabase(location)
|
69
57
|
results = mixdb.mixture_metrics(m_id=mixid, metrics=metrics, force=not update)
|
70
|
-
write_cached_data(mixdb.location, "mixture", mixdb.mixture(mixid).name,
|
58
|
+
write_cached_data(mixdb.location, "mixture", mixdb.mixture(mixid).name, results)
|
59
|
+
write_mixture_metadata(mixdb, mixture=mixdb.mixture(mixid))
|
71
60
|
|
72
61
|
return set(results.keys())
|
73
62
|
|
@@ -75,11 +64,11 @@ def _process_mixture(mixid: int, location: str, metrics: list[str], update: bool
|
|
75
64
|
def main() -> None:
|
76
65
|
from docopt import docopt
|
77
66
|
|
78
|
-
import
|
67
|
+
from sonusai import __version__ as sai_version
|
79
68
|
from sonusai.mixture import MixtureDatabase
|
80
69
|
from sonusai.utils import trim_docstring
|
81
70
|
|
82
|
-
args = docopt(trim_docstring(__doc__), version=
|
71
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
83
72
|
|
84
73
|
verbose = args["--verbose"]
|
85
74
|
mixids = args["--mixid"]
|
@@ -88,6 +77,7 @@ def main() -> None:
|
|
88
77
|
update = args["--update"]
|
89
78
|
show_supported = args["--supported"]
|
90
79
|
dryrun = args["--dryrun"]
|
80
|
+
num_proc = args["--nproc"]
|
91
81
|
location = args["LOC"]
|
92
82
|
|
93
83
|
import fnmatch
|
@@ -157,11 +147,20 @@ def main() -> None:
|
|
157
147
|
logger.info("")
|
158
148
|
logger.info(f"Found {len(mixids):,} mixtures to process")
|
159
149
|
|
150
|
+
if num_proc is None or len(mixids) == 1:
|
151
|
+
no_par = True
|
152
|
+
num_proc = None
|
153
|
+
else:
|
154
|
+
no_par = False
|
155
|
+
num_proc = int(num_proc) # TBD add support for 'auto'
|
156
|
+
|
160
157
|
progress = track(total=len(mixids), desc="genmetrics")
|
161
158
|
results = par_track(
|
162
159
|
partial(_process_mixture, location=location, metrics=metrics, update=update),
|
163
160
|
mixids,
|
164
161
|
progress=progress,
|
162
|
+
num_cpus=num_proc,
|
163
|
+
no_par=no_par,
|
165
164
|
)
|
166
165
|
progress.close()
|
167
166
|
|
@@ -176,4 +175,11 @@ def main() -> None:
|
|
176
175
|
|
177
176
|
|
178
177
|
if __name__ == "__main__":
|
179
|
-
|
178
|
+
from sonusai import exception_handler
|
179
|
+
from sonusai.utils import register_keyboard_interrupt
|
180
|
+
|
181
|
+
register_keyboard_interrupt()
|
182
|
+
try:
|
183
|
+
main()
|
184
|
+
except Exception as e:
|
185
|
+
exception_handler(e)
|