lattifai 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. lattifai/__init__.py +0 -24
  2. lattifai/alignment/__init__.py +10 -1
  3. lattifai/alignment/lattice1_aligner.py +66 -58
  4. lattifai/alignment/lattice1_worker.py +1 -6
  5. lattifai/alignment/punctuation.py +38 -0
  6. lattifai/alignment/segmenter.py +1 -1
  7. lattifai/alignment/sentence_splitter.py +350 -0
  8. lattifai/alignment/text_align.py +440 -0
  9. lattifai/alignment/tokenizer.py +91 -220
  10. lattifai/caption/__init__.py +82 -6
  11. lattifai/caption/caption.py +335 -1143
  12. lattifai/caption/formats/__init__.py +199 -0
  13. lattifai/caption/formats/base.py +211 -0
  14. lattifai/caption/formats/gemini.py +722 -0
  15. lattifai/caption/formats/json.py +194 -0
  16. lattifai/caption/formats/lrc.py +309 -0
  17. lattifai/caption/formats/nle/__init__.py +9 -0
  18. lattifai/caption/formats/nle/audition.py +561 -0
  19. lattifai/caption/formats/nle/avid.py +423 -0
  20. lattifai/caption/formats/nle/fcpxml.py +549 -0
  21. lattifai/caption/formats/nle/premiere.py +589 -0
  22. lattifai/caption/formats/pysubs2.py +642 -0
  23. lattifai/caption/formats/sbv.py +147 -0
  24. lattifai/caption/formats/tabular.py +338 -0
  25. lattifai/caption/formats/textgrid.py +193 -0
  26. lattifai/caption/formats/ttml.py +652 -0
  27. lattifai/caption/formats/vtt.py +469 -0
  28. lattifai/caption/parsers/__init__.py +9 -0
  29. lattifai/caption/{text_parser.py → parsers/text_parser.py} +4 -2
  30. lattifai/caption/standardize.py +636 -0
  31. lattifai/caption/utils.py +474 -0
  32. lattifai/cli/__init__.py +2 -1
  33. lattifai/cli/caption.py +108 -1
  34. lattifai/cli/transcribe.py +4 -9
  35. lattifai/cli/youtube.py +4 -1
  36. lattifai/client.py +48 -84
  37. lattifai/config/__init__.py +11 -1
  38. lattifai/config/alignment.py +9 -2
  39. lattifai/config/caption.py +267 -23
  40. lattifai/config/media.py +20 -0
  41. lattifai/diarization/__init__.py +41 -1
  42. lattifai/mixin.py +36 -18
  43. lattifai/transcription/base.py +6 -1
  44. lattifai/transcription/lattifai.py +19 -54
  45. lattifai/utils.py +81 -13
  46. lattifai/workflow/__init__.py +28 -4
  47. lattifai/workflow/file_manager.py +2 -5
  48. lattifai/youtube/__init__.py +43 -0
  49. lattifai/youtube/client.py +1170 -0
  50. lattifai/youtube/types.py +23 -0
  51. lattifai-1.2.2.dist-info/METADATA +615 -0
  52. lattifai-1.2.2.dist-info/RECORD +76 -0
  53. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/entry_points.txt +1 -2
  54. lattifai/caption/gemini_reader.py +0 -371
  55. lattifai/caption/gemini_writer.py +0 -173
  56. lattifai/cli/app_installer.py +0 -142
  57. lattifai/cli/server.py +0 -44
  58. lattifai/server/app.py +0 -427
  59. lattifai/workflow/youtube.py +0 -577
  60. lattifai-1.2.0.dist-info/METADATA +0 -1133
  61. lattifai-1.2.0.dist-info/RECORD +0 -57
  62. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/WHEEL +0 -0
  63. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/licenses/LICENSE +0 -0
  64. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/top_level.txt +0 -0
lattifai/__init__.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import os
2
- import sys
3
2
  import warnings
4
3
  from importlib.metadata import version
5
4
 
@@ -52,29 +51,6 @@ except Exception:
52
51
  __version__ = "0.1.0" # fallback version
53
52
 
54
53
 
