lattifai 0.2.5__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,22 @@ from lhotse.utils import Seconds
7
7
 
8
8
  @dataclass
9
9
  class Supervision(SupervisionSegment):
10
+ """
11
+ Extended SupervisionSegment with simplified initialization.
12
+
13
+ Note: The `alignment` field is inherited from SupervisionSegment:
14
+ alignment: Optional[Dict[str, List[AlignmentItem]]] = None
15
+
16
+ Structure of alignment when return_details=True:
17
+ {
18
+ 'word': [
19
+ AlignmentItem(symbol='hello', start=0.0, duration=0.5, score=0.95),
20
+ AlignmentItem(symbol='world', start=0.6, duration=0.4, score=0.92),
21
+ ...
22
+ ]
23
+ }
24
+ """
25
+
10
26
  text: Optional[str] = None
11
27
  id: str = ''
12
28
  recording_id: str = ''
lattifai/io/utils.py ADDED
@@ -0,0 +1,15 @@
1
+ """
2
+ Utility constants and helper functions for subtitle I/O operations
3
+ """
4
+
5
+ # Supported subtitle formats for reading/writing
6
+ SUBTITLE_FORMATS = ['srt', 'vtt', 'ass', 'ssa', 'sub', 'sbv', 'txt', 'md']
7
+
8
+ # Input subtitle formats (includes special formats like 'auto' and 'gemini')
9
+ INPUT_SUBTITLE_FORMATS = ['srt', 'vtt', 'ass', 'ssa', 'sub', 'sbv', 'txt', 'auto', 'gemini']
10
+
11
+ # Output subtitle formats (includes special formats like 'TextGrid' and 'json')
12
+ OUTPUT_SUBTITLE_FORMATS = ['srt', 'vtt', 'ass', 'ssa', 'sub', 'sbv', 'txt', 'TextGrid', 'json']
13
+
14
+ # All subtitle formats combined (for file detection)
15
+ ALL_SUBTITLE_FORMATS = list(set(SUBTITLE_FORMATS + ['TextGrid', 'json', 'gemini']))
lattifai/io/writer.py CHANGED
@@ -1,49 +1,90 @@
1
+ import json
1
2
  from abc import ABCMeta
2
- from typing import List
3
+ from typing import Any, List, Optional
3
4
 
5
+ import pysubs2
6
+ from lhotse.supervision import AlignmentItem
4
7
  from lhotse.utils import Pathlike
5
8
 
6
- from .reader import SubtitleFormat, Supervision
9
+ from .reader import Supervision
7
10
 
8
11
 
9
12
  class SubtitleWriter(ABCMeta):
10
- """Class for writing subtitle files."""
13
+ """Class for writing subtitle files with optional word-level alignment."""
11
14
 
12
15
  @classmethod
13
16
  def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
14
17
  if str(output_path)[-4:].lower() == '.txt':
15
18
  with open(output_path, 'w', encoding='utf-8') as f:
16
19
  for sup in alignments:
17
- f.write(f'{sup.text}\n')
20
+ word_items = parse_alignment_from_supervision(sup)
21
+ if word_items:
22
+ for item in word_items:
23
+ f.write(f'[{item.start:.2f}-{item.end:.2f}] {item.symbol}\n')
24
+ else:
25
+ text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
26
+ f.write(f'[{sup.start:.2f}-{sup.end:.2f}] {text}\n')
27
+
18
28
  elif str(output_path)[-5:].lower() == '.json':
19
29
  with open(output_path, 'w', encoding='utf-8') as f:
20
- import json
21
-
22
- json.dump([sup.to_dict() for sup in alignments], f, ensure_ascii=False, indent=4)
30
+ # Enhanced JSON export with word-level alignment
31
+ json_data = []
32
+ for sup in alignments:
33
+ sup_dict = sup.to_dict()
34
+ json_data.append(sup_dict)
35
+ json.dump(json_data, f, ensure_ascii=False, indent=4)
23
36
  elif str(output_path).endswith('.TextGrid') or str(output_path).endswith('.textgrid'):
