lattifai 1.0.5__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,12 @@
1
1
  """LattifAI speaker diarization implementation."""
2
2
 
3
3
  import logging
4
- from collections import defaultdict
5
- from typing import List, Optional, Tuple
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Tuple
6
6
 
7
- import torch
8
- from tgt import Interval, IntervalTier, TextGrid
7
+ import numpy as np
8
+ from lattifai_core.diarization import DiarizationOutput
9
+ from tgt import TextGrid
9
10
 
10
11
  from lattifai.audio2 import AudioData
11
12
  from lattifai.caption import Supervision
@@ -60,7 +61,7 @@ class LattifAIDiarizer:
60
61
  num_speakers: Optional[int] = None,
61
62
  min_speakers: Optional[int] = None,
62
63
  max_speakers: Optional[int] = None,
63
- ) -> TextGrid:
64
+ ) -> DiarizationOutput:
64
65
  """Perform speaker diarization on the input audio."""
65
66
  return self.diarizer.diarize(
66
67
  input_media,
@@ -73,11 +74,16 @@ class LattifAIDiarizer:
73
74
  self,
74
75
  input_media: AudioData,
75
76
  alignments: List[Supervision],
76
- diarization: Optional[TextGrid] = None,
77
+ diarization: Optional[DiarizationOutput] = None,
77
78
  num_speakers: Optional[int] = None,
78
79
  min_speakers: Optional[int] = None,
79
80
  max_speakers: Optional[int] = None,
80
- ) -> Tuple[TextGrid, List[Supervision]]:
81
+ alignment_fn: Optional[Callable] = None,
82
+ transcribe_fn: Optional[Callable] = None,
83
+ separate_fn: Optional[Callable] = None,
84
+ debug: bool = False,
85
+ output_path: Optional[str] = None,
86
+ ) -> Tuple[DiarizationOutput, List[Supervision]]:
81
87
  """Diarize the given media input and return alignments with refined speaker labels."""
82
88
  return self.diarizer.diarize_with_alignments(
83
89
  input_media,
@@ -86,4 +92,9 @@ class LattifAIDiarizer:
86
92
  num_speakers=num_speakers,
87
93
  min_speakers=min_speakers,
88
94
  max_speakers=max_speakers,
95
+ alignment_fn=alignment_fn,
96
+ transcribe_fn=transcribe_fn,
97
+ separate_fn=separate_fn,
98
+ debug=debug,
99
+ output_path=output_path,
89
100
  )
lattifai/mixin.py CHANGED
@@ -184,7 +184,9 @@ class LattifAIClientMixin:
184
184
  from lattifai.utils import _resolve_model_path
185
185
 
186
186
  if transcription_config is not None:
187
- transcription_config.lattice_model_path = _resolve_model_path(alignment_config.model_name)
187
+ transcription_config.lattice_model_path = _resolve_model_path(
188
+ alignment_config.model_name, getattr(alignment_config, "model_hub", "huggingface")
189
+ )
188
190
 
189
191
  # Set client_wrapper for all configs
190
192
  alignment_config.client_wrapper = self
@@ -380,6 +382,7 @@ class LattifAIClientMixin:
380
382
  media_file: Union[str, Path, AudioData],
381
383
  source_lang: Optional[str],
382
384
  is_async: bool = False,
385
+ output_dir: Optional[Path] = None,
383
386
  ) -> Caption:
