mkv-episode-matcher 0.3.3__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. mkv_episode_matcher/__init__.py +8 -0
  2. mkv_episode_matcher/__main__.py +2 -177
  3. mkv_episode_matcher/asr_models.py +506 -0
  4. mkv_episode_matcher/cli.py +558 -0
  5. mkv_episode_matcher/core/config_manager.py +100 -0
  6. mkv_episode_matcher/core/engine.py +577 -0
  7. mkv_episode_matcher/core/matcher.py +214 -0
  8. mkv_episode_matcher/core/models.py +91 -0
  9. mkv_episode_matcher/core/providers/asr.py +85 -0
  10. mkv_episode_matcher/core/providers/subtitles.py +341 -0
  11. mkv_episode_matcher/core/utils.py +148 -0
  12. mkv_episode_matcher/episode_identification.py +550 -118
  13. mkv_episode_matcher/subtitle_utils.py +82 -0
  14. mkv_episode_matcher/tmdb_client.py +56 -14
  15. mkv_episode_matcher/ui/flet_app.py +708 -0
  16. mkv_episode_matcher/utils.py +262 -139
  17. mkv_episode_matcher-1.0.0.dist-info/METADATA +242 -0
  18. mkv_episode_matcher-1.0.0.dist-info/RECORD +23 -0
  19. {mkv_episode_matcher-0.3.3.dist-info → mkv_episode_matcher-1.0.0.dist-info}/WHEEL +1 -1
  20. mkv_episode_matcher-1.0.0.dist-info/licenses/LICENSE +21 -0
  21. mkv_episode_matcher/config.py +0 -82
  22. mkv_episode_matcher/episode_matcher.py +0 -100
  23. mkv_episode_matcher/libraries/pgs2srt/.gitignore +0 -2
  24. mkv_episode_matcher/libraries/pgs2srt/Libraries/SubZero/SubZero.py +0 -321
  25. mkv_episode_matcher/libraries/pgs2srt/Libraries/SubZero/dictionaries/data.py +0 -16700
  26. mkv_episode_matcher/libraries/pgs2srt/Libraries/SubZero/post_processing.py +0 -260
  27. mkv_episode_matcher/libraries/pgs2srt/README.md +0 -26
  28. mkv_episode_matcher/libraries/pgs2srt/__init__.py +0 -0
  29. mkv_episode_matcher/libraries/pgs2srt/imagemaker.py +0 -89
  30. mkv_episode_matcher/libraries/pgs2srt/pgs2srt.py +0 -150
  31. mkv_episode_matcher/libraries/pgs2srt/pgsreader.py +0 -225
  32. mkv_episode_matcher/libraries/pgs2srt/requirements.txt +0 -4
  33. mkv_episode_matcher/mkv_to_srt.py +0 -302
  34. mkv_episode_matcher/speech_to_text.py +0 -90
  35. mkv_episode_matcher-0.3.3.dist-info/METADATA +0 -125
  36. mkv_episode_matcher-0.3.3.dist-info/RECORD +0 -25
  37. {mkv_episode_matcher-0.3.3.dist-info → mkv_episode_matcher-1.0.0.dist-info}/entry_points.txt +0 -0
  38. {mkv_episode_matcher-0.3.3.dist-info → mkv_episode_matcher-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1 +1,9 @@
1
+ """MKV Episode Matcher package."""
1
2
 
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ try:
6
+ __version__ = version("mkv-episode-matcher")
7
+ except PackageNotFoundError:
8
+ # package is not installed
9
+ __version__ = "unknown"
@@ -1,184 +1,9 @@
1
- # __main__.py
2
- import argparse
3
- import os
1
+ from mkv_episode_matcher.cli import app
4
2
 
5
- from loguru import logger
6
3
 
7
- from mkv_episode_matcher.config import get_config, set_config
8
-
9
- # Log the start of the application
10
- logger.info("Starting the application")
11
-
12
-
13
- # Check if the configuration directory exists, if not create it
14
- if not os.path.exists(os.path.join(os.path.expanduser("~"), ".mkv-episode-matcher")):
15
- os.makedirs(os.path.join(os.path.expanduser("~"), ".mkv-episode-matcher"))
16
-
17
- # Define the paths for the configuration file and cache directory
18
- CONFIG_FILE = os.path.join(
19
- os.path.expanduser("~"), ".mkv-episode-matcher", "config.ini"
20
- )
21
- CACHE_DIR = os.path.join(os.path.expanduser("~"), ".mkv-episode-matcher", "cache")
22
-
23
- # Check if the cache directory exists, if not create it
24
- if not os.path.exists(CACHE_DIR):
25
- os.makedirs(CACHE_DIR)
26
-
27
- # Check if logs directory exists, if not create it
28
- log_dir = os.path.join(os.path.expanduser("~"), ".mkv-episode-matcher", "logs")
29
- if not os.path.exists(log_dir):
30
- os.mkdir(log_dir)
31
-
32
- # Add a new handler for stdout logs
33
- logger.add(
34
- os.path.join(log_dir, "stdout.log"),
35
- format="{time} {level} {message}",
36
- level="DEBUG",
37
- rotation="10 MB",
38
- )
39
-
40
- # Add a new handler for error logs
41
- logger.add(os.path.join(log_dir, "stderr.log"), level="ERROR", rotation="10 MB")
42
-
43
-
44
- @logger.catch
45
4
  def main():
46
- """
47
- Entry point of the application.
48
-
49
- This function is responsible for starting the application, parsing command-line arguments,
50
- setting the configuration, and processing the show.
51
-
52
- Command-line arguments:
53
- --tmdb-api-key: The API key for the TMDb API. If not provided, the function will try to get it from the cache or prompt the user to input it.
54
- --show-dir: The main directory of the show. If not provided, the function will prompt the user to input it.
55
- --season: The season number to be processed. If not provided, all seasons will be processed.
56
- --dry-run: A boolean flag indicating whether to perform a dry run (i.e., not rename any files). If not provided, the function will rename files.
57
- --get-subs: A boolean flag indicating whether to download subtitles for the show. If not provided, the function will not download subtitles.
58
- --tesseract-path: The path to the tesseract executable. If not provided, the function will try to get it from the cache or prompt the user to input it.
59
-
60
- The function logs its progress to two separate log files: one for standard output and one for errors.
61
- """
62
-
63
- # Parse command-line arguments
64
- parser = argparse.ArgumentParser(description="Process shows with TMDb API")
65
- parser.add_argument("--tmdb-api-key", help="TMDb API key")
66
- parser.add_argument("--show-dir", help="Main directory of the show")
67
- parser.add_argument(
68
- "--season",
69
- type=int,
70
- default=None,
71
- nargs="?",
72
- help="Specify the season number to be processed (default: None)",
73
- )
74
- parser.add_argument(
75
- "--dry-run",
76
- type=bool,
77
- default=None,
78
- nargs="?",
79
- help="Don't rename any files (default: None)",
80
- )
81
- parser.add_argument(
82
- "--get-subs",
83
- type=bool,
84
- default=None,
85
- nargs="?",
86
- help="Download subtitles for the show (default: None)",
87
- )
88
- parser.add_argument(
89
- "--tesseract-path",
90
- type=str,
91
- default=None,
92
- nargs="?",
93
- help="Path to the tesseract executable (default: None)",
94
- )
95
- args = parser.parse_args()
96
- logger.debug(f"Command-line arguments: {args}")
97
- open_subtitles_api_key = ""
98
- open_subtitles_user_agent = ""
99
- open_subtitles_username = ""
100
- open_subtitles_password = ""
101
- # Check if API key is provided via command-line argument
102
- tmdb_api_key = args.tmdb_api_key
103
-
104
- # If API key is not provided, try to get it from the cache
105
- if not tmdb_api_key:
106
- cached_config = get_config(CONFIG_FILE)
107
- if cached_config:
108
- tmdb_api_key = cached_config.get("tmdb_api_key")
109
-
110
- # If API key is still not available, prompt the user to input it
111
- if not tmdb_api_key:
112
- tmdb_api_key = input("Enter your TMDb API key: ")
113
- # Cache the API key
114
-
115
- logger.debug(f"TMDb API Key: {tmdb_api_key}")
116
- logger.debug("Getting OpenSubtitles API key")
117
- cached_config = get_config(CONFIG_FILE)
118
- try:
119
- open_subtitles_api_key = cached_config.get("open_subtitles_api_key")
120
- open_subtitles_user_agent = cached_config.get("open_subtitles_user_agent")
121
- open_subtitles_username = cached_config.get("open_subtitles_username")
122
- open_subtitles_password = cached_config.get("open_subtitles_password")
123
- except:
124
- pass
125
- if args.get_subs:
126
- if not open_subtitles_api_key:
127
- open_subtitles_api_key = input("Enter your OpenSubtitles API key: ")
128
-
129
- if not open_subtitles_user_agent:
130
- open_subtitles_user_agent = input("Enter your OpenSubtitles User Agent: ")
131
-
132
- if not open_subtitles_username:
133
- open_subtitles_username = input("Enter your OpenSubtitles Username: ")
134
-
135
- if not open_subtitles_password:
136
- open_subtitles_password = input("Enter your OpenSubtitles Password: ")
137
-
138
- # If show directory is provided via command-line argument, use it
139
- show_dir = args.show_dir
140
- if not show_dir:
141
- show_dir = cached_config.get("show_dir")
142
- if not show_dir:
143
- # If show directory is not provided, prompt the user to input it
144
- show_dir = input("Enter the main directory of the show:")
145
- logger.info(f"Show Directory: {show_dir}")
146
- # if the user does not provide a show directory, make the default show directory the current working directory
147
- if not show_dir:
148
- show_dir = os.getcwd()
149
- if not args.tesseract_path:
150
- tesseract_path = cached_config.get("tesseract_path")
151
-
152
- if not tesseract_path:
153
- tesseract_path = input(
154
- r"Enter the path to the tesseract executable: ['C:\Program Files\Tesseract-OCR\tesseract.exe']"
155
- )
156
-
157
- else:
158
- tesseract_path = args.tesseract_path
159
- logger.debug(f"Teesseract Path: {tesseract_path}")
160
- logger.debug(f"Show Directory: {show_dir}")
161
-
162
- # Set the configuration
163
- set_config(
164
- tmdb_api_key,
165
- open_subtitles_api_key,
166
- open_subtitles_user_agent,
167
- open_subtitles_username,
168
- open_subtitles_password,
169
- show_dir,
170
- CONFIG_FILE,
171
- tesseract_path=tesseract_path,
172
- )
173
- logger.info("Configuration set")
174
-
175
- # Process the show
176
- from mkv_episode_matcher.episode_matcher import process_show
177
-
178
- process_show(args.season, dry_run=args.dry_run, get_subs=args.get_subs)
179
- logger.info("Show processing completed")
5
+ app()
180
6
 
181
7
 
182
- # Run the main function if the script is run directly
183
8
  if __name__ == "__main__":
184
9
  main()
@@ -0,0 +1,506 @@
1
+ """
2
+ ASR Model Abstraction Layer
3
+
4
+ This module provides a unified interface for different Automatic Speech Recognition models,
5
+ including OpenAI Whisper and NVIDIA Parakeet models.
6
+ """
7
+
8
+ import abc
9
+ import os
10
+ import re
11
+ import tempfile
12
+ from pathlib import Path
13
+
14
+ import librosa
15
+ import numpy as np
16
+ import soundfile as sf
17
+ import torch
18
+ from loguru import logger
19
+ from rapidfuzz import fuzz
20
+
21
+ # Cache for loaded models to avoid reloading
22
+ _model_cache = {}
23
+
24
+
25
+ class ASRModel(abc.ABC):
26
+ """Abstract base class for ASR models."""
27
+
28
+ def __init__(self, model_name: str, device: str | None = None):
29
+ """
30
+ Initialize ASR model.
31
+
32
+ Args:
33
+ model_name: Name/identifier of the model
34
+ device: Device to run on ('cpu', 'cuda', or None for auto-detect)
35
+ """
36
+ self.model_name = model_name
37
+ self.device = device or self._get_default_device()
38
+ self._model = None
39
+
40
+ def _get_default_device(self) -> str:
41
+ """Get default device for this model type."""
42
+ return "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+ @abc.abstractmethod
45
+ def load(self):
46
+ """Load the model. Should be called before transcription."""
47
+ pass
48
+
49
+ @abc.abstractmethod
50
+ def transcribe(self, audio_path: str | Path) -> dict:
51
+ """
52
+ Transcribe audio file.
53
+
54
+ Args:
55
+ audio_path: Path to audio file
56
+
57
+ Returns:
58
+ Dictionary with at least 'text' key containing transcription
59
+ """
60
+ pass
61
+
62
+ def calculate_match_score(self, transcription: str, reference: str) -> float:
63
+ """
64
+ Calculate similarity score between transcription and reference.
65
+
66
+ Args:
67
+ transcription: Transcribed text
68
+ reference: Reference subtitle text
69
+
70
+ Returns:
71
+ Float score between 0.0 and 1.0
72
+ """
73
+ # Default implementation: Standard weights
74
+ # Token sort ratio (70%) + Partial ratio (30%)
75
+ token_weight = 0.7
76
+ partial_weight = 0.3
77
+
78
+ score = (
79
+ fuzz.token_sort_ratio(transcription, reference) * token_weight
80
+ + fuzz.partial_ratio(transcription, reference) * partial_weight
81
+ ) / 100.0
82
+
83
+ return score
84
+
85
+ @property
86
+ def is_loaded(self) -> bool:
87
+ """Check if model is loaded."""
88
+ return self._model is not None
89
+
90
+ def unload(self):
91
+ """Unload model to free memory."""
92
+ self._model = None
93
+
94
+
95
+ class ParakeetTDTModel(ASRModel):
96
+ """
97
+ NVIDIA Parakeet TDT ASR model implementation.
98
+
99
+ WARNING: This model (TDT) uses the Transducer decoder which requires significant GPU resources
100
+ and may be unstable on some Windows configurations (CUDA errors).
101
+ """
102
+
103
+ def __init__(
104
+ self, model_name: str = "nvidia/parakeet-tdt-0.6b-v2", device: str | None = None
105
+ ):
106
+ """
107
+ Initialize Parakeet TDT model.
108
+
109
+ Args:
110
+ model_name: Parakeet model identifier from HuggingFace
111
+ device: Device to run on
112
+ """
113
+ super().__init__(model_name, device)
114
+
115
+ def load(self):
116
+ """Load Parakeet model with caching."""
117
+ if self.is_loaded:
118
+ return
119
+
120
+ cache_key = f"parakeet_tdt_{self.model_name}_{self.device}"
121
+
122
+ if cache_key in _model_cache:
123
+ self._model = _model_cache[cache_key]
124
+ logger.debug(
125
+ f"Using cached Parakeet TDT model: {self.model_name} on {self.device}"
126
+ )
127
+ return
128
+
129
+ try:
130
+ # Windows compatibility: Patch signal module before importing NeMo
131
+ if os.name == "nt": # Windows
132
+ import signal
133
+
134
+ if not hasattr(signal, "SIGKILL"):
135
+ # Add missing signal constants for Windows compatibility
136
+ signal.SIGKILL = 9
137
+ signal.SIGTERM = 15
138
+
139
+ import nemo.collections.asr as nemo_asr
140
+
141
+ # Store original environment variables for restoration
142
+ original_env = {}
143
+
144
+ # Configure environment to suppress NeMo warnings and optimize performance
145
+ nemo_env_settings = {
146
+ "NEMO_DISABLE_TRAINING_LOGS": "1",
147
+ "NEMO_DISABLE_HYDRA_LOGS": "1",
148
+ "HYDRA_FULL_ERROR": "0",
149
+ "PYTHONWARNINGS": "ignore::UserWarning",
150
+ "TOKENIZERS_PARALLELISM": "false", # Avoid tokenizer warnings
151
+ }
152
+
153
+ # Windows compatibility: Add optimizations but avoid signal issues
154
+ if os.name == "nt": # Windows
155
+ nemo_env_settings.update({
156
+ "OMP_NUM_THREADS": "1",
157
+ "MKL_NUM_THREADS": "1",
158
+ "NEMO_BYPASS_SIGNALS": "1", # Bypass NeMo signal handling on Windows
159
+ })
160
+
161
+ for key, value in nemo_env_settings.items():
162
+ original_env[key] = os.environ.get(key)
163
+ os.environ[key] = value
164
+
165
+ try:
166
+ # Set device for NeMo
167
+ if self.device == "cuda" and torch.cuda.is_available():
168
+ # NeMo will automatically use CUDA if available
169
+ pass
170
+ elif self.device == "cpu":
171
+ # Force CPU usage - NeMo respects CUDA_VISIBLE_DEVICES=""
172
+ original_env["CUDA_VISIBLE_DEVICES"] = os.environ.get(
173
+ "CUDA_VISIBLE_DEVICES"
174
+ )
175
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
176
+
177
+ # Load model with reduced verbosity
178
+ self._model = nemo_asr.models.ASRModel.from_pretrained(
179
+ model_name=self.model_name,
180
+ strict=False, # Allow loading with missing keys to reduce warnings
181
+ )
182
+
183
+ # Configure model for optimal inference
184
+ if hasattr(self._model, "set_batch_size"):
185
+ self._model.set_batch_size(1) # Optimize for single file processing
186
+
187
+ # Fix for Windows: Force num_workers to 0 to avoid multiprocessing errors/locks
188
+ if hasattr(self._model, "cfg"):
189
+ for ds_config in ["test_ds", "validation_ds"]:
190
+ if ds_config in self._model.cfg:
191
+ self._model.cfg[ds_config].num_workers = 0
192
+
193
+ if hasattr(self._model, "eval"):
194
+ self._model.eval() # Set to evaluation mode
195
+
196
+ finally:
197
+ # Restore original environment variables
198
+ for key, original_value in original_env.items():
199
+ if original_value is not None:
200
+ os.environ[key] = original_value
201
+ elif key in os.environ:
202
+ del os.environ[key]
203
+
204
+ _model_cache[cache_key] = self._model
205
+ logger.info(
206
+ f"Loaded Parakeet TDT model: {self.model_name} on {self.device}"
207
+ )
208
+
209
+ except ImportError as e:
210
+ raise ImportError(
211
+ "NVIDIA NeMo not installed. Run: pip install nemo_toolkit[asr]"
212
+ ) from e
213
+ except Exception as e:
214
+ logger.error(f"Failed to load Parakeet TDT model {self.model_name}: {e}")
215
+ raise
216
+
217
+ def _preprocess_audio(self, audio_path: str | Path) -> str:
218
+ """
219
+ Preprocess audio for Parakeet model requirements.
220
+
221
+ Args:
222
+ audio_path: Path to input audio file
223
+
224
+ Returns:
225
+ Path to preprocessed audio file
226
+ """
227
+ try:
228
+ # Load audio with librosa
229
+ audio, original_sr = librosa.load(str(audio_path), sr=None)
230
+
231
+ # Target sample rate for Parakeet models (16kHz is optimal)
232
+ target_sr = 16000
233
+
234
+ # Resample if necessary
235
+ if original_sr != target_sr:
236
+ audio = librosa.resample(
237
+ audio, orig_sr=original_sr, target_sr=target_sr
238
+ )
239
+ logger.debug(f"Resampled audio from {original_sr}Hz to {target_sr}Hz")
240
+
241
+ # Normalize audio to [-1, 1] range
242
+ if np.max(np.abs(audio)) > 0:
243
+ audio = audio / np.max(np.abs(audio))
244
+
245
+ # Create temporary file for preprocessed audio
246
+ temp_dir = Path(tempfile.gettempdir()) / "parakeet_preprocessed"
247
+ temp_dir.mkdir(exist_ok=True)
248
+
249
+ temp_audio_path = temp_dir / f"preprocessed_{Path(audio_path).stem}.wav"
250
+
251
+ # Save preprocessed audio
252
+ sf.write(str(temp_audio_path), audio, target_sr)
253
+
254
+ logger.debug(f"Preprocessed audio saved to {temp_audio_path}")
255
+ return str(temp_audio_path)
256
+
257
+ except Exception as e:
258
+ logger.warning(f"Audio preprocessing failed, using original: {e}")
259
+ return str(audio_path)
260
+
261
+ def _clean_transcription_text(self, text: str) -> str:
262
+ """
263
+ Clean and normalize transcription text using EXACT same method as EpisodeMatcher.
264
+
265
+ This ensures compatibility with the existing matching algorithm.
266
+
267
+ Args:
268
+ text: Raw transcription text
269
+
270
+ Returns:
271
+ Cleaned text using identical cleaning as EpisodeMatcher.clean_text()
272
+ """
273
+ if not text:
274
+ return ""
275
+
276
+ # Use EXACT same cleaning logic as EpisodeMatcher.clean_text()
277
+ text = text.lower().strip()
278
+ text = re.sub(r"\[.*?\]|\<.*?\>", "", text)
279
+ text = re.sub(r"([A-Za-z])-\1+", r"\1", text)
280
+ return " ".join(text.split())
281
+
282
+ def calculate_match_score(self, transcription: str, reference: str) -> float:
283
+ """
284
+ Calculate similarity score with Parakeet-specific weights.
285
+ Parakeet produces longer, more detailed transcriptions, so we favor partial matches.
286
+ """
287
+ # Parakeet weights: Boost partial_ratio
288
+ token_weight = 0.4
289
+ partial_weight = 0.6
290
+
291
+ # Additional boost for very detailed transcriptions
292
+ length_ratio = len(transcription) / max(len(reference), 1)
293
+ if length_ratio > 2.0: # Much longer transcription
294
+ partial_weight = 0.8
295
+ token_weight = 0.2
296
+
297
+ score = (
298
+ fuzz.token_sort_ratio(transcription, reference) * token_weight
299
+ + fuzz.partial_ratio(transcription, reference) * partial_weight
300
+ ) / 100.0
301
+
302
+ return score
303
+
304
+ def transcribe(self, audio_path: str | Path) -> dict:
305
+ """
306
+ Transcribe audio using Parakeet with preprocessing and text normalization.
307
+
308
+ Args:
309
+ audio_path: Path to audio file
310
+
311
+ Returns:
312
+ Dictionary with 'text' and 'segments' from Parakeet
313
+ """
314
+ if not self.is_loaded:
315
+ self.load()
316
+
317
+ preprocessed_audio = None
318
+ try:
319
+ logger.debug(f"Starting Parakeet transcription for {audio_path}")
320
+
321
+ # Preprocess audio for optimal Parakeet performance
322
+ preprocessed_audio = self._preprocess_audio(audio_path)
323
+
324
+ # Configure NeMo model settings to reduce warnings
325
+ old_env_vars = {}
326
+ try:
327
+ # Set environment variables to reduce NeMo warnings
328
+ env_settings = {
329
+ "CUDA_LAUNCH_BLOCKING": "0",
330
+ "NEMO_DISABLE_TRAINING_LOGS": "1",
331
+ }
332
+
333
+ for key, value in env_settings.items():
334
+ old_env_vars[key] = os.environ.get(key)
335
+ os.environ[key] = value
336
+
337
+ # Parakeet expects list of file paths
338
+ result = self._model.transcribe([preprocessed_audio])
339
+
340
+ finally:
341
+ # Restore original environment variables
342
+ for key, old_value in old_env_vars.items():
343
+ if old_value is not None:
344
+ os.environ[key] = old_value
345
+ elif key in os.environ:
346
+ del os.environ[key]
347
+
348
+ logger.debug(f"Parakeet raw result: {result}, type: {type(result)}")
349
+
350
+ # Extract text from result
351
+ raw_text = ""
352
+ if isinstance(result, list) and len(result) > 0:
353
+ if hasattr(result[0], "text"):
354
+ raw_text = result[0].text
355
+ elif isinstance(result[0], str):
356
+ raw_text = result[0]
357
+ else:
358
+ raw_text = str(result[0])
359
+ else:
360
+ logger.warning(f"Unexpected Parakeet result format: {result}")
361
+ raw_text = ""
362
+
363
+ # Clean and normalize the transcription
364
+ cleaned_text = self._clean_transcription_text(raw_text)
365
+
366
+ logger.debug(f"Raw transcription: '{raw_text}'")
367
+ logger.debug(f"Cleaned transcription: '{cleaned_text}'")
368
+
369
+ return {
370
+ "text": cleaned_text,
371
+ "raw_text": raw_text,
372
+ "segments": [],
373
+ "language": "en",
374
+ }
375
+
376
+ except Exception as e:
377
+ logger.error(
378
+ f"Parakeet transcription failed for {audio_path}: {type(e).__name__}: {e}"
379
+ )
380
+ import traceback
381
+
382
+ traceback.print_exc()
383
+ # Return empty result instead of raising to allow fallback
384
+ return {"text": "", "raw_text": "", "segments": [], "language": "en"}
385
+ finally:
386
+ # Clean up preprocessed audio file
387
+ if preprocessed_audio and preprocessed_audio != str(audio_path):
388
+ try:
389
+ Path(preprocessed_audio).unlink(missing_ok=True)
390
+ except Exception as e:
391
+ logger.debug(f"Failed to clean up preprocessed audio: {e}")
392
+
393
+
394
+ class ParakeetCTCModel(ParakeetTDTModel):
395
+ """
396
+ NVIDIA Parakeet CTC ASR model implementation.
397
+
398
+ This uses the CTC decoder which is more stable and robust on various hardware
399
+ than the TDT version, though potentially slightly less accurate.
400
+ """
401
+
402
+ def __init__(
403
+ self, model_name: str = "nvidia/parakeet-ctc-0.6b", device: str | None = None
404
+ ):
405
+ """
406
+ Initialize Parakeet CTC model.
407
+
408
+ Args:
409
+ model_name: Parakeet model identifier (default: nvidia/parakeet-ctc-0.6b)
410
+ device: Device to run on
411
+ """
412
+ # Ensure we use a CTC-compatible model name if not specified
413
+ # But we trust the user input if provided.
414
+ super().__init__(model_name, device)
415
+
416
+ def load(self):
417
+ """Load Parakeet CTC model with caching."""
418
+ # We override load simply to use a different cache key if needed, or we can just reuse parent load
419
+ # reusing parent load is fine as it uses self.model_name in cache key.
420
+ # But we need to ensure the logging says CTC.
421
+ super().load()
422
+
423
+
424
+ def create_asr_model(model_config: dict) -> ASRModel:
425
+ """
426
+ Factory function to create ASR models from configuration.
427
+
428
+ Args:
429
+ model_config: Dictionary with 'type' and 'name' keys
430
+
431
+ Returns:
432
+ Configured ASRModel instance
433
+
434
+ Example:
435
+ model_config = {"type": "parakeet", "name": "nvidia/parakeet-ctc-0.6b"}
436
+ model = create_asr_model(model_config)
437
+ """
438
+ model_type = model_config.get("type", "").lower()
439
+ model_name = model_config.get("name", "")
440
+ device = model_config.get("device")
441
+
442
+ if model_type == "parakeet":
443
+ # Always use the specific working model
444
+ if not model_name:
445
+ model_name = "nvidia/parakeet-ctc-0.6b"
446
+ return ParakeetCTCModel(model_name, device)
447
+ else:
448
+ raise ValueError(
449
+ f"Unsupported model type: {model_type}. Only 'parakeet' is supported."
450
+ )
451
+
452
+
453
+ def get_cached_model(model_config: dict) -> ASRModel:
454
+ """
455
+ Get a cached model instance, creating it if necessary.
456
+
457
+ Args:
458
+ model_config: Dictionary with model configuration
459
+
460
+ Returns:
461
+ ASRModel instance (loaded and ready for use)
462
+ """
463
+ cache_key = f"{model_config.get('type', '')}_{model_config.get('name', '')}_{model_config.get('device', 'auto')}"
464
+
465
+ if cache_key not in _model_cache:
466
+ model = create_asr_model(model_config)
467
+ model.load() # Load immediately for caching
468
+ _model_cache[cache_key] = model
469
+
470
+ return _model_cache[cache_key]
471
+
472
+
473
+ def clear_model_cache():
474
+ """Clear all cached models to free memory."""
475
+ global _model_cache
476
+ for model in _model_cache.values():
477
+ if hasattr(model, "unload"):
478
+ model.unload()
479
+ _model_cache.clear()
480
+ logger.info("Cleared ASR model cache")
481
+
482
+
483
+ def list_available_models() -> dict:
484
+ """
485
+ List available model types and their requirements.
486
+
487
+ Returns:
488
+ Dictionary with model types and their availability status
489
+ """
490
+ availability = {}
491
+
492
+ # Check Parakeet availability
493
+ try:
494
+ import nemo.collections.asr # noqa: F401
495
+
496
+ availability["parakeet"] = {
497
+ "available": True,
498
+ "models": ["nvidia/parakeet-ctc-0.6b"],
499
+ }
500
+ except ImportError:
501
+ availability["parakeet"] = {
502
+ "available": False,
503
+ "error": "NVIDIA NeMo not installed",
504
+ }
505
+
506
+ return availability