lattifai 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lattifai/io/writer.py CHANGED
@@ -1,49 +1,90 @@
1
+ import json
1
2
  from abc import ABCMeta
2
- from typing import List
3
+ from typing import Any, List, Optional
3
4
 
5
+ import pysubs2
6
+ from lhotse.supervision import AlignmentItem
4
7
  from lhotse.utils import Pathlike
5
8
 
6
- from .reader import SubtitleFormat, Supervision
9
+ from .reader import Supervision
7
10
 
8
11
 
9
12
  class SubtitleWriter(ABCMeta):
10
- """Class for writing subtitle files."""
13
+ """Class for writing subtitle files with optional word-level alignment."""
11
14
 
12
15
  @classmethod
13
16
  def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
14
17
  if str(output_path)[-4:].lower() == '.txt':
15
18
  with open(output_path, 'w', encoding='utf-8') as f:
16
19
  for sup in alignments:
17
- f.write(f'{sup.text}\n')
20
+ word_items = parse_alignment_from_supervision(sup)
21
+ if word_items:
22
+ for item in word_items:
23
+ f.write(f'[{item.start:.2f}-{item.end:.2f}] {item.symbol}\n')
24
+ else:
25
+ text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
26
+ f.write(f'[{sup.start:.2f}-{sup.end:.2f}] {text}\n')
27
+
18
28
  elif str(output_path)[-5:].lower() == '.json':
19
29
  with open(output_path, 'w', encoding='utf-8') as f:
20
- import json
21
-
22
- json.dump([sup.to_dict() for sup in alignments], f, ensure_ascii=False, indent=4)
30
+ # Enhanced JSON export with word-level alignment
31
+ json_data = []
32
+ for sup in alignments:
33
+ sup_dict = sup.to_dict()
34
+ json_data.append(sup_dict)
35
+ json.dump(json_data, f, ensure_ascii=False, indent=4)
23
36
  elif str(output_path).endswith('.TextGrid') or str(output_path).endswith('.textgrid'):
24
37
  from tgt import Interval, IntervalTier, TextGrid, write_to_file
25
38
 
26
39
  tg = TextGrid()
27
40
  supervisions, words = [], []
28
41
  for supervision in sorted(alignments, key=lambda x: x.start):
29
- supervisions.append(Interval(supervision.start, supervision.end, supervision.text or ''))
30
- if supervision.alignment and 'word' in supervision.alignment:
31
- for alignment in supervision.alignment['word']:
32
- words.append(Interval(alignment.start, alignment.end, alignment.symbol))
42
+ text = (
43
+ f'{supervision.speaker} {supervision.text}' if supervision.speaker is not None else supervision.text
44
+ )
45
+ supervisions.append(Interval(supervision.start, supervision.end, text or ''))
46
+ # Extract word-level alignment using helper function
47
+ word_items = parse_alignment_from_supervision(supervision)
48
+ if word_items:
49
+ for item in word_items:
50
+ words.append(Interval(item.start, item.end, item.symbol))
33
51
 
34
52
  tg.add_tier(IntervalTier(name='utterances', objects=supervisions))
35
53
  if words:
36
54
  tg.add_tier(IntervalTier(name='words', objects=words))
37
55
  write_to_file(tg, output_path, format='long')
38
56
  else:
39
- import pysubs2
40
-
41
57
  subs = pysubs2.SSAFile()
42
58
  for sup in alignments:
43
- start = int(sup.start * 1000)
44
- end = int(sup.end * 1000)
45
- text = sup.text or ''
46
- subs.append(pysubs2.SSAEvent(start=start, end=end, text=text))
59
+ # Add word-level timing as metadata in the subtitle text
60
+ word_items = parse_alignment_from_supervision(sup)
61
+ if word_items:
62
+ for word in word_items:
63
+ subs.append(
64
+ pysubs2.SSAEvent(start=int(word.start * 1000), end=int(word.end * 1000), text=word.symbol)
65
+ )
66
+ else:
67
+ text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
68
+ subs.append(pysubs2.SSAEvent(start=int(sup.start * 1000), end=int(sup.end * 1000), text=text or ''))
47
69
  subs.save(output_path)
48
70
 
