mkv-episode-matcher 0.3.3__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mkv_episode_matcher/__init__.py +8 -0
- mkv_episode_matcher/__main__.py +2 -177
- mkv_episode_matcher/asr_models.py +506 -0
- mkv_episode_matcher/cli.py +558 -0
- mkv_episode_matcher/core/config_manager.py +100 -0
- mkv_episode_matcher/core/engine.py +577 -0
- mkv_episode_matcher/core/matcher.py +214 -0
- mkv_episode_matcher/core/models.py +91 -0
- mkv_episode_matcher/core/providers/asr.py +85 -0
- mkv_episode_matcher/core/providers/subtitles.py +341 -0
- mkv_episode_matcher/core/utils.py +148 -0
- mkv_episode_matcher/episode_identification.py +550 -118
- mkv_episode_matcher/subtitle_utils.py +82 -0
- mkv_episode_matcher/tmdb_client.py +56 -14
- mkv_episode_matcher/ui/flet_app.py +708 -0
- mkv_episode_matcher/utils.py +262 -139
- mkv_episode_matcher-1.0.0.dist-info/METADATA +242 -0
- mkv_episode_matcher-1.0.0.dist-info/RECORD +23 -0
- {mkv_episode_matcher-0.3.3.dist-info → mkv_episode_matcher-1.0.0.dist-info}/WHEEL +1 -1
- mkv_episode_matcher-1.0.0.dist-info/licenses/LICENSE +21 -0
- mkv_episode_matcher/config.py +0 -82
- mkv_episode_matcher/episode_matcher.py +0 -100
- mkv_episode_matcher/libraries/pgs2srt/.gitignore +0 -2
- mkv_episode_matcher/libraries/pgs2srt/Libraries/SubZero/SubZero.py +0 -321
- mkv_episode_matcher/libraries/pgs2srt/Libraries/SubZero/dictionaries/data.py +0 -16700
- mkv_episode_matcher/libraries/pgs2srt/Libraries/SubZero/post_processing.py +0 -260
- mkv_episode_matcher/libraries/pgs2srt/README.md +0 -26
- mkv_episode_matcher/libraries/pgs2srt/__init__.py +0 -0
- mkv_episode_matcher/libraries/pgs2srt/imagemaker.py +0 -89
- mkv_episode_matcher/libraries/pgs2srt/pgs2srt.py +0 -150
- mkv_episode_matcher/libraries/pgs2srt/pgsreader.py +0 -225
- mkv_episode_matcher/libraries/pgs2srt/requirements.txt +0 -4
- mkv_episode_matcher/mkv_to_srt.py +0 -302
- mkv_episode_matcher/speech_to_text.py +0 -90
- mkv_episode_matcher-0.3.3.dist-info/METADATA +0 -125
- mkv_episode_matcher-0.3.3.dist-info/RECORD +0 -25
- {mkv_episode_matcher-0.3.3.dist-info → mkv_episode_matcher-1.0.0.dist-info}/entry_points.txt +0 -0
- {mkv_episode_matcher-0.3.3.dist-info → mkv_episode_matcher-1.0.0.dist-info}/top_level.txt +0 -0
mkv_episode_matcher/__init__.py
CHANGED
mkv_episode_matcher/__main__.py
CHANGED
|
@@ -1,184 +1,9 @@
|
|
|
1
|
-
|
|
2
|
-
import argparse
|
|
3
|
-
import os
|
|
1
|
+
from mkv_episode_matcher.cli import app
|
|
4
2
|
|
|
5
|
-
from loguru import logger
|
|
6
3
|
|
|
7
|
-
from mkv_episode_matcher.config import get_config, set_config
|
|
8
|
-
|
|
9
|
-
# Log the start of the application
|
|
10
|
-
logger.info("Starting the application")
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
# Check if the configuration directory exists, if not create it
|
|
14
|
-
if not os.path.exists(os.path.join(os.path.expanduser("~"), ".mkv-episode-matcher")):
|
|
15
|
-
os.makedirs(os.path.join(os.path.expanduser("~"), ".mkv-episode-matcher"))
|
|
16
|
-
|
|
17
|
-
# Define the paths for the configuration file and cache directory
|
|
18
|
-
CONFIG_FILE = os.path.join(
|
|
19
|
-
os.path.expanduser("~"), ".mkv-episode-matcher", "config.ini"
|
|
20
|
-
)
|
|
21
|
-
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".mkv-episode-matcher", "cache")
|
|
22
|
-
|
|
23
|
-
# Check if the cache directory exists, if not create it
|
|
24
|
-
if not os.path.exists(CACHE_DIR):
|
|
25
|
-
os.makedirs(CACHE_DIR)
|
|
26
|
-
|
|
27
|
-
# Check if logs directory exists, if not create it
|
|
28
|
-
log_dir = os.path.join(os.path.expanduser("~"), ".mkv-episode-matcher", "logs")
|
|
29
|
-
if not os.path.exists(log_dir):
|
|
30
|
-
os.mkdir(log_dir)
|
|
31
|
-
|
|
32
|
-
# Add a new handler for stdout logs
|
|
33
|
-
logger.add(
|
|
34
|
-
os.path.join(log_dir, "stdout.log"),
|
|
35
|
-
format="{time} {level} {message}",
|
|
36
|
-
level="DEBUG",
|
|
37
|
-
rotation="10 MB",
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
# Add a new handler for error logs
|
|
41
|
-
logger.add(os.path.join(log_dir, "stderr.log"), level="ERROR", rotation="10 MB")
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@logger.catch
|
|
45
4
|
def main():
|
|
46
|
-
|
|
47
|
-
Entry point of the application.
|
|
48
|
-
|
|
49
|
-
This function is responsible for starting the application, parsing command-line arguments,
|
|
50
|
-
setting the configuration, and processing the show.
|
|
51
|
-
|
|
52
|
-
Command-line arguments:
|
|
53
|
-
--tmdb-api-key: The API key for the TMDb API. If not provided, the function will try to get it from the cache or prompt the user to input it.
|
|
54
|
-
--show-dir: The main directory of the show. If not provided, the function will prompt the user to input it.
|
|
55
|
-
--season: The season number to be processed. If not provided, all seasons will be processed.
|
|
56
|
-
--dry-run: A boolean flag indicating whether to perform a dry run (i.e., not rename any files). If not provided, the function will rename files.
|
|
57
|
-
--get-subs: A boolean flag indicating whether to download subtitles for the show. If not provided, the function will not download subtitles.
|
|
58
|
-
--tesseract-path: The path to the tesseract executable. If not provided, the function will try to get it from the cache or prompt the user to input it.
|
|
59
|
-
|
|
60
|
-
The function logs its progress to two separate log files: one for standard output and one for errors.
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
# Parse command-line arguments
|
|
64
|
-
parser = argparse.ArgumentParser(description="Process shows with TMDb API")
|
|
65
|
-
parser.add_argument("--tmdb-api-key", help="TMDb API key")
|
|
66
|
-
parser.add_argument("--show-dir", help="Main directory of the show")
|
|
67
|
-
parser.add_argument(
|
|
68
|
-
"--season",
|
|
69
|
-
type=int,
|
|
70
|
-
default=None,
|
|
71
|
-
nargs="?",
|
|
72
|
-
help="Specify the season number to be processed (default: None)",
|
|
73
|
-
)
|
|
74
|
-
parser.add_argument(
|
|
75
|
-
"--dry-run",
|
|
76
|
-
type=bool,
|
|
77
|
-
default=None,
|
|
78
|
-
nargs="?",
|
|
79
|
-
help="Don't rename any files (default: None)",
|
|
80
|
-
)
|
|
81
|
-
parser.add_argument(
|
|
82
|
-
"--get-subs",
|
|
83
|
-
type=bool,
|
|
84
|
-
default=None,
|
|
85
|
-
nargs="?",
|
|
86
|
-
help="Download subtitles for the show (default: None)",
|
|
87
|
-
)
|
|
88
|
-
parser.add_argument(
|
|
89
|
-
"--tesseract-path",
|
|
90
|
-
type=str,
|
|
91
|
-
default=None,
|
|
92
|
-
nargs="?",
|
|
93
|
-
help="Path to the tesseract executable (default: None)",
|
|
94
|
-
)
|
|
95
|
-
args = parser.parse_args()
|
|
96
|
-
logger.debug(f"Command-line arguments: {args}")
|
|
97
|
-
open_subtitles_api_key = ""
|
|
98
|
-
open_subtitles_user_agent = ""
|
|
99
|
-
open_subtitles_username = ""
|
|
100
|
-
open_subtitles_password = ""
|
|
101
|
-
# Check if API key is provided via command-line argument
|
|
102
|
-
tmdb_api_key = args.tmdb_api_key
|
|
103
|
-
|
|
104
|
-
# If API key is not provided, try to get it from the cache
|
|
105
|
-
if not tmdb_api_key:
|
|
106
|
-
cached_config = get_config(CONFIG_FILE)
|
|
107
|
-
if cached_config:
|
|
108
|
-
tmdb_api_key = cached_config.get("tmdb_api_key")
|
|
109
|
-
|
|
110
|
-
# If API key is still not available, prompt the user to input it
|
|
111
|
-
if not tmdb_api_key:
|
|
112
|
-
tmdb_api_key = input("Enter your TMDb API key: ")
|
|
113
|
-
# Cache the API key
|
|
114
|
-
|
|
115
|
-
logger.debug(f"TMDb API Key: {tmdb_api_key}")
|
|
116
|
-
logger.debug("Getting OpenSubtitles API key")
|
|
117
|
-
cached_config = get_config(CONFIG_FILE)
|
|
118
|
-
try:
|
|
119
|
-
open_subtitles_api_key = cached_config.get("open_subtitles_api_key")
|
|
120
|
-
open_subtitles_user_agent = cached_config.get("open_subtitles_user_agent")
|
|
121
|
-
open_subtitles_username = cached_config.get("open_subtitles_username")
|
|
122
|
-
open_subtitles_password = cached_config.get("open_subtitles_password")
|
|
123
|
-
except:
|
|
124
|
-
pass
|
|
125
|
-
if args.get_subs:
|
|
126
|
-
if not open_subtitles_api_key:
|
|
127
|
-
open_subtitles_api_key = input("Enter your OpenSubtitles API key: ")
|
|
128
|
-
|
|
129
|
-
if not open_subtitles_user_agent:
|
|
130
|
-
open_subtitles_user_agent = input("Enter your OpenSubtitles User Agent: ")
|
|
131
|
-
|
|
132
|
-
if not open_subtitles_username:
|
|
133
|
-
open_subtitles_username = input("Enter your OpenSubtitles Username: ")
|
|
134
|
-
|
|
135
|
-
if not open_subtitles_password:
|
|
136
|
-
open_subtitles_password = input("Enter your OpenSubtitles Password: ")
|
|
137
|
-
|
|
138
|
-
# If show directory is provided via command-line argument, use it
|
|
139
|
-
show_dir = args.show_dir
|
|
140
|
-
if not show_dir:
|
|
141
|
-
show_dir = cached_config.get("show_dir")
|
|
142
|
-
if not show_dir:
|
|
143
|
-
# If show directory is not provided, prompt the user to input it
|
|
144
|
-
show_dir = input("Enter the main directory of the show:")
|
|
145
|
-
logger.info(f"Show Directory: {show_dir}")
|
|
146
|
-
# if the user does not provide a show directory, make the default show directory the current working directory
|
|
147
|
-
if not show_dir:
|
|
148
|
-
show_dir = os.getcwd()
|
|
149
|
-
if not args.tesseract_path:
|
|
150
|
-
tesseract_path = cached_config.get("tesseract_path")
|
|
151
|
-
|
|
152
|
-
if not tesseract_path:
|
|
153
|
-
tesseract_path = input(
|
|
154
|
-
r"Enter the path to the tesseract executable: ['C:\Program Files\Tesseract-OCR\tesseract.exe']"
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
else:
|
|
158
|
-
tesseract_path = args.tesseract_path
|
|
159
|
-
logger.debug(f"Teesseract Path: {tesseract_path}")
|
|
160
|
-
logger.debug(f"Show Directory: {show_dir}")
|
|
161
|
-
|
|
162
|
-
# Set the configuration
|
|
163
|
-
set_config(
|
|
164
|
-
tmdb_api_key,
|
|
165
|
-
open_subtitles_api_key,
|
|
166
|
-
open_subtitles_user_agent,
|
|
167
|
-
open_subtitles_username,
|
|
168
|
-
open_subtitles_password,
|
|
169
|
-
show_dir,
|
|
170
|
-
CONFIG_FILE,
|
|
171
|
-
tesseract_path=tesseract_path,
|
|
172
|
-
)
|
|
173
|
-
logger.info("Configuration set")
|
|
174
|
-
|
|
175
|
-
# Process the show
|
|
176
|
-
from mkv_episode_matcher.episode_matcher import process_show
|
|
177
|
-
|
|
178
|
-
process_show(args.season, dry_run=args.dry_run, get_subs=args.get_subs)
|
|
179
|
-
logger.info("Show processing completed")
|
|
5
|
+
app()
|
|
180
6
|
|
|
181
7
|
|
|
182
|
-
# Run the main function if the script is run directly
|
|
183
8
|
if __name__ == "__main__":
|
|
184
9
|
main()
|
|
@@ -0,0 +1,506 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ASR Model Abstraction Layer
|
|
3
|
+
|
|
4
|
+
This module provides a unified interface for different Automatic Speech Recognition models,
|
|
5
|
+
including OpenAI Whisper and NVIDIA Parakeet models.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import abc
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import tempfile
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import librosa
|
|
15
|
+
import numpy as np
|
|
16
|
+
import soundfile as sf
|
|
17
|
+
import torch
|
|
18
|
+
from loguru import logger
|
|
19
|
+
from rapidfuzz import fuzz
|
|
20
|
+
|
|
21
|
+
# Cache for loaded models to avoid reloading
|
|
22
|
+
_model_cache = {}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ASRModel(abc.ABC):
|
|
26
|
+
"""Abstract base class for ASR models."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, model_name: str, device: str | None = None):
|
|
29
|
+
"""
|
|
30
|
+
Initialize ASR model.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model_name: Name/identifier of the model
|
|
34
|
+
device: Device to run on ('cpu', 'cuda', or None for auto-detect)
|
|
35
|
+
"""
|
|
36
|
+
self.model_name = model_name
|
|
37
|
+
self.device = device or self._get_default_device()
|
|
38
|
+
self._model = None
|
|
39
|
+
|
|
40
|
+
def _get_default_device(self) -> str:
|
|
41
|
+
"""Get default device for this model type."""
|
|
42
|
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
43
|
+
|
|
44
|
+
@abc.abstractmethod
|
|
45
|
+
def load(self):
|
|
46
|
+
"""Load the model. Should be called before transcription."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abc.abstractmethod
|
|
50
|
+
def transcribe(self, audio_path: str | Path) -> dict:
|
|
51
|
+
"""
|
|
52
|
+
Transcribe audio file.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
audio_path: Path to audio file
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Dictionary with at least 'text' key containing transcription
|
|
59
|
+
"""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
def calculate_match_score(self, transcription: str, reference: str) -> float:
|
|
63
|
+
"""
|
|
64
|
+
Calculate similarity score between transcription and reference.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
transcription: Transcribed text
|
|
68
|
+
reference: Reference subtitle text
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Float score between 0.0 and 1.0
|
|
72
|
+
"""
|
|
73
|
+
# Default implementation: Standard weights
|
|
74
|
+
# Token sort ratio (70%) + Partial ratio (30%)
|
|
75
|
+
token_weight = 0.7
|
|
76
|
+
partial_weight = 0.3
|
|
77
|
+
|
|
78
|
+
score = (
|
|
79
|
+
fuzz.token_sort_ratio(transcription, reference) * token_weight
|
|
80
|
+
+ fuzz.partial_ratio(transcription, reference) * partial_weight
|
|
81
|
+
) / 100.0
|
|
82
|
+
|
|
83
|
+
return score
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def is_loaded(self) -> bool:
|
|
87
|
+
"""Check if model is loaded."""
|
|
88
|
+
return self._model is not None
|
|
89
|
+
|
|
90
|
+
def unload(self):
|
|
91
|
+
"""Unload model to free memory."""
|
|
92
|
+
self._model = None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ParakeetTDTModel(ASRModel):
|
|
96
|
+
"""
|
|
97
|
+
NVIDIA Parakeet TDT ASR model implementation.
|
|
98
|
+
|
|
99
|
+
WARNING: This model (TDT) uses the Transducer decoder which requires significant GPU resources
|
|
100
|
+
and may be unstable on some Windows configurations (CUDA errors).
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self, model_name: str = "nvidia/parakeet-tdt-0.6b-v2", device: str | None = None
|
|
105
|
+
):
|
|
106
|
+
"""
|
|
107
|
+
Initialize Parakeet TDT model.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
model_name: Parakeet model identifier from HuggingFace
|
|
111
|
+
device: Device to run on
|
|
112
|
+
"""
|
|
113
|
+
super().__init__(model_name, device)
|
|
114
|
+
|
|
115
|
+
def load(self):
|
|
116
|
+
"""Load Parakeet model with caching."""
|
|
117
|
+
if self.is_loaded:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
cache_key = f"parakeet_tdt_{self.model_name}_{self.device}"
|
|
121
|
+
|
|
122
|
+
if cache_key in _model_cache:
|
|
123
|
+
self._model = _model_cache[cache_key]
|
|
124
|
+
logger.debug(
|
|
125
|
+
f"Using cached Parakeet TDT model: {self.model_name} on {self.device}"
|
|
126
|
+
)
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
# Windows compatibility: Patch signal module before importing NeMo
|
|
131
|
+
if os.name == "nt": # Windows
|
|
132
|
+
import signal
|
|
133
|
+
|
|
134
|
+
if not hasattr(signal, "SIGKILL"):
|
|
135
|
+
# Add missing signal constants for Windows compatibility
|
|
136
|
+
signal.SIGKILL = 9
|
|
137
|
+
signal.SIGTERM = 15
|
|
138
|
+
|
|
139
|
+
import nemo.collections.asr as nemo_asr
|
|
140
|
+
|
|
141
|
+
# Store original environment variables for restoration
|
|
142
|
+
original_env = {}
|
|
143
|
+
|
|
144
|
+
# Configure environment to suppress NeMo warnings and optimize performance
|
|
145
|
+
nemo_env_settings = {
|
|
146
|
+
"NEMO_DISABLE_TRAINING_LOGS": "1",
|
|
147
|
+
"NEMO_DISABLE_HYDRA_LOGS": "1",
|
|
148
|
+
"HYDRA_FULL_ERROR": "0",
|
|
149
|
+
"PYTHONWARNINGS": "ignore::UserWarning",
|
|
150
|
+
"TOKENIZERS_PARALLELISM": "false", # Avoid tokenizer warnings
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
# Windows compatibility: Add optimizations but avoid signal issues
|
|
154
|
+
if os.name == "nt": # Windows
|
|
155
|
+
nemo_env_settings.update({
|
|
156
|
+
"OMP_NUM_THREADS": "1",
|
|
157
|
+
"MKL_NUM_THREADS": "1",
|
|
158
|
+
"NEMO_BYPASS_SIGNALS": "1", # Bypass NeMo signal handling on Windows
|
|
159
|
+
})
|
|
160
|
+
|
|
161
|
+
for key, value in nemo_env_settings.items():
|
|
162
|
+
original_env[key] = os.environ.get(key)
|
|
163
|
+
os.environ[key] = value
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
# Set device for NeMo
|
|
167
|
+
if self.device == "cuda" and torch.cuda.is_available():
|
|
168
|
+
# NeMo will automatically use CUDA if available
|
|
169
|
+
pass
|
|
170
|
+
elif self.device == "cpu":
|
|
171
|
+
# Force CPU usage - NeMo respects CUDA_VISIBLE_DEVICES=""
|
|
172
|
+
original_env["CUDA_VISIBLE_DEVICES"] = os.environ.get(
|
|
173
|
+
"CUDA_VISIBLE_DEVICES"
|
|
174
|
+
)
|
|
175
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
176
|
+
|
|
177
|
+
# Load model with reduced verbosity
|
|
178
|
+
self._model = nemo_asr.models.ASRModel.from_pretrained(
|
|
179
|
+
model_name=self.model_name,
|
|
180
|
+
strict=False, # Allow loading with missing keys to reduce warnings
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Configure model for optimal inference
|
|
184
|
+
if hasattr(self._model, "set_batch_size"):
|
|
185
|
+
self._model.set_batch_size(1) # Optimize for single file processing
|
|
186
|
+
|
|
187
|
+
# Fix for Windows: Force num_workers to 0 to avoid multiprocessing errors/locks
|
|
188
|
+
if hasattr(self._model, "cfg"):
|
|
189
|
+
for ds_config in ["test_ds", "validation_ds"]:
|
|
190
|
+
if ds_config in self._model.cfg:
|
|
191
|
+
self._model.cfg[ds_config].num_workers = 0
|
|
192
|
+
|
|
193
|
+
if hasattr(self._model, "eval"):
|
|
194
|
+
self._model.eval() # Set to evaluation mode
|
|
195
|
+
|
|
196
|
+
finally:
|
|
197
|
+
# Restore original environment variables
|
|
198
|
+
for key, original_value in original_env.items():
|
|
199
|
+
if original_value is not None:
|
|
200
|
+
os.environ[key] = original_value
|
|
201
|
+
elif key in os.environ:
|
|
202
|
+
del os.environ[key]
|
|
203
|
+
|
|
204
|
+
_model_cache[cache_key] = self._model
|
|
205
|
+
logger.info(
|
|
206
|
+
f"Loaded Parakeet TDT model: {self.model_name} on {self.device}"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
except ImportError as e:
|
|
210
|
+
raise ImportError(
|
|
211
|
+
"NVIDIA NeMo not installed. Run: pip install nemo_toolkit[asr]"
|
|
212
|
+
) from e
|
|
213
|
+
except Exception as e:
|
|
214
|
+
logger.error(f"Failed to load Parakeet TDT model {self.model_name}: {e}")
|
|
215
|
+
raise
|
|
216
|
+
|
|
217
|
+
def _preprocess_audio(self, audio_path: str | Path) -> str:
|
|
218
|
+
"""
|
|
219
|
+
Preprocess audio for Parakeet model requirements.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
audio_path: Path to input audio file
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Path to preprocessed audio file
|
|
226
|
+
"""
|
|
227
|
+
try:
|
|
228
|
+
# Load audio with librosa
|
|
229
|
+
audio, original_sr = librosa.load(str(audio_path), sr=None)
|
|
230
|
+
|
|
231
|
+
# Target sample rate for Parakeet models (16kHz is optimal)
|
|
232
|
+
target_sr = 16000
|
|
233
|
+
|
|
234
|
+
# Resample if necessary
|
|
235
|
+
if original_sr != target_sr:
|
|
236
|
+
audio = librosa.resample(
|
|
237
|
+
audio, orig_sr=original_sr, target_sr=target_sr
|
|
238
|
+
)
|
|
239
|
+
logger.debug(f"Resampled audio from {original_sr}Hz to {target_sr}Hz")
|
|
240
|
+
|
|
241
|
+
# Normalize audio to [-1, 1] range
|
|
242
|
+
if np.max(np.abs(audio)) > 0:
|
|
243
|
+
audio = audio / np.max(np.abs(audio))
|
|
244
|
+
|
|
245
|
+
# Create temporary file for preprocessed audio
|
|
246
|
+
temp_dir = Path(tempfile.gettempdir()) / "parakeet_preprocessed"
|
|
247
|
+
temp_dir.mkdir(exist_ok=True)
|
|
248
|
+
|
|
249
|
+
temp_audio_path = temp_dir / f"preprocessed_{Path(audio_path).stem}.wav"
|
|
250
|
+
|
|
251
|
+
# Save preprocessed audio
|
|
252
|
+
sf.write(str(temp_audio_path), audio, target_sr)
|
|
253
|
+
|
|
254
|
+
logger.debug(f"Preprocessed audio saved to {temp_audio_path}")
|
|
255
|
+
return str(temp_audio_path)
|
|
256
|
+
|
|
257
|
+
except Exception as e:
|
|
258
|
+
logger.warning(f"Audio preprocessing failed, using original: {e}")
|
|
259
|
+
return str(audio_path)
|
|
260
|
+
|
|
261
|
+
def _clean_transcription_text(self, text: str) -> str:
|
|
262
|
+
"""
|
|
263
|
+
Clean and normalize transcription text using EXACT same method as EpisodeMatcher.
|
|
264
|
+
|
|
265
|
+
This ensures compatibility with the existing matching algorithm.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
text: Raw transcription text
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Cleaned text using identical cleaning as EpisodeMatcher.clean_text()
|
|
272
|
+
"""
|
|
273
|
+
if not text:
|
|
274
|
+
return ""
|
|
275
|
+
|
|
276
|
+
# Use EXACT same cleaning logic as EpisodeMatcher.clean_text()
|
|
277
|
+
text = text.lower().strip()
|
|
278
|
+
text = re.sub(r"\[.*?\]|\<.*?\>", "", text)
|
|
279
|
+
text = re.sub(r"([A-Za-z])-\1+", r"\1", text)
|
|
280
|
+
return " ".join(text.split())
|
|
281
|
+
|
|
282
|
+
def calculate_match_score(self, transcription: str, reference: str) -> float:
|
|
283
|
+
"""
|
|
284
|
+
Calculate similarity score with Parakeet-specific weights.
|
|
285
|
+
Parakeet produces longer, more detailed transcriptions, so we favor partial matches.
|
|
286
|
+
"""
|
|
287
|
+
# Parakeet weights: Boost partial_ratio
|
|
288
|
+
token_weight = 0.4
|
|
289
|
+
partial_weight = 0.6
|
|
290
|
+
|
|
291
|
+
# Additional boost for very detailed transcriptions
|
|
292
|
+
length_ratio = len(transcription) / max(len(reference), 1)
|
|
293
|
+
if length_ratio > 2.0: # Much longer transcription
|
|
294
|
+
partial_weight = 0.8
|
|
295
|
+
token_weight = 0.2
|
|
296
|
+
|
|
297
|
+
score = (
|
|
298
|
+
fuzz.token_sort_ratio(transcription, reference) * token_weight
|
|
299
|
+
+ fuzz.partial_ratio(transcription, reference) * partial_weight
|
|
300
|
+
) / 100.0
|
|
301
|
+
|
|
302
|
+
return score
|
|
303
|
+
|
|
304
|
+
def transcribe(self, audio_path: str | Path) -> dict:
|
|
305
|
+
"""
|
|
306
|
+
Transcribe audio using Parakeet with preprocessing and text normalization.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
audio_path: Path to audio file
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Dictionary with 'text' and 'segments' from Parakeet
|
|
313
|
+
"""
|
|
314
|
+
if not self.is_loaded:
|
|
315
|
+
self.load()
|
|
316
|
+
|
|
317
|
+
preprocessed_audio = None
|
|
318
|
+
try:
|
|
319
|
+
logger.debug(f"Starting Parakeet transcription for {audio_path}")
|
|
320
|
+
|
|
321
|
+
# Preprocess audio for optimal Parakeet performance
|
|
322
|
+
preprocessed_audio = self._preprocess_audio(audio_path)
|
|
323
|
+
|
|
324
|
+
# Configure NeMo model settings to reduce warnings
|
|
325
|
+
old_env_vars = {}
|
|
326
|
+
try:
|
|
327
|
+
# Set environment variables to reduce NeMo warnings
|
|
328
|
+
env_settings = {
|
|
329
|
+
"CUDA_LAUNCH_BLOCKING": "0",
|
|
330
|
+
"NEMO_DISABLE_TRAINING_LOGS": "1",
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
for key, value in env_settings.items():
|
|
334
|
+
old_env_vars[key] = os.environ.get(key)
|
|
335
|
+
os.environ[key] = value
|
|
336
|
+
|
|
337
|
+
# Parakeet expects list of file paths
|
|
338
|
+
result = self._model.transcribe([preprocessed_audio])
|
|
339
|
+
|
|
340
|
+
finally:
|
|
341
|
+
# Restore original environment variables
|
|
342
|
+
for key, old_value in old_env_vars.items():
|
|
343
|
+
if old_value is not None:
|
|
344
|
+
os.environ[key] = old_value
|
|
345
|
+
elif key in os.environ:
|
|
346
|
+
del os.environ[key]
|
|
347
|
+
|
|
348
|
+
logger.debug(f"Parakeet raw result: {result}, type: {type(result)}")
|
|
349
|
+
|
|
350
|
+
# Extract text from result
|
|
351
|
+
raw_text = ""
|
|
352
|
+
if isinstance(result, list) and len(result) > 0:
|
|
353
|
+
if hasattr(result[0], "text"):
|
|
354
|
+
raw_text = result[0].text
|
|
355
|
+
elif isinstance(result[0], str):
|
|
356
|
+
raw_text = result[0]
|
|
357
|
+
else:
|
|
358
|
+
raw_text = str(result[0])
|
|
359
|
+
else:
|
|
360
|
+
logger.warning(f"Unexpected Parakeet result format: {result}")
|
|
361
|
+
raw_text = ""
|
|
362
|
+
|
|
363
|
+
# Clean and normalize the transcription
|
|
364
|
+
cleaned_text = self._clean_transcription_text(raw_text)
|
|
365
|
+
|
|
366
|
+
logger.debug(f"Raw transcription: '{raw_text}'")
|
|
367
|
+
logger.debug(f"Cleaned transcription: '{cleaned_text}'")
|
|
368
|
+
|
|
369
|
+
return {
|
|
370
|
+
"text": cleaned_text,
|
|
371
|
+
"raw_text": raw_text,
|
|
372
|
+
"segments": [],
|
|
373
|
+
"language": "en",
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
except Exception as e:
|
|
377
|
+
logger.error(
|
|
378
|
+
f"Parakeet transcription failed for {audio_path}: {type(e).__name__}: {e}"
|
|
379
|
+
)
|
|
380
|
+
import traceback
|
|
381
|
+
|
|
382
|
+
traceback.print_exc()
|
|
383
|
+
# Return empty result instead of raising to allow fallback
|
|
384
|
+
return {"text": "", "raw_text": "", "segments": [], "language": "en"}
|
|
385
|
+
finally:
|
|
386
|
+
# Clean up preprocessed audio file
|
|
387
|
+
if preprocessed_audio and preprocessed_audio != str(audio_path):
|
|
388
|
+
try:
|
|
389
|
+
Path(preprocessed_audio).unlink(missing_ok=True)
|
|
390
|
+
except Exception as e:
|
|
391
|
+
logger.debug(f"Failed to clean up preprocessed audio: {e}")
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class ParakeetCTCModel(ParakeetTDTModel):
|
|
395
|
+
"""
|
|
396
|
+
NVIDIA Parakeet CTC ASR model implementation.
|
|
397
|
+
|
|
398
|
+
This uses the CTC decoder which is more stable and robust on various hardware
|
|
399
|
+
than the TDT version, though potentially slightly less accurate.
|
|
400
|
+
"""
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self, model_name: str = "nvidia/parakeet-ctc-0.6b", device: str | None = None
|
|
404
|
+
):
|
|
405
|
+
"""
|
|
406
|
+
Initialize Parakeet CTC model.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
model_name: Parakeet model identifier (default: nvidia/parakeet-ctc-0.6b)
|
|
410
|
+
device: Device to run on
|
|
411
|
+
"""
|
|
412
|
+
# Ensure we use a CTC-compatible model name if not specified
|
|
413
|
+
# But we trust the user input if provided.
|
|
414
|
+
super().__init__(model_name, device)
|
|
415
|
+
|
|
416
|
+
def load(self):
|
|
417
|
+
"""Load Parakeet CTC model with caching."""
|
|
418
|
+
# We override load simply to use a different cache key if needed, or we can just reuse parent load
|
|
419
|
+
# reusing parent load is fine as it uses self.model_name in cache key.
|
|
420
|
+
# But we need to ensure the logging says CTC.
|
|
421
|
+
super().load()
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def create_asr_model(model_config: dict) -> ASRModel:
|
|
425
|
+
"""
|
|
426
|
+
Factory function to create ASR models from configuration.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
model_config: Dictionary with 'type' and 'name' keys
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
Configured ASRModel instance
|
|
433
|
+
|
|
434
|
+
Example:
|
|
435
|
+
model_config = {"type": "parakeet", "name": "nvidia/parakeet-ctc-0.6b"}
|
|
436
|
+
model = create_asr_model(model_config)
|
|
437
|
+
"""
|
|
438
|
+
model_type = model_config.get("type", "").lower()
|
|
439
|
+
model_name = model_config.get("name", "")
|
|
440
|
+
device = model_config.get("device")
|
|
441
|
+
|
|
442
|
+
if model_type == "parakeet":
|
|
443
|
+
# Always use the specific working model
|
|
444
|
+
if not model_name:
|
|
445
|
+
model_name = "nvidia/parakeet-ctc-0.6b"
|
|
446
|
+
return ParakeetCTCModel(model_name, device)
|
|
447
|
+
else:
|
|
448
|
+
raise ValueError(
|
|
449
|
+
f"Unsupported model type: {model_type}. Only 'parakeet' is supported."
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def get_cached_model(model_config: dict) -> ASRModel:
|
|
454
|
+
"""
|
|
455
|
+
Get a cached model instance, creating it if necessary.
|
|
456
|
+
|
|
457
|
+
Args:
|
|
458
|
+
model_config: Dictionary with model configuration
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
ASRModel instance (loaded and ready for use)
|
|
462
|
+
"""
|
|
463
|
+
cache_key = f"{model_config.get('type', '')}_{model_config.get('name', '')}_{model_config.get('device', 'auto')}"
|
|
464
|
+
|
|
465
|
+
if cache_key not in _model_cache:
|
|
466
|
+
model = create_asr_model(model_config)
|
|
467
|
+
model.load() # Load immediately for caching
|
|
468
|
+
_model_cache[cache_key] = model
|
|
469
|
+
|
|
470
|
+
return _model_cache[cache_key]
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def clear_model_cache():
|
|
474
|
+
"""Clear all cached models to free memory."""
|
|
475
|
+
global _model_cache
|
|
476
|
+
for model in _model_cache.values():
|
|
477
|
+
if hasattr(model, "unload"):
|
|
478
|
+
model.unload()
|
|
479
|
+
_model_cache.clear()
|
|
480
|
+
logger.info("Cleared ASR model cache")
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def list_available_models() -> dict:
|
|
484
|
+
"""
|
|
485
|
+
List available model types and their requirements.
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
Dictionary with model types and their availability status
|
|
489
|
+
"""
|
|
490
|
+
availability = {}
|
|
491
|
+
|
|
492
|
+
# Check Parakeet availability
|
|
493
|
+
try:
|
|
494
|
+
import nemo.collections.asr # noqa: F401
|
|
495
|
+
|
|
496
|
+
availability["parakeet"] = {
|
|
497
|
+
"available": True,
|
|
498
|
+
"models": ["nvidia/parakeet-ctc-0.6b"],
|
|
499
|
+
}
|
|
500
|
+
except ImportError:
|
|
501
|
+
availability["parakeet"] = {
|
|
502
|
+
"available": False,
|
|
503
|
+
"error": "NVIDIA NeMo not installed",
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
return availability
|