lattifai 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +30 -0
- lattifai/base_client.py +118 -0
- lattifai/bin/__init__.py +2 -0
- lattifai/bin/align.py +42 -0
- lattifai/bin/cli_base.py +14 -0
- lattifai/bin/subtitle.py +32 -0
- lattifai/client.py +131 -0
- lattifai/io/__init__.py +22 -0
- lattifai/io/reader.py +71 -0
- lattifai/io/supervision.py +17 -0
- lattifai/io/writer.py +49 -0
- lattifai/tokenizers/__init__.py +3 -0
- lattifai/tokenizers/phonemizer.py +50 -0
- lattifai/tokenizers/tokenizer.py +143 -0
- lattifai/workers/__init__.py +3 -0
- lattifai/workers/lattice1_alpha.py +119 -0
- lattifai-0.1.4.dist-info/METADATA +467 -0
- lattifai-0.1.4.dist-info/RECORD +22 -0
- lattifai-0.1.4.dist-info/WHEEL +5 -0
- lattifai-0.1.4.dist-info/entry_points.txt +3 -0
- lattifai-0.1.4.dist-info/licenses/LICENSE +21 -0
- lattifai-0.1.4.dist-info/top_level.txt +1 -0
lattifai/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from .base_client import LattifAIError
|
|
2
|
+
from .io import SubtitleIO
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from importlib.metadata import version
|
|
6
|
+
except ImportError:
|
|
7
|
+
# Python < 3.8
|
|
8
|
+
from importlib_metadata import version
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
__version__ = version('lattifai')
|
|
12
|
+
except Exception:
|
|
13
|
+
__version__ = '0.1.0' # fallback version
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Lazy import for LattifAI to avoid dependency issues during basic import
|
|
17
|
+
def __getattr__(name):
|
|
18
|
+
if name == 'LattifAI':
|
|
19
|
+
from .client import LattifAI
|
|
20
|
+
|
|
21
|
+
return LattifAI
|
|
22
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
'LattifAI', # noqa: F822
|
|
27
|
+
'LattifAIError',
|
|
28
|
+
'SubtitleIO',
|
|
29
|
+
'__version__',
|
|
30
|
+
]
|
lattifai/base_client.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Base client classes for LattifAI SDK."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from typing import Any, Awaitable, Callable, Dict, Optional, Union # noqa: F401
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LattifAIError(Exception):
|
|
11
|
+
"""Base exception for LattifAI errors."""
|
|
12
|
+
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BaseAPIClient(ABC):
|
|
17
|
+
"""Abstract base class for API clients."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
*,
|
|
22
|
+
api_key: Optional[str] = None,
|
|
23
|
+
base_url: Optional[str] = None,
|
|
24
|
+
timeout: Union[float, httpx.Timeout] = 60.0,
|
|
25
|
+
max_retries: int = 2,
|
|
26
|
+
default_headers: Optional[Dict[str, str]] = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
if api_key is None:
|
|
29
|
+
api_key = os.environ.get('LATTIFAI_API_KEY')
|
|
30
|
+
if api_key is None:
|
|
31
|
+
raise LattifAIError(
|
|
32
|
+
'The api_key client option must be set either by passing api_key to the client '
|
|
33
|
+
'or by setting the LATTIFAI_API_KEY environment variable'
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
self._api_key = api_key
|
|
37
|
+
self._base_url = base_url
|
|
38
|
+
self._timeout = timeout
|
|
39
|
+
self._max_retries = max_retries
|
|
40
|
+
|
|
41
|
+
headers = {
|
|
42
|
+
'User-Agent': 'LattifAI/Python',
|
|
43
|
+
'Authorization': f'Bearer {self._api_key}',
|
|
44
|
+
}
|
|
45
|
+
if default_headers:
|
|
46
|
+
headers.update(default_headers)
|
|
47
|
+
self._default_headers = headers
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class SyncAPIClient(BaseAPIClient):
|
|
51
|
+
"""Synchronous API client."""
|
|
52
|
+
|
|
53
|
+
def __init__(self, **kwargs) -> None:
|
|
54
|
+
super().__init__(**kwargs)
|
|
55
|
+
self._client = httpx.Client(
|
|
56
|
+
base_url=self._base_url,
|
|
57
|
+
timeout=self._timeout,
|
|
58
|
+
headers=self._default_headers,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def __enter__(self):
|
|
62
|
+
return self
|
|
63
|
+
|
|
64
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
65
|
+
self.close()
|
|
66
|
+
|
|
67
|
+
def close(self) -> None:
|
|
68
|
+
"""Close the HTTP client."""
|
|
69
|
+
self._client.close()
|
|
70
|
+
|
|
71
|
+
def _request(
|
|
72
|
+
self,
|
|
73
|
+
method: str,
|
|
74
|
+
url: str,
|
|
75
|
+
*,
|
|
76
|
+
json: Optional[Dict[str, Any]] = None,
|
|
77
|
+
**kwargs,
|
|
78
|
+
) -> httpx.Response:
|
|
79
|
+
"""Make an HTTP request."""
|
|
80
|
+
return self._client.request(method=method, url=url, json=json, **kwargs)
|
|
81
|
+
|
|
82
|
+
def post(self, api_endpoint: str, *, json: Optional[Dict[str, Any]] = None, **kwargs) -> httpx.Response:
|
|
83
|
+
"""Make a POST request to the specified API endpoint."""
|
|
84
|
+
return self._request('POST', api_endpoint, json=json, **kwargs)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class AsyncAPIClient(BaseAPIClient):
|
|
88
|
+
"""Asynchronous API client."""
|
|
89
|
+
|
|
90
|
+
def __init__(self, **kwargs) -> None:
|
|
91
|
+
super().__init__(**kwargs)
|
|
92
|
+
self._client = httpx.AsyncClient(
|
|
93
|
+
base_url=self._base_url,
|
|
94
|
+
timeout=self._timeout,
|
|
95
|
+
headers=self._default_headers,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
async def __aenter__(self):
|
|
99
|
+
return self
|
|
100
|
+
|
|
101
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
102
|
+
await self.close()
|
|
103
|
+
|
|
104
|
+
async def close(self) -> None:
|
|
105
|
+
"""Close the HTTP client."""
|
|
106
|
+
await self._client.aclose()
|
|
107
|
+
|
|
108
|
+
async def _request(
|
|
109
|
+
self,
|
|
110
|
+
method: str,
|
|
111
|
+
url: str,
|
|
112
|
+
*,
|
|
113
|
+
json: Optional[Dict[str, Any]] = None,
|
|
114
|
+
files: Optional[Dict[str, Any]] = None,
|
|
115
|
+
**kwargs,
|
|
116
|
+
) -> httpx.Response:
|
|
117
|
+
"""Make an HTTP request."""
|
|
118
|
+
return await self._client.request(method=method, url=url, json=json, files=files, **kwargs)
|
lattifai/bin/__init__.py
ADDED
lattifai/bin/align.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import click
|
|
2
|
+
import colorful
|
|
3
|
+
from lhotse.utils import Pathlike
|
|
4
|
+
|
|
5
|
+
from lattifai.bin.cli_base import cli
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@cli.command()
|
|
9
|
+
@click.option(
|
|
10
|
+
'-F',
|
|
11
|
+
'--input_format',
|
|
12
|
+
type=click.Choice(['srt', 'vtt', 'ass', 'txt', 'auto'], case_sensitive=False),
|
|
13
|
+
default='auto',
|
|
14
|
+
help='Input Subtitle format.',
|
|
15
|
+
)
|
|
16
|
+
@click.argument(
|
|
17
|
+
'input_audio_path',
|
|
18
|
+
type=click.Path(exists=True, dir_okay=False),
|
|
19
|
+
)
|
|
20
|
+
@click.argument(
|
|
21
|
+
'input_subtitle_path',
|
|
22
|
+
type=click.Path(exists=True, dir_okay=False),
|
|
23
|
+
)
|
|
24
|
+
@click.argument(
|
|
25
|
+
'output_subtitle_path',
|
|
26
|
+
type=click.Path(allow_dash=True),
|
|
27
|
+
)
|
|
28
|
+
def align(
|
|
29
|
+
input_audio_path: Pathlike,
|
|
30
|
+
input_subtitle_path: Pathlike,
|
|
31
|
+
output_subtitle_path: Pathlike,
|
|
32
|
+
input_format: str = 'auto',
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Command used to align audio with subtitles
|
|
36
|
+
"""
|
|
37
|
+
from lattifai import LattifAI
|
|
38
|
+
|
|
39
|
+
client = LattifAI()
|
|
40
|
+
client.alignment(
|
|
41
|
+
input_audio_path, input_subtitle_path, format=input_format, output_subtitle_path=output_subtitle_path
|
|
42
|
+
)
|
lattifai/bin/cli_base.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@click.group()
|
|
7
|
+
def cli():
|
|
8
|
+
"""
|
|
9
|
+
The shell entry point to Lattifai, a tool for audio data manipulation.
|
|
10
|
+
"""
|
|
11
|
+
logging.basicConfig(
|
|
12
|
+
format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s',
|
|
13
|
+
level=logging.INFO,
|
|
14
|
+
)
|
lattifai/bin/subtitle.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import click
|
|
2
|
+
from lhotse.utils import Pathlike
|
|
3
|
+
|
|
4
|
+
from lattifai.bin.cli_base import cli
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@cli.group()
|
|
8
|
+
def subtitle():
|
|
9
|
+
"""Group of commands used to convert subtitle format."""
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@subtitle.command()
|
|
14
|
+
@click.argument(
|
|
15
|
+
'input_subtitle_path',
|
|
16
|
+
type=click.Path(exists=True, dir_okay=False),
|
|
17
|
+
)
|
|
18
|
+
@click.argument(
|
|
19
|
+
'output_subtitle_path',
|
|
20
|
+
type=click.Path(allow_dash=True),
|
|
21
|
+
)
|
|
22
|
+
def convert(
|
|
23
|
+
input_subtitle_path: Pathlike,
|
|
24
|
+
output_subtitle_path: Pathlike,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Convert subtitle file to another format.
|
|
28
|
+
"""
|
|
29
|
+
import pysubs2
|
|
30
|
+
|
|
31
|
+
subtitle = pysubs2.load(input_subtitle_path)
|
|
32
|
+
subtitle.save(output_subtitle_path)
|
lattifai/client.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""LattifAI client implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Awaitable, BinaryIO, Callable, Dict, Optional, Union
|
|
7
|
+
|
|
8
|
+
import colorful
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
from lhotse.utils import Pathlike
|
|
11
|
+
|
|
12
|
+
from lattifai.base_client import AsyncAPIClient, LattifAIError, SyncAPIClient
|
|
13
|
+
from lattifai.io import SubtitleFormat, SubtitleIO
|
|
14
|
+
from lattifai.tokenizers import LatticeTokenizer
|
|
15
|
+
from lattifai.workers import Lattice1AlphaWorker
|
|
16
|
+
|
|
17
|
+
load_dotenv()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LattifAI(SyncAPIClient):
|
|
21
|
+
"""Synchronous LattifAI client."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
*,
|
|
26
|
+
api_key: Optional[str] = None,
|
|
27
|
+
base_url: Optional[str] = None,
|
|
28
|
+
device: str = 'cpu',
|
|
29
|
+
timeout: Union[float, int] = 60.0,
|
|
30
|
+
max_retries: int = 2,
|
|
31
|
+
default_headers: Optional[Dict[str, str]] = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
if api_key is None:
|
|
34
|
+
api_key = os.environ.get('LATTIFAI_API_KEY')
|
|
35
|
+
if api_key is None:
|
|
36
|
+
raise LattifAIError(
|
|
37
|
+
'The api_key client option must be set either by passing api_key to the client '
|
|
38
|
+
'or by setting the LATTIFAI_API_KEY environment variable'
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if base_url is None:
|
|
42
|
+
base_url = os.environ.get('LATTIFAI_BASE_URL')
|
|
43
|
+
if not base_url:
|
|
44
|
+
base_url = 'https://api.lattifai.com/v1'
|
|
45
|
+
|
|
46
|
+
super().__init__(
|
|
47
|
+
api_key=api_key,
|
|
48
|
+
base_url=base_url,
|
|
49
|
+
timeout=timeout,
|
|
50
|
+
max_retries=max_retries,
|
|
51
|
+
default_headers=default_headers,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Initialize components
|
|
55
|
+
model_name_or_path = '/Users/feiteng/GEEK/OmniCaptions/HF_models/Lattice-1-Alpha'
|
|
56
|
+
|
|
57
|
+
if not Path(model_name_or_path).exists():
|
|
58
|
+
from huggingface_hub import hf_hub_download
|
|
59
|
+
|
|
60
|
+
model_path = hf_hub_download(repo_id=model_name_or_path, repo_type='model')
|
|
61
|
+
else:
|
|
62
|
+
model_path = model_name_or_path
|
|
63
|
+
|
|
64
|
+
self.tokenizer = LatticeTokenizer.from_pretrained(
|
|
65
|
+
client_wrapper=self,
|
|
66
|
+
model_path=f'{model_path}/words.bin',
|
|
67
|
+
g2p_model_path=f'{model_path}/g2p.bin' if Path(f'{model_path}/g2p.bin').exists() else None,
|
|
68
|
+
device=device,
|
|
69
|
+
)
|
|
70
|
+
self.worker = Lattice1AlphaWorker(model_path, device=device, num_threads=8)
|
|
71
|
+
|
|
72
|
+
def alignment(
|
|
73
|
+
self,
|
|
74
|
+
audio: Pathlike,
|
|
75
|
+
subtitle: Pathlike,
|
|
76
|
+
format: Optional[SubtitleFormat] = None,
|
|
77
|
+
output_subtitle_path: Optional[Pathlike] = None,
|
|
78
|
+
) -> str:
|
|
79
|
+
"""Perform alignment on audio and subtitle/text.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
audio: Audio file path
|
|
83
|
+
subtitle: Subtitle/Text to align with audio
|
|
84
|
+
export_format: Output format (srt, vtt, ass, txt)
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Aligned subtitles in specified format
|
|
88
|
+
"""
|
|
89
|
+
# step1: parse text or subtitles
|
|
90
|
+
print(colorful.cyan(f'📖 Step 1: Reading subtitle file from {subtitle}'))
|
|
91
|
+
supervisions = SubtitleIO.read(subtitle, format=format)
|
|
92
|
+
print(colorful.green(f' ✓ Parsed {len(supervisions)} supervision segments'))
|
|
93
|
+
|
|
94
|
+
# step2: make lattice by call Lattifai API
|
|
95
|
+
print(colorful.cyan('🔗 Step 2: Creating lattice graph from text'))
|
|
96
|
+
lattice_id, lattice_graph = self.tokenizer.tokenize(supervisions)
|
|
97
|
+
print(colorful.green(f' ✓ Generated lattice graph with ID: {lattice_id}'))
|
|
98
|
+
|
|
99
|
+
# step3: align audio with text
|
|
100
|
+
print(colorful.cyan(f'🎵 Step 3: Performing alignment on audio file: {audio}'))
|
|
101
|
+
lattice_results = self.worker.alignment(audio, lattice_graph)
|
|
102
|
+
print(colorful.green(' ✓ Alignment completed successfully'))
|
|
103
|
+
|
|
104
|
+
# step4: decode the lattice paths
|
|
105
|
+
print(colorful.cyan('🔍 Step 4: Decoding lattice paths to final alignments'))
|
|
106
|
+
alignments = self.tokenizer.detokenize(lattice_id, lattice_results)
|
|
107
|
+
print(colorful.green(f' ✓ Decoded {len(alignments)} aligned segments'))
|
|
108
|
+
|
|
109
|
+
# step5: export alignments to target format
|
|
110
|
+
if output_subtitle_path:
|
|
111
|
+
SubtitleIO.write(alignments, output_path=output_subtitle_path)
|
|
112
|
+
print(colorful.green(f'🎉🎉🎉🎉🎉 Subtitle file written to: {output_subtitle_path}'))
|
|
113
|
+
|
|
114
|
+
return output_subtitle_path or alignments
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
if __name__ == '__main__':
|
|
118
|
+
client = LattifAI()
|
|
119
|
+
import sys
|
|
120
|
+
|
|
121
|
+
if len(sys.argv) == 4:
|
|
122
|
+
pass
|
|
123
|
+
else:
|
|
124
|
+
audio = 'tests/data/SA1.wav'
|
|
125
|
+
text = 'tests/data/SA1.TXT'
|
|
126
|
+
|
|
127
|
+
alignments = client.alignment(audio, text)
|
|
128
|
+
print(alignments)
|
|
129
|
+
|
|
130
|
+
alignments = client.alignment(audio, 'not paired texttttt', format='txt')
|
|
131
|
+
print(alignments)
|
lattifai/io/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from lhotse.utils import Pathlike
|
|
4
|
+
|
|
5
|
+
from .reader import SubtitleFormat, SubtitleReader
|
|
6
|
+
from .supervision import Supervision
|
|
7
|
+
from .writer import SubtitleWriter
|
|
8
|
+
|
|
9
|
+
__all__ = ['SubtitleReader', 'SubtitleWriter', 'SubtitleIO', 'Supervision']
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SubtitleIO:
|
|
13
|
+
def __init__(self):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def read(cls, subtitle: Pathlike, format: Optional[SubtitleFormat] = None) -> List[Supervision]:
|
|
18
|
+
return SubtitleReader.read(subtitle, format=format)
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
|
|
22
|
+
return SubtitleWriter.write(alignments, output_path)
|
lattifai/io/reader.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from abc import ABCMeta
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
from lhotse.utils import Pathlike
|
|
6
|
+
|
|
7
|
+
from .supervision import Supervision
|
|
8
|
+
|
|
9
|
+
SubtitleFormat = Literal['txt', 'srt', 'vtt', 'ass', 'auto']
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SubtitleReader(ABCMeta):
|
|
13
|
+
"""Parser for converting different subtitle formats to List[Supervision]."""
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
def read(cls, subtitle: Pathlike, format: Optional[SubtitleFormat] = None) -> List[Supervision]:
|
|
17
|
+
"""Parse text and convert to Lhotse List[Supervision].
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
text: Input text to parse. Can be either:
|
|
21
|
+
- str: Direct text content to parse
|
|
22
|
+
- Path: File path to read and parse
|
|
23
|
+
format: Input text format (txt, srt, vtt, ass, textgrid)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Parsed text in Lhotse Cut
|
|
27
|
+
"""
|
|
28
|
+
if not format and Path(str(subtitle)).exists():
|
|
29
|
+
format = Path(str(subtitle)).suffix.lstrip('.').lower()
|
|
30
|
+
elif format:
|
|
31
|
+
format = format.lower()
|
|
32
|
+
|
|
33
|
+
if format == 'txt' or subtitle[-4:].lower() == '.txt':
|
|
34
|
+
if not Path(str(subtitle)).exists(): # str
|
|
35
|
+
lines = [line.strip() for line in subtitle.split('\n')]
|
|
36
|
+
else: # file
|
|
37
|
+
lines = [line.strip() for line in open(subtitle).readlines()]
|
|
38
|
+
examples = [Supervision(text=line) for line in lines if line]
|
|
39
|
+
else:
|
|
40
|
+
examples = cls._parse_subtitle(subtitle, format=format)
|
|
41
|
+
|
|
42
|
+
return examples
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def _parse_subtitle(cls, subtitle: Pathlike, format: Optional[SubtitleFormat]) -> List[Supervision]:
|
|
46
|
+
import pysubs2
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
subs: pysubs2.SSAFile = pysubs2.load(
|
|
50
|
+
subtitle, encoding='utf-8', format_=format if format != 'auto' else None
|
|
51
|
+
) # file
|
|
52
|
+
except IOError:
|
|
53
|
+
try:
|
|
54
|
+
subs: pysubs2.SSAFile = pysubs2.SSAFile.from_string(
|
|
55
|
+
subtitle, format_=format if format != 'auto' else None
|
|
56
|
+
) # str
|
|
57
|
+
except:
|
|
58
|
+
subs: pysubs2.SSAFile = pysubs2.load(subtitle, encoding='utf-8') # auto detect format
|
|
59
|
+
|
|
60
|
+
supervisions = []
|
|
61
|
+
|
|
62
|
+
for event in subs.events:
|
|
63
|
+
supervisions.append(
|
|
64
|
+
Supervision(
|
|
65
|
+
text=event.text,
|
|
66
|
+
# "start": event.start / 1000.0 if event.start is not None else None,
|
|
67
|
+
# "duration": event.end / 1000.0 if event.end is not None else None,
|
|
68
|
+
# }
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
return supervisions
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from lhotse.supervision import SupervisionSegment
|
|
5
|
+
from lhotse.utils import Seconds
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Supervision(SupervisionSegment):
|
|
10
|
+
text: Optional[str] = None
|
|
11
|
+
id: str = ''
|
|
12
|
+
recording_id: str = ''
|
|
13
|
+
start: Seconds = 0.0
|
|
14
|
+
duration: Seconds = 0.0
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
__all__ = ['Supervision']
|
lattifai/io/writer.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from abc import ABCMeta
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from lhotse.utils import Pathlike
|
|
5
|
+
|
|
6
|
+
from .reader import SubtitleFormat, Supervision
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SubtitleWriter(ABCMeta):
|
|
10
|
+
"""Class for writing subtitle files."""
|
|
11
|
+
|
|
12
|
+
@classmethod
|
|
13
|
+
def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
|
|
14
|
+
if str(output_path)[-4:].lower() == '.txt':
|
|
15
|
+
with open(output_path, 'w', encoding='utf-8') as f:
|
|
16
|
+
for sup in alignments:
|
|
17
|
+
f.write(f'{sup.text}\n')
|
|
18
|
+
elif str(output_path)[-5:].lower() == '.json':
|
|
19
|
+
with open(output_path, 'w', encoding='utf-8') as f:
|
|
20
|
+
import json
|
|
21
|
+
|
|
22
|
+
json.dump([sup.to_dict() for sup in alignments], f, ensure_ascii=False, indent=4)
|
|
23
|
+
elif str(output_path).endswith('.TextGrid') or str(output_path).endswith('.textgrid'):
|
|
24
|
+
from tgt import Interval, IntervalTier, TextGrid, write_to_file
|
|
25
|
+
|
|
26
|
+
tg = TextGrid()
|
|
27
|
+
supervisions, words = [], []
|
|
28
|
+
for supervision in sorted(alignments, key=lambda x: x.start):
|
|
29
|
+
supervisions.append(Interval(supervision.start, supervision.end, supervision.text or ''))
|
|
30
|
+
if supervision.alignment and 'word' in supervision.alignment:
|
|
31
|
+
for alignment in supervision.alignment['word']:
|
|
32
|
+
words.append(Interval(alignment.start, alignment.end, alignment.symbol))
|
|
33
|
+
|
|
34
|
+
tg.add_tier(IntervalTier(name='utterances', objects=supervisions))
|
|
35
|
+
if words:
|
|
36
|
+
tg.add_tier(IntervalTier(name='words', objects=words))
|
|
37
|
+
write_to_file(tg, output_path, format='long')
|
|
38
|
+
else:
|
|
39
|
+
import pysubs2
|
|
40
|
+
|
|
41
|
+
subs = pysubs2.SSAFile()
|
|
42
|
+
for sup in alignments:
|
|
43
|
+
start = int(sup.start * 1000)
|
|
44
|
+
end = int(sup.end * 1000)
|
|
45
|
+
text = sup.text or ''
|
|
46
|
+
subs.append(pysubs2.SSAEvent(start=start, end=end, text=text))
|
|
47
|
+
subs.save(output_path)
|
|
48
|
+
|
|
49
|
+
return output_path
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import List, Optional, Union
|
|
3
|
+
|
|
4
|
+
from dp.phonemizer import Phonemizer
|
|
5
|
+
from num2words import num2words
|
|
6
|
+
|
|
7
|
+
LANGUAGE = 'omni'
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class G2Phonemizer:
|
|
11
|
+
def __init__(self, model_checkpoint, device):
|
|
12
|
+
self.phonemizer = Phonemizer.from_checkpoint(model_checkpoint, device=device).predictor
|
|
13
|
+
self.pattern = re.compile(r'\d+')
|
|
14
|
+
|
|
15
|
+
def num2words(self, word, lang: str):
|
|
16
|
+
matches = self.pattern.findall(word)
|
|
17
|
+
for match in matches:
|
|
18
|
+
word_equivalent = num2words(int(match), lang=lang)
|
|
19
|
+
word = word.replace(match, word_equivalent)
|
|
20
|
+
return word
|
|
21
|
+
|
|
22
|
+
def remove_special_tokens(self, decoded: List[str]) -> List[str]:
|
|
23
|
+
return [d for d in decoded if d not in self.phonemizer.phoneme_tokenizer.special_tokens]
|
|
24
|
+
|
|
25
|
+
def __call__(
|
|
26
|
+
self, words: Union[str, List[str]], lang: Optional[StopIteration], batch_size: int = 0, num_prons: int = 1
|
|
27
|
+
):
|
|
28
|
+
is_list = True
|
|
29
|
+
if not isinstance(words, list):
|
|
30
|
+
words = [words]
|
|
31
|
+
is_list = False
|
|
32
|
+
|
|
33
|
+
predictions = self.phonemizer(
|
|
34
|
+
[self.num2words(word.replace(' .', '.').replace('.', ' .'), lang=lang or 'en') for word in words],
|
|
35
|
+
lang=LANGUAGE,
|
|
36
|
+
batch_size=min(batch_size or len(words), 128),
|
|
37
|
+
num_prons=num_prons,
|
|
38
|
+
)
|
|
39
|
+
if num_prons > 1:
|
|
40
|
+
predictions = [
|
|
41
|
+
[self.remove_special_tokens(_prediction.phoneme_tokens) for _prediction in prediction]
|
|
42
|
+
for prediction in predictions
|
|
43
|
+
]
|
|
44
|
+
else:
|
|
45
|
+
predictions = [self.remove_special_tokens(prediction.phoneme_tokens) for prediction in predictions]
|
|
46
|
+
|
|
47
|
+
if is_list:
|
|
48
|
+
return predictions
|
|
49
|
+
|
|
50
|
+
return predictions[0]
|