49
71
  return output_path
72
+
73
+
74
+ def parse_alignment_from_supervision(supervision: Any) -> Optional[List[AlignmentItem]]:
75
+ """
76
+ Extract word-level alignment items from Supervision object.
77
+
78
+ Args:
79
+ supervision: Supervision object with potential alignment data
80
+
81
+ Returns:
82
+ List of AlignmentItem objects, or None if no alignment data present
83
+ """
84
+ if not hasattr(supervision, 'alignment') or not supervision.alignment:
85
+ return None
86
+
87
+ if 'word' not in supervision.alignment:
88
+ return None
89
+
90
+ return supervision.alignment['word']
@@ -1,3 +1,3 @@
1
- from .tokenizer import LatticeTokenizer
1
+ from .tokenizer import AsyncLatticeTokenizer, LatticeTokenizer
2
2
 
3
- __all__ = ['LatticeTokenizer']
3
+ __all__ = ['LatticeTokenizer', 'AsyncLatticeTokenizer']
@@ -1,13 +1,13 @@
1
1
  import gzip
2
+ import inspect
2
3
  import pickle
3
4
  import re
4
5
  from collections import defaultdict
5
- from itertools import chain
6
- from typing import Any, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
7
7
 
8
8
  import torch
9
9
 
10
- from lattifai.base_client import SyncAPIClient
10
+ from lattifai.errors import LATTICE_DECODING_FAILURE_HELP, LatticeDecodingError
11
11
  from lattifai.io import Supervision
12
12
  from lattifai.tokenizer.phonemizer import G2Phonemizer
13
13
 
@@ -21,10 +21,13 @@ GROUPING_SEPARATOR = '✹'
21
21
  MAXIMUM_WORD_LENGTH = 40
22
22
 
23
23
 
24
+ TokenizerT = TypeVar('TokenizerT', bound='LatticeTokenizer')
25
+
26
+
24
27
  class LatticeTokenizer:
25
28
  """Tokenizer for converting Lhotse Cut to LatticeGraph."""
26
29
 
27
- def __init__(self, client_wrapper: SyncAPIClient):
30
+ def __init__(self, client_wrapper: Any):
28
31
  self.client_wrapper = client_wrapper
29
32
  self.words: List[str] = []
30
33
  self.g2p_model: Any = None # Placeholder for G2P model
@@ -99,13 +102,14 @@ class LatticeTokenizer:
99
102
  # If no special pattern matches, return the original sentence
100
103
  return [sentence]
101
104
 
102
- @staticmethod
105
+ @classmethod
103
106
  def from_pretrained(
104
- client_wrapper: SyncAPIClient,
107
+ cls: Type[TokenizerT],
108
+ client_wrapper: Any,
105
109
  model_path: str,
106
110
  device: str = 'cpu',
107
111
  compressed: bool = True,
108
- ):
112
+ ) -> TokenizerT:
109
113
  """Load tokenizer from exported binary file"""
110
114
  from pathlib import Path
111
115
 
@@ -117,7 +121,7 @@ class LatticeTokenizer:
117
121
  with open(words_model_path, 'rb') as f:
118
122
  data = pickle.load(f)
119
123
 
120
- tokenizer = LatticeTokenizer(client_wrapper=client_wrapper)
124
+ tokenizer = cls(client_wrapper=client_wrapper)
121
125
  tokenizer.words = data['words']
122
126
  tokenizer.dictionaries = defaultdict(list, data['dictionaries'])
123
127
  tokenizer.oov_word = data['oov_word']
@@ -179,53 +183,89 @@ class LatticeTokenizer:
179
183
  return {}
180
184
 
181
185
  def split_sentences(self, supervisions: List[Supervision], strip_whitespace=True) -> List[str]:
186
+ """Split supervisions into sentences using the sentence splitter.
187
+
188
+ Carefull about speaker changes.
189
+ """
182
190
  texts, text_len, sidx = [], 0, 0
191
+ speakers = []
183
192
  for s, supervision in enumerate(supervisions):
184
193
  text_len += len(supervision.text)
