lattifai 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +10 -0
- lattifai/alignment/lattice1_aligner.py +64 -15
- lattifai/alignment/lattice1_worker.py +135 -50
- lattifai/alignment/segmenter.py +3 -2
- lattifai/alignment/tokenizer.py +14 -13
- lattifai/audio2.py +269 -70
- lattifai/caption/caption.py +213 -19
- lattifai/cli/__init__.py +2 -0
- lattifai/cli/alignment.py +2 -1
- lattifai/cli/app_installer.py +35 -33
- lattifai/cli/caption.py +9 -19
- lattifai/cli/diarization.py +108 -0
- lattifai/cli/server.py +3 -1
- lattifai/cli/transcribe.py +55 -38
- lattifai/cli/youtube.py +1 -0
- lattifai/client.py +42 -121
- lattifai/config/alignment.py +37 -2
- lattifai/config/caption.py +1 -1
- lattifai/config/media.py +23 -3
- lattifai/config/transcription.py +4 -0
- lattifai/diarization/lattifai.py +18 -7
- lattifai/errors.py +7 -3
- lattifai/mixin.py +45 -16
- lattifai/server/app.py +2 -1
- lattifai/transcription/__init__.py +1 -1
- lattifai/transcription/base.py +21 -2
- lattifai/transcription/gemini.py +127 -1
- lattifai/transcription/lattifai.py +30 -2
- lattifai/utils.py +96 -28
- lattifai/workflow/file_manager.py +15 -13
- lattifai/workflow/youtube.py +16 -1
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/METADATA +86 -22
- lattifai-1.1.0.dist-info/RECORD +57 -0
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/entry_points.txt +2 -0
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/licenses/LICENSE +1 -1
- lattifai-1.0.4.dist-info/RECORD +0 -56
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/WHEEL +0 -0
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/top_level.txt +0 -0
lattifai/config/media.py
CHANGED
|
@@ -52,12 +52,23 @@ class MediaConfig:
|
|
|
52
52
|
sample_rate: Optional[int] = None
|
|
53
53
|
"""Audio sample rate in Hz (e.g., 16000, 44100)."""
|
|
54
54
|
|
|
55
|
-
channels: Optional[int] = None
|
|
56
|
-
"""Number of audio channels (1=mono, 2=stereo)."""
|
|
57
|
-
|
|
58
55
|
channel_selector: Optional[str | int] = "average"
|
|
59
56
|
"""Audio channel selection strategy: 'average', 'left', 'right', or channel index."""
|
|
60
57
|
|
|
58
|
+
# Audio Streaming Configuration
|
|
59
|
+
streaming_chunk_secs: Optional[float] = 600.0
|
|
60
|
+
"""Duration in seconds of each audio chunk for streaming mode.
|
|
61
|
+
When set to a value (e.g., 600.0), enables streaming mode for processing very long audio files (>1 hour).
|
|
62
|
+
Audio is processed in chunks to keep memory usage low (<4GB peak), suitable for 20+ hour files.
|
|
63
|
+
When None, disables streaming and loads entire audio into memory.
|
|
64
|
+
Valid range: 1-1800 seconds (minimum 1 second, maximum 30 minutes).
|
|
65
|
+
Default: 600 seconds (10 minutes).
|
|
66
|
+
Recommended: Use 60 seconds or larger for optimal performance.
|
|
67
|
+
- Smaller chunks: Lower memory usage, more frequent I/O
|
|
68
|
+
- Larger chunks: Better alignment context, higher memory usage
|
|
69
|
+
Note: Streaming may add slight processing overhead but enables handling arbitrarily long files.
|
|
70
|
+
"""
|
|
71
|
+
|
|
61
72
|
# Output / download configuration
|
|
62
73
|
output_dir: Path = field(default_factory=lambda: Path.cwd())
|
|
63
74
|
"""Directory for output files (default: current working directory)."""
|
|
@@ -87,12 +98,21 @@ class MediaConfig:
|
|
|
87
98
|
self._normalize_media_format()
|
|
88
99
|
self._process_input_path()
|
|
89
100
|
self._process_output_path()
|
|
101
|
+
self._validate_streaming_config()
|
|
90
102
|
|
|
91
103
|
def _setup_output_directory(self) -> None:
|
|
92
104
|
"""Ensure output directory exists and is valid."""
|
|
93
105
|
resolved_output_dir = self._ensure_dir(self.output_dir)
|
|
94
106
|
self.output_dir = resolved_output_dir
|
|
95
107
|
|
|
108
|
+
def _validate_streaming_config(self) -> None:
|
|
109
|
+
"""Validate streaming configuration parameters."""
|
|
110
|
+
if self.streaming_chunk_secs is not None:
|
|
111
|
+
if not 1.0 <= self.streaming_chunk_secs <= 1800.0:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"streaming_chunk_secs must be between 1 and 1800 seconds (1 second to 30 minutes), got {self.streaming_chunk_secs}. Recommended: 60 seconds or larger."
|
|
114
|
+
)
|
|
115
|
+
|
|
96
116
|
def _validate_default_formats(self) -> None:
|
|
97
117
|
"""Validate default audio and video formats."""
|
|
98
118
|
self.default_audio_format = self._normalize_format(self.default_audio_format)
|
lattifai/config/transcription.py
CHANGED
|
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
|
|
12
12
|
SUPPORTED_TRANSCRIPTION_MODELS = Literal[
|
|
13
13
|
"gemini-2.5-pro",
|
|
14
14
|
"gemini-3-pro-preview",
|
|
15
|
+
"gemini-3-flash-preview",
|
|
15
16
|
"nvidia/parakeet-tdt-0.6b-v3",
|
|
16
17
|
"nvidia/canary-1b-v2",
|
|
17
18
|
"iic/SenseVoiceSmall",
|
|
@@ -50,6 +51,9 @@ class TranscriptionConfig:
|
|
|
50
51
|
lattice_model_path: Optional[str] = None
|
|
51
52
|
"""Path to local LattifAI model. Will be auto-set in LattifAI client."""
|
|
52
53
|
|
|
54
|
+
model_hub: Literal["huggingface", "modelscope"] = "huggingface"
|
|
55
|
+
"""Which model hub to use when resolving lattice models for transcription."""
|
|
56
|
+
|
|
53
57
|
client_wrapper: Optional["SyncAPIClient"] = field(default=None, repr=False)
|
|
54
58
|
"""Reference to the SyncAPIClient instance. Auto-set during client initialization."""
|
|
55
59
|
|
lattifai/diarization/lattifai.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
"""LattifAI speaker diarization implementation."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from
|
|
5
|
-
from typing import List, Optional, Tuple
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Callable, List, Optional, Tuple
|
|
6
6
|
|
|
7
|
-
import
|
|
8
|
-
from
|
|
7
|
+
import numpy as np
|
|
8
|
+
from lattifai_core.diarization import DiarizationOutput
|
|
9
|
+
from tgt import TextGrid
|
|
9
10
|
|
|
10
11
|
from lattifai.audio2 import AudioData
|
|
11
12
|
from lattifai.caption import Supervision
|
|
@@ -60,7 +61,7 @@ class LattifAIDiarizer:
|
|
|
60
61
|
num_speakers: Optional[int] = None,
|
|
61
62
|
min_speakers: Optional[int] = None,
|
|
62
63
|
max_speakers: Optional[int] = None,
|
|
63
|
-
) ->
|
|
64
|
+
) -> DiarizationOutput:
|
|
64
65
|
"""Perform speaker diarization on the input audio."""
|
|
65
66
|
return self.diarizer.diarize(
|
|
66
67
|
input_media,
|
|
@@ -73,11 +74,16 @@ class LattifAIDiarizer:
|
|
|
73
74
|
self,
|
|
74
75
|
input_media: AudioData,
|
|
75
76
|
alignments: List[Supervision],
|
|
76
|
-
diarization: Optional[
|
|
77
|
+
diarization: Optional[DiarizationOutput] = None,
|
|
77
78
|
num_speakers: Optional[int] = None,
|
|
78
79
|
min_speakers: Optional[int] = None,
|
|
79
80
|
max_speakers: Optional[int] = None,
|
|
80
|
-
|
|
81
|
+
alignment_fn: Optional[Callable] = None,
|
|
82
|
+
transcribe_fn: Optional[Callable] = None,
|
|
83
|
+
separate_fn: Optional[Callable] = None,
|
|
84
|
+
debug: bool = False,
|
|
85
|
+
output_path: Optional[str] = None,
|
|
86
|
+
) -> Tuple[DiarizationOutput, List[Supervision]]:
|
|
81
87
|
"""Diarize the given media input and return alignments with refined speaker labels."""
|
|
82
88
|
return self.diarizer.diarize_with_alignments(
|
|
83
89
|
input_media,
|
|
@@ -86,4 +92,9 @@ class LattifAIDiarizer:
|
|
|
86
92
|
num_speakers=num_speakers,
|
|
87
93
|
min_speakers=min_speakers,
|
|
88
94
|
max_speakers=max_speakers,
|
|
95
|
+
alignment_fn=alignment_fn,
|
|
96
|
+
transcribe_fn=transcribe_fn,
|
|
97
|
+
separate_fn=separate_fn,
|
|
98
|
+
debug=debug,
|
|
99
|
+
output_path=output_path,
|
|
89
100
|
)
|
lattifai/errors.py
CHANGED
|
@@ -11,11 +11,15 @@ LATTICE_DECODING_FAILURE_HELP = (
|
|
|
11
11
|
"1) Media(Audio/Video) and text content mismatch:\n"
|
|
12
12
|
" - The transcript/caption does not accurately match the media content\n"
|
|
13
13
|
" - Text may be from a different version or section of the media\n"
|
|
14
|
-
"
|
|
15
|
-
"
|
|
16
|
-
"
|
|
14
|
+
"2) Text formatting issues:\n"
|
|
15
|
+
" - Special characters, HTML entities, or unusual punctuation may cause alignment failures\n"
|
|
16
|
+
" - Text normalization is enabled by default (caption.normalize_text=True)\n"
|
|
17
|
+
" If you disabled it, try re-enabling: caption.normalize_text=True\n"
|
|
18
|
+
"3) Unsupported media type:\n"
|
|
17
19
|
" - Singing is not yet supported, this will be optimized in future versions\n\n"
|
|
18
20
|
"💡 Troubleshooting tips:\n"
|
|
21
|
+
" • Text normalization is enabled by default to handle special characters\n"
|
|
22
|
+
" (no action needed unless you explicitly set caption.normalize_text=False)\n"
|
|
19
23
|
" • Verify the transcript matches the media by listening to a few segments\n"
|
|
20
24
|
" • For YouTube videos, manually check if auto-generated transcript are accurate\n"
|
|
21
25
|
" • Consider using a different transcription source if Gemini results are incomplete"
|
lattifai/mixin.py
CHANGED
|
@@ -10,6 +10,7 @@ from lhotse.utils import Pathlike
|
|
|
10
10
|
from lattifai.audio2 import AudioData
|
|
11
11
|
from lattifai.caption import Caption
|
|
12
12
|
from lattifai.errors import CaptionProcessingError
|
|
13
|
+
from lattifai.utils import safe_print
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
from .config import AlignmentConfig, CaptionConfig, ClientConfig, DiarizationConfig, TranscriptionConfig
|
|
@@ -183,7 +184,9 @@ class LattifAIClientMixin:
|
|
|
183
184
|
from lattifai.utils import _resolve_model_path
|
|
184
185
|
|
|
185
186
|
if transcription_config is not None:
|
|
186
|
-
transcription_config.lattice_model_path = _resolve_model_path(
|
|
187
|
+
transcription_config.lattice_model_path = _resolve_model_path(
|
|
188
|
+
alignment_config.model_name, getattr(alignment_config, "model_hub", "huggingface")
|
|
189
|
+
)
|
|
187
190
|
|
|
188
191
|
# Set client_wrapper for all configs
|
|
189
192
|
alignment_config.client_wrapper = self
|
|
@@ -278,7 +281,7 @@ class LattifAIClientMixin:
|
|
|
278
281
|
|
|
279
282
|
try:
|
|
280
283
|
if verbose:
|
|
281
|
-
|
|
284
|
+
safe_print(colorful.cyan(f"📖 Step 1: Reading caption file from {input_caption}"))
|
|
282
285
|
caption = Caption.read(
|
|
283
286
|
input_caption,
|
|
284
287
|
format=input_caption_format,
|
|
@@ -287,18 +290,18 @@ class LattifAIClientMixin:
|
|
|
287
290
|
diarization_file = Path(str(input_caption)).with_suffix(".SpkDiar")
|
|
288
291
|
if diarization_file.exists():
|
|
289
292
|
if verbose:
|
|
290
|
-
|
|
293
|
+
safe_print(colorful.cyan(f"📖 Step 1b: Reading speaker diarization from {diarization_file}"))
|
|
291
294
|
caption.read_speaker_diarization(diarization_file)
|
|
292
295
|
events_file = Path(str(input_caption)).with_suffix(".AED")
|
|
293
296
|
if events_file.exists():
|
|
294
297
|
if verbose:
|
|
295
|
-
|
|
298
|
+
safe_print(colorful.cyan(f"📖 Step 1c: Reading audio events from {events_file}"))
|
|
296
299
|
from tgt import read_textgrid
|
|
297
300
|
|
|
298
301
|
caption.audio_events = read_textgrid(events_file)
|
|
299
302
|
|
|
300
303
|
if verbose:
|
|
301
|
-
|
|
304
|
+
safe_print(colorful.green(f" ✓ Parsed {len(caption)} caption segments"))
|
|
302
305
|
return caption
|
|
303
306
|
except Exception as e:
|
|
304
307
|
raise CaptionProcessingError(
|
|
@@ -332,10 +335,10 @@ class LattifAIClientMixin:
|
|
|
332
335
|
)
|
|
333
336
|
diarization_file = Path(str(output_caption_path)).with_suffix(".SpkDiar")
|
|
334
337
|
if not diarization_file.exists() and caption.speaker_diarization:
|
|
335
|
-
|
|
338
|
+
safe_print(colorful.green(f" Writing speaker diarization to: {diarization_file}"))
|
|
336
339
|
caption.write_speaker_diarization(diarization_file)
|
|
337
340
|
|
|
338
|
-
|
|
341
|
+
safe_print(colorful.green(f"🎉🎉🎉🎉🎉 Caption file written to: {output_caption_path}"))
|
|
339
342
|
return result
|
|
340
343
|
except Exception as e:
|
|
341
344
|
raise CaptionProcessingError(
|
|
@@ -352,14 +355,14 @@ class LattifAIClientMixin:
|
|
|
352
355
|
force_overwrite: bool,
|
|
353
356
|
) -> str:
|
|
354
357
|
"""Download media from YouTube (async implementation)."""
|
|
355
|
-
|
|
358
|
+
safe_print(colorful.cyan("📥 Downloading media from YouTube..."))
|
|
356
359
|
media_file = await self.downloader.download_media(
|
|
357
360
|
url=url,
|
|
358
361
|
output_dir=str(output_dir),
|
|
359
362
|
media_format=media_format,
|
|
360
363
|
force_overwrite=force_overwrite,
|
|
361
364
|
)
|
|
362
|
-
|
|
365
|
+
safe_print(colorful.green(f" ✓ Media downloaded: {media_file}"))
|
|
363
366
|
return media_file
|
|
364
367
|
|
|
365
368
|
def _download_media_sync(
|
|
@@ -379,6 +382,7 @@ class LattifAIClientMixin:
|
|
|
379
382
|
media_file: Union[str, Path, AudioData],
|
|
380
383
|
source_lang: Optional[str],
|
|
381
384
|
is_async: bool = False,
|
|
385
|
+
output_dir: Optional[Path] = None,
|
|
382
386
|
) -> Caption:
|
|
383
387
|
"""
|
|
384
388
|
Get captions by downloading or transcribing.
|
|
@@ -400,14 +404,23 @@ class LattifAIClientMixin:
|
|
|
400
404
|
# Transcription mode: use Transcriber to transcribe
|
|
401
405
|
self._validate_transcription_setup()
|
|
402
406
|
|
|
403
|
-
|
|
407
|
+
safe_print(colorful.cyan(f"🎤 Transcribing({self.transcriber.name}) media: {str(media_file)} ..."))
|
|
404
408
|
transcription = await self.transcriber.transcribe_file(media_file, language=source_lang)
|
|
405
|
-
|
|
409
|
+
safe_print(colorful.green(" ✓ Transcription completed."))
|
|
406
410
|
|
|
407
411
|
if "gemini" in self.transcriber.name.lower():
|
|
412
|
+
safe_print(colorful.yellow("🔍 Gemini raw output:"))
|
|
413
|
+
safe_print(colorful.yellow(f"{transcription[:1000]}...")) # Print first 1000 chars
|
|
414
|
+
|
|
408
415
|
# write to temp file and use Caption read
|
|
409
|
-
|
|
410
|
-
|
|
416
|
+
# On Windows, we need to close the file before writing to it
|
|
417
|
+
tmp_file = tempfile.NamedTemporaryFile(
|
|
418
|
+
suffix=self.transcriber.file_suffix, delete=False, mode="w", encoding="utf-8"
|
|
419
|
+
)
|
|
420
|
+
tmp_path = Path(tmp_file.name)
|
|
421
|
+
tmp_file.close() # Close file before writing
|
|
422
|
+
|
|
423
|
+
try:
|
|
411
424
|
await asyncio.to_thread(
|
|
412
425
|
self.transcriber.write,
|
|
413
426
|
transcription,
|
|
@@ -417,6 +430,22 @@ class LattifAIClientMixin:
|
|
|
417
430
|
transcription = self._read_caption(
|
|
418
431
|
tmp_path, input_caption_format="gemini", normalize_text=False, verbose=False
|
|
419
432
|
)
|
|
433
|
+
finally:
|
|
434
|
+
# Clean up temp file
|
|
435
|
+
if tmp_path.exists():
|
|
436
|
+
tmp_path.unlink()
|
|
437
|
+
else:
|
|
438
|
+
safe_print(colorful.yellow(f"🔍 {self.transcriber.name} raw output:"))
|
|
439
|
+
if isinstance(transcription, Caption):
|
|
440
|
+
safe_print(colorful.yellow(f"Caption with {len(transcription.transcription)} segments"))
|
|
441
|
+
if transcription.transcription:
|
|
442
|
+
safe_print(colorful.yellow(f"First segment: {transcription.transcription[0].text}"))
|
|
443
|
+
|
|
444
|
+
if output_dir:
|
|
445
|
+
# Generate transcript file path
|
|
446
|
+
transcript_file = output_dir / f"{Path(str(media_file)).stem}_{self.transcriber.file_name}"
|
|
447
|
+
await asyncio.to_thread(self.transcriber.write, transcription, transcript_file, encoding="utf-8")
|
|
448
|
+
safe_print(colorful.green(f" ✓ Transcription saved to: {transcript_file}"))
|
|
420
449
|
|
|
421
450
|
return transcription
|
|
422
451
|
|
|
@@ -459,7 +488,7 @@ class LattifAIClientMixin:
|
|
|
459
488
|
if self.caption_config.input_path:
|
|
460
489
|
caption_path = Path(self.caption_config.input_path)
|
|
461
490
|
if caption_path.exists():
|
|
462
|
-
|
|
491
|
+
safe_print(colorful.green(f"📄 Using provided caption file: {caption_path}"))
|
|
463
492
|
return str(caption_path)
|
|
464
493
|
else:
|
|
465
494
|
raise FileNotFoundError(f"Provided caption path does not exist: {caption_path}")
|
|
@@ -496,7 +525,7 @@ class LattifAIClientMixin:
|
|
|
496
525
|
|
|
497
526
|
# elif choice == "overwrite": continue to transcribe below
|
|
498
527
|
|
|
499
|
-
|
|
528
|
+
safe_print(colorful.cyan(f"🎤 Transcribing media with {transcriber_name}..."))
|
|
500
529
|
if self.transcriber.supports_url:
|
|
501
530
|
transcription = await self.transcriber.transcribe(url, language=source_lang)
|
|
502
531
|
else:
|
|
@@ -508,7 +537,7 @@ class LattifAIClientMixin:
|
|
|
508
537
|
caption_file = transcription
|
|
509
538
|
else:
|
|
510
539
|
caption_file = str(transcript_file)
|
|
511
|
-
|
|
540
|
+
safe_print(colorful.green(f" ✓ Transcription completed: {caption_file}"))
|
|
512
541
|
else:
|
|
513
542
|
# Download YouTube captions
|
|
514
543
|
caption_file = await self.downloader.download_captions(
|
lattifai/server/app.py
CHANGED
|
@@ -232,7 +232,7 @@ async def align_files(
|
|
|
232
232
|
normalize_text: bool = Form(False),
|
|
233
233
|
output_format: str = Form("srt"),
|
|
234
234
|
transcription_model: str = Form("nvidia/parakeet-tdt-0.6b-v3"),
|
|
235
|
-
alignment_model: str = Form("
|
|
235
|
+
alignment_model: str = Form("LattifAI/Lattice-1"),
|
|
236
236
|
):
|
|
237
237
|
# Check if LATTIFAI_API_KEY is set
|
|
238
238
|
if not os.environ.get("LATTIFAI_API_KEY"):
|
|
@@ -423,4 +423,5 @@ def process_alignment(
|
|
|
423
423
|
input_caption=str(caption_path) if caption_path else None,
|
|
424
424
|
output_caption_path=str(output_caption_path) if output_caption_path else None,
|
|
425
425
|
split_sentence=split_sentence,
|
|
426
|
+
streaming_chunk_secs=None, # Server API default: no streaming
|
|
426
427
|
)
|
|
@@ -70,7 +70,7 @@ def create_transcriber(
|
|
|
70
70
|
raise ValueError(
|
|
71
71
|
f"Cannot determine transcriber for model_name='{transcription_config.model_name}'. "
|
|
72
72
|
f"Supported patterns: \n"
|
|
73
|
-
f" - Gemini API models: 'gemini-2.5-pro', 'gemini-3-pro-preview'\n"
|
|
73
|
+
f" - Gemini API models: 'gemini-2.5-pro', 'gemini-3-pro-preview', 'gemini-3-flash-preview'\n"
|
|
74
74
|
f" - Local HF models: 'nvidia/parakeet-*', 'iic/SenseVoiceSmall', etc.\n"
|
|
75
75
|
f"Please specify a valid model_name."
|
|
76
76
|
)
|
lattifai/transcription/base.py
CHANGED
|
@@ -2,10 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Optional, Union
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
6
8
|
|
|
7
9
|
from lattifai.audio2 import AudioData
|
|
8
|
-
from lattifai.caption import Caption
|
|
10
|
+
from lattifai.caption import Caption, Supervision
|
|
9
11
|
from lattifai.config import TranscriptionConfig
|
|
10
12
|
from lattifai.logging import get_logger
|
|
11
13
|
|
|
@@ -96,6 +98,23 @@ class BaseTranscriber(ABC):
|
|
|
96
98
|
language: Optional language code for transcription.
|
|
97
99
|
"""
|
|
98
100
|
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def transcribe_numpy(
|
|
103
|
+
self,
|
|
104
|
+
audio: Union[np.ndarray, List[np.ndarray]],
|
|
105
|
+
language: Optional[str] = None,
|
|
106
|
+
) -> Union[Supervision, List[Supervision]]:
|
|
107
|
+
"""
|
|
108
|
+
Transcribe audio from a numpy array and return Supervision.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
audio_array: Audio data as numpy array (shape: [samples]).
|
|
112
|
+
language: Optional language code for transcription.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Supervision object with transcription info.
|
|
116
|
+
"""
|
|
117
|
+
|
|
99
118
|
@abstractmethod
|
|
100
119
|
def write(self, transcript: Union[str, Caption], output_file: Path, encoding: str = "utf-8") -> Path:
|
|
101
120
|
"""
|
lattifai/transcription/gemini.py
CHANGED
|
@@ -2,12 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Optional, Union
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
6
|
|
|
7
|
+
import numpy as np
|
|
7
8
|
from google import genai
|
|
8
9
|
from google.genai.types import GenerateContentConfig, Part, ThinkingConfig
|
|
9
10
|
|
|
10
11
|
from lattifai.audio2 import AudioData
|
|
12
|
+
from lattifai.caption import Supervision
|
|
11
13
|
from lattifai.config import TranscriptionConfig
|
|
12
14
|
from lattifai.transcription.base import BaseTranscriber
|
|
13
15
|
from lattifai.transcription.prompts import get_prompt_loader
|
|
@@ -118,6 +120,130 @@ class GeminiTranscriber(BaseTranscriber):
|
|
|
118
120
|
self.logger.error(f"Gemini transcription failed: {str(e)}")
|
|
119
121
|
raise RuntimeError(f"Gemini transcription failed: {str(e)}")
|
|
120
122
|
|
|
123
|
+
def transcribe_numpy(
|
|
124
|
+
self,
|
|
125
|
+
audio: Union[np.ndarray, List[np.ndarray]],
|
|
126
|
+
language: Optional[str] = None,
|
|
127
|
+
) -> Union[Supervision, List[Supervision]]:
|
|
128
|
+
"""
|
|
129
|
+
Transcribe audio from a numpy array (or list of arrays) and return Supervision.
|
|
130
|
+
|
|
131
|
+
Note: Gemini API does not support word-level alignment. The returned
|
|
132
|
+
Supervision will contain only the full transcription text without alignment.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
audio: Audio data as numpy array (shape: [samples]),
|
|
136
|
+
or a list of such arrays for batch processing.
|
|
137
|
+
language: Optional language code for transcription.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Supervision object (or list of Supervision objects) with transcription text (no alignment).
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
ValueError: If API key not provided
|
|
144
|
+
RuntimeError: If transcription fails
|
|
145
|
+
"""
|
|
146
|
+
# Handle batch processing
|
|
147
|
+
if isinstance(audio, list):
|
|
148
|
+
return [self.transcribe_numpy(arr, language=language) for arr in audio]
|
|
149
|
+
|
|
150
|
+
audio_array = audio
|
|
151
|
+
# Use default sample rate of 16000 Hz
|
|
152
|
+
sample_rate = 16000
|
|
153
|
+
|
|
154
|
+
if self.config.verbose:
|
|
155
|
+
self.logger.info(f"🎤 Starting Gemini transcription for numpy array (sample_rate={sample_rate})")
|
|
156
|
+
|
|
157
|
+
# Ensure audio is in the correct shape
|
|
158
|
+
if audio_array.ndim == 1:
|
|
159
|
+
audio_array = audio_array.reshape(1, -1)
|
|
160
|
+
elif audio_array.ndim > 2:
|
|
161
|
+
raise ValueError(f"Audio array must be 1D or 2D, got shape {audio_array.shape}")
|
|
162
|
+
|
|
163
|
+
# Save numpy array to temporary file
|
|
164
|
+
import tempfile
|
|
165
|
+
|
|
166
|
+
import soundfile as sf
|
|
167
|
+
|
|
168
|
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
|
169
|
+
# Transpose to (samples, channels) for soundfile
|
|
170
|
+
sf.write(tmp_file.name, audio_array.T, sample_rate)
|
|
171
|
+
tmp_path = Path(tmp_file.name)
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
# Transcribe using simple ASR prompt
|
|
175
|
+
import asyncio
|
|
176
|
+
|
|
177
|
+
transcript = asyncio.run(self._transcribe_with_simple_prompt(tmp_path, language=language))
|
|
178
|
+
|
|
179
|
+
# Create Supervision object from transcript
|
|
180
|
+
duration = audio_array.shape[-1] / sample_rate
|
|
181
|
+
supervision = Supervision(
|
|
182
|
+
id="gemini-transcription",
|
|
183
|
+
recording_id="numpy-array",
|
|
184
|
+
start=0.0,
|
|
185
|
+
duration=duration,
|
|
186
|
+
text=transcript,
|
|
187
|
+
speaker=None,
|
|
188
|
+
alignment=None, # Gemini does not provide word-level alignment
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return supervision
|
|
192
|
+
|
|
193
|
+
finally:
|
|
194
|
+
# Clean up temporary file
|
|
195
|
+
if tmp_path.exists():
|
|
196
|
+
tmp_path.unlink()
|
|
197
|
+
|
|
198
|
+
async def _transcribe_with_simple_prompt(self, media_file: Path, language: Optional[str] = None) -> str:
|
|
199
|
+
"""
|
|
200
|
+
Transcribe audio using a simple ASR prompt instead of complex instructions.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
media_file: Path to audio file
|
|
204
|
+
language: Optional language code
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Transcribed text
|
|
208
|
+
"""
|
|
209
|
+
client = self._get_client()
|
|
210
|
+
|
|
211
|
+
# Upload audio file
|
|
212
|
+
if self.config.verbose:
|
|
213
|
+
self.logger.info("📤 Uploading audio file to Gemini...")
|
|
214
|
+
uploaded_file = client.files.upload(file=str(media_file))
|
|
215
|
+
|
|
216
|
+
# Simple ASR prompt
|
|
217
|
+
system_prompt = "Transcribe the audio."
|
|
218
|
+
if language:
|
|
219
|
+
system_prompt = f"Transcribe the audio in {language}."
|
|
220
|
+
|
|
221
|
+
# Create simple generation config
|
|
222
|
+
simple_config = GenerateContentConfig(
|
|
223
|
+
system_instruction=system_prompt,
|
|
224
|
+
response_modalities=["TEXT"],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
contents = Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type)
|
|
228
|
+
response = await asyncio.get_event_loop().run_in_executor(
|
|
229
|
+
None,
|
|
230
|
+
lambda: client.models.generate_content(
|
|
231
|
+
model=self.config.model_name,
|
|
232
|
+
contents=contents,
|
|
233
|
+
config=simple_config,
|
|
234
|
+
),
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if not response.text:
|
|
238
|
+
raise RuntimeError("Empty response from Gemini API")
|
|
239
|
+
|
|
240
|
+
transcript = response.text.strip()
|
|
241
|
+
|
|
242
|
+
if self.config.verbose:
|
|
243
|
+
self.logger.info(f"✅ Transcription completed: {len(transcript)} characters")
|
|
244
|
+
|
|
245
|
+
return transcript
|
|
246
|
+
|
|
121
247
|
def _get_transcription_prompt(self) -> str:
|
|
122
248
|
"""Get (and cache) transcription system prompt from prompts module."""
|
|
123
249
|
if self._system_prompt is not None:
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
"""Transcription module with config-driven architecture."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Optional, Union
|
|
4
|
+
from typing import List, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
5
7
|
|
|
6
8
|
from lattifai.audio2 import AudioData
|
|
7
|
-
from lattifai.caption import Caption
|
|
9
|
+
from lattifai.caption import Caption, Supervision
|
|
8
10
|
from lattifai.config import TranscriptionConfig
|
|
9
11
|
from lattifai.transcription.base import BaseTranscriber
|
|
10
12
|
from lattifai.transcription.prompts import get_prompt_loader # noqa: F401
|
|
@@ -74,6 +76,32 @@ class LattifAITranscriber(BaseTranscriber):
|
|
|
74
76
|
|
|
75
77
|
return caption
|
|
76
78
|
|
|
79
|
+
def transcribe_numpy(
|
|
80
|
+
self,
|
|
81
|
+
audio: Union[np.ndarray, List[np.ndarray]],
|
|
82
|
+
language: Optional[str] = None,
|
|
83
|
+
) -> Union[Supervision, List[Supervision]]:
|
|
84
|
+
"""
|
|
85
|
+
Transcribe audio from a numpy array (or list of arrays) and return Supervision.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
audio: Audio data as numpy array (shape: [samples]),
|
|
89
|
+
or a list of such arrays for batch processing.
|
|
90
|
+
language: Optional language code for transcription.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Supervision object (or list of Supervision objects) with transcription and alignment info.
|
|
94
|
+
"""
|
|
95
|
+
if self._transcriber is None:
|
|
96
|
+
from lattifai_core.transcription import LattifAITranscriber as CoreLattifAITranscriber
|
|
97
|
+
|
|
98
|
+
self._transcriber = CoreLattifAITranscriber.from_pretrained(model_config=self.config)
|
|
99
|
+
|
|
100
|
+
# Delegate to core transcriber which handles both single arrays and lists
|
|
101
|
+
return self._transcriber.transcribe(
|
|
102
|
+
audio, language=language, return_hypotheses=True, progress_bar=False, timestamps=True
|
|
103
|
+
)[0]
|
|
104
|
+
|
|
77
105
|
def write(
|
|
78
106
|
self, transcript: Caption, output_file: Path, encoding: str = "utf-8", cache_audio_events: bool = True
|
|
79
107
|
) -> Path:
|