lattifai 0.1.5__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lattifai/__init__.py CHANGED
@@ -22,54 +22,19 @@ def _check_and_install_k2():
22
22
  """Check if k2 is installed and attempt to install it if not."""
23
23
  try:
24
24
  import k2
25
-
26
- return True
27
25
  except ImportError:
28
- pass
29
-
30
- # k2 not found, try to install it
31
- if os.environ.get('SKIP_K2_INSTALL'):
32
- warnings.warn(
33
- '\n' + '=' * 70 + '\n'
34
- ' k2 is not installed and auto-installation is disabled.\n'
35
- ' \n'
36
- ' To use lattifai, please install k2 by running:\n'
37
- ' \n'
38
- ' install-k2\n'
39
- ' \n' + '=' * 70,
40
- RuntimeWarning,
41
- stacklevel=2,
42
- )
43
- return False
44
-
45
- print('\n' + '=' * 70)
46
- print(' k2 is not installed. Attempting to install it now...')
47
- print(' This is a one-time setup and may take a few minutes.')
48
- print('=' * 70 + '\n')
49
-
50
- try:
51
- # Import and run the installation script
52
- from scripts.install_k2 import install_k2_main
53
-
54
- install_k2_main(dry_run=False)
55
-
56
- print('\n' + '=' * 70)
57
- print(' k2 has been installed successfully!')
58
- print('=' * 70 + '\n')
59
- return True
60
- except Exception as e:
61
- warnings.warn(
62
- '\n' + '=' * 70 + '\n'
63
- f' Failed to auto-install k2: {e}\n'
64
- ' \n'
65
- ' Please install k2 manually by running:\n'
66
- ' \n'
67
- ' install-k2\n'
68
- ' \n' + '=' * 70,
69
- RuntimeWarning,
70
- stacklevel=2,
71
- )
72
- return False
26
+ import subprocess
27
+
28
+ print('k2 is not installed. Attempting to install k2...')
29
+ try:
30
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'install-k2'])
31
+ subprocess.check_call([sys.executable, '-m', 'install_k2'])
32
+ import k2 # Try importing again after installation
33
+
34
+ print('k2 installed successfully.')
35
+ except Exception as e:
36
+ warnings.warn(f'Failed to install k2 automatically. Please install it manually. Error: {e}')
37
+ return True
73
38
 
74
39
 
75
40
  # Auto-install k2 on first import
lattifai/bin/align.py CHANGED
@@ -13,6 +13,23 @@ from lattifai.bin.cli_base import cli
13
13
  default='auto',
14
14
  help='Input Subtitle format.',
15
15
  )
