lattifai 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +26 -27
- lattifai/base_client.py +7 -7
- lattifai/bin/agent.py +90 -91
- lattifai/bin/align.py +110 -111
- lattifai/bin/cli_base.py +3 -3
- lattifai/bin/subtitle.py +45 -45
- lattifai/client.py +56 -56
- lattifai/errors.py +73 -73
- lattifai/io/__init__.py +12 -11
- lattifai/io/gemini_reader.py +30 -30
- lattifai/io/gemini_writer.py +17 -17
- lattifai/io/reader.py +13 -12
- lattifai/io/supervision.py +3 -3
- lattifai/io/text_parser.py +43 -16
- lattifai/io/utils.py +4 -4
- lattifai/io/writer.py +31 -19
- lattifai/tokenizer/__init__.py +1 -1
- lattifai/tokenizer/phonemizer.py +3 -3
- lattifai/tokenizer/tokenizer.py +83 -82
- lattifai/utils.py +15 -15
- lattifai/workers/__init__.py +1 -1
- lattifai/workers/lattice1_alpha.py +46 -46
- lattifai/workflows/__init__.py +11 -11
- lattifai/workflows/agents.py +2 -0
- lattifai/workflows/base.py +22 -22
- lattifai/workflows/file_manager.py +182 -182
- lattifai/workflows/gemini.py +29 -29
- lattifai/workflows/prompts/__init__.py +4 -4
- lattifai/workflows/youtube.py +233 -233
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/METADATA +7 -9
- lattifai-0.4.6.dist-info/RECORD +39 -0
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/licenses/LICENSE +1 -1
- lattifai-0.4.5.dist-info/RECORD +0 -39
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/WHEEL +0 -0
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/entry_points.txt +0 -0
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/top_level.txt +0 -0
lattifai/io/gemini_reader.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
"""Reader for YouTube transcript files with speaker labels and timestamps."""
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
-
from dataclasses import dataclass
|
|
4
|
+
from dataclasses import dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import List, Optional
|
|
6
|
+
from typing import List, Optional
|
|
7
7
|
|
|
8
8
|
from lhotse.utils import Pathlike
|
|
9
9
|
|
|
@@ -18,7 +18,7 @@ class GeminiSegment:
|
|
|
18
18
|
timestamp: Optional[float] = None
|
|
19
19
|
speaker: Optional[str] = None
|
|
20
20
|
section: Optional[str] = None
|
|
21
|
-
segment_type: str =
|
|
21
|
+
segment_type: str = "dialogue" # 'dialogue', 'event', or 'section_header'
|
|
22
22
|
line_number: int = 0
|
|
23
23
|
|
|
24
24
|
@property
|
|
@@ -31,15 +31,15 @@ class GeminiReader:
|
|
|
31
31
|
"""Parser for YouTube transcript format with speaker labels and timestamps."""
|
|
32
32
|
|
|
33
33
|
# Regex patterns for parsing (supports both [HH:MM:SS] and [MM:SS] formats)
|
|
34
|
-
TIMESTAMP_PATTERN = re.compile(r
|
|
35
|
-
SECTION_HEADER_PATTERN = re.compile(r
|
|
36
|
-
SPEAKER_PATTERN = re.compile(r
|
|
37
|
-
EVENT_PATTERN = re.compile(r
|
|
38
|
-
INLINE_TIMESTAMP_PATTERN = re.compile(r
|
|
34
|
+
TIMESTAMP_PATTERN = re.compile(r"\[(\d{1,2}):(\d{2}):(\d{2})\]|\[(\d{1,2}):(\d{2})\]")
|
|
35
|
+
SECTION_HEADER_PATTERN = re.compile(r"^##\s*\[(\d{1,2}):(\d{2}):(\d{2})\]\s*(.+)$")
|
|
36
|
+
SPEAKER_PATTERN = re.compile(r"^\*\*(.+?[::])\*\*\s*(.+)$")
|
|
37
|
+
EVENT_PATTERN = re.compile(r"^\[([^\]]+)\]\s*\[(?:(\d{1,2}):(\d{2}):(\d{2})|(\d{1,2}):(\d{2}))\]$")
|
|
38
|
+
INLINE_TIMESTAMP_PATTERN = re.compile(r"^(.+?)\s*\[(?:(\d{1,2}):(\d{2}):(\d{2})|(\d{1,2}):(\d{2}))\]$")
|
|
39
39
|
|
|
40
40
|
# New patterns for YouTube link format: [[MM:SS](URL&t=seconds)]
|
|
41
|
-
YOUTUBE_SECTION_PATTERN = re.compile(r
|
|
42
|
-
YOUTUBE_INLINE_PATTERN = re.compile(r
|
|
41
|
+
YOUTUBE_SECTION_PATTERN = re.compile(r"^##\s*\[\[(\d{1,2}):(\d{2})\]\([^)]*&t=(\d+)\)\]\s*(.+)$")
|
|
42
|
+
YOUTUBE_INLINE_PATTERN = re.compile(r"^(.+?)\s*\[\[(\d{1,2}):(\d{2})\]\([^)]*&t=(\d+)\)\]$")
|
|
43
43
|
|
|
44
44
|
@classmethod
|
|
45
45
|
def parse_timestamp(cls, *args) -> float:
|
|
@@ -61,7 +61,7 @@ class GeminiReader:
|
|
|
61
61
|
# Direct seconds (from YouTube &t= parameter)
|
|
62
62
|
return int(args[0])
|
|
63
63
|
else:
|
|
64
|
-
raise ValueError(f
|
|
64
|
+
raise ValueError(f"Invalid timestamp args: {args}")
|
|
65
65
|
|
|
66
66
|
@classmethod
|
|
67
67
|
def read(
|
|
@@ -82,13 +82,13 @@ class GeminiReader:
|
|
|
82
82
|
"""
|
|
83
83
|
transcript_path = Path(transcript_path).expanduser().resolve()
|
|
84
84
|
if not transcript_path.exists():
|
|
85
|
-
raise FileNotFoundError(f
|
|
85
|
+
raise FileNotFoundError(f"Transcript file not found: {transcript_path}")
|
|
86
86
|
|
|
87
87
|
segments: List[GeminiSegment] = []
|
|
88
88
|
current_section = None
|
|
89
89
|
current_speaker = None
|
|
90
90
|
|
|
91
|
-
with open(transcript_path,
|
|
91
|
+
with open(transcript_path, "r", encoding="utf-8") as f:
|
|
92
92
|
lines = f.readlines()
|
|
93
93
|
|
|
94
94
|
for line_num, line in enumerate(lines, start=1):
|
|
@@ -97,9 +97,9 @@ class GeminiReader:
|
|
|
97
97
|
continue
|
|
98
98
|
|
|
99
99
|
# Skip table of contents
|
|
100
|
-
if line.startswith(
|
|
100
|
+
if line.startswith("* ["):
|
|
101
101
|
continue
|
|
102
|
-
if line.startswith(
|
|
102
|
+
if line.startswith("## Table of Contents"):
|
|
103
103
|
continue
|
|
104
104
|
|
|
105
105
|
# Parse section headers
|
|
@@ -114,7 +114,7 @@ class GeminiReader:
|
|
|
114
114
|
text=section_title.strip(),
|
|
115
115
|
timestamp=timestamp,
|
|
116
116
|
section=current_section,
|
|
117
|
-
segment_type=
|
|
117
|
+
segment_type="section_header",
|
|
118
118
|
line_number=line_num,
|
|
119
119
|
)
|
|
120
120
|
)
|
|
@@ -133,7 +133,7 @@ class GeminiReader:
|
|
|
133
133
|
text=section_title.strip(),
|
|
134
134
|
timestamp=timestamp,
|
|
135
135
|
section=current_section,
|
|
136
|
-
segment_type=
|
|
136
|
+
segment_type="section_header",
|
|
137
137
|
line_number=line_num,
|
|
138
138
|
)
|
|
139
139
|
)
|
|
@@ -158,7 +158,7 @@ class GeminiReader:
|
|
|
158
158
|
text=event_text.strip(),
|
|
159
159
|
timestamp=timestamp,
|
|
160
160
|
section=current_section,
|
|
161
|
-
segment_type=
|
|
161
|
+
segment_type="event",
|
|
162
162
|
line_number=line_num,
|
|
163
163
|
)
|
|
164
164
|
)
|
|
@@ -200,7 +200,7 @@ class GeminiReader:
|
|
|
200
200
|
timestamp=timestamp,
|
|
201
201
|
speaker=current_speaker,
|
|
202
202
|
section=current_section,
|
|
203
|
-
segment_type=
|
|
203
|
+
segment_type="dialogue",
|
|
204
204
|
line_number=line_num,
|
|
205
205
|
)
|
|
206
206
|
)
|
|
@@ -228,7 +228,7 @@ class GeminiReader:
|
|
|
228
228
|
timestamp=timestamp,
|
|
229
229
|
speaker=current_speaker,
|
|
230
230
|
section=current_section,
|
|
231
|
-
segment_type=
|
|
231
|
+
segment_type="dialogue",
|
|
232
232
|
line_number=line_num,
|
|
233
233
|
)
|
|
234
234
|
)
|
|
@@ -246,14 +246,14 @@ class GeminiReader:
|
|
|
246
246
|
timestamp=timestamp,
|
|
247
247
|
speaker=current_speaker,
|
|
248
248
|
section=current_section,
|
|
249
|
-
segment_type=
|
|
249
|
+
segment_type="dialogue",
|
|
250
250
|
line_number=line_num,
|
|
251
251
|
)
|
|
252
252
|
)
|
|
253
253
|
continue
|
|
254
254
|
|
|
255
255
|
# Skip markdown headers and other formatting
|
|
256
|
-
if line.startswith(
|
|
256
|
+
if line.startswith("#"):
|
|
257
257
|
continue
|
|
258
258
|
|
|
259
259
|
return segments
|
|
@@ -283,10 +283,10 @@ class GeminiReader:
|
|
|
283
283
|
segments = cls.read(transcript_path, include_events=False, include_sections=False)
|
|
284
284
|
|
|
285
285
|
# Filter to only dialogue segments with timestamps
|
|
286
|
-
dialogue_segments = [s for s in segments if s.segment_type ==
|
|
286
|
+
dialogue_segments = [s for s in segments if s.segment_type == "dialogue" and s.timestamp is not None]
|
|
287
287
|
|
|
288
288
|
if not dialogue_segments:
|
|
289
|
-
raise ValueError(f
|
|
289
|
+
raise ValueError(f"No dialogue segments with timestamps found in {transcript_path}")
|
|
290
290
|
|
|
291
291
|
# Sort by timestamp
|
|
292
292
|
dialogue_segments.sort(key=lambda x: x.timestamp)
|
|
@@ -308,7 +308,7 @@ class GeminiReader:
|
|
|
308
308
|
text=segment.text,
|
|
309
309
|
start=segment.timestamp,
|
|
310
310
|
duration=max(duration, min_duration),
|
|
311
|
-
id=f
|
|
311
|
+
id=f"segment_{i:05d}",
|
|
312
312
|
speaker=segment.speaker,
|
|
313
313
|
)
|
|
314
314
|
)
|
|
@@ -337,13 +337,13 @@ class GeminiReader:
|
|
|
337
337
|
else:
|
|
338
338
|
# Different speaker or gap too large, save previous segment
|
|
339
339
|
if current_texts:
|
|
340
|
-
merged_text =
|
|
340
|
+
merged_text = " ".join(current_texts)
|
|
341
341
|
merged.append(
|
|
342
342
|
Supervision(
|
|
343
343
|
text=merged_text,
|
|
344
344
|
start=current_start,
|
|
345
345
|
duration=last_end_time - current_start,
|
|
346
|
-
id=f
|
|
346
|
+
id=f"merged_{len(merged):05d}",
|
|
347
347
|
)
|
|
348
348
|
)
|
|
349
349
|
current_speaker = segment.speaker
|
|
@@ -353,13 +353,13 @@ class GeminiReader:
|
|
|
353
353
|
|
|
354
354
|
# Add final segment
|
|
355
355
|
if current_texts:
|
|
356
|
-
merged_text =
|
|
356
|
+
merged_text = " ".join(current_texts)
|
|
357
357
|
merged.append(
|
|
358
358
|
Supervision(
|
|
359
359
|
text=merged_text,
|
|
360
360
|
start=current_start,
|
|
361
361
|
duration=last_end_time - current_start,
|
|
362
|
-
id=f
|
|
362
|
+
id=f"merged_{len(merged):05d}",
|
|
363
363
|
)
|
|
364
364
|
)
|
|
365
365
|
|
|
@@ -368,4 +368,4 @@ class GeminiReader:
|
|
|
368
368
|
return supervisions
|
|
369
369
|
|
|
370
370
|
|
|
371
|
-
__all__ = [
|
|
371
|
+
__all__ = ["GeminiReader", "GeminiSegment"]
|
lattifai/io/gemini_writer.py
CHANGED
|
@@ -19,7 +19,7 @@ class GeminiWriter:
|
|
|
19
19
|
hours = int(seconds // 3600)
|
|
20
20
|
minutes = int((seconds % 3600) // 60)
|
|
21
21
|
secs = int(seconds % 60)
|
|
22
|
-
return f
|
|
22
|
+
return f"[{hours:02d}:{minutes:02d}:{secs:02d}]"
|
|
23
23
|
|
|
24
24
|
@classmethod
|
|
25
25
|
def update_timestamps(
|
|
@@ -44,7 +44,7 @@ class GeminiWriter:
|
|
|
44
44
|
output_path = Path(output_path)
|
|
45
45
|
|
|
46
46
|
# Read original file
|
|
47
|
-
with open(original_path,
|
|
47
|
+
with open(original_path, "r", encoding="utf-8") as f:
|
|
48
48
|
lines = f.readlines()
|
|
49
49
|
|
|
50
50
|
# Parse original segments to get line numbers
|
|
@@ -66,7 +66,7 @@ class GeminiWriter:
|
|
|
66
66
|
|
|
67
67
|
# Write updated content
|
|
68
68
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
69
|
-
with open(output_path,
|
|
69
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
70
70
|
f.writelines(updated_lines)
|
|
71
71
|
|
|
72
72
|
return output_path
|
|
@@ -83,7 +83,7 @@ class GeminiWriter:
|
|
|
83
83
|
mapping = {}
|
|
84
84
|
|
|
85
85
|
# Create a simple text-based matching
|
|
86
|
-
dialogue_segments = [s for s in original_segments if s.segment_type ==
|
|
86
|
+
dialogue_segments = [s for s in original_segments if s.segment_type == "dialogue"]
|
|
87
87
|
|
|
88
88
|
# Try to match based on text content
|
|
89
89
|
for aligned_sup in aligned_supervisions:
|
|
@@ -120,7 +120,7 @@ class GeminiWriter:
|
|
|
120
120
|
|
|
121
121
|
# Replace timestamp patterns
|
|
122
122
|
# Pattern 1: [HH:MM:SS] at the end or in brackets
|
|
123
|
-
line = re.sub(r
|
|
123
|
+
line = re.sub(r"\[\d{2}:\d{2}:\d{2}\]", new_ts_str, line)
|
|
124
124
|
|
|
125
125
|
return line
|
|
126
126
|
|
|
@@ -146,28 +146,28 @@ class GeminiWriter:
|
|
|
146
146
|
output_path = Path(output_path)
|
|
147
147
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
148
148
|
|
|
149
|
-
with open(output_path,
|
|
150
|
-
f.write(
|
|
149
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
150
|
+
f.write("# Aligned Transcript\n\n")
|
|
151
151
|
|
|
152
152
|
for i, sup in enumerate(aligned_supervisions):
|
|
153
153
|
# Write segment with timestamp
|
|
154
154
|
start_ts = cls.format_timestamp(sup.start)
|
|
155
|
-
f.write(f
|
|
155
|
+
f.write(f"{start_ts} {sup.text}\n")
|
|
156
156
|
|
|
157
157
|
# Optionally write word-level timestamps
|
|
158
|
-
if include_word_timestamps and hasattr(sup,
|
|
159
|
-
if
|
|
160
|
-
f.write(
|
|
158
|
+
if include_word_timestamps and hasattr(sup, "alignment") and sup.alignment:
|
|
159
|
+
if "word" in sup.alignment:
|
|
160
|
+
f.write(" Words: ")
|
|
161
161
|
word_parts = []
|
|
162
|
-
for word_info in sup.alignment[
|
|
163
|
-
word_ts = cls.format_timestamp(word_info[
|
|
162
|
+
for word_info in sup.alignment["word"]:
|
|
163
|
+
word_ts = cls.format_timestamp(word_info["start"])
|
|
164
164
|
word_parts.append(f'{word_info["symbol"]}{word_ts}')
|
|
165
|
-
f.write(
|
|
166
|
-
f.write(
|
|
165
|
+
f.write(" ".join(word_parts))
|
|
166
|
+
f.write("\n")
|
|
167
167
|
|
|
168
|
-
f.write(
|
|
168
|
+
f.write("\n")
|
|
169
169
|
|
|
170
170
|
return output_path
|
|
171
171
|
|
|
172
172
|
|
|
173
|
-
__all__ = [
|
|
173
|
+
__all__ = ["GeminiWriter"]
|
lattifai/io/reader.py
CHANGED
|
@@ -7,7 +7,7 @@ from lhotse.utils import Pathlike
|
|
|
7
7
|
from .supervision import Supervision
|
|
8
8
|
from .text_parser import parse_speaker_text
|
|
9
9
|
|
|
10
|
-
SubtitleFormat = Literal[
|
|
10
|
+
SubtitleFormat = Literal["txt", "srt", "vtt", "ass", "auto"]
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class SubtitleReader(ABCMeta):
|
|
@@ -27,28 +27,27 @@ class SubtitleReader(ABCMeta):
|
|
|
27
27
|
Parsed text in Lhotse Cut
|
|
28
28
|
"""
|
|
29
29
|
if not format and Path(str(subtitle)).exists():
|
|
30
|
-
format = Path(str(subtitle)).suffix.lstrip(
|
|
30
|
+
format = Path(str(subtitle)).suffix.lstrip(".").lower()
|
|
31
31
|
elif format:
|
|
32
32
|
format = format.lower()
|
|
33
33
|
|
|
34
|
-
if format ==
|
|
34
|
+
if format == "gemini" or str(subtitle).endswith("Gemini.md"):
|
|
35
35
|
from .gemini_reader import GeminiReader
|
|
36
36
|
|
|
37
37
|
supervisions = GeminiReader.extract_for_alignment(subtitle)
|
|
38
|
-
elif format ==
|
|
38
|
+
elif format == "txt" or (format == "auto" and str(subtitle)[-4:].lower() == ".txt"):
|
|
39
39
|
if not Path(str(subtitle)).exists(): # str
|
|
40
|
-
lines = [line.strip() for line in str(subtitle).split(
|
|
40
|
+
lines = [line.strip() for line in str(subtitle).split("\n")]
|
|
41
41
|
else: # file
|
|
42
42
|
path_str = str(subtitle)
|
|
43
|
-
with open(path_str, encoding=
|
|
43
|
+
with open(path_str, encoding="utf-8") as f:
|
|
44
44
|
lines = [line.strip() for line in f.readlines()]
|
|
45
45
|
supervisions = [Supervision(text=line) for line in lines if line]
|
|
46
46
|
else:
|
|
47
47
|
try:
|
|
48
48
|
supervisions = cls._parse_subtitle(subtitle, format=format)
|
|
49
49
|
except Exception as e:
|
|
50
|
-
|
|
51
|
-
print(f"Failed to parse subtitle with format {format}, trying 'gemini' parser.")
|
|
50
|
+
print(f"Failed to parse subtitle with Format: {format}, Exception: {e}, trying 'gemini' parser.")
|
|
52
51
|
from .gemini_reader import GeminiReader
|
|
53
52
|
|
|
54
53
|
supervisions = GeminiReader.extract_for_alignment(subtitle)
|
|
@@ -61,18 +60,20 @@ class SubtitleReader(ABCMeta):
|
|
|
61
60
|
|
|
62
61
|
try:
|
|
63
62
|
subs: pysubs2.SSAFile = pysubs2.load(
|
|
64
|
-
subtitle, encoding=
|
|
63
|
+
subtitle, encoding="utf-8", format_=format if format != "auto" else None
|
|
65
64
|
) # file
|
|
66
65
|
except IOError:
|
|
67
66
|
try:
|
|
68
67
|
subs: pysubs2.SSAFile = pysubs2.SSAFile.from_string(
|
|
69
|
-
subtitle, format_=format if format !=
|
|
68
|
+
subtitle, format_=format if format != "auto" else None
|
|
70
69
|
) # str
|
|
71
|
-
except:
|
|
72
|
-
|
|
70
|
+
except Exception as e:
|
|
71
|
+
del e
|
|
72
|
+
subs: pysubs2.SSAFile = pysubs2.load(subtitle, encoding="utf-8") # auto detect format
|
|
73
73
|
|
|
74
74
|
supervisions = []
|
|
75
75
|
for event in subs.events:
|
|
76
|
+
# NOT apply text_parser.py:normalize_html_text here, to keep original text in subtitles
|
|
76
77
|
speaker, text = parse_speaker_text(event.text)
|
|
77
78
|
supervisions.append(
|
|
78
79
|
Supervision(
|
lattifai/io/supervision.py
CHANGED
|
@@ -24,10 +24,10 @@ class Supervision(SupervisionSegment):
|
|
|
24
24
|
"""
|
|
25
25
|
|
|
26
26
|
text: Optional[str] = None
|
|
27
|
-
id: str =
|
|
28
|
-
recording_id: str =
|
|
27
|
+
id: str = ""
|
|
28
|
+
recording_id: str = ""
|
|
29
29
|
start: Seconds = 0.0
|
|
30
30
|
duration: Seconds = 0.0
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
__all__ = [
|
|
33
|
+
__all__ = ["Supervision"]
|
lattifai/io/text_parser.py
CHANGED
|
@@ -3,23 +3,50 @@ import re
|
|
|
3
3
|
from typing import Optional, Tuple
|
|
4
4
|
|
|
5
5
|
# 来自于字幕中常见的说话人标记格式
|
|
6
|
-
SPEAKER_PATTERN = re.compile(r
|
|
6
|
+
SPEAKER_PATTERN = re.compile(r"((?:>>|>>|>|>).*?[::])\s*(.*)")
|
|
7
7
|
|
|
8
8
|
# Transcriber Output Example:
|
|
9
9
|
# 26:19.919 --> 26:34.921
|
|
10
10
|
# [SPEAKER_01]: 越来越多的科技巨头入...
|
|
11
|
-
SPEAKER_LATTIFAI = re.compile(r
|
|
11
|
+
SPEAKER_LATTIFAI = re.compile(r"(^\[SPEAKER_.*?\][::])\s*(.*)")
|
|
12
12
|
|
|
13
13
|
# NISHTHA BHATIA: Hey, everyone.
|
|
14
14
|
# DIETER: Oh, hey, Nishtha.
|
|
15
15
|
# GEMINI: That might
|
|
16
|
-
SPEAKER_PATTERN2 = re.compile(r
|
|
16
|
+
SPEAKER_PATTERN2 = re.compile(r"^([A-Z]{1,15}(?:\s+[A-Z]{1,15})?[::])\s*(.*)$")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def normalize_html_text(text: str) -> str:
|
|
20
|
+
"""Normalize HTML text by decoding entities and stripping whitespace."""
|
|
21
|
+
html_entities = {
|
|
22
|
+
"&": "&",
|
|
23
|
+
"<": "<",
|
|
24
|
+
">": ">",
|
|
25
|
+
""": '"',
|
|
26
|
+
"'": "'",
|
|
27
|
+
" ": " ",
|
|
28
|
+
"\\N": " ",
|
|
29
|
+
"…": " ",
|
|
30
|
+
}
|
|
31
|
+
for entity, char in html_entities.items():
|
|
32
|
+
text = text.replace(entity, char)
|
|
33
|
+
|
|
34
|
+
text = re.sub(r"\s+", " ", text) # Replace multiple spaces with a single space
|
|
35
|
+
|
|
36
|
+
# Convert curly apostrophes to straight apostrophes for common English contractions
|
|
37
|
+
# Handles: 't 's 'll 're 've 'd 'm
|
|
38
|
+
# For example, convert "don't" to "don't"
|
|
39
|
+
text = re.sub(r"([a-zA-Z])’([tsdm]|ll|re|ve)\b", r"\1'\2", text, flags=re.IGNORECASE)
|
|
40
|
+
# For example, convert "5’s" to "5's"
|
|
41
|
+
text = re.sub(r"([0-9])’([s])\b", r"\1'\2", text, flags=re.IGNORECASE)
|
|
42
|
+
|
|
43
|
+
return text.strip()
|
|
17
44
|
|
|
18
45
|
|
|
19
46
|
def parse_speaker_text(line) -> Tuple[Optional[str], str]:
|
|
20
|
-
|
|
47
|
+
"""Parse a line of text to extract speaker and content."""
|
|
21
48
|
|
|
22
|
-
if
|
|
49
|
+
if ":" not in line and ":" not in line:
|
|
23
50
|
return None, line
|
|
24
51
|
|
|
25
52
|
# 匹配以 >> 开头的行,并去除开头的名字和冒号
|
|
@@ -31,7 +58,7 @@ def parse_speaker_text(line) -> Tuple[Optional[str], str]:
|
|
|
31
58
|
if match:
|
|
32
59
|
assert len(match.groups()) == 2, match.groups()
|
|
33
60
|
if not match.group(1):
|
|
34
|
-
logging.error(f
|
|
61
|
+
logging.error(f"ParseSub LINE [{line}]")
|
|
35
62
|
else:
|
|
36
63
|
return match.group(1).strip(), match.group(2).strip()
|
|
37
64
|
|
|
@@ -43,15 +70,15 @@ def parse_speaker_text(line) -> Tuple[Optional[str], str]:
|
|
|
43
70
|
return None, line
|
|
44
71
|
|
|
45
72
|
|
|
46
|
-
if __name__ ==
|
|
47
|
-
pattern = re.compile(r
|
|
48
|
-
pattern = re.compile(r
|
|
73
|
+
if __name__ == "__main__":
|
|
74
|
+
pattern = re.compile(r">>\s*(.*?)\s*[::]\s*(.*)")
|
|
75
|
+
pattern = re.compile(r"(>>.*?[::])\s*(.*)")
|
|
49
76
|
|
|
50
77
|
test_strings = [
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
78
|
+
">>Key: Value",
|
|
79
|
+
">> Key with space : Value with space ",
|
|
80
|
+
">> 全角键 : 全角值",
|
|
81
|
+
">>Key:Value xxx. >>Key:Value",
|
|
55
82
|
]
|
|
56
83
|
|
|
57
84
|
for text in test_strings:
|
|
@@ -60,16 +87,16 @@ if __name__ == '__main__':
|
|
|
60
87
|
print(f"Input: '{text}'")
|
|
61
88
|
print(f" Key: '{match.group(1)}'")
|
|
62
89
|
print(f" Value: '{match.group(2)}'")
|
|
63
|
-
print(
|
|
90
|
+
print("-------------")
|
|
64
91
|
|
|
65
92
|
# pattern2
|
|
66
|
-
test_strings2 = [
|
|
93
|
+
test_strings2 = ["NISHTHA BHATIA: Hey, everyone.", "DIETER: Oh, hey, Nishtha.", "GEMINI: That might"]
|
|
67
94
|
for text in test_strings2:
|
|
68
95
|
match = SPEAKER_PATTERN2.match(text)
|
|
69
96
|
if match:
|
|
70
97
|
print(f" Input: '{text}'")
|
|
71
98
|
print(f"Speaker: '{match.group(1)}'")
|
|
72
99
|
print(f"Content: '{match.group(2)}'")
|
|
73
|
-
print(
|
|
100
|
+
print("-------------")
|
|
74
101
|
else:
|
|
75
102
|
raise ValueError(f"No match for: '{text}'")
|
lattifai/io/utils.py
CHANGED
|
@@ -3,13 +3,13 @@ Utility constants and helper functions for subtitle I/O operations
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
# Supported subtitle formats for reading/writing
|
|
6
|
-
SUBTITLE_FORMATS = [
|
|
6
|
+
SUBTITLE_FORMATS = ["srt", "vtt", "ass", "ssa", "sub", "sbv", "txt", "md"]
|
|
7
7
|
|
|
8
8
|
# Input subtitle formats (includes special formats like 'auto' and 'gemini')
|
|
9
|
-
INPUT_SUBTITLE_FORMATS = [
|
|
9
|
+
INPUT_SUBTITLE_FORMATS = ["srt", "vtt", "ass", "ssa", "sub", "sbv", "txt", "auto", "gemini"]
|
|
10
10
|
|
|
11
11
|
# Output subtitle formats (includes special formats like 'TextGrid' and 'json')
|
|
12
|
-
OUTPUT_SUBTITLE_FORMATS = [
|
|
12
|
+
OUTPUT_SUBTITLE_FORMATS = ["srt", "vtt", "ass", "ssa", "sub", "sbv", "txt", "TextGrid", "json"]
|
|
13
13
|
|
|
14
14
|
# All subtitle formats combined (for file detection)
|
|
15
|
-
ALL_SUBTITLE_FORMATS = list(set(SUBTITLE_FORMATS + [
|
|
15
|
+
ALL_SUBTITLE_FORMATS = list(set(SUBTITLE_FORMATS + ["TextGrid", "json", "gemini"]))
|
lattifai/io/writer.py
CHANGED
|
@@ -14,45 +14,57 @@ class SubtitleWriter(ABCMeta):
|
|
|
14
14
|
|
|
15
15
|
@classmethod
|
|
16
16
|
def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
|
|
17
|
-
if str(output_path)[-4:].lower() ==
|
|
18
|
-
with open(output_path,
|
|
17
|
+
if str(output_path)[-4:].lower() == ".txt":
|
|
18
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
19
19
|
for sup in alignments:
|
|
20
20
|
word_items = parse_alignment_from_supervision(sup)
|
|
21
21
|
if word_items:
|
|
22
22
|
for item in word_items:
|
|
23
|
-
f.write(f
|
|
23
|
+
f.write(f"[{item.start:.2f}-{item.end:.2f}] {item.symbol}\n")
|
|
24
24
|
else:
|
|
25
|
-
text = f
|
|
26
|
-
f.write(f
|
|
25
|
+
text = f"{sup.speaker} {sup.text}" if sup.speaker is not None else sup.text
|
|
26
|
+
f.write(f"[{sup.start:.2f}-{sup.end:.2f}] {text}\n")
|
|
27
27
|
|
|
28
|
-
elif str(output_path)[-5:].lower() ==
|
|
29
|
-
with open(output_path,
|
|
28
|
+
elif str(output_path)[-5:].lower() == ".json":
|
|
29
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
30
30
|
# Enhanced JSON export with word-level alignment
|
|
31
31
|
json_data = []
|
|
32
32
|
for sup in alignments:
|
|
33
33
|
sup_dict = sup.to_dict()
|
|
34
34
|
json_data.append(sup_dict)
|
|
35
35
|
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
|
36
|
-
elif str(output_path).
|
|
36
|
+
elif str(output_path).lower().endswith(".textgrid"):
|
|
37
37
|
from tgt import Interval, IntervalTier, TextGrid, write_to_file
|
|
38
38
|
|
|
39
39
|
tg = TextGrid()
|
|
40
|
-
supervisions, words = [], []
|
|
40
|
+
supervisions, words, scores = [], [], {"utterances": [], "words": []}
|
|
41
41
|
for supervision in sorted(alignments, key=lambda x: x.start):
|
|
42
42
|
text = (
|
|
43
|
-
f
|
|
43
|
+
f"{supervision.speaker} {supervision.text}" if supervision.speaker is not None else supervision.text
|
|
44
44
|
)
|
|
45
|
-
supervisions.append(Interval(supervision.start, supervision.end, text or
|
|
45
|
+
supervisions.append(Interval(supervision.start, supervision.end, text or ""))
|
|
46
46
|
# Extract word-level alignment using helper function
|
|
47
47
|
word_items = parse_alignment_from_supervision(supervision)
|
|
48
48
|
if word_items:
|
|
49
49
|
for item in word_items:
|
|
50
50
|
words.append(Interval(item.start, item.end, item.symbol))
|
|
51
|
+
if item.score is not None:
|
|
52
|
+
scores["words"].append(Interval(item.start, item.end, f"{item.score:.2f}"))
|
|
53
|
+
if supervision.has_custom("score"):
|
|
54
|
+
scores["utterances"].append(
|
|
55
|
+
Interval(supervision.start, supervision.end, f"{supervision.score:.2f}")
|
|
56
|
+
)
|
|
51
57
|
|
|
52
|
-
tg.add_tier(IntervalTier(name=
|
|
58
|
+
tg.add_tier(IntervalTier(name="utterances", objects=supervisions))
|
|
53
59
|
if words:
|
|
54
|
-
tg.add_tier(IntervalTier(name=
|
|
55
|
-
|
|
60
|
+
tg.add_tier(IntervalTier(name="words", objects=words))
|
|
61
|
+
|
|
62
|
+
if scores["utterances"]:
|
|
63
|
+
tg.add_tier(IntervalTier(name="utterance_scores", objects=scores["utterances"]))
|
|
64
|
+
if scores["words"]:
|
|
65
|
+
tg.add_tier(IntervalTier(name="word_scores", objects=scores["words"]))
|
|
66
|
+
|
|
67
|
+
write_to_file(tg, output_path, format="long")
|
|
56
68
|
else:
|
|
57
69
|
subs = pysubs2.SSAFile()
|
|
58
70
|
for sup in alignments:
|
|
@@ -64,8 +76,8 @@ class SubtitleWriter(ABCMeta):
|
|
|
64
76
|
pysubs2.SSAEvent(start=int(word.start * 1000), end=int(word.end * 1000), text=word.symbol)
|
|
65
77
|
)
|
|
66
78
|
else:
|
|
67
|
-
text = f
|
|
68
|
-
subs.append(pysubs2.SSAEvent(start=int(sup.start * 1000), end=int(sup.end * 1000), text=text or
|
|
79
|
+
text = f"{sup.speaker} {sup.text}" if sup.speaker is not None else sup.text
|
|
80
|
+
subs.append(pysubs2.SSAEvent(start=int(sup.start * 1000), end=int(sup.end * 1000), text=text or ""))
|
|
69
81
|
subs.save(output_path)
|
|
70
82
|
|
|
71
83
|
return output_path
|
|
@@ -81,10 +93,10 @@ def parse_alignment_from_supervision(supervision: Any) -> Optional[List[Alignmen
|
|
|
81
93
|
Returns:
|
|
82
94
|
List of AlignmentItem objects, or None if no alignment data present
|
|
83
95
|
"""
|
|
84
|
-
if not hasattr(supervision,
|
|
96
|
+
if not hasattr(supervision, "alignment") or not supervision.alignment:
|
|
85
97
|
return None
|
|
86
98
|
|
|
87
|
-
if
|
|
99
|
+
if "word" not in supervision.alignment:
|
|
88
100
|
return None
|
|
89
101
|
|
|
90
|
-
return supervision.alignment[
|
|
102
|
+
return supervision.alignment["word"]
|
lattifai/tokenizer/__init__.py
CHANGED
lattifai/tokenizer/phonemizer.py
CHANGED
|
@@ -4,13 +4,13 @@ from typing import List, Optional, Union
|
|
|
4
4
|
from dp.phonemizer import Phonemizer # g2p-phonemizer
|
|
5
5
|
from num2words import num2words
|
|
6
6
|
|
|
7
|
-
LANGUAGE =
|
|
7
|
+
LANGUAGE = "omni"
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class G2Phonemizer:
|
|
11
11
|
def __init__(self, model_checkpoint, device):
|
|
12
12
|
self.phonemizer = Phonemizer.from_checkpoint(model_checkpoint, device=device).predictor
|
|
13
|
-
self.pattern = re.compile(r
|
|
13
|
+
self.pattern = re.compile(r"\d+")
|
|
14
14
|
|
|
15
15
|
def num2words(self, word, lang: str):
|
|
16
16
|
matches = self.pattern.findall(word)
|
|
@@ -31,7 +31,7 @@ class G2Phonemizer:
|
|
|
31
31
|
is_list = False
|
|
32
32
|
|
|
33
33
|
predictions = self.phonemizer(
|
|
34
|
-
[self.num2words(word.replace(
|
|
34
|
+
[self.num2words(word.replace(" .", ".").replace(".", " ."), lang=lang or "en") for word in words],
|
|
35
35
|
lang=LANGUAGE,
|
|
36
36
|
batch_size=min(batch_size or len(words), 128),
|
|
37
37
|
num_prons=num_prons,
|