384
387
  """
385
388
  Get captions by downloading or transcribing.
@@ -406,6 +409,9 @@ class LattifAIClientMixin:
406
409
  safe_print(colorful.green(" ✓ Transcription completed."))
407
410
 
408
411
  if "gemini" in self.transcriber.name.lower():
412
+ safe_print(colorful.yellow("🔍 Gemini raw output:"))
413
+ safe_print(colorful.yellow(f"{transcription[:1000]}...")) # Print first 1000 chars
414
+
409
415
  # write to temp file and use Caption read
410
416
  # On Windows, we need to close the file before writing to it
411
417
  tmp_file = tempfile.NamedTemporaryFile(
@@ -428,6 +434,18 @@ class LattifAIClientMixin:
428
434
  # Clean up temp file
429
435
  if tmp_path.exists():
430
436
  tmp_path.unlink()
437
+ else:
438
+ safe_print(colorful.yellow(f"🔍 {self.transcriber.name} raw output:"))
439
+ if isinstance(transcription, Caption):
440
+ safe_print(colorful.yellow(f"Caption with {len(transcription.transcription)} segments"))
441
+ if transcription.transcription:
442
+ safe_print(colorful.yellow(f"First segment: {transcription.transcription[0].text}"))
443
+
444
+ if output_dir:
445
+ # Generate transcript file path
446
+ transcript_file = output_dir / f"{Path(str(media_file)).stem}_{self.transcriber.file_name}"
447
+ await asyncio.to_thread(self.transcriber.write, transcription, transcript_file, encoding="utf-8")
448
+ safe_print(colorful.green(f" ✓ Transcription saved to: {transcript_file}"))
431
449
 
432
450
  return transcription
433
451
 
@@ -473,10 +491,13 @@ class LattifAIClientMixin:
473
491
  safe_print(colorful.green(f"📄 Using provided caption file: {caption_path}"))
474
492
  return str(caption_path)
475
493
  else:
476
- raise FileNotFoundError(f"Provided caption path does not exist: {caption_path}")
477
-
478
- # Generate transcript file path
479
- transcript_file = output_dir / f"{Path(str(media_file)).stem}_{self.transcriber.file_name}"
494
+ safe_print(colorful.red(f"Provided caption path does not exist: {caption_path}, use transcription"))
495
+ use_transcription = True
496
+ transcript_file = caption_path
497
+ caption_path.parent.mkdir(parents=True, exist_ok=True)
498
+ else:
499
+ # Generate transcript file path
500
+ transcript_file = output_dir / f"{Path(str(media_file)).stem}_{self.transcriber.file_name}"
480
501
 
481
502
  if use_transcription:
482
503
  # Transcription mode: use Transcriber to transcribe
@@ -70,7 +70,7 @@ def create_transcriber(
70
70
  raise ValueError(
71
71
  f"Cannot determine transcriber for model_name='{transcription_config.model_name}'. "
72
72
  f"Supported patterns: \n"
73
- f" - Gemini API models: 'gemini-2.5-pro', 'gemini-3-pro-preview'\n"
73
+ f" - Gemini API models: 'gemini-2.5-pro', 'gemini-3-pro-preview', 'gemini-3-flash-preview'\n"
74
74
  f" - Local HF models: 'nvidia/parakeet-*', 'iic/SenseVoiceSmall', etc.\n"
75
75
  f"Please specify a valid model_name."
76
76
  )
@@ -2,10 +2,12 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from pathlib import Path
5
- from typing import Optional, Union
5
+ from typing import List, Optional, Union
6
+
7
+ import numpy as np
6
8
 
7
9
  from lattifai.audio2 import AudioData
8
- from lattifai.caption import Caption
10
+ from lattifai.caption import Caption, Supervision
9
11
  from lattifai.config import TranscriptionConfig
10
12
  from lattifai.logging import get_logger
11
13
 
@@ -96,6 +98,23 @@ class BaseTranscriber(ABC):
96
98
  language: Optional language code for transcription.
97
99
  """
98
100
 
101
+ @abstractmethod
102
+ def transcribe_numpy(
103
+ self,
104
+ audio: Union[np.ndarray, List[np.ndarray]],
105
+ language: Optional[str] = None,
106
+ ) -> Union[Supervision, List[Supervision]]:
107
+ """
108
+ Transcribe audio from a numpy array and return Supervision.
109
+
110
+ Args:
111
+ audio_array: Audio data as numpy array (shape: [samples]).
112
+ language: Optional language code for transcription.
113
+
114
+ Returns:
115
+ Supervision object with transcription info.
116
+ """
117
+
99
118
  @abstractmethod
100
119
  def write(self, transcript: Union[str, Caption], output_file: Path, encoding: str = "utf-8") -> Path:
