lattifai 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +0 -24
- lattifai/alignment/__init__.py +10 -1
- lattifai/alignment/lattice1_aligner.py +66 -58
- lattifai/alignment/lattice1_worker.py +1 -6
- lattifai/alignment/punctuation.py +38 -0
- lattifai/alignment/segmenter.py +1 -1
- lattifai/alignment/sentence_splitter.py +350 -0
- lattifai/alignment/text_align.py +440 -0
- lattifai/alignment/tokenizer.py +91 -220
- lattifai/caption/__init__.py +82 -6
- lattifai/caption/caption.py +335 -1143
- lattifai/caption/formats/__init__.py +199 -0
- lattifai/caption/formats/base.py +211 -0
- lattifai/caption/formats/gemini.py +722 -0
- lattifai/caption/formats/json.py +194 -0
- lattifai/caption/formats/lrc.py +309 -0
- lattifai/caption/formats/nle/__init__.py +9 -0
- lattifai/caption/formats/nle/audition.py +561 -0
- lattifai/caption/formats/nle/avid.py +423 -0
- lattifai/caption/formats/nle/fcpxml.py +549 -0
- lattifai/caption/formats/nle/premiere.py +589 -0
- lattifai/caption/formats/pysubs2.py +642 -0
- lattifai/caption/formats/sbv.py +147 -0
- lattifai/caption/formats/tabular.py +338 -0
- lattifai/caption/formats/textgrid.py +193 -0
- lattifai/caption/formats/ttml.py +652 -0
- lattifai/caption/formats/vtt.py +469 -0
- lattifai/caption/parsers/__init__.py +9 -0
- lattifai/caption/{text_parser.py → parsers/text_parser.py} +4 -2
- lattifai/caption/standardize.py +636 -0
- lattifai/caption/utils.py +474 -0
- lattifai/cli/__init__.py +2 -1
- lattifai/cli/caption.py +108 -1
- lattifai/cli/transcribe.py +4 -9
- lattifai/cli/youtube.py +4 -1
- lattifai/client.py +48 -84
- lattifai/config/__init__.py +11 -1
- lattifai/config/alignment.py +9 -2
- lattifai/config/caption.py +267 -23
- lattifai/config/media.py +20 -0
- lattifai/diarization/__init__.py +41 -1
- lattifai/mixin.py +36 -18
- lattifai/transcription/base.py +6 -1
- lattifai/transcription/lattifai.py +19 -54
- lattifai/utils.py +81 -13
- lattifai/workflow/__init__.py +28 -4
- lattifai/workflow/file_manager.py +2 -5
- lattifai/youtube/__init__.py +43 -0
- lattifai/youtube/client.py +1170 -0
- lattifai/youtube/types.py +23 -0
- lattifai-1.2.2.dist-info/METADATA +615 -0
- lattifai-1.2.2.dist-info/RECORD +76 -0
- {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/entry_points.txt +1 -2
- lattifai/caption/gemini_reader.py +0 -371
- lattifai/caption/gemini_writer.py +0 -173
- lattifai/cli/app_installer.py +0 -142
- lattifai/cli/server.py +0 -44
- lattifai/server/app.py +0 -427
- lattifai/workflow/youtube.py +0 -577
- lattifai-1.2.0.dist-info/METADATA +0 -1133
- lattifai-1.2.0.dist-info/RECORD +0 -57
- {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/WHEEL +0 -0
- {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/top_level.txt +0 -0
lattifai/__init__.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import sys
|
|
3
2
|
import warnings
|
|
4
3
|
from importlib.metadata import version
|
|
5
4
|
|
|
@@ -52,29 +51,6 @@ except Exception:
|
|
|
52
51
|
__version__ = "0.1.0" # fallback version
|
|
53
52
|
|
|
54
53
|
|
|
55
|
-
# Check and auto-install k2py if not present
|
|
56
|
-
def _check_and_install_k2py():
|
|
57
|
-
"""Check if k2py is installed and attempt to install it if not."""
|
|
58
|
-
try:
|
|
59
|
-
import k2py
|
|
60
|
-
except ImportError:
|
|
61
|
-
import subprocess
|
|
62
|
-
|
|
63
|
-
print("k2py is not installed. Attempting to install k2py...")
|
|
64
|
-
try:
|
|
65
|
-
subprocess.check_call([sys.executable, "-m", "pip", "install", "k2py"])
|
|
66
|
-
import k2py # Try importing again after installation
|
|
67
|
-
|
|
68
|
-
print("k2py installed successfully.")
|
|
69
|
-
except Exception as e:
|
|
70
|
-
warnings.warn(f"Failed to install k2py automatically. Please install it manually. Error: {e}")
|
|
71
|
-
return True
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
# Auto-install k2py on first import
|
|
75
|
-
_check_and_install_k2py()
|
|
76
|
-
|
|
77
|
-
|
|
78
54
|
__all__ = [
|
|
79
55
|
# Client classes
|
|
80
56
|
"LattifAI",
|
lattifai/alignment/__init__.py
CHANGED
|
@@ -2,5 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
from .lattice1_aligner import Lattice1Aligner
|
|
4
4
|
from .segmenter import Segmenter
|
|
5
|
+
from .sentence_splitter import SentenceSplitter
|
|
6
|
+
from .text_align import align_supervisions_and_transcription
|
|
7
|
+
from .tokenizer import tokenize_multilingual_text
|
|
5
8
|
|
|
6
|
-
__all__ = [
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Lattice1Aligner",
|
|
11
|
+
"Segmenter",
|
|
12
|
+
"SentenceSplitter",
|
|
13
|
+
"align_supervisions_and_transcription",
|
|
14
|
+
"tokenize_multilingual_text",
|
|
15
|
+
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Lattice-1 Aligner implementation."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, List, Optional, Tuple
|
|
3
|
+
from typing import Any, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import colorful
|
|
6
6
|
import numpy as np
|
|
@@ -16,11 +16,22 @@ from lattifai.errors import (
|
|
|
16
16
|
from lattifai.utils import _resolve_model_path, safe_print
|
|
17
17
|
|
|
18
18
|
from .lattice1_worker import _load_worker
|
|
19
|
+
from .text_align import TextAlignResult
|
|
19
20
|
from .tokenizer import _load_tokenizer
|
|
20
21
|
|
|
21
22
|
ClientType = Any
|
|
22
23
|
|
|
23
24
|
|
|
25
|
+
def _extract_text_for_error(supervisions: Union[list, tuple]) -> str:
|
|
26
|
+
"""Extract text from supervisions for error messages."""
|
|
27
|
+
if not supervisions:
|
|
28
|
+
return ""
|
|
29
|
+
# TextAlignResult is a tuple: (caption_sups, transcript_sups, ...)
|
|
30
|
+
if isinstance(supervisions, tuple):
|
|
31
|
+
supervisions = supervisions[0] or supervisions[1] or []
|
|
32
|
+
return " ".join(s.text for s in supervisions if s and s.text)
|
|
33
|
+
|
|
34
|
+
|
|
24
35
|
class Lattice1Aligner(object):
|
|
25
36
|
"""Synchronous LattifAI client with config-driven architecture."""
|
|
26
37
|
|
|
@@ -79,7 +90,7 @@ class Lattice1Aligner(object):
|
|
|
79
90
|
def alignment(
|
|
80
91
|
self,
|
|
81
92
|
audio: AudioData,
|
|
82
|
-
supervisions: List[Supervision],
|
|
93
|
+
supervisions: Union[List[Supervision], TextAlignResult],
|
|
83
94
|
split_sentence: Optional[bool] = False,
|
|
84
95
|
return_details: Optional[bool] = False,
|
|
85
96
|
emission: Optional[np.ndarray] = None,
|
|
@@ -102,69 +113,66 @@ class Lattice1Aligner(object):
|
|
|
102
113
|
AlignmentError: If audio alignment fails
|
|
103
114
|
LatticeDecodingError: If lattice decoding fails
|
|
104
115
|
"""
|
|
116
|
+
# Step 2: Create lattice graph
|
|
117
|
+
if verbose:
|
|
118
|
+
safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
|
|
105
119
|
try:
|
|
120
|
+
supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
|
|
121
|
+
supervisions, split_sentence=split_sentence, boost=self.config.boost
|
|
122
|
+
)
|
|
106
123
|
if verbose:
|
|
107
|
-
safe_print(colorful.
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
|
|
120
|
-
if audio.streaming_chunk_secs:
|
|
121
|
-
safe_print(
|
|
122
|
-
colorful.yellow(
|
|
123
|
-
f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
|
|
124
|
-
)
|
|
124
|
+
safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
|
|
125
|
+
except Exception as e:
|
|
126
|
+
text_content = _extract_text_for_error(supervisions)
|
|
127
|
+
raise LatticeEncodingError(text_content, original_error=e)
|
|
128
|
+
|
|
129
|
+
# Step 3: Search lattice graph
|
|
130
|
+
if verbose:
|
|
131
|
+
safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
|
|
132
|
+
if audio.streaming_mode:
|
|
133
|
+
safe_print(
|
|
134
|
+
colorful.yellow(
|
|
135
|
+
f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
|
|
125
136
|
)
|
|
126
|
-
try:
|
|
127
|
-
lattice_results = self.worker.alignment(
|
|
128
|
-
audio,
|
|
129
|
-
lattice_graph,
|
|
130
|
-
emission=emission,
|
|
131
|
-
offset=offset,
|
|
132
|
-
)
|
|
133
|
-
if verbose:
|
|
134
|
-
safe_print(colorful.green(" ✓ Lattice search completed"))
|
|
135
|
-
except Exception as e:
|
|
136
|
-
raise AlignmentError(
|
|
137
|
-
f"Audio alignment failed for {audio}",
|
|
138
|
-
media_path=str(audio),
|
|
139
|
-
context={"original_error": str(e)},
|
|
140
137
|
)
|
|
141
|
-
|
|
138
|
+
try:
|
|
139
|
+
lattice_results = self.worker.alignment(
|
|
140
|
+
audio,
|
|
141
|
+
lattice_graph,
|
|
142
|
+
emission=emission,
|
|
143
|
+
offset=offset,
|
|
144
|
+
)
|
|
142
145
|
if verbose:
|
|
143
|
-
safe_print(colorful.
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
146
|
+
safe_print(colorful.green(" ✓ Lattice search completed"))
|
|
147
|
+
except Exception as e:
|
|
148
|
+
raise AlignmentError(
|
|
149
|
+
f"Audio alignment failed for {audio}",
|
|
150
|
+
media_path=str(audio),
|
|
151
|
+
context={"original_error": str(e)},
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Step 4: Decode lattice results
|
|
155
|
+
if verbose:
|
|
156
|
+
safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
|
|
157
|
+
try:
|
|
158
|
+
alignments = self.tokenizer.detokenize(
|
|
159
|
+
lattice_id,
|
|
160
|
+
lattice_results,
|
|
161
|
+
supervisions=supervisions,
|
|
162
|
+
return_details=return_details,
|
|
163
|
+
start_margin=self.config.start_margin,
|
|
164
|
+
end_margin=self.config.end_margin,
|
|
165
|
+
)
|
|
166
|
+
if verbose:
|
|
167
|
+
safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
168
|
+
except LatticeDecodingError:
|
|
169
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
165
170
|
raise
|
|
166
171
|
except Exception as e:
|
|
167
|
-
|
|
172
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
173
|
+
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
174
|
+
|
|
175
|
+
return (supervisions, alignments)
|
|
168
176
|
|
|
169
177
|
def profile(self) -> None:
|
|
170
178
|
"""Print profiling statistics."""
|
|
@@ -7,8 +7,6 @@ from typing import Any, Dict, Optional, Tuple
|
|
|
7
7
|
import colorful
|
|
8
8
|
import numpy as np
|
|
9
9
|
import onnxruntime as ort
|
|
10
|
-
from lhotse import FbankConfig
|
|
11
|
-
from lhotse.features.kaldi.layers import Wav2LogFilterBank
|
|
12
10
|
from lhotse.utils import Pathlike
|
|
13
11
|
from tqdm import tqdm
|
|
14
12
|
|
|
@@ -159,10 +157,7 @@ class Lattice1Worker:
|
|
|
159
157
|
DependencyError: If required dependencies are missing
|
|
160
158
|
AlignmentError: If alignment process fails
|
|
161
159
|
"""
|
|
162
|
-
|
|
163
|
-
import k2py as k2
|
|
164
|
-
except ImportError:
|
|
165
|
-
raise DependencyError("k2py", install_command="pip install k2py")
|
|
160
|
+
import k2py as k2
|
|
166
161
|
|
|
167
162
|
lattice_graph_str, final_state, acoustic_scale = lattice_graph
|
|
168
163
|
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Multilingual punctuation characters (no duplicates)
|
|
2
|
+
PUNCTUATION = (
|
|
3
|
+
# ASCII punctuation
|
|
4
|
+
"!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
|
5
|
+
# Chinese/CJK punctuation
|
|
6
|
+
"。,、?!:;·¥…—~“”‘’"
|
|
7
|
+
"《》〈〉【】〔〕〖〗()" # Brackets
|
|
8
|
+
# Japanese punctuation (unique only)
|
|
9
|
+
"「」『』・"
|
|
10
|
+
# CJK punctuation (unique only)
|
|
11
|
+
"〃〆〇〒〓〘〙〚〛〜〝〞〟"
|
|
12
|
+
# Arabic punctuation
|
|
13
|
+
"،؛؟"
|
|
14
|
+
# Thai punctuation
|
|
15
|
+
"๏๚๛"
|
|
16
|
+
# Hebrew punctuation
|
|
17
|
+
"־׀׃׆"
|
|
18
|
+
# Other common punctuation
|
|
19
|
+
"¡¿" # Spanish inverted marks
|
|
20
|
+
"«»‹›" # Guillemets
|
|
21
|
+
"‐‑‒–―" # Dashes (excluding — already above)
|
|
22
|
+
"‚„" # Low quotation marks
|
|
23
|
+
"†‡•‣" # Daggers and bullets
|
|
24
|
+
"′″‴" # Prime marks
|
|
25
|
+
"‰‱" # Per mille
|
|
26
|
+
)
|
|
27
|
+
PUNCTUATION_SPACE = PUNCTUATION + " "
|
|
28
|
+
STAR_TOKEN = "※"
|
|
29
|
+
|
|
30
|
+
# End of sentence punctuation marks (multilingual)
|
|
31
|
+
# - ASCII: .!?"'])
|
|
32
|
+
# - Chinese/CJK: 。!?"】」』〗〙〛 (including right double quote U+201D)
|
|
33
|
+
# - Japanese: 。 (halfwidth period)
|
|
34
|
+
# - Arabic: ؟
|
|
35
|
+
# - Ellipsis: …
|
|
36
|
+
END_PUNCTUATION = ".!?\"'])。!?\u201d】」』〗〙〛。؟…"
|
|
37
|
+
|
|
38
|
+
GROUPING_SEPARATOR = "✹"
|
lattifai/alignment/segmenter.py
CHANGED
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from lattifai.alignment.punctuation import END_PUNCTUATION
|
|
5
|
+
from lattifai.caption import Supervision
|
|
6
|
+
from lattifai.utils import _resolve_model_path
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SentenceSplitter:
|
|
10
|
+
"""Lazy-initialized sentence splitter using wtpsplit."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, device: str = "cpu", model_hub: Optional[str] = "modelscope", lazy_init: bool = True):
|
|
13
|
+
"""Initialize sentence splitter with lazy loading.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
device: Device to run the model on (cpu, cuda, mps)
|
|
17
|
+
model_hub: Model hub to use (None for huggingface, "modelscope" for modelscope)
|
|
18
|
+
"""
|
|
19
|
+
self.device = device
|
|
20
|
+
self.model_hub = model_hub
|
|
21
|
+
self._splitter = None
|
|
22
|
+
if not lazy_init:
|
|
23
|
+
self._init_splitter()
|
|
24
|
+
|
|
25
|
+
def _init_splitter(self):
|
|
26
|
+
"""Initialize the sentence splitter model on first use."""
|
|
27
|
+
if self._splitter is not None:
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
import onnxruntime as ort
|
|
31
|
+
from wtpsplit import SaT
|
|
32
|
+
|
|
33
|
+
providers = []
|
|
34
|
+
device = self.device
|
|
35
|
+
if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
|
|
36
|
+
providers.append("CUDAExecutionProvider")
|
|
37
|
+
elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
|
|
38
|
+
providers.append("MPSExecutionProvider")
|
|
39
|
+
|
|
40
|
+
if self.model_hub == "modelscope":
|
|
41
|
+
downloaded_path = _resolve_model_path("LattifAI/OmniTokenizer", model_hub="modelscope")
|
|
42
|
+
sat = SaT(
|
|
43
|
+
f"{downloaded_path}/sat-3l-sm",
|
|
44
|
+
tokenizer_name_or_path=f"{downloaded_path}/xlm-roberta-base",
|
|
45
|
+
ort_providers=providers + ["CPUExecutionProvider"],
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
sat_path = _resolve_model_path("segment-any-text/sat-3l-sm", model_hub="huggingface")
|
|
49
|
+
sat = SaT(
|
|
50
|
+
sat_path,
|
|
51
|
+
tokenizer_name_or_path="facebookAI/xlm-roberta-base",
|
|
52
|
+
hub_prefix="segment-any-text",
|
|
53
|
+
ort_providers=providers + ["CPUExecutionProvider"],
|
|
54
|
+
)
|
|
55
|
+
self._splitter = sat
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def _distribute_time_info(
|
|
59
|
+
input_supervisions: List[Supervision],
|
|
60
|
+
split_texts: List[str],
|
|
61
|
+
) -> List[Supervision]:
|
|
62
|
+
"""Distribute time information from input supervisions to split sentences.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
input_supervisions: Original supervisions with time information
|
|
66
|
+
split_texts: List of split sentence texts
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of Supervision objects with distributed time information.
|
|
70
|
+
Custom attributes are inherited from first_sup with conflict markers.
|
|
71
|
+
"""
|
|
72
|
+
if not input_supervisions:
|
|
73
|
+
return [Supervision(text=text, id="", recording_id="", start=0, duration=0) for text in split_texts]
|
|
74
|
+
|
|
75
|
+
# Build concatenated input text
|
|
76
|
+
input_text = " ".join(sup.text for sup in input_supervisions)
|
|
77
|
+
|
|
78
|
+
# Pre-compute supervision position mapping for O(1) lookup
|
|
79
|
+
# Format: [(start_pos, end_pos, supervision), ...]
|
|
80
|
+
sup_ranges = []
|
|
81
|
+
char_pos = 0
|
|
82
|
+
for sup in input_supervisions:
|
|
83
|
+
sup_start = char_pos
|
|
84
|
+
sup_end = char_pos + len(sup.text)
|
|
85
|
+
sup_ranges.append((sup_start, sup_end, sup))
|
|
86
|
+
char_pos = sup_end + 1 # +1 for space separator
|
|
87
|
+
|
|
88
|
+
# Process each split text
|
|
89
|
+
result = []
|
|
90
|
+
search_start = 0
|
|
91
|
+
sup_idx = 0 # Track current supervision index to skip processed ones
|
|
92
|
+
|
|
93
|
+
for split_text in split_texts:
|
|
94
|
+
text_start = input_text.find(split_text, search_start)
|
|
95
|
+
if text_start == -1:
|
|
96
|
+
raise ValueError(f"Could not find split text '{split_text}' in input supervisions.")
|
|
97
|
+
|
|
98
|
+
text_end = text_start + len(split_text)
|
|
99
|
+
search_start = text_end
|
|
100
|
+
|
|
101
|
+
# Find overlapping supervisions, starting from last used index
|
|
102
|
+
first_sup = None
|
|
103
|
+
last_sup = None
|
|
104
|
+
first_char_idx = None
|
|
105
|
+
last_char_idx = None
|
|
106
|
+
overlapping_customs = [] # Track all custom dicts for conflict detection
|
|
107
|
+
|
|
108
|
+
# Start from sup_idx, which is the first supervision that might overlap
|
|
109
|
+
for i in range(sup_idx, len(sup_ranges)):
|
|
110
|
+
sup_start, sup_end, sup = sup_ranges[i]
|
|
111
|
+
|
|
112
|
+
# Skip if no overlap (before text_start)
|
|
113
|
+
if sup_end <= text_start:
|
|
114
|
+
sup_idx = i + 1 # Update starting point for next iteration
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
# Stop if no overlap (after text_end)
|
|
118
|
+
if sup_start >= text_end:
|
|
119
|
+
break
|
|
120
|
+
|
|
121
|
+
# Found overlap
|
|
122
|
+
if first_sup is None:
|
|
123
|
+
first_sup = sup
|
|
124
|
+
first_char_idx = max(0, text_start - sup_start)
|
|
125
|
+
|
|
126
|
+
last_sup = sup
|
|
127
|
+
last_char_idx = min(len(sup.text) - 1, text_end - 1 - sup_start)
|
|
128
|
+
|
|
129
|
+
# Collect custom dict for conflict detection
|
|
130
|
+
if getattr(sup, "custom", None):
|
|
131
|
+
overlapping_customs.append(sup.custom)
|
|
132
|
+
|
|
133
|
+
if first_sup is None or last_sup is None:
|
|
134
|
+
raise ValueError(f"Could not find supervisions for split text: {split_text}")
|
|
135
|
+
|
|
136
|
+
# Calculate timing
|
|
137
|
+
start_time = first_sup.start + (first_char_idx / len(first_sup.text)) * first_sup.duration
|
|
138
|
+
end_time = last_sup.start + ((last_char_idx + 1) / len(last_sup.text)) * last_sup.duration
|
|
139
|
+
|
|
140
|
+
# Inherit custom from first_sup, mark conflicts if multiple sources
|
|
141
|
+
merged_custom = None
|
|
142
|
+
if overlapping_customs:
|
|
143
|
+
# Start with first_sup's custom (inherit strategy)
|
|
144
|
+
merged_custom = overlapping_customs[0].copy() if overlapping_customs[0] else {}
|
|
145
|
+
|
|
146
|
+
# Detect conflicts if multiple overlapping supervisions have different custom values
|
|
147
|
+
if len(overlapping_customs) > 1:
|
|
148
|
+
has_conflict = False
|
|
149
|
+
for other_custom in overlapping_customs[1:]:
|
|
150
|
+
if other_custom and other_custom != overlapping_customs[0]:
|
|
151
|
+
has_conflict = True
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
if has_conflict:
|
|
155
|
+
# Mark that this supervision spans multiple sources with different customs
|
|
156
|
+
merged_custom["_split_from_multiple"] = True
|
|
157
|
+
merged_custom["_source_count"] = len(overlapping_customs)
|
|
158
|
+
|
|
159
|
+
result.append(
|
|
160
|
+
Supervision(
|
|
161
|
+
id="",
|
|
162
|
+
text=split_text,
|
|
163
|
+
start=start_time,
|
|
164
|
+
duration=end_time - start_time,
|
|
165
|
+
recording_id=first_sup.recording_id,
|
|
166
|
+
custom=merged_custom,
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return result
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def _resplit_special_sentence_types(sentence: str) -> List[str]:
|
|
174
|
+
"""
|
|
175
|
+
Re-split special sentence types.
|
|
176
|
+
|
|
177
|
+
Examples:
|
|
178
|
+
'[APPLAUSE] >> MIRA MURATI:' -> ['[APPLAUSE]', '>> MIRA MURATI:']
|
|
179
|
+
'[MUSIC] >> SPEAKER:' -> ['[MUSIC]', '>> SPEAKER:']
|
|
180
|
+
|
|
181
|
+
Special handling patterns:
|
|
182
|
+
1. Separate special marks at the beginning (e.g., [APPLAUSE], [MUSIC], etc.) from subsequent speaker marks
|
|
183
|
+
2. Use speaker marks (>> or other separators) as split points
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
sentence: Input sentence string
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
List of re-split sentences. If no special marks are found, returns the original sentence in a list
|
|
190
|
+
"""
|
|
191
|
+
# Detect special mark patterns: [SOMETHING] >> SPEAKER:
|
|
192
|
+
# or other forms like [SOMETHING] SPEAKER:
|
|
193
|
+
|
|
194
|
+
# Pattern 1: [mark] HTML-encoded separator speaker:
|
|
195
|
+
pattern1 = r"^(\[[^\]]+\])\s+(>>|>>)\s+(.+)$"
|
|
196
|
+
match1 = re.match(pattern1, sentence.strip())
|
|
197
|
+
if match1:
|
|
198
|
+
special_mark = match1.group(1)
|
|
199
|
+
separator = match1.group(2)
|
|
200
|
+
speaker_part = match1.group(3)
|
|
201
|
+
return [special_mark, f"{separator} {speaker_part}"]
|
|
202
|
+
|
|
203
|
+
# Pattern 2: [mark] speaker:
|
|
204
|
+
pattern2 = r"^(\[[^\]]+\])\s+([^:]+:)(.*)$"
|
|
205
|
+
match2 = re.match(pattern2, sentence.strip())
|
|
206
|
+
if match2:
|
|
207
|
+
special_mark = match2.group(1)
|
|
208
|
+
speaker_label = match2.group(2)
|
|
209
|
+
remaining = match2.group(3).strip()
|
|
210
|
+
if remaining:
|
|
211
|
+
return [special_mark, f"{speaker_label} {remaining}"]
|
|
212
|
+
else:
|
|
213
|
+
return [special_mark, speaker_label]
|
|
214
|
+
|
|
215
|
+
# If no special pattern matches, return the original sentence
|
|
216
|
+
return [sentence]
|
|
217
|
+
|
|
218
|
+
def split_sentences(self, supervisions: List[Supervision], strip_whitespace=True) -> List[Supervision]:
|
|
219
|
+
"""Split supervisions into sentences using the sentence splitter.
|
|
220
|
+
|
|
221
|
+
Careful about speaker changes.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
supervisions: List of Supervision objects to split
|
|
225
|
+
strip_whitespace: Whether to strip whitespace from split sentences
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
List of Supervision objects with split sentences
|
|
229
|
+
"""
|
|
230
|
+
self._init_splitter()
|
|
231
|
+
|
|
232
|
+
texts, speakers = [], []
|
|
233
|
+
text_len, sidx = 0, 0
|
|
234
|
+
|
|
235
|
+
def flush_segment(end_idx: int, speaker: Optional[str] = None):
|
|
236
|
+
"""Flush accumulated text from sidx to end_idx with given speaker."""
|
|
237
|
+
nonlocal text_len, sidx
|
|
238
|
+
if sidx <= end_idx:
|
|
239
|
+
if len(speakers) < len(texts) + 1:
|
|
240
|
+
speakers.append(speaker)
|
|
241
|
+
text = " ".join(sup.text for sup in supervisions[sidx : end_idx + 1])
|
|
242
|
+
texts.append(text)
|
|
243
|
+
sidx = end_idx + 1
|
|
244
|
+
text_len = 0
|
|
245
|
+
|
|
246
|
+
for s, supervision in enumerate(supervisions):
|
|
247
|
+
text_len += len(supervision.text)
|
|
248
|
+
is_last = s == len(supervisions) - 1
|
|
249
|
+
|
|
250
|
+
if supervision.speaker:
|
|
251
|
+
# Flush previous segment without speaker (if any)
|
|
252
|
+
if sidx < s:
|
|
253
|
+
flush_segment(s - 1, None)
|
|
254
|
+
text_len = len(supervision.text)
|
|
255
|
+
|
|
256
|
+
# Check if we should flush this speaker's segment now
|
|
257
|
+
next_has_speaker = not is_last and supervisions[s + 1].speaker
|
|
258
|
+
if is_last or next_has_speaker:
|
|
259
|
+
flush_segment(s, supervision.speaker)
|
|
260
|
+
else:
|
|
261
|
+
speakers.append(supervision.speaker)
|
|
262
|
+
|
|
263
|
+
elif text_len >= 2000 or is_last:
|
|
264
|
+
flush_segment(s, None)
|
|
265
|
+
|
|
266
|
+
if len(speakers) != len(texts):
|
|
267
|
+
raise ValueError(f"len(speakers)={len(speakers)} != len(texts)={len(texts)}")
|
|
268
|
+
sentences = self._splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace, batch_size=8)
|
|
269
|
+
|
|
270
|
+
# First pass: collect all split texts with their speakers
|
|
271
|
+
split_texts_with_speakers = []
|
|
272
|
+
remainder = ""
|
|
273
|
+
remainder_speaker = None
|
|
274
|
+
|
|
275
|
+
for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
|
|
276
|
+
# Prepend remainder from previous iteration to the first sentence
|
|
277
|
+
if _sentences and remainder:
|
|
278
|
+
_sentences[0] = remainder + _sentences[0]
|
|
279
|
+
_speaker = remainder_speaker if remainder_speaker else _speaker
|
|
280
|
+
remainder = ""
|
|
281
|
+
remainder_speaker = None
|
|
282
|
+
|
|
283
|
+
if not _sentences:
|
|
284
|
+
continue
|
|
285
|
+
|
|
286
|
+
# Process and re-split special sentence types
|
|
287
|
+
processed_sentences = []
|
|
288
|
+
for s, _sentence in enumerate(_sentences):
|
|
289
|
+
if remainder:
|
|
290
|
+
_sentence = remainder + _sentence
|
|
291
|
+
remainder = ""
|
|
292
|
+
# Detect and split special sentence types: e.g., '[APPLAUSE] >> MIRA MURATI:' -> ['[APPLAUSE]', '>> MIRA MURATI:'] # noqa: E501
|
|
293
|
+
resplit_parts = self._resplit_special_sentence_types(_sentence)
|
|
294
|
+
if any(resplit_parts[-1].endswith(sp) for sp in [":", ":"]):
|
|
295
|
+
if s < len(_sentences) - 1:
|
|
296
|
+
_sentences[s + 1] = resplit_parts[-1] + " " + _sentences[s + 1]
|
|
297
|
+
else: # last part
|
|
298
|
+
remainder = resplit_parts[-1] + " "
|
|
299
|
+
processed_sentences.extend(resplit_parts[:-1])
|
|
300
|
+
else:
|
|
301
|
+
processed_sentences.extend(resplit_parts)
|
|
302
|
+
_sentences = processed_sentences
|
|
303
|
+
|
|
304
|
+
if not _sentences:
|
|
305
|
+
if remainder:
|
|
306
|
+
_sentences, remainder = [remainder.strip()], ""
|
|
307
|
+
else:
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
if any(_sentences[-1].endswith(ep) for ep in END_PUNCTUATION):
|
|
311
|
+
split_texts_with_speakers.extend(
|
|
312
|
+
(text, _speaker if s == 0 else None) for s, text in enumerate(_sentences)
|
|
313
|
+
)
|
|
314
|
+
_speaker = None # reset speaker after use
|
|
315
|
+
else:
|
|
316
|
+
split_texts_with_speakers.extend(
|
|
317
|
+
(text, _speaker if s == 0 else None) for s, text in enumerate(_sentences[:-1])
|
|
318
|
+
)
|
|
319
|
+
remainder = _sentences[-1] + " " + remainder
|
|
320
|
+
if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
|
|
321
|
+
split_texts_with_speakers.append((remainder.strip(), _speaker if len(_sentences) == 1 else None))
|
|
322
|
+
remainder = ""
|
|
323
|
+
remainder_speaker = None
|
|
324
|
+
elif len(_sentences) == 1:
|
|
325
|
+
remainder_speaker = _speaker
|
|
326
|
+
if k == len(speakers) - 1:
|
|
327
|
+
pass # keep _speaker for the last supervision
|
|
328
|
+
elif speakers[k + 1] is not None:
|
|
329
|
+
raise ValueError(f"Expected speakers[{k + 1}] to be None, got {speakers[k + 1]}")
|
|
330
|
+
else:
|
|
331
|
+
speakers[k + 1] = _speaker
|
|
332
|
+
elif len(_sentences) > 1:
|
|
333
|
+
_speaker = None # reset speaker if sentence not ended
|
|
334
|
+
remainder_speaker = None
|
|
335
|
+
else:
|
|
336
|
+
raise ValueError(f"Unexpected state: len(_sentences)={len(_sentences)}")
|
|
337
|
+
|
|
338
|
+
if remainder.strip():
|
|
339
|
+
split_texts_with_speakers.append((remainder.strip(), remainder_speaker))
|
|
340
|
+
|
|
341
|
+
# Second pass: distribute time information
|
|
342
|
+
split_texts = [text for text, _ in split_texts_with_speakers]
|
|
343
|
+
result_supervisions = self._distribute_time_info(supervisions, split_texts)
|
|
344
|
+
|
|
345
|
+
# Third pass: add speaker information
|
|
346
|
+
for sup, (_, speaker) in zip(result_supervisions, split_texts_with_speakers):
|
|
347
|
+
if speaker:
|
|
348
|
+
sup.speaker = speaker
|
|
349
|
+
|
|
350
|
+
return result_supervisions
|