sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,5 @@
1
1
  import os
2
2
  from pathlib import Path
3
- from typing import Optional
4
3
 
5
4
  from .types import TimeAlignedType
6
5
 
@@ -14,16 +13,16 @@ def _get_num_samples(audio: str | os.PathLike[str]) -> int:
14
13
  import soundfile
15
14
  from pydub import AudioSegment
16
15
 
17
- if Path(audio).suffix == '.mp3':
16
+ if Path(audio).suffix == ".mp3":
18
17
  return AudioSegment.from_mp3(audio).frame_count()
19
18
 
20
- if Path(audio).suffix == '.m4a':
19
+ if Path(audio).suffix == ".m4a":
21
20
  return AudioSegment.from_file(audio).frame_count()
22
21
 
23
22
  return soundfile.info(audio).frames
24
23
 
25
24
 
26
- def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
25
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
27
26
  """Load text data from a LibriSpeech transcription file given a LibriSpeech audio filename.
28
27
 
29
28
  :param audio: Path to the LibriSpeech audio file.
@@ -35,44 +34,44 @@ def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
35
34
 
36
35
  path = Path(audio)
37
36
  name = path.stem
38
- transcript_filename = path.parent / f'{path.parent.parent.name}-{path.parent.name}.trans.txt'
37
+ transcript_filename = path.parent / f"{path.parent.parent.name}-{path.parent.name}.trans.txt"
39
38
 
40
39
  if not os.path.exists(transcript_filename):
41
40
  return None
42
41
 
43
- with open(transcript_filename, mode='r', encoding='utf-8') as f:
42
+ with open(transcript_filename, encoding="utf-8") as f:
44
43
  for line in f.readlines():
45
44
  fields = line.strip().split()
46
45
  key = fields[0]
47
46
  if key == name:
48
- text = ' '.join(fields[1:]).lower().translate(str.maketrans('', '', string.punctuation))
47
+ text = " ".join(fields[1:]).lower().translate(str.maketrans("", "", string.punctuation))
49
48
  return TimeAlignedType(0, _get_num_samples(audio) / get_sample_rate(str(audio)), text)
50
49
 
51
50
  return None
52
51
 
53
52
 
54
- def load_words(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
53
+ def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
55
54
  """Load time-aligned word data given a LibriSpeech audio file.
56
55
 
57
56
  :param audio: Path to the Librispeech audio file.
58
57
  :return: A list of TimeAlignedType objects.
59
58
  """
60
- return _load_ta(audio, 'words')
59
+ return _load_ta(audio, "words")
61
60
 
62
61
 
63
- def load_phonemes(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
62
+ def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
64
63
  """Load time-aligned phonemes data given a LibriSpeech audio file.
65
64
 
66
65
  :param audio: Path to the LibriSpeech audio file.
67
66
  :return: A list of TimeAlignedType objects.
68
67
  """
69
- return _load_ta(audio, 'phones')
68
+ return _load_ta(audio, "phones")
70
69
 
71
70
 
72
- def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlignedType]]:
71
+ def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
73
72
  from praatio import textgrid
74
73
 
75
- file = Path(audio).with_suffix('.TextGrid')
74
+ file = Path(audio).with_suffix(".TextGrid")
76
75
  if not os.path.exists(file):
77
76
  return None
78
77
 
@@ -89,11 +88,11 @@ def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlig
89
88
 
90
89
  def load_speakers(input_dir: Path) -> dict:
91
90
  speakers = {}
92
- with open(input_dir / 'SPEAKERS.TXT') as file:
91
+ with open(input_dir / "SPEAKERS.TXT") as file:
93
92
  for line in file:
94
- if not line.startswith(';'):
95
- fields = line.strip().split('|')
93
+ if not line.startswith(";"):
94
+ fields = line.strip().split("|")
96
95
  speaker_id = fields[0].strip()
97
96
  gender = fields[1].strip()
98
- speakers[speaker_id] = {'gender': gender}
97
+ speakers[speaker_id] = {"gender": gender}
99
98
  return speakers
sonusai/speech/mcgill.py CHANGED
@@ -1,10 +1,9 @@
1
1
  import os
2
- from typing import Optional
3
2
 
4
3
  from .types import TimeAlignedType
5
4
 
6
5
 
7
- def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
6
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
8
7
  """Load time-aligned text data given a McGill-Speech audio file.
9
8
 
10
9
  :param audio: Path to the McGill-Speech audio file.
