easytranscriber 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {easytranscriber-0.2.0/src/easytranscriber.egg-info → easytranscriber-0.2.2}/PKG-INFO +18 -2
  2. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/README.md +17 -1
  3. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/pyproject.toml +2 -2
  4. easytranscriber-0.2.2/src/easytranscriber/asr/cohere.py +136 -0
  5. easytranscriber-0.2.2/src/easytranscriber/data/collators.py +57 -0
  6. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/data/dataset.py +13 -0
  7. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/pipelines.py +93 -30
  8. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/search/__main__.py +20 -2
  9. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/search/app.py +3 -3
  10. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/search/db.py +1 -0
  11. easytranscriber-0.2.2/src/easytranscriber/search/indexer.py +180 -0
  12. {easytranscriber-0.2.0 → easytranscriber-0.2.2/src/easytranscriber.egg-info}/PKG-INFO +18 -2
  13. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/SOURCES.txt +1 -0
  14. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/requires.txt +1 -1
  15. easytranscriber-0.2.0/src/easytranscriber/data/collators.py +0 -30
  16. easytranscriber-0.2.0/src/easytranscriber/search/indexer.py +0 -128
  17. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/LICENSE +0 -0
  18. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/setup.cfg +0 -0
  19. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/asr/ct2.py +0 -0
  20. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/asr/hf.py +0 -0
  21. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/audio.py +0 -0
  22. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/data/__init__.py +0 -0
  23. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/data/datamodel.py +0 -0
  24. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/search/__init__.py +0 -0
  25. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/text/normalization.py +0 -0
  26. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber/utils.py +0 -0
  27. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/dependency_links.txt +0 -0
  28. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/entry_points.txt +0 -0
  29. {easytranscriber-0.2.0 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: easytranscriber
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Speech recognition with accurate word-level timestamps.
5
5
  Author: Faton Rekathati
6
6
  Project-URL: Repository, https://github.com/kb-labb/easytranscriber
@@ -19,7 +19,7 @@ Requires-Dist: ctranslate2>=4.4.0
19
19
  Requires-Dist: msgspec
20
20
  Requires-Dist: easyaligner==0.*
21
21
  Provides-Extra: search
22
- Requires-Dist: fastapi>=0.104.0; extra == "search"
22
+ Requires-Dist: fastapi>=0.109.0; extra == "search"
23
23
  Requires-Dist: uvicorn[standard]>=0.24.0; extra == "search"
24
24
  Requires-Dist: jinja2>=3.1.0; extra == "search"
25
25
  Dynamic: license-file
@@ -127,6 +127,9 @@ The documentation is available at [kb-labb.github.io/easytranscriber/](https://k
127
127
  * [Text normalization tutorial](https://kb-labb.github.io/easytranscriber/get-started/text-processing.html).
128
128
  * [API reference](https://kb-labb.github.io/easytranscriber/reference/).
129
129
 
130
+ > [!TIP]
131
+ > Check out the [`easyaligner`](https://kb-labb.github.io/easyaligner/) library for a user friendly pipeline for forced alignment of text and audio.
132
+
130
133
  ## Acknowledgements
131
134
 
132
135
  `easytranscriber` draws heavy inspiration from [`WhisperX`](https://github.com/m-bain/whisperX) [(Bain et al., 2023)](https://www.isca-archive.org/interspeech_2023/bain23_interspeech.pdf).
@@ -134,3 +137,16 @@ The documentation is available at [kb-labb.github.io/easytranscriber/](https://k
134
137
  The forced alignment component of `easytranscriber` is based on Pytorch's forced alignment API, which implements a GPU-accelerated version of the Viterbi algorithm as described in [Pratap et al., 2024](https://jmlr.org/papers/volume25/23-1318/23-1318.pdf#page=8).
135
138
 
136
139
  LibriVox for public domain audiobooks used as tutorial examples.
140
+
141
+ ## Citation
142
+
143
+ ```
144
+ @online{rekathati2026,
145
+ author = {Rekathati, Faton},
146
+ title = {Easytranscriber: {Speech} Recognition with Precise
147
+ Timestamps},
148
+ date = {2026-02-26},
149
+ url = {https://kb-labb.github.io/posts/2026-02-26-easytranscriber/},
150
+ langid = {en}
151
+ }
152
+ ```
@@ -101,10 +101,26 @@ The documentation is available at [kb-labb.github.io/easytranscriber/](https://k
101
101
  * [Text normalization tutorial](https://kb-labb.github.io/easytranscriber/get-started/text-processing.html).
102
102
  * [API reference](https://kb-labb.github.io/easytranscriber/reference/).
103
103
 
104
+ > [!TIP]
105
+ > Check out the [`easyaligner`](https://kb-labb.github.io/easyaligner/) library for a user friendly pipeline for forced alignment of text and audio.
106
+
104
107
  ## Acknowledgements
105
108
 
106
109
  `easytranscriber` draws heavy inspiration from [`WhisperX`](https://github.com/m-bain/whisperX) [(Bain et al., 2023)](https://www.isca-archive.org/interspeech_2023/bain23_interspeech.pdf).
107
110
 
108
111
  The forced alignment component of `easytranscriber` is based on Pytorch's forced alignment API, which implements a GPU-accelerated version of the Viterbi algorithm as described in [Pratap et al., 2024](https://jmlr.org/papers/volume25/23-1318/23-1318.pdf#page=8).
109
112
 
110
- LibriVox for public domain audiobooks used as tutorial examples.
113
+ LibriVox for public domain audiobooks used as tutorial examples.
114
+
115
+ ## Citation
116
+
117
+ ```
118
+ @online{rekathati2026,
119
+ author = {Rekathati, Faton},
120
+ title = {Easytranscriber: {Speech} Recognition with Precise
121
+ Timestamps},
122
+ date = {2026-02-26},
123
+ url = {https://kb-labb.github.io/posts/2026-02-26-easytranscriber/},
124
+ langid = {en}
125
+ }
126
+ ```
@@ -3,7 +3,7 @@ requires = ["setuptools>=67.0.0"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
- version = "0.2.0"
6
+ version = "0.2.2"
7
7
  name = "easytranscriber"
8
8
  requires-python = ">= 3.10"
9
9
  description = "Speech recognition with accurate word-level timestamps."
@@ -26,7 +26,7 @@ dependencies = [
26
26
 
27
27
  [project.optional-dependencies]
28
28
  search = [
29
- "fastapi>=0.104.0",
29
+ "fastapi>=0.109.0",
30
30
  "uvicorn[standard]>=0.24.0",
31
31
  "jinja2>=3.1.0",
32
32
  ]
@@ -0,0 +1,136 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from easyaligner.utils import save_metadata_json
6
+ from tqdm import tqdm
7
+
8
+ from easytranscriber.data.collators import cohere_transcribe_collate_fn
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # The 14 languages Cohere Transcribe was trained on.
13
+ # https://huggingface.co/CohereLabs/cohere-transcribe-03-2026
14
+ COHERE_SUPPORTED_LANGUAGES = frozenset(
15
+ {"ar", "de", "el", "en", "es", "fr", "it", "ja", "ko", "nl", "pl", "pt", "vi", "zh"}
16
+ )
17
+
18
+
19
+ def _require_transformers():
20
+ try:
21
+ from transformers import CohereAsrForConditionalGeneration
22
+ except ImportError as e:
23
+ raise ImportError(
24
+ "The 'cohere' ASR backend requires transformers>=5.4.0 "
25
+ "(CohereAsrForConditionalGeneration is not available in the installed version). "
26
+ "Upgrade with: pip install --upgrade 'transformers>=5.4.0'"
27
+ ) from e
28
+ return CohereAsrForConditionalGeneration
29
+
30
+
31
+ def transcribe(
32
+ model,
33
+ processor,
34
+ file_dataloader: torch.utils.data.DataLoader,
35
+ language: str,
36
+ batch_size: int = 4,
37
+ max_new_tokens: int = 256,
38
+ punctuation: bool = True,
39
+ sample_rate: int = 16000,
40
+ num_workers: int = 2,
41
+ prefetch_factor: int = 2,
42
+ output_dir: str = "output/transcriptions",
43
+ generate_kwargs: dict | None = None,
44
+ ):
45
+ """
46
+ Transcribe audio files using the Cohere Transcribe model.
47
+
48
+ Parameters
49
+ ----------
50
+ model : transformers.CohereAsrForConditionalGeneration
51
+ Cohere ASR model.
52
+ processor : transformers.AutoProcessor
53
+ Cohere ASR processor.
54
+ file_dataloader : torch.utils.data.DataLoader
55
+ DataLoader yielding audio file datasets. The underlying
56
+ ``StreamingAudioFileDataset`` must be constructed with
57
+ ``return_raw_audio=True`` so the processor can be called on whole
58
+ batches (per-sample calls return variable-length features).
59
+ language : str
60
+ ISO 639-1 language code (e.g. 'en', 'ja'). Required — Cohere has
61
+ no built-in language detection.
62
+ batch_size : int, optional
63
+ Batch size for inference.
64
+ max_new_tokens : int, optional
65
+ Maximum number of tokens to generate per chunk. Default is 256.
66
+ punctuation : bool, optional
67
+ Emit punctuation in transcriptions. Default is True.
68
+ sample_rate : int, optional
69
+ Sample rate of audio passed to the processor. Default is 16000.
70
+ num_workers : int, optional
71
+ Number of workers for the feature dataloader.
72
+ prefetch_factor : int, optional
73
+ Prefetch factor for the feature dataloader.
74
+ output_dir : str, optional
75
+ Directory to save transcription JSON files.
76
+ generate_kwargs : dict, optional
77
+ Extra keyword arguments forwarded to ``model.generate()`` (e.g.
78
+ ``num_beams``, ``length_penalty``).
79
+ """
80
+ _require_transformers()
81
+
82
+ if language is None:
83
+ raise ValueError(
84
+ "The 'cohere' backend requires an explicit `language` — "
85
+ "CohereAsrForConditionalGeneration does not perform language detection."
86
+ )
87
+ if language not in COHERE_SUPPORTED_LANGUAGES:
88
+ raise ValueError(
89
+ f"Language {language!r} is not supported by Cohere Transcribe. "
90
+ f"Supported: {sorted(COHERE_SUPPORTED_LANGUAGES)}."
91
+ )
92
+
93
+ generate_kwargs = generate_kwargs or {}
94
+
95
+ for features in tqdm(file_dataloader, desc="Transcribing audio files"):
96
+ slice_dataset = features[0]["dataset"]
97
+ metadata = features[0]["dataset"].metadata
98
+ transcription_texts = []
99
+
100
+ feature_dataloader = torch.utils.data.DataLoader(
101
+ slice_dataset,
102
+ batch_size=batch_size,
103
+ num_workers=num_workers,
104
+ prefetch_factor=prefetch_factor,
105
+ collate_fn=cohere_transcribe_collate_fn,
106
+ )
107
+
108
+ logger.info(f"Transcribing {metadata.audio_path} ...")
109
+
110
+ for batch in feature_dataloader:
111
+ inputs = processor(
112
+ batch["audio"],
113
+ sampling_rate=sample_rate,
114
+ return_tensors="pt",
115
+ language=language,
116
+ punctuation=punctuation,
117
+ )
118
+ inputs = inputs.to(model.device, dtype=model.dtype)
119
+
120
+ with torch.inference_mode():
121
+ outputs = model.generate(
122
+ **inputs,
123
+ max_new_tokens=max_new_tokens,
124
+ **generate_kwargs,
125
+ )
126
+
127
+ transcription = processor.batch_decode(outputs, skip_special_tokens=True)
128
+ transcription_texts.extend(transcription)
129
+
130
+ for i, speech in enumerate(metadata.speeches):
131
+ for j, chunk in enumerate(speech.chunks):
132
+ chunk.text = transcription_texts[j].strip()
133
+
134
+ output_path = Path(output_dir) / Path(metadata.audio_path).with_suffix(".json")
135
+ output_path.parent.mkdir(parents=True, exist_ok=True)
136
+ save_metadata_json(metadata, output_dir=output_dir)
@@ -0,0 +1,57 @@
1
+ import torch
2
+
3
+
4
+ def transcribe_collate_fn(batch: list[dict]) -> dict:
5
+ """
6
+ Collate function for transcription.
7
+
8
+ Parameters
9
+ ----------
10
+ batch : list of dict
11
+ List of samples from the dataset.
12
+
13
+ Returns
14
+ -------
15
+ dict
16
+ Collated batch with 'features', 'start_times', and 'speech_ids'.
17
+ """
18
+ # Remove None values
19
+ speech_ids = [b["speech_id"] for b in batch if b is not None]
20
+ start_times = [b["start_time_global"] for b in batch if b is not None]
21
+ batch = [b["feature"] for b in batch if b is not None]
22
+
23
+ # Concat, keep batch dimension
24
+ batch = torch.cat(batch, dim=0)
25
+
26
+ return {
27
+ "features": batch,
28
+ "start_times": start_times,
29
+ "speech_ids": speech_ids,
30
+ }
31
+
32
+
33
+ def cohere_transcribe_collate_fn(batch: list[dict]) -> dict:
34
+ """
35
+ Collate function for Cohere ASR transcription.
36
+
37
+ Gathers raw audio arrays into a list so the caller can invoke Cohere's
38
+ processor on the whole batch — required because per-sample processor
39
+ calls return variable-length features that cannot be stacked.
40
+
41
+ Parameters
42
+ ----------
43
+ batch : list of dict
44
+ List of samples from the dataset, each with ``"audio"`` (raw waveform).
45
+
46
+ Returns
47
+ -------
48
+ dict
49
+ Collated batch with ``'audio'`` (list of waveforms), ``'start_times'``,
50
+ and ``'speech_ids'``.
51
+ """
52
+ batch = [b for b in batch if b is not None]
53
+ return {
54
+ "audio": [b["audio"] for b in batch],
55
+ "start_times": [b["start_time_global"] for b in batch],
56
+ "speech_ids": [b["speech_id"] for b in batch],
57
+ }
@@ -48,12 +48,14 @@ class StreamingAudioSliceDataset(Dataset):
48
48
  processor: Wav2Vec2Processor | WhisperProcessor,
49
49
  sample_rate: int = 16000,
50
50
  metadata: AudioMetadata | None = None,
51
+ return_raw_audio: bool = False,
51
52
  ):
52
53
  self.audio_path = str(audio_path)
53
54
  self.chunk_specs = chunk_specs
54
55
  self.processor = processor
55
56
  self.sample_rate = sample_rate
56
57
  self.metadata = metadata
58
+ self.return_raw_audio = return_raw_audio
57
59
  self.processor_attribute = (
58
60
  "input_values" if isinstance(processor, Wav2Vec2Processor) else "input_features"
59
61
  )
@@ -75,6 +77,14 @@ class StreamingAudioSliceDataset(Dataset):
75
77
  sample_rate=self.sample_rate,
76
78
  )
77
79
 
80
+ if self.return_raw_audio:
81
+ # Caller (e.g. cohere backend) preprocesses with the batch to handle padding.
82
+ return {
83
+ "audio": audio,
84
+ "start_time_global": start_sec,
85
+ "speech_id": spec["speech_id"],
86
+ }
87
+
78
88
  # Convert to tensor and add batch dimension for processor
79
89
  if isinstance(self.processor, Wav2Vec2Processor):
80
90
  audio = torch.tensor(audio).unsqueeze(0)
@@ -165,6 +175,7 @@ class StreamingAudioFileDataset(Dataset):
165
175
  sample_rate: int = 16000,
166
176
  chunk_size: int = 30,
167
177
  alignment_strategy: str = "chunk",
178
+ return_raw_audio: bool = False,
168
179
  ):
169
180
  if isinstance(metadata, AudioMetadata):
170
181
  self.metadata = [metadata]
@@ -176,6 +187,7 @@ class StreamingAudioFileDataset(Dataset):
176
187
  self.chunk_size = chunk_size
177
188
  self.processor = processor
178
189
  self.alignment_strategy = alignment_strategy
190
+ self.return_raw_audio = return_raw_audio
179
191
 
180
192
  def _get_speech_chunk_specs(self, metadata: AudioMetadata) -> list[dict]:
181
193
  """
@@ -281,6 +293,7 @@ class StreamingAudioFileDataset(Dataset):
281
293
  processor=self.processor,
282
294
  sample_rate=self.sr,
283
295
  metadata=metadata,
296
+ return_raw_audio=self.return_raw_audio,
284
297
  )
285
298
 
286
299
  return {
@@ -5,11 +5,7 @@ import ctranslate2
5
5
  import torch
6
6
  from easyaligner.data.collators import audiofile_collate_fn, metadata_collate_fn
7
7
  from easyaligner.data.datamodel import SpeechSegment
8
- from easyaligner.data.dataset import (
9
- AudioFileDataset,
10
- JSONMetadataDataset,
11
- StreamingAudioFileDataset,
12
- )
8
+ from easyaligner.data.dataset import AudioFileDataset, JSONMetadataDataset
13
9
  from easyaligner.pipelines import alignment_pipeline, emissions_pipeline, vad_pipeline
14
10
  from easyaligner.vad.pyannote import load_vad_model as load_pyannote_vad_model
15
11
  from easyaligner.vad.silero import load_vad_model as load_silero_vad_model
@@ -22,15 +18,30 @@ from transformers import (
22
18
 
23
19
  from easytranscriber.asr.ct2 import transcribe as ct2_transcribe
24
20
  from easytranscriber.asr.hf import transcribe as hf_transcribe
21
+ from easytranscriber.data.dataset import StreamingAudioFileDataset
25
22
  from easytranscriber.text.normalization import text_normalizer
26
23
  from easytranscriber.utils import hf_to_ct2_converter
27
24
 
28
25
  logger = logging.getLogger(__name__)
29
26
 
27
+
28
+ def _load_cohere_transcribe():
29
+ """Lazy loader for the cohere backend.
30
+
31
+ Defers the ``CohereAsrForConditionalGeneration`` import (which requires
32
+ transformers>=5.4.0) so that users on older transformers who only use
33
+ the ct2/hf backends can still import this module.
34
+ """
35
+ from easytranscriber.asr.cohere import transcribe as cohere_transcribe
36
+
37
+ return cohere_transcribe
38
+
39
+
30
40
  # dispatch mapping
31
41
  TRANSCRIBE_BACKENDS = {
32
42
  "ct2": ct2_transcribe,
33
43
  "hf": hf_transcribe,
44
+ "cohere": _load_cohere_transcribe,
34
45
  }
35
46
 
36
47
  VAD_BACKENDS = {
@@ -56,10 +67,13 @@ def pipeline(
56
67
  task: str = "transcribe",
57
68
  beam_size: int = 5,
58
69
  max_length: int = 250,
70
+ max_new_tokens: int = 256,
59
71
  repetition_penalty: float = 1.0,
60
72
  length_penalty: float = 1.0,
61
73
  patience: float = 1.0,
62
74
  no_repeat_ngram_size: int = 0,
75
+ punctuation: bool = True,
76
+ generate_kwargs: dict | None = None,
63
77
  start_wildcard: bool = False,
64
78
  end_wildcard: bool = False,
65
79
  blank_id: int | None = None,
@@ -103,7 +117,9 @@ def pipeline(
103
117
  speeches : list[list[SpeechSegment]], optional
104
118
  Existing speech segments for alignment.
105
119
  backend : str, optional
106
- Backend to use for the transcription model: "ct2" or "hf". Default is "ct2".
120
+ Backend to use for the transcription model: "ct2", "hf", or "cohere". Default is "ct2".
121
+ The "cohere" backend requires `transformers>=5.4.0`, `streaming=True`, and an explicit
122
+ `language` (Cohere has no language detection).
107
123
  sample_rate : int, optional
108
124
  Sample rate.
109
125
  chunk_size : int, optional
@@ -127,7 +143,14 @@ def pipeline(
127
143
  repetition_penalty : float, optional
128
144
  See HF [source code](https://github.com/huggingface/transformers/blob/v4.57.5/src/transformers/generation/configuration_utils.py#L188-L190) for details.
129
145
  max_length : int, optional
130
- Maximum length of generated text.
146
+ Maximum length of generated text. Applies to Whisper backends (ct2, hf).
147
+ max_new_tokens : int, optional
148
+ Maximum number of new tokens to generate per chunk. Applies to the cohere backend.
149
+ punctuation : bool, optional
150
+ Emit punctuation in Cohere transcriptions. Applies to the cohere backend only.
151
+ generate_kwargs : dict, optional
152
+ Extra kwargs forwarded to ``model.generate()`` for the cohere backend
153
+ (e.g. ``num_beams``, ``length_penalty``).
131
154
  start_wildcard : bool, optional
132
155
  Add start wildcard to forced alignment.
133
156
  end_wildcard : bool, optional
@@ -244,32 +267,68 @@ def pipeline(
244
267
  )
245
268
 
246
269
  # Step 2: Run Transcription
247
- transcription_args = {
248
- "language": language,
249
- "task": task,
250
- "beam_size": beam_size,
251
- "max_length": max_length,
252
- "repetition_penalty": repetition_penalty,
253
- "length_penalty": length_penalty,
254
- }
255
-
256
- if backend == "ct2":
257
- model_path = hf_to_ct2_converter(transcription_model, cache_dir=cache_dir)
258
- logger.info(f"Loading CTranslate2 model from {model_path}...")
259
- model = ctranslate2.models.Whisper(model_path.as_posix(), device=device)
260
- transcription_args.update(
261
- {
262
- "patience": patience,
263
- "no_repeat_ngram_size": no_repeat_ngram_size,
264
- }
270
+ dataset_kwargs: dict = {}
271
+
272
+ if backend == "cohere":
273
+ if language is None:
274
+ raise ValueError(
275
+ "The 'cohere' backend requires an explicit `language` — "
276
+ "CohereAsrForConditionalGeneration does not perform language detection."
277
+ )
278
+ if not streaming:
279
+ raise ValueError(
280
+ "The 'cohere' backend requires `streaming=True` "
281
+ "(the non-streaming AudioFileDataset does not support return_raw_audio)."
282
+ )
283
+
284
+ transcription_args = {
285
+ "language": language,
286
+ "max_new_tokens": max_new_tokens,
287
+ "punctuation": punctuation,
288
+ "sample_rate": sample_rate,
289
+ "generate_kwargs": generate_kwargs,
290
+ }
291
+
292
+ from transformers import AutoProcessor, CohereAsrForConditionalGeneration
293
+
294
+ logger.info(f"Loading Cohere ASR model from {transcription_model}...")
295
+ model = (
296
+ CohereAsrForConditionalGeneration.from_pretrained(
297
+ transcription_model, torch_dtype=torch.float16, cache_dir=cache_dir
298
+ )
299
+ .to(device)
300
+ .eval()
265
301
  )
302
+ processor = AutoProcessor.from_pretrained(transcription_model, cache_dir=cache_dir)
303
+ dataset_kwargs = {"return_raw_audio": True}
266
304
  else:
267
- logger.info(f"Loading Hugging Face model from {transcription_model}...")
268
- model = WhisperForConditionalGeneration.from_pretrained(
269
- transcription_model, torch_dtype=torch.float16, cache_dir=cache_dir
270
- ).to(device)
305
+ transcription_args = {
306
+ "language": language,
307
+ "task": task,
308
+ "beam_size": beam_size,
309
+ "max_length": max_length,
310
+ "repetition_penalty": repetition_penalty,
311
+ "length_penalty": length_penalty,
312
+ }
313
+
314
+ if backend == "ct2":
315
+ model_path = hf_to_ct2_converter(transcription_model, cache_dir=cache_dir)
316
+ logger.info(f"Loading CTranslate2 model from {model_path}...")
317
+ model = ctranslate2.models.Whisper(model_path.as_posix(), device=device)
318
+ transcription_args.update(
319
+ {
320
+ "patience": patience,
321
+ "no_repeat_ngram_size": no_repeat_ngram_size,
322
+ }
323
+ )
324
+ else:
325
+ logger.info(f"Loading Hugging Face model from {transcription_model}...")
326
+ model = WhisperForConditionalGeneration.from_pretrained(
327
+ transcription_model, torch_dtype=torch.float16, cache_dir=cache_dir
328
+ ).to(device)
329
+
330
+ processor = WhisperProcessor.from_pretrained(transcription_model, cache_dir=cache_dir)
271
331
 
272
- processor = WhisperProcessor.from_pretrained(transcription_model, cache_dir=cache_dir)
273
332
  json_dataset = JSONMetadataDataset(
274
333
  json_paths=[str(Path(output_vad_dir) / p) for p in json_paths]
275
334
  )
@@ -281,6 +340,7 @@ def pipeline(
281
340
  sample_rate=sample_rate,
282
341
  chunk_size=chunk_size,
283
342
  alignment_strategy="chunk",
343
+ **dataset_kwargs,
284
344
  )
285
345
 
286
346
  file_dataloader = torch.utils.data.DataLoader(
@@ -293,6 +353,9 @@ def pipeline(
293
353
  )
294
354
 
295
355
  transcribe = TRANSCRIBE_BACKENDS[backend]
356
+ if backend == "cohere":
357
+ transcribe = transcribe() # lazy-load to avoid importing on older transformers
358
+
296
359
  transcribe(
297
360
  model=model,
298
361
  processor=processor,
@@ -51,6 +51,18 @@ def main():
51
51
  parser.add_argument(
52
52
  "--reindex", action="store_true", help="Force full re-index of all JSON files."
53
53
  )
54
+ parser.add_argument(
55
+ "--index-mode",
56
+ choices=["chunks", "alignments"],
57
+ default=None,
58
+ help=(
59
+ "How to index transcription JSON files. "
60
+ "'chunks' indexes VAD chunks produced by the ASR pipeline. "
61
+ "'alignments' indexes AlignmentSegments, for use with "
62
+ "easyaligner ground-truth alignment outputs where chunks carry no text. "
63
+ "If omitted, the mode is detected automatically per file."
64
+ ),
65
+ )
54
66
  args = parser.parse_args()
55
67
 
56
68
  logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -75,8 +87,14 @@ def main():
75
87
  # Initialize database and index
76
88
  conn = init_db(args.db)
77
89
 
78
- logger.info("Indexing %s ...", args.alignments_dir)
79
- indexed, skipped = index_directory(args.alignments_dir, conn, force=args.reindex)
90
+ logger.info(
91
+ "Indexing %s (mode: %s) ...",
92
+ args.alignments_dir,
93
+ args.index_mode or "auto",
94
+ )
95
+ indexed, skipped = index_directory(
96
+ args.alignments_dir, conn, force=args.reindex, index_mode=args.index_mode
97
+ )
80
98
  logger.info("Indexed %d file(s), skipped %d unchanged.", indexed, skipped)
81
99
 
82
100
  # Create and run the app
@@ -73,9 +73,9 @@ def create_app(
73
73
  total_pages = max(1, math.ceil(total / per_page))
74
74
 
75
75
  return templates.TemplateResponse(
76
+ request,
76
77
  "search.html",
77
78
  {
78
- "request": request,
79
79
  "query": q,
80
80
  "results": results,
81
81
  "total": total,
@@ -94,9 +94,9 @@ def create_app(
94
94
  total_pages = max(1, math.ceil(total / per_page))
95
95
 
96
96
  return templates.TemplateResponse(
97
+ request,
97
98
  "documents.html",
98
99
  {
99
- "request": request,
100
100
  "results": results,
101
101
  "total": total,
102
102
  "page": page,
@@ -116,9 +116,9 @@ def create_app(
116
116
  raise HTTPException(status_code=404, detail="Document not found")
117
117
 
118
118
  return templates.TemplateResponse(
119
+ request,
119
120
  "document.html",
120
121
  {
121
- "request": request,
122
122
  "document": doc,
123
123
  "seek_time": t,
124
124
  "query": q,
@@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS documents (
10
10
  sample_rate INTEGER NOT NULL,
11
11
  num_speeches INTEGER NOT NULL,
12
12
  num_chunks INTEGER NOT NULL,
13
+ index_mode TEXT NOT NULL DEFAULT 'chunks',
13
14
  mtime REAL NOT NULL,
14
15
  indexed_at TEXT NOT NULL DEFAULT (datetime('now'))
15
16
  );
@@ -0,0 +1,180 @@
1
+ import logging
2
+ import sqlite3
3
+ from pathlib import Path
4
+
5
+ import msgspec
6
+
7
+ from easytranscriber.data.datamodel import AudioMetadata
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def _detect_index_mode(metadata: "AudioMetadata") -> str:
13
+ """Detect whether to index by chunks or alignments.
14
+
15
+ Checks if the first chunk of the first speech has text. If not, the file
16
+ was produced by easyaligner (ground-truth alignment) and should be indexed
17
+ by alignments. ASR pipeline output always populates chunk text.
18
+ """
19
+ if metadata.speeches:
20
+ first_speech = metadata.speeches[0]
21
+ if first_speech.chunks and first_speech.chunks[0].text is None:
22
+ return "alignments"
23
+ return "chunks"
24
+
25
+
26
+ def index_file(conn: sqlite3.Connection, json_path: Path, index_mode: str | None = None) -> bool:
27
+ """Index a single alignment JSON file. Returns True if the file was (re)indexed.
28
+
29
+ Parameters
30
+ ----------
31
+ index_mode : str or None
32
+ ``"chunks"`` indexes VAD chunks produced by ASR pipelines.
33
+ ``"alignments"`` indexes sentence-level AlignmentSegments, as produced by
34
+ ``easyaligner`` when ground-truth text is aligned to audio (chunks have no text).
35
+ If ``None`` (default), the mode is detected automatically from the file contents.
36
+ """
37
+ mtime = json_path.stat().st_mtime
38
+
39
+ # Parse JSON using the project's own data model (needed for auto-detection)
40
+ raw = json_path.read_bytes()
41
+ metadata = msgspec.json.decode(raw, type=AudioMetadata)
42
+
43
+ resolved_mode = index_mode if index_mode is not None else _detect_index_mode(metadata)
44
+
45
+ # Check if already indexed with same mtime and same mode
46
+ existing = conn.execute(
47
+ "SELECT id, mtime, index_mode FROM documents WHERE json_path = ?", (str(json_path),)
48
+ ).fetchone()
49
+
50
+ if existing and existing["mtime"] == mtime and existing["index_mode"] == resolved_mode:
51
+ return False
52
+
53
+ # Remove stale entry if mtime or mode changed
54
+ if existing:
55
+ conn.execute("DELETE FROM documents WHERE id = ?", (existing["id"],))
56
+
57
+ num_speeches = len(metadata.speeches) if metadata.speeches else 0
58
+ num_segments = 0
59
+ if metadata.speeches:
60
+ for speech in metadata.speeches:
61
+ if resolved_mode == "alignments":
62
+ num_segments += len(speech.alignments)
63
+ else:
64
+ num_segments += len(speech.chunks)
65
+
66
+ # Insert document
67
+ cur = conn.execute(
68
+ """INSERT INTO documents (audio_path, json_path, duration, sample_rate,
69
+ num_speeches, num_chunks, index_mode, mtime)
70
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
71
+ (
72
+ metadata.audio_path,
73
+ str(json_path),
74
+ metadata.duration,
75
+ metadata.sample_rate,
76
+ num_speeches,
77
+ num_segments,
78
+ resolved_mode,
79
+ mtime,
80
+ ),
81
+ )
82
+ doc_id = cur.lastrowid
83
+
84
+ if metadata.speeches:
85
+ rows = []
86
+ if resolved_mode == "alignments":
87
+ for speech_idx, speech in enumerate(metadata.speeches):
88
+ for seg_idx, seg in enumerate(speech.alignments):
89
+ rows.append(
90
+ (
91
+ doc_id,
92
+ speech_idx,
93
+ seg_idx,
94
+ seg.text,
95
+ seg.start,
96
+ seg.end,
97
+ seg.duration,
98
+ )
99
+ )
100
+ else:
101
+ for speech_idx, speech in enumerate(metadata.speeches):
102
+ for chunk_idx, chunk in enumerate(speech.chunks):
103
+ if not chunk.text:
104
+ continue
105
+ rows.append(
106
+ (
107
+ doc_id,
108
+ speech_idx,
109
+ chunk_idx,
110
+ chunk.text,
111
+ chunk.start,
112
+ chunk.end,
113
+ chunk.duration,
114
+ )
115
+ )
116
+ conn.executemany(
117
+ """INSERT INTO chunks
118
+ (document_id, speech_idx, chunk_idx, text, start_time, end_time, duration)
119
+ VALUES (?, ?, ?, ?, ?, ?, ?)""",
120
+ rows,
121
+ )
122
+
123
+ return True
124
+
125
+
126
+ def index_directory(
127
+ alignments_dir: Path,
128
+ conn: sqlite3.Connection,
129
+ force: bool = False,
130
+ index_mode: str | None = None,
131
+ ) -> tuple[int, int]:
132
+ """
133
+ Index all JSON files in the alignments directory.
134
+
135
+ Parameters
136
+ ----------
137
+ index_mode : str or None
138
+ ``"chunks"``, ``"alignments"``, or ``None`` to auto-detect per file.
139
+ See :func:`index_file`.
140
+
141
+ Returns (indexed_count, skipped_count).
142
+ """
143
+ if force:
144
+ conn.execute("DELETE FROM chunks")
145
+ conn.execute("DELETE FROM documents")
146
+ # Rebuild FTS index
147
+ conn.execute("INSERT INTO chunks_fts(chunks_fts) VALUES('rebuild')")
148
+ conn.commit()
149
+
150
+ json_files = sorted(alignments_dir.glob("*.json"))
151
+ total_files = len(json_files)
152
+ if not json_files:
153
+ logger.warning("No JSON files found in %s", alignments_dir)
154
+ return 0, 0
155
+
156
+ indexed = 0
157
+ skipped = 0
158
+ for file_num, json_path in enumerate(json_files, 1):
159
+ try:
160
+ was_indexed = index_file(conn, json_path, index_mode=index_mode)
161
+ if was_indexed:
162
+ indexed += 1
163
+ else:
164
+ skipped += 1
165
+ status = "indexed" if was_indexed else "skipped (unchanged)"
166
+ logger.info("[%d/%d] %s — %s", file_num, total_files, json_path.name, status)
167
+ except Exception:
168
+ logger.exception("[%d/%d] Failed to index %s", file_num, total_files, json_path)
169
+
170
+ # Remove documents whose JSON files no longer exist
171
+ existing_paths = {str(p) for p in json_files}
172
+ all_db_paths = conn.execute("SELECT id, json_path FROM documents").fetchall()
173
+ stale_ids = [r["id"] for r in all_db_paths if r["json_path"] not in existing_paths]
174
+ if stale_ids:
175
+ placeholders = ",".join("?" * len(stale_ids))
176
+ conn.execute(f"DELETE FROM documents WHERE id IN ({placeholders})", stale_ids)
177
+ logger.info("Removed %d stale documents from index", len(stale_ids))
178
+
179
+ conn.commit()
180
+ return indexed, skipped
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: easytranscriber
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Speech recognition with accurate word-level timestamps.
5
5
  Author: Faton Rekathati
6
6
  Project-URL: Repository, https://github.com/kb-labb/easytranscriber
@@ -19,7 +19,7 @@ Requires-Dist: ctranslate2>=4.4.0
19
19
  Requires-Dist: msgspec
20
20
  Requires-Dist: easyaligner==0.*
21
21
  Provides-Extra: search
22
- Requires-Dist: fastapi>=0.104.0; extra == "search"
22
+ Requires-Dist: fastapi>=0.109.0; extra == "search"
23
23
  Requires-Dist: uvicorn[standard]>=0.24.0; extra == "search"
24
24
  Requires-Dist: jinja2>=3.1.0; extra == "search"
25
25
  Dynamic: license-file
@@ -127,6 +127,9 @@ The documentation is available at [kb-labb.github.io/easytranscriber/](https://k
127
127
  * [Text normalization tutorial](https://kb-labb.github.io/easytranscriber/get-started/text-processing.html).
128
128
  * [API reference](https://kb-labb.github.io/easytranscriber/reference/).
129
129
 
130
+ > [!TIP]
131
+ > Check out the [`easyaligner`](https://kb-labb.github.io/easyaligner/) library for a user friendly pipeline for forced alignment of text and audio.
132
+
130
133
  ## Acknowledgements
131
134
 
132
135
  `easytranscriber` draws heavy inspiration from [`WhisperX`](https://github.com/m-bain/whisperX) [(Bain et al., 2023)](https://www.isca-archive.org/interspeech_2023/bain23_interspeech.pdf).
@@ -134,3 +137,16 @@ The documentation is available at [kb-labb.github.io/easytranscriber/](https://k
134
137
  The forced alignment component of `easytranscriber` is based on Pytorch's forced alignment API, which implements a GPU-accelerated version of the Viterbi algorithm as described in [Pratap et al., 2024](https://jmlr.org/papers/volume25/23-1318/23-1318.pdf#page=8).
135
138
 
136
139
  LibriVox for public domain audiobooks used as tutorial examples.
140
+
141
+ ## Citation
142
+
143
+ ```
144
+ @online{rekathati2026,
145
+ author = {Rekathati, Faton},
146
+ title = {Easytranscriber: {Speech} Recognition with Precise
147
+ Timestamps},
148
+ date = {2026-02-26},
149
+ url = {https://kb-labb.github.io/posts/2026-02-26-easytranscriber/},
150
+ langid = {en}
151
+ }
152
+ ```
@@ -10,6 +10,7 @@ src/easytranscriber.egg-info/dependency_links.txt
10
10
  src/easytranscriber.egg-info/entry_points.txt
11
11
  src/easytranscriber.egg-info/requires.txt
12
12
  src/easytranscriber.egg-info/top_level.txt
13
+ src/easytranscriber/asr/cohere.py
13
14
  src/easytranscriber/asr/ct2.py
14
15
  src/easytranscriber/asr/hf.py
15
16
  src/easytranscriber/data/__init__.py
@@ -11,6 +11,6 @@ msgspec
11
11
  easyaligner==0.*
12
12
 
13
13
  [search]
14
- fastapi>=0.104.0
14
+ fastapi>=0.109.0
15
15
  uvicorn[standard]>=0.24.0
16
16
  jinja2>=3.1.0
@@ -1,30 +0,0 @@
1
- import torch
2
-
3
-
4
- def transcribe_collate_fn(batch: list[dict]) -> dict:
5
- """
6
- Collate function for transcription.
7
-
8
- Parameters
9
- ----------
10
- batch : list of dict
11
- List of samples from the dataset.
12
-
13
- Returns
14
- -------
15
- dict
16
- Collated batch with 'features', 'start_times', and 'speech_ids'.
17
- """
18
- # Remove None values
19
- speech_ids = [b["speech_id"] for b in batch if b is not None]
20
- start_times = [b["start_time_global"] for b in batch if b is not None]
21
- batch = [b["feature"] for b in batch if b is not None]
22
-
23
- # Concat, keep batch dimension
24
- batch = torch.cat(batch, dim=0)
25
-
26
- return {
27
- "features": batch,
28
- "start_times": start_times,
29
- "speech_ids": speech_ids,
30
- }
@@ -1,128 +0,0 @@
1
- import logging
2
- import sqlite3
3
- from pathlib import Path
4
-
5
- import msgspec
6
-
7
- from easytranscriber.data.datamodel import AudioMetadata
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def index_file(conn: sqlite3.Connection, json_path: Path) -> bool:
13
- """Index a single alignment JSON file. Returns True if the file was (re)indexed."""
14
- mtime = json_path.stat().st_mtime
15
-
16
- # Check if already indexed with same mtime
17
- existing = conn.execute(
18
- "SELECT id, mtime FROM documents WHERE json_path = ?", (str(json_path),)
19
- ).fetchone()
20
-
21
- if existing and existing["mtime"] == mtime:
22
- return False
23
-
24
- # Remove stale entry if mtime changed
25
- if existing:
26
- conn.execute("DELETE FROM documents WHERE id = ?", (existing["id"],))
27
-
28
- # Parse JSON using the project's own data model
29
- raw = json_path.read_bytes()
30
- metadata = msgspec.json.decode(raw, type=AudioMetadata)
31
-
32
- num_speeches = len(metadata.speeches) if metadata.speeches else 0
33
- num_chunks = 0
34
- if metadata.speeches:
35
- for speech in metadata.speeches:
36
- num_chunks += len(speech.chunks)
37
-
38
- # Insert document
39
- cur = conn.execute(
40
- """INSERT INTO documents (audio_path, json_path, duration, sample_rate,
41
- num_speeches, num_chunks, mtime)
42
- VALUES (?, ?, ?, ?, ?, ?, ?)""",
43
- (
44
- metadata.audio_path,
45
- str(json_path),
46
- metadata.duration,
47
- metadata.sample_rate,
48
- num_speeches,
49
- num_chunks,
50
- mtime,
51
- ),
52
- )
53
- doc_id = cur.lastrowid
54
-
55
- # Insert chunks
56
- if metadata.speeches:
57
- rows = []
58
- for speech_idx, speech in enumerate(metadata.speeches):
59
- for chunk_idx, chunk in enumerate(speech.chunks):
60
- if not chunk.text:
61
- continue
62
- rows.append(
63
- (
64
- doc_id,
65
- speech_idx,
66
- chunk_idx,
67
- chunk.text,
68
- chunk.start,
69
- chunk.end,
70
- chunk.duration,
71
- )
72
- )
73
- conn.executemany(
74
- """INSERT INTO chunks
75
- (document_id, speech_idx, chunk_idx, text, start_time, end_time, duration)
76
- VALUES (?, ?, ?, ?, ?, ?, ?)""",
77
- rows,
78
- )
79
-
80
- return True
81
-
82
-
83
- def index_directory(
84
- alignments_dir: Path, conn: sqlite3.Connection, force: bool = False
85
- ) -> tuple[int, int]:
86
- """
87
- Index all JSON files in the alignments directory.
88
-
89
- Returns (indexed_count, skipped_count).
90
- """
91
- if force:
92
- conn.execute("DELETE FROM chunks")
93
- conn.execute("DELETE FROM documents")
94
- # Rebuild FTS index
95
- conn.execute("INSERT INTO chunks_fts(chunks_fts) VALUES('rebuild')")
96
- conn.commit()
97
-
98
- json_files = sorted(alignments_dir.glob("*.json"))
99
- total_files = len(json_files)
100
- if not json_files:
101
- logger.warning("No JSON files found in %s", alignments_dir)
102
- return 0, 0
103
-
104
- indexed = 0
105
- skipped = 0
106
- for file_num, json_path in enumerate(json_files, 1):
107
- try:
108
- was_indexed = index_file(conn, json_path)
109
- if was_indexed:
110
- indexed += 1
111
- else:
112
- skipped += 1
113
- status = "indexed" if was_indexed else "skipped (unchanged)"
114
- logger.info("[%d/%d] %s — %s", file_num, total_files, json_path.name, status)
115
- except Exception:
116
- logger.exception("[%d/%d] Failed to index %s", file_num, total_files, json_path)
117
-
118
- # Remove documents whose JSON files no longer exist
119
- existing_paths = {str(p) for p in json_files}
120
- all_db_paths = conn.execute("SELECT id, json_path FROM documents").fetchall()
121
- stale_ids = [r["id"] for r in all_db_paths if r["json_path"] not in existing_paths]
122
- if stale_ids:
123
- placeholders = ",".join("?" * len(stale_ids))
124
- conn.execute(f"DELETE FROM documents WHERE id IN ({placeholders})", stale_ids)
125
- logger.info("Removed %d stale documents from index", len(stale_ids))
126
-
127
- conn.commit()
128
- return indexed, skipped
File without changes