lattifai 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. lattifai/__init__.py +42 -27
  2. lattifai/alignment/__init__.py +6 -0
  3. lattifai/alignment/lattice1_aligner.py +119 -0
  4. lattifai/{workers/lattice1_alpha.py → alignment/lattice1_worker.py} +33 -132
  5. lattifai/{tokenizer → alignment}/phonemizer.py +1 -1
  6. lattifai/alignment/segmenter.py +166 -0
  7. lattifai/{tokenizer → alignment}/tokenizer.py +186 -112
  8. lattifai/audio2.py +211 -0
  9. lattifai/caption/__init__.py +20 -0
  10. lattifai/caption/caption.py +1275 -0
  11. lattifai/{io → caption}/supervision.py +1 -0
  12. lattifai/{io → caption}/text_parser.py +53 -10
  13. lattifai/cli/__init__.py +17 -0
  14. lattifai/cli/alignment.py +153 -0
  15. lattifai/cli/caption.py +204 -0
  16. lattifai/cli/server.py +19 -0
  17. lattifai/cli/transcribe.py +197 -0
  18. lattifai/cli/youtube.py +128 -0
  19. lattifai/client.py +455 -246
  20. lattifai/config/__init__.py +20 -0
  21. lattifai/config/alignment.py +73 -0
  22. lattifai/config/caption.py +178 -0
  23. lattifai/config/client.py +46 -0
  24. lattifai/config/diarization.py +67 -0
  25. lattifai/config/media.py +335 -0
  26. lattifai/config/transcription.py +84 -0
  27. lattifai/diarization/__init__.py +5 -0
  28. lattifai/diarization/lattifai.py +89 -0
  29. lattifai/errors.py +41 -34
  30. lattifai/logging.py +116 -0
  31. lattifai/mixin.py +552 -0
  32. lattifai/server/app.py +420 -0
  33. lattifai/transcription/__init__.py +76 -0
  34. lattifai/transcription/base.py +108 -0
  35. lattifai/transcription/gemini.py +219 -0
  36. lattifai/transcription/lattifai.py +103 -0
  37. lattifai/types.py +30 -0
  38. lattifai/utils.py +3 -31
  39. lattifai/workflow/__init__.py +22 -0
  40. lattifai/workflow/agents.py +6 -0
  41. lattifai/{workflows → workflow}/file_manager.py +81 -57
  42. lattifai/workflow/youtube.py +564 -0
  43. lattifai-1.0.0.dist-info/METADATA +736 -0
  44. lattifai-1.0.0.dist-info/RECORD +52 -0
  45. {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
  46. lattifai-1.0.0.dist-info/entry_points.txt +13 -0
  47. lattifai/base_client.py +0 -126
  48. lattifai/bin/__init__.py +0 -3
  49. lattifai/bin/agent.py +0 -324
  50. lattifai/bin/align.py +0 -295
  51. lattifai/bin/cli_base.py +0 -25
  52. lattifai/bin/subtitle.py +0 -210
  53. lattifai/io/__init__.py +0 -43
  54. lattifai/io/reader.py +0 -86
  55. lattifai/io/utils.py +0 -15
  56. lattifai/io/writer.py +0 -102
  57. lattifai/tokenizer/__init__.py +0 -3
  58. lattifai/workers/__init__.py +0 -3
  59. lattifai/workflows/__init__.py +0 -34
  60. lattifai/workflows/agents.py +0 -12
  61. lattifai/workflows/gemini.py +0 -167
  62. lattifai/workflows/prompts/README.md +0 -22
  63. lattifai/workflows/prompts/gemini/README.md +0 -24
  64. lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
  65. lattifai/workflows/youtube.py +0 -931
  66. lattifai-0.4.6.dist-info/METADATA +0 -806
  67. lattifai-0.4.6.dist-info/RECORD +0 -39
  68. lattifai-0.4.6.dist-info/entry_points.txt +0 -3
  69. /lattifai/{io → caption}/gemini_reader.py +0 -0
  70. /lattifai/{io → caption}/gemini_writer.py +0 -0
  71. /lattifai/{workflows → transcription}/prompts/__init__.py +0 -0
  72. /lattifai/{workflows → workflow}/base.py +0 -0
  73. {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +0 -0
  74. {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  import gzip
2
- import inspect
3
2
  import pickle
4
3
  import re
5
4
  from collections import defaultdict
@@ -7,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
7
6
 
8
7
  import torch
9
8
 
10
- from lattifai.errors import LATTICE_DECODING_FAILURE_HELP, LatticeDecodingError
11
- from lattifai.io import Supervision, normalize_html_text
12
- from lattifai.tokenizer.phonemizer import G2Phonemizer
9
+ from lattifai.alignment.phonemizer import G2Phonemizer
10
+ from lattifai.caption import Supervision
11
+ from lattifai.caption import normalize_text as normalize_html_text
12
+ from lattifai.errors import (
13
+ LATTICE_DECODING_FAILURE_HELP,
14
+ LatticeDecodingError,
15
+ ModelLoadError,
16
+ QuotaExceededError,
17
+ )
13
18
 
14
19
  PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
15
20
  END_PUNCTUATION = '.!?"]。!?”】'
@@ -24,11 +29,98 @@ MAXIMUM_WORD_LENGTH = 40
24
29
  TokenizerT = TypeVar("TokenizerT", bound="LatticeTokenizer")
25
30
 
26
31
 
32
+ def _is_punctuation(char: str) -> bool:
33
+ """Check if a character is punctuation (not space, not alphanumeric, not CJK)."""
34
+ if len(char) != 1:
35
+ return False
36
+ if char.isspace():
37
+ return False
38
+ if char.isalnum():
39
+ return False
40
+ # Check if it's a CJK character
41
+ if "\u4e00" <= char <= "\u9fff":
42
+ return False
43
+ # Check if it's an accented Latin character
44
+ if "\u00c0" <= char <= "\u024f":
45
+ return False
46
+ return True
47
+
48
+
49
+ def tokenize_multilingual_text(text: str, keep_spaces: bool = True, attach_punctuation: bool = False) -> list[str]:
50
+ """
51
+ Tokenize a mixed Chinese-English string into individual units.
52
+
53
+ Tokenization rules:
54
+ - Chinese characters (CJK) are split individually
55
+ - Consecutive Latin letters (including accented characters) and digits are grouped as one unit
56
+ - English contractions ('s, 't, 'm, 'll, 're, 've) are kept with the preceding word
57
+ - Other characters (punctuation, spaces) are split individually by default
58
+ - If attach_punctuation=True, punctuation marks are attached to the preceding token
59
+
60
+ Args:
61
+ text: Input string containing mixed Chinese and English text
62
+ keep_spaces: If True, spaces are included in the output as separate tokens.
63
+ If False, spaces are excluded from the output. Default is True.
64
+ attach_punctuation: If True, punctuation marks are attached to the preceding token.
65
+ For example, "Hello, World!" becomes ["Hello,", " ", "World!"].
66
+ Default is False.
67
+
68
+ Returns:
69
+ List of tokenized units
70
+
71
+ Examples:
72
+ >>> tokenize_multilingual_text("Hello世界")
73
+ ['Hello', '世', '界']
74
+ >>> tokenize_multilingual_text("I'm fine")
75
+ ["I'm", ' ', 'fine']
76
+ >>> tokenize_multilingual_text("I'm fine", keep_spaces=False)
77
+ ["I'm", 'fine']
78
+ >>> tokenize_multilingual_text("Kühlschrank")
79
+ ['Kühlschrank']
80
+ >>> tokenize_multilingual_text("Hello, World!", attach_punctuation=True)
81
+ ['Hello,', ' ', 'World!']
82
+ """
83
+ # Regex pattern:
84
+ # - [a-zA-Z0-9\u00C0-\u024F]+ matches Latin letters (including accented chars like ü, ö, ä, ß, é, etc.)
85
+ # - (?:'[a-zA-Z]{1,2})? optionally matches contractions like 's, 't, 'm, 'll, 're, 've
86
+ # - [\u4e00-\u9fff] matches CJK characters
87
+ # - . matches any other single character
88
+ # Unicode ranges:
89
+ # - \u00C0-\u00FF: Latin-1 Supplement (À-ÿ)
90
+ # - \u0100-\u017F: Latin Extended-A
91
+ # - \u0180-\u024F: Latin Extended-B
92
+ pattern = re.compile(r"([a-zA-Z0-9\u00C0-\u024F]+(?:'[a-zA-Z]{1,2})?|[\u4e00-\u9fff]|.)")
93
+
94
+ # filter(None, ...) removes any empty strings from re.findall results
95
+ tokens = list(filter(None, pattern.findall(text)))
96
+
97
+ if attach_punctuation and len(tokens) > 1:
98
+ # Attach punctuation to the preceding token
99
+ # Punctuation characters (excluding spaces) are merged with the previous token
100
+ merged_tokens = []
101
+ i = 0
102
+ while i < len(tokens):
103
+ token = tokens[i]
104
+ # Look ahead to collect consecutive punctuation (non-space, non-alphanumeric, non-CJK)
105
+ if merged_tokens and _is_punctuation(token):
106
+ merged_tokens[-1] = merged_tokens[-1] + token
107
+ else:
108
+ merged_tokens.append(token)
109
+ i += 1
110
+ tokens = merged_tokens
111
+
112
+ if not keep_spaces:
113
+ tokens = [t for t in tokens if not t.isspace()]
114
+
115
+ return tokens
116
+
117
+
27
118
  class LatticeTokenizer:
28
119
  """Tokenizer for converting Lhotse Cut to LatticeGraph."""
29
120
 
30
121
  def __init__(self, client_wrapper: Any):
31
122
  self.client_wrapper = client_wrapper
123
+ self.model_name = ""
32
124
  self.words: List[str] = []
33
125
  self.g2p_model: Any = None # Placeholder for G2P model
34
126
  self.dictionaries = defaultdict(lambda: [])
@@ -107,6 +199,7 @@ class LatticeTokenizer:
107
199
  cls: Type[TokenizerT],
108
200
  client_wrapper: Any,
109
201
  model_path: str,
202
+ model_name: str,
110
203
  device: str = "cpu",
111
204
  compressed: bool = True,
112
205
  ) -> TokenizerT:
@@ -114,21 +207,37 @@ class LatticeTokenizer:
114
207
  from pathlib import Path
115
208
 
116
209
  words_model_path = f"{model_path}/words.bin"
117
- if compressed:
118
- with gzip.open(words_model_path, "rb") as f:
119
- data = pickle.load(f)
120
- else:
121
- with open(words_model_path, "rb") as f:
122
- data = pickle.load(f)
210
+ try:
211
+ if compressed:
212
+ with gzip.open(words_model_path, "rb") as f:
213
+ data = pickle.load(f)
214
+ else:
215
+ with open(words_model_path, "rb") as f:
216
+ data = pickle.load(f)
217
+ except pickle.UnpicklingError as e:
218
+ del e
219
+ import msgpack
220
+
221
+ if compressed:
222
+ with gzip.open(words_model_path, "rb") as f:
223
+ data = msgpack.unpack(f, raw=False, strict_map_key=False)
224
+ else:
225
+ with open(words_model_path, "rb") as f:
226
+ data = msgpack.unpack(f, raw=False, strict_map_key=False)
123
227
 
124
228
  tokenizer = cls(client_wrapper=client_wrapper)
229
+ tokenizer.model_name = model_name
125
230
  tokenizer.words = data["words"]
126
231
  tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
127
232
  tokenizer.oov_word = data["oov_word"]
128
233
 
129
- g2p_model_path = f"{model_path}/g2p.bin" if Path(f"{model_path}/g2p.bin").exists() else None
130
- if g2p_model_path:
131
- tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
234
+ g2pp_model_path = f"{model_path}/g2pp.bin" if Path(f"{model_path}/g2pp.bin").exists() else None
235
+ if g2pp_model_path:
236
+ tokenizer.g2p_model = G2Phonemizer(g2pp_model_path, device=device)
237
+ else:
238
+ g2p_model_path = f"{model_path}/g2p.bin" if Path(f"{model_path}/g2p.bin").exists() else None
239
+ if g2p_model_path:
240
+ tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
132
241
 
133
242
  tokenizer.device = device
134
243
  tokenizer.add_special_tokens()
@@ -148,7 +257,10 @@ class LatticeTokenizer:
148
257
  oov_words = []
149
258
  for text in texts:
150
259
  text = normalize_html_text(text)
151
- words = text.lower().replace("-", " ").replace("—", " ").replace("–", " ").split()
260
+ # support english, chinese and german tokenization
261
+ words = tokenize_multilingual_text(
262
+ text.lower().replace("-", " ").replace("—", " ").replace("–", " "), keep_spaces=False
263
+ )
152
264
  oovs = [w.strip(PUNCTUATION) for w in words if w not in self.words]
153
265
  if oovs:
154
266
  oov_words.extend([w for w in oovs if (w not in self.words and len(w) <= MAXIMUM_WORD_LENGTH)])
@@ -188,28 +300,39 @@ class LatticeTokenizer:
188
300
 
189
301
  Carefull about speaker changes.
190
302
  """
191
- texts, text_len, sidx = [], 0, 0
192
- speakers = []
303
+ texts, speakers = [], []
304
+ text_len, sidx = 0, 0
305
+
306
+ def flush_segment(end_idx: int, speaker: Optional[str] = None):
307
+ """Flush accumulated text from sidx to end_idx with given speaker."""
308
+ nonlocal text_len, sidx
309
+ if sidx <= end_idx:
310
+ if len(speakers) < len(texts) + 1:
311
+ speakers.append(speaker)
312
+ text = " ".join(sup.text for sup in supervisions[sidx : end_idx + 1])
313
+ texts.append(text)
314
+ sidx = end_idx + 1
315
+ text_len = 0
316
+
193
317
  for s, supervision in enumerate(supervisions):
194
318
  text_len += len(supervision.text)
319
+ is_last = s == len(supervisions) - 1
320
+
195
321
  if supervision.speaker:
322
+ # Flush previous segment without speaker (if any)
196
323
  if sidx < s:
197
- if len(speakers) < len(texts) + 1:
198
- speakers.append(None)
199
- text = " ".join([sup.text for sup in supervisions[sidx:s]])
200
- texts.append(text)
201
- sidx = s
324
+ flush_segment(s - 1, None)
202
325
  text_len = len(supervision.text)
203
- speakers.append(supervision.speaker)
204
326
 
205
- else:
206
- if text_len >= 2000 or s == len(supervisions) - 1:
207
- if len(speakers) < len(texts) + 1:
208
- speakers.append(None)
209
- text = " ".join([sup.text for sup in supervisions[sidx : s + 1]])
210
- texts.append(text)
211
- sidx = s + 1
212
- text_len = 0
327
+ # Check if we should flush this speaker's segment now
328
+ next_has_speaker = not is_last and supervisions[s + 1].speaker
329
+ if is_last or next_has_speaker:
330
+ flush_segment(s, supervision.speaker)
331
+ else:
332
+ speakers.append(supervision.speaker)
333
+
334
+ elif text_len >= 2000 or is_last:
335
+ flush_segment(s, None)
213
336
 
214
337
  assert len(speakers) == len(texts), f"len(speakers)={len(speakers)} != len(texts)={len(texts)}"
215
338
  sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
@@ -288,10 +411,13 @@ class LatticeTokenizer:
288
411
  response = self.client_wrapper.post(
289
412
  "tokenize",
290
413
  json={
414
+ "model_name": self.model_name,
291
415
  "supervisions": [s.to_dict() for s in supervisions],
292
416
  "pronunciation_dictionaries": pronunciation_dictionaries,
293
417
  },
294
418
  )
419
+ if response.status_code == 402:
420
+ raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
295
421
  if response.status_code != 200:
296
422
  raise Exception(f"Failed to tokenize texts: {response.text}")
297
423
  result = response.json()
@@ -313,13 +439,14 @@ class LatticeTokenizer:
313
439
  response = self.client_wrapper.post(
314
440
  "detokenize",
315
441
  json={
442
+ "model_name": self.model_name,
316
443
  "lattice_id": lattice_id,
317
444
  "frame_shift": frame_shift,
318
445
  "results": [t.to_dict() for t in results[0]],
319
446
  "labels": labels[0],
320
447
  "offset": offset,
321
448
  "channel": channel,
322
- "return_details": return_details,
449
+ "return_details": False if return_details is None else return_details,
323
450
  "destroy_lattice": True,
324
451
  },
325
452
  )
@@ -328,6 +455,8 @@ class LatticeTokenizer:
328
455
  lattice_id,
329
456
  original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
330
457
  )
458
+ if response.status_code == 402:
459
+ raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
331
460
  if response.status_code != 200:
332
461
  raise Exception(f"Failed to detokenize lattice: {response.text}")
333
462
 
@@ -339,83 +468,7 @@ class LatticeTokenizer:
339
468
 
340
469
  if return_details:
341
470
  # Add emission confidence scores for segments and word-level alignments
342
- _add_confidence_scores(alignments, emission, labels[0], frame_shift)
343
-
344
- alignments = _update_alignments_speaker(supervisions, alignments)
345
-
346
- return alignments
347
-
348
-
349
- class AsyncLatticeTokenizer(LatticeTokenizer):
350
- async def _post_async(self, endpoint: str, **kwargs):
351
- response = self.client_wrapper.post(endpoint, **kwargs)
352
- if inspect.isawaitable(response):
353
- return await response
354
- return response
355
-
356
- async def tokenize(
357
- self, supervisions: List[Supervision], split_sentence: bool = False
358
- ) -> Tuple[str, Dict[str, Any]]:
359
- if split_sentence:
360
- self.init_sentence_splitter()
361
- supervisions = self.split_sentences(supervisions)
362
-
363
- pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
364
- response = await self._post_async(
365
- "tokenize",
366
- json={
367
- "supervisions": [s.to_dict() for s in supervisions],
368
- "pronunciation_dictionaries": pronunciation_dictionaries,
369
- },
370
- )
371
- if response.status_code != 200:
372
- raise Exception(f"Failed to tokenize texts: {response.text}")
373
- result = response.json()
374
- lattice_id = result["id"]
375
- return (
376
- supervisions,
377
- lattice_id,
378
- (result["lattice_graph"], result["final_state"], result.get("acoustic_scale", 1.0)),
379
- )
380
-
381
- async def detokenize(
382
- self,
383
- lattice_id: str,
384
- lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
385
- supervisions: List[Supervision],
386
- return_details: bool = False,
387
- ) -> List[Supervision]:
388
- emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
389
- response = await self._post_async(
390
- "detokenize",
391
- json={
392
- "lattice_id": lattice_id,
393
- "frame_shift": frame_shift,
394
- "results": [t.to_dict() for t in results[0]],
395
- "labels": labels[0],
396
- "offset": offset,
397
- "channel": channel,
398
- "return_details": return_details,
399
- "destroy_lattice": True,
400
- },
401
- )
402
- if response.status_code == 422:
403
- raise LatticeDecodingError(
404
- lattice_id,
405
- original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
406
- )
407
- if response.status_code != 200:
408
- raise Exception(f"Failed to detokenize lattice: {response.text}")
409
-
410
- result = response.json()
411
- if not result.get("success"):
412
- return Exception("Failed to detokenize the alignment results.")
413
-
414
- alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
415
-
416
- if return_details:
417
- # Add emission confidence scores for segments and word-level alignments
418
- _add_confidence_scores(alignments, emission, labels[0], frame_shift)
471
+ _add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
419
472
 
420
473
  alignments = _update_alignments_speaker(supervisions, alignments)
421
474
 
@@ -427,6 +480,7 @@ def _add_confidence_scores(
427
480
  emission: torch.Tensor,
428
481
  labels: List[int],
429
482
  frame_shift: float,
483
+ offset: float = 0.0,
430
484
  ) -> None:
431
485
  """
432
486
  Add confidence scores to supervisions and their word-level alignments.
@@ -444,8 +498,8 @@ def _add_confidence_scores(
444
498
  tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
445
499
 
446
500
  for supervision in supervisions:
447
- start_frame = int(supervision.start / frame_shift)
448
- end_frame = int(supervision.end / frame_shift)
501
+ start_frame = int((supervision.start - offset) / frame_shift)
502
+ end_frame = int((supervision.end - offset) / frame_shift)
449
503
 
450
504
  # Compute segment-level confidence
451
505
  probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
@@ -457,8 +511,8 @@ def _add_confidence_scores(
457
511
  if hasattr(supervision, "alignment") and supervision.alignment:
458
512
  words = supervision.alignment.get("word", [])
459
513
  for w, item in enumerate(words):
460
- start = int(item.start / frame_shift) - start_frame
461
- end = int(item.end / frame_shift) - start_frame
514
+ start = int((item.start - offset) / frame_shift) - start_frame
515
+ end = int((item.end - offset) / frame_shift) - start_frame
462
516
  words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
463
517
 
464
518
 
@@ -473,3 +527,23 @@ def _update_alignments_speaker(supervisions: List[Supervision], alignments: List
473
527
  for supervision, alignment in zip(supervisions, alignments):
474
528
  alignment.speaker = supervision.speaker
475
529
  return alignments
530
+
531
+
532
+ def _load_tokenizer(
533
+ client_wrapper: Any,
534
+ model_path: str,
535
+ model_name: str,
536
+ device: str,
537
+ *,
538
+ tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
539
+ ) -> LatticeTokenizer:
540
+ """Instantiate tokenizer with consistent error handling."""
541
+ try:
542
+ return tokenizer_cls.from_pretrained(
543
+ client_wrapper=client_wrapper,
544
+ model_path=model_path,
545
+ model_name=model_name,
546
+ device=device,
547
+ )
548
+ except Exception as e:
549
+ raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)
lattifai/audio2.py ADDED
@@ -0,0 +1,211 @@
1
+ """Audio loading and resampling utilities."""
2
+
3
+ from collections import namedtuple
4
+ from pathlib import Path
5
+ from typing import BinaryIO, Iterable, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import torch
10
+ from lhotse.augmentation import get_or_create_resampler
11
+ from lhotse.utils import Pathlike
12
+
13
+ from lattifai.errors import AudioLoadError
14
+
15
+ # ChannelSelectorType = Union[int, Iterable[int], str]
16
+ ChannelSelectorType = Union[int, str]
17
+
18
+
19
+ class AudioData(namedtuple("AudioData", ["sampling_rate", "ndarray", "tensor", "device", "path"])):
20
+ """Audio data container with sampling rate, numpy array, tensor, and device information."""
21
+
22
+ def __str__(self) -> str:
23
+ return self.path
24
+
25
+ @property
26
+ def duration(self) -> float:
27
+ """Duration of the audio in seconds."""
28
+ return self.ndarray.shape[-1] / self.sampling_rate
29
+
30
+
31
+ class AudioLoader:
32
+ """Load and preprocess audio files into AudioData format."""
33
+
34
+ def __init__(
35
+ self,
36
+ device: str = "cpu",
37
+ ):
38
+ """Initialize AudioLoader.
39
+
40
+ Args:
41
+ device: Device to load audio tensors on (default: "cpu").
42
+ """
43
+ self.device = device
44
+ self._resampler_cache = {}
45
+
46
+ def _resample_audio(
47
+ self,
48
+ audio_sr: Tuple[torch.Tensor, int],
49
+ sampling_rate: int,
50
+ device: Optional[str],
51
+ channel_selector: Optional[ChannelSelectorType],
52
+ ) -> torch.Tensor:
53
+ """Resample audio to target sampling rate with channel selection.
54
+
55
+ Args:
56
+ audio_sr: Tuple of (audio_tensor, original_sample_rate).
57
+ sampling_rate: Target sampling rate.
58
+ device: Device to perform resampling on.
59
+ channel_selector: How to select channels.
60
+
61
+ Returns:
62
+ Resampled audio tensor of shape (1, T) or (C, T).
63
+ """
64
+ audio, sr = audio_sr
65
+
66
+ if channel_selector is None:
67
+ # keep the original multi-channel signal
68
+ tensor = audio
69
+ elif isinstance(channel_selector, int):
70
+ assert audio.shape[0] >= channel_selector, f"Invalid channel: {channel_selector}"
71
+ tensor = audio[channel_selector : channel_selector + 1].clone()
72
+ del audio
73
+ elif isinstance(channel_selector, str):
74
+ assert channel_selector == "average"
75
+ tensor = torch.mean(audio.to(device), dim=0, keepdim=True)
76
+ del audio
77
+ else:
78
+ raise ValueError(f"Unsupported channel_selector: {channel_selector}")
79
+ # assert isinstance(channel_selector, Iterable)
80
+ # num_channels = audio.shape[0]
81
+ # print(f"Selecting channels {channel_selector} from the signal with {num_channels} channels.")
82
+ # if max(channel_selector) >= num_channels:
83
+ # raise ValueError(
84
+ # f"Cannot select channel subset {channel_selector} from a signal with {num_channels} channels."
85
+ # )
86
+ # tensor = audio[channel_selector]
87
+
88
+ tensor = tensor.to(device)
89
+ if sr != sampling_rate:
90
+ cache_key = (sr, sampling_rate, device)
91
+ if cache_key not in self._resampler_cache:
92
+ self._resampler_cache[cache_key] = get_or_create_resampler(sr, sampling_rate).to(device=device)
93
+ resampler = self._resampler_cache[cache_key]
94
+
95
+ length = tensor.size(-1)
96
+ chunk_size = sampling_rate * 3600
97
+ if length > chunk_size:
98
+ resampled_chunks = []
99
+ for i in range(0, length, chunk_size):
100
+ resampled_chunks.append(resampler(tensor[..., i : i + chunk_size]))
101
+ tensor = torch.cat(resampled_chunks, dim=-1)
102
+ else:
103
+ tensor = resampler(tensor)
104
+
105
+ return tensor
106
+
107
+ def _load_audio(
108
+ self,
109
+ audio: Union[Pathlike, BinaryIO],
110
+ sampling_rate: int,
111
+ channel_selector: Optional[ChannelSelectorType],
112
+ ) -> torch.Tensor:
113
+ """Load audio from file or binary stream and resample to target rate.
114
+
115
+ Args:
116
+ audio: Path to audio file or binary stream.
117
+ sampling_rate: Target sampling rate.
118
+ channel_selector: How to select channels.
119
+
120
+ Returns:
121
+ Resampled audio tensor.
122
+
123
+ Raises:
124
+ ImportError: If PyAV is needed but not installed.
125
+ ValueError: If no audio stream found.
126
+ RuntimeError: If audio loading fails.
127
+ """
128
+ if isinstance(audio, Pathlike):
129
+ audio = str(Path(str(audio)).expanduser())
130
+
131
+ # load audio
132
+ try:
133
+ waveform, sample_rate = sf.read(audio, always_2d=True, dtype="float32") # numpy array
134
+ waveform = waveform.T # (channels, samples)
135
+ except Exception as primary_error:
136
+ # Fallback to PyAV for formats not supported by soundfile
137
+ try:
138
+ import av
139
+ except ImportError:
140
+ raise AudioLoadError(
141
+ "PyAV (av) is required for loading certain audio formats. "
142
+ f"Install it with: pip install av\n"
143
+ f"Primary error was: {primary_error}"
144
+ )
145
+
146
+ try:
147
+ container = av.open(audio)
148
+ audio_stream = next((s for s in container.streams if s.type == "audio"), None)
149
+
150
+ if audio_stream is None:
151
+ raise ValueError(f"No audio stream found in file: {audio}")
152
+
153
+ # Resample to target sample rate during decoding
154
+ audio_stream.codec_context.format = av.AudioFormat("flt") # 32-bit float
155
+
156
+ frames = []
157
+ for frame in container.decode(audio_stream):
158
+ # Convert frame to numpy array
159
+ array = frame.to_ndarray()
160
+ # Ensure shape is (channels, samples)
161
+ if array.ndim == 1:
162
+ array = array.reshape(1, -1)
163
+ elif array.ndim == 2 and array.shape[0] > array.shape[1]:
164
+ array = array.T
165
+ frames.append(array)
166
+
167
+ container.close()
168
+
169
+ if not frames:
170
+ raise ValueError(f"No audio data found in file: {audio}")
171
+
172
+ # Concatenate all frames
173
+ waveform = np.concatenate(frames, axis=1).astype(np.float32) # (channels, samples)
174
+ sample_rate = audio_stream.codec_context.sample_rate
175
+ except Exception as e:
176
+ raise RuntimeError(f"Failed to load audio file {audio}: {e}")
177
+
178
+ return self._resample_audio(
179
+ (torch.from_numpy(waveform), sample_rate),
180
+ sampling_rate,
181
+ device=self.device,
182
+ channel_selector=channel_selector,
183
+ )
184
+
185
+ def __call__(
186
+ self,
187
+ audio: Union[Pathlike, BinaryIO],
188
+ sampling_rate: int = 16000,
189
+ channel_selector: Optional[ChannelSelectorType] = "average",
190
+ ) -> AudioData:
191
+ """
192
+ Args:
193
+ audio: Path to audio file or binary stream.
194
+ channel_selector: How to select channels (default: "average").
195
+ sampling_rate: Target sampling rate (default: use instance sampling_rate).
196
+
197
+ Returns:
198
+ AudioData namedtuple with sampling_rate, ndarray, and tensor fields.
199
+ """
200
+ tensor = self._load_audio(audio, sampling_rate, channel_selector)
201
+
202
+ # tensor is (1, T) or (C, T)
203
+ ndarray = tensor.cpu().numpy()
204
+
205
+ return AudioData(
206
+ sampling_rate=sampling_rate,
207
+ ndarray=ndarray,
208
+ tensor=tensor,
209
+ device=self.device,
210
+ path=str(audio) if isinstance(audio, Pathlike) else "<BinaryIO>",
211
+ )
@@ -0,0 +1,20 @@
1
+ from typing import List, Optional
2
+
3
+ from lhotse.utils import Pathlike
4
+
5
+ from ..config.caption import InputCaptionFormat
6
+ from .caption import Caption
7
+ from .gemini_reader import GeminiReader, GeminiSegment
8
+ from .gemini_writer import GeminiWriter
9
+ from .supervision import Supervision
10
+ from .text_parser import normalize_text
11
+
12
+ __all__ = [
13
+ "Caption",
14
+ "Supervision",
15
+ "GeminiReader",
16
+ "GeminiWriter",
17
+ "GeminiSegment",
18
+ "normalize_text",
19
+ "InputCaptionFormat",
20
+ ]