24
37
  from tgt import Interval, IntervalTier, TextGrid, write_to_file
25
38
 
26
39
  tg = TextGrid()
27
40
  supervisions, words = [], []
28
41
  for supervision in sorted(alignments, key=lambda x: x.start):
29
- supervisions.append(Interval(supervision.start, supervision.end, supervision.text or ''))
30
- if supervision.alignment and 'word' in supervision.alignment:
31
- for alignment in supervision.alignment['word']:
32
- words.append(Interval(alignment.start, alignment.end, alignment.symbol))
42
+ text = (
43
+ f'{supervision.speaker} {supervision.text}' if supervision.speaker is not None else supervision.text
44
+ )
45
+ supervisions.append(Interval(supervision.start, supervision.end, text or ''))
46
+ # Extract word-level alignment using helper function
47
+ word_items = parse_alignment_from_supervision(supervision)
48
+ if word_items:
49
+ for item in word_items:
50
+ words.append(Interval(item.start, item.end, item.symbol))
33
51
 
34
52
  tg.add_tier(IntervalTier(name='utterances', objects=supervisions))
35
53
  if words:
36
54
  tg.add_tier(IntervalTier(name='words', objects=words))
37
55
  write_to_file(tg, output_path, format='long')
38
56
  else:
39
- import pysubs2
40
-
41
57
  subs = pysubs2.SSAFile()
42
58
  for sup in alignments:
43
- start = int(sup.start * 1000)
44
- end = int(sup.end * 1000)
45
- text = sup.text or ''
46
- subs.append(pysubs2.SSAEvent(start=start, end=end, text=text))
59
+ # Add word-level timing as metadata in the subtitle text
60
+ word_items = parse_alignment_from_supervision(sup)
61
+ if word_items:
62
+ for word in word_items:
63
+ subs.append(
64
+ pysubs2.SSAEvent(start=int(word.start * 1000), end=int(word.end * 1000), text=word.symbol)
65
+ )
66
+ else:
67
+ text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
68
+ subs.append(pysubs2.SSAEvent(start=int(sup.start * 1000), end=int(sup.end * 1000), text=text or ''))
47
69
  subs.save(output_path)
48
70
 
49
71
  return output_path
72
+
73
+
74
+ def parse_alignment_from_supervision(supervision: Any) -> Optional[List[AlignmentItem]]:
75
+ """
76
+ Extract word-level alignment items from Supervision object.
77
+
78
+ Args:
79
+ supervision: Supervision object with potential alignment data
80
+
81
+ Returns:
82
+ List of AlignmentItem objects, or None if no alignment data present
83
+ """
84
+ if not hasattr(supervision, 'alignment') or not supervision.alignment:
85
+ return None
86
+
87
+ if 'word' not in supervision.alignment:
88
+ return None
89
+
90
+ return supervision.alignment['word']
@@ -1,3 +1,3 @@
1
- from .tokenizer import LatticeTokenizer
1
+ from .tokenizer import AsyncLatticeTokenizer, LatticeTokenizer
2
2
 
3
- __all__ = ['LatticeTokenizer']
3
+ __all__ = ['LatticeTokenizer', 'AsyncLatticeTokenizer']
@@ -1,13 +1,13 @@
1
1
  import gzip
2
+ import inspect
2
3
  import pickle
3
4
  import re
4
5
  from collections import defaultdict
5
- from itertools import chain
6
- from typing import Any, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
7
7
 
8
8
  import torch
9
9
 
10
- from lattifai.base_client import SyncAPIClient
10
+ from lattifai.errors import LATTICE_DECODING_FAILURE_HELP, LatticeDecodingError
11
11
  from lattifai.io import Supervision
12
12
  from lattifai.tokenizer.phonemizer import G2Phonemizer
13
13
 
