ttsds 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ttsds-0.0.1/.gitignore +11 -0
  2. ttsds-0.0.1/LICENSE.txt +9 -0
  3. ttsds-0.0.1/PKG-INFO +66 -0
  4. ttsds-0.0.1/README.md +30 -0
  5. ttsds-0.0.1/pyproject.toml +89 -0
  6. ttsds-0.0.1/src/ttsds/__about__.py +4 -0
  7. ttsds-0.0.1/src/ttsds/__init__.py +345 -0
  8. ttsds-0.0.1/src/ttsds/benchmarks/__init__.py +0 -0
  9. ttsds-0.0.1/src/ttsds/benchmarks/benchmark.py +182 -0
  10. ttsds-0.0.1/src/ttsds/benchmarks/environment/voicefixer.py +80 -0
  11. ttsds-0.0.1/src/ttsds/benchmarks/environment/wada_snr.py +230 -0
  12. ttsds-0.0.1/src/ttsds/benchmarks/external/pesq.py +67 -0
  13. ttsds-0.0.1/src/ttsds/benchmarks/external/utmos/__init__.py +70 -0
  14. ttsds-0.0.1/src/ttsds/benchmarks/external/utmos/change_sample_rate.py +21 -0
  15. ttsds-0.0.1/src/ttsds/benchmarks/external/utmos/lightning_module.py +71 -0
  16. ttsds-0.0.1/src/ttsds/benchmarks/external/utmos/model.py +225 -0
  17. ttsds-0.0.1/src/ttsds/benchmarks/external/wv_mos.py +43 -0
  18. ttsds-0.0.1/src/ttsds/benchmarks/general/hubert.py +83 -0
  19. ttsds-0.0.1/src/ttsds/benchmarks/general/wav2vec2.py +83 -0
  20. ttsds-0.0.1/src/ttsds/benchmarks/general/wavlm.py +83 -0
  21. ttsds-0.0.1/src/ttsds/benchmarks/intelligibility/w2v2_wer.py +62 -0
  22. ttsds-0.0.1/src/ttsds/benchmarks/intelligibility/whisper_wer.py +63 -0
  23. ttsds-0.0.1/src/ttsds/benchmarks/prosody/hubert_token.py +143 -0
  24. ttsds-0.0.1/src/ttsds/benchmarks/prosody/mpm.py +111 -0
  25. ttsds-0.0.1/src/ttsds/benchmarks/prosody/pitch.py +40 -0
  26. ttsds-0.0.1/src/ttsds/benchmarks/speaker/dvector.py +161 -0
  27. ttsds-0.0.1/src/ttsds/benchmarks/speaker/wespeaker.py +68 -0
  28. ttsds-0.0.1/src/ttsds/data/dvector/README.md +1 -0
  29. ttsds-0.0.1/src/ttsds/data/dvector/dvector.pt +0 -0
  30. ttsds-0.0.1/src/ttsds/util/__init__.py +0 -0
  31. ttsds-0.0.1/src/ttsds/util/cache.py +70 -0
  32. ttsds-0.0.1/src/ttsds/util/dataset.py +285 -0
  33. ttsds-0.0.1/src/ttsds/util/distances.py +80 -0
  34. ttsds-0.0.1/src/ttsds/util/measures.py +142 -0
  35. ttsds-0.0.1/src/ttsds/util/mpm.py +142 -0
  36. ttsds-0.0.1/src/ttsds/util/mpm_modules.py +147 -0
