karaoke-gen 0.76.20__py3-none-any.whl → 0.81.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- karaoke_gen/instrumental_review/static/index.html +179 -16
- karaoke_gen/karaoke_gen.py +5 -4
- karaoke_gen/lyrics_processor.py +25 -6
- {karaoke_gen-0.76.20.dist-info → karaoke_gen-0.81.1.dist-info}/METADATA +79 -3
- {karaoke_gen-0.76.20.dist-info → karaoke_gen-0.81.1.dist-info}/RECORD +26 -23
- lyrics_transcriber/core/config.py +8 -0
- lyrics_transcriber/core/controller.py +43 -1
- lyrics_transcriber/correction/agentic/providers/config.py +6 -0
- lyrics_transcriber/correction/agentic/providers/model_factory.py +24 -1
- lyrics_transcriber/correction/agentic/router.py +17 -13
- lyrics_transcriber/frontend/.gitignore +1 -0
- lyrics_transcriber/frontend/e2e/agentic-corrections.spec.ts +207 -0
- lyrics_transcriber/frontend/e2e/fixtures/agentic-correction-data.json +226 -0
- lyrics_transcriber/frontend/package.json +4 -1
- lyrics_transcriber/frontend/playwright.config.ts +1 -1
- lyrics_transcriber/frontend/src/components/CorrectedWordWithActions.tsx +34 -30
- lyrics_transcriber/frontend/src/components/Header.tsx +141 -34
- lyrics_transcriber/frontend/src/components/LyricsAnalyzer.tsx +120 -3
- lyrics_transcriber/frontend/src/components/TranscriptionView.tsx +11 -1
- lyrics_transcriber/frontend/src/components/shared/components/HighlightedText.tsx +122 -35
- lyrics_transcriber/frontend/src/components/shared/types.ts +6 -0
- lyrics_transcriber/output/generator.py +50 -3
- lyrics_transcriber/transcribers/local_whisper.py +260 -0
- {karaoke_gen-0.76.20.dist-info → karaoke_gen-0.81.1.dist-info}/WHEEL +0 -0
- {karaoke_gen-0.76.20.dist-info → karaoke_gen-0.81.1.dist-info}/entry_points.txt +0 -0
- {karaoke_gen-0.76.20.dist-info → karaoke_gen-0.81.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { Typography, Box } from '@mui/material'
|
|
1
|
+
import { Typography, Box, useMediaQuery, useTheme } from '@mui/material'
|
|
2
2
|
import { WordComponent } from './Word'
|
|
3
3
|
import { useWordClick } from '../hooks/useWordClick'
|
|
4
4
|
import {
|
|
@@ -16,6 +16,7 @@ import React from 'react'
|
|
|
16
16
|
import ContentCopyIcon from '@mui/icons-material/ContentCopy'
|
|
17
17
|
import IconButton from '@mui/material/IconButton'
|
|
18
18
|
import { getWordsFromIds } from '../utils/wordUtils'
|
|
19
|
+
import CorrectedWordWithActions from '../../CorrectedWordWithActions'
|
|
19
20
|
|
|
20
21
|
export interface HighlightedTextProps {
|
|
21
22
|
text?: string
|
|
@@ -36,6 +37,12 @@ export interface HighlightedTextProps {
|
|
|
36
37
|
gaps?: GapSequence[]
|
|
37
38
|
flashingHandler?: string | null
|
|
38
39
|
corrections?: WordCorrection[]
|
|
40
|
+
// Review mode props for agentic corrections
|
|
41
|
+
reviewMode?: boolean
|
|
42
|
+
onRevertCorrection?: (wordId: string) => void
|
|
43
|
+
onEditCorrection?: (wordId: string) => void
|
|
44
|
+
onAcceptCorrection?: (wordId: string) => void
|
|
45
|
+
onShowCorrectionDetail?: (wordId: string) => void
|
|
39
46
|
}
|
|
40
47
|
|
|
41
48
|
export function HighlightedText({
|
|
@@ -57,7 +64,15 @@ export function HighlightedText({
|
|
|
57
64
|
gaps = [],
|
|
58
65
|
flashingHandler,
|
|
59
66
|
corrections = [],
|
|
67
|
+
reviewMode = false,
|
|
68
|
+
onRevertCorrection,
|
|
69
|
+
onEditCorrection,
|
|
70
|
+
onAcceptCorrection,
|
|
71
|
+
onShowCorrectionDetail,
|
|
60
72
|
}: HighlightedTextProps) {
|
|
73
|
+
const theme = useTheme()
|
|
74
|
+
const isMobile = useMediaQuery(theme.breakpoints.down('sm'))
|
|
75
|
+
|
|
61
76
|
const { handleWordClick } = useWordClick({
|
|
62
77
|
mode,
|
|
63
78
|
onElementClick,
|
|
@@ -157,43 +172,83 @@ export function HighlightedText({
|
|
|
157
172
|
|
|
158
173
|
const renderContent = () => {
|
|
159
174
|
if (wordPositions && !segments) {
|
|
160
|
-
return wordPositions.map((wordPos, index) =>
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
175
|
+
return wordPositions.map((wordPos, index) => {
|
|
176
|
+
// Find correction for this word
|
|
177
|
+
const correction = corrections?.find(c =>
|
|
178
|
+
c.corrected_word_id === wordPos.word.id ||
|
|
179
|
+
c.word_id === wordPos.word.id
|
|
180
|
+
);
|
|
181
|
+
|
|
182
|
+
// Use CorrectedWordWithActions for agentic corrections
|
|
183
|
+
if (correction && correction.handler === 'AgenticCorrector') {
|
|
184
|
+
return (
|
|
185
|
+
<React.Fragment key={wordPos.word.id}>
|
|
186
|
+
<CorrectedWordWithActions
|
|
187
|
+
word={wordPos.word.text}
|
|
188
|
+
originalWord={correction.original_word}
|
|
189
|
+
correction={{
|
|
190
|
+
originalWord: correction.original_word,
|
|
191
|
+
handler: correction.handler,
|
|
192
|
+
confidence: correction.confidence,
|
|
193
|
+
source: correction.source,
|
|
194
|
+
reason: correction.reason
|
|
195
|
+
}}
|
|
196
|
+
shouldFlash={shouldWordFlash(wordPos)}
|
|
197
|
+
showActions={reviewMode && !isMobile}
|
|
198
|
+
onRevert={() => onRevertCorrection?.(wordPos.word.id)}
|
|
199
|
+
onEdit={() => onEditCorrection?.(wordPos.word.id)}
|
|
200
|
+
onAccept={() => onAcceptCorrection?.(wordPos.word.id)}
|
|
201
|
+
onClick={() => {
|
|
202
|
+
if (isMobile) {
|
|
203
|
+
onShowCorrectionDetail?.(wordPos.word.id)
|
|
204
|
+
} else {
|
|
205
|
+
handleWordClick(
|
|
206
|
+
wordPos.word.text,
|
|
207
|
+
wordPos.word.id,
|
|
208
|
+
wordPos.type === 'anchor' ? wordPos.sequence as AnchorSequence : undefined,
|
|
209
|
+
wordPos.type === 'gap' ? wordPos.sequence as GapSequence : undefined
|
|
210
|
+
)
|
|
211
|
+
}
|
|
212
|
+
}}
|
|
213
|
+
/>
|
|
214
|
+
{index < wordPositions.length - 1 && ' '}
|
|
215
|
+
</React.Fragment>
|
|
216
|
+
);
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
// Default rendering with WordComponent
|
|
220
|
+
return (
|
|
221
|
+
<React.Fragment key={wordPos.word.id}>
|
|
222
|
+
<WordComponent
|
|
223
|
+
key={`${wordPos.word.id}-${index}`}
|
|
224
|
+
word={wordPos.word.text}
|
|
225
|
+
shouldFlash={shouldWordFlash(wordPos)}
|
|
226
|
+
isAnchor={wordPos.type === 'anchor'}
|
|
227
|
+
isCorrectedGap={wordPos.isCorrected}
|
|
228
|
+
isUncorrectedGap={wordPos.type === 'gap' && !wordPos.isCorrected}
|
|
229
|
+
isCurrentlyPlaying={shouldHighlightWord(wordPos)}
|
|
230
|
+
onClick={() => handleWordClick(
|
|
231
|
+
wordPos.word.text,
|
|
232
|
+
wordPos.word.id,
|
|
233
|
+
wordPos.type === 'anchor' ? wordPos.sequence as AnchorSequence : undefined,
|
|
234
|
+
wordPos.type === 'gap' ? wordPos.sequence as GapSequence : undefined
|
|
235
|
+
)}
|
|
236
|
+
correction={correction ? {
|
|
182
237
|
originalWord: correction.original_word,
|
|
183
238
|
handler: correction.handler,
|
|
184
239
|
confidence: correction.confidence,
|
|
185
240
|
source: correction.source,
|
|
186
241
|
reason: correction.reason
|
|
187
|
-
} : null
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
)
|
|
242
|
+
} : null}
|
|
243
|
+
/>
|
|
244
|
+
{index < wordPositions.length - 1 && ' '}
|
|
245
|
+
</React.Fragment>
|
|
246
|
+
);
|
|
247
|
+
})
|
|
193
248
|
} else if (segments) {
|
|
194
249
|
return segments.map((segment) => (
|
|
195
|
-
<Box key={segment.id} sx={{
|
|
196
|
-
display: 'flex',
|
|
250
|
+
<Box key={segment.id} sx={{
|
|
251
|
+
display: 'flex',
|
|
197
252
|
alignItems: 'flex-start',
|
|
198
253
|
mb: 0
|
|
199
254
|
}}>
|
|
@@ -212,12 +267,44 @@ export function HighlightedText({
|
|
|
212
267
|
|
|
213
268
|
const sequence = wordPos?.type === 'gap' ? wordPos.sequence as GapSequence : undefined;
|
|
214
269
|
|
|
215
|
-
// Find correction information
|
|
216
|
-
const correction = corrections?.find(c =>
|
|
217
|
-
c.corrected_word_id === word.id ||
|
|
270
|
+
// Find correction information
|
|
271
|
+
const correction = corrections?.find(c =>
|
|
272
|
+
c.corrected_word_id === word.id ||
|
|
218
273
|
c.word_id === word.id
|
|
219
274
|
);
|
|
220
|
-
|
|
275
|
+
|
|
276
|
+
// Use CorrectedWordWithActions for agentic corrections
|
|
277
|
+
if (correction && correction.handler === 'AgenticCorrector') {
|
|
278
|
+
return (
|
|
279
|
+
<React.Fragment key={word.id}>
|
|
280
|
+
<CorrectedWordWithActions
|
|
281
|
+
word={word.text}
|
|
282
|
+
originalWord={correction.original_word}
|
|
283
|
+
correction={{
|
|
284
|
+
originalWord: correction.original_word,
|
|
285
|
+
handler: correction.handler,
|
|
286
|
+
confidence: correction.confidence,
|
|
287
|
+
source: correction.source,
|
|
288
|
+
reason: correction.reason
|
|
289
|
+
}}
|
|
290
|
+
shouldFlash={shouldWordFlash(wordPos || { word: word.text, id: word.id })}
|
|
291
|
+
showActions={reviewMode && !isMobile}
|
|
292
|
+
onRevert={() => onRevertCorrection?.(word.id)}
|
|
293
|
+
onEdit={() => onEditCorrection?.(word.id)}
|
|
294
|
+
onAccept={() => onAcceptCorrection?.(word.id)}
|
|
295
|
+
onClick={() => {
|
|
296
|
+
if (isMobile) {
|
|
297
|
+
onShowCorrectionDetail?.(word.id)
|
|
298
|
+
} else {
|
|
299
|
+
handleWordClick(word.text, word.id, anchor, sequence)
|
|
300
|
+
}
|
|
301
|
+
}}
|
|
302
|
+
/>
|
|
303
|
+
{wordIndex < segment.words.length - 1 && ' '}
|
|
304
|
+
</React.Fragment>
|
|
305
|
+
);
|
|
306
|
+
}
|
|
307
|
+
|
|
221
308
|
const correctionInfo = correction ? {
|
|
222
309
|
originalWord: correction.original_word,
|
|
223
310
|
handler: correction.handler,
|
|
@@ -85,6 +85,12 @@ export interface TranscriptionViewProps {
|
|
|
85
85
|
anchors?: AnchorSequence[]
|
|
86
86
|
flashingHandler?: string | null
|
|
87
87
|
onDataChange?: (updatedData: CorrectionData) => void
|
|
88
|
+
// Review mode props for agentic corrections
|
|
89
|
+
reviewMode?: boolean
|
|
90
|
+
onRevertCorrection?: (wordId: string) => void
|
|
91
|
+
onEditCorrection?: (wordId: string) => void
|
|
92
|
+
onAcceptCorrection?: (wordId: string) => void
|
|
93
|
+
onShowCorrectionDetail?: (wordId: string) => void
|
|
88
94
|
}
|
|
89
95
|
|
|
90
96
|
// Add LinePosition type here since it's used in multiple places
|
|
@@ -52,7 +52,7 @@ class OutputGenerator:
|
|
|
52
52
|
|
|
53
53
|
self.logger.info(f"Initializing OutputGenerator with config: {self.config}")
|
|
54
54
|
|
|
55
|
-
# Load output styles from JSON if provided
|
|
55
|
+
# Load output styles from JSON if provided, otherwise use defaults
|
|
56
56
|
if self.config.output_styles_json and os.path.exists(self.config.output_styles_json):
|
|
57
57
|
try:
|
|
58
58
|
with open(self.config.output_styles_json, "r") as f:
|
|
@@ -67,9 +67,10 @@ class OutputGenerator:
|
|
|
67
67
|
self.logger.warning(f"Failed to load output styles file: {str(e)}")
|
|
68
68
|
self.config.styles = {}
|
|
69
69
|
else:
|
|
70
|
-
# No styles file provided or doesn't exist
|
|
70
|
+
# No styles file provided or doesn't exist - use defaults
|
|
71
71
|
if self.config.render_video or self.config.generate_cdg:
|
|
72
|
-
|
|
72
|
+
self.logger.info("No output styles file provided, using default karaoke styles")
|
|
73
|
+
self.config.styles = self._get_default_styles()
|
|
73
74
|
else:
|
|
74
75
|
self.config.styles = {}
|
|
75
76
|
|
|
@@ -242,6 +243,52 @@ class OutputGenerator:
|
|
|
242
243
|
|
|
243
244
|
return resolution_dims, font_size, line_height
|
|
244
245
|
|
|
246
|
+
def _get_default_styles(self) -> dict:
|
|
247
|
+
"""Get default styles for video/CDG generation when no styles file is provided."""
|
|
248
|
+
return {
|
|
249
|
+
"karaoke": {
|
|
250
|
+
# Video background
|
|
251
|
+
"background_color": "#000000",
|
|
252
|
+
"background_image": None,
|
|
253
|
+
# Font settings
|
|
254
|
+
"font": "Arial",
|
|
255
|
+
"font_path": "", # Must be string, not None (for ASS generator)
|
|
256
|
+
"ass_name": "Default",
|
|
257
|
+
# Colors in "R, G, B, A" format (required by ASS)
|
|
258
|
+
"primary_color": "112, 112, 247, 255",
|
|
259
|
+
"secondary_color": "255, 255, 255, 255",
|
|
260
|
+
"outline_color": "26, 58, 235, 255",
|
|
261
|
+
"back_color": "0, 0, 0, 0",
|
|
262
|
+
# Boolean style options
|
|
263
|
+
"bold": False,
|
|
264
|
+
"italic": False,
|
|
265
|
+
"underline": False,
|
|
266
|
+
"strike_out": False,
|
|
267
|
+
# Numeric style options (all required for ASS)
|
|
268
|
+
"scale_x": 100,
|
|
269
|
+
"scale_y": 100,
|
|
270
|
+
"spacing": 0,
|
|
271
|
+
"angle": 0.0,
|
|
272
|
+
"border_style": 1,
|
|
273
|
+
"outline": 1,
|
|
274
|
+
"shadow": 0,
|
|
275
|
+
"margin_l": 0,
|
|
276
|
+
"margin_r": 0,
|
|
277
|
+
"margin_v": 0,
|
|
278
|
+
"encoding": 0,
|
|
279
|
+
# Layout settings
|
|
280
|
+
"max_line_length": 40,
|
|
281
|
+
"top_padding": 200,
|
|
282
|
+
"font_size": 100,
|
|
283
|
+
},
|
|
284
|
+
"cdg": {
|
|
285
|
+
"font_path": None,
|
|
286
|
+
"instrumental_background": None,
|
|
287
|
+
"title_screen_background": None,
|
|
288
|
+
"outro_background": None,
|
|
289
|
+
},
|
|
290
|
+
}
|
|
291
|
+
|
|
245
292
|
def write_corrections_data(self, correction_result: CorrectionResult, output_prefix: str) -> str:
|
|
246
293
|
"""Write corrections data to JSON file."""
|
|
247
294
|
self.logger.info("Writing corrections data JSON")
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""Local Whisper transcription service using whisper-timestamped for word-level timestamps."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import os
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Optional, Dict, Any, Union
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from lyrics_transcriber.types import TranscriptionData, LyricsSegment, Word
|
|
10
|
+
from lyrics_transcriber.transcribers.base_transcriber import BaseTranscriber, TranscriptionError
|
|
11
|
+
from lyrics_transcriber.utils.word_utils import WordUtils
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class LocalWhisperConfig:
|
|
16
|
+
"""Configuration for local Whisper transcription service."""
|
|
17
|
+
|
|
18
|
+
model_size: str = "medium" # tiny, base, small, medium, large, large-v2, large-v3
|
|
19
|
+
device: Optional[str] = None # None for auto-detect, or "cpu", "cuda", "mps"
|
|
20
|
+
cache_dir: Optional[str] = None # Directory for model downloads (~/.cache/whisper by default)
|
|
21
|
+
language: Optional[str] = None # Language code for transcription, None for auto-detect
|
|
22
|
+
compute_type: str = "auto" # float16, float32, int8, auto
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LocalWhisperTranscriber(BaseTranscriber):
|
|
26
|
+
"""
|
|
27
|
+
Transcription service using local Whisper inference via whisper-timestamped.
|
|
28
|
+
|
|
29
|
+
This transcriber runs Whisper models locally on your machine, supporting
|
|
30
|
+
CPU, CUDA GPU, and Apple Silicon MPS acceleration. It uses the
|
|
31
|
+
whisper-timestamped library to get accurate word-level timestamps.
|
|
32
|
+
|
|
33
|
+
Requirements:
|
|
34
|
+
pip install karaoke-gen[local-whisper]
|
|
35
|
+
|
|
36
|
+
Configuration:
|
|
37
|
+
Set environment variables to customize behavior:
|
|
38
|
+
- WHISPER_MODEL_SIZE: Model size (tiny, base, small, medium, large)
|
|
39
|
+
- WHISPER_DEVICE: Device to use (cpu, cuda, mps, or auto)
|
|
40
|
+
- WHISPER_CACHE_DIR: Directory for model downloads
|
|
41
|
+
- WHISPER_LANGUAGE: Language code (en, es, fr, etc.) or auto-detect
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
cache_dir: Union[str, Path],
|
|
47
|
+
config: Optional[LocalWhisperConfig] = None,
|
|
48
|
+
logger: Optional[logging.Logger] = None,
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Initialize local Whisper transcriber.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
cache_dir: Directory for caching transcription results
|
|
55
|
+
config: Configuration options for the transcriber
|
|
56
|
+
logger: Logger instance to use
|
|
57
|
+
"""
|
|
58
|
+
super().__init__(cache_dir=cache_dir, logger=logger)
|
|
59
|
+
|
|
60
|
+
# Initialize configuration from env vars or defaults
|
|
61
|
+
self.config = config or LocalWhisperConfig(
|
|
62
|
+
model_size=os.getenv("WHISPER_MODEL_SIZE", "medium"),
|
|
63
|
+
device=os.getenv("WHISPER_DEVICE"), # None for auto-detect
|
|
64
|
+
cache_dir=os.getenv("WHISPER_CACHE_DIR"),
|
|
65
|
+
language=os.getenv("WHISPER_LANGUAGE"), # None for auto-detect
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Lazy-loaded model instance (loaded on first use)
|
|
69
|
+
self._model = None
|
|
70
|
+
self._whisper_module = None
|
|
71
|
+
|
|
72
|
+
self.logger.debug(
|
|
73
|
+
f"LocalWhisperTranscriber initialized with model_size={self.config.model_size}, "
|
|
74
|
+
f"device={self.config.device or 'auto'}, language={self.config.language or 'auto-detect'}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def get_name(self) -> str:
|
|
78
|
+
"""Return the name of this transcription service."""
|
|
79
|
+
return "LocalWhisper"
|
|
80
|
+
|
|
81
|
+
def _check_dependencies(self) -> None:
|
|
82
|
+
"""Check that required dependencies are installed."""
|
|
83
|
+
try:
|
|
84
|
+
import whisper_timestamped # noqa: F401
|
|
85
|
+
except ImportError:
|
|
86
|
+
raise TranscriptionError(
|
|
87
|
+
"whisper-timestamped is not installed. "
|
|
88
|
+
"Install it with: pip install karaoke-gen[local-whisper] "
|
|
89
|
+
"or: pip install whisper-timestamped"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def _get_device(self) -> str:
|
|
93
|
+
"""Determine the best device to use for inference."""
|
|
94
|
+
if self.config.device:
|
|
95
|
+
return self.config.device
|
|
96
|
+
|
|
97
|
+
# Auto-detect best available device
|
|
98
|
+
try:
|
|
99
|
+
import torch
|
|
100
|
+
|
|
101
|
+
if torch.cuda.is_available():
|
|
102
|
+
self.logger.info("Using CUDA GPU for Whisper inference")
|
|
103
|
+
return "cuda"
|
|
104
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
105
|
+
self.logger.info("Using Apple Silicon MPS for Whisper inference")
|
|
106
|
+
return "cpu" # whisper-timestamped works better with CPU on MPS
|
|
107
|
+
else:
|
|
108
|
+
self.logger.info("Using CPU for Whisper inference (no GPU detected)")
|
|
109
|
+
return "cpu"
|
|
110
|
+
except ImportError:
|
|
111
|
+
self.logger.warning("PyTorch not available, defaulting to CPU")
|
|
112
|
+
return "cpu"
|
|
113
|
+
|
|
114
|
+
def _load_model(self):
|
|
115
|
+
"""Load the Whisper model (lazy loading on first use)."""
|
|
116
|
+
if self._model is not None:
|
|
117
|
+
return self._model
|
|
118
|
+
|
|
119
|
+
self._check_dependencies()
|
|
120
|
+
import whisper_timestamped as whisper
|
|
121
|
+
|
|
122
|
+
self._whisper_module = whisper
|
|
123
|
+
|
|
124
|
+
device = self._get_device()
|
|
125
|
+
self.logger.info(f"Loading Whisper model '{self.config.model_size}' on device '{device}'...")
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
# Load model with optional custom cache directory
|
|
129
|
+
download_root = self.config.cache_dir
|
|
130
|
+
self._model = whisper.load_model(
|
|
131
|
+
self.config.model_size,
|
|
132
|
+
device=device,
|
|
133
|
+
download_root=download_root,
|
|
134
|
+
)
|
|
135
|
+
self.logger.info(f"Whisper model '{self.config.model_size}' loaded successfully")
|
|
136
|
+
return self._model
|
|
137
|
+
except RuntimeError as e:
|
|
138
|
+
if "out of memory" in str(e).lower() or "CUDA" in str(e):
|
|
139
|
+
raise TranscriptionError(
|
|
140
|
+
f"GPU out of memory loading model '{self.config.model_size}'. "
|
|
141
|
+
"Try using a smaller model (set WHISPER_MODEL_SIZE=small or tiny) "
|
|
142
|
+
"or force CPU mode (set WHISPER_DEVICE=cpu)"
|
|
143
|
+
) from e
|
|
144
|
+
raise TranscriptionError(f"Failed to load Whisper model: {e}") from e
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise TranscriptionError(f"Failed to load Whisper model: {e}") from e
|
|
147
|
+
|
|
148
|
+
def _perform_transcription(self, audio_filepath: str) -> Dict[str, Any]:
|
|
149
|
+
"""
|
|
150
|
+
Perform local Whisper transcription with word-level timestamps.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
audio_filepath: Path to the audio file to transcribe
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Raw transcription result dictionary
|
|
157
|
+
"""
|
|
158
|
+
self.logger.info(f"Starting local Whisper transcription for {audio_filepath}")
|
|
159
|
+
|
|
160
|
+
# Load model if not already loaded
|
|
161
|
+
model = self._load_model()
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
# Perform transcription with word-level timestamps
|
|
165
|
+
transcribe_kwargs = {
|
|
166
|
+
"verbose": False,
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
# Add language if specified
|
|
170
|
+
if self.config.language:
|
|
171
|
+
transcribe_kwargs["language"] = self.config.language
|
|
172
|
+
|
|
173
|
+
self.logger.debug(f"Transcribing with options: {transcribe_kwargs}")
|
|
174
|
+
result = self._whisper_module.transcribe_timestamped(
|
|
175
|
+
model,
|
|
176
|
+
audio_filepath,
|
|
177
|
+
**transcribe_kwargs,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
self.logger.info("Local Whisper transcription completed successfully")
|
|
181
|
+
return result
|
|
182
|
+
|
|
183
|
+
except RuntimeError as e:
|
|
184
|
+
if "out of memory" in str(e).lower():
|
|
185
|
+
raise TranscriptionError(
|
|
186
|
+
f"GPU out of memory during transcription. "
|
|
187
|
+
"Try using a smaller model (WHISPER_MODEL_SIZE=small) "
|
|
188
|
+
"or force CPU mode (WHISPER_DEVICE=cpu)"
|
|
189
|
+
) from e
|
|
190
|
+
raise TranscriptionError(f"Transcription failed: {e}") from e
|
|
191
|
+
except Exception as e:
|
|
192
|
+
raise TranscriptionError(f"Transcription failed: {e}") from e
|
|
193
|
+
|
|
194
|
+
def _convert_result_format(self, raw_data: Dict[str, Any]) -> TranscriptionData:
|
|
195
|
+
"""
|
|
196
|
+
Convert whisper-timestamped output to standard TranscriptionData format.
|
|
197
|
+
|
|
198
|
+
The whisper-timestamped library returns results in this format:
|
|
199
|
+
{
|
|
200
|
+
"text": "Full transcription text",
|
|
201
|
+
"segments": [
|
|
202
|
+
{
|
|
203
|
+
"id": 0,
|
|
204
|
+
"text": "Segment text",
|
|
205
|
+
"start": 0.0,
|
|
206
|
+
"end": 2.5,
|
|
207
|
+
"words": [
|
|
208
|
+
{"text": "word", "start": 0.0, "end": 0.5, "confidence": 0.95},
|
|
209
|
+
...
|
|
210
|
+
]
|
|
211
|
+
},
|
|
212
|
+
...
|
|
213
|
+
],
|
|
214
|
+
"language": "en"
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
raw_data: Raw output from whisper_timestamped.transcribe_timestamped()
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
TranscriptionData with segments, words, and metadata
|
|
222
|
+
"""
|
|
223
|
+
segments = []
|
|
224
|
+
all_words = []
|
|
225
|
+
|
|
226
|
+
for seg in raw_data.get("segments", []):
|
|
227
|
+
segment_words = []
|
|
228
|
+
|
|
229
|
+
for word_data in seg.get("words", []):
|
|
230
|
+
word = Word(
|
|
231
|
+
id=WordUtils.generate_id(),
|
|
232
|
+
text=word_data.get("text", "").strip(),
|
|
233
|
+
start_time=word_data.get("start", 0.0),
|
|
234
|
+
end_time=word_data.get("end", 0.0),
|
|
235
|
+
confidence=word_data.get("confidence"),
|
|
236
|
+
)
|
|
237
|
+
segment_words.append(word)
|
|
238
|
+
all_words.append(word)
|
|
239
|
+
|
|
240
|
+
# Create segment with its words
|
|
241
|
+
segment = LyricsSegment(
|
|
242
|
+
id=WordUtils.generate_id(),
|
|
243
|
+
text=seg.get("text", "").strip(),
|
|
244
|
+
words=segment_words,
|
|
245
|
+
start_time=seg.get("start", 0.0),
|
|
246
|
+
end_time=seg.get("end", 0.0),
|
|
247
|
+
)
|
|
248
|
+
segments.append(segment)
|
|
249
|
+
|
|
250
|
+
return TranscriptionData(
|
|
251
|
+
segments=segments,
|
|
252
|
+
words=all_words,
|
|
253
|
+
text=raw_data.get("text", "").strip(),
|
|
254
|
+
source=self.get_name(),
|
|
255
|
+
metadata={
|
|
256
|
+
"model_size": self.config.model_size,
|
|
257
|
+
"detected_language": raw_data.get("language", "unknown"),
|
|
258
|
+
"device": self._get_device(),
|
|
259
|
+
},
|
|
260
|
+
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|