pelican-nlp 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pelican_nlp/Nils_backup/__init__.py +0 -0
  2. pelican_nlp/Nils_backup/extract_acoustic_features.py +274 -0
  3. pelican_nlp/Nils_backup/fluency/__init__.py +0 -0
  4. pelican_nlp/Nils_backup/fluency/aggregate_fluency_results.py +186 -0
  5. pelican_nlp/Nils_backup/fluency/behavioral_data.py +42 -0
  6. pelican_nlp/Nils_backup/fluency/check_duplicates.py +169 -0
  7. pelican_nlp/Nils_backup/fluency/coherence.py +653 -0
  8. pelican_nlp/Nils_backup/fluency/config.py +231 -0
  9. pelican_nlp/Nils_backup/fluency/main.py +182 -0
  10. pelican_nlp/Nils_backup/fluency/optimality_without_tsa.py +466 -0
  11. pelican_nlp/Nils_backup/fluency/plot_fluency.py +573 -0
  12. pelican_nlp/Nils_backup/fluency/plotting_utils.py +170 -0
  13. pelican_nlp/Nils_backup/fluency/questionnaires_data.py +43 -0
  14. pelican_nlp/Nils_backup/fluency/stats_fluency.py +930 -0
  15. pelican_nlp/Nils_backup/fluency/utils.py +41 -0
  16. pelican_nlp/Nils_backup/speaker_diarization_Nils.py +328 -0
  17. pelican_nlp/Nils_backup/transcription/__init__.py +0 -0
  18. pelican_nlp/Nils_backup/transcription/annotation_tool.py +1001 -0
  19. pelican_nlp/Nils_backup/transcription/annotation_tool_boundaries.py +1122 -0
  20. pelican_nlp/Nils_backup/transcription/annotation_tool_sandbox.py +985 -0
  21. pelican_nlp/Nils_backup/transcription/output/holmes_control_nova_all_outputs.json +7948 -0
  22. pelican_nlp/Nils_backup/transcription/test.json +1 -0
  23. pelican_nlp/Nils_backup/transcription/transcribe_audio.py +314 -0
  24. pelican_nlp/Nils_backup/transcription/transcribe_audio_chunked.py +695 -0
  25. pelican_nlp/Nils_backup/transcription/transcription.py +801 -0
  26. pelican_nlp/Nils_backup/transcription/transcription_gui.py +955 -0
  27. pelican_nlp/Nils_backup/transcription/word_boundaries.py +190 -0
  28. pelican_nlp/Silvia_files/Opensmile/opensmile_feature_extraction.py +66 -0
  29. pelican_nlp/Silvia_files/prosogram/prosogram.py +104 -0
  30. pelican_nlp/__init__.py +1 -1
  31. pelican_nlp/_version.py +1 -0
  32. pelican_nlp/configuration_files/config_audio.yml +150 -0
  33. pelican_nlp/configuration_files/config_discourse.yml +104 -0
  34. pelican_nlp/configuration_files/config_fluency.yml +108 -0
  35. pelican_nlp/configuration_files/config_general.yml +131 -0
  36. pelican_nlp/configuration_files/config_morteza.yml +103 -0
  37. pelican_nlp/praat/__init__.py +29 -0
  38. {pelican_nlp-0.1.0.dist-info → pelican_nlp-0.1.2.dist-info}/METADATA +14 -21
  39. pelican_nlp-0.1.2.dist-info/RECORD +75 -0
  40. pelican_nlp-0.1.0.dist-info/RECORD +0 -39
  41. {pelican_nlp-0.1.0.dist-info → pelican_nlp-0.1.2.dist-info}/WHEEL +0 -0
  42. {pelican_nlp-0.1.0.dist-info → pelican_nlp-0.1.2.dist-info}/licenses/LICENSE +0 -0
  43. {pelican_nlp-0.1.0.dist-info → pelican_nlp-0.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,801 @@
1
+ # Standard Library Imports
2
+ import io
3
+ import re
4
+ import unicodedata
5
+ from typing import List, Dict
6
+ from pathlib import Path
7
+ import json
8
+
9
+ # Third-party Library Imports
10
+ import librosa
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import torch
14
+ import torchaudio
15
+ import torchaudio.transforms as T
16
+ from pydub import AudioSegment
17
+ from pydub.silence import detect_silence
18
+ from transformers import pipeline
19
+ from pyannote.audio import Pipeline as DiarizationPipeline
20
+ import uroman as ur
21
+ import pandas as pd
22
+
23
+
24
+ class Chunk:
25
+ def __init__(self, audio_segment: AudioSegment, start_time: float):
26
+ """
27
+ Represents a chunk of audio.
28
+
29
+ :param audio_segment: The audio segment.
30
+ :param start_time: Start time in the original audio (seconds).
31
+ """
32
+ self.audio_segment = audio_segment
33
+ self.start_time = start_time # Start time in seconds
34
+ self.transcript = ""
35
+ self.whisper_alignments = []
36
+ self.forced_alignments = []
37
+
38
+
39
+ class AudioFile:
40
+ def __init__(self, file_path: str, target_rms_db: float = -20):
41
+ """
42
+ Handles all operations related to an audio file.
43
+
44
+ :param file_path: Path to the audio file.
45
+ :param target_rms_db: Target RMS in dB for normalization.
46
+ """
47
+ self.file_path = file_path
48
+ self.target_rms_db = target_rms_db
49
+ self.normalized_path = None
50
+ self.audio = None
51
+ self.sample_rate = None
52
+ self.chunks: List[Chunk] = []
53
+ self.speaker_segments = []
54
+
55
+ self.metadata = {
56
+ "file_path": file_path,
57
+ "length_seconds": None,
58
+ "sample_rate": None,
59
+ "target_rms_db": target_rms_db,
60
+ "models_used": {}
61
+ }
62
+
63
+ self.load_audio()
64
+
65
+ def load_audio(self):
66
+ """
67
+ Loads the audio file using librosa.
68
+ """
69
+ self.audio, self.sample_rate = librosa.load(self.file_path, sr=None)
70
+ self.metadata["sample_rate"] = self.sample_rate
71
+ print(f"Loaded audio file: {self.file_path}")
72
+
73
+ def register_model(self, model_name: str, parameters: dict):
74
+ """
75
+ Registers a model and its parameters in the metadata.
76
+
77
+ :param model_name: Name of the model.
78
+ :param parameters: Parameters used for the model.
79
+ """
80
+ self.metadata["models_used"][model_name] = parameters
81
+
82
+ def rms_normalization(self):
83
+ """
84
+ Normalizes the audio to the target RMS level and saves it.
85
+ """
86
+ target_rms = 10 ** (self.target_rms_db / 20)
87
+ rms = np.sqrt(np.mean(self.audio ** 2))
88
+ gain = target_rms / rms
89
+ normalized_audio = self.audio * gain
90
+ self.normalized_path = self.file_path.replace(".wav", "_normalized.wav")
91
+ sf.write(self.normalized_path, normalized_audio, self.sample_rate)
92
+ print(f"Normalized audio saved as: {self.normalized_path}")
93
+
94
+ def split_on_silence(self, min_silence_len=1000, silence_thresh=-30,
95
+ min_length=30000, max_length=180000):
96
+ """
97
+ Splits the audio into chunks based on silence.
98
+
99
+ :param min_silence_len: Minimum length of silence to be used for a split (ms).
100
+ :param silence_thresh: Silence threshold in dBFS.
101
+ :param min_length: Minimum length of a chunk (ms).
102
+ :param max_length: Maximum length of a chunk (ms).
103
+ """
104
+ audio_segment = AudioSegment.from_file(self.normalized_path)
105
+ audio_length_ms = len(audio_segment)
106
+ self.metadata["length_seconds"] = audio_length_ms / 1000
107
+ silence_ranges = self._detect_silence_intervals(audio_segment, min_silence_len, silence_thresh)
108
+ splitting_points = self._get_splitting_points(silence_ranges, audio_length_ms)
109
+ initial_intervals = self._create_initial_chunks(splitting_points)
110
+ adjusted_intervals = self._adjust_intervals_by_length(initial_intervals, min_length, max_length)
111
+ chunks_with_timestamps = self._split_audio_by_intervals(audio_segment, adjusted_intervals)
112
+
113
+ self.chunks = [Chunk(chunk_audio, start_i / 1000.0) for chunk_audio, start_i, end_i in chunks_with_timestamps]
114
+ print(f"Total chunks after splitting: {len(self.chunks)}")
115
+
116
+ # Validate the combined length of chunks
117
+ self.validate_chunk_lengths(audio_length_ms)
118
+
119
+ self.register_model("Chunking", {
120
+ "min_silence_len": min_silence_len,
121
+ "silence_thresh": silence_thresh,
122
+ "min_length": min_length,
123
+ "max_length": max_length,
124
+ "num_chunks": len(self.chunks)
125
+ })
126
+
127
+
128
+ def _detect_silence_intervals(self, audio_segment: AudioSegment, min_silence_len: int, silence_thresh: int) -> List[List[int]]:
129
+ """
130
+ Detects silent intervals in the audio segment.
131
+
132
+ :param audio_segment: The audio segment.
133
+ :param min_silence_len: Minimum length of silence to be used for a split (ms).
134
+ :param silence_thresh: Silence threshold in dBFS.
135
+ :return: List of [start_ms, end_ms] pairs representing silence periods.
136
+ """
137
+ return detect_silence(audio_segment, min_silence_len=min_silence_len, silence_thresh=silence_thresh)
138
+
139
+ def _get_splitting_points(self, silence_ranges: List[List[int]], audio_length_ms: int) -> List[int]:
140
+ """
141
+ Computes splitting points based on silence ranges.
142
+
143
+ :param silence_ranges: List of silence intervals.
144
+ :param audio_length_ms: Total length of the audio in ms.
145
+ :return: Sorted list of splitting points in ms.
146
+ """
147
+ splitting_points = [0] + [(start + end) // 2 for start, end in silence_ranges] + [audio_length_ms]
148
+ return splitting_points
149
+
150
+ def _create_initial_chunks(self, splitting_points: List[int]) -> List[tuple]:
151
+ """
152
+ Creates initial chunks based on splitting points.
153
+
154
+ :param splitting_points: List of splitting points in ms.
155
+ :return: List of (start_ms, end_ms) tuples.
156
+ """
157
+ return list(zip(splitting_points[:-1], splitting_points[1:]))
158
+
159
+ def _adjust_intervals_by_length(self, intervals: List[tuple], min_length: int, max_length: int) -> List[tuple]:
160
+ """
161
+ Adjusts intervals based on minimum and maximum length constraints.
162
+
163
+ :param intervals: List of (start_ms, end_ms) tuples.
164
+ :param min_length: Minimum length of a chunk (ms).
165
+ :param max_length: Maximum length of a chunk (ms).
166
+ :return: Adjusted list of intervals.
167
+ """
168
+ adjusted_intervals = []
169
+ buffer_start, buffer_end = intervals[0]
170
+
171
+ for start, end in intervals[1:]:
172
+ buffer_end = end
173
+ buffer_length = buffer_end - buffer_start
174
+
175
+ if buffer_length < min_length:
176
+ # Merge with the next interval by extending the buffer
177
+ continue
178
+ else:
179
+ if buffer_length > max_length:
180
+ # Split the buffer into multiple chunks of `max_length`
181
+ num_splits = int(np.ceil(buffer_length / max_length))
182
+ split_size = int(np.ceil(buffer_length / num_splits))
183
+ for i in range(num_splits):
184
+ split_start = buffer_start + i * split_size
185
+ split_end = min(buffer_start + (i + 1) * split_size, buffer_end)
186
+ adjusted_intervals.append((split_start, split_end))
187
+ else:
188
+ # Add the buffer as a valid interval
189
+ adjusted_intervals.append((buffer_start, buffer_end))
190
+ buffer_start = buffer_end # Reset buffer_start to the end of the current buffer
191
+
192
+ # Handle any remaining buffer (final chunk)
193
+ buffer_length = buffer_end - buffer_start
194
+ if buffer_length > 0:
195
+ if buffer_length >= min_length:
196
+ # Include the final chunk if it's greater than `min_length`
197
+ adjusted_intervals.append((buffer_start, buffer_end))
198
+ else:
199
+ # Optionally include shorter chunks
200
+ print(f"Final chunk is shorter than min_length ({buffer_length} ms), including it anyway.")
201
+ adjusted_intervals.append((buffer_start, buffer_end))
202
+
203
+ return adjusted_intervals
204
+
205
+ def validate_chunk_lengths(self, audio_length_ms: int, tolerance: float = 1.0):
206
+ """
207
+ Validates that the combined length of all chunks matches the original audio length.
208
+
209
+ :param audio_length_ms: Length of the original audio in milliseconds.
210
+ :param tolerance: Allowed tolerance in milliseconds.
211
+ """
212
+ # Sum up the duration of all chunks
213
+ combined_length = sum(len(chunk.audio_segment) for chunk in self.chunks)
214
+
215
+ # Calculate the difference
216
+ difference = abs(combined_length - audio_length_ms)
217
+ if difference > tolerance:
218
+ raise AssertionError(
219
+ f"Chunk lengths validation failed! Combined chunk length ({combined_length} ms) "
220
+ f"differs from original audio length ({audio_length_ms} ms) by {difference} ms, "
221
+ f"which exceeds the allowed tolerance of {tolerance} ms."
222
+ )
223
+ print(f"Chunk length validation passed: Total chunks = {combined_length} ms, Original = {audio_length_ms} ms.")
224
+
225
+ def _split_audio_by_intervals(self, audio_segment: AudioSegment, intervals: List[tuple]) -> List[tuple]:
226
+ """
227
+ Splits the audio segment into chunks based on the provided intervals.
228
+
229
+ :param audio_segment: The audio segment.
230
+ :param intervals: List of (start_ms, end_ms) tuples.
231
+ :return: List of (chunk_audio, start_ms, end_ms) tuples.
232
+ """
233
+ return [(audio_segment[start_ms:end_ms], start_ms, end_ms) for start_ms, end_ms in intervals]
234
+
235
+ def combine_chunks(self):
236
+ """
237
+ Combines transcripts and alignments from all chunks.
238
+
239
+ :param chunks: List of Chunk instances.
240
+ """
241
+ self.transcript_text = " ".join([chunk.transcript for chunk in self.chunks])
242
+ self.whisper_alignments = []
243
+ self.forced_alignments = []
244
+ for chunk in self.chunks:
245
+ self.whisper_alignments.extend(chunk.whisper_alignments)
246
+ self.forced_alignments.extend(chunk.forced_alignments)
247
+ print("Combined transcripts and alignments from all chunks into Transcript.")
248
+
249
+ class Transcript:
250
+ def __init__(self, audio_file: AudioFile = None, json_data: dict = None):
251
+ """
252
+ Initializes the Transcript class.
253
+
254
+ :param audio_file: AudioFile object to initialize from.
255
+ :param json_data: Dictionary loaded from a JSON file.
256
+ """
257
+ if audio_file:
258
+ self.audio_file_path = audio_file.file_path
259
+ self.transcript_text = audio_file.transcript_text
260
+ self.whisper_alignments = audio_file.whisper_alignments
261
+ self.forced_alignments = audio_file.forced_alignments
262
+ self.speaker_segments = audio_file.speaker_segments
263
+ self.combined_data = []
264
+ self.combined_utterances = []
265
+ self.metadata = audio_file.metadata
266
+ elif json_data:
267
+ self.audio_file_path = json_data["audio_file_path"]
268
+ self.metadata = json_data["metadata"]
269
+ self.transcript_text = json_data.get("transcript_text", "")
270
+ self.whisper_alignments = json_data.get("whisper_alignments", [])
271
+ self.forced_alignments = json_data.get("forced_alignments", [])
272
+ self.speaker_segments = json_data.get("speaker_segments", [])
273
+ self.combined_data = json_data.get("combined_data", [])
274
+ self.combined_utterances = json_data.get("utterance_data", [])
275
+ else:
276
+ raise ValueError("Either an AudioFile object or JSON data must be provided.")
277
+
278
+ @classmethod
279
+ def from_json_file(cls, json_file: str):
280
+ """
281
+ Creates a Transcript instance from a JSON file.
282
+
283
+ :param json_file: Path to the JSON file.
284
+ :return: Transcript instance.
285
+ """
286
+ try:
287
+ with open(json_file, "r", encoding="utf-8") as f:
288
+ json_data = json.load(f)
289
+ print(f"Loaded transcript data from '{json_file}'.")
290
+ return cls(json_data=json_data)
291
+ except Exception as e:
292
+ print(f"Error loading JSON file: {e}")
293
+ raise
294
+
295
+ def aggregate_to_utterances(self):
296
+ """
297
+ Aggregates word-level data into utterances based on sentence endings.
298
+ """
299
+ if not self.combined_data:
300
+ print("No combined data available to aggregate.")
301
+ return
302
+
303
+ utterances = []
304
+ current_utterance = {
305
+ "text": "",
306
+ "start_time": None,
307
+ "end_time": None,
308
+ "speakers": {}
309
+ }
310
+
311
+ sentence_endings = re.compile(r'[.?!]$')
312
+ print("Aggregating words into utterances...")
313
+ for word_data in self.combined_data:
314
+ word = word_data["word"]
315
+ start_time = word_data["start_time"]
316
+ end_time = word_data["end_time"]
317
+ speaker = word_data["speaker"]
318
+
319
+ if current_utterance["start_time"] is None:
320
+ current_utterance["start_time"] = start_time
321
+
322
+ current_utterance["text"] += ("" if current_utterance["text"] == "" else " ") + word
323
+ current_utterance["end_time"] = end_time
324
+
325
+ if speaker not in current_utterance["speakers"]:
326
+ current_utterance["speakers"][speaker] = 0
327
+ current_utterance["speakers"][speaker] += 1
328
+
329
+ if sentence_endings.search(word):
330
+ majority_speaker, majority_count = max(
331
+ current_utterance["speakers"].items(), key=lambda item: item[1]
332
+ )
333
+ total_words = sum(current_utterance["speakers"].values())
334
+ confidence = round(majority_count / total_words, 2)
335
+
336
+ utterances.append({
337
+ "text": current_utterance["text"],
338
+ "start_time": current_utterance["start_time"],
339
+ "end_time": current_utterance["end_time"],
340
+ "speaker": majority_speaker,
341
+ "confidence": confidence,
342
+ })
343
+
344
+ current_utterance = {
345
+ "text": "",
346
+ "start_time": None,
347
+ "end_time": None,
348
+ "speakers": {}
349
+ }
350
+
351
+ # Handle any remaining words as the last utterance
352
+ if current_utterance["text"]:
353
+ majority_speaker, majority_count = max(
354
+ current_utterance["speakers"].items(), key=lambda item: item[1]
355
+ )
356
+ total_words = sum(current_utterance["speakers"].values())
357
+ confidence = round(majority_count / total_words, 2)
358
+
359
+ utterances.append({
360
+ "text": current_utterance["text"],
361
+ "start_time": current_utterance["start_time"],
362
+ "end_time": current_utterance["end_time"],
363
+ "speaker": majority_speaker,
364
+ "confidence": confidence,
365
+ })
366
+
367
+ self.combined_utterances = utterances
368
+ print("Aggregated utterances from combined data.")
369
+
370
+ def combine_alignment_and_diarization(self, alignment_source: str):
371
+ """
372
+ Combines alignment and diarization data by assigning speaker labels to each word.
373
+
374
+ :param speaker_segments: List of speaker segments with 'start', 'end', and 'speaker'.
375
+ :param alignment_source: The alignment data to use ('whisper_alignments' or 'forced_alignments').
376
+ """
377
+ if alignment_source not in ['whisper_alignments', 'forced_alignments']:
378
+ raise ValueError("Invalid alignment_source. Choose 'whisper_alignments' or 'forced_alignments'.")
379
+
380
+ alignment = getattr(self, alignment_source, None)
381
+ if alignment is None:
382
+ raise ValueError(f"The alignment source '{alignment_source}' does not exist in the Transcript object.")
383
+
384
+ if not self.speaker_segments:
385
+ print("No speaker segments available for diarization. All words will be labeled as 'UNKNOWN'.")
386
+ self.combined_data = [{**word, 'speaker': 'UNKNOWN'} for word in alignment]
387
+ return
388
+
389
+ combined = []
390
+ seg_idx = 0
391
+ num_segments = len(self.speaker_segments)
392
+
393
+ for word in alignment:
394
+ word_start = word['start_time']
395
+ word_end = word['end_time']
396
+ word_duration = max(1e-6, word_end - word_start) # Avoid zero-duration
397
+
398
+ speaker_overlap = {}
399
+
400
+ # Advance segments that have ended before the word starts
401
+ while seg_idx < num_segments and self.speaker_segments[seg_idx]['end'] < word_start:
402
+ seg_idx += 1
403
+
404
+ temp_idx = seg_idx
405
+ while temp_idx < num_segments and self.speaker_segments[temp_idx]['start'] < word_end:
406
+ seg = self.speaker_segments[temp_idx]
407
+ seg_start = seg['start']
408
+ seg_end = seg['end']
409
+ speaker = seg['speaker']
410
+
411
+ if seg_start <= word_start < seg_end:
412
+ overlap = word_duration # Full overlap
413
+ else:
414
+ overlap_start = max(word_start, seg_start)
415
+ overlap_end = min(word_end, seg_end)
416
+ overlap = max(0.0, overlap_end - overlap_start)
417
+
418
+ if overlap > 0:
419
+ speaker_overlap[speaker] = speaker_overlap.get(speaker, 0.0) + overlap
420
+
421
+ temp_idx += 1
422
+
423
+ assigned_speaker = max(speaker_overlap, key=speaker_overlap.get) if speaker_overlap else 'UNKNOWN'
424
+ word_with_speaker = word.copy()
425
+ word_with_speaker['speaker'] = assigned_speaker
426
+ combined.append(word_with_speaker)
427
+
428
+ self.combined_data = combined
429
+ self.metadata["alignment_source"] = alignment_source
430
+ print(f"Combined alignment and diarization data with {len(self.combined_data)} entries.")
431
+
432
+ def save_as_json(self, output_file="all_transcript_data.json"):
433
+ """
434
+ Saves all transcript data to a JSON file.
435
+
436
+ :param output_file: Path to the output JSON file.
437
+ """
438
+ if not self.combined_data:
439
+ print("No combined data available to save. Ensure 'combine_alignment_and_diarization' is run first.")
440
+ return
441
+
442
+ data = {
443
+ "audio_file_path": self.audio_file_path,
444
+ "metadata": self.metadata,
445
+ "transcript_text": self.transcript_text,
446
+ "whisper_alignments": self.whisper_alignments,
447
+ "forced_alignments": self.forced_alignments,
448
+ "combined_data": self.combined_data,
449
+ "utterance_data": self.combined_utterances,
450
+ "speaker_segments": self.speaker_segments
451
+ }
452
+
453
+ try:
454
+ with open(output_file, "w", encoding="utf-8") as f:
455
+ json.dump(data, f, indent=4)
456
+ print(f"All transcript data successfully saved to '{output_file}'.")
457
+ except Exception as e:
458
+ print(f"Error saving JSON file: {e}")
459
+
460
+
461
+ class AudioTranscriber:
462
+ """
463
+ Handles transcription of audio chunks using Whisper.
464
+ """
465
+ def __init__(self, model = "openai/whisper-medium"):
466
+ # Determine device
467
+ if torch.cuda.is_available():
468
+ self.device = torch.device("cuda")
469
+ elif torch.backends.mps.is_available():
470
+ self.device = torch.device("mps")
471
+ else:
472
+ self.device = torch.device("cpu")
473
+ self.model = model
474
+ # Initialize the Whisper pipeline
475
+ self.transcriber = pipeline(
476
+ "automatic-speech-recognition",
477
+ model=model,
478
+ device=self.device,
479
+ return_timestamps="word" # Ensure word-level timestamps are returned
480
+ )
481
+ print(f"Initialized AudioTranscriber on device: {self.device}")
482
+
483
+ def transcribe(self, audio_file: AudioFile):
484
+ """
485
+ Transcribes each audio chunk and populates the Transcript instance.
486
+
487
+ :param transcript: Transcript instance to populate.
488
+ :param audio_file: AudioFile instance containing audio chunks.
489
+ """
490
+ print("Starting transcription of audio chunks...")
491
+ for idx, chunk in enumerate(audio_file.chunks, start=1):
492
+ try:
493
+ with io.BytesIO() as wav_io:
494
+ chunk.audio_segment.export(wav_io, format="wav")
495
+ wav_io.seek(0)
496
+ transcription_result = self.transcriber(wav_io.read())
497
+
498
+ # Assign transcript to the chunk
499
+ chunk.transcript = transcription_result.get('text', "").strip()
500
+
501
+ # Extract word alignments
502
+ raw_chunks = transcription_result.get('chunks', [])
503
+ clean_chunks = []
504
+ for word_info in raw_chunks:
505
+ if 'timestamp' in word_info and len(word_info['timestamp']) == 2:
506
+ start_time = float(word_info['timestamp'][0]) + chunk.start_time
507
+ end_time = float(word_info['timestamp'][1]) + chunk.start_time
508
+ word_text = word_info.get('text', "").strip()
509
+ if word_text:
510
+ clean_chunks.append({
511
+ "word": word_text,
512
+ "start_time": start_time,
513
+ "end_time": end_time
514
+ })
515
+ chunk.whisper_alignments = clean_chunks
516
+ print(f"Transcribed chunk {idx} with {len(clean_chunks)} words.")
517
+ except Exception as e:
518
+ print(f"Error during transcription of chunk {idx}: {e}")
519
+ chunk.transcript = ""
520
+ chunk.whisper_alignments = []
521
+
522
+ audio_file.register_model("Transcription", {
523
+ "model": self.model,
524
+ "device": str(self.device)
525
+ })
526
+
527
+
528
+ class ForcedAligner:
529
+ """
530
+ Handles forced alignment of transcripts with audio.
531
+ """
532
+ def __init__(self, device: str = None):
533
+ # Determine device
534
+ if torch.cuda.is_available():
535
+ self.device = torch.device("cuda")
536
+ else:
537
+ self.device = torch.device("cpu")
538
+
539
+ # Initialize forced aligner components
540
+ self.bundle = torchaudio.pipelines.MMS_FA
541
+ self.model = self.bundle.get_model().to(self.device)
542
+ self.tokenizer = self.bundle.get_tokenizer()
543
+ self.aligner = self.bundle.get_aligner()
544
+ self.uroman = ur.Uroman()
545
+ self.sample_rate = self.bundle.sample_rate
546
+ print(f"Initialized ForcedAligner on device: {self.device}")
547
+
548
+ def normalize_uroman(self, text: str) -> str:
549
+ """
550
+ Normalizes text using Uroman.
551
+
552
+ :param text: Text to normalize.
553
+ :return: Normalized text.
554
+ """
555
+ text = text.encode('utf-8').decode('utf-8')
556
+ text = text.lower()
557
+ text = text.replace("’", "'")
558
+ text = unicodedata.normalize('NFC', text)
559
+ text = re.sub("([^a-z' ])", " ", text)
560
+ text = re.sub(' +', ' ', text)
561
+ return text.strip()
562
+
563
+ def align(self, audio_file: AudioFile):
564
+ """
565
+ Performs forced alignment and populates the Transcript instance.
566
+
567
+ :param transcript: Transcript instance to populate.
568
+ :param audio_file: AudioFile instance containing audio chunks.
569
+ """
570
+ print("Starting forced alignment of transcripts...")
571
+ for idx, chunk in enumerate(audio_file.chunks, start=1):
572
+ try:
573
+ with io.BytesIO() as wav_io:
574
+ chunk.audio_segment.export(wav_io, format="wav")
575
+ wav_io.seek(0)
576
+ waveform, sample_rate = torchaudio.load(wav_io)
577
+
578
+ # Resample if necessary
579
+ if sample_rate != self.sample_rate:
580
+ resampler = T.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
581
+ waveform = resampler(waveform)
582
+ sample_rate = self.sample_rate
583
+
584
+ # Normalize and tokenize the transcript
585
+ text_roman = self.uroman.romanize_string(chunk.transcript)
586
+ text_normalized = self.normalize_uroman(text_roman)
587
+ transcript_list = text_normalized.split()
588
+ tokens = self.tokenizer(transcript_list)
589
+
590
+ # Perform forced alignment
591
+ with torch.inference_mode():
592
+ emission, _ = self.model(waveform.to(self.device))
593
+ token_spans = self.aligner(emission[0], tokens)
594
+
595
+ # Extract timestamps
596
+ num_frames = emission.size(1)
597
+ ratio = waveform.size(1) / num_frames
598
+ for spans, word in zip(token_spans, transcript_list):
599
+ start_sec = (spans[0].start * ratio / sample_rate) + chunk.start_time
600
+ end_sec = (spans[-1].end * ratio / sample_rate) + chunk.start_time
601
+ chunk.forced_alignments.append({
602
+ "word": word,
603
+ "start_time": start_sec,
604
+ "end_time": end_sec
605
+ })
606
+ print(f"Aligned chunk {idx} successfully.")
607
+ except Exception as e:
608
+ print(f"Error during alignment of chunk {idx}: {e}")
609
+
610
+ audio_file.register_model("Forced Alignment", {
611
+ "model": "torchaudio.pipelines.MMS_FA",
612
+ "device": str(self.device)
613
+ })
614
+
615
+
616
+ class SpeakerDiarizer:
617
+ """
618
+ Handles speaker diarization of audio files.
619
+ """
620
+ def __init__(self, hf_token: str, parameters: Dict, model = "pyannote/speaker-diarization-3.1"):
621
+ """
622
+ Initializes the SpeakerDiarizer.
623
+
624
+ :param hf_token: Hugging Face token for accessing diarization models.
625
+ :param parameters: Parameters for the diarization pipeline.
626
+ """
627
+ if torch.cuda.is_available():
628
+ self.device = torch.device("cuda")
629
+ elif torch.backends.mps.is_available():
630
+ self.device = torch.device("mps")
631
+ else:
632
+ self.device = torch.device("cpu")
633
+
634
+ self.diarization_pipeline = DiarizationPipeline.from_pretrained(
635
+ model,
636
+ use_auth_token=hf_token
637
+ )
638
+ self.model = model
639
+ print("Initializing SpeakerDiarizer with parameters...")
640
+ self.parameters = parameters
641
+ self.diarization_pipeline.instantiate(parameters)
642
+ self.diarization_pipeline.to(self.device)
643
+ print("Initialized SpeakerDiarizer successfully.")
644
+
645
+ def diarize(self, audio_file: AudioFile, num_speakers: int = None):
646
+ """
647
+ Performs speaker diarization on the given audio file.
648
+
649
+ :param audio_file: AudioFile instance containing audio data.
650
+ :param num_speakers: Expected number of speakers.
651
+ """
652
+ print("Starting speaker diarization...")
653
+ try:
654
+ if num_speakers is not None:
655
+ diarization_result = self.diarization_pipeline(
656
+ audio_file.normalized_path,
657
+ num_speakers=num_speakers
658
+ )
659
+ print(f"Diarization completed with {num_speakers} speakers.")
660
+ else:
661
+ diarization_result = self.diarization_pipeline(
662
+ audio_file.normalized_path
663
+ )
664
+ print("Diarization completed without specifying number of speakers.")
665
+
666
+ # Extract speaker segments
667
+ audio_file.speaker_segments = []
668
+ for segment, _, speaker in diarization_result.itertracks(yield_label=True):
669
+ audio_file.speaker_segments.append({
670
+ "start": segment.start,
671
+ "end": segment.end,
672
+ "speaker": speaker
673
+ })
674
+ print(f"Detected {len(audio_file.speaker_segments)} speaker segments.")
675
+ except Exception as e:
676
+ print(f"An error occurred during diarization: {e}")
677
+
678
+ audio_file.register_model("Speaker Diarization", {
679
+ "model": self.model,
680
+ "device": str(self.device),
681
+ "parameters": self.parameters,
682
+ "speakers": num_speakers if num_speakers else "not specified"
683
+ })
684
+
685
+
686
+ def process_audio_files(files: List[str],
687
+ hf_token: str,
688
+ diarizer_params: Dict = {
689
+ "segmentation": {
690
+ "min_duration_off": 0.0,
691
+ },
692
+ "clustering": {
693
+ "method": "centroid",
694
+ "min_cluster_size": 12,
695
+ "threshold": 0.8,
696
+ }
697
+ },
698
+ num_speakers: int = 2,
699
+ output_folder: str = "output",
700
+ min_silence_len: int = 1000,
701
+ silence_thresh: int = -30,
702
+ min_length: int = 90000,
703
+ max_length: int = 150000,
704
+ timestamp_source: str = "whisper_alignments"):
705
+ """
706
+ Processes one or more audio files through the entire pipeline.
707
+
708
+ :param files: List of file paths to process.
709
+ :param hf_token: Hugging Face token for accessing diarization models.
710
+ :param diarizer_params: Parameters for the SpeakerDiarizer model.
711
+ :param num_speakers: Expected number of speakers.
712
+ :param output_folder: Folder to save the output JSON files.
713
+ :param min_silence_len: Minimum silence length for splitting (ms).
714
+ :param silence_thresh: Silence threshold in dBFS for splitting.
715
+ :param min_length: Minimum chunk length in ms.
716
+ :param max_length: Maximum chunk length in ms.
717
+ :param timestamp_source: Alignment source to use ('whisper_alignments' or 'forced_alignments').
718
+ """
719
+ Path(output_folder).mkdir(exist_ok=True) # Create output folder if it doesn't exist
720
+ print("Starting processing of audio files...")
721
+
722
+ # Initialize processing classes
723
+ transcriber = AudioTranscriber()
724
+ aligner = ForcedAligner()
725
+ diarizer = SpeakerDiarizer(hf_token, parameters=diarizer_params)
726
+
727
+ for file_path in files:
728
+ print(f"\nProcessing file: {file_path}")
729
+ audio_file = AudioFile(file_path)
730
+
731
+ # Step 1: Normalize audio
732
+ print("Step 1/6: Normalizing audio...")
733
+ audio_file.rms_normalization()
734
+
735
+ # Step 2: Split audio into chunks based on silence
736
+ print("Step 2/6: Splitting audio on silence...")
737
+ audio_file.split_on_silence(
738
+ min_silence_len=min_silence_len,
739
+ silence_thresh=silence_thresh,
740
+ min_length=min_length,
741
+ max_length=max_length
742
+ )
743
+
744
+ # Step 3: Transcribe audio chunks
745
+ print("Step 3/6: Transcribing audio chunks...")
746
+ transcriber.transcribe(audio_file)
747
+ for idx, chunk in enumerate(audio_file.chunks, start=1):
748
+ print(f"Chunk {idx} Transcript: {chunk.transcript}\n")
749
+
750
+ # Step 4: Perform forced alignment
751
+ print("Step 4/6: Performing forced alignment...")
752
+ aligner.align(audio_file)
753
+ audio_file.combine_chunks()
754
+
755
+ # Step 5: Perform speaker diarization
756
+ print("Step 5/6: Performing speaker diarization...")
757
+ diarizer.diarize(audio_file, num_speakers)
758
+
759
+ # Step 6: Combine alignment and diarization data
760
+ print("Step 6/6: Combining alignment and diarization data...")
761
+ transcript = Transcript(audio_file)
762
+ transcript.combine_alignment_and_diarization(timestamp_source)
763
+ transcript.aggregate_to_utterances()
764
+
765
+ # Save all data as JSON
766
+ all_output_file = Path(output_folder) / f"{Path(file_path).stem}_all_outputs.json"
767
+ print(f"Saving results to: {all_output_file}")
768
+ transcript.save_as_json(all_output_file)
769
+ print(f"Finished processing: {file_path}\n{'-' * 60}")
770
+
771
+ del transcriber
772
+ del aligner
773
+ del diarizer
774
+
775
+ print("All files have been processed.")
776
+
777
+
778
+ # Example Usage
779
+ if __name__ == "__main__":
780
+ import os
781
+
782
+ # Define input and output paths
783
+ audio_file_path = "audio.wav" # Replace with your actual audio file path
784
+ output_directory = "output"
785
+
786
+ # Ensure output directory exists
787
+ os.makedirs(output_directory, exist_ok=True)
788
+
789
+ # List of files to process
790
+ files_to_process = [audio_file_path]
791
+
792
+ # Hugging Face token (replace with your actual token)
793
+ hugging_face_token = "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"
794
+
795
+ # Process the audio files
796
+ process_audio_files(
797
+ files=files_to_process,
798
+ hf_token=hugging_face_token,
799
+ output_folder=output_directory,
800
+ num_speakers=2
801
+ )