16
+ @click.option(
17
+ '-D',
18
+ '--device',
19
+ type=click.Choice(['cpu', 'cuda', 'mps'], case_sensitive=False),
20
+ default='cpu',
21
+ help='Device to use for inference.',
22
+ )
23
+ @click.option(
24
+ '-M', '--model_name_or_path', type=str, default='Lattifai/Lattice-1-Alpha', help='Lattifai model name or path'
25
+ )
26
+ @click.option(
27
+ '-S',
28
+ '--split_sentence',
29
+ is_flag=True,
30
+ default=False,
31
+ help='Re-segment subtitles by semantics.',
32
+ )
16
33
  @click.argument(
17
34
  'input_audio_path',
18
35
  type=click.Path(exists=True, dir_okay=False),
@@ -30,13 +47,20 @@ def align(
30
47
  input_subtitle_path: Pathlike,
31
48
  output_subtitle_path: Pathlike,
32
49
  input_format: str = 'auto',
50
+ device: str = 'cpu',
51
+ model_name_or_path: str = 'Lattifai/Lattice-1-Alpha',
52
+ split_sentence: bool = False,
33
53
  ):
34
54
  """
35
55
  Command used to align audio with subtitles
36
56
  """
37
57
  from lattifai import LattifAI
38
58
 
39
- client = LattifAI()
59
+ client = LattifAI(model_name_or_path=model_name_or_path, device=device)
40
60
  client.alignment(
41
- input_audio_path, input_subtitle_path, format=input_format, output_subtitle_path=output_subtitle_path
61
+ input_audio_path,
62
+ input_subtitle_path,
63
+ format=input_format.lower(),
64
+ split_sentence=split_sentence,
65
+ output_subtitle_path=output_subtitle_path,
42
66
  )
lattifai/bin/cli_base.py CHANGED
@@ -8,6 +8,11 @@ def cli():
8
8
  """
9
9
  The shell entry point to Lattifai, a tool for audio data manipulation.
10
10
  """
11
+ # Load environment variables from .env file
12
+ from dotenv import load_dotenv
13
+
14
+ load_dotenv()
15
+
11
16
  logging.basicConfig(
12
17
  format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s',
13
18
  level=logging.INFO,
lattifai/client.py CHANGED
@@ -11,7 +11,7 @@ from lhotse.utils import Pathlike
11
11
 
12
12
  from lattifai.base_client import AsyncAPIClient, LattifAIError, SyncAPIClient
13
13
  from lattifai.io import SubtitleFormat, SubtitleIO
14
- from lattifai.tokenizers import LatticeTokenizer
14
+ from lattifai.tokenizer import LatticeTokenizer
15
15
  from lattifai.workers import Lattice1AlphaWorker
16
16
 
17
17
  load_dotenv()
@@ -25,9 +25,9 @@ class LattifAI(SyncAPIClient):
25
25
  *,
26
26
  api_key: Optional[str] = None,
27
27
  model_name_or_path: str = 'Lattifai/Lattice-1-Alpha',
28
- device: str = 'cpu',
28
+ device: Optional[str] = None,
29
29
  base_url: Optional[str] = None,
30
- timeout: Union[float, int] = 60.0,
30
+ timeout: Union[float, int] = 120.0,
31
31
  max_retries: int = 2,
32
32
  default_headers: Optional[Dict[str, str]] = None,
33
33
  ) -> None:
@@ -55,11 +55,26 @@ class LattifAI(SyncAPIClient):
55
55
  # Initialize components
56
56
  if not Path(model_name_or_path).exists():
57
57
  from huggingface_hub import snapshot_download
58
+ from huggingface_hub.errors import LocalEntryNotFoundError
58
59
 
59
- model_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
60
+ try:
61
+ model_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
62
+ except LocalEntryNotFoundError:
63
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
64
+ model_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
60
65
  else:
61
66
  model_path = model_name_or_path
62
67
 
68
+ # device setup
69
+ if device is None:
70
+ import torch
71
+
72
+ device = 'cpu'
73
+ if torch.backends.mps.is_available():
74
+ device = 'mps'
75
+ elif torch.cuda.is_available():
76
+ device = 'cuda'
77
+
63
78
  self.tokenizer = LatticeTokenizer.from_pretrained(
64
79
  client_wrapper=self,
65
80
  model_path=model_path,
@@ -72,6 +87,7 @@ class LattifAI(SyncAPIClient):
72
87
  audio: Pathlike,
73
88
  subtitle: Pathlike,
74
89
  format: Optional[SubtitleFormat] = None,
90
+ split_sentence: bool = False,
75
91
  output_subtitle_path: Optional[Pathlike] = None,
76
92
  ) -> str:
77
93
  """Perform alignment on audio and subtitle/text.
@@ -87,11 +103,11 @@ class LattifAI(SyncAPIClient):
87
103
  # step1: parse text or subtitles
88
104
  print(colorful.cyan(f'📖 Step 1: Reading subtitle file from {subtitle}'))
89
105
  supervisions = SubtitleIO.read(subtitle, format=format)
90
- print(colorful.green(f' ✓ Parsed {len(supervisions)} supervision segments'))
106
+ print(colorful.green(f' ✓ Parsed {len(supervisions)} subtitle segments'))
91
107
 
92
108
  # step2: make lattice by call Lattifai API
93
109
  print(colorful.cyan('🔗 Step 2: Creating lattice graph from text'))
94
- lattice_id, lattice_graph = self.tokenizer.tokenize(supervisions)
110
+ lattice_id, lattice_graph = self.tokenizer.tokenize(supervisions, split_sentence=split_sentence)
95
111
  print(colorful.green(f' ✓ Generated lattice graph with ID: {lattice_id}'))
96
112
 
97
113
  # step3: align audio with text
@@ -117,13 +133,10 @@ if __name__ == '__main__':
117
133
  import sys
118
134
 
119
135
  if len(sys.argv) == 4:
120
- pass
136
+ audio, subtitle, output = sys.argv[1:]
121
137
  else:
122
138
  audio = 'tests/data/SA1.wav'
123
- text = 'tests/data/SA1.TXT'
124
-
125
- alignments = client.alignment(audio, text)
126
- print(alignments)
139
+ subtitle = 'tests/data/SA1.TXT'
140
+ output = None
127
141
 
128
- alignments = client.alignment(audio, 'not paired texttttt', format='txt')
129
- print(alignments)
142
+ alignments = client.alignment(audio, subtitle, output_subtitle_path=output, split_sentence=True)
lattifai/io/reader.py CHANGED
@@ -58,13 +58,12 @@ class SubtitleReader(ABCMeta):
58
58
  subs: pysubs2.SSAFile = pysubs2.load(subtitle, encoding='utf-8') # auto detect format
59
59
 
60
60
  supervisions = []
61
-
62
61
  for event in subs.events:
63
62
  supervisions.append(
64
63
  Supervision(
65
64
  text=event.text,
66
65
  # "start": event.start / 1000.0 if event.start is not None else None,
67
- # "duration": event.end / 1000.0 if event.end is not None else None,
66
+ # "duration": (event.end - event.start) / 1000.0 if event.end is not None else None,
68
67
  # }
69
68
  )
70
69
  )
@@ -0,0 +1,284 @@
1
+ import gzip
2
+ import pickle
3
+ import re
4
+ from collections import defaultdict
5
+ from itertools import chain
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+
10
+ from lattifai.base_client import SyncAPIClient
11
+ from lattifai.io import Supervision
12
+ from lattifai.tokenizer.phonemizer import G2Phonemizer
13
+
14
+ PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
15
+ END_PUNCTUATION = '.!?"]。!?”】'
16
+ PUNCTUATION_SPACE = PUNCTUATION + ' '
17
+ STAR_TOKEN = '※'
18
+
19
+ GROUPING_SEPARATOR = '✹'
20
+
21
+ MAXIMUM_WORD_LENGTH = 40
22
+
23
+
24
+ class LatticeTokenizer:
25
+ """Tokenizer for converting Lhotse Cut to LatticeGraph."""
26
+
27
+ def __init__(self, client_wrapper: SyncAPIClient):
28
+ self.client_wrapper = client_wrapper
29
+ self.words: List[str] = []
30
+ self.g2p_model: Any = None # Placeholder for G2P model
31
+ self.dictionaries = defaultdict(lambda: [])
32
+ self.oov_word = '<unk>'
33
+ self.sentence_splitter = None
34
+ self.device = 'cpu'
35
+
36
+ def init_sentence_splitter(self):
37
+ if self.sentence_splitter is not None:
38
+ return
39
+
40
+ import onnxruntime as ort
41
+ from wtpsplit import SaT
42
+
43
+ providers = []
44
+ device = self.device
45
+ if device.startswith('cuda') and ort.get_all_providers().count('CUDAExecutionProvider') > 0:
46
+ providers.append('CUDAExecutionProvider')
47
+ elif device.startswith('mps') and ort.get_all_providers().count('MPSExecutionProvider') > 0:
48
+ providers.append('MPSExecutionProvider')
49
+
50
+ sat = SaT(
51
+ 'sat-3l-sm',
52
+ ort_providers=providers + ['CPUExecutionProvider'],
53
+ )
54
+ self.sentence_splitter = sat
55
+
56
+ @staticmethod
57
+ def _resplit_special_sentence_types(sentence: str) -> List[str]:
58
+ """
59
+ Re-split special sentence types.
60
+
61
+ Examples:
62
+ '[APPLAUSE] &gt;&gt; MIRA MURATI:' -> ['[APPLAUSE]', '&gt;&gt; MIRA MURATI:']
63
+ '[MUSIC] &gt;&gt; SPEAKER:' -> ['[MUSIC]', '&gt;&gt; SPEAKER:']
64
+
65
+ Special handling patterns:
66
+ 1. Separate special marks at the beginning (e.g., [APPLAUSE], [MUSIC], etc.) from subsequent speaker marks
67
+ 2. Use speaker marks (&gt;&gt; or other separators) as split points
68
+
69
+ Args:
70
+ sentence: Input sentence string
71
+
72
+ Returns:
73
+ List of re-split sentences. If no special marks are found, returns the original sentence in a list
74
+ """
75
+ # Detect special mark patterns: [SOMETHING] &gt;&gt; SPEAKER:
76
+ # or other forms like [SOMETHING] SPEAKER:
77
+
78
+ # Pattern 1: [mark] HTML-encoded separator speaker:
79
+ pattern1 = r'^(\[[^\]]+\])\s+(&gt;&gt;|>>)\s+(.+)$'
80
+ match1 = re.match(pattern1, sentence.strip())
81
+ if match1:
82
+ special_mark = match1.group(1)
83
+ separator = match1.group(2)
84
+ speaker_part = match1.group(3)
85
+ return [special_mark, f'{separator} {speaker_part}']
86
+
87
+ # Pattern 2: [mark] speaker:
88
+ pattern2 = r'^(\[[^\]]+\])\s+([^:]+:)(.*)$'
89
+ match2 = re.match(pattern2, sentence.strip())
90
+ if match2:
91
+ special_mark = match2.group(1)
92
+ speaker_label = match2.group(2)
93
+ remaining = match2.group(3).strip()
94
+ if remaining:
95
+ return [special_mark, f'{speaker_label} {remaining}']
96
+ else:
97
+ return [special_mark, speaker_label]
98
+
99
+ # If no special pattern matches, return the original sentence
100
+ return [sentence]
101
+
102
+ @staticmethod
103
+ def from_pretrained(
104
+ client_wrapper: SyncAPIClient,
105
+ model_path: str,
106
+ device: str = 'cpu',
107
+ compressed: bool = True,
108
+ ):
109
+ """Load tokenizer from exported binary file"""
110
+ from pathlib import Path
111
+
112
+ words_model_path = f'{model_path}/words.bin'
113
+ if compressed:
114
+ with gzip.open(words_model_path, 'rb') as f:
115
+ data = pickle.load(f)
116
+ else:
117
+ with open(words_model_path, 'rb') as f:
118
+ data = pickle.load(f)
119
+
120
+ tokenizer = LatticeTokenizer(client_wrapper=client_wrapper)
121
+ tokenizer.words = data['words']
122
+ tokenizer.dictionaries = defaultdict(list, data['dictionaries'])
123
+ tokenizer.oov_word = data['oov_word']
124
+
125
+ g2p_model_path = f'{model_path}/g2p.bin' if Path(f'{model_path}/g2p.bin').exists() else None
126
+ if g2p_model_path:
127
+ tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
128
+
129
+ tokenizer.device = device
130
+ tokenizer.add_special_tokens()
131
+ return tokenizer
132
+
133
+ def add_special_tokens(self):
134
+ tokenizer = self
135
+ for special_token in ['&gt;&gt;', '&gt;']:
136
+ if special_token not in tokenizer.dictionaries:
137
+ tokenizer.dictionaries[special_token] = tokenizer.dictionaries[tokenizer.oov_word]
138
+ return self
139
+
140
+ def prenormalize(self, texts: List[str], language: Optional[str] = None) -> List[str]:
141
+ if not self.g2p_model:
142
+ raise ValueError('G2P model is not loaded, cannot prenormalize texts')
143
+
144
+ oov_words = []
145
+ for text in texts:
146
+ words = text.lower().replace('-', ' ').replace('—', ' ').replace('–', ' ').split()
147
+ oovs = [w for w in words if w not in self.words]
148
+ if oovs:
149
+ oov_words.extend([w for w in oovs if (w not in self.words and len(w) <= MAXIMUM_WORD_LENGTH)])
150
+
151
+ oov_words = list(set(oov_words))
152
+ if oov_words:
153
+ indexs = []
154
+ for k, _word in enumerate(oov_words):
155
+ if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [('(', ')'), ('[', ']')]):
156
+ self.dictionaries[_word] = self.dictionaries[self.oov_word]
157
+ else:
158
+ _word = _word.strip(PUNCTUATION_SPACE)
159
+ if not _word or _word in self.words:
160
+ indexs.append(k)
161
+ for idx in sorted(indexs, reverse=True):
162
+ del oov_words[idx]
163
+
164
+ g2p_words = [w for w in oov_words if w not in self.dictionaries]
165
+ if g2p_words:
166
+ predictions = self.g2p_model(words=g2p_words, lang=language, batch_size=len(g2p_words), num_prons=4)
167
+ for _word, _predictions in zip(g2p_words, predictions):
168
+ for pronuncation in _predictions:
169
+ if pronuncation and pronuncation not in self.dictionaries[_word]:
170
+ self.dictionaries[_word].append(pronuncation)
171
+ if not self.dictionaries[_word]:
172
+ self.dictionaries[_word] = self.dictionaries[self.oov_word]
173
+
174
+ pronunciation_dictionaries: Dict[str, List[List[str]]] = {
175
+ w: self.dictionaries[w] for w in oov_words if self.dictionaries[w]
176
+ }
177
+ return pronunciation_dictionaries
178
+
179
+ return {}
180
+
181
+ def split_sentences(self, supervisions: List[Supervision], strip_whitespace=True) -> List[str]:
182
+ texts, text_len, sidx = [], 0, 0
183
+ for s, supervision in enumerate(supervisions):
184
+ text_len += len(supervision.text)
185
+ if text_len >= 2000 or s == len(supervisions) - 1:
186
+ text = ' '.join([sup.text for sup in supervisions[sidx : s + 1]])
187
+ texts.append(text)
188
+ sidx = s + 1
189
+ text_len = 0
190
+ if sidx < len(supervisions):
191
+ text = ' '.join([sup.text for sup in supervisions[sidx:]])
192
+ texts.append(text)
193
+ sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
194
+
195
+ supervisions, remainder = [], ''
196
+ for _sentences in sentences:
197
+ # Process and re-split special sentence types
198
+ processed_sentences = []
199
+ for s, _sentence in enumerate(_sentences):
200
+ if remainder:
201
+ _sentence = remainder + _sentence
202
+ remainder = ''
203
+
204
+ # Detect and split special sentence types: e.g., '[APPLAUSE] &gt;&gt; MIRA MURATI:' -> ['[APPLAUSE]', '&gt;&gt; MIRA MURATI:'] # noqa: E501
205
+ resplit_parts = self._resplit_special_sentence_types(_sentence)
206
+ if any(resplit_parts[-1].endswith(sp) for sp in [':', ':']):
207
+ if s < len(_sentences) - 1:
208
+ _sentences[s + 1] = resplit_parts[-1] + ' ' + _sentences[s + 1]
209
+ else: # last part
210
+ remainder = resplit_parts[-1] + ' ' + remainder
211
+ processed_sentences.extend(resplit_parts[:-1])
212
+ else:
213
+ processed_sentences.extend(resplit_parts)
214
+
215
+ _sentences = processed_sentences
216
+
217
+ if remainder:
218
+ _sentences[0] = remainder + _sentences[0]
219
+ remainder = ''
220
+
221
+ if any(_sentences[-1].endswith(ep) for ep in END_PUNCTUATION):
222
+ supervisions.extend(Supervision(text=s) for s in _sentences)
223
+ else:
224
+ supervisions.extend(Supervision(text=s) for s in _sentences[:-1])
225
+ remainder += _sentences[-1] + ' '
226
+
227
+ if remainder.strip():
228
+ supervisions.append(Supervision(text=remainder.strip()))
229
+
230
+ return supervisions
231
+
232
+ def tokenize(self, supervisions: List[Supervision], split_sentence: bool = False) -> Tuple[str, Dict[str, Any]]:
233
+ if split_sentence:
234
+ self.init_sentence_splitter()
235
+ supervisions = self.split_sentences(supervisions)
236
+
237
+ pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
238
+ response = self.client_wrapper.post(
239
+ 'tokenize',
240
+ json={
241
+ 'supervisions': [s.to_dict() for s in supervisions],
242
+ 'pronunciation_dictionaries': pronunciation_dictionaries,
243
+ },
244
+ )
245
+ if response.status_code != 200:
246
+ raise Exception(f'Failed to tokenize texts: {response.text}')
247
+ result = response.json()
248
+ lattice_id = result['id']
249
+ return lattice_id, (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0))
250
+
251
+ def detokenize(
252
+ self,
253
+ lattice_id: str,
254
+ lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
255
+ # return_supervisions: bool = True,
256
+ # return_details: bool = False,
257
+ ) -> List[Supervision]:
258
+ emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
259
+ response = self.client_wrapper.post(
260
+ 'detokenize',
261
+ json={
262
+ 'lattice_id': lattice_id,
263
+ 'frame_shift': frame_shift,
264
+ 'results': [t.to_dict() for t in results[0]],
265
+ 'labels': labels[0],
266
+ 'offset': offset,
267
+ 'channel': channel,
268
+ 'destroy_lattice': True,
269
+ },
270
+ )
271
+ if response.status_code != 200:
272
+ raise Exception(f'Failed to detokenize lattice: {response.text}')
273
+ result = response.json()
274
+ # if return_details:
275
+ # raise NotImplementedError("return_details is not implemented yet")
276
+ return [Supervision.from_dict(s) for s in result['supervisions']]
277
+
278
+
279
+ # Compute average score weighted by the span length
280
+ def _score(spans):
281
+ if not spans:
282
+ return 0.0
283
+ # TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
284
+ return round(sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans), ndigits=4)
@@ -17,8 +17,6 @@ class Lattice1AlphaWorker:
17
17
  """Worker for processing audio with LatticeGraph."""
18
18
 
19
19
  def __init__(self, model_path: Pathlike, device: str = 'cpu', num_threads: int = 8) -> None:
20
- if device != 'cpu':
21
- raise NotImplementedError(f'Only cpu is supported for now, got device={device}.')
22
20
  self.config = json.load(open(f'{model_path}/config.json'))
23
21
 
24
22
  # SessionOptions
@@ -29,8 +27,11 @@ class Lattice1AlphaWorker:
29
27
  sess_options.add_session_config_entry('session.intra_op.allow_spinning', '0')
30
28
 
31
29
  providers = []
32
- if device.startswith('cuda') or ort.get_all_providers().count('CUDAExecutionProvider') > 0:
30
+ if device.startswith('cuda') and ort.get_all_providers().count('CUDAExecutionProvider') > 0:
33
31
  providers.append('CUDAExecutionProvider')
32
+ elif device.startswith('mps') and ort.get_all_providers().count('MPSExecutionProvider') > 0:
33
+ providers.append('MPSExecutionProvider')
34
+
34
35
  self.acoustic_ort = ort.InferenceSession(
35
36
  f'{model_path}/acoustic_opt.onnx',
36
37
  sess_options,
@@ -49,13 +50,29 @@ class Lattice1AlphaWorker:
49
50
  _start = time.time()
50
51
  # audio -> features -> emission
51
52
  features = self.extractor(audio) # (1, T, D)
52
- ort_inputs = {
53
- 'features': features.cpu().numpy(),
54
- 'feature_lengths': np.array([features.size(1)], dtype=np.int64),
55
- }
56
- emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
53
+ if features.shape[1] > 6000:
54
+ features_list = torch.split(features, 6000, dim=1)
55
+ emissions = []
56
+ for features in features_list:
57
+ ort_inputs = {
58
+ 'features': features.cpu().numpy(),
59
+ 'feature_lengths': np.array([features.size(1)], dtype=np.int64),
60
+ }
61
+ emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
62
+ emissions.append(emission)
63
+ emission = torch.cat(
64
+ [torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
65
+ ) # (1, T, vocab_size)
66
+ else:
67
+ ort_inputs = {
68
+ 'features': features.cpu().numpy(),
69
+ 'feature_lengths': np.array([features.size(1)], dtype=np.int64),
70
+ }
71
+ emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
72
+ emission = torch.from_numpy(emission).to(self.device)
73
+
57
74
  self.timings['emission'] += time.time() - _start
58
- return torch.from_numpy(emission).to(self.device) # (1, T, vocab_size) torch
75
+ return emission # (1, T, vocab_size) torch
59
76
 
60
77
  def load_audio(self, audio: Union[Pathlike, BinaryIO]) -> Tuple[torch.Tensor, int]:
61
78
  # load audio
@@ -104,9 +121,14 @@ class Lattice1AlphaWorker:
104
121
  self.timings['decoding_graph'] += time.time() - _start
105
122
 
106
123
  _start = time.time()
124
+ if self.device.type == 'mps':
125
+ device = 'cpu' # k2 does not support mps yet
126
+ else:
127
+ device = self.device
128
+
107
129
  results, labels = align_segments(
108
- emission.to(self.device) * acoustic_scale,
109
- decoding_graph.to(self.device),
130
+ emission.to(device) * acoustic_scale,
131
+ decoding_graph.to(device),
110
132
  torch.tensor([emission.shape[1]], dtype=torch.int32),
111
133
  search_beam=100,
112
134
  output_beam=40,