lattifai 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,9 @@
1
1
  """Reader for YouTube transcript files with speaker labels and timestamps."""
2
2
 
3
3
  import re
4
- from dataclasses import dataclass, field
4
+ from dataclasses import dataclass
5
5
  from pathlib import Path
6
- from typing import List, Optional, Tuple
6
+ from typing import List, Optional
7
7
 
8
8
  from lhotse.utils import Pathlike
9
9
 
@@ -18,7 +18,7 @@ class GeminiSegment:
18
18
  timestamp: Optional[float] = None
19
19
  speaker: Optional[str] = None
20
20
  section: Optional[str] = None
21
- segment_type: str = 'dialogue' # 'dialogue', 'event', or 'section_header'
21
+ segment_type: str = "dialogue" # 'dialogue', 'event', or 'section_header'
22
22
  line_number: int = 0
23
23
 
24
24
  @property
@@ -31,15 +31,15 @@ class GeminiReader:
31
31
  """Parser for YouTube transcript format with speaker labels and timestamps."""
32
32
 
33
33
  # Regex patterns for parsing (supports both [HH:MM:SS] and [MM:SS] formats)
34
- TIMESTAMP_PATTERN = re.compile(r'\[(\d{1,2}):(\d{2}):(\d{2})\]|\[(\d{1,2}):(\d{2})\]')
35
- SECTION_HEADER_PATTERN = re.compile(r'^##\s*\[(\d{1,2}):(\d{2}):(\d{2})\]\s*(.+)$')
36
- SPEAKER_PATTERN = re.compile(r'^\*\*(.+?[::])\*\*\s*(.+)$')
37
- EVENT_PATTERN = re.compile(r'^\[([^\]]+)\]\s*\[(?:(\d{1,2}):(\d{2}):(\d{2})|(\d{1,2}):(\d{2}))\]$')
38
- INLINE_TIMESTAMP_PATTERN = re.compile(r'^(.+?)\s*\[(?:(\d{1,2}):(\d{2}):(\d{2})|(\d{1,2}):(\d{2}))\]$')
34
+ TIMESTAMP_PATTERN = re.compile(r"\[(\d{1,2}):(\d{2}):(\d{2})\]|\[(\d{1,2}):(\d{2})\]")
35
+ SECTION_HEADER_PATTERN = re.compile(r"^##\s*\[(\d{1,2}):(\d{2}):(\d{2})\]\s*(.+)$")
36
+ SPEAKER_PATTERN = re.compile(r"^\*\*(.+?[::])\*\*\s*(.+)$")
37
+ EVENT_PATTERN = re.compile(r"^\[([^\]]+)\]\s*\[(?:(\d{1,2}):(\d{2}):(\d{2})|(\d{1,2}):(\d{2}))\]$")
38
+ INLINE_TIMESTAMP_PATTERN = re.compile(r"^(.+?)\s*\[(?:(\d{1,2}):(\d{2}):(\d{2})|(\d{1,2}):(\d{2}))\]$")
39
39
 
40
40
  # New patterns for YouTube link format: [[MM:SS](URL&t=seconds)]
41
- YOUTUBE_SECTION_PATTERN = re.compile(r'^##\s*\[\[(\d{1,2}):(\d{2})\]\([^)]*&t=(\d+)\)\]\s*(.+)$')
42
- YOUTUBE_INLINE_PATTERN = re.compile(r'^(.+?)\s*\[\[(\d{1,2}):(\d{2})\]\([^)]*&t=(\d+)\)\]$')
41
+ YOUTUBE_SECTION_PATTERN = re.compile(r"^##\s*\[\[(\d{1,2}):(\d{2})\]\([^)]*&t=(\d+)\)\]\s*(.+)$")
42
+ YOUTUBE_INLINE_PATTERN = re.compile(r"^(.+?)\s*\[\[(\d{1,2}):(\d{2})\]\([^)]*&t=(\d+)\)\]$")
43
43
 
44
44
  @classmethod
45
45
  def parse_timestamp(cls, *args) -> float:
@@ -61,7 +61,7 @@ class GeminiReader:
61
61
  # Direct seconds (from YouTube &t= parameter)