@@ -21,10 +21,13 @@ GROUPING_SEPARATOR = '✹'
21
21
  MAXIMUM_WORD_LENGTH = 40
22
22
 
23
23
 
24
+ TokenizerT = TypeVar('TokenizerT', bound='LatticeTokenizer')
25
+
26
+
24
27
  class LatticeTokenizer:
25
28
  """Tokenizer for converting Lhotse Cut to LatticeGraph."""
26
29
 
27
- def __init__(self, client_wrapper: SyncAPIClient):
30
+ def __init__(self, client_wrapper: Any):
28
31
  self.client_wrapper = client_wrapper
29
32
  self.words: List[str] = []
30
33
  self.g2p_model: Any = None # Placeholder for G2P model
@@ -99,13 +102,14 @@ class LatticeTokenizer:
99
102
  # If no special pattern matches, return the original sentence
100
103
  return [sentence]
101
104
 
102
- @staticmethod
105
+ @classmethod
103
106
  def from_pretrained(
104
- client_wrapper: SyncAPIClient,
107
+ cls: Type[TokenizerT],
108
+ client_wrapper: Any,
105
109
  model_path: str,
106
110
  device: str = 'cpu',
107
111
  compressed: bool = True,
108
- ):
112
+ ) -> TokenizerT:
109
113
  """Load tokenizer from exported binary file"""
110
114
  from pathlib import Path
111
115
 
@@ -117,7 +121,7 @@ class LatticeTokenizer:
117
121
  with open(words_model_path, 'rb') as f:
118
122
  data = pickle.load(f)
119
123
 
120
- tokenizer = LatticeTokenizer(client_wrapper=client_wrapper)
124
+ tokenizer = cls(client_wrapper=client_wrapper)
121
125
  tokenizer.words = data['words']
122
126
  tokenizer.dictionaries = defaultdict(list, data['dictionaries'])
123
127
  tokenizer.oov_word = data['oov_word']
@@ -179,53 +183,98 @@ class LatticeTokenizer:
179
183
  return {}
180
184
 
181
185
  def split_sentences(self, supervisions: List[Supervision], strip_whitespace=True) -> List[str]:
186
+ """Split supervisions into sentences using the sentence splitter.
187
+
188
+ Carefull about speaker changes.
189
+ """
182
190
  texts, text_len, sidx = [], 0, 0
191
+ speakers = []
183
192
  for s, supervision in enumerate(supervisions):
184
193
  text_len += len(supervision.text)
185
- if text_len >= 2000 or s == len(supervisions) - 1:
186
- text = ' '.join([sup.text for sup in supervisions[sidx : s + 1]])
187
- texts.append(text)
188
- sidx = s + 1
189
- text_len = 0
190
- if sidx < len(supervisions):
191
- text = ' '.join([sup.text for sup in supervisions[sidx:]])
192
- texts.append(text)
194
+ if supervision.speaker:
195
+ if sidx < s:
196
+ if len(speakers) < len(texts) + 1:
197
+ speakers.append(None)
198
+ text = ' '.join([sup.text for sup in supervisions[sidx:s]])
199
+ texts.append(text)
200
+ sidx = s
201
+ text_len = len(supervision.text)
202
+ speakers.append(supervision.speaker)
203
+
204
+ else:
205
+ if text_len >= 2000 or s == len(supervisions) - 1:
206
+ if len(speakers) < len(texts) + 1:
207
+ speakers.append(None)
208
+ text = ' '.join([sup.text for sup in supervisions[sidx : s + 1]])
209
+ texts.append(text)
210
+ sidx = s + 1
211
+ text_len = 0
212
+
213
+ assert len(speakers) == len(texts), f'len(speakers)={len(speakers)} != len(texts)={len(texts)}'
193
214
  sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
194
215
 
195
216
  supervisions, remainder = [], ''
