lattifai 1.0.5__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +11 -12
- lattifai/alignment/lattice1_aligner.py +39 -7
- lattifai/alignment/lattice1_worker.py +135 -147
- lattifai/alignment/tokenizer.py +38 -22
- lattifai/audio2.py +1 -1
- lattifai/caption/caption.py +55 -19
- lattifai/cli/__init__.py +2 -0
- lattifai/cli/caption.py +1 -1
- lattifai/cli/diarization.py +110 -0
- lattifai/cli/transcribe.py +3 -1
- lattifai/cli/youtube.py +11 -0
- lattifai/client.py +32 -111
- lattifai/config/alignment.py +14 -0
- lattifai/config/client.py +5 -0
- lattifai/config/transcription.py +4 -0
- lattifai/diarization/lattifai.py +18 -7
- lattifai/mixin.py +26 -5
- lattifai/transcription/__init__.py +1 -1
- lattifai/transcription/base.py +21 -2
- lattifai/transcription/gemini.py +127 -1
- lattifai/transcription/lattifai.py +30 -2
- lattifai/utils.py +62 -69
- lattifai/workflow/youtube.py +55 -57
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/METADATA +352 -56
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/RECORD +29 -28
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/entry_points.txt +2 -0
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/WHEEL +0 -0
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/top_level.txt +0 -0
lattifai/diarization/lattifai.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
"""LattifAI speaker diarization implementation."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from
|
|
5
|
-
from typing import List, Optional, Tuple
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Callable, List, Optional, Tuple
|
|
6
6
|
|
|
7
|
-
import
|
|
8
|
-
from
|
|
7
|
+
import numpy as np
|
|
8
|
+
from lattifai_core.diarization import DiarizationOutput
|
|
9
|
+
from tgt import TextGrid
|
|
9
10
|
|
|
10
11
|
from lattifai.audio2 import AudioData
|
|
11
12
|
from lattifai.caption import Supervision
|
|
@@ -60,7 +61,7 @@ class LattifAIDiarizer:
|
|
|
60
61
|
num_speakers: Optional[int] = None,
|
|
61
62
|
min_speakers: Optional[int] = None,
|
|
62
63
|
max_speakers: Optional[int] = None,
|
|
63
|
-
) ->
|
|
64
|
+
) -> DiarizationOutput:
|
|
64
65
|
"""Perform speaker diarization on the input audio."""
|
|
65
66
|
return self.diarizer.diarize(
|
|
66
67
|
input_media,
|
|
@@ -73,11 +74,16 @@ class LattifAIDiarizer:
|
|
|
73
74
|
self,
|
|
74
75
|
input_media: AudioData,
|
|
75
76
|
alignments: List[Supervision],
|
|
76
|
-
diarization: Optional[
|
|
77
|
+
diarization: Optional[DiarizationOutput] = None,
|
|
77
78
|
num_speakers: Optional[int] = None,
|
|
78
79
|
min_speakers: Optional[int] = None,
|
|
79
80
|
max_speakers: Optional[int] = None,
|
|
80
|
-
|
|
81
|
+
alignment_fn: Optional[Callable] = None,
|
|
82
|
+
transcribe_fn: Optional[Callable] = None,
|
|
83
|
+
separate_fn: Optional[Callable] = None,
|
|
84
|
+
debug: bool = False,
|
|
85
|
+
output_path: Optional[str] = None,
|
|
86
|
+
) -> Tuple[DiarizationOutput, List[Supervision]]:
|
|
81
87
|
"""Diarize the given media input and return alignments with refined speaker labels."""
|
|
82
88
|
return self.diarizer.diarize_with_alignments(
|
|
83
89
|
input_media,
|
|
@@ -86,4 +92,9 @@ class LattifAIDiarizer:
|
|
|
86
92
|
num_speakers=num_speakers,
|
|
87
93
|
min_speakers=min_speakers,
|
|
88
94
|
max_speakers=max_speakers,
|
|
95
|
+
alignment_fn=alignment_fn,
|
|
96
|
+
transcribe_fn=transcribe_fn,
|
|
97
|
+
separate_fn=separate_fn,
|
|
98
|
+
debug=debug,
|
|
99
|
+
output_path=output_path,
|
|
89
100
|
)
|
lattifai/mixin.py
CHANGED
|
@@ -184,7 +184,9 @@ class LattifAIClientMixin:
|
|
|
184
184
|
from lattifai.utils import _resolve_model_path
|
|
185
185
|
|
|
186
186
|
if transcription_config is not None:
|
|
187
|
-
transcription_config.lattice_model_path = _resolve_model_path(
|
|
187
|
+
transcription_config.lattice_model_path = _resolve_model_path(
|
|
188
|
+
alignment_config.model_name, getattr(alignment_config, "model_hub", "huggingface")
|
|
189
|
+
)
|
|
188
190
|
|
|
189
191
|
# Set client_wrapper for all configs
|
|
190
192
|
alignment_config.client_wrapper = self
|
|
@@ -380,6 +382,7 @@ class LattifAIClientMixin:
|
|
|
380
382
|
media_file: Union[str, Path, AudioData],
|
|
381
383
|
source_lang: Optional[str],
|
|
382
384
|
is_async: bool = False,
|
|
385
|
+
output_dir: Optional[Path] = None,
|
|
383
386
|
) -> Caption:
|
|
384
387
|
"""
|
|
385
388
|
Get captions by downloading or transcribing.
|
|
@@ -406,6 +409,9 @@ class LattifAIClientMixin:
|
|
|
406
409
|
safe_print(colorful.green(" ✓ Transcription completed."))
|
|
407
410
|
|
|
408
411
|
if "gemini" in self.transcriber.name.lower():
|
|
412
|
+
safe_print(colorful.yellow("🔍 Gemini raw output:"))
|
|
413
|
+
safe_print(colorful.yellow(f"{transcription[:1000]}...")) # Print first 1000 chars
|
|
414
|
+
|
|
409
415
|
# write to temp file and use Caption read
|
|
410
416
|
# On Windows, we need to close the file before writing to it
|
|
411
417
|
tmp_file = tempfile.NamedTemporaryFile(
|
|
@@ -428,6 +434,18 @@ class LattifAIClientMixin:
|
|
|
428
434
|
# Clean up temp file
|
|
429
435
|
if tmp_path.exists():
|
|
430
436
|
tmp_path.unlink()
|
|
437
|
+
else:
|
|
438
|
+
safe_print(colorful.yellow(f"🔍 {self.transcriber.name} raw output:"))
|
|
439
|
+
if isinstance(transcription, Caption):
|
|
440
|
+
safe_print(colorful.yellow(f"Caption with {len(transcription.transcription)} segments"))
|
|
441
|
+
if transcription.transcription:
|
|
442
|
+
safe_print(colorful.yellow(f"First segment: {transcription.transcription[0].text}"))
|
|
443
|
+
|
|
444
|
+
if output_dir:
|
|
445
|
+
# Generate transcript file path
|
|
446
|
+
transcript_file = output_dir / f"{Path(str(media_file)).stem}_{self.transcriber.file_name}"
|
|
447
|
+
await asyncio.to_thread(self.transcriber.write, transcription, transcript_file, encoding="utf-8")
|
|
448
|
+
safe_print(colorful.green(f" ✓ Transcription saved to: {transcript_file}"))
|
|
431
449
|
|
|
432
450
|
return transcription
|
|
433
451
|
|
|
@@ -473,10 +491,13 @@ class LattifAIClientMixin:
|
|
|
473
491
|
safe_print(colorful.green(f"📄 Using provided caption file: {caption_path}"))
|
|
474
492
|
return str(caption_path)
|
|
475
493
|
else:
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
494
|
+
safe_print(colorful.red(f"Provided caption path does not exist: {caption_path}, use transcription"))
|
|
495
|
+
use_transcription = True
|
|
496
|
+
transcript_file = caption_path
|
|
497
|
+
caption_path.parent.mkdir(parents=True, exist_ok=True)
|
|
498
|
+
else:
|
|
499
|
+
# Generate transcript file path
|
|
500
|
+
transcript_file = output_dir / f"{Path(str(media_file)).stem}_{self.transcriber.file_name}"
|
|
480
501
|
|
|
481
502
|
if use_transcription:
|
|
482
503
|
# Transcription mode: use Transcriber to transcribe
|
|
@@ -70,7 +70,7 @@ def create_transcriber(
|
|
|
70
70
|
raise ValueError(
|
|
71
71
|
f"Cannot determine transcriber for model_name='{transcription_config.model_name}'. "
|
|
72
72
|
f"Supported patterns: \n"
|
|
73
|
-
f" - Gemini API models: 'gemini-2.5-pro', 'gemini-3-pro-preview'\n"
|
|
73
|
+
f" - Gemini API models: 'gemini-2.5-pro', 'gemini-3-pro-preview', 'gemini-3-flash-preview'\n"
|
|
74
74
|
f" - Local HF models: 'nvidia/parakeet-*', 'iic/SenseVoiceSmall', etc.\n"
|
|
75
75
|
f"Please specify a valid model_name."
|
|
76
76
|
)
|
lattifai/transcription/base.py
CHANGED
|
@@ -2,10 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Optional, Union
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
6
8
|
|
|
7
9
|
from lattifai.audio2 import AudioData
|
|
8
|
-
from lattifai.caption import Caption
|
|
10
|
+
from lattifai.caption import Caption, Supervision
|
|
9
11
|
from lattifai.config import TranscriptionConfig
|
|
10
12
|
from lattifai.logging import get_logger
|
|
11
13
|
|
|
@@ -96,6 +98,23 @@ class BaseTranscriber(ABC):
|
|
|
96
98
|
language: Optional language code for transcription.
|
|
97
99
|
"""
|
|
98
100
|
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def transcribe_numpy(
|
|
103
|
+
self,
|
|
104
|
+
audio: Union[np.ndarray, List[np.ndarray]],
|
|
105
|
+
language: Optional[str] = None,
|
|
106
|
+
) -> Union[Supervision, List[Supervision]]:
|
|
107
|
+
"""
|
|
108
|
+
Transcribe audio from a numpy array and return Supervision.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
audio_array: Audio data as numpy array (shape: [samples]).
|
|
112
|
+
language: Optional language code for transcription.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Supervision object with transcription info.
|
|
116
|
+
"""
|
|
117
|
+
|
|
99
118
|
@abstractmethod
|
|
100
119
|
def write(self, transcript: Union[str, Caption], output_file: Path, encoding: str = "utf-8") -> Path:
|
|
101
120
|
"""
|
lattifai/transcription/gemini.py
CHANGED
|
@@ -2,12 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Optional, Union
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
6
|
|
|
7
|
+
import numpy as np
|
|
7
8
|
from google import genai
|
|
8
9
|
from google.genai.types import GenerateContentConfig, Part, ThinkingConfig
|
|
9
10
|
|
|
10
11
|
from lattifai.audio2 import AudioData
|
|
12
|
+
from lattifai.caption import Supervision
|
|
11
13
|
from lattifai.config import TranscriptionConfig
|
|
12
14
|
from lattifai.transcription.base import BaseTranscriber
|
|
13
15
|
from lattifai.transcription.prompts import get_prompt_loader
|
|
@@ -118,6 +120,130 @@ class GeminiTranscriber(BaseTranscriber):
|
|
|
118
120
|
self.logger.error(f"Gemini transcription failed: {str(e)}")
|
|
119
121
|
raise RuntimeError(f"Gemini transcription failed: {str(e)}")
|
|
120
122
|
|
|
123
|
+
def transcribe_numpy(
|
|
124
|
+
self,
|
|
125
|
+
audio: Union[np.ndarray, List[np.ndarray]],
|
|
126
|
+
language: Optional[str] = None,
|
|
127
|
+
) -> Union[Supervision, List[Supervision]]:
|
|
128
|
+
"""
|
|
129
|
+
Transcribe audio from a numpy array (or list of arrays) and return Supervision.
|
|
130
|
+
|
|
131
|
+
Note: Gemini API does not support word-level alignment. The returned
|
|
132
|
+
Supervision will contain only the full transcription text without alignment.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
audio: Audio data as numpy array (shape: [samples]),
|
|
136
|
+
or a list of such arrays for batch processing.
|
|
137
|
+
language: Optional language code for transcription.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Supervision object (or list of Supervision objects) with transcription text (no alignment).
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
ValueError: If API key not provided
|
|
144
|
+
RuntimeError: If transcription fails
|
|
145
|
+
"""
|
|
146
|
+
# Handle batch processing
|
|
147
|
+
if isinstance(audio, list):
|
|
148
|
+
return [self.transcribe_numpy(arr, language=language) for arr in audio]
|
|
149
|
+
|
|
150
|
+
audio_array = audio
|
|
151
|
+
# Use default sample rate of 16000 Hz
|
|
152
|
+
sample_rate = 16000
|
|
153
|
+
|
|
154
|
+
if self.config.verbose:
|
|
155
|
+
self.logger.info(f"🎤 Starting Gemini transcription for numpy array (sample_rate={sample_rate})")
|
|
156
|
+
|
|
157
|
+
# Ensure audio is in the correct shape
|
|
158
|
+
if audio_array.ndim == 1:
|
|
159
|
+
audio_array = audio_array.reshape(1, -1)
|
|
160
|
+
elif audio_array.ndim > 2:
|
|
161
|
+
raise ValueError(f"Audio array must be 1D or 2D, got shape {audio_array.shape}")
|
|
162
|
+
|
|
163
|
+
# Save numpy array to temporary file
|
|
164
|
+
import tempfile
|
|
165
|
+
|
|
166
|
+
import soundfile as sf
|
|
167
|
+
|
|
168
|
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
|
169
|
+
# Transpose to (samples, channels) for soundfile
|
|
170
|
+
sf.write(tmp_file.name, audio_array.T, sample_rate)
|
|
171
|
+
tmp_path = Path(tmp_file.name)
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
# Transcribe using simple ASR prompt
|
|
175
|
+
import asyncio
|
|
176
|
+
|
|
177
|
+
transcript = asyncio.run(self._transcribe_with_simple_prompt(tmp_path, language=language))
|
|
178
|
+
|
|
179
|
+
# Create Supervision object from transcript
|
|
180
|
+
duration = audio_array.shape[-1] / sample_rate
|
|
181
|
+
supervision = Supervision(
|
|
182
|
+
id="gemini-transcription",
|
|
183
|
+
recording_id="numpy-array",
|
|
184
|
+
start=0.0,
|
|
185
|
+
duration=duration,
|
|
186
|
+
text=transcript,
|
|
187
|
+
speaker=None,
|
|
188
|
+
alignment=None, # Gemini does not provide word-level alignment
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return supervision
|
|
192
|
+
|
|
193
|
+
finally:
|
|
194
|
+
# Clean up temporary file
|
|
195
|
+
if tmp_path.exists():
|
|
196
|
+
tmp_path.unlink()
|
|
197
|
+
|
|
198
|
+
async def _transcribe_with_simple_prompt(self, media_file: Path, language: Optional[str] = None) -> str:
|
|
199
|
+
"""
|
|
200
|
+
Transcribe audio using a simple ASR prompt instead of complex instructions.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
media_file: Path to audio file
|
|
204
|
+
language: Optional language code
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Transcribed text
|
|
208
|
+
"""
|
|
209
|
+
client = self._get_client()
|
|
210
|
+
|
|
211
|
+
# Upload audio file
|
|
212
|
+
if self.config.verbose:
|
|
213
|
+
self.logger.info("📤 Uploading audio file to Gemini...")
|
|
214
|
+
uploaded_file = client.files.upload(file=str(media_file))
|
|
215
|
+
|
|
216
|
+
# Simple ASR prompt
|
|
217
|
+
system_prompt = "Transcribe the audio."
|
|
218
|
+
if language:
|
|
219
|
+
system_prompt = f"Transcribe the audio in {language}."
|
|
220
|
+
|
|
221
|
+
# Create simple generation config
|
|
222
|
+
simple_config = GenerateContentConfig(
|
|
223
|
+
system_instruction=system_prompt,
|
|
224
|
+
response_modalities=["TEXT"],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
contents = Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type)
|
|
228
|
+
response = await asyncio.get_event_loop().run_in_executor(
|
|
229
|
+
None,
|
|
230
|
+
lambda: client.models.generate_content(
|
|
231
|
+
model=self.config.model_name,
|
|
232
|
+
contents=contents,
|
|
233
|
+
config=simple_config,
|
|
234
|
+
),
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if not response.text:
|
|
238
|
+
raise RuntimeError("Empty response from Gemini API")
|
|
239
|
+
|
|
240
|
+
transcript = response.text.strip()
|
|
241
|
+
|
|
242
|
+
if self.config.verbose:
|
|
243
|
+
self.logger.info(f"✅ Transcription completed: {len(transcript)} characters")
|
|
244
|
+
|
|
245
|
+
return transcript
|
|
246
|
+
|
|
121
247
|
def _get_transcription_prompt(self) -> str:
|
|
122
248
|
"""Get (and cache) transcription system prompt from prompts module."""
|
|
123
249
|
if self._system_prompt is not None:
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
"""Transcription module with config-driven architecture."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Optional, Union
|
|
4
|
+
from typing import List, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
5
7
|
|
|
6
8
|
from lattifai.audio2 import AudioData
|
|
7
|
-
from lattifai.caption import Caption
|
|
9
|
+
from lattifai.caption import Caption, Supervision
|
|
8
10
|
from lattifai.config import TranscriptionConfig
|
|
9
11
|
from lattifai.transcription.base import BaseTranscriber
|
|
10
12
|
from lattifai.transcription.prompts import get_prompt_loader # noqa: F401
|
|
@@ -74,6 +76,32 @@ class LattifAITranscriber(BaseTranscriber):
|
|
|
74
76
|
|
|
75
77
|
return caption
|
|
76
78
|
|
|
79
|
+
def transcribe_numpy(
|
|
80
|
+
self,
|
|
81
|
+
audio: Union[np.ndarray, List[np.ndarray]],
|
|
82
|
+
language: Optional[str] = None,
|
|
83
|
+
) -> Union[Supervision, List[Supervision]]:
|
|
84
|
+
"""
|
|
85
|
+
Transcribe audio from a numpy array (or list of arrays) and return Supervision.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
audio: Audio data as numpy array (shape: [samples]),
|
|
89
|
+
or a list of such arrays for batch processing.
|
|
90
|
+
language: Optional language code for transcription.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Supervision object (or list of Supervision objects) with transcription and alignment info.
|
|
94
|
+
"""
|
|
95
|
+
if self._transcriber is None:
|
|
96
|
+
from lattifai_core.transcription import LattifAITranscriber as CoreLattifAITranscriber
|
|
97
|
+
|
|
98
|
+
self._transcriber = CoreLattifAITranscriber.from_pretrained(model_config=self.config)
|
|
99
|
+
|
|
100
|
+
# Delegate to core transcriber which handles both single arrays and lists
|
|
101
|
+
return self._transcriber.transcribe(
|
|
102
|
+
audio, language=language, return_hypotheses=True, progress_bar=False, timestamps=True
|
|
103
|
+
)[0]
|
|
104
|
+
|
|
77
105
|
def write(
|
|
78
106
|
self, transcript: Caption, output_file: Path, encoding: str = "utf-8", cache_audio_events: bool = True
|
|
79
107
|
) -> Path:
|
lattifai/utils.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
"""Shared utility helpers for the LattifAI SDK."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
import sys
|
|
5
4
|
from datetime import datetime, timedelta
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
from typing import
|
|
6
|
+
from typing import Optional
|
|
8
7
|
|
|
9
8
|
from lattifai.errors import ModelLoadError
|
|
10
9
|
|
|
@@ -45,85 +44,79 @@ def safe_print(text: str, **kwargs) -> None:
|
|
|
45
44
|
print(text.encode("utf-8", errors="replace").decode("utf-8"), **kwargs)
|
|
46
45
|
|
|
47
46
|
|
|
48
|
-
def
|
|
49
|
-
"""
|
|
50
|
-
today = datetime.now().strftime("%Y%m%d")
|
|
51
|
-
return cache_dir / f".done{today}"
|
|
47
|
+
def _resolve_model_path(model_name_or_path: str, model_hub: str = "huggingface") -> str:
|
|
48
|
+
"""Resolve model path, downloading from the specified model hub when necessary.
|
|
52
49
|
|
|
50
|
+
Args:
|
|
51
|
+
model_name_or_path: Local path or remote model identifier.
|
|
52
|
+
model_hub: Which hub to use for downloads. Supported: "huggingface", "modelscope".
|
|
53
|
+
"""
|
|
54
|
+
if Path(model_name_or_path).expanduser().exists():
|
|
55
|
+
return str(Path(model_name_or_path).expanduser())
|
|
53
56
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
if not cache_dir.exists():
|
|
57
|
-
return False
|
|
58
|
-
|
|
59
|
-
# Find any .done* marker files
|
|
60
|
-
marker_files = list(cache_dir.glob(".done*"))
|
|
61
|
-
if not marker_files:
|
|
62
|
-
return False
|
|
57
|
+
# Normalize hub name
|
|
58
|
+
hub = (model_hub or "huggingface").lower()
|
|
63
59
|
|
|
64
|
-
|
|
65
|
-
|
|
60
|
+
if hub not in ("huggingface", "modelscope"):
|
|
61
|
+
raise ValueError(f"Unsupported model_hub: {model_hub}. Supported: 'huggingface', 'modelscope'.")
|
|
66
62
|
|
|
67
|
-
#
|
|
68
|
-
try:
|
|
69
|
-
date_str = latest_marker.name.replace(".done", "")
|
|
70
|
-
marker_date = datetime.strptime(date_str, "%Y%m%d")
|
|
71
|
-
# Check if marker is older than 1 days
|
|
72
|
-
if datetime.now() - marker_date > timedelta(days=1):
|
|
73
|
-
return False
|
|
74
|
-
return True
|
|
75
|
-
except (ValueError, IndexError):
|
|
76
|
-
# Invalid marker file format, treat as invalid cache
|
|
77
|
-
return False
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def _create_cache_marker(cache_dir: Path) -> None:
|
|
81
|
-
"""Create a cache marker file with current date and clean old markers."""
|
|
82
|
-
# Remove old marker files
|
|
83
|
-
for old_marker in cache_dir.glob(".done*"):
|
|
84
|
-
old_marker.unlink(missing_ok=True)
|
|
85
|
-
|
|
86
|
-
# Create new marker file
|
|
87
|
-
marker_path = _get_cache_marker_path(cache_dir)
|
|
88
|
-
marker_path.touch()
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def _resolve_model_path(model_name_or_path: str) -> str:
|
|
92
|
-
"""Resolve model path, downloading from Hugging Face when necessary."""
|
|
63
|
+
# If local path exists, return it regardless of hub
|
|
93
64
|
if Path(model_name_or_path).expanduser().exists():
|
|
94
65
|
return str(Path(model_name_or_path).expanduser())
|
|
95
66
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
67
|
+
if hub == "huggingface":
|
|
68
|
+
from huggingface_hub import HfApi, snapshot_download
|
|
69
|
+
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
99
70
|
|
|
100
|
-
|
|
101
|
-
|
|
71
|
+
# Support repo_id@revision syntax
|
|
72
|
+
hf_repo_id = model_name_or_path
|
|
73
|
+
revision = None
|
|
74
|
+
if "@" in model_name_or_path:
|
|
75
|
+
hf_repo_id, revision = model_name_or_path.split("@", 1)
|
|
102
76
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
77
|
+
# If no specific revision/commit is provided, try to fetch the real latest SHA
|
|
78
|
+
# to bypass Hugging Face's model_info (metadata) sync lag.
|
|
79
|
+
if not revision:
|
|
80
|
+
try:
|
|
81
|
+
api = HfApi()
|
|
82
|
+
refs = api.list_repo_refs(repo_id=hf_repo_id, repo_type="model")
|
|
83
|
+
# Look for the default branch (usually 'main')
|
|
84
|
+
for branch in refs.branches:
|
|
85
|
+
if branch.name == "main":
|
|
86
|
+
revision = branch.target_commit
|
|
87
|
+
break
|
|
88
|
+
except Exception:
|
|
89
|
+
# Fallback to default behavior if API call fails
|
|
90
|
+
revision = None
|
|
113
91
|
|
|
114
|
-
try:
|
|
115
|
-
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type="model")
|
|
116
|
-
_create_cache_marker(cache_dir)
|
|
117
|
-
return downloaded_path
|
|
118
|
-
except LocalEntryNotFoundError:
|
|
119
92
|
try:
|
|
120
|
-
|
|
121
|
-
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type="model")
|
|
122
|
-
_create_cache_marker(cache_dir)
|
|
93
|
+
downloaded_path = snapshot_download(repo_id=hf_repo_id, repo_type="model", revision=revision)
|
|
123
94
|
return downloaded_path
|
|
124
|
-
except
|
|
125
|
-
|
|
126
|
-
|
|
95
|
+
except LocalEntryNotFoundError:
|
|
96
|
+
# Fall back to modelscope if HF entry not found
|
|
97
|
+
try:
|
|
98
|
+
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
|
|
99
|
+
|
|
100
|
+
downloaded_path = ms_snapshot(model_name_or_path)
|
|
101
|
+
return downloaded_path
|
|
102
|
+
except Exception as e: # pragma: no cover - bubble up for caller context
|
|
103
|
+
raise ModelLoadError(model_name_or_path, original_error=e)
|
|
104
|
+
except Exception as e: # pragma: no cover - unexpected download issue
|
|
105
|
+
import colorful
|
|
106
|
+
|
|
107
|
+
print(colorful.red | f"Error downloading from Hugging Face Hub: {e}. Trying ModelScope...")
|
|
108
|
+
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
|
|
109
|
+
|
|
110
|
+
downloaded_path = ms_snapshot(model_name_or_path)
|
|
111
|
+
return downloaded_path
|
|
112
|
+
|
|
113
|
+
# modelscope path
|
|
114
|
+
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
downloaded_path = ms_snapshot(model_name_or_path)
|
|
118
|
+
return downloaded_path
|
|
119
|
+
except Exception as e: # pragma: no cover
|
|
127
120
|
raise ModelLoadError(model_name_or_path, original_error=e)
|
|
128
121
|
|
|
129
122
|
|
lattifai/workflow/youtube.py
CHANGED
|
@@ -429,79 +429,77 @@ class YouTubeDownloader:
|
|
|
429
429
|
result = await loop.run_in_executor(
|
|
430
430
|
None, lambda: subprocess.run(ytdlp_options, capture_output=True, text=True, check=True)
|
|
431
431
|
)
|
|
432
|
-
|
|
433
432
|
# Only log success message, not full yt-dlp output
|
|
434
433
|
self.logger.debug(f"yt-dlp output: {result.stdout.strip()}")
|
|
435
|
-
|
|
436
|
-
# Find the downloaded transcript file
|
|
437
|
-
caption_patterns = [
|
|
438
|
-
f"{video_id}.*vtt",
|
|
439
|
-
f"{video_id}.*srt",
|
|
440
|
-
f"{video_id}.*sub",
|
|
441
|
-
f"{video_id}.*sbv",
|
|
442
|
-
f"{video_id}.*ssa",
|
|
443
|
-
f"{video_id}.*ass",
|
|
444
|
-
]
|
|
445
|
-
|
|
446
|
-
caption_files = []
|
|
447
|
-
for pattern in caption_patterns:
|
|
448
|
-
_caption_files = list(target_dir.glob(pattern))
|
|
449
|
-
for caption_file in _caption_files:
|
|
450
|
-
self.logger.info(f"📥 Downloaded caption: {caption_file}")
|
|
451
|
-
caption_files.extend(_caption_files)
|
|
452
|
-
|
|
453
|
-
if not caption_files:
|
|
454
|
-
self.logger.warning("No caption available for this video")
|
|
455
|
-
return None
|
|
456
|
-
|
|
457
|
-
# If only one caption file, return it directly
|
|
458
|
-
if len(caption_files) == 1:
|
|
459
|
-
self.logger.info(f"✅ Using caption: {caption_files[0]}")
|
|
460
|
-
return str(caption_files[0])
|
|
461
|
-
|
|
462
|
-
# Multiple caption files found, let user choose
|
|
463
|
-
if FileExistenceManager.is_interactive_mode():
|
|
464
|
-
self.logger.info(f"📋 Found {len(caption_files)} caption files")
|
|
465
|
-
caption_choice = FileExistenceManager.prompt_file_selection(
|
|
466
|
-
file_type="caption",
|
|
467
|
-
files=[str(f) for f in caption_files],
|
|
468
|
-
operation="use",
|
|
469
|
-
transcriber_name=transcriber_name,
|
|
470
|
-
)
|
|
471
|
-
|
|
472
|
-
if caption_choice == "cancel":
|
|
473
|
-
raise RuntimeError("Caption selection cancelled by user")
|
|
474
|
-
elif caption_choice == TRANSCRIBE_CHOICE:
|
|
475
|
-
return caption_choice
|
|
476
|
-
elif caption_choice:
|
|
477
|
-
self.logger.info(f"✅ Selected caption: {caption_choice}")
|
|
478
|
-
return caption_choice
|
|
479
|
-
else:
|
|
480
|
-
# Fallback to first file
|
|
481
|
-
self.logger.info(f"✅ Using first caption: {caption_files[0]}")
|
|
482
|
-
return str(caption_files[0])
|
|
483
|
-
else:
|
|
484
|
-
# Non-interactive mode: use first file
|
|
485
|
-
self.logger.info(f"✅ Using first caption: {caption_files[0]}")
|
|
486
|
-
return str(caption_files[0])
|
|
487
|
-
|
|
488
434
|
except subprocess.CalledProcessError as e:
|
|
489
435
|
error_msg = e.stderr.strip() if e.stderr else str(e)
|
|
490
436
|
|
|
491
437
|
# Check for specific error conditions
|
|
492
438
|
if "No automatic or manual captions found" in error_msg:
|
|
493
439
|
self.logger.warning("No captions available for this video")
|
|
494
|
-
return None
|
|
495
440
|
elif "HTTP Error 429" in error_msg or "Too Many Requests" in error_msg:
|
|
496
441
|
self.logger.error("YouTube rate limit exceeded. Please try again later or use a different method.")
|
|
497
|
-
|
|
442
|
+
self.logger.error(
|
|
498
443
|
"YouTube rate limit exceeded (HTTP 429). "
|
|
499
444
|
"Try again later or use --cookies option with authenticated cookies. "
|
|
500
445
|
"See: https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp"
|
|
501
446
|
)
|
|
502
447
|
else:
|
|
503
448
|
self.logger.error(f"Failed to download transcript: {error_msg}")
|
|
504
|
-
|
|
449
|
+
|
|
450
|
+
# Find the downloaded transcript file
|
|
451
|
+
caption_patterns = [
|
|
452
|
+
f"{video_id}.*vtt",
|
|
453
|
+
f"{video_id}.*srt",
|
|
454
|
+
f"{video_id}.*sub",
|
|
455
|
+
f"{video_id}.*sbv",
|
|
456
|
+
f"{video_id}.*ssa",
|
|
457
|
+
f"{video_id}.*ass",
|
|
458
|
+
]
|
|
459
|
+
|
|
460
|
+
caption_files = []
|
|
461
|
+
for pattern in caption_patterns:
|
|
462
|
+
_caption_files = list(target_dir.glob(pattern))
|
|
463
|
+
for caption_file in _caption_files:
|
|
464
|
+
self.logger.info(f"📥 Downloaded caption: {caption_file}")
|
|
465
|
+
caption_files.extend(_caption_files)
|
|
466
|
+
|
|
467
|
+
# If only one caption file, return it directly
|
|
468
|
+
if len(caption_files) == 1:
|
|
469
|
+
self.logger.info(f"✅ Using caption: {caption_files[0]}")
|
|
470
|
+
return str(caption_files[0])
|
|
471
|
+
|
|
472
|
+
# Multiple caption files found, let user choose
|
|
473
|
+
if FileExistenceManager.is_interactive_mode():
|
|
474
|
+
self.logger.info(f"📋 Found {len(caption_files)} caption files")
|
|
475
|
+
caption_choice = FileExistenceManager.prompt_file_selection(
|
|
476
|
+
file_type="caption",
|
|
477
|
+
files=[str(f) for f in caption_files],
|
|
478
|
+
operation="use",
|
|
479
|
+
transcriber_name=transcriber_name,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if caption_choice == "cancel":
|
|
483
|
+
raise RuntimeError("Caption selection cancelled by user")
|
|
484
|
+
elif caption_choice == TRANSCRIBE_CHOICE:
|
|
485
|
+
return caption_choice
|
|
486
|
+
elif caption_choice:
|
|
487
|
+
self.logger.info(f"✅ Selected caption: {caption_choice}")
|
|
488
|
+
return caption_choice
|
|
489
|
+
elif caption_files:
|
|
490
|
+
# Fallback to first file
|
|
491
|
+
self.logger.info(f"✅ Using first caption: {caption_files[0]}")
|
|
492
|
+
return str(caption_files[0])
|
|
493
|
+
else:
|
|
494
|
+
self.logger.warning("No caption files available after download")
|
|
495
|
+
return None
|
|
496
|
+
elif caption_files:
|
|
497
|
+
# Non-interactive mode: use first file
|
|
498
|
+
self.logger.info(f"✅ Using first caption: {caption_files[0]}")
|
|
499
|
+
return str(caption_files[0])
|
|
500
|
+
else:
|
|
501
|
+
self.logger.warning("No caption files available after download")
|
|
502
|
+
return None
|
|
505
503
|
|
|
506
504
|
async def list_available_captions(self, url: str) -> List[Dict[str, Any]]:
|
|
507
505
|
"""
|