ttsds-0.0.1/.gitignore ADDED
@@ -0,0 +1,11 @@
1
+ .pytest_cache
2
+ __pycache__
3
+ .DS_Store
4
+ *.txt
5
+ *.wav
6
+ examples/blizzard08/data
7
+ examples/blizzard08/processed_data
8
+ examples/blizzard13-ext/data
9
+ examples/blizzard13-ext/processed_data
10
+ src/ttsds/data/original
11
+ dist
@@ -0,0 +1,9 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024-present Christoph Minixhofer <christoph.minixhofer@gmail.com>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
ttsds-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,66 @@
1
+ Metadata-Version: 2.3
2
+ Name: ttsds
3
+ Version: 0.0.1
4
+ Project-URL: Documentation, https://github.com/ttsds/ttsds#readme
5
+ Project-URL: Issues, https://github.com/ttsds/ttsds/issues
6
+ Project-URL: Source, https://github.com/ttsds/ttsds
7
+ Author-email: Christoph Minixhofer <christoph.minixhofer@gmail.com>
8
+ License-Expression: MIT
9
+ License-File: LICENSE.txt
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Programming Language :: Python
12
+ Classifier: Programming Language :: Python :: 3.8
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: Implementation :: CPython
18
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
19
+ Requires-Python: >=3.8
20
+ Requires-Dist: allosaurus>=0.1.0
21
+ Requires-Dist: jiwer>=2.2.0
22
+ Requires-Dist: librosa>=0.10.0
23
+ Requires-Dist: lightning>=1.3.0
24
+ Requires-Dist: numpy>=1.21.0
25
+ Requires-Dist: openai-whisper==20231117
26
+ Requires-Dist: pandas>=1.3.0
27
+ Requires-Dist: pesq>=0.0.1
28
+ Requires-Dist: pyannote-audio==3.1.*
29
+ Requires-Dist: pyworld>=0.2.0
30
+ Requires-Dist: statsmodels>=0.12.0
31
+ Requires-Dist: torch>=2.0.0
32
+ Requires-Dist: tqdm>=4.61.0
33
+ Requires-Dist: transformers>=4.0.0
34
+ Requires-Dist: voicefixer>=0.1.0
35
+ Description-Content-Type: text/markdown
36
+
37
+ # ttsds
38
+
39
+ [![PyPI - Version](https://img.shields.io/pypi/v/ttsds.svg)](https://pypi.org/project/ttsds)
40
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ttsds.svg)](https://pypi.org/project/ttsds)
41
+
42
+ ## Installation
43
+
44
+ ### Requirements
45
+
46
+ - Python 3.8+
47
+ - System packages: ffmpeg, automake, autoconf, unzip, sox, gfortran, subversion, libtool
48
+ - Simple_hifigan, wvmos and wespeaker are not available on PyPi, so you need to install them manually.
49
+ - https://github.com/wenet-e2e/wespeaker
50
+ - https://github.com/AndreevP/wvmos
51
+ - https://github.com/MiniXC/simple_hifigan
52
+ - On some systems, the fairseq installation may fail due to conflicting dependencies. In this case, you can install this fork of fairseq https://github.com/MiniXC/fairseq-noconf
53
+
54
+ ### Pip
55
+
56
+ ```console
57
+ pip install ttsds
58
+ ```
59
+
60
+ ### Caching
61
+
62
+ Please set ``TTSDS_CACHE_DIR`` environment variable to a directory where you want to cache the downloaded models and data.
63
+
64
+ ## License
65
+
66
+ `ttsds` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.
ttsds-0.0.1/README.md ADDED
@@ -0,0 +1,30 @@
1
+ # ttsds
2
+
3
+ [![PyPI - Version](https://img.shields.io/pypi/v/ttsds.svg)](https://pypi.org/project/ttsds)
4
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ttsds.svg)](https://pypi.org/project/ttsds)
5
+
6
+ ## Installation
7
+
8
+ ### Requirements
9
+
10
+ - Python 3.8+
11
+ - System packages: ffmpeg, automake, autoconf, unzip, sox, gfortran, subversion, libtool
12
+ - Simple_hifigan, wvmos and wespeaker are not available on PyPi, so you need to install them manually.
13
+ - https://github.com/wenet-e2e/wespeaker
14
+ - https://github.com/AndreevP/wvmos
15
+ - https://github.com/MiniXC/simple_hifigan
16
+ - On some systems, the fairseq installation may fail due to conflicting dependencies. In this case, you can install this fork of fairseq https://github.com/MiniXC/fairseq-noconf
17
+
18
+ ### Pip
19
+
20
+ ```console
21
+ pip install ttsds
22
+ ```
23
+
24
+ ### Caching
25
+
26
+ Please set ``TTSDS_CACHE_DIR`` environment variable to a directory where you want to cache the downloaded models and data.
27
+
28
+ ## License
29
+
30
+ `ttsds` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.
@@ -0,0 +1,89 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "ttsds"
7
+ dynamic = ["version"]
8
+ description = ''
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = "MIT"
12
+ keywords = []
13
+ authors = [
14
+ { name = "Christoph Minixhofer", email = "christoph.minixhofer@gmail.com" },
15
+ ]
16
+ classifiers = [
17
+ "Development Status :: 4 - Beta",
18
+ "Programming Language :: Python",
19
+ "Programming Language :: Python :: 3.8",
20
+ "Programming Language :: Python :: 3.9",
21
+ "Programming Language :: Python :: 3.10",
22
+ "Programming Language :: Python :: 3.11",
23
+ "Programming Language :: Python :: 3.12",
24
+ "Programming Language :: Python :: Implementation :: CPython",
25
+ "Programming Language :: Python :: Implementation :: PyPy",
26
+ ]
27
+ dependencies = [
28
+ "allosaurus>=0.1.0",
29
+ "jiwer>=2.2.0",
30
+ "librosa>=0.10.0",
31
+ "lightning>=1.3.0",
32
+ "numpy>=1.21.0",
33
+ "openai-whisper==20231117",
34
+ "pandas>=1.3.0",
35
+ "pesq>=0.0.1",
36
+ "pyannote.audio==3.1.*",
37
+ "pyworld>=0.2.0",
38
+ "statsmodels>=0.12.0",
39
+ "torch>=2.0.0",
40
+ "tqdm>=4.61.0",
41
+ "transformers>=4.0.0",
42
+ "voicefixer>=0.1.0",
43
+ ]
44
+
45
+ [project.urls]
46
+ Documentation = "https://github.com/ttsds/ttsds#readme"
47
+ Issues = "https://github.com/ttsds/ttsds/issues"
48
+ Source = "https://github.com/ttsds/ttsds"
49
+
50
+ [tool.hatch.version]
51
+ path = "src/ttsds/__about__.py"
52
+
53
+ [tool.hatch.build.targets.sdist]
54
+ include = ["src/ttsds"]
55
+ exclude = [
56
+ "src/ttsds/data/*.pkl.gz"
57
+ ]
58
+
59
+
60
+ [tool.hatch.envs.types]
61
+ extra-dependencies = [
62
+ "mypy>=1.0.0",
63
+ "pytest>=6.0.0",
64
+ "streamlit>=1.0.0",
65
+ ]
66
+ [tool.hatch.envs.types.scripts]
67
+ check = "mypy --install-types --non-interactive {args:src/ttsds tests}"
68
+
69
+ [tool.coverage.run]
70
+ source_pkgs = ["ttsds", "tests"]
71
+ branch = true
72
+ parallel = true
73
+ omit = [
74
+ "src/ttsds/__about__.py",
75
+ ]
76
+
77
+ [tool.coverage.paths]
78
+ ttsds = ["src/ttsds", "*/ttsds/src/ttsds"]
79
+ tests = ["tests", "*/ttsds/tests"]
80
+
81
+ [tool.coverage.report]
82
+ exclude_lines = [
83
+ "no cov",
84
+ "if __name__ == .__main__.:",
85
+ "if TYPE_CHECKING:",
86
+ ]
87
+
88
+ [tool.hatch.metadata]
89
+ allow-direct-references = true
@@ -0,0 +1,4 @@
1
+ # SPDX-FileCopyrightText: 2024-present Christoph Minixhofer <christoph.minixhofer@gmail.com>
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+ __version__ = "0.0.1"
@@ -0,0 +1,345 @@
1
+ from typing import List, Optional
2
+ import importlib.resources
3
+ from time import time
4
+ from pathlib import Path
5
+ import pickle
6
+ import gzip
7
+ import requests
8
+
9
+ import pandas as pd
10
+ from transformers import logging
11
+ import numpy as np
12
+ from sklearn.decomposition import PCA
13
+
14
+ from ttsds.benchmarks.environment.voicefixer import VoiceFixerBenchmark
15
+ from ttsds.benchmarks.environment.wada_snr import WadaSNRBenchmark
16
+ from ttsds.benchmarks.general.hubert import HubertBenchmark
17
+ from ttsds.benchmarks.general.wav2vec2 import Wav2Vec2Benchmark
18
+ from ttsds.benchmarks.general.wavlm import WavLMBenchmark
19
+ from ttsds.benchmarks.intelligibility.w2v2_wer import Wav2Vec2WERBenchmark
20
+ from ttsds.benchmarks.intelligibility.whisper_wer import WhisperWERBenchmark
21
+ from ttsds.benchmarks.prosody.mpm import MPMBenchmark
22
+ from ttsds.benchmarks.prosody.pitch import PitchBenchmark
23
+ from ttsds.benchmarks.prosody.hubert_token import HubertTokenBenchmark
24
+ from ttsds.benchmarks.speaker.wespeaker import WeSpeakerBenchmark
25
+ from ttsds.benchmarks.speaker.dvector import DVectorBenchmark
26
+ from ttsds.benchmarks.benchmark import BenchmarkCategory, BenchmarkDimension
27
+ from ttsds.util.dataset import Dataset, TarDataset, DataDistribution, DEFAULT_BENCHMARKS
28
+
29
+ # we do this to avoid "some weights of the model checkpoint at ... were not used when initializing" warnings
30
+ logging.set_verbosity_error()
31
+
32
+
33
+ benchmark_dict = {
34
+ "hubert": HubertBenchmark,
35
+ "wav2vec2": Wav2Vec2Benchmark,
36
+ "wavlm": WavLMBenchmark,
37
+ "wav2vec2_wer": Wav2Vec2WERBenchmark,
38
+ "whisper_wer": WhisperWERBenchmark,
39
+ "mpm": MPMBenchmark,
40
+ "pitch": PitchBenchmark,
41
+ "wespeaker": WeSpeakerBenchmark,
42
+ "dvector": DVectorBenchmark,
43
+ "hubert_token": HubertTokenBenchmark,
44
+ "voicefixer": VoiceFixerBenchmark,
45
+ "wada_snr": WadaSNRBenchmark,
46
+ }
47
+
48
+ with importlib.resources.path("ttsds", "data") as data_path:
49
+ # if they don't exist, download from github
50
+ for noise_name in [
51
+ "esc50",
52
+ "all_ones",
53
+ "all_zeros",
54
+ "normal_distribution",
55
+ "uniform_distribution",
56
+ ]:
57
+ if not Path(f"{data_path}/noise_{noise_name}.pkl.gz").exists():
58
+ print(f"Downloading noise_{noise_name}.pkl.gz")
59
+ url = f"https://github.com/ttsds/ttsds/raw/main/src/ttsds/data/noise_{noise_name}.pkl.gz"
60
+ r = requests.get(url)
61
+ with open(f"{data_path}/noise_{noise_name}.pkl.gz", "wb") as f:
62
+ f.write(r.content)
63
+
64
+ for speech_name in [
65
+ "blizzard2008",
66
+ "blizzard2013",
67
+ "common_voice",
68
+ "libritts_test",
69
+ "libritts_r_test",
70
+ "lj_speech",
71
+ "vctk",
72
+ ]:
73
+ if not Path(f"{data_path}/reference_speech_{speech_name}.pkl.gz").exists():
74
+ print(f"Downloading reference_speech_{speech_name}.pkl.gz")
75
+ url = f"https://github.com/ttsds/ttsds/raw/main/src/ttsds/data/reference_speech_{speech_name}.pkl.gz"
76
+ r = requests.get(url)
77
+ with open(f"{data_path}/reference_speech_{speech_name}.pkl.gz", "wb") as f:
78
+ f.write(r.content)
79
+
80
+
81
+ # check if the reference and noise distributions are already saved
82
+ if not Path(f"{data_path}/reference_speech_blizzard2008.pkl.gz").exists():
83
+ print("Creating reference distributions")
84
+ ref_benchmark_dict = {
85
+ k: v() for k, v in benchmark_dict.items()
86
+ }
87
+ REFERENCE_DISTS = [
88
+ DataDistribution(
89
+ TarDataset(data_path / "original" / f"speech_{name}.tar.gz"),
90
+ ref_benchmark_dict,
91
+ benchmarks=DEFAULT_BENCHMARKS,
92
+ name=f"speech_{name}",
93
+ )
94
+ for name in [
95
+ "blizzard2008",
96
+ "blizzard2013",
97
+ "common_voice",
98
+ "libritts_test",
99
+ "libritts_r_test",
100
+ "lj_speech",
101
+ "vctk",
102
+ ]
103
+ ]
104
+ # save the reference distributions
105
+ for dist in REFERENCE_DISTS:
106
+ dist.to_pickle(f"{data_path}/reference_{dist.name}.pkl.gz")
107
+
108
+ REFERENCE_DISTS = [
109
+ DataDistribution.from_pickle(f"{data_path}/reference_{name}.pkl.gz")
110
+ for name in [
111
+ "speech_blizzard2008",
112
+ "speech_blizzard2013",
113
+ "speech_common_voice",
114
+ "speech_libritts_test",
115
+ "speech_libritts_r_test",
116
+ "speech_lj_speech",
117
+ "speech_vctk",
118
+ ]
119
+ ]
120
+
121
+ if not Path(f"{data_path}/noise_esc50.pkl.gz").exists():
122
+ print("Creating noise distributions")
123
+ ref_benchmark_dict = {
124
+ k: v() for k, v in benchmark_dict.items()
125
+ }
126
+ NOISE_DISTS = [
127
+ DataDistribution(
128
+ TarDataset(data_path / "original" / f"noise_{name}.tar.gz"),
129
+ ref_benchmark_dict,
130
+ benchmarks=DEFAULT_BENCHMARKS,
131
+ name=name,
132
+ )
133
+ for name in [
134
+ "esc50",
135
+ "all_ones",
136
+ "all_zeros",
137
+ "normal_distribution",
138
+ "uniform_distribution",
139
+ ]
140
+ ]
141
+ # save the noise distributions
142
+ for dist in NOISE_DISTS:
143
+ dist.to_pickle(f"{data_path}/noise_{dist.name}.pkl.gz")
144
+
145
+ NOISE_DISTS = [
146
+ DataDistribution.from_pickle(f"{data_path}/noise_{name}.pkl.gz")
147
+ for name in [
148
+ "esc50",
149
+ "all_ones",
150
+ "all_zeros",
151
+ "normal_distribution",
152
+ "uniform_distribution",
153
+ ]
154
+ ]
155
+
156
+
157
+ class BenchmarkSuite:
158
+ def __init__(
159
+ self,
160
+ datasets: List[Dataset],
161
+ benchmarks: List[str] = DEFAULT_BENCHMARKS,
162
+ print_results: bool = True,
163
+ skip_errors: bool = False,
164
+ noise_distributions: List[DataDistribution] = NOISE_DISTS,
165
+ reference_distributions: List[DataDistribution] = REFERENCE_DISTS,
166
+ write_to_file: str = None,
167
+ ):
168
+ self.benchmarks = benchmarks
169
+ self.benchmark_objects = [benchmark_dict[benchmark]() for benchmark in benchmarks]
170
+ # sort by category and then by name
171
+ self.benchmark_objects = sorted(
172
+ self.benchmark_objects, key=lambda x: (x.category.value, x.name)
173
+ )
174
+ self.datasets = datasets
175
+ self.datasets = sorted(self.datasets, key=lambda x: x.name)
176
+ self.database = pd.DataFrame(
177
+ columns=[
178
+ "benchmark_name",
179
+ "benchmark_category",
180
+ "dataset",
181
+ "score",
182
+ "ci",
183
+ "time_taken",
184
+ "noise_dataset",
185
+ "reference_dataset",
186
+ ]
187
+ )
188
+ self.print_results = print_results
189
+ self.skip_errors = skip_errors
190
+ self.noise_distributions = noise_distributions
191
+ self.reference_distributions = reference_distributions
192
+ self.write_to_file = write_to_file
193
+ if Path(write_to_file).exists():
194
+ self.database = pd.read_csv(write_to_file, index_col=0)
195
+ self.database = self.database.reset_index()
196
+
197
+ def run(self) -> pd.DataFrame:
198
+ for benchmark in self.benchmark_objects:
199
+ for dataset in self.datasets:
200
+ # empty lines for better readability
201
+ print("\n")
202
+ print(f"{'='*80}")
203
+ print(f"Benchmark Category: {benchmark.category.name}")
204
+ print(f"Running {benchmark.name} on {dataset.root_dir}")
205
+ try:
206
+ # check if it's in the database
207
+ if (
208
+ (self.database["benchmark_name"] == benchmark.name)
209
+ & (self.database["dataset"] == dataset.name)
210
+ ).any():
211
+ print(
212
+ f"Skipping {benchmark.name} on {dataset.name} as it's already in the database"
213
+ )
214
+ continue
215
+ start = time()
216
+ if "WER".lower() in benchmark.name.lower():
217
+ print([
218
+ x.get_distribution(benchmark.key) for x in self.reference_distributions
219
+ ])
220
+ print([
221
+ x.get_distribution(benchmark.key) for x in self.noise_distributions
222
+ ])
223
+ print(benchmark.get_distribution(dataset))
224
+ score = benchmark.compute_score(
225
+ dataset, self.reference_distributions, self.noise_distributions
226
+ )
227
+ time_taken = time() - start
228
+ except Exception as e:
229
+ if self.skip_errors:
230
+ print(f"Error: {e}")
231
+ score = (np.nan, np.nan)
232
+ time_taken = np.nan
233
+ else:
234
+ raise e
235
+ result = {
236
+ "benchmark_name": [benchmark.name],
237
+ "benchmark_category": [benchmark.category.value],
238
+ "dataset": [dataset.name],
239
+ "score": [score[0]],
240
+ "ci": [score[1]],
241
+ "time_taken": [time_taken],
242
+ "noise_dataset": [score[2][0]],
243
+ "reference_dataset": [score[2][1]],
244
+ }
245
+ if self.print_results:
246
+ print(result)
247
+ self.database = pd.concat(
248
+ [
249
+ self.database,
250
+ pd.DataFrame(result),
251
+ ],
252
+ ignore_index=True,
253
+ )
254
+ if self.write_to_file is not None:
255
+ self.database["score"] = self.database["score"].astype(float)
256
+ self.database = self.database.sort_values(
257
+ ["benchmark_category", "benchmark_name", "score"],
258
+ ascending=[True, True, False],
259
+ )
260
+ self.database.to_csv(self.write_to_file, index=False)
261
+ return self.database
262
+
263
+ @staticmethod
264
+ def aggregate_df(df: pd.DataFrame) -> pd.DataFrame:
265
+ def concat_text(x):
266
+ return ", ".join(x)
267
+
268
+ df["benchmark_category"] = df["benchmark_category"].apply(
269
+ lambda x: BenchmarkCategory(x).name
270
+ )
271
+ df = (
272
+ df.groupby(
273
+ [
274
+ "benchmark_category",
275
+ "dataset",
276
+ ]
277
+ )
278
+ .agg(
279
+ {
280
+ "score": ["mean"],
281
+ "ci": ["mean"],
282
+ "time_taken": ["mean"],
283
+ "noise_dataset": [concat_text],
284
+ "reference_dataset": [concat_text],
285
+ "benchmark_name": [concat_text],
286
+ }
287
+ )
288
+ .reset_index()
289
+ )
290
+ # remove multiindex
291
+ df.columns = [x[0] for x in df.columns.ravel()]
292
+ # drop the benchmark_name column
293
+ df = df.drop("benchmark_name", axis=1)
294
+ # replace benchmark_category number with string
295
+ return df
296
+
297
+ def get_aggregated_results(self) -> pd.DataFrame:
298
+ df = self.database.copy()
299
+ return BenchmarkSuite.aggregate_df(df)
300
+
301
+ def get_benchmark_distribution(
302
+ self,
303
+ benchmark_name: str,
304
+ dataset_name: str,
305
+ pca_components: Optional[int] = None,
306
+ ) -> dict:
307
+ benchmark = [x for x in self.benchmark_objects if x.name == benchmark_name][0]
308
+ dataset = [x for x in self.datasets if x.name == dataset_name][0]
309
+ closest_noise = self.database[
310
+ (self.database["benchmark_name"] == benchmark_name)
311
+ & (self.database["dataset"] == dataset_name)
312
+ ]["noise_dataset"].values[0]
313
+ closest_noise = [
314
+ x for x in self.noise_distributions if x.name == closest_noise
315
+ ][0]
316
+ other_noise = [
317
+ x for x in self.noise_distributions if x.name != closest_noise.name
318
+ ][0]
319
+ closest_reference = self.database[
320
+ (self.database["benchmark_name"] == benchmark_name)
321
+ & (self.database["dataset"] == dataset_name)
322
+ ]["reference_dataset"].values[0]
323
+ closest_reference = [
324
+ x for x in self.reference_distributions if x.name == closest_reference
325
+ ][0]
326
+ other_reference = [
327
+ x for x in self.reference_distributions if x.name != closest_reference.name
328
+ ][0]
329
+ result = {
330
+ "benchmark_distribution": benchmark.get_distribution(dataset),
331
+ "noise_distribution": benchmark.get_distribution(closest_noise),
332
+ "reference_distribution": benchmark.get_distribution(closest_reference),
333
+ "other_noise_distribution": benchmark.get_distribution(other_noise),
334
+ "other_reference_distribution": benchmark.get_distribution(other_reference),
335
+ }
336
+ if pca_components is not None:
337
+ pca = PCA(n_components=pca_components)
338
+ # fit on all except the benchmark distribution
339
+ pca.fit(
340
+ np.vstack(
341
+ [v for k, v in result.items() if k != "benchmark_distribution"]
342
+ )
343
+ )
344
+ result = {k: pca.transform(v) for k, v in result.items()}
345
+ return result
File without changes