ttsds 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ttsds-0.0.1/.gitignore +11 -0
- ttsds-0.0.1/LICENSE.txt +9 -0
- ttsds-0.0.1/PKG-INFO +66 -0
- ttsds-0.0.1/README.md +30 -0
- ttsds-0.0.1/pyproject.toml +89 -0
- ttsds-0.0.1/src/ttsds/__about__.py +4 -0
- ttsds-0.0.1/src/ttsds/__init__.py +345 -0
- ttsds-0.0.1/src/ttsds/benchmarks/__init__.py +0 -0
- ttsds-0.0.1/src/ttsds/benchmarks/benchmark.py +182 -0
- ttsds-0.0.1/src/ttsds/benchmarks/environment/voicefixer.py +80 -0
- ttsds-0.0.1/src/ttsds/benchmarks/environment/wada_snr.py +230 -0
- ttsds-0.0.1/src/ttsds/benchmarks/external/pesq.py +67 -0
- ttsds-0.0.1/src/ttsds/benchmarks/external/utmos/__init__.py +70 -0
- ttsds-0.0.1/src/ttsds/benchmarks/external/utmos/change_sample_rate.py +21 -0
- ttsds-0.0.1/src/ttsds/benchmarks/external/utmos/lightning_module.py +71 -0
- ttsds-0.0.1/src/ttsds/benchmarks/external/utmos/model.py +225 -0
- ttsds-0.0.1/src/ttsds/benchmarks/external/wv_mos.py +43 -0
- ttsds-0.0.1/src/ttsds/benchmarks/general/hubert.py +83 -0
- ttsds-0.0.1/src/ttsds/benchmarks/general/wav2vec2.py +83 -0
- ttsds-0.0.1/src/ttsds/benchmarks/general/wavlm.py +83 -0
- ttsds-0.0.1/src/ttsds/benchmarks/intelligibility/w2v2_wer.py +62 -0
- ttsds-0.0.1/src/ttsds/benchmarks/intelligibility/whisper_wer.py +63 -0
- ttsds-0.0.1/src/ttsds/benchmarks/prosody/hubert_token.py +143 -0
- ttsds-0.0.1/src/ttsds/benchmarks/prosody/mpm.py +111 -0
- ttsds-0.0.1/src/ttsds/benchmarks/prosody/pitch.py +40 -0
- ttsds-0.0.1/src/ttsds/benchmarks/speaker/dvector.py +161 -0
- ttsds-0.0.1/src/ttsds/benchmarks/speaker/wespeaker.py +68 -0
- ttsds-0.0.1/src/ttsds/data/dvector/README.md +1 -0
- ttsds-0.0.1/src/ttsds/data/dvector/dvector.pt +0 -0
- ttsds-0.0.1/src/ttsds/util/__init__.py +0 -0
- ttsds-0.0.1/src/ttsds/util/cache.py +70 -0
- ttsds-0.0.1/src/ttsds/util/dataset.py +285 -0
- ttsds-0.0.1/src/ttsds/util/distances.py +80 -0
- ttsds-0.0.1/src/ttsds/util/measures.py +142 -0
- ttsds-0.0.1/src/ttsds/util/mpm.py +142 -0
- ttsds-0.0.1/src/ttsds/util/mpm_modules.py +147 -0
ttsds-0.0.1/.gitignore
ADDED
ttsds-0.0.1/LICENSE.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024-present Christoph Minixhofer <christoph.minixhofer@gmail.com>
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
|
6
|
+
|
|
7
|
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
|
8
|
+
|
|
9
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
ttsds-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: ttsds
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Project-URL: Documentation, https://github.com/ttsds/ttsds#readme
|
|
5
|
+
Project-URL: Issues, https://github.com/ttsds/ttsds/issues
|
|
6
|
+
Project-URL: Source, https://github.com/ttsds/ttsds
|
|
7
|
+
Author-email: Christoph Minixhofer <christoph.minixhofer@gmail.com>
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
License-File: LICENSE.txt
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Programming Language :: Python
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
18
|
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
19
|
+
Requires-Python: >=3.8
|
|
20
|
+
Requires-Dist: allosaurus>=0.1.0
|
|
21
|
+
Requires-Dist: jiwer>=2.2.0
|
|
22
|
+
Requires-Dist: librosa>=0.10.0
|
|
23
|
+
Requires-Dist: lightning>=1.3.0
|
|
24
|
+
Requires-Dist: numpy>=1.21.0
|
|
25
|
+
Requires-Dist: openai-whisper==20231117
|
|
26
|
+
Requires-Dist: pandas>=1.3.0
|
|
27
|
+
Requires-Dist: pesq>=0.0.1
|
|
28
|
+
Requires-Dist: pyannote-audio==3.1.*
|
|
29
|
+
Requires-Dist: pyworld>=0.2.0
|
|
30
|
+
Requires-Dist: statsmodels>=0.12.0
|
|
31
|
+
Requires-Dist: torch>=2.0.0
|
|
32
|
+
Requires-Dist: tqdm>=4.61.0
|
|
33
|
+
Requires-Dist: transformers>=4.0.0
|
|
34
|
+
Requires-Dist: voicefixer>=0.1.0
|
|
35
|
+
Description-Content-Type: text/markdown
|
|
36
|
+
|
|
37
|
+
# ttsds
|
|
38
|
+
|
|
39
|
+
[](https://pypi.org/project/ttsds)
|
|
40
|
+
[](https://pypi.org/project/ttsds)
|
|
41
|
+
|
|
42
|
+
## Installation
|
|
43
|
+
|
|
44
|
+
### Requirements
|
|
45
|
+
|
|
46
|
+
- Python 3.8+
|
|
47
|
+
- System packages: ffmpeg, automake, autoconf, unzip, sox, gfortran, subversion, libtool
|
|
48
|
+
- Simple_hifigan, wvmos and wespeaker are not available on PyPi, so you need to install them manually.
|
|
49
|
+
- https://github.com/wenet-e2e/wespeaker
|
|
50
|
+
- https://github.com/AndreevP/wvmos
|
|
51
|
+
- https://github.com/MiniXC/simple_hifigan
|
|
52
|
+
- On some systems, the fairseq installation may fail due to conflicting dependencies. In this case, you can install this fork of fairseq https://github.com/MiniXC/fairseq-noconf
|
|
53
|
+
|
|
54
|
+
### Pip
|
|
55
|
+
|
|
56
|
+
```console
|
|
57
|
+
pip install ttsds
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
### Caching
|
|
61
|
+
|
|
62
|
+
Please set ``TTSDS_CACHE_DIR`` environment variable to a directory where you want to cache the downloaded models and data.
|
|
63
|
+
|
|
64
|
+
## License
|
|
65
|
+
|
|
66
|
+
`ttsds` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.
|
ttsds-0.0.1/README.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# ttsds
|
|
2
|
+
|
|
3
|
+
[](https://pypi.org/project/ttsds)
|
|
4
|
+
[](https://pypi.org/project/ttsds)
|
|
5
|
+
|
|
6
|
+
## Installation
|
|
7
|
+
|
|
8
|
+
### Requirements
|
|
9
|
+
|
|
10
|
+
- Python 3.8+
|
|
11
|
+
- System packages: ffmpeg, automake, autoconf, unzip, sox, gfortran, subversion, libtool
|
|
12
|
+
- Simple_hifigan, wvmos and wespeaker are not available on PyPi, so you need to install them manually.
|
|
13
|
+
- https://github.com/wenet-e2e/wespeaker
|
|
14
|
+
- https://github.com/AndreevP/wvmos
|
|
15
|
+
- https://github.com/MiniXC/simple_hifigan
|
|
16
|
+
- On some systems, the fairseq installation may fail due to conflicting dependencies. In this case, you can install this fork of fairseq https://github.com/MiniXC/fairseq-noconf
|
|
17
|
+
|
|
18
|
+
### Pip
|
|
19
|
+
|
|
20
|
+
```console
|
|
21
|
+
pip install ttsds
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
### Caching
|
|
25
|
+
|
|
26
|
+
Please set ``TTSDS_CACHE_DIR`` environment variable to a directory where you want to cache the downloaded models and data.
|
|
27
|
+
|
|
28
|
+
## License
|
|
29
|
+
|
|
30
|
+
`ttsds` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "ttsds"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = ''
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.8"
|
|
11
|
+
license = "MIT"
|
|
12
|
+
keywords = []
|
|
13
|
+
authors = [
|
|
14
|
+
{ name = "Christoph Minixhofer", email = "christoph.minixhofer@gmail.com" },
|
|
15
|
+
]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Development Status :: 4 - Beta",
|
|
18
|
+
"Programming Language :: Python",
|
|
19
|
+
"Programming Language :: Python :: 3.8",
|
|
20
|
+
"Programming Language :: Python :: 3.9",
|
|
21
|
+
"Programming Language :: Python :: 3.10",
|
|
22
|
+
"Programming Language :: Python :: 3.11",
|
|
23
|
+
"Programming Language :: Python :: 3.12",
|
|
24
|
+
"Programming Language :: Python :: Implementation :: CPython",
|
|
25
|
+
"Programming Language :: Python :: Implementation :: PyPy",
|
|
26
|
+
]
|
|
27
|
+
dependencies = [
|
|
28
|
+
"allosaurus>=0.1.0",
|
|
29
|
+
"jiwer>=2.2.0",
|
|
30
|
+
"librosa>=0.10.0",
|
|
31
|
+
"lightning>=1.3.0",
|
|
32
|
+
"numpy>=1.21.0",
|
|
33
|
+
"openai-whisper==20231117",
|
|
34
|
+
"pandas>=1.3.0",
|
|
35
|
+
"pesq>=0.0.1",
|
|
36
|
+
"pyannote.audio==3.1.*",
|
|
37
|
+
"pyworld>=0.2.0",
|
|
38
|
+
"statsmodels>=0.12.0",
|
|
39
|
+
"torch>=2.0.0",
|
|
40
|
+
"tqdm>=4.61.0",
|
|
41
|
+
"transformers>=4.0.0",
|
|
42
|
+
"voicefixer>=0.1.0",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
[project.urls]
|
|
46
|
+
Documentation = "https://github.com/ttsds/ttsds#readme"
|
|
47
|
+
Issues = "https://github.com/ttsds/ttsds/issues"
|
|
48
|
+
Source = "https://github.com/ttsds/ttsds"
|
|
49
|
+
|
|
50
|
+
[tool.hatch.version]
|
|
51
|
+
path = "src/ttsds/__about__.py"
|
|
52
|
+
|
|
53
|
+
[tool.hatch.build.targets.sdist]
|
|
54
|
+
include = ["src/ttsds"]
|
|
55
|
+
exclude = [
|
|
56
|
+
"src/ttsds/data/*.pkl.gz"
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
[tool.hatch.envs.types]
|
|
61
|
+
extra-dependencies = [
|
|
62
|
+
"mypy>=1.0.0",
|
|
63
|
+
"pytest>=6.0.0",
|
|
64
|
+
"streamlit>=1.0.0",
|
|
65
|
+
]
|
|
66
|
+
[tool.hatch.envs.types.scripts]
|
|
67
|
+
check = "mypy --install-types --non-interactive {args:src/ttsds tests}"
|
|
68
|
+
|
|
69
|
+
[tool.coverage.run]
|
|
70
|
+
source_pkgs = ["ttsds", "tests"]
|
|
71
|
+
branch = true
|
|
72
|
+
parallel = true
|
|
73
|
+
omit = [
|
|
74
|
+
"src/ttsds/__about__.py",
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
[tool.coverage.paths]
|
|
78
|
+
ttsds = ["src/ttsds", "*/ttsds/src/ttsds"]
|
|
79
|
+
tests = ["tests", "*/ttsds/tests"]
|
|
80
|
+
|
|
81
|
+
[tool.coverage.report]
|
|
82
|
+
exclude_lines = [
|
|
83
|
+
"no cov",
|
|
84
|
+
"if __name__ == .__main__.:",
|
|
85
|
+
"if TYPE_CHECKING:",
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
[tool.hatch.metadata]
|
|
89
|
+
allow-direct-references = true
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
import importlib.resources
|
|
3
|
+
from time import time
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import pickle
|
|
6
|
+
import gzip
|
|
7
|
+
import requests
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from transformers import logging
|
|
11
|
+
import numpy as np
|
|
12
|
+
from sklearn.decomposition import PCA
|
|
13
|
+
|
|
14
|
+
from ttsds.benchmarks.environment.voicefixer import VoiceFixerBenchmark
|
|
15
|
+
from ttsds.benchmarks.environment.wada_snr import WadaSNRBenchmark
|
|
16
|
+
from ttsds.benchmarks.general.hubert import HubertBenchmark
|
|
17
|
+
from ttsds.benchmarks.general.wav2vec2 import Wav2Vec2Benchmark
|
|
18
|
+
from ttsds.benchmarks.general.wavlm import WavLMBenchmark
|
|
19
|
+
from ttsds.benchmarks.intelligibility.w2v2_wer import Wav2Vec2WERBenchmark
|
|
20
|
+
from ttsds.benchmarks.intelligibility.whisper_wer import WhisperWERBenchmark
|
|
21
|
+
from ttsds.benchmarks.prosody.mpm import MPMBenchmark
|
|
22
|
+
from ttsds.benchmarks.prosody.pitch import PitchBenchmark
|
|
23
|
+
from ttsds.benchmarks.prosody.hubert_token import HubertTokenBenchmark
|
|
24
|
+
from ttsds.benchmarks.speaker.wespeaker import WeSpeakerBenchmark
|
|
25
|
+
from ttsds.benchmarks.speaker.dvector import DVectorBenchmark
|
|
26
|
+
from ttsds.benchmarks.benchmark import BenchmarkCategory, BenchmarkDimension
|
|
27
|
+
from ttsds.util.dataset import Dataset, TarDataset, DataDistribution, DEFAULT_BENCHMARKS
|
|
28
|
+
|
|
29
|
+
# we do this to avoid "some weights of the model checkpoint at ... were not used when initializing" warnings
|
|
30
|
+
logging.set_verbosity_error()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
benchmark_dict = {
|
|
34
|
+
"hubert": HubertBenchmark,
|
|
35
|
+
"wav2vec2": Wav2Vec2Benchmark,
|
|
36
|
+
"wavlm": WavLMBenchmark,
|
|
37
|
+
"wav2vec2_wer": Wav2Vec2WERBenchmark,
|
|
38
|
+
"whisper_wer": WhisperWERBenchmark,
|
|
39
|
+
"mpm": MPMBenchmark,
|
|
40
|
+
"pitch": PitchBenchmark,
|
|
41
|
+
"wespeaker": WeSpeakerBenchmark,
|
|
42
|
+
"dvector": DVectorBenchmark,
|
|
43
|
+
"hubert_token": HubertTokenBenchmark,
|
|
44
|
+
"voicefixer": VoiceFixerBenchmark,
|
|
45
|
+
"wada_snr": WadaSNRBenchmark,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
with importlib.resources.path("ttsds", "data") as data_path:
|
|
49
|
+
# if they don't exist, download from github
|
|
50
|
+
for noise_name in [
|
|
51
|
+
"esc50",
|
|
52
|
+
"all_ones",
|
|
53
|
+
"all_zeros",
|
|
54
|
+
"normal_distribution",
|
|
55
|
+
"uniform_distribution",
|
|
56
|
+
]:
|
|
57
|
+
if not Path(f"{data_path}/noise_{noise_name}.pkl.gz").exists():
|
|
58
|
+
print(f"Downloading noise_{noise_name}.pkl.gz")
|
|
59
|
+
url = f"https://github.com/ttsds/ttsds/raw/main/src/ttsds/data/noise_{noise_name}.pkl.gz"
|
|
60
|
+
r = requests.get(url)
|
|
61
|
+
with open(f"{data_path}/noise_{noise_name}.pkl.gz", "wb") as f:
|
|
62
|
+
f.write(r.content)
|
|
63
|
+
|
|
64
|
+
for speech_name in [
|
|
65
|
+
"blizzard2008",
|
|
66
|
+
"blizzard2013",
|
|
67
|
+
"common_voice",
|
|
68
|
+
"libritts_test",
|
|
69
|
+
"libritts_r_test",
|
|
70
|
+
"lj_speech",
|
|
71
|
+
"vctk",
|
|
72
|
+
]:
|
|
73
|
+
if not Path(f"{data_path}/reference_speech_{speech_name}.pkl.gz").exists():
|
|
74
|
+
print(f"Downloading reference_speech_{speech_name}.pkl.gz")
|
|
75
|
+
url = f"https://github.com/ttsds/ttsds/raw/main/src/ttsds/data/reference_speech_{speech_name}.pkl.gz"
|
|
76
|
+
r = requests.get(url)
|
|
77
|
+
with open(f"{data_path}/reference_speech_{speech_name}.pkl.gz", "wb") as f:
|
|
78
|
+
f.write(r.content)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# check if the reference and noise distributions are already saved
|
|
82
|
+
if not Path(f"{data_path}/reference_speech_blizzard2008.pkl.gz").exists():
|
|
83
|
+
print("Creating reference distributions")
|
|
84
|
+
ref_benchmark_dict = {
|
|
85
|
+
k: v() for k, v in benchmark_dict.items()
|
|
86
|
+
}
|
|
87
|
+
REFERENCE_DISTS = [
|
|
88
|
+
DataDistribution(
|
|
89
|
+
TarDataset(data_path / "original" / f"speech_{name}.tar.gz"),
|
|
90
|
+
ref_benchmark_dict,
|
|
91
|
+
benchmarks=DEFAULT_BENCHMARKS,
|
|
92
|
+
name=f"speech_{name}",
|
|
93
|
+
)
|
|
94
|
+
for name in [
|
|
95
|
+
"blizzard2008",
|
|
96
|
+
"blizzard2013",
|
|
97
|
+
"common_voice",
|
|
98
|
+
"libritts_test",
|
|
99
|
+
"libritts_r_test",
|
|
100
|
+
"lj_speech",
|
|
101
|
+
"vctk",
|
|
102
|
+
]
|
|
103
|
+
]
|
|
104
|
+
# save the reference distributions
|
|
105
|
+
for dist in REFERENCE_DISTS:
|
|
106
|
+
dist.to_pickle(f"{data_path}/reference_{dist.name}.pkl.gz")
|
|
107
|
+
|
|
108
|
+
REFERENCE_DISTS = [
|
|
109
|
+
DataDistribution.from_pickle(f"{data_path}/reference_{name}.pkl.gz")
|
|
110
|
+
for name in [
|
|
111
|
+
"speech_blizzard2008",
|
|
112
|
+
"speech_blizzard2013",
|
|
113
|
+
"speech_common_voice",
|
|
114
|
+
"speech_libritts_test",
|
|
115
|
+
"speech_libritts_r_test",
|
|
116
|
+
"speech_lj_speech",
|
|
117
|
+
"speech_vctk",
|
|
118
|
+
]
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
if not Path(f"{data_path}/noise_esc50.pkl.gz").exists():
|
|
122
|
+
print("Creating noise distributions")
|
|
123
|
+
ref_benchmark_dict = {
|
|
124
|
+
k: v() for k, v in benchmark_dict.items()
|
|
125
|
+
}
|
|
126
|
+
NOISE_DISTS = [
|
|
127
|
+
DataDistribution(
|
|
128
|
+
TarDataset(data_path / "original" / f"noise_{name}.tar.gz"),
|
|
129
|
+
ref_benchmark_dict,
|
|
130
|
+
benchmarks=DEFAULT_BENCHMARKS,
|
|
131
|
+
name=name,
|
|
132
|
+
)
|
|
133
|
+
for name in [
|
|
134
|
+
"esc50",
|
|
135
|
+
"all_ones",
|
|
136
|
+
"all_zeros",
|
|
137
|
+
"normal_distribution",
|
|
138
|
+
"uniform_distribution",
|
|
139
|
+
]
|
|
140
|
+
]
|
|
141
|
+
# save the noise distributions
|
|
142
|
+
for dist in NOISE_DISTS:
|
|
143
|
+
dist.to_pickle(f"{data_path}/noise_{dist.name}.pkl.gz")
|
|
144
|
+
|
|
145
|
+
NOISE_DISTS = [
|
|
146
|
+
DataDistribution.from_pickle(f"{data_path}/noise_{name}.pkl.gz")
|
|
147
|
+
for name in [
|
|
148
|
+
"esc50",
|
|
149
|
+
"all_ones",
|
|
150
|
+
"all_zeros",
|
|
151
|
+
"normal_distribution",
|
|
152
|
+
"uniform_distribution",
|
|
153
|
+
]
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class BenchmarkSuite:
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
datasets: List[Dataset],
|
|
161
|
+
benchmarks: List[str] = DEFAULT_BENCHMARKS,
|
|
162
|
+
print_results: bool = True,
|
|
163
|
+
skip_errors: bool = False,
|
|
164
|
+
noise_distributions: List[DataDistribution] = NOISE_DISTS,
|
|
165
|
+
reference_distributions: List[DataDistribution] = REFERENCE_DISTS,
|
|
166
|
+
write_to_file: str = None,
|
|
167
|
+
):
|
|
168
|
+
self.benchmarks = benchmarks
|
|
169
|
+
self.benchmark_objects = [benchmark_dict[benchmark]() for benchmark in benchmarks]
|
|
170
|
+
# sort by category and then by name
|
|
171
|
+
self.benchmark_objects = sorted(
|
|
172
|
+
self.benchmark_objects, key=lambda x: (x.category.value, x.name)
|
|
173
|
+
)
|
|
174
|
+
self.datasets = datasets
|
|
175
|
+
self.datasets = sorted(self.datasets, key=lambda x: x.name)
|
|
176
|
+
self.database = pd.DataFrame(
|
|
177
|
+
columns=[
|
|
178
|
+
"benchmark_name",
|
|
179
|
+
"benchmark_category",
|
|
180
|
+
"dataset",
|
|
181
|
+
"score",
|
|
182
|
+
"ci",
|
|
183
|
+
"time_taken",
|
|
184
|
+
"noise_dataset",
|
|
185
|
+
"reference_dataset",
|
|
186
|
+
]
|
|
187
|
+
)
|
|
188
|
+
self.print_results = print_results
|
|
189
|
+
self.skip_errors = skip_errors
|
|
190
|
+
self.noise_distributions = noise_distributions
|
|
191
|
+
self.reference_distributions = reference_distributions
|
|
192
|
+
self.write_to_file = write_to_file
|
|
193
|
+
if Path(write_to_file).exists():
|
|
194
|
+
self.database = pd.read_csv(write_to_file, index_col=0)
|
|
195
|
+
self.database = self.database.reset_index()
|
|
196
|
+
|
|
197
|
+
def run(self) -> pd.DataFrame:
|
|
198
|
+
for benchmark in self.benchmark_objects:
|
|
199
|
+
for dataset in self.datasets:
|
|
200
|
+
# empty lines for better readability
|
|
201
|
+
print("\n")
|
|
202
|
+
print(f"{'='*80}")
|
|
203
|
+
print(f"Benchmark Category: {benchmark.category.name}")
|
|
204
|
+
print(f"Running {benchmark.name} on {dataset.root_dir}")
|
|
205
|
+
try:
|
|
206
|
+
# check if it's in the database
|
|
207
|
+
if (
|
|
208
|
+
(self.database["benchmark_name"] == benchmark.name)
|
|
209
|
+
& (self.database["dataset"] == dataset.name)
|
|
210
|
+
).any():
|
|
211
|
+
print(
|
|
212
|
+
f"Skipping {benchmark.name} on {dataset.name} as it's already in the database"
|
|
213
|
+
)
|
|
214
|
+
continue
|
|
215
|
+
start = time()
|
|
216
|
+
if "WER".lower() in benchmark.name.lower():
|
|
217
|
+
print([
|
|
218
|
+
x.get_distribution(benchmark.key) for x in self.reference_distributions
|
|
219
|
+
])
|
|
220
|
+
print([
|
|
221
|
+
x.get_distribution(benchmark.key) for x in self.noise_distributions
|
|
222
|
+
])
|
|
223
|
+
print(benchmark.get_distribution(dataset))
|
|
224
|
+
score = benchmark.compute_score(
|
|
225
|
+
dataset, self.reference_distributions, self.noise_distributions
|
|
226
|
+
)
|
|
227
|
+
time_taken = time() - start
|
|
228
|
+
except Exception as e:
|
|
229
|
+
if self.skip_errors:
|
|
230
|
+
print(f"Error: {e}")
|
|
231
|
+
score = (np.nan, np.nan)
|
|
232
|
+
time_taken = np.nan
|
|
233
|
+
else:
|
|
234
|
+
raise e
|
|
235
|
+
result = {
|
|
236
|
+
"benchmark_name": [benchmark.name],
|
|
237
|
+
"benchmark_category": [benchmark.category.value],
|
|
238
|
+
"dataset": [dataset.name],
|
|
239
|
+
"score": [score[0]],
|
|
240
|
+
"ci": [score[1]],
|
|
241
|
+
"time_taken": [time_taken],
|
|
242
|
+
"noise_dataset": [score[2][0]],
|
|
243
|
+
"reference_dataset": [score[2][1]],
|
|
244
|
+
}
|
|
245
|
+
if self.print_results:
|
|
246
|
+
print(result)
|
|
247
|
+
self.database = pd.concat(
|
|
248
|
+
[
|
|
249
|
+
self.database,
|
|
250
|
+
pd.DataFrame(result),
|
|
251
|
+
],
|
|
252
|
+
ignore_index=True,
|
|
253
|
+
)
|
|
254
|
+
if self.write_to_file is not None:
|
|
255
|
+
self.database["score"] = self.database["score"].astype(float)
|
|
256
|
+
self.database = self.database.sort_values(
|
|
257
|
+
["benchmark_category", "benchmark_name", "score"],
|
|
258
|
+
ascending=[True, True, False],
|
|
259
|
+
)
|
|
260
|
+
self.database.to_csv(self.write_to_file, index=False)
|
|
261
|
+
return self.database
|
|
262
|
+
|
|
263
|
+
@staticmethod
|
|
264
|
+
def aggregate_df(df: pd.DataFrame) -> pd.DataFrame:
|
|
265
|
+
def concat_text(x):
|
|
266
|
+
return ", ".join(x)
|
|
267
|
+
|
|
268
|
+
df["benchmark_category"] = df["benchmark_category"].apply(
|
|
269
|
+
lambda x: BenchmarkCategory(x).name
|
|
270
|
+
)
|
|
271
|
+
df = (
|
|
272
|
+
df.groupby(
|
|
273
|
+
[
|
|
274
|
+
"benchmark_category",
|
|
275
|
+
"dataset",
|
|
276
|
+
]
|
|
277
|
+
)
|
|
278
|
+
.agg(
|
|
279
|
+
{
|
|
280
|
+
"score": ["mean"],
|
|
281
|
+
"ci": ["mean"],
|
|
282
|
+
"time_taken": ["mean"],
|
|
283
|
+
"noise_dataset": [concat_text],
|
|
284
|
+
"reference_dataset": [concat_text],
|
|
285
|
+
"benchmark_name": [concat_text],
|
|
286
|
+
}
|
|
287
|
+
)
|
|
288
|
+
.reset_index()
|
|
289
|
+
)
|
|
290
|
+
# remove multiindex
|
|
291
|
+
df.columns = [x[0] for x in df.columns.ravel()]
|
|
292
|
+
# drop the benchmark_name column
|
|
293
|
+
df = df.drop("benchmark_name", axis=1)
|
|
294
|
+
# replace benchmark_category number with string
|
|
295
|
+
return df
|
|
296
|
+
|
|
297
|
+
def get_aggregated_results(self) -> pd.DataFrame:
|
|
298
|
+
df = self.database.copy()
|
|
299
|
+
return BenchmarkSuite.aggregate_df(df)
|
|
300
|
+
|
|
301
|
+
def get_benchmark_distribution(
|
|
302
|
+
self,
|
|
303
|
+
benchmark_name: str,
|
|
304
|
+
dataset_name: str,
|
|
305
|
+
pca_components: Optional[int] = None,
|
|
306
|
+
) -> dict:
|
|
307
|
+
benchmark = [x for x in self.benchmark_objects if x.name == benchmark_name][0]
|
|
308
|
+
dataset = [x for x in self.datasets if x.name == dataset_name][0]
|
|
309
|
+
closest_noise = self.database[
|
|
310
|
+
(self.database["benchmark_name"] == benchmark_name)
|
|
311
|
+
& (self.database["dataset"] == dataset_name)
|
|
312
|
+
]["noise_dataset"].values[0]
|
|
313
|
+
closest_noise = [
|
|
314
|
+
x for x in self.noise_distributions if x.name == closest_noise
|
|
315
|
+
][0]
|
|
316
|
+
other_noise = [
|
|
317
|
+
x for x in self.noise_distributions if x.name != closest_noise.name
|
|
318
|
+
][0]
|
|
319
|
+
closest_reference = self.database[
|
|
320
|
+
(self.database["benchmark_name"] == benchmark_name)
|
|
321
|
+
& (self.database["dataset"] == dataset_name)
|
|
322
|
+
]["reference_dataset"].values[0]
|
|
323
|
+
closest_reference = [
|
|
324
|
+
x for x in self.reference_distributions if x.name == closest_reference
|
|
325
|
+
][0]
|
|
326
|
+
other_reference = [
|
|
327
|
+
x for x in self.reference_distributions if x.name != closest_reference.name
|
|
328
|
+
][0]
|
|
329
|
+
result = {
|
|
330
|
+
"benchmark_distribution": benchmark.get_distribution(dataset),
|
|
331
|
+
"noise_distribution": benchmark.get_distribution(closest_noise),
|
|
332
|
+
"reference_distribution": benchmark.get_distribution(closest_reference),
|
|
333
|
+
"other_noise_distribution": benchmark.get_distribution(other_noise),
|
|
334
|
+
"other_reference_distribution": benchmark.get_distribution(other_reference),
|
|
335
|
+
}
|
|
336
|
+
if pca_components is not None:
|
|
337
|
+
pca = PCA(n_components=pca_components)
|
|
338
|
+
# fit on all except the benchmark distribution
|
|
339
|
+
pca.fit(
|
|
340
|
+
np.vstack(
|
|
341
|
+
[v for k, v in result.items() if k != "benchmark_distribution"]
|
|
342
|
+
)
|
|
343
|
+
)
|
|
344
|
+
result = {k: pca.transform(v) for k, v in result.items()}
|
|
345
|
+
return result
|
|
File without changes
|