easytranscriber 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {easytranscriber-0.2.1/src/easytranscriber.egg-info → easytranscriber-0.2.2}/PKG-INFO +1 -1
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/pyproject.toml +1 -1
- easytranscriber-0.2.2/src/easytranscriber/asr/cohere.py +136 -0
- easytranscriber-0.2.2/src/easytranscriber/data/collators.py +57 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/data/dataset.py +13 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/pipelines.py +93 -30
- {easytranscriber-0.2.1 → easytranscriber-0.2.2/src/easytranscriber.egg-info}/PKG-INFO +1 -1
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/SOURCES.txt +1 -0
- easytranscriber-0.2.1/src/easytranscriber/data/collators.py +0 -30
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/LICENSE +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/README.md +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/setup.cfg +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/asr/ct2.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/asr/hf.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/audio.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/data/__init__.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/data/datamodel.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/search/__init__.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/search/__main__.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/search/app.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/search/db.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/search/indexer.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/text/normalization.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber/utils.py +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/dependency_links.txt +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/entry_points.txt +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/requires.txt +0 -0
- {easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from easyaligner.utils import save_metadata_json
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from easytranscriber.data.collators import cohere_transcribe_collate_fn
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
# The 14 languages Cohere Transcribe was trained on.
|
|
13
|
+
# https://huggingface.co/CohereLabs/cohere-transcribe-03-2026
|
|
14
|
+
COHERE_SUPPORTED_LANGUAGES = frozenset(
|
|
15
|
+
{"ar", "de", "el", "en", "es", "fr", "it", "ja", "ko", "nl", "pl", "pt", "vi", "zh"}
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _require_transformers():
|
|
20
|
+
try:
|
|
21
|
+
from transformers import CohereAsrForConditionalGeneration
|
|
22
|
+
except ImportError as e:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"The 'cohere' ASR backend requires transformers>=5.4.0 "
|
|
25
|
+
"(CohereAsrForConditionalGeneration is not available in the installed version). "
|
|
26
|
+
"Upgrade with: pip install --upgrade 'transformers>=5.4.0'"
|
|
27
|
+
) from e
|
|
28
|
+
return CohereAsrForConditionalGeneration
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def transcribe(
|
|
32
|
+
model,
|
|
33
|
+
processor,
|
|
34
|
+
file_dataloader: torch.utils.data.DataLoader,
|
|
35
|
+
language: str,
|
|
36
|
+
batch_size: int = 4,
|
|
37
|
+
max_new_tokens: int = 256,
|
|
38
|
+
punctuation: bool = True,
|
|
39
|
+
sample_rate: int = 16000,
|
|
40
|
+
num_workers: int = 2,
|
|
41
|
+
prefetch_factor: int = 2,
|
|
42
|
+
output_dir: str = "output/transcriptions",
|
|
43
|
+
generate_kwargs: dict | None = None,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Transcribe audio files using the Cohere Transcribe model.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
model : transformers.CohereAsrForConditionalGeneration
|
|
51
|
+
Cohere ASR model.
|
|
52
|
+
processor : transformers.AutoProcessor
|
|
53
|
+
Cohere ASR processor.
|
|
54
|
+
file_dataloader : torch.utils.data.DataLoader
|
|
55
|
+
DataLoader yielding audio file datasets. The underlying
|
|
56
|
+
``StreamingAudioFileDataset`` must be constructed with
|
|
57
|
+
``return_raw_audio=True`` so the processor can be called on whole
|
|
58
|
+
batches (per-sample calls return variable-length features).
|
|
59
|
+
language : str
|
|
60
|
+
ISO 639-1 language code (e.g. 'en', 'ja'). Required — Cohere has
|
|
61
|
+
no built-in language detection.
|
|
62
|
+
batch_size : int, optional
|
|
63
|
+
Batch size for inference.
|
|
64
|
+
max_new_tokens : int, optional
|
|
65
|
+
Maximum number of tokens to generate per chunk. Default is 256.
|
|
66
|
+
punctuation : bool, optional
|
|
67
|
+
Emit punctuation in transcriptions. Default is True.
|
|
68
|
+
sample_rate : int, optional
|
|
69
|
+
Sample rate of audio passed to the processor. Default is 16000.
|
|
70
|
+
num_workers : int, optional
|
|
71
|
+
Number of workers for the feature dataloader.
|
|
72
|
+
prefetch_factor : int, optional
|
|
73
|
+
Prefetch factor for the feature dataloader.
|
|
74
|
+
output_dir : str, optional
|
|
75
|
+
Directory to save transcription JSON files.
|
|
76
|
+
generate_kwargs : dict, optional
|
|
77
|
+
Extra keyword arguments forwarded to ``model.generate()`` (e.g.
|
|
78
|
+
``num_beams``, ``length_penalty``).
|
|
79
|
+
"""
|
|
80
|
+
_require_transformers()
|
|
81
|
+
|
|
82
|
+
if language is None:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"The 'cohere' backend requires an explicit `language` — "
|
|
85
|
+
"CohereAsrForConditionalGeneration does not perform language detection."
|
|
86
|
+
)
|
|
87
|
+
if language not in COHERE_SUPPORTED_LANGUAGES:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Language {language!r} is not supported by Cohere Transcribe. "
|
|
90
|
+
f"Supported: {sorted(COHERE_SUPPORTED_LANGUAGES)}."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
generate_kwargs = generate_kwargs or {}
|
|
94
|
+
|
|
95
|
+
for features in tqdm(file_dataloader, desc="Transcribing audio files"):
|
|
96
|
+
slice_dataset = features[0]["dataset"]
|
|
97
|
+
metadata = features[0]["dataset"].metadata
|
|
98
|
+
transcription_texts = []
|
|
99
|
+
|
|
100
|
+
feature_dataloader = torch.utils.data.DataLoader(
|
|
101
|
+
slice_dataset,
|
|
102
|
+
batch_size=batch_size,
|
|
103
|
+
num_workers=num_workers,
|
|
104
|
+
prefetch_factor=prefetch_factor,
|
|
105
|
+
collate_fn=cohere_transcribe_collate_fn,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
logger.info(f"Transcribing {metadata.audio_path} ...")
|
|
109
|
+
|
|
110
|
+
for batch in feature_dataloader:
|
|
111
|
+
inputs = processor(
|
|
112
|
+
batch["audio"],
|
|
113
|
+
sampling_rate=sample_rate,
|
|
114
|
+
return_tensors="pt",
|
|
115
|
+
language=language,
|
|
116
|
+
punctuation=punctuation,
|
|
117
|
+
)
|
|
118
|
+
inputs = inputs.to(model.device, dtype=model.dtype)
|
|
119
|
+
|
|
120
|
+
with torch.inference_mode():
|
|
121
|
+
outputs = model.generate(
|
|
122
|
+
**inputs,
|
|
123
|
+
max_new_tokens=max_new_tokens,
|
|
124
|
+
**generate_kwargs,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
transcription = processor.batch_decode(outputs, skip_special_tokens=True)
|
|
128
|
+
transcription_texts.extend(transcription)
|
|
129
|
+
|
|
130
|
+
for i, speech in enumerate(metadata.speeches):
|
|
131
|
+
for j, chunk in enumerate(speech.chunks):
|
|
132
|
+
chunk.text = transcription_texts[j].strip()
|
|
133
|
+
|
|
134
|
+
output_path = Path(output_dir) / Path(metadata.audio_path).with_suffix(".json")
|
|
135
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
136
|
+
save_metadata_json(metadata, output_dir=output_dir)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def transcribe_collate_fn(batch: list[dict]) -> dict:
|
|
5
|
+
"""
|
|
6
|
+
Collate function for transcription.
|
|
7
|
+
|
|
8
|
+
Parameters
|
|
9
|
+
----------
|
|
10
|
+
batch : list of dict
|
|
11
|
+
List of samples from the dataset.
|
|
12
|
+
|
|
13
|
+
Returns
|
|
14
|
+
-------
|
|
15
|
+
dict
|
|
16
|
+
Collated batch with 'features', 'start_times', and 'speech_ids'.
|
|
17
|
+
"""
|
|
18
|
+
# Remove None values
|
|
19
|
+
speech_ids = [b["speech_id"] for b in batch if b is not None]
|
|
20
|
+
start_times = [b["start_time_global"] for b in batch if b is not None]
|
|
21
|
+
batch = [b["feature"] for b in batch if b is not None]
|
|
22
|
+
|
|
23
|
+
# Concat, keep batch dimension
|
|
24
|
+
batch = torch.cat(batch, dim=0)
|
|
25
|
+
|
|
26
|
+
return {
|
|
27
|
+
"features": batch,
|
|
28
|
+
"start_times": start_times,
|
|
29
|
+
"speech_ids": speech_ids,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def cohere_transcribe_collate_fn(batch: list[dict]) -> dict:
|
|
34
|
+
"""
|
|
35
|
+
Collate function for Cohere ASR transcription.
|
|
36
|
+
|
|
37
|
+
Gathers raw audio arrays into a list so the caller can invoke Cohere's
|
|
38
|
+
processor on the whole batch — required because per-sample processor
|
|
39
|
+
calls return variable-length features that cannot be stacked.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
batch : list of dict
|
|
44
|
+
List of samples from the dataset, each with ``"audio"`` (raw waveform).
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
dict
|
|
49
|
+
Collated batch with ``'audio'`` (list of waveforms), ``'start_times'``,
|
|
50
|
+
and ``'speech_ids'``.
|
|
51
|
+
"""
|
|
52
|
+
batch = [b for b in batch if b is not None]
|
|
53
|
+
return {
|
|
54
|
+
"audio": [b["audio"] for b in batch],
|
|
55
|
+
"start_times": [b["start_time_global"] for b in batch],
|
|
56
|
+
"speech_ids": [b["speech_id"] for b in batch],
|
|
57
|
+
}
|
|
@@ -48,12 +48,14 @@ class StreamingAudioSliceDataset(Dataset):
|
|
|
48
48
|
processor: Wav2Vec2Processor | WhisperProcessor,
|
|
49
49
|
sample_rate: int = 16000,
|
|
50
50
|
metadata: AudioMetadata | None = None,
|
|
51
|
+
return_raw_audio: bool = False,
|
|
51
52
|
):
|
|
52
53
|
self.audio_path = str(audio_path)
|
|
53
54
|
self.chunk_specs = chunk_specs
|
|
54
55
|
self.processor = processor
|
|
55
56
|
self.sample_rate = sample_rate
|
|
56
57
|
self.metadata = metadata
|
|
58
|
+
self.return_raw_audio = return_raw_audio
|
|
57
59
|
self.processor_attribute = (
|
|
58
60
|
"input_values" if isinstance(processor, Wav2Vec2Processor) else "input_features"
|
|
59
61
|
)
|
|
@@ -75,6 +77,14 @@ class StreamingAudioSliceDataset(Dataset):
|
|
|
75
77
|
sample_rate=self.sample_rate,
|
|
76
78
|
)
|
|
77
79
|
|
|
80
|
+
if self.return_raw_audio:
|
|
81
|
+
# Caller (e.g. cohere backend) preprocesses with the batch to handle padding.
|
|
82
|
+
return {
|
|
83
|
+
"audio": audio,
|
|
84
|
+
"start_time_global": start_sec,
|
|
85
|
+
"speech_id": spec["speech_id"],
|
|
86
|
+
}
|
|
87
|
+
|
|
78
88
|
# Convert to tensor and add batch dimension for processor
|
|
79
89
|
if isinstance(self.processor, Wav2Vec2Processor):
|
|
80
90
|
audio = torch.tensor(audio).unsqueeze(0)
|
|
@@ -165,6 +175,7 @@ class StreamingAudioFileDataset(Dataset):
|
|
|
165
175
|
sample_rate: int = 16000,
|
|
166
176
|
chunk_size: int = 30,
|
|
167
177
|
alignment_strategy: str = "chunk",
|
|
178
|
+
return_raw_audio: bool = False,
|
|
168
179
|
):
|
|
169
180
|
if isinstance(metadata, AudioMetadata):
|
|
170
181
|
self.metadata = [metadata]
|
|
@@ -176,6 +187,7 @@ class StreamingAudioFileDataset(Dataset):
|
|
|
176
187
|
self.chunk_size = chunk_size
|
|
177
188
|
self.processor = processor
|
|
178
189
|
self.alignment_strategy = alignment_strategy
|
|
190
|
+
self.return_raw_audio = return_raw_audio
|
|
179
191
|
|
|
180
192
|
def _get_speech_chunk_specs(self, metadata: AudioMetadata) -> list[dict]:
|
|
181
193
|
"""
|
|
@@ -281,6 +293,7 @@ class StreamingAudioFileDataset(Dataset):
|
|
|
281
293
|
processor=self.processor,
|
|
282
294
|
sample_rate=self.sr,
|
|
283
295
|
metadata=metadata,
|
|
296
|
+
return_raw_audio=self.return_raw_audio,
|
|
284
297
|
)
|
|
285
298
|
|
|
286
299
|
return {
|
|
@@ -5,11 +5,7 @@ import ctranslate2
|
|
|
5
5
|
import torch
|
|
6
6
|
from easyaligner.data.collators import audiofile_collate_fn, metadata_collate_fn
|
|
7
7
|
from easyaligner.data.datamodel import SpeechSegment
|
|
8
|
-
from easyaligner.data.dataset import
|
|
9
|
-
AudioFileDataset,
|
|
10
|
-
JSONMetadataDataset,
|
|
11
|
-
StreamingAudioFileDataset,
|
|
12
|
-
)
|
|
8
|
+
from easyaligner.data.dataset import AudioFileDataset, JSONMetadataDataset
|
|
13
9
|
from easyaligner.pipelines import alignment_pipeline, emissions_pipeline, vad_pipeline
|
|
14
10
|
from easyaligner.vad.pyannote import load_vad_model as load_pyannote_vad_model
|
|
15
11
|
from easyaligner.vad.silero import load_vad_model as load_silero_vad_model
|
|
@@ -22,15 +18,30 @@ from transformers import (
|
|
|
22
18
|
|
|
23
19
|
from easytranscriber.asr.ct2 import transcribe as ct2_transcribe
|
|
24
20
|
from easytranscriber.asr.hf import transcribe as hf_transcribe
|
|
21
|
+
from easytranscriber.data.dataset import StreamingAudioFileDataset
|
|
25
22
|
from easytranscriber.text.normalization import text_normalizer
|
|
26
23
|
from easytranscriber.utils import hf_to_ct2_converter
|
|
27
24
|
|
|
28
25
|
logger = logging.getLogger(__name__)
|
|
29
26
|
|
|
27
|
+
|
|
28
|
+
def _load_cohere_transcribe():
|
|
29
|
+
"""Lazy loader for the cohere backend.
|
|
30
|
+
|
|
31
|
+
Defers the ``CohereAsrForConditionalGeneration`` import (which requires
|
|
32
|
+
transformers>=5.4.0) so that users on older transformers who only use
|
|
33
|
+
the ct2/hf backends can still import this module.
|
|
34
|
+
"""
|
|
35
|
+
from easytranscriber.asr.cohere import transcribe as cohere_transcribe
|
|
36
|
+
|
|
37
|
+
return cohere_transcribe
|
|
38
|
+
|
|
39
|
+
|
|
30
40
|
# dispatch mapping
|
|
31
41
|
TRANSCRIBE_BACKENDS = {
|
|
32
42
|
"ct2": ct2_transcribe,
|
|
33
43
|
"hf": hf_transcribe,
|
|
44
|
+
"cohere": _load_cohere_transcribe,
|
|
34
45
|
}
|
|
35
46
|
|
|
36
47
|
VAD_BACKENDS = {
|
|
@@ -56,10 +67,13 @@ def pipeline(
|
|
|
56
67
|
task: str = "transcribe",
|
|
57
68
|
beam_size: int = 5,
|
|
58
69
|
max_length: int = 250,
|
|
70
|
+
max_new_tokens: int = 256,
|
|
59
71
|
repetition_penalty: float = 1.0,
|
|
60
72
|
length_penalty: float = 1.0,
|
|
61
73
|
patience: float = 1.0,
|
|
62
74
|
no_repeat_ngram_size: int = 0,
|
|
75
|
+
punctuation: bool = True,
|
|
76
|
+
generate_kwargs: dict | None = None,
|
|
63
77
|
start_wildcard: bool = False,
|
|
64
78
|
end_wildcard: bool = False,
|
|
65
79
|
blank_id: int | None = None,
|
|
@@ -103,7 +117,9 @@ def pipeline(
|
|
|
103
117
|
speeches : list[list[SpeechSegment]], optional
|
|
104
118
|
Existing speech segments for alignment.
|
|
105
119
|
backend : str, optional
|
|
106
|
-
Backend to use for the transcription model: "ct2" or "
|
|
120
|
+
Backend to use for the transcription model: "ct2", "hf", or "cohere". Default is "ct2".
|
|
121
|
+
The "cohere" backend requires `transformers>=5.4.0`, `streaming=True`, and an explicit
|
|
122
|
+
`language` (Cohere has no language detection).
|
|
107
123
|
sample_rate : int, optional
|
|
108
124
|
Sample rate.
|
|
109
125
|
chunk_size : int, optional
|
|
@@ -127,7 +143,14 @@ def pipeline(
|
|
|
127
143
|
repetition_penalty : float, optional
|
|
128
144
|
See HF [source code](https://github.com/huggingface/transformers/blob/v4.57.5/src/transformers/generation/configuration_utils.py#L188-L190) for details.
|
|
129
145
|
max_length : int, optional
|
|
130
|
-
Maximum length of generated text.
|
|
146
|
+
Maximum length of generated text. Applies to Whisper backends (ct2, hf).
|
|
147
|
+
max_new_tokens : int, optional
|
|
148
|
+
Maximum number of new tokens to generate per chunk. Applies to the cohere backend.
|
|
149
|
+
punctuation : bool, optional
|
|
150
|
+
Emit punctuation in Cohere transcriptions. Applies to the cohere backend only.
|
|
151
|
+
generate_kwargs : dict, optional
|
|
152
|
+
Extra kwargs forwarded to ``model.generate()`` for the cohere backend
|
|
153
|
+
(e.g. ``num_beams``, ``length_penalty``).
|
|
131
154
|
start_wildcard : bool, optional
|
|
132
155
|
Add start wildcard to forced alignment.
|
|
133
156
|
end_wildcard : bool, optional
|
|
@@ -244,32 +267,68 @@ def pipeline(
|
|
|
244
267
|
)
|
|
245
268
|
|
|
246
269
|
# Step 2: Run Transcription
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
270
|
+
dataset_kwargs: dict = {}
|
|
271
|
+
|
|
272
|
+
if backend == "cohere":
|
|
273
|
+
if language is None:
|
|
274
|
+
raise ValueError(
|
|
275
|
+
"The 'cohere' backend requires an explicit `language` — "
|
|
276
|
+
"CohereAsrForConditionalGeneration does not perform language detection."
|
|
277
|
+
)
|
|
278
|
+
if not streaming:
|
|
279
|
+
raise ValueError(
|
|
280
|
+
"The 'cohere' backend requires `streaming=True` "
|
|
281
|
+
"(the non-streaming AudioFileDataset does not support return_raw_audio)."
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
transcription_args = {
|
|
285
|
+
"language": language,
|
|
286
|
+
"max_new_tokens": max_new_tokens,
|
|
287
|
+
"punctuation": punctuation,
|
|
288
|
+
"sample_rate": sample_rate,
|
|
289
|
+
"generate_kwargs": generate_kwargs,
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
from transformers import AutoProcessor, CohereAsrForConditionalGeneration
|
|
293
|
+
|
|
294
|
+
logger.info(f"Loading Cohere ASR model from {transcription_model}...")
|
|
295
|
+
model = (
|
|
296
|
+
CohereAsrForConditionalGeneration.from_pretrained(
|
|
297
|
+
transcription_model, torch_dtype=torch.float16, cache_dir=cache_dir
|
|
298
|
+
)
|
|
299
|
+
.to(device)
|
|
300
|
+
.eval()
|
|
265
301
|
)
|
|
302
|
+
processor = AutoProcessor.from_pretrained(transcription_model, cache_dir=cache_dir)
|
|
303
|
+
dataset_kwargs = {"return_raw_audio": True}
|
|
266
304
|
else:
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
305
|
+
transcription_args = {
|
|
306
|
+
"language": language,
|
|
307
|
+
"task": task,
|
|
308
|
+
"beam_size": beam_size,
|
|
309
|
+
"max_length": max_length,
|
|
310
|
+
"repetition_penalty": repetition_penalty,
|
|
311
|
+
"length_penalty": length_penalty,
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
if backend == "ct2":
|
|
315
|
+
model_path = hf_to_ct2_converter(transcription_model, cache_dir=cache_dir)
|
|
316
|
+
logger.info(f"Loading CTranslate2 model from {model_path}...")
|
|
317
|
+
model = ctranslate2.models.Whisper(model_path.as_posix(), device=device)
|
|
318
|
+
transcription_args.update(
|
|
319
|
+
{
|
|
320
|
+
"patience": patience,
|
|
321
|
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
|
322
|
+
}
|
|
323
|
+
)
|
|
324
|
+
else:
|
|
325
|
+
logger.info(f"Loading Hugging Face model from {transcription_model}...")
|
|
326
|
+
model = WhisperForConditionalGeneration.from_pretrained(
|
|
327
|
+
transcription_model, torch_dtype=torch.float16, cache_dir=cache_dir
|
|
328
|
+
).to(device)
|
|
329
|
+
|
|
330
|
+
processor = WhisperProcessor.from_pretrained(transcription_model, cache_dir=cache_dir)
|
|
271
331
|
|
|
272
|
-
processor = WhisperProcessor.from_pretrained(transcription_model, cache_dir=cache_dir)
|
|
273
332
|
json_dataset = JSONMetadataDataset(
|
|
274
333
|
json_paths=[str(Path(output_vad_dir) / p) for p in json_paths]
|
|
275
334
|
)
|
|
@@ -281,6 +340,7 @@ def pipeline(
|
|
|
281
340
|
sample_rate=sample_rate,
|
|
282
341
|
chunk_size=chunk_size,
|
|
283
342
|
alignment_strategy="chunk",
|
|
343
|
+
**dataset_kwargs,
|
|
284
344
|
)
|
|
285
345
|
|
|
286
346
|
file_dataloader = torch.utils.data.DataLoader(
|
|
@@ -293,6 +353,9 @@ def pipeline(
|
|
|
293
353
|
)
|
|
294
354
|
|
|
295
355
|
transcribe = TRANSCRIBE_BACKENDS[backend]
|
|
356
|
+
if backend == "cohere":
|
|
357
|
+
transcribe = transcribe() # lazy-load to avoid importing on older transformers
|
|
358
|
+
|
|
296
359
|
transcribe(
|
|
297
360
|
model=model,
|
|
298
361
|
processor=processor,
|
|
@@ -10,6 +10,7 @@ src/easytranscriber.egg-info/dependency_links.txt
|
|
|
10
10
|
src/easytranscriber.egg-info/entry_points.txt
|
|
11
11
|
src/easytranscriber.egg-info/requires.txt
|
|
12
12
|
src/easytranscriber.egg-info/top_level.txt
|
|
13
|
+
src/easytranscriber/asr/cohere.py
|
|
13
14
|
src/easytranscriber/asr/ct2.py
|
|
14
15
|
src/easytranscriber/asr/hf.py
|
|
15
16
|
src/easytranscriber/data/__init__.py
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
def transcribe_collate_fn(batch: list[dict]) -> dict:
|
|
5
|
-
"""
|
|
6
|
-
Collate function for transcription.
|
|
7
|
-
|
|
8
|
-
Parameters
|
|
9
|
-
----------
|
|
10
|
-
batch : list of dict
|
|
11
|
-
List of samples from the dataset.
|
|
12
|
-
|
|
13
|
-
Returns
|
|
14
|
-
-------
|
|
15
|
-
dict
|
|
16
|
-
Collated batch with 'features', 'start_times', and 'speech_ids'.
|
|
17
|
-
"""
|
|
18
|
-
# Remove None values
|
|
19
|
-
speech_ids = [b["speech_id"] for b in batch if b is not None]
|
|
20
|
-
start_times = [b["start_time_global"] for b in batch if b is not None]
|
|
21
|
-
batch = [b["feature"] for b in batch if b is not None]
|
|
22
|
-
|
|
23
|
-
# Concat, keep batch dimension
|
|
24
|
-
batch = torch.cat(batch, dim=0)
|
|
25
|
-
|
|
26
|
-
return {
|
|
27
|
-
"features": batch,
|
|
28
|
-
"start_times": start_times,
|
|
29
|
-
"speech_ids": speech_ids,
|
|
30
|
-
}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{easytranscriber-0.2.1 → easytranscriber-0.2.2}/src/easytranscriber.egg-info/entry_points.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|