55
- # Check and auto-install k2py if not present
56
- def _check_and_install_k2py():
57
- """Check if k2py is installed and attempt to install it if not."""
58
- try:
59
- import k2py
60
- except ImportError:
61
- import subprocess
62
-
63
- print("k2py is not installed. Attempting to install k2py...")
64
- try:
65
- subprocess.check_call([sys.executable, "-m", "pip", "install", "k2py"])
66
- import k2py # Try importing again after installation
67
-
68
- print("k2py installed successfully.")
69
- except Exception as e:
70
- warnings.warn(f"Failed to install k2py automatically. Please install it manually. Error: {e}")
71
- return True
72
-
73
-
74
- # Auto-install k2py on first import
75
- _check_and_install_k2py()
76
-
77
-
78
54
  __all__ = [
79
55
  # Client classes
80
56
  "LattifAI",
@@ -2,5 +2,14 @@
2
2
 
3
3
  from .lattice1_aligner import Lattice1Aligner
4
4
  from .segmenter import Segmenter
5
+ from .sentence_splitter import SentenceSplitter
6
+ from .text_align import align_supervisions_and_transcription
7
+ from .tokenizer import tokenize_multilingual_text
5
8
 
6
- __all__ = ["Lattice1Aligner", "Segmenter"]
9
+ __all__ = [
10
+ "Lattice1Aligner",
11
+ "Segmenter",
12
+ "SentenceSplitter",
13
+ "align_supervisions_and_transcription",
14
+ "tokenize_multilingual_text",
15
+ ]
@@ -1,6 +1,6 @@
1
1
  """Lattice-1 Aligner implementation."""
2
2
 
3
- from typing import Any, List, Optional, Tuple
3
+ from typing import Any, List, Optional, Tuple, Union
4
4
 
5
5
  import colorful
6
6
  import numpy as np
@@ -16,11 +16,22 @@ from lattifai.errors import (
16
16
  from lattifai.utils import _resolve_model_path, safe_print
17
17
 
18
18
  from .lattice1_worker import _load_worker
19
+ from .text_align import TextAlignResult
19
20
  from .tokenizer import _load_tokenizer
20
21
 
21
22
  ClientType = Any
22
23
 
23
24
 
25
+ def _extract_text_for_error(supervisions: Union[list, tuple]) -> str:
26
+ """Extract text from supervisions for error messages."""
27
+ if not supervisions:
28
+ return ""
29
+ # TextAlignResult is a tuple: (caption_sups, transcript_sups, ...)
30
+ if isinstance(supervisions, tuple):
31
+ supervisions = supervisions[0] or supervisions[1] or []
32
+ return " ".join(s.text for s in supervisions if s and s.text)
33
+
34
+
24
35
  class Lattice1Aligner(object):
25
36
  """Synchronous LattifAI client with config-driven architecture."""
26
37
 
@@ -79,7 +90,7 @@ class Lattice1Aligner(object):
79
90
  def alignment(
80
91
  self,
81
92
  audio: AudioData,
82
- supervisions: List[Supervision],
93
+ supervisions: Union[List[Supervision], TextAlignResult],
83
94
  split_sentence: Optional[bool] = False,
84
95
  return_details: Optional[bool] = False,
85
96
  emission: Optional[np.ndarray] = None,
@@ -102,69 +113,66 @@ class Lattice1Aligner(object):
102
113
  AlignmentError: If audio alignment fails
103
114
  LatticeDecodingError: If lattice decoding fails
104
115
  """
116
+ # Step 2: Create lattice graph
117
+ if verbose:
118
+ safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
105
119
  try:
120
+ supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
121
+ supervisions, split_sentence=split_sentence, boost=self.config.boost
122
+ )
106
123
  if verbose:
107
- safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
108
- try:
109
- supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
110
- supervisions, split_sentence=split_sentence
111
- )
112
- if verbose:
113
- safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
114
- except Exception as e:
115
- text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
116
- raise LatticeEncodingError(text_content, original_error=e)
117
-
118
- if verbose:
119
- safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
120
- if audio.streaming_chunk_secs:
121
- safe_print(
122
- colorful.yellow(
123
- f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
124
- )
124
+ safe_print(colorful.green(f" Generated lattice graph with ID: {lattice_id}"))
125
+ except Exception as e:
126
+ text_content = _extract_text_for_error(supervisions)
127
+ raise LatticeEncodingError(text_content, original_error=e)
128
+
129
+ # Step 3: Search lattice graph
130
+ if verbose:
131
+ safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
132
+ if audio.streaming_mode:
133
+ safe_print(
134
+ colorful.yellow(
135
+ f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
125
136
  )
126
- try:
127
- lattice_results = self.worker.alignment(
128
- audio,
129
- lattice_graph,
130
- emission=emission,
131
- offset=offset,
132
- )
133
- if verbose:
134
- safe_print(colorful.green(" ✓ Lattice search completed"))
135
- except Exception as e:
136
- raise AlignmentError(
137
- f"Audio alignment failed for {audio}",
138
- media_path=str(audio),
139
- context={"original_error": str(e)},
140
137
  )
141
-
138
+ try:
139
+ lattice_results = self.worker.alignment(
140
+ audio,
141
+ lattice_graph,
142
+ emission=emission,
143
+ offset=offset,
144
+ )
142
145
  if verbose:
143
- safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
144
- try:
145
- alignments = self.tokenizer.detokenize(
146
- lattice_id,
147
- lattice_results,
148
- supervisions=supervisions,
149
- return_details=return_details,
150
- start_margin=self.config.start_margin,
151
- end_margin=self.config.end_margin,
152
- )
153
- if verbose:
154
- safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
155
- except LatticeDecodingError as e:
156
- safe_print(colorful.red(" x Failed to decode lattice alignment results"))
157
- raise e
158
- except Exception as e:
159
- safe_print(colorful.red(" x Failed to decode lattice alignment results"))
160
- raise LatticeDecodingError(lattice_id, original_error=e)
161
-
162
- return (supervisions, alignments)
163
-
164
- except (LatticeEncodingError, AlignmentError, LatticeDecodingError):
146
+ safe_print(colorful.green(" Lattice search completed"))
147
+ except Exception as e:
148
+ raise AlignmentError(
149
+ f"Audio alignment failed for {audio}",
150
+ media_path=str(audio),
151
+ context={"original_error": str(e)},
152
+ )
153
+
154
+ # Step 4: Decode lattice results
155
+ if verbose:
156
+ safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
157
+ try:
158
+ alignments = self.tokenizer.detokenize(
159
+ lattice_id,
160
+ lattice_results,
161
+ supervisions=supervisions,
162
+ return_details=return_details,
163
+ start_margin=self.config.start_margin,
164
+ end_margin=self.config.end_margin,
165
+ )
166
+ if verbose:
167
+ safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
168
+ except LatticeDecodingError:
169
+ safe_print(colorful.red(" x Failed to decode lattice alignment results"))
165
170
  raise
166
171
  except Exception as e:
167
- raise e
172
+ safe_print(colorful.red(" x Failed to decode lattice alignment results"))
173
+ raise LatticeDecodingError(lattice_id, original_error=e)
174
+
175
+ return (supervisions, alignments)
168
176
 
169
177
  def profile(self) -> None:
170
178
  """Print profiling statistics."""
@@ -7,8 +7,6 @@ from typing import Any, Dict, Optional, Tuple
7
7
  import colorful
8
8
  import numpy as np
9
9
  import onnxruntime as ort
10
- from lhotse import FbankConfig
11
- from lhotse.features.kaldi.layers import Wav2LogFilterBank
12
10
  from lhotse.utils import Pathlike
13
11
  from tqdm import tqdm
14
12
 
@@ -159,10 +157,7 @@ class Lattice1Worker:
159
157
  DependencyError: If required dependencies are missing
160
158
  AlignmentError: If alignment process fails
161
159
  """
162
- try:
163
- import k2py as k2
164
- except ImportError:
165
- raise DependencyError("k2py", install_command="pip install k2py")
160
+ import k2py as k2
166
161
 
167
162
  lattice_graph_str, final_state, acoustic_scale = lattice_graph
168
163
 
@@ -0,0 +1,38 @@
1
+ # Multilingual punctuation characters (no duplicates)
2
+ PUNCTUATION = (
3
+ # ASCII punctuation
4
+ "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
5
+ # Chinese/CJK punctuation
6
+ "。,、?!:;·¥…—~“”‘’"
7
+ "《》〈〉【】〔〕〖〗()" # Brackets
8
+ # Japanese punctuation (unique only)
9
+ "「」『』・"
10
+ # CJK punctuation (unique only)
11
+ "〃〆〇〒〓〘〙〚〛〜〝〞〟"
12
+ # Arabic punctuation
13
+ "،؛؟"
14
+ # Thai punctuation
15
+ "๏๚๛"
16
+ # Hebrew punctuation
17
+ "־׀׃׆"
18
+ # Other common punctuation
19
+ "¡¿" # Spanish inverted marks
20
+ "«»‹›" # Guillemets
21
+ "‐‑‒–―" # Dashes (excluding — already above)
22
+ "‚„" # Low quotation marks
23
+ "†‡•‣" # Daggers and bullets
24
+ "′″‴" # Prime marks
25
+ "‰‱" # Per mille
26
+ )
27
+ PUNCTUATION_SPACE = PUNCTUATION + " "
28
+ STAR_TOKEN = "※"
29
+
30
+ # End of sentence punctuation marks (multilingual)
31
+ # - ASCII: .!?"'])
32
+ # - Chinese/CJK: 。!?"】」』〗〙〛 (including right double quote U+201D)
33
+ # - Japanese: 。 (halfwidth period)
34
+ # - Arabic: ؟
35
+ # - Ellipsis: …
36
+ END_PUNCTUATION = ".!?\"'])。!?\u201d】」』〗〙〛。؟…"
37
+
38
+ GROUPING_SEPARATOR = "✹"
@@ -9,7 +9,7 @@ from lattifai.caption import Caption, Supervision
9
9
  from lattifai.config import AlignmentConfig
10
10
  from lattifai.utils import safe_print
11
11
 
12
- from .tokenizer import END_PUNCTUATION
12
+ from .sentence_splitter import END_PUNCTUATION
13
13
 
14
14
 
15
15
  class Segmenter:
@@ -0,0 +1,350 @@
1
+ import re
2
+ from typing import List, Optional
3
+
4
+ from lattifai.alignment.punctuation import END_PUNCTUATION
5
+ from lattifai.caption import Supervision
6
+ from lattifai.utils import _resolve_model_path
7
+
8
+
9
+ class SentenceSplitter:
10
+ """Lazy-initialized sentence splitter using wtpsplit."""
11
+
12
+ def __init__(self, device: str = "cpu", model_hub: Optional[str] = "modelscope", lazy_init: bool = True):
13
+ """Initialize sentence splitter with lazy loading.
14
+
15
+ Args:
16
+ device: Device to run the model on (cpu, cuda, mps)
17
+ model_hub: Model hub to use (None for huggingface, "modelscope" for modelscope)
18
+ """
19
+ self.device = device
20
+ self.model_hub = model_hub
21
+ self._splitter = None
22
+ if not lazy_init:
23
+ self._init_splitter()
24
+
25
+ def _init_splitter(self):
26
+ """Initialize the sentence splitter model on first use."""
27
+ if self._splitter is not None:
28
+ return
29
+
30
+ import onnxruntime as ort
31
+ from wtpsplit import SaT
32
+
33
+ providers = []
34
+ device = self.device
35
+ if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
36
+ providers.append("CUDAExecutionProvider")
37
+ elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
38
+ providers.append("MPSExecutionProvider")
39
+
40
+ if self.model_hub == "modelscope":
41
+ downloaded_path = _resolve_model_path("LattifAI/OmniTokenizer", model_hub="modelscope")
42
+ sat = SaT(
43
+ f"{downloaded_path}/sat-3l-sm",
44
+ tokenizer_name_or_path=f"{downloaded_path}/xlm-roberta-base",
45
+ ort_providers=providers + ["CPUExecutionProvider"],
46
+ )
47
+ else:
48
+ sat_path = _resolve_model_path("segment-any-text/sat-3l-sm", model_hub="huggingface")
49
+ sat = SaT(
50
+ sat_path,
51
+ tokenizer_name_or_path="facebookAI/xlm-roberta-base",
52
+ hub_prefix="segment-any-text",
53
+ ort_providers=providers + ["CPUExecutionProvider"],
54
+ )
55
+ self._splitter = sat
56
+
57
+ @staticmethod
58
+ def _distribute_time_info(
59
+ input_supervisions: List[Supervision],
60
+ split_texts: List[str],
61
+ ) -> List[Supervision]:
62
+ """Distribute time information from input supervisions to split sentences.
63
+
64
+ Args:
65
+ input_supervisions: Original supervisions with time information
66
+ split_texts: List of split sentence texts
67
+
68
+ Returns:
69
+ List of Supervision objects with distributed time information.
70
+ Custom attributes are inherited from first_sup with conflict markers.
71
+ """
72
+ if not input_supervisions:
73
+ return [Supervision(text=text, id="", recording_id="", start=0, duration=0) for text in split_texts]
74
+
75
+ # Build concatenated input text
76
+ input_text = " ".join(sup.text for sup in input_supervisions)
77
+
78
+ # Pre-compute supervision position mapping for O(1) lookup
79
+ # Format: [(start_pos, end_pos, supervision), ...]
80
+ sup_ranges = []
81
+ char_pos = 0
82
+ for sup in input_supervisions:
83
+ sup_start = char_pos
84
+ sup_end = char_pos + len(sup.text)
85
+ sup_ranges.append((sup_start, sup_end, sup))
86
+ char_pos = sup_end + 1 # +1 for space separator
87
+
88
+ # Process each split text
89
+ result = []
90
+ search_start = 0
91
+ sup_idx = 0 # Track current supervision index to skip processed ones
92
+
93
+ for split_text in split_texts:
94
+ text_start = input_text.find(split_text, search_start)
95
+ if text_start == -1:
96
+ raise ValueError(f"Could not find split text '{split_text}' in input supervisions.")
97
+
98
+ text_end = text_start + len(split_text)
99
+ search_start = text_end
100
+
101
+ # Find overlapping supervisions, starting from last used index
102
+ first_sup = None
103
+ last_sup = None
104
+ first_char_idx = None
105
+ last_char_idx = None
106
+ overlapping_customs = [] # Track all custom dicts for conflict detection
107
+
108
+ # Start from sup_idx, which is the first supervision that might overlap
109
+ for i in range(sup_idx, len(sup_ranges)):
110
+ sup_start, sup_end, sup = sup_ranges[i]
111
+
112
+ # Skip if no overlap (before text_start)
113
+ if sup_end <= text_start:
114
+ sup_idx = i + 1 # Update starting point for next iteration
115
+ continue
116
+
117
+ # Stop if no overlap (after text_end)
118
+ if sup_start >= text_end:
119
+ break
120
+
121
+ # Found overlap
122
+ if first_sup is None:
123
+ first_sup = sup
124
+ first_char_idx = max(0, text_start - sup_start)
125
+
126
+ last_sup = sup
127
+ last_char_idx = min(len(sup.text) - 1, text_end - 1 - sup_start)
128
+
129
+ # Collect custom dict for conflict detection
130
+ if getattr(sup, "custom", None):
131
+ overlapping_customs.append(sup.custom)
132
+
133
+ if first_sup is None or last_sup is None:
134
+ raise ValueError(f"Could not find supervisions for split text: {split_text}")
135
+
136
+ # Calculate timing
137
+ start_time = first_sup.start + (first_char_idx / len(first_sup.text)) * first_sup.duration
138
+ end_time = last_sup.start + ((last_char_idx + 1) / len(last_sup.text)) * last_sup.duration
139
+
140
+ # Inherit custom from first_sup, mark conflicts if multiple sources
141
+ merged_custom = None
142
+ if overlapping_customs:
143
+ # Start with first_sup's custom (inherit strategy)
144
+ merged_custom = overlapping_customs[0].copy() if overlapping_customs[0] else {}
145
+
146
+ # Detect conflicts if multiple overlapping supervisions have different custom values
147
+ if len(overlapping_customs) > 1:
148
+ has_conflict = False
149
+ for other_custom in overlapping_customs[1:]:
150
+ if other_custom and other_custom != overlapping_customs[0]:
151
+ has_conflict = True
152
+ break
153
+
154
+ if has_conflict:
155
+ # Mark that this supervision spans multiple sources with different customs
156
+ merged_custom["_split_from_multiple"] = True
157
+ merged_custom["_source_count"] = len(overlapping_customs)
158
+
159
+ result.append(
160
+ Supervision(
161
+ id="",
162
+ text=split_text,
163
+ start=start_time,
164
+ duration=end_time - start_time,
165
+ recording_id=first_sup.recording_id,
166
+ custom=merged_custom,
167
+ )
168
+ )
169
+
170
+ return result
171
+
172
+ @staticmethod
173
+ def _resplit_special_sentence_types(sentence: str) -> List[str]:
174
+ """
175
+ Re-split special sentence types.
176
+
177
+ Examples:
178
+ '[APPLAUSE] &gt;&gt; MIRA MURATI:' -> ['[APPLAUSE]', '&gt;&gt; MIRA MURATI:']
179
+ '[MUSIC] &gt;&gt; SPEAKER:' -> ['[MUSIC]', '&gt;&gt; SPEAKER:']
180
+
181
+ Special handling patterns:
182
+ 1. Separate special marks at the beginning (e.g., [APPLAUSE], [MUSIC], etc.) from subsequent speaker marks
183
+ 2. Use speaker marks (&gt;&gt; or other separators) as split points
184
+
185
+ Args:
186
+ sentence: Input sentence string
187
+
188
+ Returns:
189
+ List of re-split sentences. If no special marks are found, returns the original sentence in a list
190
+ """
191
+ # Detect special mark patterns: [SOMETHING] &gt;&gt; SPEAKER:
192
+ # or other forms like [SOMETHING] SPEAKER:
193
+
194
+ # Pattern 1: [mark] HTML-encoded separator speaker:
195
+ pattern1 = r"^(\[[^\]]+\])\s+(&gt;&gt;|>>)\s+(.+)$"
196
+ match1 = re.match(pattern1, sentence.strip())
197
+ if match1:
198
+ special_mark = match1.group(1)
199
+ separator = match1.group(2)
200
+ speaker_part = match1.group(3)
201
+ return [special_mark, f"{separator} {speaker_part}"]
202
+
203
+ # Pattern 2: [mark] speaker:
204
+ pattern2 = r"^(\[[^\]]+\])\s+([^:]+:)(.*)$"
205
+ match2 = re.match(pattern2, sentence.strip())
206
+ if match2:
207
+ special_mark = match2.group(1)
208
+ speaker_label = match2.group(2)
209
+ remaining = match2.group(3).strip()
210
+ if remaining:
211
+ return [special_mark, f"{speaker_label} {remaining}"]
212
+ else:
213
+ return [special_mark, speaker_label]
214
+
215
+ # If no special pattern matches, return the original sentence
216
+ return [sentence]
217
+
218
+ def split_sentences(self, supervisions: List[Supervision], strip_whitespace=True) -> List[Supervision]:
219
+ """Split supervisions into sentences using the sentence splitter.
220
+
221
+ Careful about speaker changes.
222
+
223
+ Args:
224
+ supervisions: List of Supervision objects to split
225
+ strip_whitespace: Whether to strip whitespace from split sentences
226
+
227
+ Returns:
228
+ List of Supervision objects with split sentences
229
+ """
230
+ self._init_splitter()
231
+
232
+ texts, speakers = [], []
233
+ text_len, sidx = 0, 0
234
+
235
+ def flush_segment(end_idx: int, speaker: Optional[str] = None):
236
+ """Flush accumulated text from sidx to end_idx with given speaker."""
237
+ nonlocal text_len, sidx
238
+ if sidx <= end_idx:
239
+ if len(speakers) < len(texts) + 1:
240
+ speakers.append(speaker)
241
+ text = " ".join(sup.text for sup in supervisions[sidx : end_idx + 1])
242
+ texts.append(text)
243
+ sidx = end_idx + 1
244
+ text_len = 0
245
+
246
+ for s, supervision in enumerate(supervisions):
247
+ text_len += len(supervision.text)
248
+ is_last = s == len(supervisions) - 1
249
+
250
+ if supervision.speaker:
251
+ # Flush previous segment without speaker (if any)
252
+ if sidx < s:
253
+ flush_segment(s - 1, None)
254
+ text_len = len(supervision.text)
255
+
256
+ # Check if we should flush this speaker's segment now
257
+ next_has_speaker = not is_last and supervisions[s + 1].speaker
258
+ if is_last or next_has_speaker:
259
+ flush_segment(s, supervision.speaker)
260
+ else:
261
+ speakers.append(supervision.speaker)
262
+
263
+ elif text_len >= 2000 or is_last:
264
+ flush_segment(s, None)
265
+
266
+ if len(speakers) != len(texts):
267
+ raise ValueError(f"len(speakers)={len(speakers)} != len(texts)={len(texts)}")
268
+ sentences = self._splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace, batch_size=8)
269
+
270
+ # First pass: collect all split texts with their speakers
271
+ split_texts_with_speakers = []
272
+ remainder = ""
273
+ remainder_speaker = None
274
+
275
+ for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
276
+ # Prepend remainder from previous iteration to the first sentence
277
+ if _sentences and remainder:
278
+ _sentences[0] = remainder + _sentences[0]
279
+ _speaker = remainder_speaker if remainder_speaker else _speaker
280
+ remainder = ""
281
+ remainder_speaker = None
282
+
283
+ if not _sentences:
284
+ continue
285
+
286
+ # Process and re-split special sentence types
287
+ processed_sentences = []
288
+ for s, _sentence in enumerate(_sentences):
289
+ if remainder:
290
+ _sentence = remainder + _sentence
291
+ remainder = ""
292
+ # Detect and split special sentence types: e.g., '[APPLAUSE] &gt;&gt; MIRA MURATI:' -> ['[APPLAUSE]', '&gt;&gt; MIRA MURATI:'] # noqa: E501
293
+ resplit_parts = self._resplit_special_sentence_types(_sentence)
294
+ if any(resplit_parts[-1].endswith(sp) for sp in [":", ":"]):
295
+ if s < len(_sentences) - 1:
296
+ _sentences[s + 1] = resplit_parts[-1] + " " + _sentences[s + 1]
297
+ else: # last part
298
+ remainder = resplit_parts[-1] + " "
299
+ processed_sentences.extend(resplit_parts[:-1])
300
+ else:
301
+ processed_sentences.extend(resplit_parts)
302
+ _sentences = processed_sentences
303
+
304
+ if not _sentences:
305
+ if remainder:
306
+ _sentences, remainder = [remainder.strip()], ""
307
+ else:
308
+ continue
309
+
310
+ if any(_sentences[-1].endswith(ep) for ep in END_PUNCTUATION):
311
+ split_texts_with_speakers.extend(
312
+ (text, _speaker if s == 0 else None) for s, text in enumerate(_sentences)
313
+ )
314
+ _speaker = None # reset speaker after use
315
+ else:
316
+ split_texts_with_speakers.extend(
317
+ (text, _speaker if s == 0 else None) for s, text in enumerate(_sentences[:-1])
318
+ )
319
+ remainder = _sentences[-1] + " " + remainder
320
+ if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
321
+ split_texts_with_speakers.append((remainder.strip(), _speaker if len(_sentences) == 1 else None))
322
+ remainder = ""
323
+ remainder_speaker = None
324
+ elif len(_sentences) == 1:
325
+ remainder_speaker = _speaker
326
+ if k == len(speakers) - 1:
327
+ pass # keep _speaker for the last supervision
328
+ elif speakers[k + 1] is not None:
329
+ raise ValueError(f"Expected speakers[{k + 1}] to be None, got {speakers[k + 1]}")
330
+ else:
331
+ speakers[k + 1] = _speaker
332
+ elif len(_sentences) > 1:
333
+ _speaker = None # reset speaker if sentence not ended
334
+ remainder_speaker = None
335
+ else:
336
+ raise ValueError(f"Unexpected state: len(_sentences)={len(_sentences)}")
337
+
338
+ if remainder.strip():
339
+ split_texts_with_speakers.append((remainder.strip(), remainder_speaker))
340
+
341
+ # Second pass: distribute time information
342
+ split_texts = [text for text, _ in split_texts_with_speakers]
343
+ result_supervisions = self._distribute_time_info(supervisions, split_texts)
344
+
345
+ # Third pass: add speaker information
346
+ for sup, (_, speaker) in zip(result_supervisions, split_texts_with_speakers):
347
+ if speaker:
348
+ sup.speaker = speaker
349
+
350
+ return result_supervisions