196
- for _sentences in sentences:
217
+ for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
218
+ # Prepend remainder from previous iteration to the first sentence
219
+ if _sentences and remainder:
220
+ _sentences[0] = remainder + _sentences[0]
221
+ remainder = ''
222
+
223
+ if not _sentences:
224
+ continue
225
+
197
226
  # Process and re-split special sentence types
198
227
  processed_sentences = []
199
228
  for s, _sentence in enumerate(_sentences):
200
229
  if remainder:
201
230
  _sentence = remainder + _sentence
202
231
  remainder = ''
203
-
204
232
  # Detect and split special sentence types: e.g., '[APPLAUSE] &gt;&gt; MIRA MURATI:' -> ['[APPLAUSE]', '&gt;&gt; MIRA MURATI:'] # noqa: E501
205
233
  resplit_parts = self._resplit_special_sentence_types(_sentence)
206
- if any(resplit_parts[-1].endswith(sp) for sp in [':', ':']):
234
+ if any(resplit_parts[-1].endswith(sp) for sp in [':', ':', ']']):
207
235
  if s < len(_sentences) - 1:
208
236
  _sentences[s + 1] = resplit_parts[-1] + ' ' + _sentences[s + 1]
209
237
  else: # last part
210
- remainder = resplit_parts[-1] + ' ' + remainder
238
+ remainder = resplit_parts[-1] + ' '
211
239
  processed_sentences.extend(resplit_parts[:-1])
212
240
  else:
213
241
  processed_sentences.extend(resplit_parts)
214
-
215
242
  _sentences = processed_sentences
216
243
 
217
- if remainder:
218
- _sentences[0] = remainder + _sentences[0]
219
- remainder = ''
244
+ if not _sentences:
245
+ if remainder:
246
+ _sentences, remainder = [remainder.strip()], ''
247
+ else:
248
+ continue
220
249
 
221
250
  if any(_sentences[-1].endswith(ep) for ep in END_PUNCTUATION):
222
- supervisions.extend(Supervision(text=s) for s in _sentences)
251
+ supervisions.extend(
252
+ Supervision(text=text, speaker=(_speaker if s == 0 else None)) for s, text in enumerate(_sentences)
253
+ )
254
+ _speaker = None # reset speaker after use
223
255
  else:
224
- supervisions.extend(Supervision(text=s) for s in _sentences[:-1])
225
- remainder += _sentences[-1] + ' '
256
+ supervisions.extend(
257
+ Supervision(text=text, speaker=(_speaker if s == 0 else None))
258
+ for s, text in enumerate(_sentences[:-1])
259
+ )
260
+ remainder = _sentences[-1] + ' ' + remainder
261
+ if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
262
+ supervisions.append(
263
+ Supervision(text=remainder.strip(), speaker=_speaker if len(_sentences) == 1 else None)
264
+ )
265
+ remainder = ''
266
+ elif len(_sentences) == 1:
267
+ if k == len(speakers) - 1:
268
+ pass # keep _speaker for the last supervision
269
+ else:
270
+ assert speakers[k + 1] is None
271
+ speakers[k + 1] = _speaker
272
+ else:
273
+ assert len(_sentences) > 1
274
+ _speaker = None # reset speaker if sentence not ended
226
275
 
227
276
  if remainder.strip():
228
- supervisions.append(Supervision(text=remainder.strip()))
277
+ supervisions.append(Supervision(text=remainder.strip(), speaker=_speaker))
229
278
 
230
279
  return supervisions
231
280
 
@@ -246,14 +295,18 @@ class LatticeTokenizer:
246
295
  raise Exception(f'Failed to tokenize texts: {response.text}')
247
296
  result = response.json()
248
297
  lattice_id = result['id']
