lattifai 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lattifai/__init__.py ADDED
@@ -0,0 +1,30 @@
1
+ from .base_client import LattifAIError
2
+ from .io import SubtitleIO
3
+
4
+ try:
5
+ from importlib.metadata import version
6
+ except ImportError:
7
+ # Python < 3.8
8
+ from importlib_metadata import version
9
+
10
+ try:
11
+ __version__ = version('lattifai')
12
+ except Exception:
13
+ __version__ = '0.1.0' # fallback version
14
+
15
+
16
+ # Lazy import for LattifAI to avoid dependency issues during basic import
17
+ def __getattr__(name):
18
+ if name == 'LattifAI':
19
+ from .client import LattifAI
20
+
21
+ return LattifAI
22
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
23
+
24
+
25
+ __all__ = [
26
+ 'LattifAI', # noqa: F822
27
+ 'LattifAIError',
28
+ 'SubtitleIO',
29
+ '__version__',
30
+ ]
@@ -0,0 +1,118 @@
1
+ """Base client classes for LattifAI SDK."""
2
+
3
+ import os
4
+ from abc import ABC
5
+ from typing import Any, Awaitable, Callable, Dict, Optional, Union # noqa: F401
6
+
7
+ import httpx
8
+
9
+
10
+ class LattifAIError(Exception):
11
+ """Base exception for LattifAI errors."""
12
+
13
+ pass
14
+
15
+
16
+ class BaseAPIClient(ABC):
17
+ """Abstract base class for API clients."""
18
+
19
+ def __init__(
20
+ self,
21
+ *,
22
+ api_key: Optional[str] = None,
23
+ base_url: Optional[str] = None,
24
+ timeout: Union[float, httpx.Timeout] = 60.0,
25
+ max_retries: int = 2,
26
+ default_headers: Optional[Dict[str, str]] = None,
27
+ ) -> None:
28
+ if api_key is None:
29
+ api_key = os.environ.get('LATTIFAI_API_KEY')
30
+ if api_key is None:
31
+ raise LattifAIError(
32
+ 'The api_key client option must be set either by passing api_key to the client '
33
+ 'or by setting the LATTIFAI_API_KEY environment variable'
34
+ )
35
+
36
+ self._api_key = api_key
37
+ self._base_url = base_url
38
+ self._timeout = timeout
39
+ self._max_retries = max_retries
40
+
41
+ headers = {
42
+ 'User-Agent': 'LattifAI/Python',
43
+ 'Authorization': f'Bearer {self._api_key}',
44
+ }
45
+ if default_headers:
46
+ headers.update(default_headers)
47
+ self._default_headers = headers
48
+
49
+
50
+ class SyncAPIClient(BaseAPIClient):
51
+ """Synchronous API client."""
52
+
53
+ def __init__(self, **kwargs) -> None:
54
+ super().__init__(**kwargs)
55
+ self._client = httpx.Client(
56
+ base_url=self._base_url,
57
+ timeout=self._timeout,
58
+ headers=self._default_headers,
59
+ )
60
+
61
+ def __enter__(self):
62
+ return self
63
+
64
+ def __exit__(self, exc_type, exc_val, exc_tb):
65
+ self.close()
66
+
67
+ def close(self) -> None:
68
+ """Close the HTTP client."""
69
+ self._client.close()
70
+
71
+ def _request(
72
+ self,
73
+ method: str,
74
+ url: str,
75
+ *,
76
+ json: Optional[Dict[str, Any]] = None,
77
+ **kwargs,
78
+ ) -> httpx.Response:
79
+ """Make an HTTP request."""
80
+ return self._client.request(method=method, url=url, json=json, **kwargs)
81
+
82
+ def post(self, api_endpoint: str, *, json: Optional[Dict[str, Any]] = None, **kwargs) -> httpx.Response:
83
+ """Make a POST request to the specified API endpoint."""
84
+ return self._request('POST', api_endpoint, json=json, **kwargs)
85
+
86
+
87
+ class AsyncAPIClient(BaseAPIClient):
88
+ """Asynchronous API client."""
89
+
90
+ def __init__(self, **kwargs) -> None:
91
+ super().__init__(**kwargs)
92
+ self._client = httpx.AsyncClient(
93
+ base_url=self._base_url,
94
+ timeout=self._timeout,
95
+ headers=self._default_headers,
96
+ )
97
+
98
+ async def __aenter__(self):
99
+ return self
100
+
101
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
102
+ await self.close()
103
+
104
+ async def close(self) -> None:
105
+ """Close the HTTP client."""
106
+ await self._client.aclose()
107
+
108
+ async def _request(
109
+ self,
110
+ method: str,
111
+ url: str,
112
+ *,
113
+ json: Optional[Dict[str, Any]] = None,
114
+ files: Optional[Dict[str, Any]] = None,
115
+ **kwargs,
116
+ ) -> httpx.Response:
117
+ """Make an HTTP request."""
118
+ return await self._client.request(method=method, url=url, json=json, files=files, **kwargs)
@@ -0,0 +1,2 @@
1
+ from .align import * # noqa
2
+ from .subtitle import * # noqa
lattifai/bin/align.py ADDED
@@ -0,0 +1,42 @@
1
+ import click
2
+ import colorful
3
+ from lhotse.utils import Pathlike
4
+
5
+ from lattifai.bin.cli_base import cli
6
+
7
+
8
+ @cli.command()
9
+ @click.option(
10
+ '-F',
11
+ '--input_format',
12
+ type=click.Choice(['srt', 'vtt', 'ass', 'txt', 'auto'], case_sensitive=False),
13
+ default='auto',
14
+ help='Input Subtitle format.',
15
+ )
16
+ @click.argument(
17
+ 'input_audio_path',
18
+ type=click.Path(exists=True, dir_okay=False),
19
+ )
20
+ @click.argument(
21
+ 'input_subtitle_path',
22
+ type=click.Path(exists=True, dir_okay=False),
23
+ )
24
+ @click.argument(
25
+ 'output_subtitle_path',
26
+ type=click.Path(allow_dash=True),
27
+ )
28
+ def align(
29
+ input_audio_path: Pathlike,
30
+ input_subtitle_path: Pathlike,
31
+ output_subtitle_path: Pathlike,
32
+ input_format: str = 'auto',
33
+ ):
34
+ """
35
+ Command used to align audio with subtitles
36
+ """
37
+ from lattifai import LattifAI
38
+
39
+ client = LattifAI()
40
+ client.alignment(
41
+ input_audio_path, input_subtitle_path, format=input_format, output_subtitle_path=output_subtitle_path
42
+ )
@@ -0,0 +1,14 @@
1
+ import logging
2
+
3
+ import click
4
+
5
+
6
+ @click.group()
7
+ def cli():
8
+ """
9
+ The shell entry point to Lattifai, a tool for audio data manipulation.
10
+ """
11
+ logging.basicConfig(
12
+ format='%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s',
13
+ level=logging.INFO,
14
+ )
@@ -0,0 +1,32 @@
1
+ import click
2
+ from lhotse.utils import Pathlike
3
+
4
+ from lattifai.bin.cli_base import cli
5
+
6
+
7
+ @cli.group()
8
+ def subtitle():
9
+ """Group of commands used to convert subtitle format."""
10
+ pass
11
+
12
+
13
+ @subtitle.command()
14
+ @click.argument(
15
+ 'input_subtitle_path',
16
+ type=click.Path(exists=True, dir_okay=False),
17
+ )
18
+ @click.argument(
19
+ 'output_subtitle_path',
20
+ type=click.Path(allow_dash=True),
21
+ )
22
+ def convert(
23
+ input_subtitle_path: Pathlike,
24
+ output_subtitle_path: Pathlike,
25
+ ):
26
+ """
27
+ Convert subtitle file to another format.
28
+ """
29
+ import pysubs2
30
+
31
+ subtitle = pysubs2.load(input_subtitle_path)
32
+ subtitle.save(output_subtitle_path)
lattifai/client.py ADDED
@@ -0,0 +1,131 @@
1
+ """LattifAI client implementation."""
2
+
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any, Awaitable, BinaryIO, Callable, Dict, Optional, Union
7
+
8
+ import colorful
9
+ from dotenv import load_dotenv
10
+ from lhotse.utils import Pathlike
11
+
12
+ from lattifai.base_client import AsyncAPIClient, LattifAIError, SyncAPIClient
13
+ from lattifai.io import SubtitleFormat, SubtitleIO
14
+ from lattifai.tokenizers import LatticeTokenizer
15
+ from lattifai.workers import Lattice1AlphaWorker
16
+
17
+ load_dotenv()
18
+
19
+
20
+ class LattifAI(SyncAPIClient):
21
+ """Synchronous LattifAI client."""
22
+
23
+ def __init__(
24
+ self,
25
+ *,
26
+ api_key: Optional[str] = None,
27
+ base_url: Optional[str] = None,
28
+ device: str = 'cpu',
29
+ timeout: Union[float, int] = 60.0,
30
+ max_retries: int = 2,
31
+ default_headers: Optional[Dict[str, str]] = None,
32
+ ) -> None:
33
+ if api_key is None:
34
+ api_key = os.environ.get('LATTIFAI_API_KEY')
35
+ if api_key is None:
36
+ raise LattifAIError(
37
+ 'The api_key client option must be set either by passing api_key to the client '
38
+ 'or by setting the LATTIFAI_API_KEY environment variable'
39
+ )
40
+
41
+ if base_url is None:
42
+ base_url = os.environ.get('LATTIFAI_BASE_URL')
43
+ if not base_url:
44
+ base_url = 'https://api.lattifai.com/v1'
45
+
46
+ super().__init__(
47
+ api_key=api_key,
48
+ base_url=base_url,
49
+ timeout=timeout,
50
+ max_retries=max_retries,
51
+ default_headers=default_headers,
52
+ )
53
+
54
+ # Initialize components
55
+ model_name_or_path = '/Users/feiteng/GEEK/OmniCaptions/HF_models/Lattice-1-Alpha'
56
+
57
+ if not Path(model_name_or_path).exists():
58
+ from huggingface_hub import hf_hub_download
59
+
60
+ model_path = hf_hub_download(repo_id=model_name_or_path, repo_type='model')
61
+ else:
62
+ model_path = model_name_or_path
63
+
64
+ self.tokenizer = LatticeTokenizer.from_pretrained(
65
+ client_wrapper=self,
66
+ model_path=f'{model_path}/words.bin',
67
+ g2p_model_path=f'{model_path}/g2p.bin' if Path(f'{model_path}/g2p.bin').exists() else None,
68
+ device=device,
69
+ )
70
+ self.worker = Lattice1AlphaWorker(model_path, device=device, num_threads=8)
71
+
72
+ def alignment(
73
+ self,
74
+ audio: Pathlike,
75
+ subtitle: Pathlike,
76
+ format: Optional[SubtitleFormat] = None,
77
+ output_subtitle_path: Optional[Pathlike] = None,
78
+ ) -> str:
79
+ """Perform alignment on audio and subtitle/text.
80
+
81
+ Args:
82
+ audio: Audio file path
83
+ subtitle: Subtitle/Text to align with audio
84
+ export_format: Output format (srt, vtt, ass, txt)
85
+
86
+ Returns:
87
+ Aligned subtitles in specified format
88
+ """
89
+ # step1: parse text or subtitles
90
+ print(colorful.cyan(f'📖 Step 1: Reading subtitle file from {subtitle}'))
91
+ supervisions = SubtitleIO.read(subtitle, format=format)
92
+ print(colorful.green(f' ✓ Parsed {len(supervisions)} supervision segments'))
93
+
94
+ # step2: make lattice by call Lattifai API
95
+ print(colorful.cyan('🔗 Step 2: Creating lattice graph from text'))
96
+ lattice_id, lattice_graph = self.tokenizer.tokenize(supervisions)
97
+ print(colorful.green(f' ✓ Generated lattice graph with ID: {lattice_id}'))
98
+
99
+ # step3: align audio with text
100
+ print(colorful.cyan(f'🎵 Step 3: Performing alignment on audio file: {audio}'))
101
+ lattice_results = self.worker.alignment(audio, lattice_graph)
102
+ print(colorful.green(' ✓ Alignment completed successfully'))
103
+
104
+ # step4: decode the lattice paths
105
+ print(colorful.cyan('🔍 Step 4: Decoding lattice paths to final alignments'))
106
+ alignments = self.tokenizer.detokenize(lattice_id, lattice_results)
107
+ print(colorful.green(f' ✓ Decoded {len(alignments)} aligned segments'))
108
+
109
+ # step5: export alignments to target format
110
+ if output_subtitle_path:
111
+ SubtitleIO.write(alignments, output_path=output_subtitle_path)
112
+ print(colorful.green(f'🎉🎉🎉🎉🎉 Subtitle file written to: {output_subtitle_path}'))
113
+
114
+ return output_subtitle_path or alignments
115
+
116
+
117
+ if __name__ == '__main__':
118
+ client = LattifAI()
119
+ import sys
120
+
121
+ if len(sys.argv) == 4:
122
+ pass
123
+ else:
124
+ audio = 'tests/data/SA1.wav'
125
+ text = 'tests/data/SA1.TXT'
126
+
127
+ alignments = client.alignment(audio, text)
128
+ print(alignments)
129
+
130
+ alignments = client.alignment(audio, 'not paired texttttt', format='txt')
131
+ print(alignments)
@@ -0,0 +1,22 @@
1
+ from typing import List, Optional
2
+
3
+ from lhotse.utils import Pathlike
4
+
5
+ from .reader import SubtitleFormat, SubtitleReader
6
+ from .supervision import Supervision
7
+ from .writer import SubtitleWriter
8
+
9
+ __all__ = ['SubtitleReader', 'SubtitleWriter', 'SubtitleIO', 'Supervision']
10
+
11
+
12
+ class SubtitleIO:
13
+ def __init__(self):
14
+ pass
15
+
16
+ @classmethod
17
+ def read(cls, subtitle: Pathlike, format: Optional[SubtitleFormat] = None) -> List[Supervision]:
18
+ return SubtitleReader.read(subtitle, format=format)
19
+
20
+ @classmethod
21
+ def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
22
+ return SubtitleWriter.write(alignments, output_path)
lattifai/io/reader.py ADDED
@@ -0,0 +1,71 @@
1
+ from abc import ABCMeta
2
+ from pathlib import Path
3
+ from typing import List, Literal, Optional, Union
4
+
5
+ from lhotse.utils import Pathlike
6
+
7
+ from .supervision import Supervision
8
+
9
+ SubtitleFormat = Literal['txt', 'srt', 'vtt', 'ass', 'auto']
10
+
11
+
12
+ class SubtitleReader(ABCMeta):
13
+ """Parser for converting different subtitle formats to List[Supervision]."""
14
+
15
+ @classmethod
16
+ def read(cls, subtitle: Pathlike, format: Optional[SubtitleFormat] = None) -> List[Supervision]:
17
+ """Parse text and convert to Lhotse List[Supervision].
18
+
19
+ Args:
20
+ text: Input text to parse. Can be either:
21
+ - str: Direct text content to parse
22
+ - Path: File path to read and parse
23
+ format: Input text format (txt, srt, vtt, ass, textgrid)
24
+
25
+ Returns:
26
+ Parsed text in Lhotse Cut
27
+ """
28
+ if not format and Path(str(subtitle)).exists():
29
+ format = Path(str(subtitle)).suffix.lstrip('.').lower()
30
+ elif format:
31
+ format = format.lower()
32
+
33
+ if format == 'txt' or subtitle[-4:].lower() == '.txt':
34
+ if not Path(str(subtitle)).exists(): # str
35
+ lines = [line.strip() for line in subtitle.split('\n')]
36
+ else: # file
37
+ lines = [line.strip() for line in open(subtitle).readlines()]
38
+ examples = [Supervision(text=line) for line in lines if line]
39
+ else:
40
+ examples = cls._parse_subtitle(subtitle, format=format)
41
+
42
+ return examples
43
+
44
+ @classmethod
45
+ def _parse_subtitle(cls, subtitle: Pathlike, format: Optional[SubtitleFormat]) -> List[Supervision]:
46
+ import pysubs2
47
+
48
+ try:
49
+ subs: pysubs2.SSAFile = pysubs2.load(
50
+ subtitle, encoding='utf-8', format_=format if format != 'auto' else None
51
+ ) # file
52
+ except IOError:
53
+ try:
54
+ subs: pysubs2.SSAFile = pysubs2.SSAFile.from_string(
55
+ subtitle, format_=format if format != 'auto' else None
56
+ ) # str
57
+ except:
58
+ subs: pysubs2.SSAFile = pysubs2.load(subtitle, encoding='utf-8') # auto detect format
59
+
60
+ supervisions = []
61
+
62
+ for event in subs.events:
63
+ supervisions.append(
64
+ Supervision(
65
+ text=event.text,
66
+ # "start": event.start / 1000.0 if event.start is not None else None,
67
+ # "duration": event.end / 1000.0 if event.end is not None else None,
68
+ # }
69
+ )
70
+ )
71
+ return supervisions
@@ -0,0 +1,17 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ from lhotse.supervision import SupervisionSegment
5
+ from lhotse.utils import Seconds
6
+
7
+
8
+ @dataclass
9
+ class Supervision(SupervisionSegment):
10
+ text: Optional[str] = None
11
+ id: str = ''
12
+ recording_id: str = ''
13
+ start: Seconds = 0.0
14
+ duration: Seconds = 0.0
15
+
16
+
17
+ __all__ = ['Supervision']
lattifai/io/writer.py ADDED
@@ -0,0 +1,49 @@
1
+ from abc import ABCMeta
2
+ from typing import List
3
+
4
+ from lhotse.utils import Pathlike
5
+
6
+ from .reader import SubtitleFormat, Supervision
7
+
8
+
9
+ class SubtitleWriter(ABCMeta):
10
+ """Class for writing subtitle files."""
11
+
12
+ @classmethod
13
+ def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
14
+ if str(output_path)[-4:].lower() == '.txt':
15
+ with open(output_path, 'w', encoding='utf-8') as f:
16
+ for sup in alignments:
17
+ f.write(f'{sup.text}\n')
18
+ elif str(output_path)[-5:].lower() == '.json':
19
+ with open(output_path, 'w', encoding='utf-8') as f:
20
+ import json
21
+
22
+ json.dump([sup.to_dict() for sup in alignments], f, ensure_ascii=False, indent=4)
23
+ elif str(output_path).endswith('.TextGrid') or str(output_path).endswith('.textgrid'):
24
+ from tgt import Interval, IntervalTier, TextGrid, write_to_file
25
+
26
+ tg = TextGrid()
27
+ supervisions, words = [], []
28
+ for supervision in sorted(alignments, key=lambda x: x.start):
29
+ supervisions.append(Interval(supervision.start, supervision.end, supervision.text or ''))
30
+ if supervision.alignment and 'word' in supervision.alignment:
31
+ for alignment in supervision.alignment['word']:
32
+ words.append(Interval(alignment.start, alignment.end, alignment.symbol))
33
+
34
+ tg.add_tier(IntervalTier(name='utterances', objects=supervisions))
35
+ if words:
36
+ tg.add_tier(IntervalTier(name='words', objects=words))
37
+ write_to_file(tg, output_path, format='long')
38
+ else:
39
+ import pysubs2
40
+
41
+ subs = pysubs2.SSAFile()
42
+ for sup in alignments:
43
+ start = int(sup.start * 1000)
44
+ end = int(sup.end * 1000)
45
+ text = sup.text or ''
46
+ subs.append(pysubs2.SSAEvent(start=start, end=end, text=text))
47
+ subs.save(output_path)
48
+
49
+ return output_path
@@ -0,0 +1,3 @@
1
+ from .tokenizer import LatticeTokenizer
2
+
3
+ __all__ = ['LatticeTokenizer']
@@ -0,0 +1,50 @@
1
+ import re
2
+ from typing import List, Optional, Union
3
+
4
+ from dp.phonemizer import Phonemizer
5
+ from num2words import num2words
6
+
7
+ LANGUAGE = 'omni'
8
+
9
+
10
+ class G2Phonemizer:
11
+ def __init__(self, model_checkpoint, device):
12
+ self.phonemizer = Phonemizer.from_checkpoint(model_checkpoint, device=device).predictor
13
+ self.pattern = re.compile(r'\d+')
14
+
15
+ def num2words(self, word, lang: str):
16
+ matches = self.pattern.findall(word)
17
+ for match in matches:
18
+ word_equivalent = num2words(int(match), lang=lang)
19
+ word = word.replace(match, word_equivalent)
20
+ return word
21
+
22
+ def remove_special_tokens(self, decoded: List[str]) -> List[str]:
23
+ return [d for d in decoded if d not in self.phonemizer.phoneme_tokenizer.special_tokens]
24
+
25
+ def __call__(
26
+ self, words: Union[str, List[str]], lang: Optional[StopIteration], batch_size: int = 0, num_prons: int = 1
27
+ ):
28
+ is_list = True
29
+ if not isinstance(words, list):
30
+ words = [words]
31
+ is_list = False
32
+
33
+ predictions = self.phonemizer(
34
+ [self.num2words(word.replace(' .', '.').replace('.', ' .'), lang=lang or 'en') for word in words],
35
+ lang=LANGUAGE,
36
+ batch_size=min(batch_size or len(words), 128),
37
+ num_prons=num_prons,
38
+ )
39
+ if num_prons > 1:
40
+ predictions = [
41
+ [self.remove_special_tokens(_prediction.phoneme_tokens) for _prediction in prediction]
42
+ for prediction in predictions
43
+ ]
44
+ else:
45
+ predictions = [self.remove_special_tokens(prediction.phoneme_tokens) for prediction in predictions]
46
+
47
+ if is_list:
48
+ return predictions
49
+
50
+ return predictions[0]