101
120
  """
@@ -2,12 +2,14 @@
2
2
 
3
3
  import asyncio
4
4
  from pathlib import Path
5
- from typing import Optional, Union
5
+ from typing import List, Optional, Union
6
6
 
7
+ import numpy as np
7
8
  from google import genai
8
9
  from google.genai.types import GenerateContentConfig, Part, ThinkingConfig
9
10
 
10
11
  from lattifai.audio2 import AudioData
12
+ from lattifai.caption import Supervision
11
13
  from lattifai.config import TranscriptionConfig
12
14
  from lattifai.transcription.base import BaseTranscriber
13
15
  from lattifai.transcription.prompts import get_prompt_loader
@@ -118,6 +120,130 @@ class GeminiTranscriber(BaseTranscriber):
118
120
  self.logger.error(f"Gemini transcription failed: {str(e)}")
119
121
  raise RuntimeError(f"Gemini transcription failed: {str(e)}")
120
122
 
123
+ def transcribe_numpy(
124
+ self,
125
+ audio: Union[np.ndarray, List[np.ndarray]],
126
+ language: Optional[str] = None,
127
+ ) -> Union[Supervision, List[Supervision]]:
128
+ """
129
+ Transcribe audio from a numpy array (or list of arrays) and return Supervision.
130
+
131
+ Note: Gemini API does not support word-level alignment. The returned
132
+ Supervision will contain only the full transcription text without alignment.
133
+
134
+ Args:
135
+ audio: Audio data as numpy array (shape: [samples]),
136
+ or a list of such arrays for batch processing.
137
+ language: Optional language code for transcription.
138
+
139
+ Returns:
140
+ Supervision object (or list of Supervision objects) with transcription text (no alignment).
141
+
142
+ Raises:
143
+ ValueError: If API key not provided
144
+ RuntimeError: If transcription fails
145
+ """
146
+ # Handle batch processing
147
+ if isinstance(audio, list):
148
+ return [self.transcribe_numpy(arr, language=language) for arr in audio]
149
+
150
+ audio_array = audio
151
+ # Use default sample rate of 16000 Hz
152
+ sample_rate = 16000
153
+
154
+ if self.config.verbose:
155
+ self.logger.info(f"🎤 Starting Gemini transcription for numpy array (sample_rate={sample_rate})")
156
+
157
+ # Ensure audio is in the correct shape
158
+ if audio_array.ndim == 1:
159
+ audio_array = audio_array.reshape(1, -1)
160
+ elif audio_array.ndim > 2:
161
+ raise ValueError(f"Audio array must be 1D or 2D, got shape {audio_array.shape}")
162
+
163
+ # Save numpy array to temporary file
164
+ import tempfile
165
+
166
+ import soundfile as sf
167
+
168
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
169
+ # Transpose to (samples, channels) for soundfile
170
+ sf.write(tmp_file.name, audio_array.T, sample_rate)
171
+ tmp_path = Path(tmp_file.name)
172
+
173
+ try:
174
+ # Transcribe using simple ASR prompt
175
+ import asyncio
176
+
177
+ transcript = asyncio.run(self._transcribe_with_simple_prompt(tmp_path, language=language))
178
+
179
+ # Create Supervision object from transcript
180
+ duration = audio_array.shape[-1] / sample_rate
181
+ supervision = Supervision(
182
+ id="gemini-transcription",
183
+ recording_id="numpy-array",
184
+ start=0.0,
185
+ duration=duration,
186
+ text=transcript,
187
+ speaker=None,
188
+ alignment=None, # Gemini does not provide word-level alignment
189
+ )
190
+
191
+ return supervision
192
+
193
+ finally:
194
+ # Clean up temporary file
195
+ if tmp_path.exists():
196
+ tmp_path.unlink()
197
+
198
+ async def _transcribe_with_simple_prompt(self, media_file: Path, language: Optional[str] = None) -> str:
199
+ """
200
+ Transcribe audio using a simple ASR prompt instead of complex instructions.
201
+
202
+ Args:
203
+ media_file: Path to audio file
204
+ language: Optional language code
205
+
206
+ Returns:
207
+ Transcribed text
208
+ """
209
+ client = self._get_client()
210
+
211
+ # Upload audio file
212
+ if self.config.verbose:
213
+ self.logger.info("📤 Uploading audio file to Gemini...")
214
+ uploaded_file = client.files.upload(file=str(media_file))
215
+
216
+ # Simple ASR prompt
217
+ system_prompt = "Transcribe the audio."
218
+ if language:
219
+ system_prompt = f"Transcribe the audio in {language}."
220
+
221
+ # Create simple generation config
222
+ simple_config = GenerateContentConfig(
223
+ system_instruction=system_prompt,
224
+ response_modalities=["TEXT"],
225
+ )
226
+
227
+ contents = Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type)
228
+ response = await asyncio.get_event_loop().run_in_executor(
229
+ None,
230
+ lambda: client.models.generate_content(
231
+ model=self.config.model_name,
232
+ contents=contents,
233
+ config=simple_config,
234
+ ),
235
+ )
236
+
237
+ if not response.text:
238
+ raise RuntimeError("Empty response from Gemini API")
239
+
240
+ transcript = response.text.strip()
241
+
242
+ if self.config.verbose:
243
+ self.logger.info(f"✅ Transcription completed: {len(transcript)} characters")
244
+
245
+ return transcript
246
+
121
247
  def _get_transcription_prompt(self) -> str:
122
248
  """Get (and cache) transcription system prompt from prompts module."""
123
249
  if self._system_prompt is not None:
@@ -1,10 +1,12 @@
1
1
  """Transcription module with config-driven architecture."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Optional, Union
