lattifai 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +42 -27
- lattifai/alignment/__init__.py +6 -0
- lattifai/alignment/lattice1_aligner.py +119 -0
- lattifai/{workers/lattice1_alpha.py → alignment/lattice1_worker.py} +33 -132
- lattifai/{tokenizer → alignment}/phonemizer.py +1 -1
- lattifai/alignment/segmenter.py +166 -0
- lattifai/{tokenizer → alignment}/tokenizer.py +186 -112
- lattifai/audio2.py +211 -0
- lattifai/caption/__init__.py +20 -0
- lattifai/caption/caption.py +1275 -0
- lattifai/{io → caption}/supervision.py +1 -0
- lattifai/{io → caption}/text_parser.py +53 -10
- lattifai/cli/__init__.py +17 -0
- lattifai/cli/alignment.py +153 -0
- lattifai/cli/caption.py +204 -0
- lattifai/cli/server.py +19 -0
- lattifai/cli/transcribe.py +197 -0
- lattifai/cli/youtube.py +128 -0
- lattifai/client.py +455 -246
- lattifai/config/__init__.py +20 -0
- lattifai/config/alignment.py +73 -0
- lattifai/config/caption.py +178 -0
- lattifai/config/client.py +46 -0
- lattifai/config/diarization.py +67 -0
- lattifai/config/media.py +335 -0
- lattifai/config/transcription.py +84 -0
- lattifai/diarization/__init__.py +5 -0
- lattifai/diarization/lattifai.py +89 -0
- lattifai/errors.py +41 -34
- lattifai/logging.py +116 -0
- lattifai/mixin.py +552 -0
- lattifai/server/app.py +420 -0
- lattifai/transcription/__init__.py +76 -0
- lattifai/transcription/base.py +108 -0
- lattifai/transcription/gemini.py +219 -0
- lattifai/transcription/lattifai.py +103 -0
- lattifai/types.py +30 -0
- lattifai/utils.py +3 -31
- lattifai/workflow/__init__.py +22 -0
- lattifai/workflow/agents.py +6 -0
- lattifai/{workflows → workflow}/file_manager.py +81 -57
- lattifai/workflow/youtube.py +564 -0
- lattifai-1.0.0.dist-info/METADATA +736 -0
- lattifai-1.0.0.dist-info/RECORD +52 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
- lattifai-1.0.0.dist-info/entry_points.txt +13 -0
- lattifai/base_client.py +0 -126
- lattifai/bin/__init__.py +0 -3
- lattifai/bin/agent.py +0 -324
- lattifai/bin/align.py +0 -295
- lattifai/bin/cli_base.py +0 -25
- lattifai/bin/subtitle.py +0 -210
- lattifai/io/__init__.py +0 -43
- lattifai/io/reader.py +0 -86
- lattifai/io/utils.py +0 -15
- lattifai/io/writer.py +0 -102
- lattifai/tokenizer/__init__.py +0 -3
- lattifai/workers/__init__.py +0 -3
- lattifai/workflows/__init__.py +0 -34
- lattifai/workflows/agents.py +0 -12
- lattifai/workflows/gemini.py +0 -167
- lattifai/workflows/prompts/README.md +0 -22
- lattifai/workflows/prompts/gemini/README.md +0 -24
- lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
- lattifai/workflows/youtube.py +0 -931
- lattifai-0.4.6.dist-info/METADATA +0 -806
- lattifai-0.4.6.dist-info/RECORD +0 -39
- lattifai-0.4.6.dist-info/entry_points.txt +0 -3
- /lattifai/{io → caption}/gemini_reader.py +0 -0
- /lattifai/{io → caption}/gemini_writer.py +0 -0
- /lattifai/{workflows → transcription}/prompts/__init__.py +0 -0
- /lattifai/{workflows → workflow}/base.py +0 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1275 @@
|
|
|
1
|
+
"""Caption data structure for storing subtitle information with metadata."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional, Union
|
|
8
|
+
|
|
9
|
+
from lhotse.supervision import AlignmentItem
|
|
10
|
+
from lhotse.utils import Pathlike
|
|
11
|
+
from tgt import TextGrid
|
|
12
|
+
|
|
13
|
+
from ..config.caption import InputCaptionFormat, OutputCaptionFormat
|
|
14
|
+
from .supervision import Supervision
|
|
15
|
+
from .text_parser import normalize_text as normalize_text_fn
|
|
16
|
+
from .text_parser import parse_speaker_text, parse_timestamp_text
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class Caption:
|
|
21
|
+
"""
|
|
22
|
+
Container for caption/subtitle data with metadata.
|
|
23
|
+
|
|
24
|
+
This class encapsulates a list of supervisions (subtitle segments) along with
|
|
25
|
+
metadata such as language, kind, format information, and source file details.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
supervisions: List of supervision segments containing text and timing information
|
|
29
|
+
language: Language code (e.g., 'en', 'zh', 'es')
|
|
30
|
+
kind: Caption kind/type (e.g., 'captions', 'subtitles', 'descriptions')
|
|
31
|
+
source_format: Original format of the caption file (e.g., 'vtt', 'srt', 'json')
|
|
32
|
+
source_path: Path to the source caption file
|
|
33
|
+
metadata: Additional custom metadata as key-value pairs
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# read from subtitle file
|
|
37
|
+
supervisions: List[Supervision] = field(default_factory=list)
|
|
38
|
+
# Transcription results
|
|
39
|
+
transcription: List[Supervision] = field(default_factory=list)
|
|
40
|
+
# Audio Event Detection results
|
|
41
|
+
audio_events: Optional[TextGrid] = None
|
|
42
|
+
# Speaker Diarization results
|
|
43
|
+
speaker_diarization: Optional[TextGrid] = None
|
|
44
|
+
# Alignment results
|
|
45
|
+
alignments: List[Supervision] = field(default_factory=list)
|
|
46
|
+
|
|
47
|
+
language: Optional[str] = None
|
|
48
|
+
kind: Optional[str] = None
|
|
49
|
+
source_format: Optional[str] = None
|
|
50
|
+
source_path: Optional[Pathlike] = None
|
|
51
|
+
metadata: Dict[str, str] = field(default_factory=dict)
|
|
52
|
+
|
|
53
|
+
def __len__(self) -> int:
|
|
54
|
+
"""Return the number of supervision segments."""
|
|
55
|
+
return len(self.supervisions or self.transcription)
|
|
56
|
+
|
|
57
|
+
def __iter__(self):
|
|
58
|
+
"""Iterate over supervision segments."""
|
|
59
|
+
return iter(self.supervisions)
|
|
60
|
+
|
|
61
|
+
def __getitem__(self, index):
|
|
62
|
+
"""Get supervision segment by index."""
|
|
63
|
+
return self.supervisions[index]
|
|
64
|
+
|
|
65
|
+
def __bool__(self) -> bool:
|
|
66
|
+
"""Return True if caption has supervisions."""
|
|
67
|
+
return self.__len__() > 0
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def is_empty(self) -> bool:
|
|
71
|
+
"""Check if caption has no supervisions."""
|
|
72
|
+
return len(self.supervisions) == 0
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def duration(self) -> Optional[float]:
|
|
76
|
+
"""
|
|
77
|
+
Get total duration of the caption in seconds.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Total duration from first to last supervision, or None if empty
|
|
81
|
+
"""
|
|
82
|
+
if not self.supervisions:
|
|
83
|
+
return None
|
|
84
|
+
return self.supervisions[-1].end - self.supervisions[0].start
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def start_time(self) -> Optional[float]:
|
|
88
|
+
"""Get start time of first supervision."""
|
|
89
|
+
if not self.supervisions:
|
|
90
|
+
return None
|
|
91
|
+
return self.supervisions[0].start
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def end_time(self) -> Optional[float]:
|
|
95
|
+
"""Get end time of last supervision."""
|
|
96
|
+
if not self.supervisions:
|
|
97
|
+
return None
|
|
98
|
+
return self.supervisions[-1].end
|
|
99
|
+
|
|
100
|
+
def append(self, supervision: Supervision) -> None:
|
|
101
|
+
"""Add a supervision segment to the caption."""
|
|
102
|
+
self.supervisions.append(supervision)
|
|
103
|
+
|
|
104
|
+
def extend(self, supervisions: List[Supervision]) -> None:
|
|
105
|
+
"""Add multiple supervision segments to the caption."""
|
|
106
|
+
self.supervisions.extend(supervisions)
|
|
107
|
+
|
|
108
|
+
def filter_by_speaker(self, speaker: str) -> "Caption":
|
|
109
|
+
"""
|
|
110
|
+
Create a new Caption with only supervisions from a specific speaker.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
speaker: Speaker identifier to filter by
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
New Caption instance with filtered supervisions
|
|
117
|
+
"""
|
|
118
|
+
filtered_sups = [sup for sup in self.supervisions if sup.speaker == speaker]
|
|
119
|
+
return Caption(
|
|
120
|
+
supervisions=filtered_sups,
|
|
121
|
+
language=self.language,
|
|
122
|
+
kind=self.kind,
|
|
123
|
+
source_format=self.source_format,
|
|
124
|
+
source_path=self.source_path,
|
|
125
|
+
metadata=self.metadata.copy(),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def get_speakers(self) -> List[str]:
|
|
129
|
+
"""
|
|
130
|
+
Get list of unique speakers in the caption.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Sorted list of unique speaker identifiers
|
|
134
|
+
"""
|
|
135
|
+
speakers = {sup.speaker for sup in self.supervisions if sup.speaker}
|
|
136
|
+
return sorted(speakers)
|
|
137
|
+
|
|
138
|
+
def shift_time(self, seconds: float) -> "Caption":
|
|
139
|
+
"""
|
|
140
|
+
Create a new Caption with all timestamps shifted by given seconds.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
seconds: Number of seconds to shift (positive delays, negative advances)
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
New Caption instance with shifted timestamps
|
|
147
|
+
"""
|
|
148
|
+
shifted_sups = [
|
|
149
|
+
Supervision(
|
|
150
|
+
text=sup.text,
|
|
151
|
+
start=sup.start + seconds,
|
|
152
|
+
duration=sup.duration,
|
|
153
|
+
speaker=sup.speaker,
|
|
154
|
+
id=sup.id,
|
|
155
|
+
language=sup.language,
|
|
156
|
+
alignment=sup.alignment if hasattr(sup, "alignment") else None,
|
|
157
|
+
custom=sup.custom,
|
|
158
|
+
)
|
|
159
|
+
for sup in self.supervisions
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
return Caption(
|
|
163
|
+
supervisions=shifted_sups,
|
|
164
|
+
language=self.language,
|
|
165
|
+
kind=self.kind,
|
|
166
|
+
source_format=self.source_format,
|
|
167
|
+
source_path=self.source_path,
|
|
168
|
+
metadata=self.metadata.copy(),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def to_string(self, format: str = "srt") -> str:
|
|
172
|
+
"""
|
|
173
|
+
Return caption content in specified format.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
format: Output format (e.g., 'srt', 'vtt', 'ass')
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
String containing formatted captions
|
|
180
|
+
"""
|
|
181
|
+
import pysubs2
|
|
182
|
+
|
|
183
|
+
subs = pysubs2.SSAFile()
|
|
184
|
+
|
|
185
|
+
if self.alignments:
|
|
186
|
+
alignments = self.alignments
|
|
187
|
+
else:
|
|
188
|
+
alignments = self.supervisions
|
|
189
|
+
|
|
190
|
+
if not alignments:
|
|
191
|
+
alignments = self.transcription
|
|
192
|
+
|
|
193
|
+
for sup in alignments:
|
|
194
|
+
# Add word-level timing as metadata in the caption text
|
|
195
|
+
word_items = self._parse_alignment_from_supervision(sup)
|
|
196
|
+
if word_items:
|
|
197
|
+
for word in word_items:
|
|
198
|
+
subs.append(
|
|
199
|
+
pysubs2.SSAEvent(
|
|
200
|
+
start=int(word.start * 1000),
|
|
201
|
+
end=int(word.end * 1000),
|
|
202
|
+
text=word.symbol,
|
|
203
|
+
name=sup.speaker or "",
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
subs.append(
|
|
208
|
+
pysubs2.SSAEvent(
|
|
209
|
+
start=int(sup.start * 1000),
|
|
210
|
+
end=int(sup.end * 1000),
|
|
211
|
+
text=sup.text or "",
|
|
212
|
+
name=sup.speaker or "",
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return subs.to_string(format_=format)
|
|
217
|
+
|
|
218
|
+
def to_dict(self) -> Dict:
|
|
219
|
+
"""
|
|
220
|
+
Convert Caption to dictionary representation.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Dictionary with caption data and metadata
|
|
224
|
+
"""
|
|
225
|
+
return {
|
|
226
|
+
"supervisions": [sup.to_dict() for sup in self.supervisions],
|
|
227
|
+
"language": self.language,
|
|
228
|
+
"kind": self.kind,
|
|
229
|
+
"source_format": self.source_format,
|
|
230
|
+
"source_path": str(self.source_path) if self.source_path else None,
|
|
231
|
+
"metadata": self.metadata,
|
|
232
|
+
"duration": self.duration,
|
|
233
|
+
"num_segments": len(self.supervisions),
|
|
234
|
+
"speakers": self.get_speakers(),
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
@classmethod
|
|
238
|
+
def from_supervisions(
|
|
239
|
+
cls,
|
|
240
|
+
supervisions: List[Supervision],
|
|
241
|
+
language: Optional[str] = None,
|
|
242
|
+
kind: Optional[str] = None,
|
|
243
|
+
source_format: Optional[str] = None,
|
|
244
|
+
source_path: Optional[Pathlike] = None,
|
|
245
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
246
|
+
) -> "Caption":
|
|
247
|
+
"""
|
|
248
|
+
Create Caption from a list of supervisions.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
supervisions: List of supervision segments
|
|
252
|
+
language: Language code
|
|
253
|
+
kind: Caption kind/type
|
|
254
|
+
source_format: Original format
|
|
255
|
+
source_path: Source file path
|
|
256
|
+
metadata: Additional metadata
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
New Caption instance
|
|
260
|
+
"""
|
|
261
|
+
return cls(
|
|
262
|
+
supervisions=supervisions,
|
|
263
|
+
language=language,
|
|
264
|
+
kind=kind,
|
|
265
|
+
source_format=source_format,
|
|
266
|
+
source_path=source_path,
|
|
267
|
+
metadata=metadata or {},
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def from_transcription_results(
|
|
272
|
+
cls,
|
|
273
|
+
transcription: List[Supervision],
|
|
274
|
+
audio_events: Optional[TextGrid] = None,
|
|
275
|
+
speaker_diarization: Optional[TextGrid] = None,
|
|
276
|
+
language: Optional[str] = None,
|
|
277
|
+
source_path: Optional[Pathlike] = None,
|
|
278
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
279
|
+
) -> "Caption":
|
|
280
|
+
"""
|
|
281
|
+
Create Caption from transcription results including audio events and diarization.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
transcription: List of transcription supervision segments
|
|
285
|
+
audio_events: Optional TextGrid with audio event detection results
|
|
286
|
+
speaker_diarization: Optional TextGrid with speaker diarization results
|
|
287
|
+
language: Language code
|
|
288
|
+
source_path: Source file path
|
|
289
|
+
metadata: Additional metadata
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
New Caption instance with transcription data
|
|
293
|
+
"""
|
|
294
|
+
return cls(
|
|
295
|
+
transcription=transcription,
|
|
296
|
+
audio_events=audio_events,
|
|
297
|
+
speaker_diarization=speaker_diarization,
|
|
298
|
+
language=language,
|
|
299
|
+
kind="transcription",
|
|
300
|
+
source_format="asr",
|
|
301
|
+
source_path=source_path,
|
|
302
|
+
metadata=metadata or {},
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
@classmethod
|
|
306
|
+
def read(
|
|
307
|
+
cls,
|
|
308
|
+
path: Pathlike,
|
|
309
|
+
format: Optional[str] = None,
|
|
310
|
+
normalize_text: bool = False,
|
|
311
|
+
) -> "Caption":
|
|
312
|
+
"""
|
|
313
|
+
Read caption file and return Caption object.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
path: Path to caption file
|
|
317
|
+
format: Caption format (auto-detected if not provided)
|
|
318
|
+
normalize_text: Whether to normalize text during reading
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Caption object containing supervisions and metadata
|
|
322
|
+
|
|
323
|
+
Example:
|
|
324
|
+
>>> caption = Caption.read("subtitles.srt")
|
|
325
|
+
>>> print(f"Loaded {len(caption)} segments")
|
|
326
|
+
"""
|
|
327
|
+
caption_path = Path(str(path)) if not isinstance(path, Path) else path
|
|
328
|
+
|
|
329
|
+
# Detect format if not provided
|
|
330
|
+
if not format and caption_path.exists():
|
|
331
|
+
format = caption_path.suffix.lstrip(".").lower()
|
|
332
|
+
elif format:
|
|
333
|
+
format = format.lower()
|
|
334
|
+
|
|
335
|
+
# Extract metadata from file
|
|
336
|
+
metadata = cls._extract_metadata(path, format)
|
|
337
|
+
|
|
338
|
+
# Parse supervisions
|
|
339
|
+
supervisions = cls._parse_supervisions(path, format, normalize_text)
|
|
340
|
+
|
|
341
|
+
# Create Caption object
|
|
342
|
+
return cls(
|
|
343
|
+
supervisions=supervisions,
|
|
344
|
+
language=metadata.get("language"),
|
|
345
|
+
kind=metadata.get("kind"),
|
|
346
|
+
source_format=format,
|
|
347
|
+
source_path=str(caption_path) if caption_path.exists() else None,
|
|
348
|
+
metadata=metadata,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def write(
|
|
352
|
+
self,
|
|
353
|
+
path: Pathlike,
|
|
354
|
+
include_speaker_in_text: bool = True,
|
|
355
|
+
) -> Pathlike:
|
|
356
|
+
"""
|
|
357
|
+
Write caption to file.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
path: Path to output caption file
|
|
361
|
+
include_speaker_in_text: Whether to include speaker labels in text
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
Path to the written file
|
|
365
|
+
|
|
366
|
+
Example:
|
|
367
|
+
>>> caption = Caption.read("input.srt")
|
|
368
|
+
>>> caption.write("output.vtt", include_speaker_in_text=False)
|
|
369
|
+
"""
|
|
370
|
+
if self.alignments:
|
|
371
|
+
alignments = self.alignments
|
|
372
|
+
else:
|
|
373
|
+
alignments = self.supervisions
|
|
374
|
+
|
|
375
|
+
if not alignments:
|
|
376
|
+
alignments = self.transcription
|
|
377
|
+
|
|
378
|
+
return self._write_caption(alignments, path, include_speaker_in_text)
|
|
379
|
+
|
|
380
|
+
def read_speaker_diarization(
|
|
381
|
+
self,
|
|
382
|
+
path: Pathlike,
|
|
383
|
+
) -> TextGrid:
|
|
384
|
+
"""
|
|
385
|
+
Read speaker diarization TextGrid from file.
|
|
386
|
+
"""
|
|
387
|
+
from tgt import read_textgrid
|
|
388
|
+
|
|
389
|
+
self.speaker_diarization = read_textgrid(path)
|
|
390
|
+
return self.speaker_diarization
|
|
391
|
+
|
|
392
|
+
def write_speaker_diarization(
|
|
393
|
+
self,
|
|
394
|
+
path: Pathlike,
|
|
395
|
+
) -> Pathlike:
|
|
396
|
+
"""
|
|
397
|
+
Write speaker diarization TextGrid to file.
|
|
398
|
+
"""
|
|
399
|
+
if not self.speaker_diarization:
|
|
400
|
+
raise ValueError("No speaker diarization data to write.")
|
|
401
|
+
|
|
402
|
+
from tgt import write_to_file
|
|
403
|
+
|
|
404
|
+
write_to_file(self.speaker_diarization, path, format="long")
|
|
405
|
+
return path
|
|
406
|
+
|
|
407
|
+
@staticmethod
|
|
408
|
+
def _parse_alignment_from_supervision(supervision: Any) -> Optional[List[AlignmentItem]]:
|
|
409
|
+
"""
|
|
410
|
+
Extract word-level alignment items from Supervision object.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
supervision: Supervision object with potential alignment data
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
List of AlignmentItem objects, or None if no alignment data present
|
|
417
|
+
"""
|
|
418
|
+
if not hasattr(supervision, "alignment") or not supervision.alignment:
|
|
419
|
+
return None
|
|
420
|
+
|
|
421
|
+
if "word" not in supervision.alignment:
|
|
422
|
+
return None
|
|
423
|
+
|
|
424
|
+
return supervision.alignment["word"]
|
|
425
|
+
|
|
426
|
+
@classmethod
|
|
427
|
+
def _write_caption(
|
|
428
|
+
cls,
|
|
429
|
+
alignments: List[Supervision],
|
|
430
|
+
output_path: Pathlike,
|
|
431
|
+
include_speaker_in_text: bool = True,
|
|
432
|
+
) -> Pathlike:
|
|
433
|
+
"""
|
|
434
|
+
Write caption to file in various formats.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
alignments: List of supervision segments to write
|
|
438
|
+
output_path: Path to output file
|
|
439
|
+
include_speaker_in_text: Whether to include speaker in text
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
Path to written file
|
|
443
|
+
"""
|
|
444
|
+
if str(output_path)[-4:].lower() == ".txt":
|
|
445
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
446
|
+
for sup in alignments:
|
|
447
|
+
word_items = cls._parse_alignment_from_supervision(sup)
|
|
448
|
+
if word_items:
|
|
449
|
+
for item in word_items:
|
|
450
|
+
f.write(f"[{item.start:.2f}-{item.end:.2f}] {item.symbol}\n")
|
|
451
|
+
else:
|
|
452
|
+
if include_speaker_in_text and sup.speaker is not None:
|
|
453
|
+
# Use [SPEAKER]: format for consistency with parsing
|
|
454
|
+
text = f"[{sup.speaker}]: {sup.text}"
|
|
455
|
+
else:
|
|
456
|
+
text = sup.text
|
|
457
|
+
f.write(f"[{sup.start:.2f}-{sup.end:.2f}] {text}\n")
|
|
458
|
+
|
|
459
|
+
elif str(output_path)[-5:].lower() == ".json":
|
|
460
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
461
|
+
# Enhanced JSON export with word-level alignment
|
|
462
|
+
json_data = []
|
|
463
|
+
for sup in alignments:
|
|
464
|
+
sup_dict = sup.to_dict()
|
|
465
|
+
json_data.append(sup_dict)
|
|
466
|
+
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
|
467
|
+
|
|
468
|
+
elif str(output_path).lower().endswith(".textgrid"):
|
|
469
|
+
from tgt import Interval, IntervalTier, TextGrid, write_to_file
|
|
470
|
+
|
|
471
|
+
tg = TextGrid()
|
|
472
|
+
supervisions, words, scores = [], [], {"utterances": [], "words": []}
|
|
473
|
+
for supervision in sorted(alignments, key=lambda x: x.start):
|
|
474
|
+
if include_speaker_in_text and supervision.speaker is not None:
|
|
475
|
+
text = f"{supervision.speaker} {supervision.text}"
|
|
476
|
+
else:
|
|
477
|
+
text = supervision.text
|
|
478
|
+
supervisions.append(Interval(supervision.start, supervision.end, text or ""))
|
|
479
|
+
# Extract word-level alignment using helper function
|
|
480
|
+
word_items = cls._parse_alignment_from_supervision(supervision)
|
|
481
|
+
if word_items:
|
|
482
|
+
for item in word_items:
|
|
483
|
+
words.append(Interval(item.start, item.end, item.symbol))
|
|
484
|
+
if item.score is not None:
|
|
485
|
+
scores["words"].append(Interval(item.start, item.end, f"{item.score:.2f}"))
|
|
486
|
+
if supervision.has_custom("score"):
|
|
487
|
+
scores["utterances"].append(
|
|
488
|
+
Interval(supervision.start, supervision.end, f"{supervision.score:.2f}")
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
tg.add_tier(IntervalTier(name="utterances", objects=supervisions))
|
|
492
|
+
if words:
|
|
493
|
+
tg.add_tier(IntervalTier(name="words", objects=words))
|
|
494
|
+
|
|
495
|
+
if scores["utterances"]:
|
|
496
|
+
tg.add_tier(IntervalTier(name="utterance_scores", objects=scores["utterances"]))
|
|
497
|
+
if scores["words"]:
|
|
498
|
+
tg.add_tier(IntervalTier(name="word_scores", objects=scores["words"]))
|
|
499
|
+
|
|
500
|
+
write_to_file(tg, output_path, format="long")
|
|
501
|
+
|
|
502
|
+
elif str(output_path)[-4:].lower() == ".tsv":
|
|
503
|
+
cls._write_tsv(alignments, output_path, include_speaker_in_text)
|
|
504
|
+
elif str(output_path)[-4:].lower() == ".csv":
|
|
505
|
+
cls._write_csv(alignments, output_path, include_speaker_in_text)
|
|
506
|
+
elif str(output_path)[-4:].lower() == ".aud":
|
|
507
|
+
cls._write_aud(alignments, output_path, include_speaker_in_text)
|
|
508
|
+
else:
|
|
509
|
+
import pysubs2
|
|
510
|
+
|
|
511
|
+
subs = pysubs2.SSAFile()
|
|
512
|
+
for sup in alignments:
|
|
513
|
+
# Add word-level timing as metadata in the caption text
|
|
514
|
+
word_items = cls._parse_alignment_from_supervision(sup)
|
|
515
|
+
if word_items:
|
|
516
|
+
for word in word_items:
|
|
517
|
+
subs.append(
|
|
518
|
+
pysubs2.SSAEvent(
|
|
519
|
+
start=int(word.start * 1000),
|
|
520
|
+
end=int(word.end * 1000),
|
|
521
|
+
text=word.symbol,
|
|
522
|
+
name=sup.speaker or "",
|
|
523
|
+
)
|
|
524
|
+
)
|
|
525
|
+
else:
|
|
526
|
+
if include_speaker_in_text and sup.speaker is not None:
|
|
527
|
+
text = f"{sup.speaker} {sup.text}"
|
|
528
|
+
else:
|
|
529
|
+
text = sup.text
|
|
530
|
+
subs.append(
|
|
531
|
+
pysubs2.SSAEvent(
|
|
532
|
+
start=int(sup.start * 1000),
|
|
533
|
+
end=int(sup.end * 1000),
|
|
534
|
+
text=text or "",
|
|
535
|
+
name=sup.speaker or "",
|
|
536
|
+
)
|
|
537
|
+
)
|
|
538
|
+
subs.save(output_path)
|
|
539
|
+
|
|
540
|
+
return output_path
|
|
541
|
+
|
|
542
|
+
@classmethod
|
|
543
|
+
def _extract_metadata(cls, caption: Pathlike, format: Optional[str]) -> Dict[str, str]:
|
|
544
|
+
"""
|
|
545
|
+
Extract metadata from caption file header.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
caption: Caption file path or content
|
|
549
|
+
format: Caption format
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
Dictionary of metadata key-value pairs
|
|
553
|
+
"""
|
|
554
|
+
metadata = {}
|
|
555
|
+
caption_path = Path(str(caption))
|
|
556
|
+
|
|
557
|
+
if not caption_path.exists():
|
|
558
|
+
return metadata
|
|
559
|
+
|
|
560
|
+
try:
|
|
561
|
+
with open(caption_path, "r", encoding="utf-8") as f:
|
|
562
|
+
content = f.read(2048) # Read first 2KB for metadata
|
|
563
|
+
|
|
564
|
+
# WebVTT metadata extraction
|
|
565
|
+
if format == "vtt" or content.startswith("WEBVTT"):
|
|
566
|
+
lines = content.split("\n")
|
|
567
|
+
for line in lines[:10]: # Check first 10 lines
|
|
568
|
+
line = line.strip()
|
|
569
|
+
if line.startswith("Kind:"):
|
|
570
|
+
metadata["kind"] = line.split(":", 1)[1].strip()
|
|
571
|
+
elif line.startswith("Language:"):
|
|
572
|
+
metadata["language"] = line.split(":", 1)[1].strip()
|
|
573
|
+
elif line.startswith("NOTE"):
|
|
574
|
+
# Extract metadata from NOTE comments
|
|
575
|
+
match = re.search(r"NOTE\s+(\w+):\s*(.+)", line)
|
|
576
|
+
if match:
|
|
577
|
+
key, value = match.groups()
|
|
578
|
+
metadata[key.lower()] = value.strip()
|
|
579
|
+
|
|
580
|
+
# SRT doesn't have standard metadata, but check for BOM
|
|
581
|
+
elif format == "srt":
|
|
582
|
+
if content.startswith("\ufeff"):
|
|
583
|
+
metadata["encoding"] = "utf-8-sig"
|
|
584
|
+
|
|
585
|
+
# TextGrid metadata
|
|
586
|
+
elif format == "textgrid" or caption_path.suffix.lower() == ".textgrid":
|
|
587
|
+
match = re.search(r"xmin\s*=\s*([\d.]+)", content)
|
|
588
|
+
if match:
|
|
589
|
+
metadata["xmin"] = match.group(1)
|
|
590
|
+
match = re.search(r"xmax\s*=\s*([\d.]+)", content)
|
|
591
|
+
if match:
|
|
592
|
+
metadata["xmax"] = match.group(1)
|
|
593
|
+
|
|
594
|
+
except Exception:
|
|
595
|
+
# If metadata extraction fails, continue with empty metadata
|
|
596
|
+
pass
|
|
597
|
+
|
|
598
|
+
return metadata
|
|
599
|
+
|
|
600
|
+
@classmethod
|
|
601
|
+
def _parse_youtube_vtt_with_word_timestamps(
|
|
602
|
+
cls, content: str, normalize_text: Optional[bool] = False
|
|
603
|
+
) -> List[Supervision]:
|
|
604
|
+
"""
|
|
605
|
+
Parse YouTube VTT format with word-level timestamps.
|
|
606
|
+
|
|
607
|
+
YouTube auto-generated captions use this format:
|
|
608
|
+
Word1<00:00:10.559><c> Word2</c><00:00:11.120><c> Word3</c>...
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
content: VTT file content
|
|
612
|
+
normalize_text: Whether to normalize text
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
List of Supervision objects with word-level alignments
|
|
616
|
+
"""
|
|
617
|
+
from lhotse.supervision import AlignmentItem
|
|
618
|
+
|
|
619
|
+
supervisions = []
|
|
620
|
+
|
|
621
|
+
# Pattern to match timestamp lines: 00:00:14.280 --> 00:00:17.269 align:start position:0%
|
|
622
|
+
timestamp_pattern = re.compile(r"(\d{2}:\d{2}:\d{2}[.,]\d{3})\s*-->\s*(\d{2}:\d{2}:\d{2}[.,]\d{3})")
|
|
623
|
+
|
|
624
|
+
# Pattern to match word-level timestamps: <00:00:10.559><c> word</c>
|
|
625
|
+
word_timestamp_pattern = re.compile(r"<(\d{2}:\d{2}:\d{2}[.,]\d{3})><c>\s*([^<]+)</c>")
|
|
626
|
+
|
|
627
|
+
# Pattern to match the first word (before first timestamp)
|
|
628
|
+
first_word_pattern = re.compile(r"^([^<\n]+?)<(\d{2}:\d{2}:\d{2}[.,]\d{3})>")
|
|
629
|
+
|
|
630
|
+
def parse_timestamp(ts: str) -> float:
|
|
631
|
+
"""Convert timestamp string to seconds."""
|
|
632
|
+
ts = ts.replace(",", ".")
|
|
633
|
+
parts = ts.split(":")
|
|
634
|
+
hours = int(parts[0])
|
|
635
|
+
minutes = int(parts[1])
|
|
636
|
+
seconds = float(parts[2])
|
|
637
|
+
return hours * 3600 + minutes * 60 + seconds
|
|
638
|
+
|
|
639
|
+
lines = content.split("\n")
|
|
640
|
+
i = 0
|
|
641
|
+
while i < len(lines):
|
|
642
|
+
line = lines[i].strip()
|
|
643
|
+
|
|
644
|
+
# Look for timestamp line
|
|
645
|
+
ts_match = timestamp_pattern.search(line)
|
|
646
|
+
if ts_match:
|
|
647
|
+
cue_start = parse_timestamp(ts_match.group(1))
|
|
648
|
+
cue_end = parse_timestamp(ts_match.group(2))
|
|
649
|
+
|
|
650
|
+
# Read the next non-empty lines for cue content
|
|
651
|
+
cue_lines = []
|
|
652
|
+
i += 1
|
|
653
|
+
while i < len(lines) and lines[i].strip() and not timestamp_pattern.search(lines[i]):
|
|
654
|
+
cue_lines.append(lines[i])
|
|
655
|
+
i += 1
|
|
656
|
+
|
|
657
|
+
# Process cue content
|
|
658
|
+
for cue_line in cue_lines:
|
|
659
|
+
cue_line = cue_line.strip()
|
|
660
|
+
if not cue_line:
|
|
661
|
+
continue
|
|
662
|
+
|
|
663
|
+
# Check if this line has word-level timestamps
|
|
664
|
+
word_matches = word_timestamp_pattern.findall(cue_line)
|
|
665
|
+
if word_matches:
|
|
666
|
+
# This line has word-level timing
|
|
667
|
+
word_alignments = []
|
|
668
|
+
|
|
669
|
+
# Get the first word (before the first timestamp)
|
|
670
|
+
first_match = first_word_pattern.match(cue_line)
|
|
671
|
+
if first_match:
|
|
672
|
+
first_word = first_match.group(1).strip()
|
|
673
|
+
first_word_next_ts = parse_timestamp(first_match.group(2))
|
|
674
|
+
if first_word:
|
|
675
|
+
# First word starts at cue_start
|
|
676
|
+
word_alignments.append(
|
|
677
|
+
AlignmentItem(
|
|
678
|
+
symbol=first_word,
|
|
679
|
+
start=cue_start,
|
|
680
|
+
duration=first_word_next_ts - cue_start,
|
|
681
|
+
)
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
# Process remaining words with timestamps
|
|
685
|
+
for idx, (ts, word) in enumerate(word_matches):
|
|
686
|
+
word_start = parse_timestamp(ts)
|
|
687
|
+
word = word.strip()
|
|
688
|
+
if not word:
|
|
689
|
+
continue
|
|
690
|
+
|
|
691
|
+
# Calculate duration based on next word's timestamp or cue end
|
|
692
|
+
if idx + 1 < len(word_matches):
|
|
693
|
+
next_ts = parse_timestamp(word_matches[idx + 1][0])
|
|
694
|
+
duration = next_ts - word_start
|
|
695
|
+
else:
|
|
696
|
+
duration = cue_end - word_start
|
|
697
|
+
|
|
698
|
+
word_alignments.append(
|
|
699
|
+
AlignmentItem(
|
|
700
|
+
symbol=word,
|
|
701
|
+
start=word_start,
|
|
702
|
+
duration=max(0.01, duration), # Ensure positive duration
|
|
703
|
+
)
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
if word_alignments:
|
|
707
|
+
# Create supervision with word-level alignment
|
|
708
|
+
full_text = " ".join(item.symbol for item in word_alignments)
|
|
709
|
+
if normalize_text:
|
|
710
|
+
full_text = normalize_text_fn(full_text)
|
|
711
|
+
|
|
712
|
+
sup_start = word_alignments[0].start
|
|
713
|
+
sup_end = word_alignments[-1].start + word_alignments[-1].duration
|
|
714
|
+
|
|
715
|
+
supervisions.append(
|
|
716
|
+
Supervision(
|
|
717
|
+
text=full_text,
|
|
718
|
+
start=sup_start,
|
|
719
|
+
duration=sup_end - sup_start,
|
|
720
|
+
alignment={"word": word_alignments},
|
|
721
|
+
)
|
|
722
|
+
)
|
|
723
|
+
else:
|
|
724
|
+
# Plain text line without word-level timing - skip duplicate lines
|
|
725
|
+
# (YouTube VTT often repeats the previous line without timestamps)
|
|
726
|
+
pass
|
|
727
|
+
|
|
728
|
+
continue
|
|
729
|
+
i += 1
|
|
730
|
+
|
|
731
|
+
# Merge consecutive supervisions to form complete utterances
|
|
732
|
+
if supervisions:
|
|
733
|
+
supervisions = cls._merge_youtube_vtt_supervisions(supervisions)
|
|
734
|
+
|
|
735
|
+
return supervisions
|
|
736
|
+
|
|
737
|
+
@classmethod
|
|
738
|
+
def _merge_youtube_vtt_supervisions(cls, supervisions: List[Supervision]) -> List[Supervision]:
|
|
739
|
+
"""
|
|
740
|
+
Merge consecutive YouTube VTT supervisions into complete utterances.
|
|
741
|
+
|
|
742
|
+
YouTube VTT splits utterances across multiple cues. This method merges
|
|
743
|
+
cues that are close together in time.
|
|
744
|
+
|
|
745
|
+
Args:
|
|
746
|
+
supervisions: List of supervisions to merge
|
|
747
|
+
|
|
748
|
+
Returns:
|
|
749
|
+
List of merged supervisions
|
|
750
|
+
"""
|
|
751
|
+
if not supervisions:
|
|
752
|
+
return supervisions
|
|
753
|
+
|
|
754
|
+
merged = []
|
|
755
|
+
current = supervisions[0]
|
|
756
|
+
|
|
757
|
+
for next_sup in supervisions[1:]:
|
|
758
|
+
# Check if next supervision is close enough to merge (within 0.5 seconds)
|
|
759
|
+
gap = next_sup.start - (current.start + current.duration)
|
|
760
|
+
|
|
761
|
+
if gap < 0.5 and current.alignment and next_sup.alignment:
|
|
762
|
+
# Merge alignments
|
|
763
|
+
current_words = current.alignment.get("word", [])
|
|
764
|
+
next_words = next_sup.alignment.get("word", [])
|
|
765
|
+
merged_words = list(current_words) + list(next_words)
|
|
766
|
+
|
|
767
|
+
# Create merged supervision
|
|
768
|
+
merged_text = current.text + " " + next_sup.text
|
|
769
|
+
merged_end = next_sup.start + next_sup.duration
|
|
770
|
+
|
|
771
|
+
current = Supervision(
|
|
772
|
+
text=merged_text,
|
|
773
|
+
start=current.start,
|
|
774
|
+
duration=merged_end - current.start,
|
|
775
|
+
alignment={"word": merged_words},
|
|
776
|
+
)
|
|
777
|
+
else:
|
|
778
|
+
merged.append(current)
|
|
779
|
+
current = next_sup
|
|
780
|
+
|
|
781
|
+
merged.append(current)
|
|
782
|
+
return merged
|
|
783
|
+
|
|
784
|
+
@classmethod
|
|
785
|
+
def _is_youtube_vtt_with_word_timestamps(cls, content: str) -> bool:
|
|
786
|
+
"""
|
|
787
|
+
Check if content is YouTube VTT format with word-level timestamps.
|
|
788
|
+
|
|
789
|
+
Args:
|
|
790
|
+
content: File content to check
|
|
791
|
+
|
|
792
|
+
Returns:
|
|
793
|
+
True if content contains YouTube-style word timestamps
|
|
794
|
+
"""
|
|
795
|
+
# Look for pattern like <00:00:10.559><c> word</c>
|
|
796
|
+
return bool(re.search(r"<\d{2}:\d{2}:\d{2}[.,]\d{3}><c>", content))
|
|
797
|
+
|
|
798
|
+
@classmethod
|
|
799
|
+
def _parse_supervisions(
|
|
800
|
+
cls, caption: Pathlike, format: Optional[str], normalize_text: Optional[bool] = False
|
|
801
|
+
) -> List[Supervision]:
|
|
802
|
+
"""
|
|
803
|
+
Parse supervisions from caption file.
|
|
804
|
+
|
|
805
|
+
Args:
|
|
806
|
+
caption: Caption file path or content
|
|
807
|
+
format: Caption format
|
|
808
|
+
normalize_text: Whether to normalize text
|
|
809
|
+
|
|
810
|
+
Returns:
|
|
811
|
+
List of Supervision objects
|
|
812
|
+
"""
|
|
813
|
+
if format:
|
|
814
|
+
format = format.lower()
|
|
815
|
+
|
|
816
|
+
# Check for YouTube VTT with word-level timestamps first
|
|
817
|
+
caption_path = Path(str(caption))
|
|
818
|
+
if caption_path.exists():
|
|
819
|
+
with open(caption_path, "r", encoding="utf-8") as f:
|
|
820
|
+
content = f.read()
|
|
821
|
+
if cls._is_youtube_vtt_with_word_timestamps(content):
|
|
822
|
+
return cls._parse_youtube_vtt_with_word_timestamps(content, normalize_text)
|
|
823
|
+
|
|
824
|
+
if format == "gemini" or str(caption).endswith("Gemini.md"):
|
|
825
|
+
from .gemini_reader import GeminiReader
|
|
826
|
+
|
|
827
|
+
supervisions = GeminiReader.extract_for_alignment(caption)
|
|
828
|
+
elif format and (format == "textgrid" or str(caption).lower().endswith("textgrid")):
|
|
829
|
+
# Internel usage
|
|
830
|
+
from tgt import read_textgrid
|
|
831
|
+
|
|
832
|
+
tgt = read_textgrid(caption)
|
|
833
|
+
supervisions = []
|
|
834
|
+
for tier in tgt.tiers:
|
|
835
|
+
supervisions.extend(
|
|
836
|
+
[
|
|
837
|
+
Supervision(
|
|
838
|
+
text=interval.text,
|
|
839
|
+
start=interval.start_time,
|
|
840
|
+
duration=interval.end_time - interval.start_time,
|
|
841
|
+
speaker=tier.name,
|
|
842
|
+
)
|
|
843
|
+
for interval in tier.intervals
|
|
844
|
+
]
|
|
845
|
+
)
|
|
846
|
+
supervisions = sorted(supervisions, key=lambda x: x.start)
|
|
847
|
+
elif format == "tsv" or str(caption)[-4:].lower() == ".tsv":
|
|
848
|
+
supervisions = cls._parse_tsv(caption, normalize_text)
|
|
849
|
+
elif format == "csv" or str(caption)[-4:].lower() == ".csv":
|
|
850
|
+
supervisions = cls._parse_csv(caption, normalize_text)
|
|
851
|
+
elif format == "aud" or str(caption)[-4:].lower() == ".aud":
|
|
852
|
+
supervisions = cls._parse_aud(caption, normalize_text)
|
|
853
|
+
elif format == "txt" or (format == "auto" and str(caption)[-4:].lower() == ".txt"):
|
|
854
|
+
if not Path(str(caption)).exists(): # str
|
|
855
|
+
lines = [line.strip() for line in str(caption).split("\n")]
|
|
856
|
+
else: # file
|
|
857
|
+
path_str = str(caption)
|
|
858
|
+
with open(path_str, encoding="utf-8") as f:
|
|
859
|
+
lines = [line.strip() for line in f.readlines()]
|
|
860
|
+
if normalize_text:
|
|
861
|
+
lines = [normalize_text_fn(line) for line in lines]
|
|
862
|
+
supervisions = []
|
|
863
|
+
for line in lines:
|
|
864
|
+
if line:
|
|
865
|
+
# First try to parse timestamp format: [start-end] text
|
|
866
|
+
start, end, remaining_text = parse_timestamp_text(line)
|
|
867
|
+
if start is not None and end is not None:
|
|
868
|
+
# Has timestamp, now check for speaker in the remaining text
|
|
869
|
+
speaker, text = parse_speaker_text(remaining_text)
|
|
870
|
+
supervisions.append(
|
|
871
|
+
Supervision(
|
|
872
|
+
text=text,
|
|
873
|
+
start=start,
|
|
874
|
+
duration=end - start,
|
|
875
|
+
speaker=speaker,
|
|
876
|
+
)
|
|
877
|
+
)
|
|
878
|
+
else:
|
|
879
|
+
# No timestamp, just parse speaker and text
|
|
880
|
+
speaker, text = parse_speaker_text(line)
|
|
881
|
+
supervisions.append(Supervision(text=text, speaker=speaker))
|
|
882
|
+
else:
|
|
883
|
+
try:
|
|
884
|
+
supervisions = cls._parse_caption(caption, format=format, normalize_text=normalize_text)
|
|
885
|
+
except Exception as e:
|
|
886
|
+
print(f"Failed to parse caption with Format: {format}, Exception: {e}, trying 'gemini' parser.")
|
|
887
|
+
from .gemini_reader import GeminiReader
|
|
888
|
+
|
|
889
|
+
supervisions = GeminiReader.extract_for_alignment(caption)
|
|
890
|
+
|
|
891
|
+
return supervisions
|
|
892
|
+
|
|
893
|
+
@classmethod
|
|
894
|
+
def _parse_tsv(cls, caption: Pathlike, normalize_text: Optional[bool] = False) -> List[Supervision]:
|
|
895
|
+
"""
|
|
896
|
+
Parse TSV (Tab-Separated Values) format caption file.
|
|
897
|
+
|
|
898
|
+
Format specifications:
|
|
899
|
+
- With speaker: speaker\tstart\tend\ttext
|
|
900
|
+
- Without speaker: start\tend\ttext
|
|
901
|
+
- Times are in milliseconds
|
|
902
|
+
|
|
903
|
+
Args:
|
|
904
|
+
caption: Caption file path
|
|
905
|
+
normalize_text: Whether to normalize text
|
|
906
|
+
|
|
907
|
+
Returns:
|
|
908
|
+
List of Supervision objects
|
|
909
|
+
"""
|
|
910
|
+
caption_path = Path(str(caption))
|
|
911
|
+
if not caption_path.exists():
|
|
912
|
+
raise FileNotFoundError(f"Caption file not found: {caption}")
|
|
913
|
+
|
|
914
|
+
supervisions = []
|
|
915
|
+
|
|
916
|
+
with open(caption_path, "r", encoding="utf-8") as f:
|
|
917
|
+
lines = f.readlines()
|
|
918
|
+
|
|
919
|
+
# Check if first line is a header
|
|
920
|
+
first_line = lines[0].strip().lower()
|
|
921
|
+
has_header = "start" in first_line and "end" in first_line and "text" in first_line
|
|
922
|
+
has_speaker_column = "speaker" in first_line
|
|
923
|
+
|
|
924
|
+
start_idx = 1 if has_header else 0
|
|
925
|
+
|
|
926
|
+
for line in lines[start_idx:]:
|
|
927
|
+
line = line.strip()
|
|
928
|
+
if not line:
|
|
929
|
+
continue
|
|
930
|
+
|
|
931
|
+
parts = line.split("\t")
|
|
932
|
+
if len(parts) < 3:
|
|
933
|
+
continue
|
|
934
|
+
|
|
935
|
+
try:
|
|
936
|
+
if has_speaker_column and len(parts) >= 4:
|
|
937
|
+
# Format: speaker\tstart\tend\ttext
|
|
938
|
+
speaker = parts[0].strip() if parts[0].strip() else None
|
|
939
|
+
start = float(parts[1]) / 1000.0 # Convert milliseconds to seconds
|
|
940
|
+
end = float(parts[2]) / 1000.0
|
|
941
|
+
text = "\t".join(parts[3:]).strip()
|
|
942
|
+
else:
|
|
943
|
+
# Format: start\tend\ttext
|
|
944
|
+
start = float(parts[0]) / 1000.0 # Convert milliseconds to seconds
|
|
945
|
+
end = float(parts[1]) / 1000.0
|
|
946
|
+
text = "\t".join(parts[2:]).strip()
|
|
947
|
+
speaker = None
|
|
948
|
+
|
|
949
|
+
if normalize_text:
|
|
950
|
+
text = normalize_text_fn(text)
|
|
951
|
+
|
|
952
|
+
duration = end - start
|
|
953
|
+
if duration < 0:
|
|
954
|
+
continue
|
|
955
|
+
|
|
956
|
+
supervisions.append(
|
|
957
|
+
Supervision(
|
|
958
|
+
text=text,
|
|
959
|
+
start=start,
|
|
960
|
+
duration=duration,
|
|
961
|
+
speaker=speaker,
|
|
962
|
+
)
|
|
963
|
+
)
|
|
964
|
+
except (ValueError, IndexError):
|
|
965
|
+
# Skip malformed lines
|
|
966
|
+
continue
|
|
967
|
+
|
|
968
|
+
return supervisions
|
|
969
|
+
|
|
970
|
+
@classmethod
|
|
971
|
+
def _parse_csv(cls, caption: Pathlike, normalize_text: Optional[bool] = False) -> List[Supervision]:
|
|
972
|
+
"""
|
|
973
|
+
Parse CSV (Comma-Separated Values) format caption file.
|
|
974
|
+
|
|
975
|
+
Format specifications:
|
|
976
|
+
- With speaker: speaker,start,end,text
|
|
977
|
+
- Without speaker: start,end,text
|
|
978
|
+
- Times are in milliseconds
|
|
979
|
+
|
|
980
|
+
Args:
|
|
981
|
+
caption: Caption file path
|
|
982
|
+
normalize_text: Whether to normalize text
|
|
983
|
+
|
|
984
|
+
Returns:
|
|
985
|
+
List of Supervision objects
|
|
986
|
+
"""
|
|
987
|
+
import csv
|
|
988
|
+
|
|
989
|
+
caption_path = Path(str(caption))
|
|
990
|
+
if not caption_path.exists():
|
|
991
|
+
raise FileNotFoundError(f"Caption file not found: {caption}")
|
|
992
|
+
|
|
993
|
+
supervisions = []
|
|
994
|
+
|
|
995
|
+
with open(caption_path, "r", encoding="utf-8", newline="") as f:
|
|
996
|
+
reader = csv.reader(f)
|
|
997
|
+
lines = list(reader)
|
|
998
|
+
|
|
999
|
+
if not lines:
|
|
1000
|
+
return supervisions
|
|
1001
|
+
|
|
1002
|
+
# Check if first line is a header
|
|
1003
|
+
first_line = [col.strip().lower() for col in lines[0]]
|
|
1004
|
+
has_header = "start" in first_line and "end" in first_line and "text" in first_line
|
|
1005
|
+
has_speaker_column = "speaker" in first_line
|
|
1006
|
+
|
|
1007
|
+
start_idx = 1 if has_header else 0
|
|
1008
|
+
|
|
1009
|
+
for parts in lines[start_idx:]:
|
|
1010
|
+
if len(parts) < 3:
|
|
1011
|
+
continue
|
|
1012
|
+
|
|
1013
|
+
try:
|
|
1014
|
+
if has_speaker_column and len(parts) >= 4:
|
|
1015
|
+
# Format: speaker,start,end,text
|
|
1016
|
+
speaker = parts[0].strip() if parts[0].strip() else None
|
|
1017
|
+
start = float(parts[1]) / 1000.0 # Convert milliseconds to seconds
|
|
1018
|
+
end = float(parts[2]) / 1000.0
|
|
1019
|
+
text = ",".join(parts[3:]).strip()
|
|
1020
|
+
else:
|
|
1021
|
+
# Format: start,end,text
|
|
1022
|
+
start = float(parts[0]) / 1000.0 # Convert milliseconds to seconds
|
|
1023
|
+
end = float(parts[1]) / 1000.0
|
|
1024
|
+
text = ",".join(parts[2:]).strip()
|
|
1025
|
+
speaker = None
|
|
1026
|
+
|
|
1027
|
+
if normalize_text:
|
|
1028
|
+
text = normalize_text_fn(text)
|
|
1029
|
+
|
|
1030
|
+
duration = end - start
|
|
1031
|
+
if duration < 0:
|
|
1032
|
+
continue
|
|
1033
|
+
|
|
1034
|
+
supervisions.append(
|
|
1035
|
+
Supervision(
|
|
1036
|
+
text=text,
|
|
1037
|
+
start=start,
|
|
1038
|
+
duration=duration,
|
|
1039
|
+
speaker=speaker,
|
|
1040
|
+
)
|
|
1041
|
+
)
|
|
1042
|
+
except (ValueError, IndexError):
|
|
1043
|
+
# Skip malformed lines
|
|
1044
|
+
continue
|
|
1045
|
+
|
|
1046
|
+
return supervisions
|
|
1047
|
+
|
|
1048
|
+
@classmethod
|
|
1049
|
+
def _parse_aud(cls, caption: Pathlike, normalize_text: Optional[bool] = False) -> List[Supervision]:
|
|
1050
|
+
"""
|
|
1051
|
+
Parse AUD (Audacity Labels) format caption file.
|
|
1052
|
+
|
|
1053
|
+
Format: start\tend\t[[speaker]]text
|
|
1054
|
+
- Times are in seconds (float)
|
|
1055
|
+
- Speaker is optional and enclosed in [[brackets]]
|
|
1056
|
+
|
|
1057
|
+
Args:
|
|
1058
|
+
caption: Caption file path
|
|
1059
|
+
normalize_text: Whether to normalize text
|
|
1060
|
+
|
|
1061
|
+
Returns:
|
|
1062
|
+
List of Supervision objects
|
|
1063
|
+
"""
|
|
1064
|
+
caption_path = Path(str(caption))
|
|
1065
|
+
if not caption_path.exists():
|
|
1066
|
+
raise FileNotFoundError(f"Caption file not found: {caption}")
|
|
1067
|
+
|
|
1068
|
+
supervisions = []
|
|
1069
|
+
|
|
1070
|
+
with open(caption_path, "r", encoding="utf-8") as f:
|
|
1071
|
+
lines = f.readlines()
|
|
1072
|
+
|
|
1073
|
+
for line in lines:
|
|
1074
|
+
line = line.strip()
|
|
1075
|
+
if not line:
|
|
1076
|
+
continue
|
|
1077
|
+
|
|
1078
|
+
parts = line.split("\t")
|
|
1079
|
+
if len(parts) < 3:
|
|
1080
|
+
continue
|
|
1081
|
+
|
|
1082
|
+
try:
|
|
1083
|
+
# AUD format: start\tend\ttext (speaker in [[brackets]])
|
|
1084
|
+
start = float(parts[0])
|
|
1085
|
+
end = float(parts[1])
|
|
1086
|
+
text = "\t".join(parts[2:]).strip()
|
|
1087
|
+
|
|
1088
|
+
# Extract speaker from [[speaker]] prefix
|
|
1089
|
+
speaker = None
|
|
1090
|
+
speaker_match = re.match(r"^\[\[([^\]]+)\]\]\s*(.*)$", text)
|
|
1091
|
+
if speaker_match:
|
|
1092
|
+
speaker = speaker_match.group(1)
|
|
1093
|
+
text = speaker_match.group(2)
|
|
1094
|
+
|
|
1095
|
+
if normalize_text:
|
|
1096
|
+
text = normalize_text_fn(text)
|
|
1097
|
+
|
|
1098
|
+
duration = end - start
|
|
1099
|
+
if duration < 0:
|
|
1100
|
+
continue
|
|
1101
|
+
|
|
1102
|
+
supervisions.append(
|
|
1103
|
+
Supervision(
|
|
1104
|
+
text=text,
|
|
1105
|
+
start=start,
|
|
1106
|
+
duration=duration,
|
|
1107
|
+
speaker=speaker,
|
|
1108
|
+
)
|
|
1109
|
+
)
|
|
1110
|
+
except (ValueError, IndexError):
|
|
1111
|
+
# Skip malformed lines
|
|
1112
|
+
continue
|
|
1113
|
+
|
|
1114
|
+
return supervisions
|
|
1115
|
+
|
|
1116
|
+
@classmethod
|
|
1117
|
+
def _write_tsv(
|
|
1118
|
+
cls,
|
|
1119
|
+
alignments: List[Supervision],
|
|
1120
|
+
output_path: Pathlike,
|
|
1121
|
+
include_speaker_in_text: bool = True,
|
|
1122
|
+
) -> None:
|
|
1123
|
+
"""
|
|
1124
|
+
Write caption to TSV format.
|
|
1125
|
+
|
|
1126
|
+
Format: speaker\tstart\tend\ttext (with speaker)
|
|
1127
|
+
or: start\tend\ttext (without speaker)
|
|
1128
|
+
|
|
1129
|
+
Args:
|
|
1130
|
+
alignments: List of supervision segments to write
|
|
1131
|
+
output_path: Path to output TSV file
|
|
1132
|
+
include_speaker_in_text: Whether to include speaker column
|
|
1133
|
+
"""
|
|
1134
|
+
with open(output_path, "w", encoding="utf-8") as file:
|
|
1135
|
+
# Write header
|
|
1136
|
+
if include_speaker_in_text:
|
|
1137
|
+
file.write("speaker\tstart\tend\ttext\n")
|
|
1138
|
+
for supervision in alignments:
|
|
1139
|
+
speaker = supervision.speaker or ""
|
|
1140
|
+
start_ms = round(1000 * supervision.start)
|
|
1141
|
+
end_ms = round(1000 * supervision.end)
|
|
1142
|
+
text = supervision.text.strip().replace("\t", " ")
|
|
1143
|
+
file.write(f"{speaker}\t{start_ms}\t{end_ms}\t{text}\n")
|
|
1144
|
+
else:
|
|
1145
|
+
file.write("start\tend\ttext\n")
|
|
1146
|
+
for supervision in alignments:
|
|
1147
|
+
start_ms = round(1000 * supervision.start)
|
|
1148
|
+
end_ms = round(1000 * supervision.end)
|
|
1149
|
+
text = supervision.text.strip().replace("\t", " ")
|
|
1150
|
+
file.write(f"{start_ms}\t{end_ms}\t{text}\n")
|
|
1151
|
+
|
|
1152
|
+
@classmethod
|
|
1153
|
+
def _write_csv(
|
|
1154
|
+
cls,
|
|
1155
|
+
alignments: List[Supervision],
|
|
1156
|
+
output_path: Pathlike,
|
|
1157
|
+
include_speaker_in_text: bool = True,
|
|
1158
|
+
) -> None:
|
|
1159
|
+
"""
|
|
1160
|
+
Write caption to CSV format.
|
|
1161
|
+
|
|
1162
|
+
Format: speaker,start,end,text (with speaker)
|
|
1163
|
+
or: start,end,text (without speaker)
|
|
1164
|
+
|
|
1165
|
+
Args:
|
|
1166
|
+
alignments: List of supervision segments to write
|
|
1167
|
+
output_path: Path to output CSV file
|
|
1168
|
+
include_speaker_in_text: Whether to include speaker column
|
|
1169
|
+
"""
|
|
1170
|
+
import csv
|
|
1171
|
+
|
|
1172
|
+
with open(output_path, "w", encoding="utf-8", newline="") as file:
|
|
1173
|
+
if include_speaker_in_text:
|
|
1174
|
+
writer = csv.writer(file)
|
|
1175
|
+
writer.writerow(["speaker", "start", "end", "text"])
|
|
1176
|
+
for supervision in alignments:
|
|
1177
|
+
speaker = supervision.speaker or ""
|
|
1178
|
+
start_ms = round(1000 * supervision.start)
|
|
1179
|
+
end_ms = round(1000 * supervision.end)
|
|
1180
|
+
text = supervision.text.strip()
|
|
1181
|
+
writer.writerow([speaker, start_ms, end_ms, text])
|
|
1182
|
+
else:
|
|
1183
|
+
writer = csv.writer(file)
|
|
1184
|
+
writer.writerow(["start", "end", "text"])
|
|
1185
|
+
for supervision in alignments:
|
|
1186
|
+
start_ms = round(1000 * supervision.start)
|
|
1187
|
+
end_ms = round(1000 * supervision.end)
|
|
1188
|
+
text = supervision.text.strip()
|
|
1189
|
+
writer.writerow([start_ms, end_ms, text])
|
|
1190
|
+
|
|
1191
|
+
@classmethod
|
|
1192
|
+
def _write_aud(
|
|
1193
|
+
cls,
|
|
1194
|
+
alignments: List[Supervision],
|
|
1195
|
+
output_path: Pathlike,
|
|
1196
|
+
include_speaker_in_text: bool = True,
|
|
1197
|
+
) -> None:
|
|
1198
|
+
"""
|
|
1199
|
+
Write caption to AUD format.
|
|
1200
|
+
|
|
1201
|
+
Format: start\tend\t[[speaker]]text
|
|
1202
|
+
or: start\tend\ttext (without speaker)
|
|
1203
|
+
|
|
1204
|
+
Args:
|
|
1205
|
+
alignments: List of supervision segments to write
|
|
1206
|
+
output_path: Path to output AUD file
|
|
1207
|
+
include_speaker_in_text: Whether to include speaker in [[brackets]]
|
|
1208
|
+
"""
|
|
1209
|
+
with open(output_path, "w", encoding="utf-8") as file:
|
|
1210
|
+
for supervision in alignments:
|
|
1211
|
+
start = supervision.start
|
|
1212
|
+
end = supervision.end
|
|
1213
|
+
text = supervision.text.strip().replace("\t", " ")
|
|
1214
|
+
|
|
1215
|
+
if include_speaker_in_text and supervision.speaker:
|
|
1216
|
+
text = f"[[{supervision.speaker}]]{text}"
|
|
1217
|
+
|
|
1218
|
+
file.write(f"{start}\t{end}\t{text}\n")
|
|
1219
|
+
|
|
1220
|
+
@classmethod
|
|
1221
|
+
def _parse_caption(
|
|
1222
|
+
cls, caption: Pathlike, format: Optional[OutputCaptionFormat], normalize_text: Optional[bool] = False
|
|
1223
|
+
) -> List[Supervision]:
|
|
1224
|
+
"""
|
|
1225
|
+
Parse caption using pysubs2.
|
|
1226
|
+
|
|
1227
|
+
Args:
|
|
1228
|
+
caption: Caption file path or content
|
|
1229
|
+
format: Caption format
|
|
1230
|
+
normalize_text: Whether to normalize text
|
|
1231
|
+
|
|
1232
|
+
Returns:
|
|
1233
|
+
List of Supervision objects
|
|
1234
|
+
"""
|
|
1235
|
+
import pysubs2
|
|
1236
|
+
|
|
1237
|
+
try:
|
|
1238
|
+
subs: pysubs2.SSAFile = pysubs2.load(
|
|
1239
|
+
caption, encoding="utf-8", format_=format if format != "auto" else None
|
|
1240
|
+
) # file
|
|
1241
|
+
except IOError:
|
|
1242
|
+
try:
|
|
1243
|
+
subs: pysubs2.SSAFile = pysubs2.SSAFile.from_string(
|
|
1244
|
+
caption, format_=format if format != "auto" else None
|
|
1245
|
+
) # str
|
|
1246
|
+
except Exception as e:
|
|
1247
|
+
del e
|
|
1248
|
+
subs: pysubs2.SSAFile = pysubs2.load(caption, encoding="utf-8") # auto detect format
|
|
1249
|
+
|
|
1250
|
+
# Parse supervisions
|
|
1251
|
+
supervisions = []
|
|
1252
|
+
for event in subs.events:
|
|
1253
|
+
if normalize_text:
|
|
1254
|
+
event.text = normalize_text_fn(event.text)
|
|
1255
|
+
speaker, text = parse_speaker_text(event.text)
|
|
1256
|
+
supervisions.append(
|
|
1257
|
+
Supervision(
|
|
1258
|
+
text=text,
|
|
1259
|
+
speaker=speaker or event.name,
|
|
1260
|
+
start=event.start / 1000.0 if event.start is not None else None,
|
|
1261
|
+
duration=(event.end - event.start) / 1000.0 if event.end is not None else None,
|
|
1262
|
+
)
|
|
1263
|
+
)
|
|
1264
|
+
return supervisions
|
|
1265
|
+
|
|
1266
|
+
def __repr__(self) -> str:
|
|
1267
|
+
"""String representation of Caption."""
|
|
1268
|
+
lang = f"lang={self.language}" if self.language else "lang=unknown"
|
|
1269
|
+
kind_str = f"kind={self.kind}" if self.kind else ""
|
|
1270
|
+
parts = [f"Caption({len(self.supervisions or self.transcription)} segments", lang]
|
|
1271
|
+
if kind_str:
|
|
1272
|
+
parts.append(kind_str)
|
|
1273
|
+
if self.duration:
|
|
1274
|
+
parts.append(f"duration={self.duration:.2f}s")
|
|
1275
|
+
return ", ".join(parts) + ")"
|