@@ -20,48 +19,50 @@ def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
20
19
 
21
20
  sample_rate = get_sample_rate(str(audio))
22
21
 
23
- with open(audio, mode='rb') as f:
22
+ with open(audio, mode="rb") as f:
24
23
  content = f.read()
25
24
 
26
- riff_id, file_size, wave_id = struct.unpack('<4si4s', content[:12])
27
- if riff_id.decode('utf-8') != 'RIFF':
25
+ riff_id, file_size, wave_id = struct.unpack("<4si4s", content[:12])
26
+ if riff_id.decode("utf-8") != "RIFF":
28
27
  return None
29
28
 
30
- if wave_id.decode('utf-8') != 'WAVE':
29
+ if wave_id.decode("utf-8") != "WAVE":
31
30
  return None
32
31
 
33
- fmt_id, fmt_size = struct.unpack('<4si', content[12:20])
32
+ fmt_id, fmt_size = struct.unpack("<4si", content[12:20])
34
33
 
35
- if fmt_id.decode('utf-8') != 'fmt ':
34
+ if fmt_id.decode("utf-8") != "fmt ":
36
35
  return None
37
36
 
38
37
  if fmt_size != 16:
39
38
  return None
40
39
 
41
- (_wave_format_tag,
42
- channels,
43
- _samples_per_sec,
44
- _avg_bytes_per_sec,
45
- _block_align,
46
- bits_per_sample) = struct.unpack('<hhiihh', content[20:36])
40
+ (
41
+ _wave_format_tag,
42
+ channels,
43
+ _samples_per_sec,
44
+ _avg_bytes_per_sec,
45
+ _block_align,
46
+ bits_per_sample,
47
+ ) = struct.unpack("<hhiihh", content[20:36])
47
48
 
48
49
  i = 36
49
50
  samples = None
50
51
  text = None
51
52
  while i < file_size:
52
- chunk_id = struct.unpack('<4s', content[i:i + 4])[0].decode('utf-8')
53
- chunk_size = struct.unpack('<i', content[i + 4:i + 8])[0]
53
+ chunk_id = struct.unpack("<4s", content[i : i + 4])[0].decode("utf-8")
54
+ chunk_size = struct.unpack("<i", content[i + 4 : i + 8])[0]
54
55
 
55
- if chunk_id == 'data':
56
+ if chunk_id == "data":
56
57
  samples = chunk_size / channels / (bits_per_sample / 8)
57
58
  break
58
59
 
59
- if chunk_id == 'afsp':
60
- chunks = struct.unpack(f'<{chunk_size}s', content[i + 8:i + 8 + chunk_size])[0]
61
- chunks = chunks.decode('utf-8').split('\x00')
60
+ if chunk_id == "afsp":
61
+ chunks = struct.unpack(f"<{chunk_size}s", content[i + 8 : i + 8 + chunk_size])[0]
62
+ chunks = chunks.decode("utf-8").split("\x00")
62
63
  for chunk in chunks:
63
64
  if chunk.startswith('text: "'):
64
- text = chunk[7:-1].lower().translate(str.maketrans('', '', string.punctuation))
65
+ text = chunk[7:-1].lower().translate(str.maketrans("", "", string.punctuation))
65
66
  i += 8 + chunk_size + chunk_size % 2
66
67
 
67
68
  if text and samples:
@@ -6,61 +6,68 @@ from praatio.utilities.constants import Interval
6
6
  from .types import TimeAlignedType
7
7
 
8
8
 
9
- def create_textgrid(prompt: Path,
10
- output_dir: Path,
11
- text: TimeAlignedType = None,
12
- words: list[TimeAlignedType] = None,
13
- phonemes: list[TimeAlignedType] = None) -> None:
9
+ def create_textgrid(
10
+ prompt: Path,
11
+ output_dir: Path,
12
+ text: TimeAlignedType | None = None,
13
+ words: list[TimeAlignedType] | None = None,
14
+ phonemes: list[TimeAlignedType] | None = None,
15
+ ) -> None:
14
16
  if text is None and words is None and phonemes is None:
15
17
  return
16
18
 
17
- min_t, max_t = _get_min_max({'phonemes': phonemes,
18
- 'text': [text],
19
- 'words': words})
19
+ min_t, max_t = _get_min_max({"phonemes": phonemes, "text": text, "words": words})
20
20
 
21
21
  tg = textgrid.Textgrid()
22
22
 
23
23
  if text is not None:
