lattifai 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +26 -27
- lattifai/base_client.py +7 -7
- lattifai/bin/agent.py +90 -91
- lattifai/bin/align.py +110 -111
- lattifai/bin/cli_base.py +3 -3
- lattifai/bin/subtitle.py +45 -45
- lattifai/client.py +56 -56
- lattifai/errors.py +73 -73
- lattifai/io/__init__.py +12 -11
- lattifai/io/gemini_reader.py +30 -30
- lattifai/io/gemini_writer.py +17 -17
- lattifai/io/reader.py +13 -12
- lattifai/io/supervision.py +3 -3
- lattifai/io/text_parser.py +43 -16
- lattifai/io/utils.py +4 -4
- lattifai/io/writer.py +31 -19
- lattifai/tokenizer/__init__.py +1 -1
- lattifai/tokenizer/phonemizer.py +3 -3
- lattifai/tokenizer/tokenizer.py +83 -82
- lattifai/utils.py +15 -15
- lattifai/workers/__init__.py +1 -1
- lattifai/workers/lattice1_alpha.py +46 -46
- lattifai/workflows/__init__.py +11 -11
- lattifai/workflows/agents.py +2 -0
- lattifai/workflows/base.py +22 -22
- lattifai/workflows/file_manager.py +182 -182
- lattifai/workflows/gemini.py +29 -29
- lattifai/workflows/prompts/__init__.py +4 -4
- lattifai/workflows/youtube.py +233 -233
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/METADATA +7 -9
- lattifai-0.4.6.dist-info/RECORD +39 -0
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/licenses/LICENSE +1 -1
- lattifai-0.4.5.dist-info/RECORD +0 -39
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/WHEEL +0 -0
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/entry_points.txt +0 -0
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/top_level.txt +0 -0
lattifai/tokenizer/tokenizer.py
CHANGED
|
@@ -3,25 +3,25 @@ import inspect
|
|
|
3
3
|
import pickle
|
|
4
4
|
import re
|
|
5
5
|
from collections import defaultdict
|
|
6
|
-
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
10
|
from lattifai.errors import LATTICE_DECODING_FAILURE_HELP, LatticeDecodingError
|
|
11
|
-
from lattifai.io import Supervision
|
|
11
|
+
from lattifai.io import Supervision, normalize_html_text
|
|
12
12
|
from lattifai.tokenizer.phonemizer import G2Phonemizer
|
|
13
13
|
|
|
14
14
|
PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
|
|
15
15
|
END_PUNCTUATION = '.!?"]。!?”】'
|
|
16
|
-
PUNCTUATION_SPACE = PUNCTUATION +
|
|
17
|
-
STAR_TOKEN =
|
|
16
|
+
PUNCTUATION_SPACE = PUNCTUATION + " "
|
|
17
|
+
STAR_TOKEN = "※"
|
|
18
18
|
|
|
19
|
-
GROUPING_SEPARATOR =
|
|
19
|
+
GROUPING_SEPARATOR = "✹"
|
|
20
20
|
|
|
21
21
|
MAXIMUM_WORD_LENGTH = 40
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
TokenizerT = TypeVar(
|
|
24
|
+
TokenizerT = TypeVar("TokenizerT", bound="LatticeTokenizer")
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class LatticeTokenizer:
|
|
@@ -32,9 +32,9 @@ class LatticeTokenizer:
|
|
|
32
32
|
self.words: List[str] = []
|
|
33
33
|
self.g2p_model: Any = None # Placeholder for G2P model
|
|
34
34
|
self.dictionaries = defaultdict(lambda: [])
|
|
35
|
-
self.oov_word =
|
|
35
|
+
self.oov_word = "<unk>"
|
|
36
36
|
self.sentence_splitter = None
|
|
37
|
-
self.device =
|
|
37
|
+
self.device = "cpu"
|
|
38
38
|
|
|
39
39
|
def init_sentence_splitter(self):
|
|
40
40
|
if self.sentence_splitter is not None:
|
|
@@ -45,14 +45,14 @@ class LatticeTokenizer:
|
|
|
45
45
|
|
|
46
46
|
providers = []
|
|
47
47
|
device = self.device
|
|
48
|
-
if device.startswith(
|
|
49
|
-
providers.append(
|
|
50
|
-
elif device.startswith(
|
|
51
|
-
providers.append(
|
|
48
|
+
if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
|
|
49
|
+
providers.append("CUDAExecutionProvider")
|
|
50
|
+
elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
|
|
51
|
+
providers.append("MPSExecutionProvider")
|
|
52
52
|
|
|
53
53
|
sat = SaT(
|
|
54
|
-
|
|
55
|
-
ort_providers=providers + [
|
|
54
|
+
"sat-3l-sm",
|
|
55
|
+
ort_providers=providers + ["CPUExecutionProvider"],
|
|
56
56
|
)
|
|
57
57
|
self.sentence_splitter = sat
|
|
58
58
|
|
|
@@ -79,23 +79,23 @@ class LatticeTokenizer:
|
|
|
79
79
|
# or other forms like [SOMETHING] SPEAKER:
|
|
80
80
|
|
|
81
81
|
# Pattern 1: [mark] HTML-encoded separator speaker:
|
|
82
|
-
pattern1 = r
|
|
82
|
+
pattern1 = r"^(\[[^\]]+\])\s+(>>|>>)\s+(.+)$"
|
|
83
83
|
match1 = re.match(pattern1, sentence.strip())
|
|
84
84
|
if match1:
|
|
85
85
|
special_mark = match1.group(1)
|
|
86
86
|
separator = match1.group(2)
|
|
87
87
|
speaker_part = match1.group(3)
|
|
88
|
-
return [special_mark, f
|
|
88
|
+
return [special_mark, f"{separator} {speaker_part}"]
|
|
89
89
|
|
|
90
90
|
# Pattern 2: [mark] speaker:
|
|
91
|
-
pattern2 = r
|
|
91
|
+
pattern2 = r"^(\[[^\]]+\])\s+([^:]+:)(.*)$"
|
|
92
92
|
match2 = re.match(pattern2, sentence.strip())
|
|
93
93
|
if match2:
|
|
94
94
|
special_mark = match2.group(1)
|
|
95
95
|
speaker_label = match2.group(2)
|
|
96
96
|
remaining = match2.group(3).strip()
|
|
97
97
|
if remaining:
|
|
98
|
-
return [special_mark, f
|
|
98
|
+
return [special_mark, f"{speaker_label} {remaining}"]
|
|
99
99
|
else:
|
|
100
100
|
return [special_mark, speaker_label]
|
|
101
101
|
|
|
@@ -107,26 +107,26 @@ class LatticeTokenizer:
|
|
|
107
107
|
cls: Type[TokenizerT],
|
|
108
108
|
client_wrapper: Any,
|
|
109
109
|
model_path: str,
|
|
110
|
-
device: str =
|
|
110
|
+
device: str = "cpu",
|
|
111
111
|
compressed: bool = True,
|
|
112
112
|
) -> TokenizerT:
|
|
113
113
|
"""Load tokenizer from exported binary file"""
|
|
114
114
|
from pathlib import Path
|
|
115
115
|
|
|
116
|
-
words_model_path = f
|
|
116
|
+
words_model_path = f"{model_path}/words.bin"
|
|
117
117
|
if compressed:
|
|
118
|
-
with gzip.open(words_model_path,
|
|
118
|
+
with gzip.open(words_model_path, "rb") as f:
|
|
119
119
|
data = pickle.load(f)
|
|
120
120
|
else:
|
|
121
|
-
with open(words_model_path,
|
|
121
|
+
with open(words_model_path, "rb") as f:
|
|
122
122
|
data = pickle.load(f)
|
|
123
123
|
|
|
124
124
|
tokenizer = cls(client_wrapper=client_wrapper)
|
|
125
|
-
tokenizer.words = data[
|
|
126
|
-
tokenizer.dictionaries = defaultdict(list, data[
|
|
127
|
-
tokenizer.oov_word = data[
|
|
125
|
+
tokenizer.words = data["words"]
|
|
126
|
+
tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
|
|
127
|
+
tokenizer.oov_word = data["oov_word"]
|
|
128
128
|
|
|
129
|
-
g2p_model_path = f
|
|
129
|
+
g2p_model_path = f"{model_path}/g2p.bin" if Path(f"{model_path}/g2p.bin").exists() else None
|
|
130
130
|
if g2p_model_path:
|
|
131
131
|
tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
|
|
132
132
|
|
|
@@ -136,18 +136,19 @@ class LatticeTokenizer:
|
|
|
136
136
|
|
|
137
137
|
def add_special_tokens(self):
|
|
138
138
|
tokenizer = self
|
|
139
|
-
for special_token in [
|
|
139
|
+
for special_token in [">>", ">"]:
|
|
140
140
|
if special_token not in tokenizer.dictionaries:
|
|
141
141
|
tokenizer.dictionaries[special_token] = tokenizer.dictionaries[tokenizer.oov_word]
|
|
142
142
|
return self
|
|
143
143
|
|
|
144
144
|
def prenormalize(self, texts: List[str], language: Optional[str] = None) -> List[str]:
|
|
145
145
|
if not self.g2p_model:
|
|
146
|
-
raise ValueError(
|
|
146
|
+
raise ValueError("G2P model is not loaded, cannot prenormalize texts")
|
|
147
147
|
|
|
148
148
|
oov_words = []
|
|
149
149
|
for text in texts:
|
|
150
|
-
|
|
150
|
+
text = normalize_html_text(text)
|
|
151
|
+
words = text.lower().replace("-", " ").replace("—", " ").replace("–", " ").split()
|
|
151
152
|
oovs = [w.strip(PUNCTUATION) for w in words if w not in self.words]
|
|
152
153
|
if oovs:
|
|
153
154
|
oov_words.extend([w for w in oovs if (w not in self.words and len(w) <= MAXIMUM_WORD_LENGTH)])
|
|
@@ -156,7 +157,7 @@ class LatticeTokenizer:
|
|
|
156
157
|
if oov_words:
|
|
157
158
|
indexs = []
|
|
158
159
|
for k, _word in enumerate(oov_words):
|
|
159
|
-
if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [(
|
|
160
|
+
if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [("(", ")"), ("[", "]")]):
|
|
160
161
|
self.dictionaries[_word] = self.dictionaries[self.oov_word]
|
|
161
162
|
else:
|
|
162
163
|
_word = _word.strip(PUNCTUATION_SPACE)
|
|
@@ -195,7 +196,7 @@ class LatticeTokenizer:
|
|
|
195
196
|
if sidx < s:
|
|
196
197
|
if len(speakers) < len(texts) + 1:
|
|
197
198
|
speakers.append(None)
|
|
198
|
-
text =
|
|
199
|
+
text = " ".join([sup.text for sup in supervisions[sidx:s]])
|
|
199
200
|
texts.append(text)
|
|
200
201
|
sidx = s
|
|
201
202
|
text_len = len(supervision.text)
|
|
@@ -205,20 +206,20 @@ class LatticeTokenizer:
|
|
|
205
206
|
if text_len >= 2000 or s == len(supervisions) - 1:
|
|
206
207
|
if len(speakers) < len(texts) + 1:
|
|
207
208
|
speakers.append(None)
|
|
208
|
-
text =
|
|
209
|
+
text = " ".join([sup.text for sup in supervisions[sidx : s + 1]])
|
|
209
210
|
texts.append(text)
|
|
210
211
|
sidx = s + 1
|
|
211
212
|
text_len = 0
|
|
212
213
|
|
|
213
|
-
assert len(speakers) == len(texts), f
|
|
214
|
+
assert len(speakers) == len(texts), f"len(speakers)={len(speakers)} != len(texts)={len(texts)}"
|
|
214
215
|
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
|
|
215
216
|
|
|
216
|
-
supervisions, remainder = [],
|
|
217
|
+
supervisions, remainder = [], ""
|
|
217
218
|
for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
|
|
218
219
|
# Prepend remainder from previous iteration to the first sentence
|
|
219
220
|
if _sentences and remainder:
|
|
220
221
|
_sentences[0] = remainder + _sentences[0]
|
|
221
|
-
remainder =
|
|
222
|
+
remainder = ""
|
|
222
223
|
|
|
223
224
|
if not _sentences:
|
|
224
225
|
continue
|
|
@@ -228,14 +229,14 @@ class LatticeTokenizer:
|
|
|
228
229
|
for s, _sentence in enumerate(_sentences):
|
|
229
230
|
if remainder:
|
|
230
231
|
_sentence = remainder + _sentence
|
|
231
|
-
remainder =
|
|
232
|
+
remainder = ""
|
|
232
233
|
# Detect and split special sentence types: e.g., '[APPLAUSE] >> MIRA MURATI:' -> ['[APPLAUSE]', '>> MIRA MURATI:'] # noqa: E501
|
|
233
234
|
resplit_parts = self._resplit_special_sentence_types(_sentence)
|
|
234
|
-
if any(resplit_parts[-1].endswith(sp) for sp in [
|
|
235
|
+
if any(resplit_parts[-1].endswith(sp) for sp in [":", ":"]):
|
|
235
236
|
if s < len(_sentences) - 1:
|
|
236
|
-
_sentences[s + 1] = resplit_parts[-1] +
|
|
237
|
+
_sentences[s + 1] = resplit_parts[-1] + " " + _sentences[s + 1]
|
|
237
238
|
else: # last part
|
|
238
|
-
remainder = resplit_parts[-1] +
|
|
239
|
+
remainder = resplit_parts[-1] + " "
|
|
239
240
|
processed_sentences.extend(resplit_parts[:-1])
|
|
240
241
|
else:
|
|
241
242
|
processed_sentences.extend(resplit_parts)
|
|
@@ -243,7 +244,7 @@ class LatticeTokenizer:
|
|
|
243
244
|
|
|
244
245
|
if not _sentences:
|
|
245
246
|
if remainder:
|
|
246
|
-
_sentences, remainder = [remainder.strip()],
|
|
247
|
+
_sentences, remainder = [remainder.strip()], ""
|
|
247
248
|
else:
|
|
248
249
|
continue
|
|
249
250
|
|
|
@@ -257,12 +258,12 @@ class LatticeTokenizer:
|
|
|
257
258
|
Supervision(text=text, speaker=(_speaker if s == 0 else None))
|
|
258
259
|
for s, text in enumerate(_sentences[:-1])
|
|
259
260
|
)
|
|
260
|
-
remainder = _sentences[-1] +
|
|
261
|
+
remainder = _sentences[-1] + " " + remainder
|
|
261
262
|
if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
|
|
262
263
|
supervisions.append(
|
|
263
264
|
Supervision(text=remainder.strip(), speaker=_speaker if len(_sentences) == 1 else None)
|
|
264
265
|
)
|
|
265
|
-
remainder =
|
|
266
|
+
remainder = ""
|
|
266
267
|
elif len(_sentences) == 1:
|
|
267
268
|
if k == len(speakers) - 1:
|
|
268
269
|
pass # keep _speaker for the last supervision
|
|
@@ -285,20 +286,20 @@ class LatticeTokenizer:
|
|
|
285
286
|
|
|
286
287
|
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
287
288
|
response = self.client_wrapper.post(
|
|
288
|
-
|
|
289
|
+
"tokenize",
|
|
289
290
|
json={
|
|
290
|
-
|
|
291
|
-
|
|
291
|
+
"supervisions": [s.to_dict() for s in supervisions],
|
|
292
|
+
"pronunciation_dictionaries": pronunciation_dictionaries,
|
|
292
293
|
},
|
|
293
294
|
)
|
|
294
295
|
if response.status_code != 200:
|
|
295
|
-
raise Exception(f
|
|
296
|
+
raise Exception(f"Failed to tokenize texts: {response.text}")
|
|
296
297
|
result = response.json()
|
|
297
|
-
lattice_id = result[
|
|
298
|
+
lattice_id = result["id"]
|
|
298
299
|
return (
|
|
299
300
|
supervisions,
|
|
300
301
|
lattice_id,
|
|
301
|
-
(result[
|
|
302
|
+
(result["lattice_graph"], result["final_state"], result.get("acoustic_scale", 1.0)),
|
|
302
303
|
)
|
|
303
304
|
|
|
304
305
|
def detokenize(
|
|
@@ -310,16 +311,16 @@ class LatticeTokenizer:
|
|
|
310
311
|
) -> List[Supervision]:
|
|
311
312
|
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
312
313
|
response = self.client_wrapper.post(
|
|
313
|
-
|
|
314
|
+
"detokenize",
|
|
314
315
|
json={
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
316
|
+
"lattice_id": lattice_id,
|
|
317
|
+
"frame_shift": frame_shift,
|
|
318
|
+
"results": [t.to_dict() for t in results[0]],
|
|
319
|
+
"labels": labels[0],
|
|
320
|
+
"offset": offset,
|
|
321
|
+
"channel": channel,
|
|
322
|
+
"return_details": return_details,
|
|
323
|
+
"destroy_lattice": True,
|
|
323
324
|
},
|
|
324
325
|
)
|
|
325
326
|
if response.status_code == 422:
|
|
@@ -328,13 +329,13 @@ class LatticeTokenizer:
|
|
|
328
329
|
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
329
330
|
)
|
|
330
331
|
if response.status_code != 200:
|
|
331
|
-
raise Exception(f
|
|
332
|
+
raise Exception(f"Failed to detokenize lattice: {response.text}")
|
|
332
333
|
|
|
333
334
|
result = response.json()
|
|
334
|
-
if not result.get(
|
|
335
|
-
raise Exception(
|
|
335
|
+
if not result.get("success"):
|
|
336
|
+
raise Exception("Failed to detokenize the alignment results.")
|
|
336
337
|
|
|
337
|
-
alignments = [Supervision.from_dict(s) for s in result[
|
|
338
|
+
alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
|
|
338
339
|
|
|
339
340
|
if return_details:
|
|
340
341
|
# Add emission confidence scores for segments and word-level alignments
|
|
@@ -361,20 +362,20 @@ class AsyncLatticeTokenizer(LatticeTokenizer):
|
|
|
361
362
|
|
|
362
363
|
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
363
364
|
response = await self._post_async(
|
|
364
|
-
|
|
365
|
+
"tokenize",
|
|
365
366
|
json={
|
|
366
|
-
|
|
367
|
-
|
|
367
|
+
"supervisions": [s.to_dict() for s in supervisions],
|
|
368
|
+
"pronunciation_dictionaries": pronunciation_dictionaries,
|
|
368
369
|
},
|
|
369
370
|
)
|
|
370
371
|
if response.status_code != 200:
|
|
371
|
-
raise Exception(f
|
|
372
|
+
raise Exception(f"Failed to tokenize texts: {response.text}")
|
|
372
373
|
result = response.json()
|
|
373
|
-
lattice_id = result[
|
|
374
|
+
lattice_id = result["id"]
|
|
374
375
|
return (
|
|
375
376
|
supervisions,
|
|
376
377
|
lattice_id,
|
|
377
|
-
(result[
|
|
378
|
+
(result["lattice_graph"], result["final_state"], result.get("acoustic_scale", 1.0)),
|
|
378
379
|
)
|
|
379
380
|
|
|
380
381
|
async def detokenize(
|
|
@@ -386,16 +387,16 @@ class AsyncLatticeTokenizer(LatticeTokenizer):
|
|
|
386
387
|
) -> List[Supervision]:
|
|
387
388
|
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
388
389
|
response = await self._post_async(
|
|
389
|
-
|
|
390
|
+
"detokenize",
|
|
390
391
|
json={
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
392
|
+
"lattice_id": lattice_id,
|
|
393
|
+
"frame_shift": frame_shift,
|
|
394
|
+
"results": [t.to_dict() for t in results[0]],
|
|
395
|
+
"labels": labels[0],
|
|
396
|
+
"offset": offset,
|
|
397
|
+
"channel": channel,
|
|
398
|
+
"return_details": return_details,
|
|
399
|
+
"destroy_lattice": True,
|
|
399
400
|
},
|
|
400
401
|
)
|
|
401
402
|
if response.status_code == 422:
|
|
@@ -404,13 +405,13 @@ class AsyncLatticeTokenizer(LatticeTokenizer):
|
|
|
404
405
|
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
405
406
|
)
|
|
406
407
|
if response.status_code != 200:
|
|
407
|
-
raise Exception(f
|
|
408
|
+
raise Exception(f"Failed to detokenize lattice: {response.text}")
|
|
408
409
|
|
|
409
410
|
result = response.json()
|
|
410
|
-
if not result.get(
|
|
411
|
-
return Exception(
|
|
411
|
+
if not result.get("success"):
|
|
412
|
+
return Exception("Failed to detokenize the alignment results.")
|
|
412
413
|
|
|
413
|
-
alignments = [Supervision.from_dict(s) for s in result[
|
|
414
|
+
alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
|
|
414
415
|
|
|
415
416
|
if return_details:
|
|
416
417
|
# Add emission confidence scores for segments and word-level alignments
|
|
@@ -453,8 +454,8 @@ def _add_confidence_scores(
|
|
|
453
454
|
supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
|
|
454
455
|
|
|
455
456
|
# Compute word-level confidence if alignment exists
|
|
456
|
-
if hasattr(supervision,
|
|
457
|
-
words = supervision.alignment.get(
|
|
457
|
+
if hasattr(supervision, "alignment") and supervision.alignment:
|
|
458
|
+
words = supervision.alignment.get("word", [])
|
|
458
459
|
for w, item in enumerate(words):
|
|
459
460
|
start = int(item.start / frame_shift) - start_frame
|
|
460
461
|
end = int(item.end / frame_shift) - start_frame
|
lattifai/utils.py
CHANGED
|
@@ -12,8 +12,8 @@ from lattifai.workers import Lattice1AlphaWorker
|
|
|
12
12
|
|
|
13
13
|
def _get_cache_marker_path(cache_dir: Path) -> Path:
|
|
14
14
|
"""Get the path for the cache marker file with current date."""
|
|
15
|
-
today = datetime.now().strftime(
|
|
16
|
-
return cache_dir / f
|
|
15
|
+
today = datetime.now().strftime("%Y%m%d")
|
|
16
|
+
return cache_dir / f".done{today}"
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def _is_cache_valid(cache_dir: Path) -> bool:
|
|
@@ -22,7 +22,7 @@ def _is_cache_valid(cache_dir: Path) -> bool:
|
|
|
22
22
|
return False
|
|
23
23
|
|
|
24
24
|
# Find any .done* marker files
|
|
25
|
-
marker_files = list(cache_dir.glob(
|
|
25
|
+
marker_files = list(cache_dir.glob(".done*"))
|
|
26
26
|
if not marker_files:
|
|
27
27
|
return False
|
|
28
28
|
|
|
@@ -31,8 +31,8 @@ def _is_cache_valid(cache_dir: Path) -> bool:
|
|
|
31
31
|
|
|
32
32
|
# Extract date from marker filename (format: .doneYYYYMMDD)
|
|
33
33
|
try:
|
|
34
|
-
date_str = latest_marker.name.replace(
|
|
35
|
-
marker_date = datetime.strptime(date_str,
|
|
34
|
+
date_str = latest_marker.name.replace(".done", "")
|
|
35
|
+
marker_date = datetime.strptime(date_str, "%Y%m%d")
|
|
36
36
|
# Check if marker is older than 1 days
|
|
37
37
|
if datetime.now() - marker_date > timedelta(days=1):
|
|
38
38
|
return False
|
|
@@ -45,7 +45,7 @@ def _is_cache_valid(cache_dir: Path) -> bool:
|
|
|
45
45
|
def _create_cache_marker(cache_dir: Path) -> None:
|
|
46
46
|
"""Create a cache marker file with current date and clean old markers."""
|
|
47
47
|
# Remove old marker files
|
|
48
|
-
for old_marker in cache_dir.glob(
|
|
48
|
+
for old_marker in cache_dir.glob(".done*"):
|
|
49
49
|
old_marker.unlink(missing_ok=True)
|
|
50
50
|
|
|
51
51
|
# Create new marker file
|
|
@@ -68,7 +68,7 @@ def _resolve_model_path(model_name_or_path: str) -> str:
|
|
|
68
68
|
# Check if we have a valid cached version
|
|
69
69
|
if _is_cache_valid(cache_dir):
|
|
70
70
|
# Return the snapshot path (latest version)
|
|
71
|
-
snapshots_dir = cache_dir /
|
|
71
|
+
snapshots_dir = cache_dir / "snapshots"
|
|
72
72
|
if snapshots_dir.exists():
|
|
73
73
|
snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]
|
|
74
74
|
if snapshot_dirs:
|
|
@@ -77,13 +77,13 @@ def _resolve_model_path(model_name_or_path: str) -> str:
|
|
|
77
77
|
return str(latest_snapshot)
|
|
78
78
|
|
|
79
79
|
try:
|
|
80
|
-
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type=
|
|
80
|
+
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type="model")
|
|
81
81
|
_create_cache_marker(cache_dir)
|
|
82
82
|
return downloaded_path
|
|
83
83
|
except LocalEntryNotFoundError:
|
|
84
84
|
try:
|
|
85
|
-
os.environ[
|
|
86
|
-
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type=
|
|
85
|
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
86
|
+
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type="model")
|
|
87
87
|
_create_cache_marker(cache_dir)
|
|
88
88
|
return downloaded_path
|
|
89
89
|
except Exception as e: # pragma: no cover - bubble up for caller context
|
|
@@ -99,11 +99,11 @@ def _select_device(device: Optional[str]) -> str:
|
|
|
99
99
|
|
|
100
100
|
import torch
|
|
101
101
|
|
|
102
|
-
detected =
|
|
102
|
+
detected = "cpu"
|
|
103
103
|
if torch.backends.mps.is_available():
|
|
104
|
-
detected =
|
|
104
|
+
detected = "mps"
|
|
105
105
|
elif torch.cuda.is_available():
|
|
106
|
-
detected =
|
|
106
|
+
detected = "cuda"
|
|
107
107
|
return detected
|
|
108
108
|
|
|
109
109
|
|
|
@@ -122,7 +122,7 @@ def _load_tokenizer(
|
|
|
122
122
|
device=device,
|
|
123
123
|
)
|
|
124
124
|
except Exception as e:
|
|
125
|
-
raise ModelLoadError(f
|
|
125
|
+
raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)
|
|
126
126
|
|
|
127
127
|
|
|
128
128
|
def _load_worker(model_path: str, device: str) -> Lattice1AlphaWorker:
|
|
@@ -130,4 +130,4 @@ def _load_worker(model_path: str, device: str) -> Lattice1AlphaWorker:
|
|
|
130
130
|
try:
|
|
131
131
|
return Lattice1AlphaWorker(model_path, device=device, num_threads=8)
|
|
132
132
|
except Exception as e:
|
|
133
|
-
raise ModelLoadError(f
|
|
133
|
+
raise ModelLoadError(f"worker from {model_path}", original_error=e)
|
lattifai/workers/__init__.py
CHANGED