4
+ from typing import List, Optional, Union
5
+
6
+ import numpy as np
5
7
 
6
8
  from lattifai.audio2 import AudioData
7
- from lattifai.caption import Caption
9
+ from lattifai.caption import Caption, Supervision
8
10
  from lattifai.config import TranscriptionConfig
9
11
  from lattifai.transcription.base import BaseTranscriber
10
12
  from lattifai.transcription.prompts import get_prompt_loader # noqa: F401
@@ -74,6 +76,32 @@ class LattifAITranscriber(BaseTranscriber):
74
76
 
75
77
  return caption
76
78
 
79
+ def transcribe_numpy(
80
+ self,
81
+ audio: Union[np.ndarray, List[np.ndarray]],
82
+ language: Optional[str] = None,
83
+ ) -> Union[Supervision, List[Supervision]]:
84
+ """
85
+ Transcribe audio from a numpy array (or list of arrays) and return Supervision.
86
+
87
+ Args:
88
+ audio: Audio data as numpy array (shape: [samples]),
89
+ or a list of such arrays for batch processing.
90
+ language: Optional language code for transcription.
91
+
92
+ Returns:
93
+ Supervision object (or list of Supervision objects) with transcription and alignment info.
94
+ """
95
+ if self._transcriber is None:
96
+ from lattifai_core.transcription import LattifAITranscriber as CoreLattifAITranscriber
97
+
98
+ self._transcriber = CoreLattifAITranscriber.from_pretrained(model_config=self.config)
99
+
100
+ # Delegate to core transcriber which handles both single arrays and lists
101
+ return self._transcriber.transcribe(
102
+ audio, language=language, return_hypotheses=True, progress_bar=False, timestamps=True
103
+ )[0]
104
+
77
105
  def write(
78
106
  self, transcript: Caption, output_file: Path, encoding: str = "utf-8", cache_audio_events: bool = True
79
107
  ) -> Path:
lattifai/utils.py CHANGED
@@ -1,10 +1,9 @@
1
1
  """Shared utility helpers for the LattifAI SDK."""
2
2
 
3
- import os
4
3
  import sys
5
4
  from datetime import datetime, timedelta
6
5
  from pathlib import Path
7
- from typing import Any, Optional, Type
6
+ from typing import Optional
8
7
 
9
8
  from lattifai.errors import ModelLoadError
10
9
 
@@ -45,85 +44,79 @@ def safe_print(text: str, **kwargs) -> None:
45
44
  print(text.encode("utf-8", errors="replace").decode("utf-8"), **kwargs)
