lattifai 0.1.5__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +12 -47
- lattifai/bin/align.py +26 -2
- lattifai/bin/cli_base.py +5 -0
- lattifai/client.py +26 -13
- lattifai/io/reader.py +1 -2
- lattifai/tokenizer/tokenizer.py +284 -0
- lattifai/workers/lattice1_alpha.py +33 -11
- lattifai-0.2.2.dist-info/METADATA +333 -0
- lattifai-0.2.2.dist-info/RECORD +22 -0
- lattifai/tokenizers/tokenizer.py +0 -147
- lattifai-0.1.5.dist-info/METADATA +0 -444
- lattifai-0.1.5.dist-info/RECORD +0 -24
- scripts/__init__.py +0 -1
- scripts/install_k2.py +0 -520
- /lattifai/{tokenizers → tokenizer}/__init__.py +0 -0
- /lattifai/{tokenizers → tokenizer}/phonemizer.py +0 -0
- {lattifai-0.1.5.dist-info → lattifai-0.2.2.dist-info}/WHEEL +0 -0
- {lattifai-0.1.5.dist-info → lattifai-0.2.2.dist-info}/entry_points.txt +0 -0
- {lattifai-0.1.5.dist-info → lattifai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {lattifai-0.1.5.dist-info → lattifai-0.2.2.dist-info}/top_level.txt +0 -0
lattifai/__init__.py
CHANGED
|
@@ -22,54 +22,19 @@ def _check_and_install_k2():
|
|
|
22
22
|
"""Check if k2 is installed and attempt to install it if not."""
|
|
23
23
|
try:
|
|
24
24
|
import k2
|
|
25
|
-
|
|
26
|
-
return True
|
|
27
25
|
except ImportError:
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
'
|
|
37
|
-
|
|
38
|
-
'
|
|
39
|
-
|
|
40
|
-
RuntimeWarning,
|
|
41
|
-
stacklevel=2,
|
|
42
|
-
)
|
|
43
|
-
return False
|
|
44
|
-
|
|
45
|
-
print('\n' + '=' * 70)
|
|
46
|
-
print(' k2 is not installed. Attempting to install it now...')
|
|
47
|
-
print(' This is a one-time setup and may take a few minutes.')
|
|
48
|
-
print('=' * 70 + '\n')
|
|
49
|
-
|
|
50
|
-
try:
|
|
51
|
-
# Import and run the installation script
|
|
52
|
-
from scripts.install_k2 import install_k2_main
|
|
53
|
-
|
|
54
|
-
install_k2_main(dry_run=False)
|
|
55
|
-
|
|
56
|
-
print('\n' + '=' * 70)
|
|
57
|
-
print(' k2 has been installed successfully!')
|
|
58
|
-
print('=' * 70 + '\n')
|
|
59
|
-
return True
|
|
60
|
-
except Exception as e:
|
|
61
|
-
warnings.warn(
|
|
62
|
-
'\n' + '=' * 70 + '\n'
|
|
63
|
-
f' Failed to auto-install k2: {e}\n'
|
|
64
|
-
' \n'
|
|
65
|
-
' Please install k2 manually by running:\n'
|
|
66
|
-
' \n'
|
|
67
|
-
' install-k2\n'
|
|
68
|
-
' \n' + '=' * 70,
|
|
69
|
-
RuntimeWarning,
|
|
70
|
-
stacklevel=2,
|
|
71
|
-
)
|
|
72
|
-
return False
|
|
26
|
+
import subprocess
|
|
27
|
+
|
|
28
|
+
print('k2 is not installed. Attempting to install k2...')
|
|
29
|
+
try:
|
|
30
|
+
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'install-k2'])
|
|
31
|
+
subprocess.check_call([sys.executable, '-m', 'install_k2'])
|
|
32
|
+
import k2 # Try importing again after installation
|
|
33
|
+
|
|
34
|
+
print('k2 installed successfully.')
|
|
35
|
+
except Exception as e:
|
|
36
|
+
warnings.warn(f'Failed to install k2 automatically. Please install it manually. Error: {e}')
|
|
37
|
+
return True
|
|
73
38
|
|
|
74
39
|
|
|
75
40
|
# Auto-install k2 on first import
|
lattifai/bin/align.py
CHANGED
|
@@ -13,6 +13,23 @@ from lattifai.bin.cli_base import cli
|
|
|
13
13
|
default='auto',
|
|
14
14
|
help='Input Subtitle format.',
|
|
15
15
|
)
|
|
16
|
+
@click.option(
|
|
17
|
+
'-D',
|
|
18
|
+
'--device',
|
|
19
|
+
type=click.Choice(['cpu', 'cuda', 'mps'], case_sensitive=False),
|
|
20
|
+
default='cpu',
|
|
21
|
+
help='Device to use for inference.',
|
|
22
|
+
)
|
|
23
|
+
@click.option(
|
|
24
|
+
'-M', '--model_name_or_path', type=str, default='Lattifai/Lattice-1-Alpha', help='Lattifai model name or path'
|
|
25
|
+
)
|
|
26
|
+
@click.option(
|
|
27
|
+
'-S',
|
|
28
|
+
'--split_sentence',
|
|
29
|
+
is_flag=True,
|
|
30
|
+
default=False,
|
|
31
|
+
help='Re-segment subtitles by semantics.',
|
|
32
|
+
)
|
|
16
33
|
@click.argument(
|
|
17
34
|
'input_audio_path',
|
|
18
35
|
type=click.Path(exists=True, dir_okay=False),
|
|
@@ -30,13 +47,20 @@ def align(
|
|
|
30
47
|
input_subtitle_path: Pathlike,
|
|
31
48
|
output_subtitle_path: Pathlike,
|
|
32
49
|
input_format: str = 'auto',
|
|
50
|
+
device: str = 'cpu',
|
|
51
|
+
model_name_or_path: str = 'Lattifai/Lattice-1-Alpha',
|
|
52
|
+
split_sentence: bool = False,
|
|
33
53
|
):
|
|
34
54
|
"""
|
|
35
55
|
Command used to align audio with subtitles
|
|
36
56
|
"""
|
|
37
57
|
from lattifai import LattifAI
|
|
38
58
|
|
|
39
|
-
client = LattifAI()
|
|
59
|
+
client = LattifAI(model_name_or_path=model_name_or_path, device=device)
|
|
40
60
|
client.alignment(
|
|
41
|
-
input_audio_path,
|
|
61
|
+
input_audio_path,
|
|
62
|
+
input_subtitle_path,
|
|
63
|
+
format=input_format.lower(),
|
|
64
|
+
split_sentence=split_sentence,
|
|
65
|
+
output_subtitle_path=output_subtitle_path,
|
|
42
66
|
)
|
lattifai/bin/cli_base.py
CHANGED
|
@@ -8,6 +8,11 @@ def cli():
|
|
|
8
8
|
"""
|
|
9
9
|
The shell entry point to Lattifai, a tool for audio data manipulation.
|
|
10
10
|
"""
|
|
11
|
+
# Load environment variables from .env file
|
|
12
|
+
from dotenv import load_dotenv
|
|
13
|
+
|
|
14
|
+
load_dotenv()
|
|
15
|
+
|
|
11
16
|
logging.basicConfig(
|
|
12
17
|
format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s',
|
|
13
18
|
level=logging.INFO,
|
lattifai/client.py
CHANGED
|
@@ -11,7 +11,7 @@ from lhotse.utils import Pathlike
|
|
|
11
11
|
|
|
12
12
|
from lattifai.base_client import AsyncAPIClient, LattifAIError, SyncAPIClient
|
|
13
13
|
from lattifai.io import SubtitleFormat, SubtitleIO
|
|
14
|
-
from lattifai.
|
|
14
|
+
from lattifai.tokenizer import LatticeTokenizer
|
|
15
15
|
from lattifai.workers import Lattice1AlphaWorker
|
|
16
16
|
|
|
17
17
|
load_dotenv()
|
|
@@ -25,9 +25,9 @@ class LattifAI(SyncAPIClient):
|
|
|
25
25
|
*,
|
|
26
26
|
api_key: Optional[str] = None,
|
|
27
27
|
model_name_or_path: str = 'Lattifai/Lattice-1-Alpha',
|
|
28
|
-
device: str =
|
|
28
|
+
device: Optional[str] = None,
|
|
29
29
|
base_url: Optional[str] = None,
|
|
30
|
-
timeout: Union[float, int] =
|
|
30
|
+
timeout: Union[float, int] = 120.0,
|
|
31
31
|
max_retries: int = 2,
|
|
32
32
|
default_headers: Optional[Dict[str, str]] = None,
|
|
33
33
|
) -> None:
|
|
@@ -55,11 +55,26 @@ class LattifAI(SyncAPIClient):
|
|
|
55
55
|
# Initialize components
|
|
56
56
|
if not Path(model_name_or_path).exists():
|
|
57
57
|
from huggingface_hub import snapshot_download
|
|
58
|
+
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
58
59
|
|
|
59
|
-
|
|
60
|
+
try:
|
|
61
|
+
model_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
|
|
62
|
+
except LocalEntryNotFoundError:
|
|
63
|
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
|
64
|
+
model_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
|
|
60
65
|
else:
|
|
61
66
|
model_path = model_name_or_path
|
|
62
67
|
|
|
68
|
+
# device setup
|
|
69
|
+
if device is None:
|
|
70
|
+
import torch
|
|
71
|
+
|
|
72
|
+
device = 'cpu'
|
|
73
|
+
if torch.backends.mps.is_available():
|
|
74
|
+
device = 'mps'
|
|
75
|
+
elif torch.cuda.is_available():
|
|
76
|
+
device = 'cuda'
|
|
77
|
+
|
|
63
78
|
self.tokenizer = LatticeTokenizer.from_pretrained(
|
|
64
79
|
client_wrapper=self,
|
|
65
80
|
model_path=model_path,
|
|
@@ -72,6 +87,7 @@ class LattifAI(SyncAPIClient):
|
|
|
72
87
|
audio: Pathlike,
|
|
73
88
|
subtitle: Pathlike,
|
|
74
89
|
format: Optional[SubtitleFormat] = None,
|
|
90
|
+
split_sentence: bool = False,
|
|
75
91
|
output_subtitle_path: Optional[Pathlike] = None,
|
|
76
92
|
) -> str:
|
|
77
93
|
"""Perform alignment on audio and subtitle/text.
|
|
@@ -87,11 +103,11 @@ class LattifAI(SyncAPIClient):
|
|
|
87
103
|
# step1: parse text or subtitles
|
|
88
104
|
print(colorful.cyan(f'📖 Step 1: Reading subtitle file from {subtitle}'))
|
|
89
105
|
supervisions = SubtitleIO.read(subtitle, format=format)
|
|
90
|
-
print(colorful.green(f' ✓ Parsed {len(supervisions)}
|
|
106
|
+
print(colorful.green(f' ✓ Parsed {len(supervisions)} subtitle segments'))
|
|
91
107
|
|
|
92
108
|
# step2: make lattice by call Lattifai API
|
|
93
109
|
print(colorful.cyan('🔗 Step 2: Creating lattice graph from text'))
|
|
94
|
-
lattice_id, lattice_graph = self.tokenizer.tokenize(supervisions)
|
|
110
|
+
lattice_id, lattice_graph = self.tokenizer.tokenize(supervisions, split_sentence=split_sentence)
|
|
95
111
|
print(colorful.green(f' ✓ Generated lattice graph with ID: {lattice_id}'))
|
|
96
112
|
|
|
97
113
|
# step3: align audio with text
|
|
@@ -117,13 +133,10 @@ if __name__ == '__main__':
|
|
|
117
133
|
import sys
|
|
118
134
|
|
|
119
135
|
if len(sys.argv) == 4:
|
|
120
|
-
|
|
136
|
+
audio, subtitle, output = sys.argv[1:]
|
|
121
137
|
else:
|
|
122
138
|
audio = 'tests/data/SA1.wav'
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
alignments = client.alignment(audio, text)
|
|
126
|
-
print(alignments)
|
|
139
|
+
subtitle = 'tests/data/SA1.TXT'
|
|
140
|
+
output = None
|
|
127
141
|
|
|
128
|
-
alignments = client.alignment(audio,
|
|
129
|
-
print(alignments)
|
|
142
|
+
alignments = client.alignment(audio, subtitle, output_subtitle_path=output, split_sentence=True)
|
lattifai/io/reader.py
CHANGED
|
@@ -58,13 +58,12 @@ class SubtitleReader(ABCMeta):
|
|
|
58
58
|
subs: pysubs2.SSAFile = pysubs2.load(subtitle, encoding='utf-8') # auto detect format
|
|
59
59
|
|
|
60
60
|
supervisions = []
|
|
61
|
-
|
|
62
61
|
for event in subs.events:
|
|
63
62
|
supervisions.append(
|
|
64
63
|
Supervision(
|
|
65
64
|
text=event.text,
|
|
66
65
|
# "start": event.start / 1000.0 if event.start is not None else None,
|
|
67
|
-
# "duration": event.end / 1000.0 if event.end is not None else None,
|
|
66
|
+
# "duration": (event.end - event.start) / 1000.0 if event.end is not None else None,
|
|
68
67
|
# }
|
|
69
68
|
)
|
|
70
69
|
)
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import gzip
|
|
2
|
+
import pickle
|
|
3
|
+
import re
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from itertools import chain
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from lattifai.base_client import SyncAPIClient
|
|
11
|
+
from lattifai.io import Supervision
|
|
12
|
+
from lattifai.tokenizer.phonemizer import G2Phonemizer
|
|
13
|
+
|
|
14
|
+
PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
|
|
15
|
+
END_PUNCTUATION = '.!?"]。!?”】'
|
|
16
|
+
PUNCTUATION_SPACE = PUNCTUATION + ' '
|
|
17
|
+
STAR_TOKEN = '※'
|
|
18
|
+
|
|
19
|
+
GROUPING_SEPARATOR = '✹'
|
|
20
|
+
|
|
21
|
+
MAXIMUM_WORD_LENGTH = 40
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LatticeTokenizer:
|
|
25
|
+
"""Tokenizer for converting Lhotse Cut to LatticeGraph."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, client_wrapper: SyncAPIClient):
|
|
28
|
+
self.client_wrapper = client_wrapper
|
|
29
|
+
self.words: List[str] = []
|
|
30
|
+
self.g2p_model: Any = None # Placeholder for G2P model
|
|
31
|
+
self.dictionaries = defaultdict(lambda: [])
|
|
32
|
+
self.oov_word = '<unk>'
|
|
33
|
+
self.sentence_splitter = None
|
|
34
|
+
self.device = 'cpu'
|
|
35
|
+
|
|
36
|
+
def init_sentence_splitter(self):
|
|
37
|
+
if self.sentence_splitter is not None:
|
|
38
|
+
return
|
|
39
|
+
|
|
40
|
+
import onnxruntime as ort
|
|
41
|
+
from wtpsplit import SaT
|
|
42
|
+
|
|
43
|
+
providers = []
|
|
44
|
+
device = self.device
|
|
45
|
+
if device.startswith('cuda') and ort.get_all_providers().count('CUDAExecutionProvider') > 0:
|
|
46
|
+
providers.append('CUDAExecutionProvider')
|
|
47
|
+
elif device.startswith('mps') and ort.get_all_providers().count('MPSExecutionProvider') > 0:
|
|
48
|
+
providers.append('MPSExecutionProvider')
|
|
49
|
+
|
|
50
|
+
sat = SaT(
|
|
51
|
+
'sat-3l-sm',
|
|
52
|
+
ort_providers=providers + ['CPUExecutionProvider'],
|
|
53
|
+
)
|
|
54
|
+
self.sentence_splitter = sat
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _resplit_special_sentence_types(sentence: str) -> List[str]:
|
|
58
|
+
"""
|
|
59
|
+
Re-split special sentence types.
|
|
60
|
+
|
|
61
|
+
Examples:
|
|
62
|
+
'[APPLAUSE] >> MIRA MURATI:' -> ['[APPLAUSE]', '>> MIRA MURATI:']
|
|
63
|
+
'[MUSIC] >> SPEAKER:' -> ['[MUSIC]', '>> SPEAKER:']
|
|
64
|
+
|
|
65
|
+
Special handling patterns:
|
|
66
|
+
1. Separate special marks at the beginning (e.g., [APPLAUSE], [MUSIC], etc.) from subsequent speaker marks
|
|
67
|
+
2. Use speaker marks (>> or other separators) as split points
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
sentence: Input sentence string
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
List of re-split sentences. If no special marks are found, returns the original sentence in a list
|
|
74
|
+
"""
|
|
75
|
+
# Detect special mark patterns: [SOMETHING] >> SPEAKER:
|
|
76
|
+
# or other forms like [SOMETHING] SPEAKER:
|
|
77
|
+
|
|
78
|
+
# Pattern 1: [mark] HTML-encoded separator speaker:
|
|
79
|
+
pattern1 = r'^(\[[^\]]+\])\s+(>>|>>)\s+(.+)$'
|
|
80
|
+
match1 = re.match(pattern1, sentence.strip())
|
|
81
|
+
if match1:
|
|
82
|
+
special_mark = match1.group(1)
|
|
83
|
+
separator = match1.group(2)
|
|
84
|
+
speaker_part = match1.group(3)
|
|
85
|
+
return [special_mark, f'{separator} {speaker_part}']
|
|
86
|
+
|
|
87
|
+
# Pattern 2: [mark] speaker:
|
|
88
|
+
pattern2 = r'^(\[[^\]]+\])\s+([^:]+:)(.*)$'
|
|
89
|
+
match2 = re.match(pattern2, sentence.strip())
|
|
90
|
+
if match2:
|
|
91
|
+
special_mark = match2.group(1)
|
|
92
|
+
speaker_label = match2.group(2)
|
|
93
|
+
remaining = match2.group(3).strip()
|
|
94
|
+
if remaining:
|
|
95
|
+
return [special_mark, f'{speaker_label} {remaining}']
|
|
96
|
+
else:
|
|
97
|
+
return [special_mark, speaker_label]
|
|
98
|
+
|
|
99
|
+
# If no special pattern matches, return the original sentence
|
|
100
|
+
return [sentence]
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def from_pretrained(
|
|
104
|
+
client_wrapper: SyncAPIClient,
|
|
105
|
+
model_path: str,
|
|
106
|
+
device: str = 'cpu',
|
|
107
|
+
compressed: bool = True,
|
|
108
|
+
):
|
|
109
|
+
"""Load tokenizer from exported binary file"""
|
|
110
|
+
from pathlib import Path
|
|
111
|
+
|
|
112
|
+
words_model_path = f'{model_path}/words.bin'
|
|
113
|
+
if compressed:
|
|
114
|
+
with gzip.open(words_model_path, 'rb') as f:
|
|
115
|
+
data = pickle.load(f)
|
|
116
|
+
else:
|
|
117
|
+
with open(words_model_path, 'rb') as f:
|
|
118
|
+
data = pickle.load(f)
|
|
119
|
+
|
|
120
|
+
tokenizer = LatticeTokenizer(client_wrapper=client_wrapper)
|
|
121
|
+
tokenizer.words = data['words']
|
|
122
|
+
tokenizer.dictionaries = defaultdict(list, data['dictionaries'])
|
|
123
|
+
tokenizer.oov_word = data['oov_word']
|
|
124
|
+
|
|
125
|
+
g2p_model_path = f'{model_path}/g2p.bin' if Path(f'{model_path}/g2p.bin').exists() else None
|
|
126
|
+
if g2p_model_path:
|
|
127
|
+
tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
|
|
128
|
+
|
|
129
|
+
tokenizer.device = device
|
|
130
|
+
tokenizer.add_special_tokens()
|
|
131
|
+
return tokenizer
|
|
132
|
+
|
|
133
|
+
def add_special_tokens(self):
|
|
134
|
+
tokenizer = self
|
|
135
|
+
for special_token in ['>>', '>']:
|
|
136
|
+
if special_token not in tokenizer.dictionaries:
|
|
137
|
+
tokenizer.dictionaries[special_token] = tokenizer.dictionaries[tokenizer.oov_word]
|
|
138
|
+
return self
|
|
139
|
+
|
|
140
|
+
def prenormalize(self, texts: List[str], language: Optional[str] = None) -> List[str]:
|
|
141
|
+
if not self.g2p_model:
|
|
142
|
+
raise ValueError('G2P model is not loaded, cannot prenormalize texts')
|
|
143
|
+
|
|
144
|
+
oov_words = []
|
|
145
|
+
for text in texts:
|
|
146
|
+
words = text.lower().replace('-', ' ').replace('—', ' ').replace('–', ' ').split()
|
|
147
|
+
oovs = [w for w in words if w not in self.words]
|
|
148
|
+
if oovs:
|
|
149
|
+
oov_words.extend([w for w in oovs if (w not in self.words and len(w) <= MAXIMUM_WORD_LENGTH)])
|
|
150
|
+
|
|
151
|
+
oov_words = list(set(oov_words))
|
|
152
|
+
if oov_words:
|
|
153
|
+
indexs = []
|
|
154
|
+
for k, _word in enumerate(oov_words):
|
|
155
|
+
if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [('(', ')'), ('[', ']')]):
|
|
156
|
+
self.dictionaries[_word] = self.dictionaries[self.oov_word]
|
|
157
|
+
else:
|
|
158
|
+
_word = _word.strip(PUNCTUATION_SPACE)
|
|
159
|
+
if not _word or _word in self.words:
|
|
160
|
+
indexs.append(k)
|
|
161
|
+
for idx in sorted(indexs, reverse=True):
|
|
162
|
+
del oov_words[idx]
|
|
163
|
+
|
|
164
|
+
g2p_words = [w for w in oov_words if w not in self.dictionaries]
|
|
165
|
+
if g2p_words:
|
|
166
|
+
predictions = self.g2p_model(words=g2p_words, lang=language, batch_size=len(g2p_words), num_prons=4)
|
|
167
|
+
for _word, _predictions in zip(g2p_words, predictions):
|
|
168
|
+
for pronuncation in _predictions:
|
|
169
|
+
if pronuncation and pronuncation not in self.dictionaries[_word]:
|
|
170
|
+
self.dictionaries[_word].append(pronuncation)
|
|
171
|
+
if not self.dictionaries[_word]:
|
|
172
|
+
self.dictionaries[_word] = self.dictionaries[self.oov_word]
|
|
173
|
+
|
|
174
|
+
pronunciation_dictionaries: Dict[str, List[List[str]]] = {
|
|
175
|
+
w: self.dictionaries[w] for w in oov_words if self.dictionaries[w]
|
|
176
|
+
}
|
|
177
|
+
return pronunciation_dictionaries
|
|
178
|
+
|
|
179
|
+
return {}
|
|
180
|
+
|
|
181
|
+
def split_sentences(self, supervisions: List[Supervision], strip_whitespace=True) -> List[str]:
|
|
182
|
+
texts, text_len, sidx = [], 0, 0
|
|
183
|
+
for s, supervision in enumerate(supervisions):
|
|
184
|
+
text_len += len(supervision.text)
|
|
185
|
+
if text_len >= 2000 or s == len(supervisions) - 1:
|
|
186
|
+
text = ' '.join([sup.text for sup in supervisions[sidx : s + 1]])
|
|
187
|
+
texts.append(text)
|
|
188
|
+
sidx = s + 1
|
|
189
|
+
text_len = 0
|
|
190
|
+
if sidx < len(supervisions):
|
|
191
|
+
text = ' '.join([sup.text for sup in supervisions[sidx:]])
|
|
192
|
+
texts.append(text)
|
|
193
|
+
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
|
|
194
|
+
|
|
195
|
+
supervisions, remainder = [], ''
|
|
196
|
+
for _sentences in sentences:
|
|
197
|
+
# Process and re-split special sentence types
|
|
198
|
+
processed_sentences = []
|
|
199
|
+
for s, _sentence in enumerate(_sentences):
|
|
200
|
+
if remainder:
|
|
201
|
+
_sentence = remainder + _sentence
|
|
202
|
+
remainder = ''
|
|
203
|
+
|
|
204
|
+
# Detect and split special sentence types: e.g., '[APPLAUSE] >> MIRA MURATI:' -> ['[APPLAUSE]', '>> MIRA MURATI:'] # noqa: E501
|
|
205
|
+
resplit_parts = self._resplit_special_sentence_types(_sentence)
|
|
206
|
+
if any(resplit_parts[-1].endswith(sp) for sp in [':', ':']):
|
|
207
|
+
if s < len(_sentences) - 1:
|
|
208
|
+
_sentences[s + 1] = resplit_parts[-1] + ' ' + _sentences[s + 1]
|
|
209
|
+
else: # last part
|
|
210
|
+
remainder = resplit_parts[-1] + ' ' + remainder
|
|
211
|
+
processed_sentences.extend(resplit_parts[:-1])
|
|
212
|
+
else:
|
|
213
|
+
processed_sentences.extend(resplit_parts)
|
|
214
|
+
|
|
215
|
+
_sentences = processed_sentences
|
|
216
|
+
|
|
217
|
+
if remainder:
|
|
218
|
+
_sentences[0] = remainder + _sentences[0]
|
|
219
|
+
remainder = ''
|
|
220
|
+
|
|
221
|
+
if any(_sentences[-1].endswith(ep) for ep in END_PUNCTUATION):
|
|
222
|
+
supervisions.extend(Supervision(text=s) for s in _sentences)
|
|
223
|
+
else:
|
|
224
|
+
supervisions.extend(Supervision(text=s) for s in _sentences[:-1])
|
|
225
|
+
remainder += _sentences[-1] + ' '
|
|
226
|
+
|
|
227
|
+
if remainder.strip():
|
|
228
|
+
supervisions.append(Supervision(text=remainder.strip()))
|
|
229
|
+
|
|
230
|
+
return supervisions
|
|
231
|
+
|
|
232
|
+
def tokenize(self, supervisions: List[Supervision], split_sentence: bool = False) -> Tuple[str, Dict[str, Any]]:
|
|
233
|
+
if split_sentence:
|
|
234
|
+
self.init_sentence_splitter()
|
|
235
|
+
supervisions = self.split_sentences(supervisions)
|
|
236
|
+
|
|
237
|
+
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
238
|
+
response = self.client_wrapper.post(
|
|
239
|
+
'tokenize',
|
|
240
|
+
json={
|
|
241
|
+
'supervisions': [s.to_dict() for s in supervisions],
|
|
242
|
+
'pronunciation_dictionaries': pronunciation_dictionaries,
|
|
243
|
+
},
|
|
244
|
+
)
|
|
245
|
+
if response.status_code != 200:
|
|
246
|
+
raise Exception(f'Failed to tokenize texts: {response.text}')
|
|
247
|
+
result = response.json()
|
|
248
|
+
lattice_id = result['id']
|
|
249
|
+
return lattice_id, (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0))
|
|
250
|
+
|
|
251
|
+
def detokenize(
|
|
252
|
+
self,
|
|
253
|
+
lattice_id: str,
|
|
254
|
+
lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
|
|
255
|
+
# return_supervisions: bool = True,
|
|
256
|
+
# return_details: bool = False,
|
|
257
|
+
) -> List[Supervision]:
|
|
258
|
+
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
259
|
+
response = self.client_wrapper.post(
|
|
260
|
+
'detokenize',
|
|
261
|
+
json={
|
|
262
|
+
'lattice_id': lattice_id,
|
|
263
|
+
'frame_shift': frame_shift,
|
|
264
|
+
'results': [t.to_dict() for t in results[0]],
|
|
265
|
+
'labels': labels[0],
|
|
266
|
+
'offset': offset,
|
|
267
|
+
'channel': channel,
|
|
268
|
+
'destroy_lattice': True,
|
|
269
|
+
},
|
|
270
|
+
)
|
|
271
|
+
if response.status_code != 200:
|
|
272
|
+
raise Exception(f'Failed to detokenize lattice: {response.text}')
|
|
273
|
+
result = response.json()
|
|
274
|
+
# if return_details:
|
|
275
|
+
# raise NotImplementedError("return_details is not implemented yet")
|
|
276
|
+
return [Supervision.from_dict(s) for s in result['supervisions']]
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
# Compute average score weighted by the span length
|
|
280
|
+
def _score(spans):
|
|
281
|
+
if not spans:
|
|
282
|
+
return 0.0
|
|
283
|
+
# TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
|
|
284
|
+
return round(sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans), ndigits=4)
|
|
@@ -17,8 +17,6 @@ class Lattice1AlphaWorker:
|
|
|
17
17
|
"""Worker for processing audio with LatticeGraph."""
|
|
18
18
|
|
|
19
19
|
def __init__(self, model_path: Pathlike, device: str = 'cpu', num_threads: int = 8) -> None:
|
|
20
|
-
if device != 'cpu':
|
|
21
|
-
raise NotImplementedError(f'Only cpu is supported for now, got device={device}.')
|
|
22
20
|
self.config = json.load(open(f'{model_path}/config.json'))
|
|
23
21
|
|
|
24
22
|
# SessionOptions
|
|
@@ -29,8 +27,11 @@ class Lattice1AlphaWorker:
|
|
|
29
27
|
sess_options.add_session_config_entry('session.intra_op.allow_spinning', '0')
|
|
30
28
|
|
|
31
29
|
providers = []
|
|
32
|
-
if device.startswith('cuda')
|
|
30
|
+
if device.startswith('cuda') and ort.get_all_providers().count('CUDAExecutionProvider') > 0:
|
|
33
31
|
providers.append('CUDAExecutionProvider')
|
|
32
|
+
elif device.startswith('mps') and ort.get_all_providers().count('MPSExecutionProvider') > 0:
|
|
33
|
+
providers.append('MPSExecutionProvider')
|
|
34
|
+
|
|
34
35
|
self.acoustic_ort = ort.InferenceSession(
|
|
35
36
|
f'{model_path}/acoustic_opt.onnx',
|
|
36
37
|
sess_options,
|
|
@@ -49,13 +50,29 @@ class Lattice1AlphaWorker:
|
|
|
49
50
|
_start = time.time()
|
|
50
51
|
# audio -> features -> emission
|
|
51
52
|
features = self.extractor(audio) # (1, T, D)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
53
|
+
if features.shape[1] > 6000:
|
|
54
|
+
features_list = torch.split(features, 6000, dim=1)
|
|
55
|
+
emissions = []
|
|
56
|
+
for features in features_list:
|
|
57
|
+
ort_inputs = {
|
|
58
|
+
'features': features.cpu().numpy(),
|
|
59
|
+
'feature_lengths': np.array([features.size(1)], dtype=np.int64),
|
|
60
|
+
}
|
|
61
|
+
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
62
|
+
emissions.append(emission)
|
|
63
|
+
emission = torch.cat(
|
|
64
|
+
[torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
|
|
65
|
+
) # (1, T, vocab_size)
|
|
66
|
+
else:
|
|
67
|
+
ort_inputs = {
|
|
68
|
+
'features': features.cpu().numpy(),
|
|
69
|
+
'feature_lengths': np.array([features.size(1)], dtype=np.int64),
|
|
70
|
+
}
|
|
71
|
+
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
72
|
+
emission = torch.from_numpy(emission).to(self.device)
|
|
73
|
+
|
|
57
74
|
self.timings['emission'] += time.time() - _start
|
|
58
|
-
return
|
|
75
|
+
return emission # (1, T, vocab_size) torch
|
|
59
76
|
|
|
60
77
|
def load_audio(self, audio: Union[Pathlike, BinaryIO]) -> Tuple[torch.Tensor, int]:
|
|
61
78
|
# load audio
|
|
@@ -104,9 +121,14 @@ class Lattice1AlphaWorker:
|
|
|
104
121
|
self.timings['decoding_graph'] += time.time() - _start
|
|
105
122
|
|
|
106
123
|
_start = time.time()
|
|
124
|
+
if self.device.type == 'mps':
|
|
125
|
+
device = 'cpu' # k2 does not support mps yet
|
|
126
|
+
else:
|
|
127
|
+
device = self.device
|
|
128
|
+
|
|
107
129
|
results, labels = align_segments(
|
|
108
|
-
emission.to(
|
|
109
|
-
decoding_graph.to(
|
|
130
|
+
emission.to(device) * acoustic_scale,
|
|
131
|
+
decoding_graph.to(device),
|
|
110
132
|
torch.tensor([emission.shape[1]], dtype=torch.int32),
|
|
111
133
|
search_beam=100,
|
|
112
134
|
output_beam=40,
|