185
- if text_len >= 2000 or s == len(supervisions) - 1:
186
- text = ' '.join([sup.text for sup in supervisions[sidx : s + 1]])
187
- texts.append(text)
188
- sidx = s + 1
189
- text_len = 0
190
- if sidx < len(supervisions):
191
- text = ' '.join([sup.text for sup in supervisions[sidx:]])
192
- texts.append(text)
194
+ if supervision.speaker:
195
+ speakers.append(supervision.speaker)
196
+ if sidx < s:
197
+ text = ' '.join([sup.text for sup in supervisions[sidx:s]])
198
+ texts.append(text)
199
+ sidx = s
200
+ text_len = len(supervision.text)
201
+ else:
202
+ if text_len >= 2000 or s == len(supervisions) - 1:
203
+ if len(speakers) < len(texts) + 1:
204
+ speakers.append(None)
205
+ text = ' '.join([sup.text for sup in supervisions[sidx : s + 1]])
206
+ texts.append(text)
207
+ sidx = s + 1
208
+ text_len = 0
209
+
210
+ assert len(speakers) == len(texts), f'len(speakers)={len(speakers)} != len(texts)={len(texts)}'
193
211
  sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
194
212
 
195
213
  supervisions, remainder = [], ''
196
- for _sentences in sentences:
214
+ for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
215
+ # Prepend remainder from previous iteration to the first sentence
216
+ if _sentences and remainder:
217
+ _sentences[0] = remainder + _sentences[0]
218
+ remainder = ''
219
+
220
+ if not _sentences:
221
+ continue
222
+
197
223
  # Process and re-split special sentence types
198
224
  processed_sentences = []
199
225
  for s, _sentence in enumerate(_sentences):
200
226
  if remainder:
201
227
  _sentence = remainder + _sentence
202
228
  remainder = ''
203
-
204
229
  # Detect and split special sentence types: e.g., '[APPLAUSE] &gt;&gt; MIRA MURATI:' -> ['[APPLAUSE]', '&gt;&gt; MIRA MURATI:'] # noqa: E501
205
230
  resplit_parts = self._resplit_special_sentence_types(_sentence)
206
231
  if any(resplit_parts[-1].endswith(sp) for sp in [':', ':']):
207
232
  if s < len(_sentences) - 1:
208
233
  _sentences[s + 1] = resplit_parts[-1] + ' ' + _sentences[s + 1]
209
234
  else: # last part
210
- remainder = resplit_parts[-1] + ' ' + remainder
235
+ remainder = resplit_parts[-1] + ' '
211
236
  processed_sentences.extend(resplit_parts[:-1])
212
237
  else:
213
238
  processed_sentences.extend(resplit_parts)
214
-
215
239
  _sentences = processed_sentences
216
240
 
217
- if remainder:
218
- _sentences[0] = remainder + _sentences[0]
219
- remainder = ''
220
-
221
241
  if any(_sentences[-1].endswith(ep) for ep in END_PUNCTUATION):
222
- supervisions.extend(Supervision(text=s) for s in _sentences)
242
+ supervisions.extend(
243
+ Supervision(text=text, speaker=(_speaker if s == 0 else None)) for s, text in enumerate(_sentences)
244
+ )
245
+ _speaker = None # reset speaker after use
223
246
  else:
224
- supervisions.extend(Supervision(text=s) for s in _sentences[:-1])
225
- remainder += _sentences[-1] + ' '
247
+ supervisions.extend(
248
+ Supervision(text=text, speaker=(_speaker if s == 0 else None))
249
+ for s, text in enumerate(_sentences[:-1])
250
+ )
251
+ remainder = _sentences[-1] + ' ' + remainder
252
+ if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
253
+ supervisions.append(
254
+ Supervision(text=remainder.strip(), speaker=_speaker if len(_sentences) == 1 else None)
255
+ )
256
+ remainder = ''
257
+ elif len(_sentences) == 1:
258
+ if k == len(speakers) - 1:
259
+ pass # keep _speaker for the last supervision
260
+ else:
261
+ assert speakers[k + 1] is None
262
+ speakers[k + 1] = _speaker
263
+ else:
264
+ assert len(_sentences) > 1
265
+ _speaker = None # reset speaker if sentence not ended
226
266
 
227
267
  if remainder.strip():