46
45
 
47
46
 
48
- def _get_cache_marker_path(cache_dir: Path) -> Path:
49
- """Get the path for the cache marker file with current date."""
50
- today = datetime.now().strftime("%Y%m%d")
51
- return cache_dir / f".done{today}"
47
+ def _resolve_model_path(model_name_or_path: str, model_hub: str = "huggingface") -> str:
48
+ """Resolve model path, downloading from the specified model hub when necessary.
52
49
 
50
+ Args:
51
+ model_name_or_path: Local path or remote model identifier.
52
+ model_hub: Which hub to use for downloads. Supported: "huggingface", "modelscope".
53
+ """
54
+ if Path(model_name_or_path).expanduser().exists():
55
+ return str(Path(model_name_or_path).expanduser())
53
56
 
54
- def _is_cache_valid(cache_dir: Path) -> bool:
55
- """Check if cached model is valid (exists and not older than 1 days)."""
56
- if not cache_dir.exists():
57
- return False
58
-
59
- # Find any .done* marker files
60
- marker_files = list(cache_dir.glob(".done*"))
61
- if not marker_files:
62
- return False
57
+ # Normalize hub name
58
+ hub = (model_hub or "huggingface").lower()
63
59
 
64
- # Get the most recent marker file
65
- latest_marker = max(marker_files, key=lambda p: p.stat().st_mtime)
60
+ if hub not in ("huggingface", "modelscope"):
61
+ raise ValueError(f"Unsupported model_hub: {model_hub}. Supported: 'huggingface', 'modelscope'.")
66
62
 
67
- # Extract date from marker filename (format: .doneYYYYMMDD)
68
- try:
69
- date_str = latest_marker.name.replace(".done", "")
70
- marker_date = datetime.strptime(date_str, "%Y%m%d")
71
- # Check if marker is older than 1 days
72
- if datetime.now() - marker_date > timedelta(days=1):
73
- return False
74
- return True
75
- except (ValueError, IndexError):
76
- # Invalid marker file format, treat as invalid cache
77
- return False
78
-
79
-
80
- def _create_cache_marker(cache_dir: Path) -> None:
81
- """Create a cache marker file with current date and clean old markers."""
82
- # Remove old marker files
83
- for old_marker in cache_dir.glob(".done*"):
84
- old_marker.unlink(missing_ok=True)
85
-
86
- # Create new marker file
87
- marker_path = _get_cache_marker_path(cache_dir)
88
- marker_path.touch()
89
-
90
-
91
- def _resolve_model_path(model_name_or_path: str) -> str:
92
- """Resolve model path, downloading from Hugging Face when necessary."""
63
+ # If local path exists, return it regardless of hub
93
64
  if Path(model_name_or_path).expanduser().exists():
94
65
  return str(Path(model_name_or_path).expanduser())
95
66
 
96
- from huggingface_hub import snapshot_download
97
- from huggingface_hub.constants import HF_HUB_CACHE
98
- from huggingface_hub.errors import LocalEntryNotFoundError
67
+ if hub == "huggingface":
68
+ from huggingface_hub import HfApi, snapshot_download
69
+ from huggingface_hub.errors import LocalEntryNotFoundError
99
70
 
100
- # Determine cache directory for this model
101
- cache_dir = Path(HF_HUB_CACHE) / f'models--{model_name_or_path.replace("/", "--")}'
71
+ # Support repo_id@revision syntax
72
+ hf_repo_id = model_name_or_path
73
+ revision = None
74
+ if "@" in model_name_or_path:
75
+ hf_repo_id, revision = model_name_or_path.split("@", 1)
102
76
 