249
- return lattice_id, (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0))
298
+ return (
299
+ supervisions,
300
+ lattice_id,
301
+ (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
302
+ )
250
303
 
251
304
  def detokenize(
252
305
  self,
253
306
  lattice_id: str,
254
307
  lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
255
- # return_supervisions: bool = True,
256
- # return_details: bool = False,
308
+ supervisions: List[Supervision],
309
+ return_details: bool = False,
257
310
  ) -> List[Supervision]:
258
311
  emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
259
312
  response = self.client_wrapper.post(
@@ -265,22 +318,157 @@ class LatticeTokenizer:
265
318
  'labels': labels[0],
266
319
  'offset': offset,
267
320
  'channel': channel,
321
+ 'return_details': return_details,
268
322
  'destroy_lattice': True,
269
323
  },
270
324
  )
325
+ if response.status_code == 422:
326
+ raise LatticeDecodingError(
327
+ lattice_id,
328
+ original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
329
+ )
271
330
  if response.status_code != 200:
272
331
  raise Exception(f'Failed to detokenize lattice: {response.text}')
332
+
273
333
  result = response.json()
274
334
  if not result.get('success'):
275
- return Exception('Failed to detokenize the alignment results.')
276
- # if return_details:
277
- # raise NotImplementedError("return_details is not implemented yet")
278
- return [Supervision.from_dict(s) for s in result['supervisions']]
335
+ raise Exception('Failed to detokenize the alignment results.')
336
+
337
+ alignments = [Supervision.from_dict(s) for s in result['supervisions']]
338
+
339
+ if return_details:
340
+ # Add emission confidence scores for segments and word-level alignments
341
+ _add_confidence_scores(alignments, emission, labels[0], frame_shift)
342
+
343
+ alignments = _update_alignments_speaker(supervisions, alignments)
344
+
345
+ return alignments
279
346
 
280
347
 