24
24
  entries = [Interval(text.start, text.end, text.text)]
25
- text_tier = textgrid.IntervalTier('text', entries, min_t, max_t)
25
+ text_tier = textgrid.IntervalTier("text", entries, min_t, max_t)
26
26
  tg.addTier(text_tier)
27
27
 
28
28
  if words is not None:
29
29
  entries = []
30
30
  for word in words:
31
31
  entries.append(Interval(word.start, word.end, word.text))
32
- words_tier = textgrid.IntervalTier('words', entries, min_t, max_t)
32
+ words_tier = textgrid.IntervalTier("words", entries, min_t, max_t)
33
33
  tg.addTier(words_tier)
34
34
 
35
35
  if phonemes is not None:
36
36
  entries = []
37
37
  for phoneme in phonemes:
38
38
  entries.append(Interval(phoneme.start, phoneme.end, phoneme.text))
39
- phonemes_tier = textgrid.IntervalTier('phonemes', entries, min_t, max_t)
39
+ phonemes_tier = textgrid.IntervalTier("phonemes", entries, min_t, max_t)
40
40
  tg.addTier(phonemes_tier)
41
41
 
42
- output_filename = str(output_dir / prompt.stem) + '.TextGrid'
43
- tg.save(output_filename, format='long_textgrid', includeBlankSpaces=True)
42
+ output_filename = str(output_dir / prompt.stem) + ".TextGrid"
43
+ tg.save(output_filename, format="long_textgrid", includeBlankSpaces=True)
44
44
 
45
45
 
46
- def _get_min_max(tiers: dict[str, list[TimeAlignedType]]) -> tuple[float, float]:
46
+ def _get_min_max(tiers: dict[str, TimeAlignedType | list[TimeAlignedType] | None]) -> tuple[float, float]:
47
47
  starts = []
48
48
  ends = []
49
49
  for tier in tiers.values():
50
- if tier is not None:
50
+ if tier is None:
51
+ continue
52
+ if isinstance(tier, TimeAlignedType):
53
+ starts.append(tier.start)
54
+ ends.append(tier.end)
55
+ else:
51
56
  starts.append(tier[0].start)
52
57
  ends.append(tier[-1].end)
53
58
 
54
59
  return min(starts), max(ends)
55
60
 
56
61
 
57
- def annotate_textgrid(tiers: dict[str, list[TimeAlignedType]], prompt: Path, output_dir: Path) -> None:
62
+ def annotate_textgrid(
63
+ tiers: dict[str, TimeAlignedType | list[TimeAlignedType] | None] | None, prompt: Path, output_dir: Path
64
+ ) -> None:
58
65
  import os
59
66
 
60
67
  if tiers is None:
61
68
  return
62
69
 
63
- file = Path(output_dir / prompt.stem).with_suffix('.TextGrid')
70
+ file = Path(output_dir / prompt.stem).with_suffix(".TextGrid")
64
71
  if not os.path.exists(file):
65
72
  tg = textgrid.Textgrid()
66
73
  min_t, max_t = _get_min_max(tiers)
