sonusai 1.0.16__cp311-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,71 @@
1
+ import os
2
+
3
+ from .types import TimeAlignedType
4
+
5
+
6
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
7
+ """Load time-aligned text data given a McGill-Speech audio file.
8
+
9
+ :param audio: Path to the McGill-Speech audio file.
10
+ :return: A TimeAlignedType object.
11
+ """
12
+ import string
13
+ import struct
14
+
15
+ from ..mixture.audio import get_sample_rate
16
+
17
+ if not os.path.exists(audio):
18
+ return None
19
+
20
+ sample_rate = get_sample_rate(str(audio))
21
+
22
+ with open(audio, mode="rb") as f:
23
+ content = f.read()
24
+
25
+ riff_id, file_size, wave_id = struct.unpack("<4si4s", content[:12])
26
+ if riff_id.decode("utf-8") != "RIFF":
27
+ return None
28
+
29
+ if wave_id.decode("utf-8") != "WAVE":
30
+ return None
31
+
32
+ fmt_id, fmt_size = struct.unpack("<4si", content[12:20])
33
+
34
+ if fmt_id.decode("utf-8") != "fmt ":
35
+ return None
36
+
37
+ if fmt_size != 16:
38
+ return None
39
+
40
+ (
41
+ _wave_format_tag,
42
+ channels,
43
+ _samples_per_sec,
44
+ _avg_bytes_per_sec,
45
+ _block_align,
46
+ bits_per_sample,
47
+ ) = struct.unpack("<hhiihh", content[20:36])
48
+
49
+ i = 36
50
+ samples = None
51
+ text = None
52
+ while i < file_size:
53
+ chunk_id = struct.unpack("<4s", content[i : i + 4])[0].decode("utf-8")
54
+ chunk_size = struct.unpack("<i", content[i + 4 : i + 8])[0]
55
+
56
+ if chunk_id == "data":
57
+ samples = chunk_size / channels / (bits_per_sample / 8)
58
+ break
59
+
60
+ if chunk_id == "afsp":
61
+ chunks = struct.unpack(f"<{chunk_size}s", content[i + 8 : i + 8 + chunk_size])[0]
62
+ chunks = chunks.decode("utf-8").split("\x00")
63
+ for chunk in chunks:
64
+ if chunk.startswith('text: "'):
65
+ text = chunk[7:-1].lower().translate(str.maketrans("", "", string.punctuation))
66
+ i += 8 + chunk_size + chunk_size % 2
67
+
68
+ if text and samples:
69
+ return TimeAlignedType(start=0, end=samples / sample_rate, text=text)
70
+
71
+ return None
@@ -0,0 +1,89 @@
1
+ from pathlib import Path
2
+
3
+ from praatio import textgrid
4
+ from praatio.utilities.constants import Interval
5
+
6
+ from .types import TimeAlignedType
7
+
8
+
9
+ def create_textgrid(
10
+ prompt: Path,
11
+ output_dir: Path,
12
+ text: TimeAlignedType | None = None,
13
+ words: list[TimeAlignedType] | None = None,
14
+ phonemes: list[TimeAlignedType] | None = None,
15
+ ) -> None:
16
+ if text is None and words is None and phonemes is None:
17
+ return
18
+
19
+ min_t, max_t = _get_min_max({"phonemes": phonemes, "text": text, "words": words})
20
+
21
+ tg = textgrid.Textgrid()
22
+
23
+ if text is not None:
24
+ entries = [Interval(text.start, text.end, text.text)]
25
+ text_tier = textgrid.IntervalTier("text", entries, min_t, max_t)
26
+ tg.addTier(text_tier)
27
+
28
+ if words is not None:
29
+ entries = []
30
+ for word in words:
31
+ entries.append(Interval(word.start, word.end, word.text))
32
+ words_tier = textgrid.IntervalTier("words", entries, min_t, max_t)
33
+ tg.addTier(words_tier)
34
+
35
+ if phonemes is not None:
36
+ entries = []
37
+ for phoneme in phonemes:
38
+ entries.append(Interval(phoneme.start, phoneme.end, phoneme.text))
39
+ phonemes_tier = textgrid.IntervalTier("phonemes", entries, min_t, max_t)
40
+ tg.addTier(phonemes_tier)
41
+
42
+ output_filename = str(output_dir / prompt.stem) + ".TextGrid"
43
+ tg.save(output_filename, format="long_textgrid", includeBlankSpaces=True)
44
+
45
+
46
+ def _get_min_max(tiers: dict[str, TimeAlignedType | list[TimeAlignedType] | None]) -> tuple[float, float]:
47
+ starts = []
48
+ ends = []
49
+ for tier in tiers.values():
50
+ if tier is None:
51
+ continue
52
+ if isinstance(tier, TimeAlignedType):
53
+ starts.append(tier.start)
54
+ ends.append(tier.end)
55
+ else:
56
+ starts.append(tier[0].start)
57
+ ends.append(tier[-1].end)
58
+
59
+ return min(starts), max(ends)
60
+
61
+
62
+ def annotate_textgrid(
63
+ tiers: dict[str, TimeAlignedType | list[TimeAlignedType] | None] | None, prompt: Path, output_dir: Path
64
+ ) -> None:
65
+ import os
66
+
67
+ if tiers is None:
68
+ return
69
+
70
+ file = Path(output_dir / prompt.stem).with_suffix(".TextGrid")
71
+ if not os.path.exists(file):
72
+ tg = textgrid.Textgrid()
73
+ min_t, max_t = _get_min_max(tiers)
74
+ else:
75
+ tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
76
+ min_t = tg.minTimestamp
77
+ max_t = tg.maxTimestamp
78
+
79
+ for k, v in tiers.items():
80
+ if v is None:
81
+ continue
82
+ entries = [Interval(entry.start, entry.end, entry.text) for entry in v]
83
+ if k == "phones":
84
+ name = "annotation_phonemes"
85
+ else:
86
+ name = "annotation_" + k
87
+ tg.addTier(textgrid.IntervalTier(name, entries, min_t, max_t))
88
+
89
+ tg.save(str(file), format="long_textgrid", includeBlankSpaces=True)
@@ -0,0 +1,138 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from .types import TimeAlignedType
5
+
6
+
7
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
8
+ """Load time-aligned text data given a TIMIT audio file.
9
+
10
+ :param audio: Path to the TIMIT audio file.
11
+ :return: A TimeAlignedType object.
12
+ """
13
+ import string
14
+
15
+ from ..mixture.audio import get_sample_rate
16
+
17
+ file = Path(audio).with_suffix(".TXT")
18
+ if not os.path.exists(file):
19
+ return None
20
+
21
+ sample_rate = get_sample_rate(str(audio))
22
+
23
+ with open(file, encoding="utf-8") as f:
24
+ line = f.read()
25
+
26
+ fields = line.strip().split()
27
+ start = int(fields[0]) / sample_rate
28
+ end = int(fields[1]) / sample_rate
29
+ text = " ".join(fields[2:]).lower().translate(str.maketrans("", "", string.punctuation))
30
+
31
+ return TimeAlignedType(start, end, text)
32
+
33
+
34
+ def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
35
+ """Load time-aligned word data given a TIMIT audio file.
36
+
37
+ :param audio: Path to the TIMIT audio file.
38
+ :return: A list of TimeAlignedType objects.
39
+ """
40
+
41
+ return _load_ta(audio, "words")
42
+
43
+
44
+ def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
45
+ """Load time-aligned phonemes data given a TIMIT audio file.
46
+
47
+ :param audio: Path to the TIMIT audio file.
48
+ :return: A list of TimeAlignedType objects.
49
+ """
50
+
51
+ return _load_ta(audio, "phonemes")
52
+
53
+
54
+ def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
55
+ from ..mixture.audio import get_sample_rate
56
+
57
+ if tier == "words":
58
+ file = Path(audio).with_suffix(".WRD")
59
+ elif tier == "phonemes":
60
+ file = Path(audio).with_suffix(".PHN")
61
+ else:
62
+ raise ValueError(f"Unknown tier: {tier}")
63
+
64
+ if not os.path.exists(file):
65
+ return None
66
+
67
+ sample_rate = get_sample_rate(str(audio))
68
+
69
+ entries: list[TimeAlignedType] = []
70
+ first = True
71
+ with open(file, encoding="utf-8") as f:
72
+ for line in f.readlines():
73
+ fields = line.strip().split()
74
+ start = int(fields[0]) / sample_rate
75
+ end = int(fields[1]) / sample_rate
76
+ text = " ".join(fields[2:])
77
+
78
+ if first:
79
+ first = False
80
+ else:
81
+ if start < entries[-1].end:
82
+ start = entries[-1].end - (entries[-1].end - start) // 2
83
+ entries[-1] = TimeAlignedType(text=entries[-1].text, start=entries[-1].start, end=start)
84
+
85
+ if end <= start:
86
+ end = start + 1 / sample_rate
87
+
88
+ entries.append(TimeAlignedType(text=text, start=start, end=end))
89
+
90
+ return entries
91
+
92
+
93
+ def _years_between(record, born):
94
+ try:
95
+ rec_fields = [int(x) for x in record.split("/")]
96
+ brn_fields = [int(x) for x in born.split("/")]
97
+ return rec_fields[2] - brn_fields[2] - ((rec_fields[1], rec_fields[0]) < (brn_fields[1], brn_fields[0]))
98
+ except ValueError:
99
+ return "??"
100
+
101
+
102
+ def _decode_dialect(d: str) -> str:
103
+ if d in ["DR1", "1"]:
104
+ return "New England"
105
+ if d in ["DR2", "2"]:
106
+ return "Northern"
107
+ if d in ["DR3", "3"]:
108
+ return "North Midland"
109
+ if d in ["DR4", "4"]:
110
+ return "South Midland"
111
+ if d in ["DR5", "5"]:
112
+ return "Southern"
113
+ if d in ["DR6", "6"]:
114
+ return "New York City"
115
+ if d in ["DR7", "7"]:
116
+ return "Western"
117
+ if d in ["DR8", "8"]:
118
+ return "Army Brat"
119
+
120
+ raise ValueError(f"Unrecognized dialect: {d}")
121
+
122
+
123
+ def load_speakers(input_dir: Path) -> dict:
124
+ speakers = {}
125
+ with open(input_dir / "SPKRINFO.TXT") as file:
126
+ for line in file:
127
+ if not line.startswith(";"):
128
+ fields = line.strip().split()
129
+ speaker_id = fields[0]
130
+ gender = fields[1]
131
+ dialect = _decode_dialect(fields[2])
132
+ age = _years_between(fields[4], fields[5])
133
+ speakers[speaker_id] = {
134
+ "gender": gender,
135
+ "dialect": dialect,
136
+ "age": age,
137
+ }
138
+ return speakers
@@ -0,0 +1,12 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass(frozen=True)
5
+ class TimeAlignedType:
6
+ start: float
7
+ end: float
8
+ text: str
9
+
10
+ @property
11
+ def duration(self) -> float:
12
+ return self.end - self.start
sonusai/speech/vctk.py ADDED
@@ -0,0 +1,53 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from .types import TimeAlignedType
5
+
6
+
7
+ def _get_duration(name: str) -> float:
8
+ import soundfile
9
+
10
+ try:
11
+ return soundfile.info(name).duration
12
+ except Exception as e:
13
+ raise OSError(f"Error reading {name}: {e}") from e
14
+
15
+
16
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
17
+ """Load time-aligned text data given a VCTK audio file.
18
+
19
+ :param audio: Path to the VCTK audio file.
20
+ :return: A TimeAlignedType object.
21
+ """
22
+ import string
23
+
24
+ file = Path(audio).parents[2] / "txt" / Path(audio).parent.name / (Path(audio).stem[:-5] + ".txt")
25
+ if not os.path.exists(file):
26
+ return None
27
+
28
+ with open(file, encoding="utf-8") as f:
29
+ line = f.read()
30
+
31
+ start = 0
32
+ end = _get_duration(str(audio))
33
+ text = line.strip().lower().translate(str.maketrans("", "", string.punctuation))
34
+
35
+ return TimeAlignedType(start, end, text)
36
+
37
+
38
+ def load_speakers(input_dir: Path) -> dict:
39
+ speakers = {}
40
+ with open(input_dir / "speaker-info.txt") as file:
41
+ for line in file:
42
+ if not line.startswith("ID"):
43
+ fields = line.strip().split("(", 1)[0].split()
44
+ speaker_id = fields[0]
45
+ age = fields[1]
46
+ gender = fields[2]
47
+ dialect = " ".join(list(fields[3:]))
48
+ speakers[speaker_id] = {
49
+ "gender": gender,
50
+ "dialect": dialect,
51
+ "age": age,
52
+ }
53
+ return speakers
@@ -0,0 +1,108 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class Segment:
8
+ person: str
9
+ video: str
10
+ id: str
11
+ start: float
12
+ stop: float
13
+
14
+
15
+ def load_speakers(input_dir: Path) -> dict:
16
+ import csv
17
+
18
+ speakers = {}
19
+
20
+ # VoxCeleb1
21
+ first = True
22
+ with open(input_dir / "vox1_meta.csv", newline="") as file:
23
+ data = csv.reader(file, delimiter="\t")
24
+ for row in data:
25
+ if first:
26
+ first = False
27
+ else:
28
+ speakers[row[0].strip()] = {
29
+ "gender": row[2].strip(),
30
+ "dialect": row[3].strip(),
31
+ "category": row[4].strip(),
32
+ }
33
+
34
+ # VoxCeleb2
35
+ first = True
36
+ with open(input_dir / "vox2_meta.csv", newline="") as file:
37
+ data = csv.reader(file, delimiter="\t")
38
+ for row in data:
39
+ if first:
40
+ first = False
41
+ else:
42
+ speakers[row[1].strip()] = {
43
+ "gender": row[3].strip(),
44
+ "category": row[4].strip(),
45
+ }
46
+
47
+ return speakers
48
+
49
+
50
+ def load_segment(path: str | os.PathLike[str]) -> Segment:
51
+ path = Path(path)
52
+
53
+ with path.open("r") as file:
54
+ segment = file.read().strip()
55
+
56
+ header, frames = segment.split("\n\n")
57
+ header_dict = _parse_header(header)
58
+ start, stop = _get_segment_boundaries(frames)
59
+
60
+ return Segment(
61
+ person=header_dict["Identity"],
62
+ video=header_dict["Reference"],
63
+ id=path.stem,
64
+ start=start,
65
+ stop=stop,
66
+ )
67
+
68
+
69
+ def _parse_header(header: str) -> dict:
70
+ def _parse_line(line: str) -> tuple[str, str]:
71
+ """Parse a line of header text into a dictionary.
72
+
73
+ Header text has the following format:
74
+
75
+ Identity : \tid00017
76
+ Reference : \t7t6lfzvVaTM
77
+ Offset : \t1
78
+ FV Conf : \t16.647\t(1)
79
+ ASD Conf : \t4.465
80
+
81
+ """
82
+ k, v = line.split("\t", maxsplit=1)
83
+ k = k[:-2].strip()
84
+ v = v.strip()
85
+ return k, v
86
+
87
+ return dict(_parse_line(line) for line in header.split("\n"))
88
+
89
+
90
+ def _get_segment_boundaries(frames: str) -> tuple[float, float]:
91
+ """Get the start and stop points of the segment.
92
+
93
+ Frames text has the following format:
94
+
95
+ FRAME X Y W H
96
+ 000245 0.392 0.223 0.253 0.451
97
+ ...
98
+ 000470 0.359 0.207 0.260 0.463
99
+
100
+ """
101
+
102
+ def _get_frame_seconds(line: str) -> float:
103
+ frame = int(line.split("\t")[0])
104
+ # YouTube is 25 FPS
105
+ return frame / 25
106
+
107
+ lines = frames.split("\n")
108
+ return _get_frame_seconds(lines[1]), _get_frame_seconds(lines[-1])
@@ -0,0 +1,3 @@
1
+ from .asr import ASRResult
2
+
3
+ __all__ = ["ASRResult"]
@@ -0,0 +1,130 @@
1
+ from ..datatypes import AudioT
2
+
3
+
4
+ def asl_p56(audio: AudioT) -> float:
5
+ """Implement ITU-T P.56 method B
6
+ :param audio: audio for which to calculate active speech level
7
+ :return: Active speech level mean square energy
8
+ """
9
+ import numpy as np
10
+ import scipy.signal as signal
11
+
12
+ from ..constants import SAMPLE_RATE
13
+
14
+ eps = np.finfo(np.float32).eps
15
+
16
+ # Time constant of smoothing in seconds
17
+ T = 0.03
18
+
19
+ # Coefficient of smoothing
20
+ g = np.exp(-1 / (SAMPLE_RATE * T))
21
+
22
+ # Hangover time in seconds
23
+ H = 0.2
24
+ # Rounded up to next integer
25
+ H_samples = np.ceil(H * SAMPLE_RATE)
26
+
27
+ # Margin in dB, difference between threshold and active speech level
28
+ M = 15.9
29
+
30
+ # Number of thresholds
31
+ thresh_num = 15
32
+
33
+ # Series of fixed threshold voltages to apply to the envelope. These are spaced
34
+ # in geometric progression, at intervals of not more than 2:1 (6.02 dB), from a
35
+ # value equal to about half the maximum code down to a value equal to one
36
+ # quantizing interval or lower.
37
+ c = 2 ** np.arange(-15, thresh_num - 15, dtype=np.float32)
38
+
39
+ # Activity counter for each threshold
40
+ a = np.full(thresh_num, -1)
41
+
42
+ # Hangover counter for each threshold
43
+ h = np.full(thresh_num, H_samples)
44
+
45
+ # Long-term level square energy of audio
46
+ sq = sum(np.square(audio))
47
+
48
+ # Use a 2nd order IIR filter to detect the envelope q
49
+ p = signal.lfilter([1 - g, 0], [1, -g], abs(audio))
50
+ # q is the envelope, obtained from moving average of abs(audio) (with slight "hangover").
51
+ q = signal.lfilter([1 - g, 0], [1, -g], p)
52
+
53
+ for k in range(len(audio)):
54
+ for j in range(thresh_num):
55
+ if q[k] >= c[j]:
56
+ a[j] = a[j] + 1
57
+ h[j] = 0
58
+ elif h[j] < H_samples:
59
+ a[j] = a[j] + 1
60
+ h[j] = h[j] + 1
61
+ else:
62
+ break
63
+ asl_msq = 0
64
+ if a[0] == -1:
65
+ return asl_msq
66
+
67
+ a += 2
68
+ A_db1 = 10 * np.log10(sq / a[0] + eps)
69
+ C_db1 = 20 * np.log10(c[0] + eps)
70
+ if A_db1 - C_db1 < M:
71
+ return asl_msq
72
+
73
+ A_db = np.zeros(thresh_num)
74
+ C_db = np.zeros(thresh_num)
75
+ delta = np.zeros(thresh_num)
76
+ A_db[0] = A_db1
77
+ C_db[0] = C_db1
78
+ delta[0] = A_db1 - C_db1
79
+
80
+ for j in range(1, thresh_num):
81
+ A_db[j] = 10 * np.log10(sq / (a[j] + eps) + eps)
82
+ C_db[j] = 20 * np.log10(c[j] + eps)
83
+
84
+ for j in range(1, thresh_num):
85
+ if a[j] != 0:
86
+ delta[j] = A_db[j] - C_db[j]
87
+ if delta[j] <= M:
88
+ # Interpolate to find the asl_ms_log
89
+ asl_ms_log = _bin_interp(A_db[j], A_db[j - 1], C_db[j], C_db[j - 1], M, 0.5)
90
+ # This is the mean square value NOT the RMS
91
+ asl_msq = 10.0 ** (asl_ms_log / 10)
92
+ break
93
+
94
+ return asl_msq
95
+
96
+
97
+ def _bin_interp(u_cnt: float, l_cnt: float, u_thr: float, l_thr: float, margin: float, tol: float) -> float:
98
+ tol = abs(tol)
99
+
100
+ # Check if extreme counts are not already the true active value
101
+ iter_num = 1
102
+ if abs(u_cnt - u_thr - margin) < tol:
103
+ return u_cnt
104
+
105
+ if abs(l_cnt - l_thr - margin) < tol:
106
+ return l_cnt
107
+
108
+ # Initialize first middle for given (initial) bounds
109
+ m_cnt = (u_cnt + l_cnt) / 2.0
110
+ m_thr = (u_thr + l_thr) / 2.0
111
+
112
+ while True:
113
+ # Loop until diff falls inside the tolerance (-tol<=diff<=tol)
114
+ diff = m_cnt - m_thr - margin
115
+ if abs(diff) <= tol:
116
+ break
117
+
118
+ # If tolerance is not met up to 20 iterations, then relax the tolerance by 10
119
+ iter_num += 1
120
+ if iter_num > 20:
121
+ tol = tol * 1.1
122
+
123
+ if diff > tol:
124
+ m_cnt = (u_cnt + m_cnt) / 2.0
125
+ m_thr = (u_thr + m_thr) / 2.0
126
+ elif diff < -tol:
127
+ m_cnt = (m_cnt + l_cnt) / 2.0
128
+ m_thr = (m_thr + l_thr) / 2.0
129
+
130
+ return m_cnt
sonusai/utils/asr.py ADDED
@@ -0,0 +1,91 @@
1
+ from collections.abc import Callable
2
+ from dataclasses import dataclass
3
+
4
+ from ..datatypes import AudioT
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class ASRResult:
9
+ text: str
10
+ confidence: float | None = None
11
+ lang: str | None = None
12
+ lang_prob: float | None = None
13
+ duration: float | None = None
14
+ num_segments: int | None = None
15
+ asr_cpu_time: float | None = None
16
+
17
+
18
+ def get_available_engines() -> list[str]:
19
+ from importlib import import_module
20
+ from pkgutil import iter_modules
21
+
22
+ module = import_module("sonusai.utils.asr_functions")
23
+ engines = [method for method in dir(module) if not method.startswith("_")]
24
+ for _, name, _ in iter_modules():
25
+ if name.startswith("sonusai_asr_"):
26
+ module = import_module(f"{name}.asr_functions")
27
+ for method in dir(module):
28
+ if not method.startswith("_"):
29
+ engines.append(method)
30
+
31
+ return engines
32
+
33
+
34
+ def _asr_fn(engine: str) -> Callable[..., ASRResult]:
35
+ from importlib import import_module
36
+ from pkgutil import iter_modules
37
+
38
+ module = import_module("sonusai.utils.asr_functions")
39
+ for method in dir(module):
40
+ if method == engine:
41
+ return getattr(module, method)
42
+
43
+ for _, name, _ in iter_modules():
44
+ if name.startswith("sonusai_asr_"):
45
+ module = import_module(f"{name}.asr_functions")
46
+ for method in dir(module):
47
+ if method == engine:
48
+ return getattr(module, method)
49
+
50
+ raise ValueError(f"engine {engine} not supported")
51
+
52
+
53
+ def calc_asr(audio: AudioT | str, engine: str, **config) -> ASRResult:
54
+ """Run ASR on audio
55
+
56
+ :param audio: Numpy array of audio samples or location of an audio file
57
+ :param engine: ASR engine to use
58
+ :param config: kwargs configuration parameters
59
+ :return: ASRResult object containing text and confidence
60
+ """
61
+ from copy import copy
62
+
63
+ import numpy as np
64
+
65
+ from ..mixture.audio import read_audio
66
+
67
+ if not isinstance(audio, np.ndarray):
68
+ audio = copy(read_audio(audio, config.get("use_cache", True)))
69
+
70
+ return _asr_fn(engine)(audio, **config)
71
+
72
+
73
+ def validate_asr(engine: str, **config) -> None:
74
+ from importlib import import_module
75
+ from pkgutil import iter_modules
76
+
77
+ module = import_module("sonusai.utils.asr_functions")
78
+ for method in dir(module):
79
+ if method == engine:
80
+ getattr(module, method + "_validate")(**config)
81
+ return
82
+
83
+ for _, name, _ in iter_modules():
84
+ if name.startswith("sonusai_asr_"):
85
+ module = import_module(f"{name}.asr_functions")
86
+ for method in dir(module):
87
+ if method == engine:
88
+ getattr(module, method + "_validate")(**config)
89
+ return
90
+
91
+ raise ValueError(f"engine {engine} not supported")
@@ -0,0 +1,3 @@
1
+ from .aaware_whisper import aaware_whisper
2
+
3
+ __all__ = ["aaware_whisper"]