62
62
  return int(args[0])
63
63
  else:
64
- raise ValueError(f'Invalid timestamp args: {args}')
64
+ raise ValueError(f"Invalid timestamp args: {args}")
65
65
 
66
66
  @classmethod
67
67
  def read(
@@ -82,13 +82,13 @@ class GeminiReader:
82
82
  """
83
83
  transcript_path = Path(transcript_path).expanduser().resolve()
84
84
  if not transcript_path.exists():
85
- raise FileNotFoundError(f'Transcript file not found: {transcript_path}')
85
+ raise FileNotFoundError(f"Transcript file not found: {transcript_path}")
86
86
 
87
87
  segments: List[GeminiSegment] = []
88
88
  current_section = None
89
89
  current_speaker = None
90
90
 
91
- with open(transcript_path, 'r', encoding='utf-8') as f:
91
+ with open(transcript_path, "r", encoding="utf-8") as f:
92
92
  lines = f.readlines()
93
93
 
94
94
  for line_num, line in enumerate(lines, start=1):
@@ -97,9 +97,9 @@ class GeminiReader:
97
97
  continue
98
98
 
99
99
  # Skip table of contents
100
- if line.startswith('* ['):
100
+ if line.startswith("* ["):
101
101
  continue
102
- if line.startswith('## Table of Contents'):
102
+ if line.startswith("## Table of Contents"):
103
103
  continue
104
104
 
105
105
  # Parse section headers
@@ -114,7 +114,7 @@ class GeminiReader:
114
114
  text=section_title.strip(),
115
115
  timestamp=timestamp,
116
116
  section=current_section,
117
- segment_type='section_header',
117
+ segment_type="section_header",
118
118
  line_number=line_num,
119
119
  )
120
120
  )
@@ -133,7 +133,7 @@ class GeminiReader:
133
133
  text=section_title.strip(),
134
134
  timestamp=timestamp,
135
135
  section=current_section,
136
- segment_type='section_header',
136
+ segment_type="section_header",
137
137
  line_number=line_num,
138
138
  )
139
139
  )
@@ -158,7 +158,7 @@ class GeminiReader:
158
158
  text=event_text.strip(),
159
159
  timestamp=timestamp,
160
160
  section=current_section,
161
- segment_type='event',
161
+ segment_type="event",
162
162
  line_number=line_num,
163
163
  )
164
164
  )
@@ -200,7 +200,7 @@ class GeminiReader:
200
200
  timestamp=timestamp,
201
201
  speaker=current_speaker,
202
202
  section=current_section,
203
- segment_type='dialogue',
203
+ segment_type="dialogue",
204
204
  line_number=line_num,
205
205
  )
206
206
  )
@@ -228,7 +228,7 @@ class GeminiReader:
228
228
  timestamp=timestamp,
229
229
  speaker=current_speaker,
230
230
  section=current_section,
231
- segment_type='dialogue',
231
+ segment_type="dialogue",
232
232
  line_number=line_num,
233
233
  )
234
234
  )
@@ -246,14 +246,14 @@ class GeminiReader:
246
246
  timestamp=timestamp,
247
247
  speaker=current_speaker,
248
248
  section=current_section,
249
- segment_type='dialogue',
249
+ segment_type="dialogue",
250
250
  line_number=line_num,
251
251
  )
252
252
  )
253
253
  continue
254
254
 
255
255
  # Skip markdown headers and other formatting
256
- if line.startswith('#'):
256
+ if line.startswith("#"):
257
257
  continue
258
258
 
259
259
  return segments
@@ -283,10 +283,10 @@ class GeminiReader:
283
283
  segments = cls.read(transcript_path, include_events=False, include_sections=False)
284
284
 
285
285
  # Filter to only dialogue segments with timestamps
286
- dialogue_segments = [s for s in segments if s.segment_type == 'dialogue' and s.timestamp is not None]
286
+ dialogue_segments = [s for s in segments if s.segment_type == "dialogue" and s.timestamp is not None]
287
287
 
288
288
  if not dialogue_segments:
289
- raise ValueError(f'No dialogue segments with timestamps found in {transcript_path}')
289
+ raise ValueError(f"No dialogue segments with timestamps found in {transcript_path}")
290
290
 
291
291
  # Sort by timestamp
292
292
  dialogue_segments.sort(key=lambda x: x.timestamp)
@@ -308,7 +308,7 @@ class GeminiReader:
308
308
  text=segment.text,
309
309
  start=segment.timestamp,
310
310
  duration=max(duration, min_duration),
311
- id=f'segment_{i:05d}',
311
+ id=f"segment_{i:05d}",
312
312
  speaker=segment.speaker,
313
313
  )
314
314
  )
@@ -337,13 +337,13 @@ class GeminiReader:
337
337
  else:
338
338
  # Different speaker or gap too large, save previous segment
339
339
  if current_texts:
340
- merged_text = ' '.join(current_texts)
340
+ merged_text = " ".join(current_texts)
341
341
  merged.append(
342
342
  Supervision(
343
343
  text=merged_text,
344
344
  start=current_start,
345
345
  duration=last_end_time - current_start,
346
- id=f'merged_{len(merged):05d}',
346
+ id=f"merged_{len(merged):05d}",
347
347
  )
348
348
  )
349
349
  current_speaker = segment.speaker
@@ -353,13 +353,13 @@ class GeminiReader:
353
353
 
354
354
  # Add final segment
355
355
  if current_texts:
356
- merged_text = ' '.join(current_texts)
356
+ merged_text = " ".join(current_texts)
357
357
  merged.append(
358
358
  Supervision(
359
359
  text=merged_text,
360
360
  start=current_start,
361
361
  duration=last_end_time - current_start,
362
- id=f'merged_{len(merged):05d}',
362
+ id=f"merged_{len(merged):05d}",
363
363
  )
364
364
  )
365
365
 
@@ -368,4 +368,4 @@ class GeminiReader:
368
368
  return supervisions
369
369
 
370
370
 
371
- __all__ = ['GeminiReader', 'GeminiSegment']
371
+ __all__ = ["GeminiReader", "GeminiSegment"]
@@ -19,7 +19,7 @@ class GeminiWriter:
19
19
  hours = int(seconds // 3600)
20
20
  minutes = int((seconds % 3600) // 60)
21
21
  secs = int(seconds % 60)
22
- return f'[{hours:02d}:{minutes:02d}:{secs:02d}]'
22
+ return f"[{hours:02d}:{minutes:02d}:{secs:02d}]"
23
23
 
24
24
  @classmethod
25
25
  def update_timestamps(
@@ -44,7 +44,7 @@ class GeminiWriter:
44
44
  output_path = Path(output_path)
45
45
 
46
46
  # Read original file
47
- with open(original_path, 'r', encoding='utf-8') as f:
47
+ with open(original_path, "r", encoding="utf-8") as f:
48
48
  lines = f.readlines()
49
49
 
50
50
  # Parse original segments to get line numbers
@@ -66,7 +66,7 @@ class GeminiWriter:
66
66
 
67
67
  # Write updated content
68
68
  output_path.parent.mkdir(parents=True, exist_ok=True)
69
- with open(output_path, 'w', encoding='utf-8') as f:
69
+ with open(output_path, "w", encoding="utf-8") as f:
70
70
  f.writelines(updated_lines)
71
71
 
72
72
  return output_path
@@ -83,7 +83,7 @@ class GeminiWriter:
83
83
  mapping = {}
84
84
 
85
85
  # Create a simple text-based matching
86
- dialogue_segments = [s for s in original_segments if s.segment_type == 'dialogue']
86
+ dialogue_segments = [s for s in original_segments if s.segment_type == "dialogue"]
87
87
 
88
88
  # Try to match based on text content
89
89
  for aligned_sup in aligned_supervisions:
@@ -120,7 +120,7 @@ class GeminiWriter:
120
120
 
121
121
  # Replace timestamp patterns
122
122
  # Pattern 1: [HH:MM:SS] at the end or in brackets
123
- line = re.sub(r'\[\d{2}:\d{2}:\d{2}\]', new_ts_str, line)
123
+ line = re.sub(r"\[\d{2}:\d{2}:\d{2}\]", new_ts_str, line)
124
124
 
125
125
  return line
126
126
 
@@ -146,28 +146,28 @@ class GeminiWriter:
146
146
  output_path = Path(output_path)
147
147
  output_path.parent.mkdir(parents=True, exist_ok=True)
148
148
 
149
- with open(output_path, 'w', encoding='utf-8') as f:
150
- f.write('# Aligned Transcript\n\n')
149
+ with open(output_path, "w", encoding="utf-8") as f:
150
+ f.write("# Aligned Transcript\n\n")
151
151
 
152
152
  for i, sup in enumerate(aligned_supervisions):
153
153
  # Write segment with timestamp
154
154
  start_ts = cls.format_timestamp(sup.start)
155
- f.write(f'{start_ts} {sup.text}\n')
155
+ f.write(f"{start_ts} {sup.text}\n")
156
156
 
157
157
  # Optionally write word-level timestamps
158
- if include_word_timestamps and hasattr(sup, 'alignment') and sup.alignment:
159
- if 'word' in sup.alignment:
160
- f.write(' Words: ')
158
+ if include_word_timestamps and hasattr(sup, "alignment") and sup.alignment:
159
+ if "word" in sup.alignment:
160
+ f.write(" Words: ")
161
161
  word_parts = []
162
- for word_info in sup.alignment['word']:
163
- word_ts = cls.format_timestamp(word_info['start'])
162
+ for word_info in sup.alignment["word"]:
163
+ word_ts = cls.format_timestamp(word_info["start"])
164
164
  word_parts.append(f'{word_info["symbol"]}{word_ts}')
165
- f.write(' '.join(word_parts))
166
- f.write('\n')
165
+ f.write(" ".join(word_parts))
166
+ f.write("\n")
167
167
 
168
- f.write('\n')
168
+ f.write("\n")
169
169
 
170
170
  return output_path
171
171
 
172
172
 
173
- __all__ = ['GeminiWriter']
173
+ __all__ = ["GeminiWriter"]
lattifai/io/reader.py CHANGED
@@ -7,7 +7,7 @@ from lhotse.utils import Pathlike
7
7
  from .supervision import Supervision
8
8
  from .text_parser import parse_speaker_text
9
9
 
10
- SubtitleFormat = Literal['txt', 'srt', 'vtt', 'ass', 'auto']
10
+ SubtitleFormat = Literal["txt", "srt", "vtt", "ass", "auto"]
11
11
 
12
12
 
13
13
  class SubtitleReader(ABCMeta):
@@ -27,28 +27,27 @@ class SubtitleReader(ABCMeta):
27
27
  Parsed text in Lhotse Cut
28
28
  """
29
29
  if not format and Path(str(subtitle)).exists():
30
- format = Path(str(subtitle)).suffix.lstrip('.').lower()
30
+ format = Path(str(subtitle)).suffix.lstrip(".").lower()
31
31
  elif format:
32
32
  format = format.lower()
33
33
 
34
- if format == 'gemini' or str(subtitle).endswith('Gemini.md'):
34
+ if format == "gemini" or str(subtitle).endswith("Gemini.md"):
35
35
  from .gemini_reader import GeminiReader
36
36
 
37
37
  supervisions = GeminiReader.extract_for_alignment(subtitle)
38
- elif format == 'txt' or (format == 'auto' and str(subtitle)[-4:].lower() == '.txt'):
38
+ elif format == "txt" or (format == "auto" and str(subtitle)[-4:].lower() == ".txt"):
39
39
  if not Path(str(subtitle)).exists(): # str
40
- lines = [line.strip() for line in str(subtitle).split('\n')]
40
+ lines = [line.strip() for line in str(subtitle).split("\n")]
41
41
  else: # file
42
42
  path_str = str(subtitle)
43
- with open(path_str, encoding='utf-8') as f:
43
+ with open(path_str, encoding="utf-8") as f:
44
44
  lines = [line.strip() for line in f.readlines()]
45
45
  supervisions = [Supervision(text=line) for line in lines if line]
46
46
  else:
47
47
  try:
48
48
  supervisions = cls._parse_subtitle(subtitle, format=format)
49
49
  except Exception as e:
50
- del e
51
- print(f"Failed to parse subtitle with format {format}, trying 'gemini' parser.")
50
+ print(f"Failed to parse subtitle with Format: {format}, Exception: {e}, trying 'gemini' parser.")
52
51
  from .gemini_reader import GeminiReader
53
52
 
54
53
  supervisions = GeminiReader.extract_for_alignment(subtitle)
@@ -61,18 +60,20 @@ class SubtitleReader(ABCMeta):
61
60
 
62
61
  try:
63
62
  subs: pysubs2.SSAFile = pysubs2.load(
64
- subtitle, encoding='utf-8', format_=format if format != 'auto' else None
63
+ subtitle, encoding="utf-8", format_=format if format != "auto" else None
65
64
  ) # file
66
65
  except IOError:
67
66
  try:
68
67
  subs: pysubs2.SSAFile = pysubs2.SSAFile.from_string(
69
- subtitle, format_=format if format != 'auto' else None
68
+ subtitle, format_=format if format != "auto" else None
70
69
  ) # str
71
- except:
72
- subs: pysubs2.SSAFile = pysubs2.load(subtitle, encoding='utf-8') # auto detect format
70
+ except Exception as e:
71
+ del e
72
+ subs: pysubs2.SSAFile = pysubs2.load(subtitle, encoding="utf-8") # auto detect format
73
73
 
74
74
  supervisions = []
75
75
  for event in subs.events:
76
+ # NOT apply text_parser.py:normalize_html_text here, to keep original text in subtitles
76
77
  speaker, text = parse_speaker_text(event.text)
77
78
  supervisions.append(
78
79
  Supervision(
@@ -24,10 +24,10 @@ class Supervision(SupervisionSegment):
24
24
  """
25
25
 
26
26
  text: Optional[str] = None
27
- id: str = ''
28
- recording_id: str = ''
27
+ id: str = ""
28
+ recording_id: str = ""
29
29
  start: Seconds = 0.0
30
30
  duration: Seconds = 0.0
31
31
 
32
32
 
33
- __all__ = ['Supervision']
33
+ __all__ = ["Supervision"]
@@ -3,23 +3,50 @@ import re
3
3
  from typing import Optional, Tuple
4
4
 
5
5
  # 来自于字幕中常见的说话人标记格式
6
- SPEAKER_PATTERN = re.compile(r'((?:>>|>>|>|>).*?[::])\s*(.*)')
6
+ SPEAKER_PATTERN = re.compile(r"((?:>>|>>|>|>).*?[::])\s*(.*)")
7
7
 
8
8
  # Transcriber Output Example:
9
9
  # 26:19.919 --> 26:34.921
10
10
  # [SPEAKER_01]: 越来越多的科技巨头入...
11
- SPEAKER_LATTIFAI = re.compile(r'(^\[SPEAKER_.*?\][::])\s*(.*)')
11
+ SPEAKER_LATTIFAI = re.compile(r"(^\[SPEAKER_.*?\][::])\s*(.*)")
12
12
 
13
13
  # NISHTHA BHATIA: Hey, everyone.
14
14
  # DIETER: Oh, hey, Nishtha.
15
15
  # GEMINI: That might
16
- SPEAKER_PATTERN2 = re.compile(r'^([A-Z]{1,15}(?:\s+[A-Z]{1,15})?[::])\s*(.*)$')
16
+ SPEAKER_PATTERN2 = re.compile(r"^([A-Z]{1,15}(?:\s+[A-Z]{1,15})?[::])\s*(.*)$")
17
+
18
+
19
+ def normalize_html_text(text: str) -> str:
20
+ """Normalize HTML text by decoding entities and stripping whitespace."""
21
+ html_entities = {
22
+ "&": "&",
23
+ "&lt;": "<",
24
+ "&gt;": ">",
25
+ "&quot;": '"',
26
+ "&#39;": "'",
27
+ "&nbsp;": " ",
28
+ "\\N": " ",
29
+ "…": " ",
30
+ }
31
+ for entity, char in html_entities.items():
32
+ text = text.replace(entity, char)
33
+
34
+ text = re.sub(r"\s+", " ", text) # Replace multiple spaces with a single space
35
+
36
+ # Convert curly apostrophes to straight apostrophes for common English contractions
37
+ # Handles: 't 's 'll 're 've 'd 'm
38
+ # For example, convert "don't" to "don't"
39
+ text = re.sub(r"([a-zA-Z])’([tsdm]|ll|re|ve)\b", r"\1'\2", text, flags=re.IGNORECASE)
40
+ # For example, convert "5’s" to "5's"
41
+ text = re.sub(r"([0-9])’([s])\b", r"\1'\2", text, flags=re.IGNORECASE)
42
+
43
+ return text.strip()
17
44
 
18
45
 
19
46
  def parse_speaker_text(line) -> Tuple[Optional[str], str]:
20
- line = line.replace('\\N', ' ')
47
+ """Parse a line of text to extract speaker and content."""
21
48
 
22
- if ':' not in line and '' not in line:
49
+ if ":" not in line and "" not in line:
23
50
  return None, line
24
51
 
25
52
  # 匹配以 >> 开头的行,并去除开头的名字和冒号
@@ -31,7 +58,7 @@ def parse_speaker_text(line) -> Tuple[Optional[str], str]:
31
58
  if match:
32
59
  assert len(match.groups()) == 2, match.groups()
33
60
  if not match.group(1):
34
- logging.error(f'ParseSub LINE [{line}]')
61
+ logging.error(f"ParseSub LINE [{line}]")
35
62
  else:
36
63
  return match.group(1).strip(), match.group(2).strip()
37
64
 
@@ -43,15 +70,15 @@ def parse_speaker_text(line) -> Tuple[Optional[str], str]:
43
70
  return None, line
44
71
 
45
72
 
46
- if __name__ == '__main__':
47
- pattern = re.compile(r'>>\s*(.*?)\s*[::]\s*(.*)')
48
- pattern = re.compile(r'(>>.*?[::])\s*(.*)')
73
+ if __name__ == "__main__":
74
+ pattern = re.compile(r">>\s*(.*?)\s*[::]\s*(.*)")
75
+ pattern = re.compile(r"(>>.*?[::])\s*(.*)")
49
76
 
50
77
  test_strings = [
51
- '>>Key: Value',
52
- '>> Key with space : Value with space ',
53
- '>> 全角键 : 全角值',
54
- '>>Key:Value xxx. >>Key:Value',
78
+ ">>Key: Value",
79
+ ">> Key with space : Value with space ",
80
+ ">> 全角键 : 全角值",
81
+ ">>Key:Value xxx. >>Key:Value",
55
82
  ]
56
83
 
57
84
  for text in test_strings:
@@ -60,16 +87,16 @@ if __name__ == '__main__':
60
87
  print(f"Input: '{text}'")
61
88
  print(f" Key: '{match.group(1)}'")
62
89
  print(f" Value: '{match.group(2)}'")
63
- print('-------------')
90
+ print("-------------")
64
91
 
65
92
  # pattern2
66
- test_strings2 = ['NISHTHA BHATIA: Hey, everyone.', 'DIETER: Oh, hey, Nishtha.', 'GEMINI: That might']
93
+ test_strings2 = ["NISHTHA BHATIA: Hey, everyone.", "DIETER: Oh, hey, Nishtha.", "GEMINI: That might"]
67
94
  for text in test_strings2:
68
95
  match = SPEAKER_PATTERN2.match(text)
69
96
  if match:
70
97
  print(f" Input: '{text}'")
71
98
  print(f"Speaker: '{match.group(1)}'")
72
99
  print(f"Content: '{match.group(2)}'")
73
- print('-------------')
100
+ print("-------------")
74
101
  else:
75
102
  raise ValueError(f"No match for: '{text}'")
lattifai/io/utils.py CHANGED
@@ -3,13 +3,13 @@ Utility constants and helper functions for subtitle I/O operations
3
3
  """
4
4
 
5
5
  # Supported subtitle formats for reading/writing
6
- SUBTITLE_FORMATS = ['srt', 'vtt', 'ass', 'ssa', 'sub', 'sbv', 'txt', 'md']
6
+ SUBTITLE_FORMATS = ["srt", "vtt", "ass", "ssa", "sub", "sbv", "txt", "md"]
7
7
 
8
8
  # Input subtitle formats (includes special formats like 'auto' and 'gemini')
9
- INPUT_SUBTITLE_FORMATS = ['srt', 'vtt', 'ass', 'ssa', 'sub', 'sbv', 'txt', 'auto', 'gemini']
9
+ INPUT_SUBTITLE_FORMATS = ["srt", "vtt", "ass", "ssa", "sub", "sbv", "txt", "auto", "gemini"]
10
10
 
11
11
  # Output subtitle formats (includes special formats like 'TextGrid' and 'json')
12
- OUTPUT_SUBTITLE_FORMATS = ['srt', 'vtt', 'ass', 'ssa', 'sub', 'sbv', 'txt', 'TextGrid', 'json']
12
+ OUTPUT_SUBTITLE_FORMATS = ["srt", "vtt", "ass", "ssa", "sub", "sbv", "txt", "TextGrid", "json"]
13
13
 
14
14
  # All subtitle formats combined (for file detection)
15
- ALL_SUBTITLE_FORMATS = list(set(SUBTITLE_FORMATS + ['TextGrid', 'json', 'gemini']))
15
+ ALL_SUBTITLE_FORMATS = list(set(SUBTITLE_FORMATS + ["TextGrid", "json", "gemini"]))
lattifai/io/writer.py CHANGED
@@ -14,45 +14,57 @@ class SubtitleWriter(ABCMeta):
14
14
 
15
15
  @classmethod
16
16
  def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
17
- if str(output_path)[-4:].lower() == '.txt':
18
- with open(output_path, 'w', encoding='utf-8') as f:
17
+ if str(output_path)[-4:].lower() == ".txt":
18
+ with open(output_path, "w", encoding="utf-8") as f:
19
19
  for sup in alignments:
20
20
  word_items = parse_alignment_from_supervision(sup)
21
21
  if word_items:
22
22
  for item in word_items:
23
- f.write(f'[{item.start:.2f}-{item.end:.2f}] {item.symbol}\n')
23
+ f.write(f"[{item.start:.2f}-{item.end:.2f}] {item.symbol}\n")
24
24
  else:
25
- text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
26
- f.write(f'[{sup.start:.2f}-{sup.end:.2f}] {text}\n')
25
+ text = f"{sup.speaker} {sup.text}" if sup.speaker is not None else sup.text
26
+ f.write(f"[{sup.start:.2f}-{sup.end:.2f}] {text}\n")
27
27
 
28
- elif str(output_path)[-5:].lower() == '.json':
29
- with open(output_path, 'w', encoding='utf-8') as f:
28
+ elif str(output_path)[-5:].lower() == ".json":
29
+ with open(output_path, "w", encoding="utf-8") as f:
30
30
  # Enhanced JSON export with word-level alignment
31
31
  json_data = []
32
32
  for sup in alignments:
33
33
  sup_dict = sup.to_dict()
34
34
  json_data.append(sup_dict)
35
35
  json.dump(json_data, f, ensure_ascii=False, indent=4)
36
- elif str(output_path).endswith('.TextGrid') or str(output_path).endswith('.textgrid'):
36
+ elif str(output_path).lower().endswith(".textgrid"):
37
37
  from tgt import Interval, IntervalTier, TextGrid, write_to_file
38
38
 
39
39
  tg = TextGrid()
40
- supervisions, words = [], []
40
+ supervisions, words, scores = [], [], {"utterances": [], "words": []}
41
41
  for supervision in sorted(alignments, key=lambda x: x.start):
42
42
  text = (
43
- f'{supervision.speaker} {supervision.text}' if supervision.speaker is not None else supervision.text
43
+ f"{supervision.speaker} {supervision.text}" if supervision.speaker is not None else supervision.text
44
44
  )
45
- supervisions.append(Interval(supervision.start, supervision.end, text or ''))
45
+ supervisions.append(Interval(supervision.start, supervision.end, text or ""))
46
46
  # Extract word-level alignment using helper function
47
47
  word_items = parse_alignment_from_supervision(supervision)
48
48
  if word_items:
49
49
  for item in word_items:
50
50
  words.append(Interval(item.start, item.end, item.symbol))
51
+ if item.score is not None:
52
+ scores["words"].append(Interval(item.start, item.end, f"{item.score:.2f}"))
53
+ if supervision.has_custom("score"):
54
+ scores["utterances"].append(
55
+ Interval(supervision.start, supervision.end, f"{supervision.score:.2f}")
56
+ )
51
57
 
52
- tg.add_tier(IntervalTier(name='utterances', objects=supervisions))
58
+ tg.add_tier(IntervalTier(name="utterances", objects=supervisions))
53
59
  if words:
54
- tg.add_tier(IntervalTier(name='words', objects=words))
55
- write_to_file(tg, output_path, format='long')
60
+ tg.add_tier(IntervalTier(name="words", objects=words))
61
+
62
+ if scores["utterances"]:
63
+ tg.add_tier(IntervalTier(name="utterance_scores", objects=scores["utterances"]))
64
+ if scores["words"]:
65
+ tg.add_tier(IntervalTier(name="word_scores", objects=scores["words"]))
66
+
67
+ write_to_file(tg, output_path, format="long")
56
68
  else:
57
69
  subs = pysubs2.SSAFile()
58
70
  for sup in alignments:
@@ -64,8 +76,8 @@ class SubtitleWriter(ABCMeta):
64
76
  pysubs2.SSAEvent(start=int(word.start * 1000), end=int(word.end * 1000), text=word.symbol)
65
77
  )
66
78
  else:
67
- text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
68
- subs.append(pysubs2.SSAEvent(start=int(sup.start * 1000), end=int(sup.end * 1000), text=text or ''))
79
+ text = f"{sup.speaker} {sup.text}" if sup.speaker is not None else sup.text
80
+ subs.append(pysubs2.SSAEvent(start=int(sup.start * 1000), end=int(sup.end * 1000), text=text or ""))
69
81
  subs.save(output_path)
70
82
 
71
83
  return output_path
@@ -81,10 +93,10 @@ def parse_alignment_from_supervision(supervision: Any) -> Optional[List[Alignmen
81
93
  Returns:
82
94
  List of AlignmentItem objects, or None if no alignment data present
83
95
  """
84
- if not hasattr(supervision, 'alignment') or not supervision.alignment:
96
+ if not hasattr(supervision, "alignment") or not supervision.alignment:
85
97
  return None
86
98
 
87
- if 'word' not in supervision.alignment:
99
+ if "word" not in supervision.alignment:
88
100
  return None
89
101
 
90
- return supervision.alignment['word']
102
+ return supervision.alignment["word"]
@@ -1,3 +1,3 @@
1
1
  from .tokenizer import AsyncLatticeTokenizer, LatticeTokenizer
2
2
 
3
- __all__ = ['LatticeTokenizer', 'AsyncLatticeTokenizer']
3
+ __all__ = ["LatticeTokenizer", "AsyncLatticeTokenizer"]
@@ -4,13 +4,13 @@ from typing import List, Optional, Union
4
4
  from dp.phonemizer import Phonemizer # g2p-phonemizer
5
5
  from num2words import num2words
6
6
 
7
- LANGUAGE = 'omni'
7
+ LANGUAGE = "omni"
8
8
 
9
9
 
10
10
  class G2Phonemizer:
11
11
  def __init__(self, model_checkpoint, device):
12
12
  self.phonemizer = Phonemizer.from_checkpoint(model_checkpoint, device=device).predictor
13
- self.pattern = re.compile(r'\d+')
13
+ self.pattern = re.compile(r"\d+")
14
14
 
15
15
  def num2words(self, word, lang: str):
16
16
  matches = self.pattern.findall(word)
@@ -31,7 +31,7 @@ class G2Phonemizer:
31
31
  is_list = False
32
32
 
33
33
  predictions = self.phonemizer(
34
- [self.num2words(word.replace(' .', '.').replace('.', ' .'), lang=lang or 'en') for word in words],
34
+ [self.num2words(word.replace(" .", ".").replace(".", " ."), lang=lang or "en") for word in words],
35
35
  lang=LANGUAGE,
36
36
  batch_size=min(batch_size or len(words), 128),
37
37
  num_prons=num_prons,