103
- # Check if we have a valid cached version
104
- if _is_cache_valid(cache_dir):
105
- # Return the snapshot path (latest version)
106
- snapshots_dir = cache_dir / "snapshots"
107
- if snapshots_dir.exists():
108
- snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]
109
- if snapshot_dirs:
110
- # Return the most recent snapshot
111
- latest_snapshot = max(snapshot_dirs, key=lambda p: p.stat().st_mtime)
112
- return str(latest_snapshot)
77
+ # If no specific revision/commit is provided, try to fetch the real latest SHA
78
+ # to bypass Hugging Face's model_info (metadata) sync lag.
79
+ if not revision:
80
+ try:
81
+ api = HfApi()
82
+ refs = api.list_repo_refs(repo_id=hf_repo_id, repo_type="model")
83
+ # Look for the default branch (usually 'main')
84
+ for branch in refs.branches:
85
+ if branch.name == "main":
86
+ revision = branch.target_commit
87
+ break
88
+ except Exception:
89
+ # Fallback to default behavior if API call fails
90
+ revision = None
113
91
 
114
- try:
115
- downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type="model")
116
- _create_cache_marker(cache_dir)
117
- return downloaded_path
118
- except LocalEntryNotFoundError:
119
92
  try:
120
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
121
- downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type="model")
122
- _create_cache_marker(cache_dir)
93
+ downloaded_path = snapshot_download(repo_id=hf_repo_id, repo_type="model", revision=revision)
123
94
  return downloaded_path
124
- except Exception as e: # pragma: no cover - bubble up for caller context
125
- raise ModelLoadError(model_name_or_path, original_error=e)
126
- except Exception as e: # pragma: no cover - unexpected download issue
95
+ except LocalEntryNotFoundError:
96
+ # Fall back to modelscope if HF entry not found
97
+ try:
98
+ from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
99
+
100
+ downloaded_path = ms_snapshot(model_name_or_path)
101
+ return downloaded_path
102
+ except Exception as e: # pragma: no cover - bubble up for caller context
103
+ raise ModelLoadError(model_name_or_path, original_error=e)
104
+ except Exception as e: # pragma: no cover - unexpected download issue
105
+ import colorful
106
+
107
+ print(colorful.red | f"Error downloading from Hugging Face Hub: {e}. Trying ModelScope...")
108
+ from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
109
+
110
+ downloaded_path = ms_snapshot(model_name_or_path)
111
+ return downloaded_path
112
+
113
+ # modelscope path
114
+ from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
115
+
116
+ try:
117
+ downloaded_path = ms_snapshot(model_name_or_path)
118
+ return downloaded_path
119
+ except Exception as e: # pragma: no cover
127
120
  raise ModelLoadError(model_name_or_path, original_error=e)
128
121
 
129
122
 
@@ -429,79 +429,77 @@ class YouTubeDownloader:
429
429
  result = await loop.run_in_executor(
430
430
  None, lambda: subprocess.run(ytdlp_options, capture_output=True, text=True, check=True)
431
431
  )
432
-
433
432
  # Only log success message, not full yt-dlp output
434
433
  self.logger.debug(f"yt-dlp output: {result.stdout.strip()}")
435
-
436
- # Find the downloaded transcript file
437
- caption_patterns = [
438
- f"{video_id}.*vtt",
439
- f"{video_id}.*srt",
440
- f"{video_id}.*sub",
441
- f"{video_id}.*sbv",
442
- f"{video_id}.*ssa",
443
- f"{video_id}.*ass",
444
- ]
445
-
446
- caption_files = []
447
- for pattern in caption_patterns:
448
- _caption_files = list(target_dir.glob(pattern))
449
- for caption_file in _caption_files:
450
- self.logger.info(f"📥 Downloaded caption: {caption_file}")
451
- caption_files.extend(_caption_files)
452
-
453
- if not caption_files:
454
- self.logger.warning("No caption available for this video")
455
- return None
456
-
457
- # If only one caption file, return it directly
458
- if len(caption_files) == 1:
459
- self.logger.info(f"✅ Using caption: {caption_files[0]}")
460
- return str(caption_files[0])
461
-
462
- # Multiple caption files found, let user choose
463
- if FileExistenceManager.is_interactive_mode():
464
- self.logger.info(f"📋 Found {len(caption_files)} caption files")
465
- caption_choice = FileExistenceManager.prompt_file_selection(
466
- file_type="caption",
467
- files=[str(f) for f in caption_files],
468
- operation="use",
469
- transcriber_name=transcriber_name,
470
- )
471
-
472
- if caption_choice == "cancel":
473
- raise RuntimeError("Caption selection cancelled by user")
474
- elif caption_choice == TRANSCRIBE_CHOICE:
475
- return caption_choice
476
- elif caption_choice:
477
- self.logger.info(f"✅ Selected caption: {caption_choice}")
478
- return caption_choice
479
- else:
480
- # Fallback to first file
481
- self.logger.info(f"✅ Using first caption: {caption_files[0]}")
482
- return str(caption_files[0])
483
- else:
484
- # Non-interactive mode: use first file
485
- self.logger.info(f"✅ Using first caption: {caption_files[0]}")
486
- return str(caption_files[0])
487
-
488
434
  except subprocess.CalledProcessError as e:
489
435
  error_msg = e.stderr.strip() if e.stderr else str(e)
490
436
 
491
437
  # Check for specific error conditions
492
438
  if "No automatic or manual captions found" in error_msg:
493
439
  self.logger.warning("No captions available for this video")
494
- return None
495
440
  elif "HTTP Error 429" in error_msg or "Too Many Requests" in error_msg:
496
441
  self.logger.error("YouTube rate limit exceeded. Please try again later or use a different method.")
497
- raise RuntimeError(
442
+ self.logger.error(
498
443
  "YouTube rate limit exceeded (HTTP 429). "
499
444
  "Try again later or use --cookies option with authenticated cookies. "
500
445
  "See: https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp"
501
446
  )
502
447
  else:
503
448
  self.logger.error(f"Failed to download transcript: {error_msg}")
504
- raise RuntimeError(f"Failed to download transcript: {error_msg}")
449
+
450
+ # Find the downloaded transcript file
451
+ caption_patterns = [
452
+ f"{video_id}.*vtt",
453
+ f"{video_id}.*srt",
454
+ f"{video_id}.*sub",
455
+ f"{video_id}.*sbv",
456
+ f"{video_id}.*ssa",
457
+ f"{video_id}.*ass",
458
+ ]
459
+
460
+ caption_files = []
461
+ for pattern in caption_patterns:
462
+ _caption_files = list(target_dir.glob(pattern))
463
+ for caption_file in _caption_files:
464
+ self.logger.info(f"📥 Downloaded caption: {caption_file}")
465
+ caption_files.extend(_caption_files)
466
+
467
+ # If only one caption file, return it directly
468
+ if len(caption_files) == 1:
469
+ self.logger.info(f"✅ Using caption: {caption_files[0]}")
470
+ return str(caption_files[0])
471
+
472
+ # Multiple caption files found, let user choose
473
+ if FileExistenceManager.is_interactive_mode():
474
+ self.logger.info(f"📋 Found {len(caption_files)} caption files")
475
+ caption_choice = FileExistenceManager.prompt_file_selection(
476
+ file_type="caption",
477
+ files=[str(f) for f in caption_files],
478
+ operation="use",
479
+ transcriber_name=transcriber_name,
480
+ )
481
+
482
+ if caption_choice == "cancel":
483
+ raise RuntimeError("Caption selection cancelled by user")
484
+ elif caption_choice == TRANSCRIBE_CHOICE:
485
+ return caption_choice
486
+ elif caption_choice:
487
+ self.logger.info(f"✅ Selected caption: {caption_choice}")
488
+ return caption_choice
489
+ elif caption_files:
490
+ # Fallback to first file
491
+ self.logger.info(f"✅ Using first caption: {caption_files[0]}")
492
+ return str(caption_files[0])
493
+ else:
494
+ self.logger.warning("No caption files available after download")
495
+ return None
496
+ elif caption_files:
497
+ # Non-interactive mode: use first file
498
+ self.logger.info(f"✅ Using first caption: {caption_files[0]}")
499
+ return str(caption_files[0])
500
+ else:
501
+ self.logger.warning("No caption files available after download")
502
+ return None
505
503
 
506
504
  async def list_available_captions(self, url: str) -> List[Dict[str, Any]]:
507
505
  """