281
- # Compute average score weighted by the span length
282
- def _score(spans):
283
- if not spans:
284
- return 0.0
285
- # TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
286
- return round(sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans), ndigits=4)
348
+ class AsyncLatticeTokenizer(LatticeTokenizer):
349
+ async def _post_async(self, endpoint: str, **kwargs):
350
+ response = self.client_wrapper.post(endpoint, **kwargs)
351
+ if inspect.isawaitable(response):
352
+ return await response
353
+ return response
354
+
355
+ async def tokenize(
356
+ self, supervisions: List[Supervision], split_sentence: bool = False
357
+ ) -> Tuple[str, Dict[str, Any]]:
358
+ if split_sentence:
359
+ self.init_sentence_splitter()
360
+ supervisions = self.split_sentences(supervisions)
361
+
362
+ pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
363
+ response = await self._post_async(
364
+ 'tokenize',
365
+ json={
366
+ 'supervisions': [s.to_dict() for s in supervisions],
367
+ 'pronunciation_dictionaries': pronunciation_dictionaries,
368
+ },
369
+ )
370
+ if response.status_code != 200:
371
+ raise Exception(f'Failed to tokenize texts: {response.text}')
372
+ result = response.json()
373
+ lattice_id = result['id']
374
+ return (
375
+ supervisions,
376
+ lattice_id,
377
+ (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
378
+ )
379
+
380
+ async def detokenize(
381
+ self,
382
+ lattice_id: str,
383
+ lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
384
+ supervisions: List[Supervision],
385
+ return_details: bool = False,
386
+ ) -> List[Supervision]:
387
+ emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
388
+ response = await self._post_async(
389
+ 'detokenize',
390
+ json={
391
+ 'lattice_id': lattice_id,
392
+ 'frame_shift': frame_shift,
393
+ 'results': [t.to_dict() for t in results[0]],
394
+ 'labels': labels[0],
395
+ 'offset': offset,
396
+ 'channel': channel,
397
+ 'return_details': return_details,
398
+ 'destroy_lattice': True,
399
+ },
400
+ )
401
+ if response.status_code == 422:
402
+ raise LatticeDecodingError(
403
+ lattice_id,
404
+ original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
405
+ )
406
+ if response.status_code != 200:
407
+ raise Exception(f'Failed to detokenize lattice: {response.text}')
408
+
409
+ result = response.json()
410
+ if not result.get('success'):
411
+ return Exception('Failed to detokenize the alignment results.')
412
+
413
+ alignments = [Supervision.from_dict(s) for s in result['supervisions']]
414
+
415
+ if return_details:
416
+ # Add emission confidence scores for segments and word-level alignments
417
+ _add_confidence_scores(alignments, emission, labels[0], frame_shift)
418
+
419
+ alignments = _update_alignments_speaker(supervisions, alignments)
420
+
421
+ return alignments
422
+
423
+
424
+ def _add_confidence_scores(
425
+ supervisions: List[Supervision],
426
+ emission: torch.Tensor,
427
+ labels: List[int],
428
+ frame_shift: float,
429
+ ) -> None:
430
+ """
431
+ Add confidence scores to supervisions and their word-level alignments.
432
+
433
+ This function modifies supervisions in-place by:
434
+ 1. Computing segment-level confidence scores based on emission probabilities
435
+ 2. Computing word-level confidence scores for each aligned word
436
+
437
+ Args:
438
+ supervisions: List of Supervision objects to add scores to (modified in-place)
439
+ emission: Emission tensor with shape [batch, time, vocab_size]
440
+ labels: Token labels corresponding to aligned tokens
441
+ frame_shift: Frame shift in seconds for converting frames to time
442
+ """
443
+ tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
444
+
445
+ for supervision in supervisions:
446
+ start_frame = int(supervision.start / frame_shift)
447
+ end_frame = int(supervision.end / frame_shift)
448
+
449
+ # Compute segment-level confidence
450
+ probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
451
+ aligned = probabilities[range(0, end_frame - start_frame), tokens[start_frame:end_frame]]
452
+ diffprobs = (probabilities.max(dim=-1).values - aligned).cpu()
453
+ supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
454
+
455
+ # Compute word-level confidence if alignment exists
456
+ if hasattr(supervision, 'alignment') and supervision.alignment:
457
+ words = supervision.alignment.get('word', [])
458
+ for w, item in enumerate(words):
459
+ start = int(item.start / frame_shift) - start_frame
460
+ end = int(item.end / frame_shift) - start_frame
461
+ words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
462
+
463
+
464
+ def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]:
465
+ """
466
+ Update the speaker attribute for a list of supervisions.
467
+
468
+ Args:
469
+ supervisions: List of Supervision objects to get speaker info from
470
+ alignments: List of aligned Supervision objects to update speaker info to
471
+ """
472
+ for supervision, alignment in zip(supervisions, alignments):
473
+ alignment.speaker = supervision.speaker
474
+ return alignments
lattifai/utils.py ADDED
@@ -0,0 +1,133 @@
1
+ """Shared utility helpers for the LattifAI SDK."""
2
+
3
+ import os
4
+ from datetime import datetime, timedelta
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Type
7
+
8
+ from lattifai.errors import ModelLoadError
9
+ from lattifai.tokenizer import LatticeTokenizer
10
+ from lattifai.workers import Lattice1AlphaWorker
11
+
12
+
13
+ def _get_cache_marker_path(cache_dir: Path) -> Path:
14
+ """Get the path for the cache marker file with current date."""
15
+ today = datetime.now().strftime('%Y%m%d')
16
+ return cache_dir / f'.done{today}'
17
+
18
+
19
+ def _is_cache_valid(cache_dir: Path) -> bool:
20
+ """Check if cached model is valid (exists and not older than 1 days)."""
21
+ if not cache_dir.exists():
22
+ return False
23
+
24
+ # Find any .done* marker files
25
+ marker_files = list(cache_dir.glob('.done*'))
26
+ if not marker_files:
27
+ return False
28
+
29
+ # Get the most recent marker file
30
+ latest_marker = max(marker_files, key=lambda p: p.stat().st_mtime)
31
+
32
+ # Extract date from marker filename (format: .doneYYYYMMDD)
33
+ try:
34
+ date_str = latest_marker.name.replace('.done', '')
35
+ marker_date = datetime.strptime(date_str, '%Y%m%d')
36
+ # Check if marker is older than 1 days
37
+ if datetime.now() - marker_date > timedelta(days=1):
38
+ return False
39
+ return True
40
+ except (ValueError, IndexError):
41
+ # Invalid marker file format, treat as invalid cache
42
+ return False
43
+
44
+
45
+ def _create_cache_marker(cache_dir: Path) -> None:
46
+ """Create a cache marker file with current date and clean old markers."""
47
+ # Remove old marker files
48
+ for old_marker in cache_dir.glob('.done*'):
49
+ old_marker.unlink(missing_ok=True)
50
+
51
+ # Create new marker file
52
+ marker_path = _get_cache_marker_path(cache_dir)
53
+ marker_path.touch()
54
+
55
+
56
+ def _resolve_model_path(model_name_or_path: str) -> str:
57
+ """Resolve model path, downloading from Hugging Face when necessary."""
58
+ if Path(model_name_or_path).exists():
59
+ return model_name_or_path
60
+
61
+ from huggingface_hub import snapshot_download
62
+ from huggingface_hub.constants import HF_HUB_CACHE
63
+ from huggingface_hub.errors import LocalEntryNotFoundError
64
+
65
+ # Determine cache directory for this model
66
+ cache_dir = Path(HF_HUB_CACHE) / f'models--{model_name_or_path.replace("/", "--")}'
67
+
68
+ # Check if we have a valid cached version
69
+ if _is_cache_valid(cache_dir):
70
+ # Return the snapshot path (latest version)
71
+ snapshots_dir = cache_dir / 'snapshots'
72
+ if snapshots_dir.exists():
73
+ snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]
74
+ if snapshot_dirs:
75
+ # Return the most recent snapshot
76
+ latest_snapshot = max(snapshot_dirs, key=lambda p: p.stat().st_mtime)
77
+ return str(latest_snapshot)
78
+
79
+ try:
80
+ downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
81
+ _create_cache_marker(cache_dir)
82
+ return downloaded_path
83
+ except LocalEntryNotFoundError:
84
+ try:
85
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
86
+ downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
87
+ _create_cache_marker(cache_dir)
88
+ return downloaded_path
89
+ except Exception as e: # pragma: no cover - bubble up for caller context
90
+ raise ModelLoadError(model_name_or_path, original_error=e)
91
+ except Exception as e: # pragma: no cover - unexpected download issue
92
+ raise ModelLoadError(model_name_or_path, original_error=e)
93
+
94
+
95
+ def _select_device(device: Optional[str]) -> str:
96
+ """Select best available torch device when not explicitly provided."""
97
+ if device:
98
+ return device
99
+
100
+ import torch
101
+
102
+ detected = 'cpu'
103
+ if torch.backends.mps.is_available():
104
+ detected = 'mps'
105
+ elif torch.cuda.is_available():
106
+ detected = 'cuda'
107
+ return detected
108
+
109
+
110
+ def _load_tokenizer(
111
+ client_wrapper: Any,
112
+ model_path: str,
113
+ device: str,
114
+ *,
115
+ tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
116
+ ) -> LatticeTokenizer:
117
+ """Instantiate tokenizer with consistent error handling."""
118
+ try:
119
+ return tokenizer_cls.from_pretrained(
120
+ client_wrapper=client_wrapper,
121
+ model_path=model_path,
122
+ device=device,
123
+ )
124
+ except Exception as e:
125
+ raise ModelLoadError(f'tokenizer from {model_path}', original_error=e)
126
+
127
+
128
+ def _load_worker(model_path: str, device: str) -> Lattice1AlphaWorker:
129
+ """Instantiate lattice worker with consistent error handling."""
130
+ try:
131
+ return Lattice1AlphaWorker(model_path, device=device, num_threads=8)
132
+ except Exception as e:
133
+ raise ModelLoadError(f'worker from {model_path}', original_error=e)