sonusai 1.0.16__cp311-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/speech/mcgill.py
ADDED
@@ -0,0 +1,71 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
from .types import TimeAlignedType
|
4
|
+
|
5
|
+
|
6
|
+
def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
|
7
|
+
"""Load time-aligned text data given a McGill-Speech audio file.
|
8
|
+
|
9
|
+
:param audio: Path to the McGill-Speech audio file.
|
10
|
+
:return: A TimeAlignedType object.
|
11
|
+
"""
|
12
|
+
import string
|
13
|
+
import struct
|
14
|
+
|
15
|
+
from ..mixture.audio import get_sample_rate
|
16
|
+
|
17
|
+
if not os.path.exists(audio):
|
18
|
+
return None
|
19
|
+
|
20
|
+
sample_rate = get_sample_rate(str(audio))
|
21
|
+
|
22
|
+
with open(audio, mode="rb") as f:
|
23
|
+
content = f.read()
|
24
|
+
|
25
|
+
riff_id, file_size, wave_id = struct.unpack("<4si4s", content[:12])
|
26
|
+
if riff_id.decode("utf-8") != "RIFF":
|
27
|
+
return None
|
28
|
+
|
29
|
+
if wave_id.decode("utf-8") != "WAVE":
|
30
|
+
return None
|
31
|
+
|
32
|
+
fmt_id, fmt_size = struct.unpack("<4si", content[12:20])
|
33
|
+
|
34
|
+
if fmt_id.decode("utf-8") != "fmt ":
|
35
|
+
return None
|
36
|
+
|
37
|
+
if fmt_size != 16:
|
38
|
+
return None
|
39
|
+
|
40
|
+
(
|
41
|
+
_wave_format_tag,
|
42
|
+
channels,
|
43
|
+
_samples_per_sec,
|
44
|
+
_avg_bytes_per_sec,
|
45
|
+
_block_align,
|
46
|
+
bits_per_sample,
|
47
|
+
) = struct.unpack("<hhiihh", content[20:36])
|
48
|
+
|
49
|
+
i = 36
|
50
|
+
samples = None
|
51
|
+
text = None
|
52
|
+
while i < file_size:
|
53
|
+
chunk_id = struct.unpack("<4s", content[i : i + 4])[0].decode("utf-8")
|
54
|
+
chunk_size = struct.unpack("<i", content[i + 4 : i + 8])[0]
|
55
|
+
|
56
|
+
if chunk_id == "data":
|
57
|
+
samples = chunk_size / channels / (bits_per_sample / 8)
|
58
|
+
break
|
59
|
+
|
60
|
+
if chunk_id == "afsp":
|
61
|
+
chunks = struct.unpack(f"<{chunk_size}s", content[i + 8 : i + 8 + chunk_size])[0]
|
62
|
+
chunks = chunks.decode("utf-8").split("\x00")
|
63
|
+
for chunk in chunks:
|
64
|
+
if chunk.startswith('text: "'):
|
65
|
+
text = chunk[7:-1].lower().translate(str.maketrans("", "", string.punctuation))
|
66
|
+
i += 8 + chunk_size + chunk_size % 2
|
67
|
+
|
68
|
+
if text and samples:
|
69
|
+
return TimeAlignedType(start=0, end=samples / sample_rate, text=text)
|
70
|
+
|
71
|
+
return None
|
@@ -0,0 +1,89 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
from praatio import textgrid
|
4
|
+
from praatio.utilities.constants import Interval
|
5
|
+
|
6
|
+
from .types import TimeAlignedType
|
7
|
+
|
8
|
+
|
9
|
+
def create_textgrid(
|
10
|
+
prompt: Path,
|
11
|
+
output_dir: Path,
|
12
|
+
text: TimeAlignedType | None = None,
|
13
|
+
words: list[TimeAlignedType] | None = None,
|
14
|
+
phonemes: list[TimeAlignedType] | None = None,
|
15
|
+
) -> None:
|
16
|
+
if text is None and words is None and phonemes is None:
|
17
|
+
return
|
18
|
+
|
19
|
+
min_t, max_t = _get_min_max({"phonemes": phonemes, "text": text, "words": words})
|
20
|
+
|
21
|
+
tg = textgrid.Textgrid()
|
22
|
+
|
23
|
+
if text is not None:
|
24
|
+
entries = [Interval(text.start, text.end, text.text)]
|
25
|
+
text_tier = textgrid.IntervalTier("text", entries, min_t, max_t)
|
26
|
+
tg.addTier(text_tier)
|
27
|
+
|
28
|
+
if words is not None:
|
29
|
+
entries = []
|
30
|
+
for word in words:
|
31
|
+
entries.append(Interval(word.start, word.end, word.text))
|
32
|
+
words_tier = textgrid.IntervalTier("words", entries, min_t, max_t)
|
33
|
+
tg.addTier(words_tier)
|
34
|
+
|
35
|
+
if phonemes is not None:
|
36
|
+
entries = []
|
37
|
+
for phoneme in phonemes:
|
38
|
+
entries.append(Interval(phoneme.start, phoneme.end, phoneme.text))
|
39
|
+
phonemes_tier = textgrid.IntervalTier("phonemes", entries, min_t, max_t)
|
40
|
+
tg.addTier(phonemes_tier)
|
41
|
+
|
42
|
+
output_filename = str(output_dir / prompt.stem) + ".TextGrid"
|
43
|
+
tg.save(output_filename, format="long_textgrid", includeBlankSpaces=True)
|
44
|
+
|
45
|
+
|
46
|
+
def _get_min_max(tiers: dict[str, TimeAlignedType | list[TimeAlignedType] | None]) -> tuple[float, float]:
|
47
|
+
starts = []
|
48
|
+
ends = []
|
49
|
+
for tier in tiers.values():
|
50
|
+
if tier is None:
|
51
|
+
continue
|
52
|
+
if isinstance(tier, TimeAlignedType):
|
53
|
+
starts.append(tier.start)
|
54
|
+
ends.append(tier.end)
|
55
|
+
else:
|
56
|
+
starts.append(tier[0].start)
|
57
|
+
ends.append(tier[-1].end)
|
58
|
+
|
59
|
+
return min(starts), max(ends)
|
60
|
+
|
61
|
+
|
62
|
+
def annotate_textgrid(
|
63
|
+
tiers: dict[str, TimeAlignedType | list[TimeAlignedType] | None] | None, prompt: Path, output_dir: Path
|
64
|
+
) -> None:
|
65
|
+
import os
|
66
|
+
|
67
|
+
if tiers is None:
|
68
|
+
return
|
69
|
+
|
70
|
+
file = Path(output_dir / prompt.stem).with_suffix(".TextGrid")
|
71
|
+
if not os.path.exists(file):
|
72
|
+
tg = textgrid.Textgrid()
|
73
|
+
min_t, max_t = _get_min_max(tiers)
|
74
|
+
else:
|
75
|
+
tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
|
76
|
+
min_t = tg.minTimestamp
|
77
|
+
max_t = tg.maxTimestamp
|
78
|
+
|
79
|
+
for k, v in tiers.items():
|
80
|
+
if v is None:
|
81
|
+
continue
|
82
|
+
entries = [Interval(entry.start, entry.end, entry.text) for entry in v]
|
83
|
+
if k == "phones":
|
84
|
+
name = "annotation_phonemes"
|
85
|
+
else:
|
86
|
+
name = "annotation_" + k
|
87
|
+
tg.addTier(textgrid.IntervalTier(name, entries, min_t, max_t))
|
88
|
+
|
89
|
+
tg.save(str(file), format="long_textgrid", includeBlankSpaces=True)
|
sonusai/speech/timit.py
ADDED
@@ -0,0 +1,138 @@
|
|
1
|
+
import os
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
from .types import TimeAlignedType
|
5
|
+
|
6
|
+
|
7
|
+
def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
|
8
|
+
"""Load time-aligned text data given a TIMIT audio file.
|
9
|
+
|
10
|
+
:param audio: Path to the TIMIT audio file.
|
11
|
+
:return: A TimeAlignedType object.
|
12
|
+
"""
|
13
|
+
import string
|
14
|
+
|
15
|
+
from ..mixture.audio import get_sample_rate
|
16
|
+
|
17
|
+
file = Path(audio).with_suffix(".TXT")
|
18
|
+
if not os.path.exists(file):
|
19
|
+
return None
|
20
|
+
|
21
|
+
sample_rate = get_sample_rate(str(audio))
|
22
|
+
|
23
|
+
with open(file, encoding="utf-8") as f:
|
24
|
+
line = f.read()
|
25
|
+
|
26
|
+
fields = line.strip().split()
|
27
|
+
start = int(fields[0]) / sample_rate
|
28
|
+
end = int(fields[1]) / sample_rate
|
29
|
+
text = " ".join(fields[2:]).lower().translate(str.maketrans("", "", string.punctuation))
|
30
|
+
|
31
|
+
return TimeAlignedType(start, end, text)
|
32
|
+
|
33
|
+
|
34
|
+
def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
|
35
|
+
"""Load time-aligned word data given a TIMIT audio file.
|
36
|
+
|
37
|
+
:param audio: Path to the TIMIT audio file.
|
38
|
+
:return: A list of TimeAlignedType objects.
|
39
|
+
"""
|
40
|
+
|
41
|
+
return _load_ta(audio, "words")
|
42
|
+
|
43
|
+
|
44
|
+
def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
|
45
|
+
"""Load time-aligned phonemes data given a TIMIT audio file.
|
46
|
+
|
47
|
+
:param audio: Path to the TIMIT audio file.
|
48
|
+
:return: A list of TimeAlignedType objects.
|
49
|
+
"""
|
50
|
+
|
51
|
+
return _load_ta(audio, "phonemes")
|
52
|
+
|
53
|
+
|
54
|
+
def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
|
55
|
+
from ..mixture.audio import get_sample_rate
|
56
|
+
|
57
|
+
if tier == "words":
|
58
|
+
file = Path(audio).with_suffix(".WRD")
|
59
|
+
elif tier == "phonemes":
|
60
|
+
file = Path(audio).with_suffix(".PHN")
|
61
|
+
else:
|
62
|
+
raise ValueError(f"Unknown tier: {tier}")
|
63
|
+
|
64
|
+
if not os.path.exists(file):
|
65
|
+
return None
|
66
|
+
|
67
|
+
sample_rate = get_sample_rate(str(audio))
|
68
|
+
|
69
|
+
entries: list[TimeAlignedType] = []
|
70
|
+
first = True
|
71
|
+
with open(file, encoding="utf-8") as f:
|
72
|
+
for line in f.readlines():
|
73
|
+
fields = line.strip().split()
|
74
|
+
start = int(fields[0]) / sample_rate
|
75
|
+
end = int(fields[1]) / sample_rate
|
76
|
+
text = " ".join(fields[2:])
|
77
|
+
|
78
|
+
if first:
|
79
|
+
first = False
|
80
|
+
else:
|
81
|
+
if start < entries[-1].end:
|
82
|
+
start = entries[-1].end - (entries[-1].end - start) // 2
|
83
|
+
entries[-1] = TimeAlignedType(text=entries[-1].text, start=entries[-1].start, end=start)
|
84
|
+
|
85
|
+
if end <= start:
|
86
|
+
end = start + 1 / sample_rate
|
87
|
+
|
88
|
+
entries.append(TimeAlignedType(text=text, start=start, end=end))
|
89
|
+
|
90
|
+
return entries
|
91
|
+
|
92
|
+
|
93
|
+
def _years_between(record, born):
|
94
|
+
try:
|
95
|
+
rec_fields = [int(x) for x in record.split("/")]
|
96
|
+
brn_fields = [int(x) for x in born.split("/")]
|
97
|
+
return rec_fields[2] - brn_fields[2] - ((rec_fields[1], rec_fields[0]) < (brn_fields[1], brn_fields[0]))
|
98
|
+
except ValueError:
|
99
|
+
return "??"
|
100
|
+
|
101
|
+
|
102
|
+
def _decode_dialect(d: str) -> str:
|
103
|
+
if d in ["DR1", "1"]:
|
104
|
+
return "New England"
|
105
|
+
if d in ["DR2", "2"]:
|
106
|
+
return "Northern"
|
107
|
+
if d in ["DR3", "3"]:
|
108
|
+
return "North Midland"
|
109
|
+
if d in ["DR4", "4"]:
|
110
|
+
return "South Midland"
|
111
|
+
if d in ["DR5", "5"]:
|
112
|
+
return "Southern"
|
113
|
+
if d in ["DR6", "6"]:
|
114
|
+
return "New York City"
|
115
|
+
if d in ["DR7", "7"]:
|
116
|
+
return "Western"
|
117
|
+
if d in ["DR8", "8"]:
|
118
|
+
return "Army Brat"
|
119
|
+
|
120
|
+
raise ValueError(f"Unrecognized dialect: {d}")
|
121
|
+
|
122
|
+
|
123
|
+
def load_speakers(input_dir: Path) -> dict:
|
124
|
+
speakers = {}
|
125
|
+
with open(input_dir / "SPKRINFO.TXT") as file:
|
126
|
+
for line in file:
|
127
|
+
if not line.startswith(";"):
|
128
|
+
fields = line.strip().split()
|
129
|
+
speaker_id = fields[0]
|
130
|
+
gender = fields[1]
|
131
|
+
dialect = _decode_dialect(fields[2])
|
132
|
+
age = _years_between(fields[4], fields[5])
|
133
|
+
speakers[speaker_id] = {
|
134
|
+
"gender": gender,
|
135
|
+
"dialect": dialect,
|
136
|
+
"age": age,
|
137
|
+
}
|
138
|
+
return speakers
|
sonusai/speech/types.py
ADDED
sonusai/speech/vctk.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
import os
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
from .types import TimeAlignedType
|
5
|
+
|
6
|
+
|
7
|
+
def _get_duration(name: str) -> float:
|
8
|
+
import soundfile
|
9
|
+
|
10
|
+
try:
|
11
|
+
return soundfile.info(name).duration
|
12
|
+
except Exception as e:
|
13
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
14
|
+
|
15
|
+
|
16
|
+
def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
|
17
|
+
"""Load time-aligned text data given a VCTK audio file.
|
18
|
+
|
19
|
+
:param audio: Path to the VCTK audio file.
|
20
|
+
:return: A TimeAlignedType object.
|
21
|
+
"""
|
22
|
+
import string
|
23
|
+
|
24
|
+
file = Path(audio).parents[2] / "txt" / Path(audio).parent.name / (Path(audio).stem[:-5] + ".txt")
|
25
|
+
if not os.path.exists(file):
|
26
|
+
return None
|
27
|
+
|
28
|
+
with open(file, encoding="utf-8") as f:
|
29
|
+
line = f.read()
|
30
|
+
|
31
|
+
start = 0
|
32
|
+
end = _get_duration(str(audio))
|
33
|
+
text = line.strip().lower().translate(str.maketrans("", "", string.punctuation))
|
34
|
+
|
35
|
+
return TimeAlignedType(start, end, text)
|
36
|
+
|
37
|
+
|
38
|
+
def load_speakers(input_dir: Path) -> dict:
|
39
|
+
speakers = {}
|
40
|
+
with open(input_dir / "speaker-info.txt") as file:
|
41
|
+
for line in file:
|
42
|
+
if not line.startswith("ID"):
|
43
|
+
fields = line.strip().split("(", 1)[0].split()
|
44
|
+
speaker_id = fields[0]
|
45
|
+
age = fields[1]
|
46
|
+
gender = fields[2]
|
47
|
+
dialect = " ".join(list(fields[3:]))
|
48
|
+
speakers[speaker_id] = {
|
49
|
+
"gender": gender,
|
50
|
+
"dialect": dialect,
|
51
|
+
"age": age,
|
52
|
+
}
|
53
|
+
return speakers
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import os
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass(frozen=True)
|
7
|
+
class Segment:
|
8
|
+
person: str
|
9
|
+
video: str
|
10
|
+
id: str
|
11
|
+
start: float
|
12
|
+
stop: float
|
13
|
+
|
14
|
+
|
15
|
+
def load_speakers(input_dir: Path) -> dict:
|
16
|
+
import csv
|
17
|
+
|
18
|
+
speakers = {}
|
19
|
+
|
20
|
+
# VoxCeleb1
|
21
|
+
first = True
|
22
|
+
with open(input_dir / "vox1_meta.csv", newline="") as file:
|
23
|
+
data = csv.reader(file, delimiter="\t")
|
24
|
+
for row in data:
|
25
|
+
if first:
|
26
|
+
first = False
|
27
|
+
else:
|
28
|
+
speakers[row[0].strip()] = {
|
29
|
+
"gender": row[2].strip(),
|
30
|
+
"dialect": row[3].strip(),
|
31
|
+
"category": row[4].strip(),
|
32
|
+
}
|
33
|
+
|
34
|
+
# VoxCeleb2
|
35
|
+
first = True
|
36
|
+
with open(input_dir / "vox2_meta.csv", newline="") as file:
|
37
|
+
data = csv.reader(file, delimiter="\t")
|
38
|
+
for row in data:
|
39
|
+
if first:
|
40
|
+
first = False
|
41
|
+
else:
|
42
|
+
speakers[row[1].strip()] = {
|
43
|
+
"gender": row[3].strip(),
|
44
|
+
"category": row[4].strip(),
|
45
|
+
}
|
46
|
+
|
47
|
+
return speakers
|
48
|
+
|
49
|
+
|
50
|
+
def load_segment(path: str | os.PathLike[str]) -> Segment:
|
51
|
+
path = Path(path)
|
52
|
+
|
53
|
+
with path.open("r") as file:
|
54
|
+
segment = file.read().strip()
|
55
|
+
|
56
|
+
header, frames = segment.split("\n\n")
|
57
|
+
header_dict = _parse_header(header)
|
58
|
+
start, stop = _get_segment_boundaries(frames)
|
59
|
+
|
60
|
+
return Segment(
|
61
|
+
person=header_dict["Identity"],
|
62
|
+
video=header_dict["Reference"],
|
63
|
+
id=path.stem,
|
64
|
+
start=start,
|
65
|
+
stop=stop,
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
def _parse_header(header: str) -> dict:
|
70
|
+
def _parse_line(line: str) -> tuple[str, str]:
|
71
|
+
"""Parse a line of header text into a dictionary.
|
72
|
+
|
73
|
+
Header text has the following format:
|
74
|
+
|
75
|
+
Identity : \tid00017
|
76
|
+
Reference : \t7t6lfzvVaTM
|
77
|
+
Offset : \t1
|
78
|
+
FV Conf : \t16.647\t(1)
|
79
|
+
ASD Conf : \t4.465
|
80
|
+
|
81
|
+
"""
|
82
|
+
k, v = line.split("\t", maxsplit=1)
|
83
|
+
k = k[:-2].strip()
|
84
|
+
v = v.strip()
|
85
|
+
return k, v
|
86
|
+
|
87
|
+
return dict(_parse_line(line) for line in header.split("\n"))
|
88
|
+
|
89
|
+
|
90
|
+
def _get_segment_boundaries(frames: str) -> tuple[float, float]:
|
91
|
+
"""Get the start and stop points of the segment.
|
92
|
+
|
93
|
+
Frames text has the following format:
|
94
|
+
|
95
|
+
FRAME X Y W H
|
96
|
+
000245 0.392 0.223 0.253 0.451
|
97
|
+
...
|
98
|
+
000470 0.359 0.207 0.260 0.463
|
99
|
+
|
100
|
+
"""
|
101
|
+
|
102
|
+
def _get_frame_seconds(line: str) -> float:
|
103
|
+
frame = int(line.split("\t")[0])
|
104
|
+
# YouTube is 25 FPS
|
105
|
+
return frame / 25
|
106
|
+
|
107
|
+
lines = frames.split("\n")
|
108
|
+
return _get_frame_seconds(lines[1]), _get_frame_seconds(lines[-1])
|
sonusai/utils/asl_p56.py
ADDED
@@ -0,0 +1,130 @@
|
|
1
|
+
from ..datatypes import AudioT
|
2
|
+
|
3
|
+
|
4
|
+
def asl_p56(audio: AudioT) -> float:
|
5
|
+
"""Implement ITU-T P.56 method B
|
6
|
+
:param audio: audio for which to calculate active speech level
|
7
|
+
:return: Active speech level mean square energy
|
8
|
+
"""
|
9
|
+
import numpy as np
|
10
|
+
import scipy.signal as signal
|
11
|
+
|
12
|
+
from ..constants import SAMPLE_RATE
|
13
|
+
|
14
|
+
eps = np.finfo(np.float32).eps
|
15
|
+
|
16
|
+
# Time constant of smoothing in seconds
|
17
|
+
T = 0.03
|
18
|
+
|
19
|
+
# Coefficient of smoothing
|
20
|
+
g = np.exp(-1 / (SAMPLE_RATE * T))
|
21
|
+
|
22
|
+
# Hangover time in seconds
|
23
|
+
H = 0.2
|
24
|
+
# Rounded up to next integer
|
25
|
+
H_samples = np.ceil(H * SAMPLE_RATE)
|
26
|
+
|
27
|
+
# Margin in dB, difference between threshold and active speech level
|
28
|
+
M = 15.9
|
29
|
+
|
30
|
+
# Number of thresholds
|
31
|
+
thresh_num = 15
|
32
|
+
|
33
|
+
# Series of fixed threshold voltages to apply to the envelope. These are spaced
|
34
|
+
# in geometric progression, at intervals of not more than 2:1 (6.02 dB), from a
|
35
|
+
# value equal to about half the maximum code down to a value equal to one
|
36
|
+
# quantizing interval or lower.
|
37
|
+
c = 2 ** np.arange(-15, thresh_num - 15, dtype=np.float32)
|
38
|
+
|
39
|
+
# Activity counter for each threshold
|
40
|
+
a = np.full(thresh_num, -1)
|
41
|
+
|
42
|
+
# Hangover counter for each threshold
|
43
|
+
h = np.full(thresh_num, H_samples)
|
44
|
+
|
45
|
+
# Long-term level square energy of audio
|
46
|
+
sq = sum(np.square(audio))
|
47
|
+
|
48
|
+
# Use a 2nd order IIR filter to detect the envelope q
|
49
|
+
p = signal.lfilter([1 - g, 0], [1, -g], abs(audio))
|
50
|
+
# q is the envelope, obtained from moving average of abs(audio) (with slight "hangover").
|
51
|
+
q = signal.lfilter([1 - g, 0], [1, -g], p)
|
52
|
+
|
53
|
+
for k in range(len(audio)):
|
54
|
+
for j in range(thresh_num):
|
55
|
+
if q[k] >= c[j]:
|
56
|
+
a[j] = a[j] + 1
|
57
|
+
h[j] = 0
|
58
|
+
elif h[j] < H_samples:
|
59
|
+
a[j] = a[j] + 1
|
60
|
+
h[j] = h[j] + 1
|
61
|
+
else:
|
62
|
+
break
|
63
|
+
asl_msq = 0
|
64
|
+
if a[0] == -1:
|
65
|
+
return asl_msq
|
66
|
+
|
67
|
+
a += 2
|
68
|
+
A_db1 = 10 * np.log10(sq / a[0] + eps)
|
69
|
+
C_db1 = 20 * np.log10(c[0] + eps)
|
70
|
+
if A_db1 - C_db1 < M:
|
71
|
+
return asl_msq
|
72
|
+
|
73
|
+
A_db = np.zeros(thresh_num)
|
74
|
+
C_db = np.zeros(thresh_num)
|
75
|
+
delta = np.zeros(thresh_num)
|
76
|
+
A_db[0] = A_db1
|
77
|
+
C_db[0] = C_db1
|
78
|
+
delta[0] = A_db1 - C_db1
|
79
|
+
|
80
|
+
for j in range(1, thresh_num):
|
81
|
+
A_db[j] = 10 * np.log10(sq / (a[j] + eps) + eps)
|
82
|
+
C_db[j] = 20 * np.log10(c[j] + eps)
|
83
|
+
|
84
|
+
for j in range(1, thresh_num):
|
85
|
+
if a[j] != 0:
|
86
|
+
delta[j] = A_db[j] - C_db[j]
|
87
|
+
if delta[j] <= M:
|
88
|
+
# Interpolate to find the asl_ms_log
|
89
|
+
asl_ms_log = _bin_interp(A_db[j], A_db[j - 1], C_db[j], C_db[j - 1], M, 0.5)
|
90
|
+
# This is the mean square value NOT the RMS
|
91
|
+
asl_msq = 10.0 ** (asl_ms_log / 10)
|
92
|
+
break
|
93
|
+
|
94
|
+
return asl_msq
|
95
|
+
|
96
|
+
|
97
|
+
def _bin_interp(u_cnt: float, l_cnt: float, u_thr: float, l_thr: float, margin: float, tol: float) -> float:
|
98
|
+
tol = abs(tol)
|
99
|
+
|
100
|
+
# Check if extreme counts are not already the true active value
|
101
|
+
iter_num = 1
|
102
|
+
if abs(u_cnt - u_thr - margin) < tol:
|
103
|
+
return u_cnt
|
104
|
+
|
105
|
+
if abs(l_cnt - l_thr - margin) < tol:
|
106
|
+
return l_cnt
|
107
|
+
|
108
|
+
# Initialize first middle for given (initial) bounds
|
109
|
+
m_cnt = (u_cnt + l_cnt) / 2.0
|
110
|
+
m_thr = (u_thr + l_thr) / 2.0
|
111
|
+
|
112
|
+
while True:
|
113
|
+
# Loop until diff falls inside the tolerance (-tol<=diff<=tol)
|
114
|
+
diff = m_cnt - m_thr - margin
|
115
|
+
if abs(diff) <= tol:
|
116
|
+
break
|
117
|
+
|
118
|
+
# If tolerance is not met up to 20 iterations, then relax the tolerance by 10
|
119
|
+
iter_num += 1
|
120
|
+
if iter_num > 20:
|
121
|
+
tol = tol * 1.1
|
122
|
+
|
123
|
+
if diff > tol:
|
124
|
+
m_cnt = (u_cnt + m_cnt) / 2.0
|
125
|
+
m_thr = (u_thr + m_thr) / 2.0
|
126
|
+
elif diff < -tol:
|
127
|
+
m_cnt = (m_cnt + l_cnt) / 2.0
|
128
|
+
m_thr = (m_thr + l_thr) / 2.0
|
129
|
+
|
130
|
+
return m_cnt
|
sonusai/utils/asr.py
ADDED
@@ -0,0 +1,91 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
from ..datatypes import AudioT
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass(frozen=True)
|
8
|
+
class ASRResult:
|
9
|
+
text: str
|
10
|
+
confidence: float | None = None
|
11
|
+
lang: str | None = None
|
12
|
+
lang_prob: float | None = None
|
13
|
+
duration: float | None = None
|
14
|
+
num_segments: int | None = None
|
15
|
+
asr_cpu_time: float | None = None
|
16
|
+
|
17
|
+
|
18
|
+
def get_available_engines() -> list[str]:
|
19
|
+
from importlib import import_module
|
20
|
+
from pkgutil import iter_modules
|
21
|
+
|
22
|
+
module = import_module("sonusai.utils.asr_functions")
|
23
|
+
engines = [method for method in dir(module) if not method.startswith("_")]
|
24
|
+
for _, name, _ in iter_modules():
|
25
|
+
if name.startswith("sonusai_asr_"):
|
26
|
+
module = import_module(f"{name}.asr_functions")
|
27
|
+
for method in dir(module):
|
28
|
+
if not method.startswith("_"):
|
29
|
+
engines.append(method)
|
30
|
+
|
31
|
+
return engines
|
32
|
+
|
33
|
+
|
34
|
+
def _asr_fn(engine: str) -> Callable[..., ASRResult]:
|
35
|
+
from importlib import import_module
|
36
|
+
from pkgutil import iter_modules
|
37
|
+
|
38
|
+
module = import_module("sonusai.utils.asr_functions")
|
39
|
+
for method in dir(module):
|
40
|
+
if method == engine:
|
41
|
+
return getattr(module, method)
|
42
|
+
|
43
|
+
for _, name, _ in iter_modules():
|
44
|
+
if name.startswith("sonusai_asr_"):
|
45
|
+
module = import_module(f"{name}.asr_functions")
|
46
|
+
for method in dir(module):
|
47
|
+
if method == engine:
|
48
|
+
return getattr(module, method)
|
49
|
+
|
50
|
+
raise ValueError(f"engine {engine} not supported")
|
51
|
+
|
52
|
+
|
53
|
+
def calc_asr(audio: AudioT | str, engine: str, **config) -> ASRResult:
|
54
|
+
"""Run ASR on audio
|
55
|
+
|
56
|
+
:param audio: Numpy array of audio samples or location of an audio file
|
57
|
+
:param engine: ASR engine to use
|
58
|
+
:param config: kwargs configuration parameters
|
59
|
+
:return: ASRResult object containing text and confidence
|
60
|
+
"""
|
61
|
+
from copy import copy
|
62
|
+
|
63
|
+
import numpy as np
|
64
|
+
|
65
|
+
from ..mixture.audio import read_audio
|
66
|
+
|
67
|
+
if not isinstance(audio, np.ndarray):
|
68
|
+
audio = copy(read_audio(audio, config.get("use_cache", True)))
|
69
|
+
|
70
|
+
return _asr_fn(engine)(audio, **config)
|
71
|
+
|
72
|
+
|
73
|
+
def validate_asr(engine: str, **config) -> None:
|
74
|
+
from importlib import import_module
|
75
|
+
from pkgutil import iter_modules
|
76
|
+
|
77
|
+
module = import_module("sonusai.utils.asr_functions")
|
78
|
+
for method in dir(module):
|
79
|
+
if method == engine:
|
80
|
+
getattr(module, method + "_validate")(**config)
|
81
|
+
return
|
82
|
+
|
83
|
+
for _, name, _ in iter_modules():
|
84
|
+
if name.startswith("sonusai_asr_"):
|
85
|
+
module = import_module(f"{name}.asr_functions")
|
86
|
+
for method in dir(module):
|
87
|
+
if method == engine:
|
88
|
+
getattr(module, method + "_validate")(**config)
|
89
|
+
return
|
90
|
+
|
91
|
+
raise ValueError(f"engine {engine} not supported")
|