lattifai 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. lattifai/__init__.py +0 -24
  2. lattifai/alignment/__init__.py +10 -1
  3. lattifai/alignment/lattice1_aligner.py +66 -58
  4. lattifai/alignment/lattice1_worker.py +1 -6
  5. lattifai/alignment/punctuation.py +38 -0
  6. lattifai/alignment/segmenter.py +1 -1
  7. lattifai/alignment/sentence_splitter.py +350 -0
  8. lattifai/alignment/text_align.py +440 -0
  9. lattifai/alignment/tokenizer.py +91 -220
  10. lattifai/caption/__init__.py +82 -6
  11. lattifai/caption/caption.py +335 -1143
  12. lattifai/caption/formats/__init__.py +199 -0
  13. lattifai/caption/formats/base.py +211 -0
  14. lattifai/caption/formats/gemini.py +722 -0
  15. lattifai/caption/formats/json.py +194 -0
  16. lattifai/caption/formats/lrc.py +309 -0
  17. lattifai/caption/formats/nle/__init__.py +9 -0
  18. lattifai/caption/formats/nle/audition.py +561 -0
  19. lattifai/caption/formats/nle/avid.py +423 -0
  20. lattifai/caption/formats/nle/fcpxml.py +549 -0
  21. lattifai/caption/formats/nle/premiere.py +589 -0
  22. lattifai/caption/formats/pysubs2.py +642 -0
  23. lattifai/caption/formats/sbv.py +147 -0
  24. lattifai/caption/formats/tabular.py +338 -0
  25. lattifai/caption/formats/textgrid.py +193 -0
  26. lattifai/caption/formats/ttml.py +652 -0
  27. lattifai/caption/formats/vtt.py +469 -0
  28. lattifai/caption/parsers/__init__.py +9 -0
  29. lattifai/caption/{text_parser.py → parsers/text_parser.py} +4 -2
  30. lattifai/caption/standardize.py +636 -0
  31. lattifai/caption/utils.py +474 -0
  32. lattifai/cli/__init__.py +2 -1
  33. lattifai/cli/caption.py +108 -1
  34. lattifai/cli/transcribe.py +4 -9
  35. lattifai/cli/youtube.py +4 -1
  36. lattifai/client.py +48 -84
  37. lattifai/config/__init__.py +11 -1
  38. lattifai/config/alignment.py +9 -2
  39. lattifai/config/caption.py +267 -23
  40. lattifai/config/media.py +20 -0
  41. lattifai/diarization/__init__.py +41 -1
  42. lattifai/mixin.py +36 -18
  43. lattifai/transcription/base.py +6 -1
  44. lattifai/transcription/lattifai.py +19 -54
  45. lattifai/utils.py +81 -13
  46. lattifai/workflow/__init__.py +28 -4
  47. lattifai/workflow/file_manager.py +2 -5
  48. lattifai/youtube/__init__.py +43 -0
  49. lattifai/youtube/client.py +1170 -0
  50. lattifai/youtube/types.py +23 -0
  51. lattifai-1.2.2.dist-info/METADATA +615 -0
  52. lattifai-1.2.2.dist-info/RECORD +76 -0
  53. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/entry_points.txt +1 -2
  54. lattifai/caption/gemini_reader.py +0 -371
  55. lattifai/caption/gemini_writer.py +0 -173
  56. lattifai/cli/app_installer.py +0 -142
  57. lattifai/cli/server.py +0 -44
  58. lattifai/server/app.py +0 -427
  59. lattifai/workflow/youtube.py +0 -577
  60. lattifai-1.2.0.dist-info/METADATA +0 -1133
  61. lattifai-1.2.0.dist-info/RECORD +0 -57
  62. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/WHEEL +0 -0
  63. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/licenses/LICENSE +0 -0
  64. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,474 @@
