lattifai 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/alignment/__init__.py +10 -1
- lattifai/alignment/lattice1_aligner.py +66 -58
- lattifai/alignment/punctuation.py +38 -0
- lattifai/alignment/sentence_splitter.py +152 -21
- lattifai/alignment/text_align.py +440 -0
- lattifai/alignment/tokenizer.py +82 -40
- lattifai/caption/__init__.py +82 -6
- lattifai/caption/caption.py +335 -1141
- lattifai/caption/formats/__init__.py +199 -0
- lattifai/caption/formats/base.py +211 -0
- lattifai/caption/{gemini_reader.py → formats/gemini.py} +320 -60
- lattifai/caption/formats/json.py +194 -0
- lattifai/caption/formats/lrc.py +309 -0
- lattifai/caption/formats/nle/__init__.py +9 -0
- lattifai/caption/formats/nle/audition.py +561 -0
- lattifai/caption/formats/nle/avid.py +423 -0
- lattifai/caption/formats/nle/fcpxml.py +549 -0
- lattifai/caption/formats/nle/premiere.py +589 -0
- lattifai/caption/formats/pysubs2.py +642 -0
- lattifai/caption/formats/sbv.py +147 -0
- lattifai/caption/formats/tabular.py +338 -0
- lattifai/caption/formats/textgrid.py +193 -0
- lattifai/caption/formats/ttml.py +652 -0
- lattifai/caption/formats/vtt.py +469 -0
- lattifai/caption/parsers/__init__.py +9 -0
- lattifai/caption/{text_parser.py → parsers/text_parser.py} +4 -2
- lattifai/caption/standardize.py +636 -0
- lattifai/caption/utils.py +474 -0
- lattifai/cli/__init__.py +2 -1
- lattifai/cli/caption.py +108 -1
- lattifai/cli/transcribe.py +1 -1
- lattifai/cli/youtube.py +4 -1
- lattifai/client.py +33 -113
- lattifai/config/__init__.py +11 -1
- lattifai/config/alignment.py +7 -0
- lattifai/config/caption.py +267 -23
- lattifai/config/media.py +20 -0
- lattifai/diarization/__init__.py +41 -1
- lattifai/mixin.py +27 -15
- lattifai/transcription/base.py +6 -1
- lattifai/transcription/lattifai.py +19 -54
- lattifai/utils.py +7 -13
- lattifai/workflow/__init__.py +28 -4
- lattifai/workflow/file_manager.py +2 -5
- lattifai/youtube/__init__.py +43 -0
- lattifai/youtube/client.py +1170 -0
- lattifai/youtube/types.py +23 -0
- lattifai-1.2.2.dist-info/METADATA +615 -0
- lattifai-1.2.2.dist-info/RECORD +76 -0
- {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/entry_points.txt +1 -2
- lattifai/caption/gemini_writer.py +0 -173
- lattifai/cli/app_installer.py +0 -142
- lattifai/cli/server.py +0 -44
- lattifai/server/app.py +0 -427
- lattifai/workflow/youtube.py +0 -577
- lattifai-1.2.1.dist-info/METADATA +0 -1134
- lattifai-1.2.1.dist-info/RECORD +0 -58
- {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/WHEEL +0 -0
- {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/top_level.txt +0 -0
lattifai/alignment/__init__.py
CHANGED
|
@@ -2,5 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
from .lattice1_aligner import Lattice1Aligner
|
|
4
4
|
from .segmenter import Segmenter
|
|
5
|
+
from .sentence_splitter import SentenceSplitter
|
|
6
|
+
from .text_align import align_supervisions_and_transcription
|
|
7
|
+
from .tokenizer import tokenize_multilingual_text
|
|
5
8
|
|
|
6
|
-
__all__ = [
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Lattice1Aligner",
|
|
11
|
+
"Segmenter",
|
|
12
|
+
"SentenceSplitter",
|
|
13
|
+
"align_supervisions_and_transcription",
|
|
14
|
+
"tokenize_multilingual_text",
|
|
15
|
+
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Lattice-1 Aligner implementation."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, List, Optional, Tuple
|
|
3
|
+
from typing import Any, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import colorful
|
|
6
6
|
import numpy as np
|
|
@@ -16,11 +16,22 @@ from lattifai.errors import (
|
|
|
16
16
|
from lattifai.utils import _resolve_model_path, safe_print
|
|
17
17
|
|
|
18
18
|
from .lattice1_worker import _load_worker
|
|
19
|
+
from .text_align import TextAlignResult
|
|
19
20
|
from .tokenizer import _load_tokenizer
|
|
20
21
|
|
|
21
22
|
ClientType = Any
|
|
22
23
|
|
|
23
24
|
|
|
25
|
+
def _extract_text_for_error(supervisions: Union[list, tuple]) -> str:
|
|
26
|
+
"""Extract text from supervisions for error messages."""
|
|
27
|
+
if not supervisions:
|
|
28
|
+
return ""
|
|
29
|
+
# TextAlignResult is a tuple: (caption_sups, transcript_sups, ...)
|
|
30
|
+
if isinstance(supervisions, tuple):
|
|
31
|
+
supervisions = supervisions[0] or supervisions[1] or []
|
|
32
|
+
return " ".join(s.text for s in supervisions if s and s.text)
|
|
33
|
+
|
|
34
|
+
|
|
24
35
|
class Lattice1Aligner(object):
|
|
25
36
|
"""Synchronous LattifAI client with config-driven architecture."""
|
|
26
37
|
|
|
@@ -79,7 +90,7 @@ class Lattice1Aligner(object):
|
|
|
79
90
|
def alignment(
|
|
80
91
|
self,
|
|
81
92
|
audio: AudioData,
|
|
82
|
-
supervisions: List[Supervision],
|
|
93
|
+
supervisions: Union[List[Supervision], TextAlignResult],
|
|
83
94
|
split_sentence: Optional[bool] = False,
|
|
84
95
|
return_details: Optional[bool] = False,
|
|
85
96
|
emission: Optional[np.ndarray] = None,
|
|
@@ -102,69 +113,66 @@ class Lattice1Aligner(object):
|
|
|
102
113
|
AlignmentError: If audio alignment fails
|
|
103
114
|
LatticeDecodingError: If lattice decoding fails
|
|
104
115
|
"""
|
|
116
|
+
# Step 2: Create lattice graph
|
|
117
|
+
if verbose:
|
|
118
|
+
safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
|
|
105
119
|
try:
|
|
120
|
+
supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
|
|
121
|
+
supervisions, split_sentence=split_sentence, boost=self.config.boost
|
|
122
|
+
)
|
|
106
123
|
if verbose:
|
|
107
|
-
safe_print(colorful.
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
|
|
120
|
-
if audio.streaming_mode:
|
|
121
|
-
safe_print(
|
|
122
|
-
colorful.yellow(
|
|
123
|
-
f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
|
|
124
|
-
)
|
|
124
|
+
safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
|
|
125
|
+
except Exception as e:
|
|
126
|
+
text_content = _extract_text_for_error(supervisions)
|
|
127
|
+
raise LatticeEncodingError(text_content, original_error=e)
|
|
128
|
+
|
|
129
|
+
# Step 3: Search lattice graph
|
|
130
|
+
if verbose:
|
|
131
|
+
safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
|
|
132
|
+
if audio.streaming_mode:
|
|
133
|
+
safe_print(
|
|
134
|
+
colorful.yellow(
|
|
135
|
+
f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
|
|
125
136
|
)
|
|
126
|
-
try:
|
|
127
|
-
lattice_results = self.worker.alignment(
|
|
128
|
-
audio,
|
|
129
|
-
lattice_graph,
|
|
130
|
-
emission=emission,
|
|
131
|
-
offset=offset,
|
|
132
|
-
)
|
|
133
|
-
if verbose:
|
|
134
|
-
safe_print(colorful.green(" ✓ Lattice search completed"))
|
|
135
|
-
except Exception as e:
|
|
136
|
-
raise AlignmentError(
|
|
137
|
-
f"Audio alignment failed for {audio}",
|
|
138
|
-
media_path=str(audio),
|
|
139
|
-
context={"original_error": str(e)},
|
|
140
137
|
)
|
|
141
|
-
|
|
138
|
+
try:
|
|
139
|
+
lattice_results = self.worker.alignment(
|
|
140
|
+
audio,
|
|
141
|
+
lattice_graph,
|
|
142
|
+
emission=emission,
|
|
143
|
+
offset=offset,
|
|
144
|
+
)
|
|
142
145
|
if verbose:
|
|
143
|
-
safe_print(colorful.
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
146
|
+
safe_print(colorful.green(" ✓ Lattice search completed"))
|
|
147
|
+
except Exception as e:
|
|
148
|
+
raise AlignmentError(
|
|
149
|
+
f"Audio alignment failed for {audio}",
|
|
150
|
+
media_path=str(audio),
|
|
151
|
+
context={"original_error": str(e)},
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Step 4: Decode lattice results
|
|
155
|
+
if verbose:
|
|
156
|
+
safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
|
|
157
|
+
try:
|
|
158
|
+
alignments = self.tokenizer.detokenize(
|
|
159
|
+
lattice_id,
|
|
160
|
+
lattice_results,
|
|
161
|
+
supervisions=supervisions,
|
|
162
|
+
return_details=return_details,
|
|
163
|
+
start_margin=self.config.start_margin,
|
|
164
|
+
end_margin=self.config.end_margin,
|
|
165
|
+
)
|
|
166
|
+
if verbose:
|
|
167
|
+
safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
168
|
+
except LatticeDecodingError:
|
|
169
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
165
170
|
raise
|
|
166
171
|
except Exception as e:
|
|
167
|
-
|
|
172
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
173
|
+
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
174
|
+
|
|
175
|
+
return (supervisions, alignments)
|
|
168
176
|
|
|
169
177
|
def profile(self) -> None:
|
|
170
178
|
"""Print profiling statistics."""
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Multilingual punctuation characters (no duplicates)
|
|
2
|
+
PUNCTUATION = (
|
|
3
|
+
# ASCII punctuation
|
|
4
|
+
"!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
|
5
|
+
# Chinese/CJK punctuation
|
|
6
|
+
"。,、?!:;·¥…—~“”‘’"
|
|
7
|
+
"《》〈〉【】〔〕〖〗()" # Brackets
|
|
8
|
+
# Japanese punctuation (unique only)
|
|
9
|
+
"「」『』・"
|
|
10
|
+
# CJK punctuation (unique only)
|
|
11
|
+
"〃〆〇〒〓〘〙〚〛〜〝〞〟"
|
|
12
|
+
# Arabic punctuation
|
|
13
|
+
"،؛؟"
|
|
14
|
+
# Thai punctuation
|
|
15
|
+
"๏๚๛"
|
|
16
|
+
# Hebrew punctuation
|
|
17
|
+
"־׀׃׆"
|
|
18
|
+
# Other common punctuation
|
|
19
|
+
"¡¿" # Spanish inverted marks
|
|
20
|
+
"«»‹›" # Guillemets
|
|
21
|
+
"‐‑‒–―" # Dashes (excluding — already above)
|
|
22
|
+
"‚„" # Low quotation marks
|
|
23
|
+
"†‡•‣" # Daggers and bullets
|
|
24
|
+
"′″‴" # Prime marks
|
|
25
|
+
"‰‱" # Per mille
|
|
26
|
+
)
|
|
27
|
+
PUNCTUATION_SPACE = PUNCTUATION + " "
|
|
28
|
+
STAR_TOKEN = "※"
|
|
29
|
+
|
|
30
|
+
# End of sentence punctuation marks (multilingual)
|
|
31
|
+
# - ASCII: .!?"'])
|
|
32
|
+
# - Chinese/CJK: 。!?"】」』〗〙〛 (including right double quote U+201D)
|
|
33
|
+
# - Japanese: 。 (halfwidth period)
|
|
34
|
+
# - Arabic: ؟
|
|
35
|
+
# - Ellipsis: …
|
|
36
|
+
END_PUNCTUATION = ".!?\"'])。!?\u201d】」』〗〙〛。؟…"
|
|
37
|
+
|
|
38
|
+
GROUPING_SEPARATOR = "✹"
|
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from typing import List, Optional
|
|
3
3
|
|
|
4
|
+
from lattifai.alignment.punctuation import END_PUNCTUATION
|
|
4
5
|
from lattifai.caption import Supervision
|
|
5
6
|
from lattifai.utils import _resolve_model_path
|
|
6
7
|
|
|
7
|
-
END_PUNCTUATION = '.!?"]。!?"】'
|
|
8
|
-
|
|
9
8
|
|
|
10
9
|
class SentenceSplitter:
|
|
11
10
|
"""Lazy-initialized sentence splitter using wtpsplit."""
|
|
12
11
|
|
|
13
|
-
def __init__(self, device: str = "cpu", model_hub: Optional[str] =
|
|
12
|
+
def __init__(self, device: str = "cpu", model_hub: Optional[str] = "modelscope", lazy_init: bool = True):
|
|
14
13
|
"""Initialize sentence splitter with lazy loading.
|
|
15
14
|
|
|
16
15
|
Args:
|
|
@@ -19,9 +18,8 @@ class SentenceSplitter:
|
|
|
19
18
|
"""
|
|
20
19
|
self.device = device
|
|
21
20
|
self.model_hub = model_hub
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
else:
|
|
21
|
+
self._splitter = None
|
|
22
|
+
if not lazy_init:
|
|
25
23
|
self._init_splitter()
|
|
26
24
|
|
|
27
25
|
def _init_splitter(self):
|
|
@@ -56,6 +54,121 @@ class SentenceSplitter:
|
|
|
56
54
|
)
|
|
57
55
|
self._splitter = sat
|
|
58
56
|
|
|
57
|
+
@staticmethod
|
|
58
|
+
def _distribute_time_info(
|
|
59
|
+
input_supervisions: List[Supervision],
|
|
60
|
+
split_texts: List[str],
|
|
61
|
+
) -> List[Supervision]:
|
|
62
|
+
"""Distribute time information from input supervisions to split sentences.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
input_supervisions: Original supervisions with time information
|
|
66
|
+
split_texts: List of split sentence texts
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of Supervision objects with distributed time information.
|
|
70
|
+
Custom attributes are inherited from first_sup with conflict markers.
|
|
71
|
+
"""
|
|
72
|
+
if not input_supervisions:
|
|
73
|
+
return [Supervision(text=text, id="", recording_id="", start=0, duration=0) for text in split_texts]
|
|
74
|
+
|
|
75
|
+
# Build concatenated input text
|
|
76
|
+
input_text = " ".join(sup.text for sup in input_supervisions)
|
|
77
|
+
|
|
78
|
+
# Pre-compute supervision position mapping for O(1) lookup
|
|
79
|
+
# Format: [(start_pos, end_pos, supervision), ...]
|
|
80
|
+
sup_ranges = []
|
|
81
|
+
char_pos = 0
|
|
82
|
+
for sup in input_supervisions:
|
|
83
|
+
sup_start = char_pos
|
|
84
|
+
sup_end = char_pos + len(sup.text)
|
|
85
|
+
sup_ranges.append((sup_start, sup_end, sup))
|
|
86
|
+
char_pos = sup_end + 1 # +1 for space separator
|
|
87
|
+
|
|
88
|
+
# Process each split text
|
|
89
|
+
result = []
|
|
90
|
+
search_start = 0
|
|
91
|
+
sup_idx = 0 # Track current supervision index to skip processed ones
|
|
92
|
+
|
|
93
|
+
for split_text in split_texts:
|
|
94
|
+
text_start = input_text.find(split_text, search_start)
|
|
95
|
+
if text_start == -1:
|
|
96
|
+
raise ValueError(f"Could not find split text '{split_text}' in input supervisions.")
|
|
97
|
+
|
|
98
|
+
text_end = text_start + len(split_text)
|
|
99
|
+
search_start = text_end
|
|
100
|
+
|
|
101
|
+
# Find overlapping supervisions, starting from last used index
|
|
102
|
+
first_sup = None
|
|
103
|
+
last_sup = None
|
|
104
|
+
first_char_idx = None
|
|
105
|
+
last_char_idx = None
|
|
106
|
+
overlapping_customs = [] # Track all custom dicts for conflict detection
|
|
107
|
+
|
|
108
|
+
# Start from sup_idx, which is the first supervision that might overlap
|
|
109
|
+
for i in range(sup_idx, len(sup_ranges)):
|
|
110
|
+
sup_start, sup_end, sup = sup_ranges[i]
|
|
111
|
+
|
|
112
|
+
# Skip if no overlap (before text_start)
|
|
113
|
+
if sup_end <= text_start:
|
|
114
|
+
sup_idx = i + 1 # Update starting point for next iteration
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
# Stop if no overlap (after text_end)
|
|
118
|
+
if sup_start >= text_end:
|
|
119
|
+
break
|
|
120
|
+
|
|
121
|
+
# Found overlap
|
|
122
|
+
if first_sup is None:
|
|
123
|
+
first_sup = sup
|
|
124
|
+
first_char_idx = max(0, text_start - sup_start)
|
|
125
|
+
|
|
126
|
+
last_sup = sup
|
|
127
|
+
last_char_idx = min(len(sup.text) - 1, text_end - 1 - sup_start)
|
|
128
|
+
|
|
129
|
+
# Collect custom dict for conflict detection
|
|
130
|
+
if getattr(sup, "custom", None):
|
|
131
|
+
overlapping_customs.append(sup.custom)
|
|
132
|
+
|
|
133
|
+
if first_sup is None or last_sup is None:
|
|
134
|
+
raise ValueError(f"Could not find supervisions for split text: {split_text}")
|
|
135
|
+
|
|
136
|
+
# Calculate timing
|
|
137
|
+
start_time = first_sup.start + (first_char_idx / len(first_sup.text)) * first_sup.duration
|
|
138
|
+
end_time = last_sup.start + ((last_char_idx + 1) / len(last_sup.text)) * last_sup.duration
|
|
139
|
+
|
|
140
|
+
# Inherit custom from first_sup, mark conflicts if multiple sources
|
|
141
|
+
merged_custom = None
|
|
142
|
+
if overlapping_customs:
|
|
143
|
+
# Start with first_sup's custom (inherit strategy)
|
|
144
|
+
merged_custom = overlapping_customs[0].copy() if overlapping_customs[0] else {}
|
|
145
|
+
|
|
146
|
+
# Detect conflicts if multiple overlapping supervisions have different custom values
|
|
147
|
+
if len(overlapping_customs) > 1:
|
|
148
|
+
has_conflict = False
|
|
149
|
+
for other_custom in overlapping_customs[1:]:
|
|
150
|
+
if other_custom and other_custom != overlapping_customs[0]:
|
|
151
|
+
has_conflict = True
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
if has_conflict:
|
|
155
|
+
# Mark that this supervision spans multiple sources with different customs
|
|
156
|
+
merged_custom["_split_from_multiple"] = True
|
|
157
|
+
merged_custom["_source_count"] = len(overlapping_customs)
|
|
158
|
+
|
|
159
|
+
result.append(
|
|
160
|
+
Supervision(
|
|
161
|
+
id="",
|
|
162
|
+
text=split_text,
|
|
163
|
+
start=start_time,
|
|
164
|
+
duration=end_time - start_time,
|
|
165
|
+
recording_id=first_sup.recording_id,
|
|
166
|
+
custom=merged_custom,
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return result
|
|
171
|
+
|
|
59
172
|
@staticmethod
|
|
60
173
|
def _resplit_special_sentence_types(sentence: str) -> List[str]:
|
|
61
174
|
"""
|
|
@@ -150,15 +263,22 @@ class SentenceSplitter:
|
|
|
150
263
|
elif text_len >= 2000 or is_last:
|
|
151
264
|
flush_segment(s, None)
|
|
152
265
|
|
|
153
|
-
|
|
266
|
+
if len(speakers) != len(texts):
|
|
267
|
+
raise ValueError(f"len(speakers)={len(speakers)} != len(texts)={len(texts)}")
|
|
154
268
|
sentences = self._splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace, batch_size=8)
|
|
155
269
|
|
|
156
|
-
|
|
270
|
+
# First pass: collect all split texts with their speakers
|
|
271
|
+
split_texts_with_speakers = []
|
|
272
|
+
remainder = ""
|
|
273
|
+
remainder_speaker = None
|
|
274
|
+
|
|
157
275
|
for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
|
|
158
276
|
# Prepend remainder from previous iteration to the first sentence
|
|
159
277
|
if _sentences and remainder:
|
|
160
278
|
_sentences[0] = remainder + _sentences[0]
|
|
279
|
+
_speaker = remainder_speaker if remainder_speaker else _speaker
|
|
161
280
|
remainder = ""
|
|
281
|
+
remainder_speaker = None
|
|
162
282
|
|
|
163
283
|
if not _sentences:
|
|
164
284
|
continue
|
|
@@ -188,32 +308,43 @@ class SentenceSplitter:
|
|
|
188
308
|
continue
|
|
189
309
|
|
|
190
310
|
if any(_sentences[-1].endswith(ep) for ep in END_PUNCTUATION):
|
|
191
|
-
|
|
192
|
-
|
|
311
|
+
split_texts_with_speakers.extend(
|
|
312
|
+
(text, _speaker if s == 0 else None) for s, text in enumerate(_sentences)
|
|
193
313
|
)
|
|
194
314
|
_speaker = None # reset speaker after use
|
|
195
315
|
else:
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
for s, text in enumerate(_sentences[:-1])
|
|
316
|
+
split_texts_with_speakers.extend(
|
|
317
|
+
(text, _speaker if s == 0 else None) for s, text in enumerate(_sentences[:-1])
|
|
199
318
|
)
|
|
200
319
|
remainder = _sentences[-1] + " " + remainder
|
|
201
320
|
if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
|
|
202
|
-
|
|
203
|
-
Supervision(text=remainder.strip(), speaker=_speaker if len(_sentences) == 1 else None)
|
|
204
|
-
)
|
|
321
|
+
split_texts_with_speakers.append((remainder.strip(), _speaker if len(_sentences) == 1 else None))
|
|
205
322
|
remainder = ""
|
|
323
|
+
remainder_speaker = None
|
|
206
324
|
elif len(_sentences) == 1:
|
|
325
|
+
remainder_speaker = _speaker
|
|
207
326
|
if k == len(speakers) - 1:
|
|
208
327
|
pass # keep _speaker for the last supervision
|
|
328
|
+
elif speakers[k + 1] is not None:
|
|
329
|
+
raise ValueError(f"Expected speakers[{k + 1}] to be None, got {speakers[k + 1]}")
|
|
209
330
|
else:
|
|
210
|
-
assert speakers[k + 1] is None
|
|
211
331
|
speakers[k + 1] = _speaker
|
|
212
|
-
|
|
213
|
-
assert len(_sentences) > 1
|
|
332
|
+
elif len(_sentences) > 1:
|
|
214
333
|
_speaker = None # reset speaker if sentence not ended
|
|
334
|
+
remainder_speaker = None
|
|
335
|
+
else:
|
|
336
|
+
raise ValueError(f"Unexpected state: len(_sentences)={len(_sentences)}")
|
|
215
337
|
|
|
216
338
|
if remainder.strip():
|
|
217
|
-
|
|
339
|
+
split_texts_with_speakers.append((remainder.strip(), remainder_speaker))
|
|
340
|
+
|
|
341
|
+
# Second pass: distribute time information
|
|
342
|
+
split_texts = [text for text, _ in split_texts_with_speakers]
|
|
343
|
+
result_supervisions = self._distribute_time_info(supervisions, split_texts)
|
|
344
|
+
|
|
345
|
+
# Third pass: add speaker information
|
|
346
|
+
for sup, (_, speaker) in zip(result_supervisions, split_texts_with_speakers):
|
|
347
|
+
if speaker:
|
|
348
|
+
sup.speaker = speaker
|
|
218
349
|
|
|
219
|
-
return
|
|
350
|
+
return result_supervisions
|