@@ -69,14 +76,14 @@ def annotate_textgrid(tiers: dict[str, list[TimeAlignedType]], prompt: Path, out
69
76
  min_t = tg.minTimestamp
70
77
  max_t = tg.maxTimestamp
71
78
 
72
- for tier in tiers.keys():
73
- entries = []
74
- for entry in tiers[tier]:
75
- entries.append(Interval(entry.start, entry.end, entry.text))
76
- if tier == 'phones':
77
- name = 'annotation_phonemes'
79
+ for k, v in tiers.items():
80
+ if v is None:
81
+ continue
82
+ entries = [Interval(entry.start, entry.end, entry.text) for entry in v]
83
+ if k == "phones":
84
+ name = "annotation_phonemes"
78
85
  else:
79
- name = 'annotation_' + tier
86
+ name = "annotation_" + k
80
87
  tg.addTier(textgrid.IntervalTier(name, entries, min_t, max_t))
81
88
 
82
- tg.save(str(file), format='long_textgrid', includeBlankSpaces=True)
89
+ tg.save(str(file), format="long_textgrid", includeBlankSpaces=True)
sonusai/speech/timit.py CHANGED
@@ -1,11 +1,10 @@
1
1
  import os
2
2
  from pathlib import Path
3
- from typing import Optional
4
3
 
5
4
  from .types import TimeAlignedType
6
5
 
7
6
 
8
- def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
7
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
9
8
  """Load time-aligned text data given a TIMIT audio file.
10
9
 
11
10
  :param audio: Path to the TIMIT audio file.
@@ -15,52 +14,52 @@ def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
15
14
 
16
15
  from sonusai.mixture import get_sample_rate
17
16
 
18
- file = Path(audio).with_suffix('.TXT')
17
+ file = Path(audio).with_suffix(".TXT")
19
18
  if not os.path.exists(file):
20
19
  return None
21
20
 
22
21
  sample_rate = get_sample_rate(str(audio))
23
22
 
24
- with open(file, mode='r', encoding='utf-8') as f:
23
+ with open(file, encoding="utf-8") as f:
25
24
  line = f.read()
26
25
 
27
26
  fields = line.strip().split()
28
27
  start = int(fields[0]) / sample_rate
29
28
  end = int(fields[1]) / sample_rate
30
- text = ' '.join(fields[2:]).lower().translate(str.maketrans('', '', string.punctuation))
29
+ text = " ".join(fields[2:]).lower().translate(str.maketrans("", "", string.punctuation))
31
30
 
32
31
  return TimeAlignedType(start, end, text)
33
32
 
34
33
 
35
- def load_words(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
34
+ def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
36
35
  """Load time-aligned word data given a TIMIT audio file.
37
36
 
38
37
  :param audio: Path to the TIMIT audio file.
39
38
  :return: A list of TimeAlignedType objects.
40
39
  """
41
40
 
42
- return _load_ta(audio, 'words')
41
+ return _load_ta(audio, "words")
43
42
 
44
43
 
45
- def load_phonemes(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
44
+ def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
46
45
  """Load time-aligned phonemes data given a TIMIT audio file.
47
46
 
48
47
  :param audio: Path to the TIMIT audio file.
49
48
  :return: A list of TimeAlignedType objects.
50
49
  """
51
50
 
52
- return _load_ta(audio, 'phonemes')
51
+ return _load_ta(audio, "phonemes")
53
52
 
54
53
 
55
- def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlignedType]]:
54
+ def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
56
55
  from sonusai.mixture import get_sample_rate
57
56
 
58
- if tier == 'words':
59
- file = Path(audio).with_suffix('.WRD')
60
- elif tier == 'phonemes':
61
- file = Path(audio).with_suffix('.PHN')
57
+ if tier == "words":
58
+ file = Path(audio).with_suffix(".WRD")
59
+ elif tier == "phonemes":
60
+ file = Path(audio).with_suffix(".PHN")
62
61
  else:
63
- raise ValueError(f'Unknown tier: {tier}')
62
+ raise ValueError(f"Unknown tier: {tier}")
64
63
 
65
64
  if not os.path.exists(file):
66
65
  return None
@@ -69,18 +68,18 @@ def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlig
69
68
 
70
69
  entries: list[TimeAlignedType] = []
71
70
  first = True
72
- with open(file, mode='r', encoding='utf-8') as f:
71
+ with open(file, encoding="utf-8") as f:
73
72
  for line in f.readlines():
74
73
  fields = line.strip().split()
75
74
  start = int(fields[0]) / sample_rate
76
75
  end = int(fields[1]) / sample_rate
77
- text = ' '.join(fields[2:])
76
+ text = " ".join(fields[2:])
78
77
 
79
78
  if first:
80
79
  first = False
81
80
  else:
82
81
  if start < entries[-1].end:
83
- start = entries[-1].end - (entries[- 1].end - start) // 2
82
+ start = entries[-1].end - (entries[-1].end - start) // 2
84
83
  entries[-1] = TimeAlignedType(text=entries[-1].text, start=entries[-1].start, end=start)
85
84
 
86
85
  if end <= start:
@@ -93,43 +92,47 @@ def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlig
93
92
 
94
93
  def _years_between(record, born):
95
94
  try:
96
- rec_fields = [int(x) for x in record.split('/')]
97
- brn_fields = [int(x) for x in born.split('/')]
95
+ rec_fields = [int(x) for x in record.split("/")]
96
+ brn_fields = [int(x) for x in born.split("/")]
98
97
  return rec_fields[2] - brn_fields[2] - ((rec_fields[1], rec_fields[0]) < (brn_fields[1], brn_fields[0]))
99
98
  except ValueError:
100
- return '??'
99
+ return "??"
101
100
 
102
101
 
103
102
  def _decode_dialect(d: str) -> str:
104
- if d in ['DR1', '1']:
105
- return 'New England'
106
- if d in ['DR2', '2']:
107
- return 'Northern'
108
- if d in ['DR3', '3']:
109
- return 'North Midland'
110
- if d in ['DR4', '4']:
111
- return 'South Midland'
112
- if d in ['DR5', '5']:
113
- return 'Southern'
114
- if d in ['DR6', '6']:
115
- return 'New York City'
116
- if d in ['DR7', '7']:
117
- return 'Western'
118
- if d in ['DR8', '8']:
119
- return 'Army Brat'
120
-
121
- raise ValueError(f'Unrecognized dialect: {d}')
103
+ if d in ["DR1", "1"]:
104
+ return "New England"
105
+ if d in ["DR2", "2"]:
106
+ return "Northern"
107
+ if d in ["DR3", "3"]:
108
+ return "North Midland"
109
+ if d in ["DR4", "4"]:
110
+ return "South Midland"
111
+ if d in ["DR5", "5"]:
112
+ return "Southern"
113
+ if d in ["DR6", "6"]:
114
+ return "New York City"
115
+ if d in ["DR7", "7"]:
116
+ return "Western"
117
+ if d in ["DR8", "8"]:
118
+ return "Army Brat"
119
+
120
+ raise ValueError(f"Unrecognized dialect: {d}")
122
121
 
123
122
 
124
123
  def load_speakers(input_dir: Path) -> dict:
125
124
  speakers = {}
126
- with open(input_dir / 'SPKRINFO.TXT') as file:
125
+ with open(input_dir / "SPKRINFO.TXT") as file:
127
126
  for line in file:
128
- if not line.startswith(';'):
127
+ if not line.startswith(";"):
129
128
  fields = line.strip().split()
130
129
  speaker_id = fields[0]
131
130
  gender = fields[1]
132
131
  dialect = _decode_dialect(fields[2])
133
132
  age = _years_between(fields[4], fields[5])
134
- speakers[speaker_id] = {'gender': gender, 'dialect': dialect, 'age': age}
133
+ speakers[speaker_id] = {
134
+ "gender": gender,
135
+ "dialect": dialect,
136
+ "age": age,
137
+ }
135
138
  return speakers
sonusai/speech/vctk.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import os
2
2
  from pathlib import Path
3
- from typing import Optional
4
3
 
5
4
  from .types import TimeAlignedType
6
5
 
@@ -8,15 +7,13 @@ from .types import TimeAlignedType
8
7
  def _get_duration(name: str) -> float:
9
8
  import soundfile
10
9
 
11
- from sonusai import SonusAIError
12
-
13
10
  try:
14
11
  return soundfile.info(name).duration
15
12
  except Exception as e:
16
- raise SonusAIError(f'Error reading {name}: {e}')
13
+ raise OSError(f"Error reading {name}: {e}") from e
17
14
 
18
15
 
19
- def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
16
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
20
17
  """Load time-aligned text data given a VCTK audio file.
21
18
 
22
19
  :param audio: Path to the VCTK audio file.
@@ -24,29 +21,33 @@ def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
24
21
  """
25
22
  import string
26
23
 
27
- file = Path(audio).parents[2] / 'txt' / Path(audio).parent.name / (Path(audio).stem[:-5] + '.txt')
24
+ file = Path(audio).parents[2] / "txt" / Path(audio).parent.name / (Path(audio).stem[:-5] + ".txt")
28
25
  if not os.path.exists(file):
29
26
  return None
30
27
 
31
- with open(file, mode='r', encoding='utf-8') as f:
28
+ with open(file, encoding="utf-8") as f:
32
29
  line = f.read()
33
30
 
34
31
  start = 0
35
32
  end = _get_duration(str(audio))
36
- text = line.strip().lower().translate(str.maketrans('', '', string.punctuation))
33
+ text = line.strip().lower().translate(str.maketrans("", "", string.punctuation))
37
34
 
38
35
  return TimeAlignedType(start, end, text)
39
36
 
40
37
 
41
38
  def load_speakers(input_dir: Path) -> dict:
42
39
  speakers = {}
43
- with open(input_dir / 'speaker-info.txt') as file:
40
+ with open(input_dir / "speaker-info.txt") as file:
44
41
  for line in file:
45
- if not line.startswith('ID'):
46
- fields = line.strip().split('(', 1)[0].split()
42
+ if not line.startswith("ID"):
43
+ fields = line.strip().split("(", 1)[0].split()
47
44
  speaker_id = fields[0]
48
45
  age = fields[1]
49
46
  gender = fields[2]
50
- dialect = ' '.join([field for field in fields[3:]])
51
- speakers[speaker_id] = {'gender': gender, 'dialect': dialect, 'age': age}
47
+ dialect = " ".join(list(fields[3:]))
48
+ speakers[speaker_id] = {
49
+ "gender": gender,
50
+ "dialect": dialect,
51
+ "age": age,
52
+ }
52
53
  return speakers
@@ -19,26 +19,30 @@ def load_speakers(input_dir: Path) -> dict:
19
19
 
20
20
  # VoxCeleb1
21
21
  first = True
22
- with open(input_dir / 'vox1_meta.csv', newline='') as file:
23
- data = csv.reader(file, delimiter='\t')
22
+ with open(input_dir / "vox1_meta.csv", newline="") as file:
23
+ data = csv.reader(file, delimiter="\t")
24
24
  for row in data:
25
25
  if first:
26
26
  first = False
27
27
  else:
28
- speakers[row[0].strip()] = {'gender': row[2].strip(),
29
- 'dialect': row[3].strip(),
30
- 'category': row[4].strip()}
28
+ speakers[row[0].strip()] = {
29
+ "gender": row[2].strip(),
30
+ "dialect": row[3].strip(),
31
+ "category": row[4].strip(),
32
+ }
31
33
 
32
34
  # VoxCeleb2
33
35
  first = True
34
- with open(input_dir / 'vox2_meta.csv', newline='') as file:
35
- data = csv.reader(file, delimiter='\t')
36
+ with open(input_dir / "vox2_meta.csv", newline="") as file:
37
+ data = csv.reader(file, delimiter="\t")
36
38
  for row in data:
37
39
  if first:
38
40
  first = False
39
41
  else:
40
- speakers[row[1].strip()] = {'gender': row[3].strip(),
41
- 'category': row[4].strip()}
42
+ speakers[row[1].strip()] = {
43
+ "gender": row[3].strip(),
44
+ "category": row[4].strip(),
45
+ }
42
46
 
43
47
  return speakers
44
48
 
@@ -46,18 +50,20 @@ def load_speakers(input_dir: Path) -> dict:
46
50
  def load_segment(path: str | os.PathLike[str]) -> Segment:
47
51
  path = Path(path)
48
52
 
49
- with path.open('r') as file:
53
+ with path.open("r") as file:
50
54
  segment = file.read().strip()
51
55
 
52
- header, frames = segment.split('\n\n')
56
+ header, frames = segment.split("\n\n")
53
57
  header_dict = _parse_header(header)
54
58
  start, stop = _get_segment_boundaries(frames)
55
59
 
56
- return Segment(person=header_dict['Identity'],
57
- video=header_dict['Reference'],
58
- id=path.stem,
59
- start=start,
60
- stop=stop)
60
+ return Segment(
61
+ person=header_dict["Identity"],
62
+ video=header_dict["Reference"],
63
+ id=path.stem,
64
+ start=start,
65
+ stop=stop,
66
+ )
61
67
 
62
68
 
63
69
  def _parse_header(header: str) -> dict:
@@ -73,12 +79,12 @@ def _parse_header(header: str) -> dict:
73
79
  ASD Conf : \t4.465
74
80
 
75
81
  """
76
- k, v = line.split('\t', maxsplit=1)
82
+ k, v = line.split("\t", maxsplit=1)
77
83
  k = k[:-2].strip()
78
84
  v = v.strip()
79
85
  return k, v
80
86
 
81
- return dict(_parse_line(line) for line in header.split('\n'))
87
+ return dict(_parse_line(line) for line in header.split("\n"))
82
88
 
83
89
 
84
90
  def _get_segment_boundaries(frames: str) -> tuple[float, float]:
@@ -94,9 +100,9 @@ def _get_segment_boundaries(frames: str) -> tuple[float, float]:
94
100
  """
95
101
 
96
102
  def _get_frame_seconds(line: str) -> float:
97
- frame = int(line.split('\t')[0])
103
+ frame = int(line.split("\t")[0])
98
104
  # YouTube is 25 FPS
99
105
  return frame / 25
100
106
 
101
- lines = frames.split('\n')
107
+ lines = frames.split("\n")
102
108
  return _get_frame_seconds(lines[1]), _get_frame_seconds(lines[-1])