1
+ """Utility functions for caption processing.
2
+
3
+ This module provides utility functions for:
4
+ - Timecode offset handling (for professional timelines starting at 01:00:00:00)
5
+ - Overlap/collision resolution (merge or trim modes)
6
+ - SRT format optimization (UTF-8 BOM, comma-separated milliseconds)
7
+ """
8
+
9
+ from copy import deepcopy
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+ from typing import TYPE_CHECKING, List, Optional, Tuple
13
+
14
+ if TYPE_CHECKING:
15
+ from .supervision import Supervision
16
+
17
+
18
+ class CollisionMode(Enum):
19
+ """Mode for resolving overlapping captions."""
20
+
21
+ MERGE = "merge" # Merge overlapping lines with line break
22
+ TRIM = "trim" # Trim earlier caption to end before later starts
23
+ KEEP = "keep" # Keep overlaps as-is (may cause issues in some NLE)
24
+
25
+
26
+ @dataclass
27
+ class TimecodeOffset:
28
+ """Configuration for timecode offset.
29
+
30
+ Professional timelines often start at 01:00:00:00 instead of 00:00:00:00.
31
+ This class handles the offset conversion.
32
+
33
+ Attributes:
34
+ hours: Hour offset (default 0)
35
+ minutes: Minute offset (default 0)
36
+ seconds: Second offset (default 0)
37
+ frames: Frame offset (default 0)
38
+ fps: Frame rate for frame-based offset calculation
39
+ """
40
+
41
+ hours: int = 0
42
+ minutes: int = 0
43
+ seconds: float = 0.0
44
+ frames: int = 0
45
+ fps: float = 25.0
46
+
47
+ @property
48
+ def total_seconds(self) -> float:
49
+ """Calculate total offset in seconds."""
50
+ return self.hours * 3600 + self.minutes * 60 + self.seconds + (self.frames / self.fps)
51
+
52
+ @classmethod
53
+ def from_timecode(cls, timecode: str, fps: float = 25.0) -> "TimecodeOffset":
54
+ """Create offset from timecode string.
55
+
56
+ Args:
57
+ timecode: Timecode string (HH:MM:SS:FF or HH:MM:SS.mmm)
58
+ fps: Frame rate
59
+
60
+ Returns:
61
+ TimecodeOffset instance
62
+ """
63
+ # Handle different separators
64
+ if ";" in timecode:
65
+ # Drop-frame format
66
+ parts = timecode.replace(";", ":").split(":")
67
+ else:
68
+ parts = timecode.split(":")
69
+
70
+ hours = int(parts[0]) if len(parts) > 0 else 0
71
+ minutes = int(parts[1]) if len(parts) > 1 else 0
72
+
73
+ # Handle seconds (may have frames or milliseconds)
74
+ if len(parts) > 2:
75
+ sec_part = parts[2]
76
+ if "." in sec_part:
77
+ # Millisecond format
78
+ seconds = float(sec_part)
79
+ frames = 0
80
+ else:
81
+ seconds = float(sec_part)
82
+ frames = int(parts[3]) if len(parts) > 3 else 0
83
+ else:
84
+ seconds = 0.0
85
+ frames = 0
86
+
87
+ return cls(hours=hours, minutes=minutes, seconds=seconds, frames=frames, fps=fps)
88
+
89
+ @classmethod
90
+ def broadcast_start(cls, fps: float = 25.0) -> "TimecodeOffset":
91
+ """Create standard broadcast start offset (01:00:00:00).
92
+
93
+ Args:
94
+ fps: Frame rate
95
+
96
+ Returns:
97
+ TimecodeOffset for broadcast start
98
+ """
99
+ return cls(hours=1, fps=fps)
100
+
101
+
102
+ def apply_timecode_offset(
103
+ supervisions: List["Supervision"],
104
+ offset: TimecodeOffset,
105
+ ) -> List["Supervision"]:
106
+ """Apply timecode offset to all supervisions.
107
+
108
+ Args:
109
+ supervisions: List of supervision segments
110
+ offset: Timecode offset to apply
111
+
112
+ Returns:
113
+ New list of supervisions with offset applied
114
+ """
115
+ from .supervision import Supervision
116
+
117
+ offset_seconds = offset.total_seconds
118
+ result = []
119
+
120
+ for sup in supervisions:
121
+ new_sup = Supervision(
122
+ text=sup.text,
123
+ start=sup.start + offset_seconds,
124
+ duration=sup.duration,
125
+ speaker=sup.speaker,
126
+ id=sup.id,
127
+ language=sup.language,
128
+ alignment=deepcopy(getattr(sup, "alignment", None)),
129
+ custom=sup.custom.copy() if sup.custom else None,
130
+ )
131
+
132
+ # Also offset word-level alignments if present
133
+ if new_sup.alignment and "word" in new_sup.alignment:
134
+ from lhotse.supervision import AlignmentItem
135
+
136
+ new_words = []
137
+ for word in new_sup.alignment["word"]:
138
+ new_words.append(
139
+ AlignmentItem(
140
+ symbol=word.symbol,
141
+ start=word.start + offset_seconds,
142
+ duration=word.duration,
143
+ score=word.score,
144
+ )
145
+ )
146
+ new_sup.alignment["word"] = new_words
147
+
148
+ result.append(new_sup)
149
+
150
+ return result
151
+
152
+
153
+ def resolve_overlaps(
154
+ supervisions: List["Supervision"],
155
+ mode: CollisionMode = CollisionMode.MERGE,
156
+ gap_threshold: float = 0.05,
157
+ ) -> List["Supervision"]:
158
+ """Resolve overlapping supervisions.
159
+
160
+ Args:
161
+ supervisions: List of supervision segments (should be sorted by start time)
162
+ mode: How to handle overlaps (MERGE, TRIM, or KEEP)
163
+ gap_threshold: Minimum gap between captions in seconds (for TRIM mode)
164
+
165
+ Returns:
166
+ New list of supervisions with overlaps resolved
167
+ """
168
+ from .supervision import Supervision
169
+
170
+ if not supervisions or mode == CollisionMode.KEEP:
171
+ return supervisions
172
+
173
+ # Sort by start time
174
+ sorted_sups = sorted(supervisions, key=lambda x: x.start)
175
+ result = []
176
+
177
+ i = 0
178
+ while i < len(sorted_sups):
179
+ current = sorted_sups[i]
180
+
181
+ # Find all overlapping supervisions
182
+ overlapping = [current]
183
+ j = i + 1
184
+ while j < len(sorted_sups):
185
+ next_sup = sorted_sups[j]
186
+ # Check if next overlaps with any in our group
187
+ current_end = max(s.end for s in overlapping)
188
+ if next_sup.start < current_end:
189
+ overlapping.append(next_sup)
190
+ j += 1
191
+ else:
192
+ break
193
+
194
+ if len(overlapping) == 1:
195
+ # No overlap
196
+ result.append(current)
197
+ i += 1
198
+ elif mode == CollisionMode.MERGE:
199
+ # Merge all overlapping into one
200
+ merged = _merge_supervisions(overlapping)
201
+ result.append(merged)
202
+ i = j
203
+ elif mode == CollisionMode.TRIM:
204
+ # Trim each to not overlap with next
205
+ for k, sup in enumerate(overlapping[:-1]):
206
+ next_sup = overlapping[k + 1]
207
+ # Trim current to end before next starts
208
+ new_duration = max(gap_threshold, next_sup.start - sup.start - gap_threshold)
209
+ trimmed = Supervision(
210
+ text=sup.text,
211
+ start=sup.start,
212
+ duration=min(sup.duration, new_duration),
213
+ speaker=sup.speaker,
214
+ id=sup.id,
215
+ language=sup.language,
216
+ alignment=sup.alignment,
217
+ custom=sup.custom,
218
+ )
219
+ result.append(trimmed)
220
+ # Add last one as-is
221
+ result.append(overlapping[-1])
222
+ i = j
223
+ else:
224
+ result.append(current)
225
+ i += 1
226
+
227
+ return result
228
+
229
+
230
+ def _merge_supervisions(supervisions: List["Supervision"]) -> "Supervision":
231
+ """Merge multiple overlapping supervisions into one.
232
+
233
+ Args:
234
+ supervisions: List of overlapping supervisions
235
+
236
+ Returns:
237
+ Single merged supervision
238
+ """
239
+ from .supervision import Supervision
240
+
241
+ if not supervisions:
242
+ raise ValueError("Cannot merge empty supervision list")
243
+
244
+ if len(supervisions) == 1:
245
+ return supervisions[0]
246
+
247
+ # Calculate merged timing
248
+ start = min(s.start for s in supervisions)
249
+ end = max(s.end for s in supervisions)
250
+
251
+ # Merge text with line breaks, indicating speakers
252
+ texts = []
253
+ for sup in supervisions:
254
+ text = sup.text.strip() if sup.text else ""
255
+ if sup.speaker:
256
+ texts.append(f"- {sup.speaker}: {text}")
257
+ else:
258
+ texts.append(f"- {text}")
259
+
260
+ merged_text = "\n".join(texts)
261
+
262
+ # Use first supervision's speaker or None for mixed speakers
263
+ speakers = set(s.speaker for s in supervisions if s.speaker)
264
+ speaker = supervisions[0].speaker if len(speakers) == 1 else None
265
+
266
+ return Supervision(
267
+ text=merged_text,
268
+ start=start,
269
+ duration=end - start,
270
+ speaker=speaker,
271
+ id=supervisions[0].id,
272
+ language=supervisions[0].language,
273
+ )
274
+
275
+
276
+ def format_srt_timestamp(seconds: float) -> str:
277
+ """Format timestamp for SRT format.
278
+
279
+ SRT uses comma as millisecond separator: HH:MM:SS,mmm
280
+
281
+ Args:
282
+ seconds: Time in seconds
283
+
284
+ Returns:
285
+ SRT-formatted timestamp string
286
+ """
287
+ if seconds < 0:
288
+ seconds = 0
289
+
290
+ hours = int(seconds // 3600)
291
+ minutes = int((seconds % 3600) // 60)
292
+ secs = int(seconds % 60)
293
+ millis = int((seconds % 1) * 1000)
294
+
295
+ return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
296
+
297
+
298
+ def generate_srt_content(
299
+ supervisions: List["Supervision"],
300
+ include_speaker: bool = True,
301
+ use_bom: bool = True,
302
+ ) -> bytes:
303
+ """Generate SRT content with proper formatting.
304
+
305
+ Args:
306
+ supervisions: List of supervision segments
307
+ include_speaker: Include speaker labels in text
308
+ use_bom: Include UTF-8 BOM for Windows compatibility
309
+
310
+ Returns:
311
+ SRT content as bytes
312
+ """
313
+ lines = []
314
+
315
+ for i, sup in enumerate(supervisions, 1):
316
+ # Sequence number
317
+ lines.append(str(i))
318
+
319
+ # Timestamp line with comma separator
320
+ start_ts = format_srt_timestamp(sup.start)
321
+ end_ts = format_srt_timestamp(sup.end)
322
+ lines.append(f"{start_ts} --> {end_ts}")
323
+
324
+ # Text content
325
+ text = sup.text.strip() if sup.text else ""
326
+ if include_speaker and sup.speaker:
327
+ # Check if speaker was originally in text
328
+ if not (hasattr(sup, "custom") and sup.custom and not sup.custom.get("original_speaker", True)):
329
+ text = f"{sup.speaker}: {text}"
330
+ lines.append(text)
331
+
332
+ # Blank line between entries
333
+ lines.append("")
334
+
335
+ content = "\n".join(lines)
336
+
337
+ if use_bom:
338
+ # UTF-8 with BOM for Windows compatibility
339
+ return b"\xef\xbb\xbf" + content.encode("utf-8")
340
+ else:
341
+ return content.encode("utf-8")
342
+
343
+
344
+ def detect_overlaps(supervisions: List["Supervision"]) -> List[Tuple[int, int]]:
345
+ """Detect all overlapping supervision pairs.
346
+
347
+ Args:
348
+ supervisions: List of supervision segments
349
+
350
+ Returns:
351
+ List of tuples (index1, index2) where supervisions overlap
352
+ """
353
+ overlaps = []
354
+ sorted_sups = sorted(enumerate(supervisions), key=lambda x: x[1].start)
355
+
356
+ for i in range(len(sorted_sups) - 1):
357
+ idx1, sup1 = sorted_sups[i]
358
+ for j in range(i + 1, len(sorted_sups)):
359
+ idx2, sup2 = sorted_sups[j]
360
+ if sup2.start >= sup1.end:
361
+ break
362
+ overlaps.append((idx1, idx2))
363
+
364
+ return overlaps
365
+
366
+
367
+ def split_long_lines(
368
+ supervisions: List["Supervision"],
369
+ max_chars_per_line: int = 42,
370
+ max_lines: int = 2,
371
+ ) -> List["Supervision"]:
372
+ """Split supervisions with long text into multiple segments.
373
+
374
+ Useful for broadcast compliance where line length limits are strict.
375
+
376
+ Args:
377
+ supervisions: List of supervision segments
378
+ max_chars_per_line: Maximum characters per line
379
+ max_lines: Maximum lines per supervision
380
+
381
+ Returns:
382
+ New list with long supervisions split
383
+ """
384
+ from .supervision import Supervision
385
+
386
+ result = []
387
+ max_total_chars = max_chars_per_line * max_lines
388
+
389
+ for sup in supervisions:
390
+ text = sup.text.strip() if sup.text else ""
391
+
392
+ if len(text) <= max_total_chars:
393
+ # Text fits, just wrap lines if needed
394
+ wrapped = _wrap_text(text, max_chars_per_line, max_lines)
395
+ new_sup = Supervision(
396
+ text=wrapped,
397
+ start=sup.start,
398
+ duration=sup.duration,
399
+ speaker=sup.speaker,
400
+ id=sup.id,
401
+ language=sup.language,
402
+ alignment=sup.alignment,
403
+ custom=sup.custom,
404
+ )
405
+ result.append(new_sup)
406
+ else:
407
+ # Split into multiple supervisions
408
+ chunks = _split_text_chunks(text, max_total_chars)
409
+ chunk_duration = sup.duration / len(chunks)
410
+
411
+ for i, chunk in enumerate(chunks):
412
+ wrapped = _wrap_text(chunk, max_chars_per_line, max_lines)
413
+ new_sup = Supervision(
414
+ text=wrapped,
415
+ start=sup.start + i * chunk_duration,
416
+ duration=chunk_duration,
417
+ speaker=sup.speaker if i == 0 else None,
418
+ id=f"{sup.id}_{i}" if sup.id else None,
419
+ language=sup.language,
420
+ )
421
+ result.append(new_sup)
422
+
423
+ return result
424
+
425
+
426
+ def _wrap_text(text: str, max_chars: int, max_lines: int) -> str:
427
+ """Wrap text to fit within character and line limits."""
428
+ words = text.split()
429
+ lines = []
430
+ current_line = []
431
+ current_length = 0
432
+
433
+ for word in words:
434
+ word_len = len(word)
435
+ if current_length + word_len + (1 if current_line else 0) <= max_chars:
436
+ current_line.append(word)
437
+ current_length += word_len + (1 if len(current_line) > 1 else 0)
438
+ else:
439
+ if current_line:
440
+ lines.append(" ".join(current_line))
441
+ current_line = [word]
442
+ current_length = word_len
443
+
444
+ if len(lines) >= max_lines:
445
+ break
446
+
447
+ if current_line and len(lines) < max_lines:
448
+ lines.append(" ".join(current_line))
449
+
450
+ return "\n".join(lines[:max_lines])
451
+
452
+
453
+ def _split_text_chunks(text: str, max_chars: int) -> List[str]:
454
+ """Split text into chunks that fit within character limit."""
455
+ words = text.split()
456
+ chunks = []
457
+ current_chunk = []
458
+ current_length = 0
459
+
460
+ for word in words:
461
+ word_len = len(word)
462
+ if current_length + word_len + (1 if current_chunk else 0) <= max_chars:
463
+ current_chunk.append(word)
464
+ current_length += word_len + (1 if len(current_chunk) > 1 else 0)
465
+ else:
466
+ if current_chunk:
467
+ chunks.append(" ".join(current_chunk))
468
+ current_chunk = [word]
469
+ current_length = word_len
470
+
471
+ if current_chunk:
472
+ chunks.append(" ".join(current_chunk))
473
+
474
+ return chunks
lattifai/cli/__init__.py CHANGED
@@ -4,7 +4,7 @@ import nemo_run as run # noqa: F401
4
4
 
5
5
  # Import and re-export entrypoints at package level so NeMo Run can find them
6
6
  from lattifai.cli.alignment import align
7
- from lattifai.cli.caption import convert
7
+ from lattifai.cli.caption import convert, diff
8
8
  from lattifai.cli.diarization import diarize
9
9
  from lattifai.cli.transcribe import transcribe, transcribe_align
10
10
  from lattifai.cli.youtube import youtube
@@ -12,6 +12,7 @@ from lattifai.cli.youtube import youtube
12
12
  __all__ = [
13
13
  "align",
14
14
  "convert",
15
+ "diff",
15
16
  "diarize",
16
17
  "transcribe",
17
18
  "transcribe_align",
lattifai/cli/caption.py CHANGED
@@ -7,6 +7,7 @@ from lhotse.utils import Pathlike
7
7
  from typing_extensions import Annotated
8
8
 
9
9
  from lattifai.config import CaptionConfig
10
+ from lattifai.config.caption import KaraokeConfig
10
11
  from lattifai.utils import safe_print
11
12
 
12
13
 
@@ -16,6 +17,8 @@ def convert(
16
17
  output_path: Pathlike,
17
18
  include_speaker_in_text: bool = False,
18
19
  normalize_text: bool = False,
20
+ word_level: bool = False,
21
+ karaoke: bool = False,
19
22
  ):
20
23
  """
21
24
  Convert caption file to another format.
@@ -33,6 +36,11 @@ def convert(
33
36
  normalize_text: Whether to normalize caption text during conversion.
34
37
  This applies text cleaning such as removing HTML tags, decoding entities,
35
38
  collapsing whitespace, and standardizing punctuation.
39
+ word_level: Use word-level output format if supported.
40
+ When True without karaoke: outputs word-per-segment (each word as separate segment).
41
+ JSON format will include a 'words' field with word-level timestamps.
42
+ karaoke: Enable karaoke styling (requires word_level=True).
43
+ When True: outputs karaoke format (ASS \\kf tags, enhanced LRC, etc.).
36
44
 
37
45
  Examples:
38
46
  # Basic format conversion (positional arguments)
@@ -41,6 +49,15 @@ def convert(
41
49
  # Convert with text normalization
42
50
  lai caption convert input.srt output.json normalize_text=true
43
51
 
52
+ # Convert to word-per-segment output (if input has alignment)
53
+ lai caption convert input.json output.srt word_level=true
54
+
55
+ # Convert to karaoke format (ASS with \\kf tags)
56
+ lai caption convert input.json output.ass word_level=true karaoke=true
57
+
58
+ # Export JSON with word-level timestamps
59
+ lai caption convert input.srt output.json word_level=true
60
+
44
61
  # Mixing positional and keyword arguments
45
62
  lai caption convert input.srt output.vtt \\
46
63
  include_speaker_in_text=false \\
@@ -53,8 +70,16 @@ def convert(
53
70
  """
54
71
  from lattifai.caption import Caption
55
72
 
73
+ # Create karaoke_config if karaoke flag is set
74
+ karaoke_config = KaraokeConfig(enabled=True) if karaoke else None
75
+
56
76
  caption = Caption.read(input_path, normalize_text=normalize_text)
57
- caption.write(output_path, include_speaker_in_text=include_speaker_in_text)
77
+ caption.write(
78
+ output_path,
79
+ include_speaker_in_text=include_speaker_in_text,
80
+ word_level=word_level,
81
+ karaoke_config=karaoke_config,
82
+ )
58
83
 
59
84
  safe_print(f"✅ Converted {input_path} -> {output_path}")
60
85
  return output_path
@@ -178,6 +203,88 @@ def shift(
178
203
  return output_path
179
204
 
180
205
 
206
+ @run.cli.entrypoint(name="diff", namespace="caption")
207
+ def diff(
208
+ ref_path: Pathlike,
209
+ hyp_path: Pathlike,
210
+ split_sentence: bool = True,
211
+ verbose: bool = True,
212
+ ):
213
+ """
214
+ Compare and align caption supervisions with transcription segments.
215
+
216
+ This command reads a reference caption file and a hypothesis file, then performs
217
+ text alignment to show how they match up. It's useful for comparing
218
+ original subtitles against ASR (Automatic Speech Recognition) results.
219
+
220
+ Args:
221
+ ref_path: Path to reference caption file (ground truth)
222
+ hyp_path: Path to hypothesis file (e.g., ASR results)
223
+ split_sentence: Enable sentence splitting before alignment (default: True)
224
+ verbose: Enable verbose output to show detailed alignment info (default: True)
225
+
226
+ Examples:
227
+ # Compare reference with hypothesis (positional arguments)
228
+ lai caption diff subtitles.srt transcription.json
229
+
230
+ # Disable sentence splitting
231
+ lai caption diff subtitles.srt transcription.json split_sentence=false
232
+
233
+ # Disable verbose output
234
+ lai caption diff subtitles.srt transcription.json verbose=false
235
+ """
236
+ from pathlib import Path
237
+
238
+ from lattifai.alignment.sentence_splitter import SentenceSplitter
239
+ from lattifai.alignment.text_align import align_supervisions_and_transcription
240
+ from lattifai.caption import Caption
241
+
242
+ ref_path = Path(ref_path).expanduser()
243
+ hyp_path = Path(hyp_path).expanduser()
244
+
245
+ # Read reference caption (supervisions)
246
+ caption_obj = Caption.read(ref_path)
247
+
248
+ # Read hypothesis
249
+ hyp_obj = Caption.read(hyp_path)
250
+
251
+ # Apply sentence splitting if enabled
252
+ if split_sentence:
253
+ splitter = SentenceSplitter(device="cpu", lazy_init=True)
254
+ caption_obj.supervisions = splitter.split_sentences(caption_obj.supervisions)
255
+ hyp_obj.supervisions = splitter.split_sentences(hyp_obj.supervisions)
256
+
257
+ # Set transcription on caption object
258
+ caption_obj.transcription = hyp_obj.supervisions
259
+
260
+ safe_print(f"📖 Reference: {len(caption_obj.supervisions)} segments from {ref_path}")
261
+ safe_print(f"🎤 Hypothesis: {len(caption_obj.transcription)} segments from {hyp_path}")
262
+ if split_sentence:
263
+ safe_print("✂️ Sentence splitting: enabled")
264
+ safe_print("")
265
+
266
+ # Perform alignment
267
+ results = align_supervisions_and_transcription(
268
+ caption=caption_obj,
269
+ verbose=verbose,
270
+ )
271
+
272
+ # # Print summary
273
+ # safe_print("")
274
+ # safe_print("=" * 72)
275
+ # safe_print(f"📊 Alignment Summary: {len(results)} groups")
276
+ # for idx, (sub_align, asr_align, quality, timestamp, typing) in enumerate(results):
277
+ # sub_count = len(sub_align) if sub_align else 0
278
+ # asr_count = len(asr_align) if asr_align else 0
279
+ # safe_print(f" Group {idx + 1}: ref={sub_count}, hyp={asr_count}, {quality.info}, typing={typing}")
280
+
281
+ return results
282
+
283
+
284
+ def main_diff():
285
+ run.cli.main(diff)
286
+
287
+
181
288
  def main_convert():
182
289
  run.cli.main(convert)
183
290
 
@@ -108,12 +108,7 @@ def transcribe(
108
108
  is_url = media_config.is_input_remote()
109
109
 
110
110
  # Prepare output paths
111
- if is_url:
112
- # For URLs, use output_dir from media_config or current directory
113
- output_path = media_config.output_dir
114
- else:
115
- # For files, use input path directory
116
- output_path = Path(media_config.input_path).parent
111
+ output_dir = media_config.output_dir or Path(media_config.input_path).parent
117
112
 
118
113
  # Create transcriber
119
114
  if not transcription_config.lattice_model_path:
@@ -134,13 +129,13 @@ def transcribe(
134
129
  if is_url:
135
130
  # Download media first, then transcribe
136
131
  safe_print(colorful.cyan(" Downloading media from URL..."))
137
- from lattifai.workflow.youtube import YouTubeDownloader
132
+ from lattifai.youtube import YouTubeDownloader
138
133
 
139
134
  downloader = YouTubeDownloader()
140
135
  input_path = asyncio.run(
141
136
  downloader.download_media(
142
137
  url=media_config.input_path,
143
- output_dir=str(output_path),
138
+ output_dir=str(output_dir),
144
139
  media_format=media_config.normalize_format(),
145
140
  force_overwrite=media_config.force_overwrite,
146
141
  )
@@ -167,7 +162,7 @@ def transcribe(
167
162
  if is_url:
168
163
  # For URLs, generate output filename based on transcriber
169
164
  output_format = transcriber.file_suffix.lstrip(".")
170
- final_output = output_path / f"youtube_LattifAI_{transcriber.name}.{output_format}"
165
+ final_output = output_dir / f"youtube_LattifAI_{transcriber.name}.{output_format}"
171
166
  else:
172
167
  # For files, use input filename with suffix
173
168
  final_output = Path(media_config.input_path).with_suffix(".LattifAI.srt")
lattifai/cli/youtube.py CHANGED
@@ -44,7 +44,8 @@ def youtube(
44
44
  Args:
45
45
  yt_url: YouTube video URL (can be provided as positional argument)
46
46
  media: Media configuration for controlling formats and output directories.
47
- Fields: input_path (YouTube URL), output_dir, output_format, force_overwrite
47
+ Fields: input_path (YouTube URL), output_dir, output_format, force_overwrite,
48
+ audio_track_id (default: "original"), quality (default: "best")
48
49
  client: API client configuration.
49
50
  Fields: api_key, timeout, max_retries
50
51
  alignment: Alignment configuration (model selection and inference settings).
@@ -129,6 +130,8 @@ def youtube(
129
130
  channel_selector=media_config.channel_selector,
130
131
  streaming_chunk_secs=media_config.streaming_chunk_secs,
131
132
  use_transcription=use_transcription,
133
+ audio_track_id=media_config.audio_track_id,
134
+ quality=media_config.quality,
132
135
  )
133
136
 
134
137