ttsds 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ttsds/__about__.py ADDED
@@ -0,0 +1,4 @@
1
+ # SPDX-FileCopyrightText: 2024-present Christoph Minixhofer <christoph.minixhofer@gmail.com>
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+ __version__ = "0.0.1"
ttsds/__init__.py ADDED
@@ -0,0 +1,345 @@
1
+ from typing import List, Optional
2
+ import importlib.resources
3
+ from time import time
4
+ from pathlib import Path
5
+ import pickle
6
+ import gzip
7
+ import requests
8
+
9
+ import pandas as pd
10
+ from transformers import logging
11
+ import numpy as np
12
+ from sklearn.decomposition import PCA
13
+
14
+ from ttsds.benchmarks.environment.voicefixer import VoiceFixerBenchmark
15
+ from ttsds.benchmarks.environment.wada_snr import WadaSNRBenchmark
16
+ from ttsds.benchmarks.general.hubert import HubertBenchmark
17
+ from ttsds.benchmarks.general.wav2vec2 import Wav2Vec2Benchmark
18
+ from ttsds.benchmarks.general.wavlm import WavLMBenchmark
19
+ from ttsds.benchmarks.intelligibility.w2v2_wer import Wav2Vec2WERBenchmark
20
+ from ttsds.benchmarks.intelligibility.whisper_wer import WhisperWERBenchmark
21
+ from ttsds.benchmarks.prosody.mpm import MPMBenchmark
22
+ from ttsds.benchmarks.prosody.pitch import PitchBenchmark
23
+ from ttsds.benchmarks.prosody.hubert_token import HubertTokenBenchmark
24
+ from ttsds.benchmarks.speaker.wespeaker import WeSpeakerBenchmark
25
+ from ttsds.benchmarks.speaker.dvector import DVectorBenchmark
26
+ from ttsds.benchmarks.benchmark import BenchmarkCategory, BenchmarkDimension
27
+ from ttsds.util.dataset import Dataset, TarDataset, DataDistribution, DEFAULT_BENCHMARKS
28
+
29
+ # we do this to avoid "some weights of the model checkpoint at ... were not used when initializing" warnings
30
+ logging.set_verbosity_error()
31
+
32
+
33
+ benchmark_dict = {
34
+ "hubert": HubertBenchmark,
35
+ "wav2vec2": Wav2Vec2Benchmark,
36
+ "wavlm": WavLMBenchmark,
37
+ "wav2vec2_wer": Wav2Vec2WERBenchmark,
38
+ "whisper_wer": WhisperWERBenchmark,
39
+ "mpm": MPMBenchmark,
40
+ "pitch": PitchBenchmark,
41
+ "wespeaker": WeSpeakerBenchmark,
42
+ "dvector": DVectorBenchmark,
43
+ "hubert_token": HubertTokenBenchmark,
44
+ "voicefixer": VoiceFixerBenchmark,
45
+ "wada_snr": WadaSNRBenchmark,
46
+ }
47
+
48
+ with importlib.resources.path("ttsds", "data") as data_path:
49
+ # if they don't exist, download from github
50
+ for noise_name in [
51
+ "esc50",
52
+ "all_ones",
53
+ "all_zeros",
54
+ "normal_distribution",
55
+ "uniform_distribution",
56
+ ]:
57
+ if not Path(f"{data_path}/noise_{noise_name}.pkl.gz").exists():
58
+ print(f"Downloading noise_{noise_name}.pkl.gz")
59
+ url = f"https://github.com/ttsds/ttsds/raw/main/src/ttsds/data/noise_{noise_name}.pkl.gz"
60
+ r = requests.get(url)
61
+ with open(f"{data_path}/noise_{noise_name}.pkl.gz", "wb") as f:
62
+ f.write(r.content)
63
+
64
+ for speech_name in [
65
+ "blizzard2008",
66
+ "blizzard2013",
67
+ "common_voice",
68
+ "libritts_test",
69
+ "libritts_r_test",
70
+ "lj_speech",
71
+ "vctk",
72
+ ]:
73
+ if not Path(f"{data_path}/reference_speech_{speech_name}.pkl.gz").exists():
74
+ print(f"Downloading reference_speech_{speech_name}.pkl.gz")
75
+ url = f"https://github.com/ttsds/ttsds/raw/main/src/ttsds/data/reference_speech_{speech_name}.pkl.gz"
76
+ r = requests.get(url)
77
+ with open(f"{data_path}/reference_speech_{speech_name}.pkl.gz", "wb") as f:
78
+ f.write(r.content)
79
+
80
+
81
+ # check if the reference and noise distributions are already saved
82
+ if not Path(f"{data_path}/reference_speech_blizzard2008.pkl.gz").exists():
83
+ print("Creating reference distributions")
84
+ ref_benchmark_dict = {
85
+ k: v() for k, v in benchmark_dict.items()
86
+ }
87
+ REFERENCE_DISTS = [
88
+ DataDistribution(
89
+ TarDataset(data_path / "original" / f"speech_{name}.tar.gz"),
90
+ ref_benchmark_dict,
91
+ benchmarks=DEFAULT_BENCHMARKS,
92
+ name=f"speech_{name}",
93
+ )
94
+ for name in [
95
+ "blizzard2008",
96
+ "blizzard2013",
97
+ "common_voice",
98
+ "libritts_test",
99
+ "libritts_r_test",
100
+ "lj_speech",
101
+ "vctk",
102
+ ]
103
+ ]
104
+ # save the reference distributions
105
+ for dist in REFERENCE_DISTS:
106
+ dist.to_pickle(f"{data_path}/reference_{dist.name}.pkl.gz")
107
+
108
+ REFERENCE_DISTS = [
109
+ DataDistribution.from_pickle(f"{data_path}/reference_{name}.pkl.gz")
110
+ for name in [
111
+ "speech_blizzard2008",
112
+ "speech_blizzard2013",
113
+ "speech_common_voice",
114
+ "speech_libritts_test",
115
+ "speech_libritts_r_test",
116
+ "speech_lj_speech",
117
+ "speech_vctk",
118
+ ]
119
+ ]
120
+
121
+ if not Path(f"{data_path}/noise_esc50.pkl.gz").exists():
122
+ print("Creating noise distributions")
123
+ ref_benchmark_dict = {
124
+ k: v() for k, v in benchmark_dict.items()
125
+ }
126
+ NOISE_DISTS = [
127
+ DataDistribution(
128
+ TarDataset(data_path / "original" / f"noise_{name}.tar.gz"),
129
+ ref_benchmark_dict,
130
+ benchmarks=DEFAULT_BENCHMARKS,
131
+ name=name,
132
+ )
133
+ for name in [
134
+ "esc50",
135
+ "all_ones",
136
+ "all_zeros",
137
+ "normal_distribution",
138
+ "uniform_distribution",
139
+ ]
140
+ ]
141
+ # save the noise distributions
142
+ for dist in NOISE_DISTS:
143
+ dist.to_pickle(f"{data_path}/noise_{dist.name}.pkl.gz")
144
+
145
+ NOISE_DISTS = [
146
+ DataDistribution.from_pickle(f"{data_path}/noise_{name}.pkl.gz")
147
+ for name in [
148
+ "esc50",
149
+ "all_ones",
150
+ "all_zeros",
151
+ "normal_distribution",
152
+ "uniform_distribution",
153
+ ]
154
+ ]
155
+
156
+
157
+ class BenchmarkSuite:
158
+ def __init__(
159
+ self,
160
+ datasets: List[Dataset],
161
+ benchmarks: List[str] = DEFAULT_BENCHMARKS,
162
+ print_results: bool = True,
163
+ skip_errors: bool = False,
164
+ noise_distributions: List[DataDistribution] = NOISE_DISTS,
165
+ reference_distributions: List[DataDistribution] = REFERENCE_DISTS,
166
+ write_to_file: str = None,
167
+ ):
168
+ self.benchmarks = benchmarks
169
+ self.benchmark_objects = [benchmark_dict[benchmark]() for benchmark in benchmarks]
170
+ # sort by category and then by name
171
+ self.benchmark_objects = sorted(
172
+ self.benchmark_objects, key=lambda x: (x.category.value, x.name)
173
+ )
174
+ self.datasets = datasets
175
+ self.datasets = sorted(self.datasets, key=lambda x: x.name)
176
+ self.database = pd.DataFrame(
177
+ columns=[
178
+ "benchmark_name",
179
+ "benchmark_category",
180
+ "dataset",
181
+ "score",
182
+ "ci",
183
+ "time_taken",
184
+ "noise_dataset",
185
+ "reference_dataset",
186
+ ]
187
+ )
188
+ self.print_results = print_results
189
+ self.skip_errors = skip_errors
190
+ self.noise_distributions = noise_distributions
191
+ self.reference_distributions = reference_distributions
192
+ self.write_to_file = write_to_file
193
+ if Path(write_to_file).exists():
194
+ self.database = pd.read_csv(write_to_file, index_col=0)
195
+ self.database = self.database.reset_index()
196
+
197
+ def run(self) -> pd.DataFrame:
198
+ for benchmark in self.benchmark_objects:
199
+ for dataset in self.datasets:
200
+ # empty lines for better readability
201
+ print("\n")
202
+ print(f"{'='*80}")
203
+ print(f"Benchmark Category: {benchmark.category.name}")
204
+ print(f"Running {benchmark.name} on {dataset.root_dir}")
205
+ try:
206
+ # check if it's in the database
207
+ if (
208
+ (self.database["benchmark_name"] == benchmark.name)
209
+ & (self.database["dataset"] == dataset.name)
210
+ ).any():
211
+ print(
212
+ f"Skipping {benchmark.name} on {dataset.name} as it's already in the database"
213
+ )
214
+ continue
215
+ start = time()
216
+ if "WER".lower() in benchmark.name.lower():
217
+ print([
218
+ x.get_distribution(benchmark.key) for x in self.reference_distributions
219
+ ])
220
+ print([
221
+ x.get_distribution(benchmark.key) for x in self.noise_distributions
222
+ ])
223
+ print(benchmark.get_distribution(dataset))
224
+ score = benchmark.compute_score(
225
+ dataset, self.reference_distributions, self.noise_distributions
226
+ )
227
+ time_taken = time() - start
228
+ except Exception as e:
229
+ if self.skip_errors:
230
+ print(f"Error: {e}")
231
+ score = (np.nan, np.nan)
232
+ time_taken = np.nan
233
+ else:
234
+ raise e
235
+ result = {
236
+ "benchmark_name": [benchmark.name],
237
+ "benchmark_category": [benchmark.category.value],
238
+ "dataset": [dataset.name],
239
+ "score": [score[0]],
240
+ "ci": [score[1]],
241
+ "time_taken": [time_taken],
242
+ "noise_dataset": [score[2][0]],
243
+ "reference_dataset": [score[2][1]],
244
+ }
245
+ if self.print_results:
246
+ print(result)
247
+ self.database = pd.concat(
248
+ [
249
+ self.database,
250
+ pd.DataFrame(result),
251
+ ],
252
+ ignore_index=True,
253
+ )
254
+ if self.write_to_file is not None:
255
+ self.database["score"] = self.database["score"].astype(float)
256
+ self.database = self.database.sort_values(
257
+ ["benchmark_category", "benchmark_name", "score"],
258
+ ascending=[True, True, False],
259
+ )
260
+ self.database.to_csv(self.write_to_file, index=False)
261
+ return self.database
262
+
263
+ @staticmethod
264
+ def aggregate_df(df: pd.DataFrame) -> pd.DataFrame:
265
+ def concat_text(x):
266
+ return ", ".join(x)
267
+
268
+ df["benchmark_category"] = df["benchmark_category"].apply(
269
+ lambda x: BenchmarkCategory(x).name
270
+ )
271
+ df = (
272
+ df.groupby(
273
+ [
274
+ "benchmark_category",
275
+ "dataset",
276
+ ]
277
+ )
278
+ .agg(
279
+ {
280
+ "score": ["mean"],
281
+ "ci": ["mean"],
282
+ "time_taken": ["mean"],
283
+ "noise_dataset": [concat_text],
284
+ "reference_dataset": [concat_text],
285
+ "benchmark_name": [concat_text],
286
+ }
287
+ )
288
+ .reset_index()
289
+ )
290
+ # remove multiindex
291
+ df.columns = [x[0] for x in df.columns.ravel()]
292
+ # drop the benchmark_name column
293
+ df = df.drop("benchmark_name", axis=1)
294
+ # replace benchmark_category number with string
295
+ return df
296
+
297
+ def get_aggregated_results(self) -> pd.DataFrame:
298
+ df = self.database.copy()
299
+ return BenchmarkSuite.aggregate_df(df)
300
+
301
+ def get_benchmark_distribution(
302
+ self,
303
+ benchmark_name: str,
304
+ dataset_name: str,
305
+ pca_components: Optional[int] = None,
306
+ ) -> dict:
307
+ benchmark = [x for x in self.benchmark_objects if x.name == benchmark_name][0]
308
+ dataset = [x for x in self.datasets if x.name == dataset_name][0]
309
+ closest_noise = self.database[
310
+ (self.database["benchmark_name"] == benchmark_name)
311
+ & (self.database["dataset"] == dataset_name)
312
+ ]["noise_dataset"].values[0]
313
+ closest_noise = [
314
+ x for x in self.noise_distributions if x.name == closest_noise
315
+ ][0]
316
+ other_noise = [
317
+ x for x in self.noise_distributions if x.name != closest_noise.name
318
+ ][0]
319
+ closest_reference = self.database[
320
+ (self.database["benchmark_name"] == benchmark_name)
321
+ & (self.database["dataset"] == dataset_name)
322
+ ]["reference_dataset"].values[0]
323
+ closest_reference = [
324
+ x for x in self.reference_distributions if x.name == closest_reference
325
+ ][0]
326
+ other_reference = [
327
+ x for x in self.reference_distributions if x.name != closest_reference.name
328
+ ][0]
329
+ result = {
330
+ "benchmark_distribution": benchmark.get_distribution(dataset),
331
+ "noise_distribution": benchmark.get_distribution(closest_noise),
332
+ "reference_distribution": benchmark.get_distribution(closest_reference),
333
+ "other_noise_distribution": benchmark.get_distribution(other_noise),
334
+ "other_reference_distribution": benchmark.get_distribution(other_reference),
335
+ }
336
+ if pca_components is not None:
337
+ pca = PCA(n_components=pca_components)
338
+ # fit on all except the benchmark distribution
339
+ pca.fit(
340
+ np.vstack(
341
+ [v for k, v in result.items() if k != "benchmark_distribution"]
342
+ )
343
+ )
344
+ result = {k: pca.transform(v) for k, v in result.items()}
345
+ return result
File without changes
@@ -0,0 +1,182 @@
1
+ """
2
+ This file contains the Benchmark abstract class.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from enum import Enum
7
+ import hashlib
8
+ import importlib.resources
9
+ import json
10
+ from typing import List, Union
11
+ from functools import lru_cache
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+
16
+ from ttsds.util.dataset import Dataset, DataDistribution
17
+ from ttsds.util.cache import cache, load_cache, check_cache, hash_md5
18
+ from ttsds.util.distances import wasserstein_distance, frechet_distance
19
+
20
+
21
+ class BenchmarkCategory(Enum):
22
+ """
23
+ Enum class for the different categories of benchmarks.
24
+ """
25
+
26
+ OVERALL = 1
27
+ PROSODY = 2
28
+ ENVIRONMENT = 3
29
+ SPEAKER = 4
30
+ PHONETICS = 5
31
+ INTELLIGIBILITY = 6
32
+ TRAINABILITY = 7
33
+ EXTERNAL = 8
34
+
35
+
36
+ class BenchmarkDimension(Enum):
37
+ """
38
+ Enum class for the different dimensions of benchmarks.
39
+ """
40
+
41
+ ONE_DIMENSIONAL = 1
42
+ N_DIMENSIONAL = 2
43
+
44
+
45
+ class Benchmark(ABC):
46
+ """
47
+ Abstract class for a benchmark.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ name: str,
53
+ category: BenchmarkCategory,
54
+ dimension: BenchmarkDimension,
55
+ description: str,
56
+ **kwargs,
57
+ ):
58
+ self.name = name
59
+ self.key = name.lower().replace(" ", "_")
60
+ self.category = category
61
+ self.dimension = dimension
62
+ self.description = description
63
+ self.kwargs = kwargs
64
+
65
+ def get_distribution(self, dataset: Union[Dataset, DataDistribution]) -> np.ndarray:
66
+ """
67
+ Abstract method to get the distribution of the benchmark.
68
+ If the benchmark is one-dimensional, the method should return a
69
+ numpy array with the values of the benchmark for each sample in the dataset.
70
+ If the benchmark is n-dimensional, the method should return a numpy array
71
+ with the values of the benchmark for each sample in the dataset, where each
72
+ row corresponds to a sample and each column corresponds to a dimension of the benchmark.
73
+ """
74
+ ds_hash = hash_md5(dataset)
75
+ benchmark_hash = hash_md5(self)
76
+ cache_name = f"benchmarks/{self.name}/{ds_hash}_{benchmark_hash}"
77
+ if check_cache(cache_name):
78
+ return load_cache(cache_name)
79
+ if check_cache(cache_name + "_mu") and check_cache(cache_name + "_sig"):
80
+ mu = load_cache(cache_name + "_mu")
81
+ sig = load_cache(cache_name + "_sig")
82
+ return (mu, sig)
83
+ if isinstance(dataset, DataDistribution) and self.dimension == BenchmarkDimension.N_DIMENSIONAL:
84
+ mu, sig = dataset.get_distribution(self.key)
85
+ cache(mu, cache_name + "_mu")
86
+ cache(sig, cache_name + "_sig")
87
+ return (mu, sig)
88
+ elif isinstance(dataset, DataDistribution) and self.dimension == BenchmarkDimension.ONE_DIMENSIONAL:
89
+ distribution = dataset.get_distribution(self.key)
90
+ cache(distribution, cache_name)
91
+ return distribution
92
+ distribution = self._get_distribution(dataset)
93
+ cache(distribution, cache_name)
94
+ return distribution
95
+
96
+ @abstractmethod
97
+ def _get_distribution(self, dataset: Dataset) -> np.ndarray:
98
+ """
99
+ Abstract method to get the distribution of the benchmark.
100
+ """
101
+ raise NotImplementedError
102
+
103
+ def __str__(self) -> str:
104
+ return f"{self.category.name}/{self.name}"
105
+
106
+ def __repr__(self):
107
+ return f"{self.category.name}/{self.name}"
108
+
109
+ def __hash__(self) -> int:
110
+ h = hashlib.md5()
111
+ h.update(self.name.encode())
112
+ h.update(self.category.name.encode())
113
+ h.update(self.dimension.name.encode())
114
+ h.update(self.description.encode())
115
+ # convert the kwargs to strings
116
+ kwargs_str = {
117
+ k: str(v) if not isinstance(v, dict) else json.dumps(v, sort_keys=True)
118
+ for k, v in self.kwargs.items()
119
+ }
120
+ h.update(json.dumps(kwargs_str, sort_keys=True).encode())
121
+ return int(h.hexdigest(), 16)
122
+
123
+ @lru_cache(maxsize=None)
124
+ def compute_distance(
125
+ self,
126
+ one_dataset: Union[Dataset, DataDistribution],
127
+ other_dataset: Union[Dataset, DataDistribution],
128
+ ) -> float:
129
+ """
130
+ Compute the distance between the distributions of the benchmark in two datasets.
131
+ """
132
+ one_distribution = self.get_distribution(one_dataset)
133
+ other_distribution = self.get_distribution(other_dataset)
134
+ if self.dimension == BenchmarkDimension.ONE_DIMENSIONAL:
135
+ return wasserstein_distance(one_distribution, other_distribution)
136
+ elif self.dimension == BenchmarkDimension.N_DIMENSIONAL:
137
+ return frechet_distance(one_distribution, other_distribution)
138
+ else:
139
+ raise ValueError("Invalid benchmark dimension")
140
+
141
+ def compute_score(
142
+ self,
143
+ dataset: Dataset,
144
+ reference_datasets: List[Dataset],
145
+ noise_datasets: List[Dataset],
146
+ ) -> float:
147
+ """
148
+ Compute the score of the benchmark on a dataset.
149
+ """
150
+ noise_scores = []
151
+ for noise_ds in noise_datasets:
152
+ score = self.compute_distance(noise_ds, dataset)
153
+ noise_scores.append(score)
154
+ noise_scores = np.array(noise_scores)
155
+
156
+ dataset_scores = []
157
+ for ref_ds in reference_datasets:
158
+ score = self.compute_distance(ref_ds, dataset)
159
+ dataset_scores.append(score)
160
+ dataset_scores = np.array(dataset_scores)
161
+
162
+ closest_noise_idx = np.argmin(noise_scores)
163
+ closest_dataset_idx = np.argmin(dataset_scores)
164
+
165
+ print(f"Closest noise dataset: {noise_datasets[closest_noise_idx].name}")
166
+ print(
167
+ f"Closest reference dataset: {reference_datasets[closest_dataset_idx].name}"
168
+ )
169
+
170
+ noise_score = np.min(noise_scores)
171
+ dataset_score = np.min(dataset_scores)
172
+ combined_score = dataset_score + noise_score
173
+ score = (noise_score / combined_score) * 100
174
+ # TODO: compute confidence interval
175
+ return (
176
+ score,
177
+ 1.0,
178
+ (
179
+ noise_datasets[closest_noise_idx].name,
180
+ reference_datasets[closest_dataset_idx].name,
181
+ ),
182
+ )
@@ -0,0 +1,80 @@
1
+ import tempfile
2
+
3
+ from pesq import pesq
4
+ from voicefixer import VoiceFixer
5
+ from simple_hifigan import Synthesiser
6
+ import numpy as np
7
+ import soundfile as sf
8
+ import librosa
9
+ from tqdm import tqdm
10
+
11
+ from ttsds.benchmarks.benchmark import Benchmark, BenchmarkCategory, BenchmarkDimension
12
+ from ttsds.util.dataset import Dataset
13
+
14
+
15
+ class VoiceFixerBenchmark(Benchmark):
16
+ """
17
+ Benchmark class for the VoiceFixer benchmark.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ ):
23
+ super().__init__(
24
+ name="VoiceFixer",
25
+ category=BenchmarkCategory.ENVIRONMENT,
26
+ dimension=BenchmarkDimension.ONE_DIMENSIONAL,
27
+ description="The phone counts of VoiceFixer.",
28
+ )
29
+ self.model = VoiceFixer()
30
+ self.synthesiser = Synthesiser()
31
+
32
+ def _get_distribution(self, dataset: Dataset) -> np.ndarray:
33
+ """
34
+ Compute the Word Error Rate (WER) distribution of the VoiceFixer model.
35
+
36
+ Args:
37
+ dataset (Dataset): The dataset to compute the WER on.
38
+
39
+ Returns:
40
+ float: The Word Error Rate (WER) distribution of the VoiceFixer model.
41
+ """
42
+ mel_diffs = []
43
+ for wav, _ in tqdm(dataset, desc=f"computing noise for {self.name}"):
44
+ if dataset.sample_rate != 16000:
45
+ wav = librosa.resample(
46
+ wav, orig_sr=dataset.sample_rate, target_sr=16000
47
+ )
48
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
49
+ # take random 2 seconds
50
+ if len(wav) > 32000:
51
+ start = np.random.randint(0, len(wav) - 32000)
52
+ wav = wav[start : start + 32000]
53
+ sf.write(f.name, wav, 16000)
54
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f_out:
55
+ self.model.restore(f.name, f_out.name)
56
+ wav_out, _ = librosa.load(f_out.name, sr=16000)
57
+ wav = wav / (np.max(np.abs(wav)) + 1e-5)
58
+ wav_out = wav_out / np.max(np.abs(wav_out))
59
+ mel = self.synthesiser.wav_to_mel(wav, 16000)[0].T
60
+ mel_out = self.synthesiser.wav_to_mel(wav_out, 16000)[0].T
61
+ if mel_out.shape[0] > mel.shape[0]:
62
+ mel_out = mel_out[: mel.shape[0]]
63
+ elif mel_out.shape[0] < mel.shape[0]:
64
+ mel = mel[: mel_out.shape[0]]
65
+ mel_diff = mel_out
66
+ # check if there is any nan
67
+ if np.isnan(mel_diff).any():
68
+ print("nan found, skip")
69
+ continue
70
+ # convert back to wav
71
+ mel_diff = self.synthesiser(mel_diff.T)[0]
72
+ mel_diff = mel_diff / np.max(np.abs(mel_diff) + 1e-5)
73
+ mel_diff = librosa.resample(mel_diff, orig_sr=22050, target_sr=16000)
74
+ # calculate the difference
75
+ try:
76
+ mel_diff = pesq(16000, wav, mel_diff, "wb")
77
+ mel_diffs.append(mel_diff)
78
+ except:
79
+ mel_diffs.append(0)
80
+ return np.array(mel_diffs)