228
- supervisions.append(Supervision(text=remainder.strip()))
268
+ supervisions.append(Supervision(text=remainder.strip(), speaker=_speaker))
229
269
 
230
270
  return supervisions
231
271
 
@@ -246,14 +286,18 @@ class LatticeTokenizer:
246
286
  raise Exception(f'Failed to tokenize texts: {response.text}')
247
287
  result = response.json()
248
288
  lattice_id = result['id']
249
- return lattice_id, (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0))
289
+ return (
290
+ supervisions,
291
+ lattice_id,
292
+ (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
293
+ )
250
294
 
251
295
  def detokenize(
252
296
  self,
253
297
  lattice_id: str,
254
298
  lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
255
- # return_supervisions: bool = True,
256
- # return_details: bool = False,
299
+ supervisions: List[Supervision],
300
+ return_details: bool = False,
257
301
  ) -> List[Supervision]:
258
302
  emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
259
303
  response = self.client_wrapper.post(
@@ -265,20 +309,157 @@ class LatticeTokenizer:
265
309
  'labels': labels[0],
266
310
  'offset': offset,
267
311
  'channel': channel,
312
+ 'return_details': return_details,
268
313
  'destroy_lattice': True,
269
314
  },
270
315
  )
316
+ if response.status_code == 422:
317
+ raise LatticeDecodingError(
318
+ lattice_id,
319
+ original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
320
+ )
271
321
  if response.status_code != 200:
272
322
  raise Exception(f'Failed to detokenize lattice: {response.text}')
323
+
273
324
  result = response.json()
274
- # if return_details:
275
- # raise NotImplementedError("return_details is not implemented yet")
276
- return [Supervision.from_dict(s) for s in result['supervisions']]
325
+ if not result.get('success'):
326
+ raise Exception('Failed to detokenize the alignment results.')
327
+
328
+ alignments = [Supervision.from_dict(s) for s in result['supervisions']]
329
+
330
+ if return_details:
331
+ # Add emission confidence scores for segments and word-level alignments
332
+ _add_confidence_scores(alignments, emission, labels[0], frame_shift)
333
+
334
+ alignments = _update_alignments_speaker(supervisions, alignments)
335
+
336
+ return alignments
337
+
277
338
 
339
+ class AsyncLatticeTokenizer(LatticeTokenizer):
340
+ async def _post_async(self, endpoint: str, **kwargs):
341
+ response = self.client_wrapper.post(endpoint, **kwargs)
342
+ if inspect.isawaitable(response):
343
+ return await response
344
+ return response
278
345
 
279
- # Compute average score weighted by the span length
280
- def _score(spans):
281
- if not spans:
282
- return 0.0
283
- # TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
284
- return round(sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans), ndigits=4)
346
+ async def tokenize(
347
+ self, supervisions: List[Supervision], split_sentence: bool = False
348
+ ) -> Tuple[str, Dict[str, Any]]:
349
+ if split_sentence:
350
+ self.init_sentence_splitter()
351
+ supervisions = self.split_sentences(supervisions)
352
+
353
+ pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
354
+ response = await self._post_async(
355
+ 'tokenize',
356
+ json={
357
+ 'supervisions': [s.to_dict() for s in supervisions],
358
+ 'pronunciation_dictionaries': pronunciation_dictionaries,
359
+ },
360
+ )
361
+ if response.status_code != 200:
362
+ raise Exception(f'Failed to tokenize texts: {response.text}')
363
+ result = response.json()
364
+ lattice_id = result['id']
365
+ return (
366
+ supervisions,
367
+ lattice_id,
368
+ (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
369
+ )
370
+
371
+ async def detokenize(
372
+ self,
373
+ lattice_id: str,
374
+ lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
375
+ supervisions: List[Supervision],
376
+ return_details: bool = False,
377
+ ) -> List[Supervision]:
378
+ emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
379
+ response = await self._post_async(
380
+ 'detokenize',
381
+ json={
382
+ 'lattice_id': lattice_id,
383
+ 'frame_shift': frame_shift,
384
+ 'results': [t.to_dict() for t in results[0]],
385
+ 'labels': labels[0],
386
+ 'offset': offset,
387
+ 'channel': channel,
388
+ 'return_details': return_details,
389
+ 'destroy_lattice': True,
390
+ },
391
+ )
392
+ if response.status_code == 422:
393
+ raise LatticeDecodingError(
394
+ lattice_id,
395
+ original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
396
+ )
397
+ if response.status_code != 200:
398
+ raise Exception(f'Failed to detokenize lattice: {response.text}')
399
+
400
+ result = response.json()
401
+ if not result.get('success'):
402
+ return Exception('Failed to detokenize the alignment results.')
403
+
404
+ alignments = [Supervision.from_dict(s) for s in result['supervisions']]
405
+
406
+ if return_details:
407
+ # Add emission confidence scores for segments and word-level alignments
408
+ _add_confidence_scores(alignments, emission, labels[0], frame_shift)
409
+
410
+ alignments = _update_alignments_speaker(supervisions, alignments)
411
+
412
+ return alignments
413
+
414
+
415
+ def _add_confidence_scores(
416
+ supervisions: List[Supervision],
417
+ emission: torch.Tensor,
418
+ labels: List[int],
419
+ frame_shift: float,
420
+ ) -> None:
421
+ """
422
+ Add confidence scores to supervisions and their word-level alignments.
423
+
424
+ This function modifies supervisions in-place by:
425
+ 1. Computing segment-level confidence scores based on emission probabilities
426
+ 2. Computing word-level confidence scores for each aligned word
427
+
428
+ Args:
429
+ supervisions: List of Supervision objects to add scores to (modified in-place)
430
+ emission: Emission tensor with shape [batch, time, vocab_size]
431
+ labels: Token labels corresponding to aligned tokens
432
+ frame_shift: Frame shift in seconds for converting frames to time
433
+ """
434
+ tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
435
+
436
+ for supervision in supervisions:
437
+ start_frame = int(supervision.start / frame_shift)
438
+ end_frame = int(supervision.end / frame_shift)
439
+
440
+ # Compute segment-level confidence
441
+ probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
442
+ aligned = probabilities[range(0, end_frame - start_frame), tokens[start_frame:end_frame]]
443
+ diffprobs = (probabilities.max(dim=-1).values - aligned).cpu()
444
+ supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
445
+
446
+ # Compute word-level confidence if alignment exists
447
+ if hasattr(supervision, 'alignment') and supervision.alignment:
448
+ words = supervision.alignment.get('word', [])
449
+ for w, item in enumerate(words):
450
+ start = int(item.start / frame_shift) - start_frame
451
+ end = int(item.end / frame_shift) - start_frame
452
+ words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
453
+
454
+
455
+ def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]:
456
+ """
457
+ Update the speaker attribute for a list of supervisions.
458
+
459
+ Args:
460
+ supervisions: List of Supervision objects to get speaker info from
461
+ alignments: List of aligned Supervision objects to update speaker info to
462
+ """
463
+ for supervision, alignment in zip(supervisions, alignments):
464
+ alignment.speaker = supervision.speaker
465
+ return alignments
lattifai/utils.py ADDED
@@ -0,0 +1,133 @@
1
+ """Shared utility helpers for the LattifAI SDK."""
2
+
3
+ import os
4
+ from datetime import datetime, timedelta
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Type
7
+
8
+ from lattifai.errors import ModelLoadError
9
+ from lattifai.tokenizer import LatticeTokenizer
10
+ from lattifai.workers import Lattice1AlphaWorker
11
+
12
+
13
+ def _get_cache_marker_path(cache_dir: Path) -> Path:
14
+ """Get the path for the cache marker file with current date."""
15
+ today = datetime.now().strftime('%Y%m%d')
16
+ return cache_dir / f'.done{today}'
17
+
18
+
19
+ def _is_cache_valid(cache_dir: Path) -> bool:
20
+ """Check if cached model is valid (exists and not older than 1 days)."""
21
+ if not cache_dir.exists():
22
+ return False
23
+
24
+ # Find any .done* marker files
25
+ marker_files = list(cache_dir.glob('.done*'))
26
+ if not marker_files:
27
+ return False
28
+
29
+ # Get the most recent marker file
30
+ latest_marker = max(marker_files, key=lambda p: p.stat().st_mtime)
31
+
32
+ # Extract date from marker filename (format: .doneYYYYMMDD)
33
+ try:
34
+ date_str = latest_marker.name.replace('.done', '')
35
+ marker_date = datetime.strptime(date_str, '%Y%m%d')
36
+ # Check if marker is older than 1 days
37
+ if datetime.now() - marker_date > timedelta(days=1):
38
+ return False
39
+ return True
40
+ except (ValueError, IndexError):
41
+ # Invalid marker file format, treat as invalid cache
42
+ return False
43
+
44
+
45
+ def _create_cache_marker(cache_dir: Path) -> None:
46
+ """Create a cache marker file with current date and clean old markers."""
47
+ # Remove old marker files
48
+ for old_marker in cache_dir.glob('.done*'):
49
+ old_marker.unlink(missing_ok=True)
50
+
51
+ # Create new marker file
52
+ marker_path = _get_cache_marker_path(cache_dir)
53
+ marker_path.touch()
54
+
55
+
56
+ def _resolve_model_path(model_name_or_path: str) -> str:
57
+ """Resolve model path, downloading from Hugging Face when necessary."""
58
+ if Path(model_name_or_path).exists():
59
+ return model_name_or_path
60
+
61
+ from huggingface_hub import snapshot_download
62
+ from huggingface_hub.constants import HF_HUB_CACHE
63
+ from huggingface_hub.errors import LocalEntryNotFoundError
64
+
65
+ # Determine cache directory for this model
66
+ cache_dir = Path(HF_HUB_CACHE) / f'models--{model_name_or_path.replace("/", "--")}'
67
+
68
+ # Check if we have a valid cached version
69
+ if _is_cache_valid(cache_dir):
70
+ # Return the snapshot path (latest version)
71
+ snapshots_dir = cache_dir / 'snapshots'
72
+ if snapshots_dir.exists():
73
+ snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]
74
+ if snapshot_dirs:
75
+ # Return the most recent snapshot
76
+ latest_snapshot = max(snapshot_dirs, key=lambda p: p.stat().st_mtime)
77
+ return str(latest_snapshot)
78
+
79
+ try:
80
+ downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
81
+ _create_cache_marker(cache_dir)
82
+ return downloaded_path
83
+ except LocalEntryNotFoundError:
84
+ try:
85
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
86
+ downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
87
+ _create_cache_marker(cache_dir)
88
+ return downloaded_path
89
+ except Exception as e: # pragma: no cover - bubble up for caller context
90
+ raise ModelLoadError(model_name_or_path, original_error=e)
91
+ except Exception as e: # pragma: no cover - unexpected download issue
92
+ raise ModelLoadError(model_name_or_path, original_error=e)
93
+
94
+
95
+ def _select_device(device: Optional[str]) -> str:
96
+ """Select best available torch device when not explicitly provided."""
97
+ if device:
98
+ return device
99
+
100
+ import torch
101
+
102
+ detected = 'cpu'
103
+ if torch.backends.mps.is_available():
104
+ detected = 'mps'
105
+ elif torch.cuda.is_available():
106
+ detected = 'cuda'
107
+ return detected
108
+
109
+
110
+ def _load_tokenizer(
111
+ client_wrapper: Any,
112
+ model_path: str,
113
+ device: str,
114
+ *,
115
+ tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
116
+ ) -> LatticeTokenizer:
117
+ """Instantiate tokenizer with consistent error handling."""
118
+ try:
119
+ return tokenizer_cls.from_pretrained(
120
+ client_wrapper=client_wrapper,
121
+ model_path=model_path,
122
+ device=device,
123
+ )
124
+ except Exception as e:
125
+ raise ModelLoadError(f'tokenizer from {model_path}', original_error=e)
126
+
127
+
128
+ def _load_worker(model_path: str, device: str) -> Lattice1AlphaWorker:
129
+ """Instantiate lattice worker with consistent error handling."""
130
+ try:
131
+ return Lattice1AlphaWorker(model_path, device=device, num_threads=8)
132
+ except Exception as e:
133
+ raise ModelLoadError(f'worker from {model_path}', original_error=e)