nkululeko 0.90.0__py3-none-any.whl → 0.90.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/aug_train.py +6 -4
- nkululeko/augment.py +6 -4
- nkululeko/augmenting/augmenter.py +4 -4
- nkululeko/augmenting/randomsplicer.py +6 -6
- nkululeko/augmenting/randomsplicing.py +2 -3
- nkululeko/augmenting/resampler.py +9 -6
- nkululeko/autopredict/ap_age.py +4 -2
- nkululeko/autopredict/ap_arousal.py +4 -2
- nkululeko/autopredict/ap_dominance.py +3 -2
- nkululeko/autopredict/ap_gender.py +4 -2
- nkululeko/autopredict/ap_mos.py +5 -2
- nkululeko/autopredict/ap_pesq.py +5 -2
- nkululeko/autopredict/ap_sdr.py +5 -2
- nkululeko/autopredict/ap_snr.py +5 -2
- nkululeko/autopredict/ap_stoi.py +5 -2
- nkululeko/autopredict/ap_valence.py +4 -2
- nkululeko/autopredict/estimate_snr.py +10 -14
- nkululeko/cacheddataset.py +1 -1
- nkululeko/constants.py +1 -1
- nkululeko/data/dataset.py +11 -14
- nkululeko/data/dataset_csv.py +5 -3
- nkululeko/demo-ft.py +29 -0
- nkululeko/demo_feats.py +5 -4
- nkululeko/demo_predictor.py +3 -4
- nkululeko/ensemble.py +27 -28
- nkululeko/experiment.py +3 -5
- nkululeko/experiment_felix.py +728 -0
- nkululeko/explore.py +1 -0
- nkululeko/export.py +7 -5
- nkululeko/feat_extract/feats_agender.py +5 -4
- nkululeko/feat_extract/feats_agender_agender.py +7 -6
- nkululeko/feat_extract/feats_analyser.py +18 -16
- nkululeko/feat_extract/feats_ast.py +9 -8
- nkululeko/feat_extract/feats_auddim.py +3 -5
- nkululeko/feat_extract/feats_audmodel.py +2 -2
- nkululeko/feat_extract/feats_clap.py +9 -12
- nkululeko/feat_extract/feats_hubert.py +2 -3
- nkululeko/feat_extract/feats_import.py +5 -4
- nkululeko/feat_extract/feats_mld.py +3 -5
- nkululeko/feat_extract/feats_mos.py +4 -3
- nkululeko/feat_extract/feats_opensmile.py +4 -3
- nkululeko/feat_extract/feats_oxbow.py +5 -4
- nkululeko/feat_extract/feats_praat.py +4 -7
- nkululeko/feat_extract/feats_snr.py +3 -5
- nkululeko/feat_extract/feats_spectra.py +8 -9
- nkululeko/feat_extract/feats_spkrec.py +6 -11
- nkululeko/feat_extract/feats_squim.py +2 -4
- nkululeko/feat_extract/feats_trill.py +2 -5
- nkululeko/feat_extract/feats_wav2vec2.py +8 -4
- nkululeko/feat_extract/feats_wavlm.py +2 -3
- nkululeko/feat_extract/feats_whisper.py +4 -6
- nkululeko/feat_extract/featureset.py +4 -2
- nkululeko/feat_extract/feinberg_praat.py +1 -3
- nkululeko/feat_extract/transformer_feature_extractor.py +147 -0
- nkululeko/file_checker.py +3 -3
- nkululeko/filter_data.py +3 -1
- nkululeko/fixedsegment.py +83 -0
- nkululeko/models/model.py +3 -5
- nkululeko/models/model_bayes.py +1 -0
- nkululeko/models/model_cnn.py +4 -6
- nkululeko/models/model_gmm.py +13 -9
- nkululeko/models/model_knn.py +1 -0
- nkululeko/models/model_knn_reg.py +1 -0
- nkululeko/models/model_lin_reg.py +1 -0
- nkululeko/models/model_mlp.py +2 -3
- nkululeko/models/model_mlp_regression.py +1 -6
- nkululeko/models/model_svm.py +2 -2
- nkululeko/models/model_svr.py +1 -0
- nkululeko/models/model_tree.py +2 -3
- nkululeko/models/model_tree_reg.py +1 -0
- nkululeko/models/model_tuned.py +54 -33
- nkululeko/models/model_xgb.py +1 -0
- nkululeko/models/model_xgr.py +1 -0
- nkululeko/multidb.py +1 -0
- nkululeko/nkululeko.py +1 -1
- nkululeko/predict.py +4 -5
- nkululeko/reporting/defines.py +6 -8
- nkululeko/reporting/latex_writer.py +3 -3
- nkululeko/reporting/report.py +2 -2
- nkululeko/reporting/report_item.py +1 -0
- nkululeko/reporting/reporter.py +20 -19
- nkululeko/resample.py +8 -12
- nkululeko/resample_cli.py +99 -0
- nkululeko/runmanager.py +3 -1
- nkululeko/scaler.py +1 -1
- nkululeko/segment.py +6 -5
- nkululeko/segmenting/seg_inaspeechsegmenter.py +3 -3
- nkululeko/segmenting/seg_silero.py +4 -4
- nkululeko/syllable_nuclei.py +9 -22
- nkululeko/test_pretrain.py +6 -7
- nkululeko/utils/stats.py +0 -1
- nkululeko/utils/util.py +2 -3
- {nkululeko-0.90.0.dist-info → nkululeko-0.90.1.dist-info}/METADATA +6 -2
- nkululeko-0.90.1.dist-info/RECORD +119 -0
- {nkululeko-0.90.0.dist-info → nkululeko-0.90.1.dist-info}/WHEEL +1 -1
- nkululeko-0.90.0.dist-info/RECORD +0 -114
- {nkululeko-0.90.0.dist-info → nkululeko-0.90.1.dist-info}/LICENSE +0 -0
- {nkululeko-0.90.0.dist-info → nkululeko-0.90.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,728 @@
|
|
1
|
+
# experiment.py: Main class for an experiment (nkululeko.nkululeko)
|
2
|
+
import ast
|
3
|
+
import os
|
4
|
+
import pickle
|
5
|
+
import random
|
6
|
+
import time
|
7
|
+
|
8
|
+
import audeer
|
9
|
+
import audformat
|
10
|
+
import numpy as np
|
11
|
+
import pandas as pd
|
12
|
+
from sklearn.preprocessing import LabelEncoder
|
13
|
+
|
14
|
+
import nkululeko.glob_conf as glob_conf
|
15
|
+
from nkululeko.data.dataset import Dataset
|
16
|
+
from nkululeko.data.dataset_csv import Dataset_CSV
|
17
|
+
from nkululeko.demo_predictor import Demo_predictor
|
18
|
+
from nkululeko.feat_extract.feats_analyser import FeatureAnalyser
|
19
|
+
from nkululeko.feature_extractor import FeatureExtractor
|
20
|
+
from nkululeko.file_checker import FileChecker
|
21
|
+
from nkululeko.filter_data import DataFilter
|
22
|
+
from nkululeko.plots import Plots
|
23
|
+
from nkululeko.reporting.report import Report
|
24
|
+
from nkululeko.runmanager import Runmanager
|
25
|
+
from nkululeko.scaler import Scaler
|
26
|
+
from nkululeko.test_predictor import TestPredictor
|
27
|
+
from nkululeko.utils.util import Util
|
28
|
+
|
29
|
+
|
30
|
+
class Experiment:
|
31
|
+
"""Main class specifying an experiment"""
|
32
|
+
|
33
|
+
def __init__(self, config_obj):
|
34
|
+
"""
|
35
|
+
Parameters
|
36
|
+
----------
|
37
|
+
config_obj : a config parser object that sets the experiment parameters and being set as a global object.
|
38
|
+
"""
|
39
|
+
|
40
|
+
self.set_globals(config_obj)
|
41
|
+
self.name = glob_conf.config["EXP"]["name"]
|
42
|
+
self.root = os.path.join(glob_conf.config["EXP"]["root"], "")
|
43
|
+
self.data_dir = os.path.join(self.root, self.name)
|
44
|
+
audeer.mkdir(self.data_dir) # create the experiment directory
|
45
|
+
self.util = Util("experiment")
|
46
|
+
glob_conf.set_util(self.util)
|
47
|
+
fresh_report = eval(self.util.config_val("REPORT", "fresh", "False"))
|
48
|
+
if not fresh_report:
|
49
|
+
try:
|
50
|
+
with open(os.path.join(self.data_dir, "report.pkl"), "rb") as handle:
|
51
|
+
self.report = pickle.load(handle)
|
52
|
+
except FileNotFoundError:
|
53
|
+
self.report = Report()
|
54
|
+
else:
|
55
|
+
self.util.debug("starting a fresh report")
|
56
|
+
self.report = Report()
|
57
|
+
glob_conf.set_report(self.report)
|
58
|
+
self.loso = self.util.config_val("MODEL", "loso", False)
|
59
|
+
self.logo = self.util.config_val("MODEL", "logo", False)
|
60
|
+
self.xfoldx = self.util.config_val("MODEL", "k_fold_cross", False)
|
61
|
+
self.start = time.process_time()
|
62
|
+
|
63
|
+
def set_module(self, module):
|
64
|
+
glob_conf.set_module(module)
|
65
|
+
|
66
|
+
def store_report(self):
|
67
|
+
with open(os.path.join(self.data_dir, "report.pkl"), "wb") as handle:
|
68
|
+
pickle.dump(self.report, handle)
|
69
|
+
if eval(self.util.config_val("REPORT", "show", "False")):
|
70
|
+
self.report.print()
|
71
|
+
if self.util.config_val("REPORT", "latex", False):
|
72
|
+
self.report.export_latex()
|
73
|
+
|
74
|
+
def get_name(self):
|
75
|
+
return self.util.get_exp_name()
|
76
|
+
|
77
|
+
def set_globals(self, config_obj):
|
78
|
+
"""install a config object in the global space"""
|
79
|
+
glob_conf.init_config(config_obj)
|
80
|
+
|
81
|
+
def load_datasets(self):
|
82
|
+
"""Load all databases specified in the configuration and map the labels"""
|
83
|
+
ds = ast.literal_eval(glob_conf.config["DATA"]["databases"])
|
84
|
+
self.datasets = {}
|
85
|
+
self.got_speaker, self.got_gender, self.got_age = False, False, False
|
86
|
+
for d in ds:
|
87
|
+
ds_type = self.util.config_val_data(d, "type", "audformat")
|
88
|
+
if ds_type == "audformat":
|
89
|
+
data = Dataset(d)
|
90
|
+
elif ds_type == "csv":
|
91
|
+
data = Dataset_CSV(d)
|
92
|
+
else:
|
93
|
+
self.util.error(f"unknown data type: {ds_type}")
|
94
|
+
data.load()
|
95
|
+
data.prepare()
|
96
|
+
if data.got_gender:
|
97
|
+
self.got_gender = True
|
98
|
+
if data.got_age:
|
99
|
+
self.got_age = True
|
100
|
+
if data.got_speaker:
|
101
|
+
self.got_speaker = True
|
102
|
+
self.datasets.update({d: data})
|
103
|
+
self.target = self.util.config_val("DATA", "target", "emotion")
|
104
|
+
glob_conf.set_target(self.target)
|
105
|
+
# print target via debug
|
106
|
+
self.util.debug(f"target: {self.target}")
|
107
|
+
# print keys/column
|
108
|
+
dbs = ",".join(list(self.datasets.keys()))
|
109
|
+
labels = self.util.config_val("DATA", "labels", False)
|
110
|
+
if labels:
|
111
|
+
self.labels = ast.literal_eval(labels)
|
112
|
+
self.util.debug(f"Target labels (from config): {labels}")
|
113
|
+
else:
|
114
|
+
self.labels = list(
|
115
|
+
next(iter(self.datasets.values())).df[self.target].unique()
|
116
|
+
)
|
117
|
+
self.util.debug(f"Target labels (from database): {labels}")
|
118
|
+
glob_conf.set_labels(self.labels)
|
119
|
+
self.util.debug(f"loaded databases {dbs}")
|
120
|
+
|
121
|
+
def _import_csv(self, storage):
|
122
|
+
# df = pd.read_csv(storage, header=0, index_col=[0,1,2])
|
123
|
+
# df.index.set_levels(pd.to_timedelta(df.index.levels[1]), level=1)
|
124
|
+
# df.index.set_levels(pd.to_timedelta(df.index.levels[2]), level=2)
|
125
|
+
df = audformat.utils.read_csv(storage)
|
126
|
+
df.is_labeled = True if self.target in df else False
|
127
|
+
# print(df.head())
|
128
|
+
return df
|
129
|
+
|
130
|
+
def fill_tests(self):
|
131
|
+
"""Only fill a new test set"""
|
132
|
+
|
133
|
+
test_dbs = ast.literal_eval(glob_conf.config["DATA"]["tests"])
|
134
|
+
self.df_test = pd.DataFrame()
|
135
|
+
start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
|
136
|
+
store = self.util.get_path("store")
|
137
|
+
storage_test = f"{store}extra_testdf.csv"
|
138
|
+
if os.path.isfile(storage_test) and not start_fresh:
|
139
|
+
self.util.debug(f"reusing previously stored {storage_test}")
|
140
|
+
self.df_test = self._import_csv(storage_test)
|
141
|
+
else:
|
142
|
+
for d in test_dbs:
|
143
|
+
ds_type = self.util.config_val_data(d, "type", "audformat")
|
144
|
+
if ds_type == "audformat":
|
145
|
+
data = Dataset(d)
|
146
|
+
elif ds_type == "csv":
|
147
|
+
data = Dataset_CSV(d)
|
148
|
+
else:
|
149
|
+
self.util.error(f"unknown data type: {ds_type}")
|
150
|
+
data.load()
|
151
|
+
if data.got_gender:
|
152
|
+
self.got_gender = True
|
153
|
+
if data.got_age:
|
154
|
+
self.got_age = True
|
155
|
+
if data.got_speaker:
|
156
|
+
self.got_speaker = True
|
157
|
+
data.split()
|
158
|
+
data.prepare_labels()
|
159
|
+
self.df_test = pd.concat(
|
160
|
+
[self.df_test, self.util.make_segmented_index(data.df_test)]
|
161
|
+
)
|
162
|
+
self.df_test.is_labeled = data.is_labeled
|
163
|
+
self.df_test.got_gender = self.got_gender
|
164
|
+
self.df_test.got_speaker = self.got_speaker
|
165
|
+
# self.util.set_config_val('FEATS', 'needs_features_extraction', 'True')
|
166
|
+
# self.util.set_config_val('FEATS', 'no_reuse', 'True')
|
167
|
+
self.df_test["class_label"] = self.df_test[self.target]
|
168
|
+
self.df_test[self.target] = self.label_encoder.transform(
|
169
|
+
self.df_test[self.target]
|
170
|
+
)
|
171
|
+
self.df_test.to_csv(storage_test)
|
172
|
+
|
173
|
+
def fill_train_and_tests(self):
|
174
|
+
"""Set up train and development sets. The method should be specified in the config."""
|
175
|
+
store = self.util.get_path("store")
|
176
|
+
storage_test = f"{store}testdf.csv"
|
177
|
+
storage_train = f"{store}traindf.csv"
|
178
|
+
start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
|
179
|
+
if (
|
180
|
+
os.path.isfile(storage_train)
|
181
|
+
and os.path.isfile(storage_test)
|
182
|
+
and not start_fresh
|
183
|
+
):
|
184
|
+
self.util.debug(
|
185
|
+
f"reusing previously stored {storage_test} and {storage_train}"
|
186
|
+
)
|
187
|
+
self.df_test = self._import_csv(storage_test)
|
188
|
+
# print(f"df_test: {self.df_test}")
|
189
|
+
self.df_train = self._import_csv(storage_train)
|
190
|
+
# print(f"df_train: {self.df_train}")
|
191
|
+
else:
|
192
|
+
self.df_train, self.df_test = pd.DataFrame(), pd.DataFrame()
|
193
|
+
for d in self.datasets.values():
|
194
|
+
d.split()
|
195
|
+
d.prepare_labels()
|
196
|
+
if d.df_train.shape[0] == 0:
|
197
|
+
self.util.debug(f"warn: {d.name} train empty")
|
198
|
+
self.df_train = pd.concat([self.df_train, d.df_train])
|
199
|
+
# print(f"df_train: {self.df_train}")
|
200
|
+
self.util.copy_flags(d, self.df_train)
|
201
|
+
if d.df_test.shape[0] == 0:
|
202
|
+
self.util.debug(f"warn: {d.name} test empty")
|
203
|
+
self.df_test = pd.concat([self.df_test, d.df_test])
|
204
|
+
self.util.copy_flags(d, self.df_test)
|
205
|
+
store = self.util.get_path("store")
|
206
|
+
storage_test = f"{store}testdf.csv"
|
207
|
+
storage_train = f"{store}traindf.csv"
|
208
|
+
self.df_test.to_csv(storage_test)
|
209
|
+
self.df_train.to_csv(storage_train)
|
210
|
+
|
211
|
+
self.util.copy_flags(self, self.df_test)
|
212
|
+
self.util.copy_flags(self, self.df_train)
|
213
|
+
# Try data checks
|
214
|
+
datachecker = FileChecker(self.df_train)
|
215
|
+
self.df_train = datachecker.all_checks()
|
216
|
+
datachecker.set_data(self.df_test)
|
217
|
+
self.df_test = datachecker.all_checks()
|
218
|
+
|
219
|
+
# Check for filters
|
220
|
+
filter_sample_selection = self.util.config_val(
|
221
|
+
"DATA", "filter.sample_selection", "all"
|
222
|
+
)
|
223
|
+
if filter_sample_selection == "all":
|
224
|
+
datafilter = DataFilter(self.df_train)
|
225
|
+
self.df_train = datafilter.all_filters()
|
226
|
+
datafilter = DataFilter(self.df_test)
|
227
|
+
self.df_test = datafilter.all_filters()
|
228
|
+
elif filter_sample_selection == "train":
|
229
|
+
datafilter = DataFilter(self.df_train)
|
230
|
+
self.df_train = datafilter.all_filters()
|
231
|
+
elif filter_sample_selection == "test":
|
232
|
+
datafilter = DataFilter(self.df_test)
|
233
|
+
self.df_test = datafilter.all_filters()
|
234
|
+
else:
|
235
|
+
self.util.error(
|
236
|
+
"unkown filter sample selection specifier"
|
237
|
+
f" {filter_sample_selection}, should be [all | train | test]"
|
238
|
+
)
|
239
|
+
|
240
|
+
# encode the labels
|
241
|
+
if self.util.exp_is_classification():
|
242
|
+
datatype = self.util.config_val("DATA", "type", "dummy")
|
243
|
+
if datatype == "continuous":
|
244
|
+
# if self.df_test.is_labeled:
|
245
|
+
# # remember the target in case they get labelencoded later
|
246
|
+
# self.df_test["class_label"] = self.df_test[self.target]
|
247
|
+
test_cats = self.df_test["class_label"].unique()
|
248
|
+
# else:
|
249
|
+
# # if there is no target, copy a dummy label
|
250
|
+
# self.df_test = self._add_random_target(self.df_test)
|
251
|
+
# if self.df_train.is_labeled:
|
252
|
+
# # remember the target in case they get labelencoded later
|
253
|
+
# self.df_train["class_label"] = self.df_train[self.target]
|
254
|
+
train_cats = self.df_train["class_label"].unique()
|
255
|
+
|
256
|
+
else:
|
257
|
+
if self.df_test.is_labeled:
|
258
|
+
test_cats = self.df_test[self.target].unique()
|
259
|
+
else:
|
260
|
+
# if there is no target, copy a dummy label
|
261
|
+
self.df_test = self._add_random_target(self.df_test).astype("str")
|
262
|
+
train_cats = self.df_train[self.target].unique()
|
263
|
+
# print(f"df_train: {pd.DataFrame(self.df_train[self.target])}")
|
264
|
+
# print(f"train_cats with target {self.target}: {train_cats}")
|
265
|
+
if self.df_test.is_labeled:
|
266
|
+
if type(test_cats) == np.ndarray:
|
267
|
+
self.util.debug(f"Categories test (nd.array): {test_cats}")
|
268
|
+
else:
|
269
|
+
self.util.debug(f"Categories test (list): {list(test_cats)}")
|
270
|
+
if type(train_cats) == np.ndarray:
|
271
|
+
self.util.debug(f"Categories train (nd.array): {train_cats}")
|
272
|
+
else:
|
273
|
+
self.util.debug(f"Categories train (list): {list(train_cats)}")
|
274
|
+
|
275
|
+
# encode the labels as numbers
|
276
|
+
self.label_encoder = LabelEncoder()
|
277
|
+
self.df_train[self.target] = self.label_encoder.fit_transform(
|
278
|
+
self.df_train[self.target]
|
279
|
+
)
|
280
|
+
self.df_test[self.target] = self.label_encoder.transform(
|
281
|
+
self.df_test[self.target]
|
282
|
+
)
|
283
|
+
glob_conf.set_label_encoder(self.label_encoder)
|
284
|
+
if self.got_speaker:
|
285
|
+
self.util.debug(
|
286
|
+
f"{self.df_test.speaker.nunique()} speakers in test and"
|
287
|
+
f" {self.df_train.speaker.nunique()} speakers in train"
|
288
|
+
)
|
289
|
+
|
290
|
+
target_factor = self.util.config_val("DATA", "target_divide_by", False)
|
291
|
+
if target_factor:
|
292
|
+
self.df_test[self.target] = self.df_test[self.target] / float(target_factor)
|
293
|
+
self.df_train[self.target] = self.df_train[self.target] / float(
|
294
|
+
target_factor
|
295
|
+
)
|
296
|
+
if not self.util.exp_is_classification():
|
297
|
+
self.df_test["class_label"] = self.df_test["class_label"] / float(
|
298
|
+
target_factor
|
299
|
+
)
|
300
|
+
self.df_train["class_label"] = self.df_train["class_label"] / float(
|
301
|
+
target_factor
|
302
|
+
)
|
303
|
+
|
304
|
+
def _add_random_target(self, df):
|
305
|
+
labels = glob_conf.labels
|
306
|
+
a = [None] * len(df)
|
307
|
+
for i in range(0, len(df)):
|
308
|
+
a[i] = random.choice(labels)
|
309
|
+
df[self.target] = a
|
310
|
+
return df
|
311
|
+
|
312
|
+
def plot_distribution(self, df_labels):
|
313
|
+
"""Plot the distribution of samples and speaker per target class and biological sex"""
|
314
|
+
plot = Plots()
|
315
|
+
sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
|
316
|
+
plot.plot_distributions(df_labels)
|
317
|
+
if self.got_speaker:
|
318
|
+
plot.plot_distributions_speaker(df_labels)
|
319
|
+
|
320
|
+
def extract_test_feats(self):
|
321
|
+
self.feats_test = pd.DataFrame()
|
322
|
+
feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["tests"]))
|
323
|
+
feats_types = self.util.config_val_list("FEATS", "type", ["os"])
|
324
|
+
self.feature_extractor = FeatureExtractor(
|
325
|
+
self.df_test, feats_types, feats_name, "test"
|
326
|
+
)
|
327
|
+
self.feats_test = self.feature_extractor.extract()
|
328
|
+
self.util.debug(f"Test features shape:{self.feats_test.shape}")
|
329
|
+
|
330
|
+
def extract_feats(self):
|
331
|
+
"""Extract the features for train and dev sets.
|
332
|
+
|
333
|
+
They will be stored on disk and need to be removed manually.
|
334
|
+
|
335
|
+
The string FEATS.feats_type is read from the config, defaults to os.
|
336
|
+
|
337
|
+
"""
|
338
|
+
df_train, df_test = self.df_train, self.df_test
|
339
|
+
feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
|
340
|
+
self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
|
341
|
+
feats_types = self.util.config_val_list("FEATS", "type", ["os"])
|
342
|
+
self.feature_extractor = FeatureExtractor(
|
343
|
+
df_train, feats_types, feats_name, "train"
|
344
|
+
)
|
345
|
+
self.feats_train = self.feature_extractor.extract()
|
346
|
+
self.feature_extractor = FeatureExtractor(
|
347
|
+
df_test, feats_types, feats_name, "test"
|
348
|
+
)
|
349
|
+
self.feats_test = self.feature_extractor.extract()
|
350
|
+
self.util.debug(
|
351
|
+
f"All features: train shape : {self.feats_train.shape}, test"
|
352
|
+
f" shape:{self.feats_test.shape}"
|
353
|
+
)
|
354
|
+
if self.feats_train.shape[0] < self.df_train.shape[0]:
|
355
|
+
self.util.warn(
|
356
|
+
f"train feats ({self.feats_train.shape[0]}) != train labels"
|
357
|
+
f" ({self.df_train.shape[0]})"
|
358
|
+
)
|
359
|
+
self.df_train = self.df_train[
|
360
|
+
self.df_train.index.isin(self.feats_train.index)
|
361
|
+
]
|
362
|
+
self.util.warn(f"new train labels shape: {self.df_train.shape[0]}")
|
363
|
+
if self.feats_test.shape[0] < self.df_test.shape[0]:
|
364
|
+
self.util.warn(
|
365
|
+
f"test feats ({self.feats_test.shape[0]}) != test labels"
|
366
|
+
f" ({self.df_test.shape[0]})"
|
367
|
+
)
|
368
|
+
self.df_test = self.df_test[self.df_test.index.isin(self.feats_test.index)]
|
369
|
+
self.util.warn(f"mew test labels shape: {self.df_test.shape[0]}")
|
370
|
+
|
371
|
+
self._check_scale()
|
372
|
+
|
373
|
+
def augment(self):
|
374
|
+
"""
|
375
|
+
Augment the selected samples
|
376
|
+
"""
|
377
|
+
from nkululeko.augmenting.augmenter import Augmenter
|
378
|
+
|
379
|
+
sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
|
380
|
+
if sample_selection == "all":
|
381
|
+
df = pd.concat([self.df_train, self.df_test])
|
382
|
+
elif sample_selection == "train":
|
383
|
+
df = self.df_train
|
384
|
+
elif sample_selection == "test":
|
385
|
+
df = self.df_test
|
386
|
+
else:
|
387
|
+
self.util.error(
|
388
|
+
f"unknown augmentation selection specifier {sample_selection},"
|
389
|
+
" should be [all | train | test]"
|
390
|
+
)
|
391
|
+
|
392
|
+
augmenter = Augmenter(df)
|
393
|
+
df_ret = augmenter.augment(sample_selection)
|
394
|
+
return df_ret
|
395
|
+
|
396
|
+
def autopredict(self):
|
397
|
+
"""
|
398
|
+
Predict labels for samples with existing models and add to the dataframe.
|
399
|
+
"""
|
400
|
+
sample_selection = self.util.config_val("PREDICT", "split", "all")
|
401
|
+
if sample_selection == "all":
|
402
|
+
df = pd.concat([self.df_train, self.df_test])
|
403
|
+
elif sample_selection == "train":
|
404
|
+
df = self.df_train
|
405
|
+
elif sample_selection == "test":
|
406
|
+
df = self.df_test
|
407
|
+
else:
|
408
|
+
self.util.error(
|
409
|
+
f"unknown augmentation selection specifier {sample_selection},"
|
410
|
+
" should be [all | train | test]"
|
411
|
+
)
|
412
|
+
targets = self.util.config_val_list("PREDICT", "targets", ["gender"])
|
413
|
+
for target in targets:
|
414
|
+
if target == "gender":
|
415
|
+
from nkululeko.autopredict.ap_gender import GenderPredictor
|
416
|
+
|
417
|
+
predictor = GenderPredictor(df)
|
418
|
+
df = predictor.predict(sample_selection)
|
419
|
+
elif target == "age":
|
420
|
+
from nkululeko.autopredict.ap_age import AgePredictor
|
421
|
+
|
422
|
+
predictor = AgePredictor(df)
|
423
|
+
df = predictor.predict(sample_selection)
|
424
|
+
elif target == "snr":
|
425
|
+
from nkululeko.autopredict.ap_snr import SNRPredictor
|
426
|
+
|
427
|
+
predictor = SNRPredictor(df)
|
428
|
+
df = predictor.predict(sample_selection)
|
429
|
+
elif target == "mos":
|
430
|
+
from nkululeko.autopredict.ap_mos import MOSPredictor
|
431
|
+
|
432
|
+
predictor = MOSPredictor(df)
|
433
|
+
df = predictor.predict(sample_selection)
|
434
|
+
elif target == "pesq":
|
435
|
+
from nkululeko.autopredict.ap_pesq import PESQPredictor
|
436
|
+
|
437
|
+
predictor = PESQPredictor(df)
|
438
|
+
df = predictor.predict(sample_selection)
|
439
|
+
elif target == "sdr":
|
440
|
+
from nkululeko.autopredict.ap_sdr import SDRPredictor
|
441
|
+
|
442
|
+
predictor = SDRPredictor(df)
|
443
|
+
df = predictor.predict(sample_selection)
|
444
|
+
elif target == "stoi":
|
445
|
+
from nkululeko.autopredict.ap_stoi import STOIPredictor
|
446
|
+
|
447
|
+
predictor = STOIPredictor(df)
|
448
|
+
df = predictor.predict(sample_selection)
|
449
|
+
elif target == "arousal":
|
450
|
+
from nkululeko.autopredict.ap_arousal import ArousalPredictor
|
451
|
+
|
452
|
+
predictor = ArousalPredictor(df)
|
453
|
+
df = predictor.predict(sample_selection)
|
454
|
+
elif target == "valence":
|
455
|
+
from nkululeko.autopredict.ap_valence import ValencePredictor
|
456
|
+
|
457
|
+
predictor = ValencePredictor(df)
|
458
|
+
df = predictor.predict(sample_selection)
|
459
|
+
elif target == "dominance":
|
460
|
+
from nkululeko.autopredict.ap_dominance import DominancePredictor
|
461
|
+
|
462
|
+
predictor = DominancePredictor(df)
|
463
|
+
df = predictor.predict(sample_selection)
|
464
|
+
else:
|
465
|
+
self.util.error(f"unknown auto predict target: {target}")
|
466
|
+
return df
|
467
|
+
|
468
|
+
def random_splice(self):
|
469
|
+
"""
|
470
|
+
Random-splice the selected samples
|
471
|
+
"""
|
472
|
+
from nkululeko.augmenting.randomsplicer import Randomsplicer
|
473
|
+
|
474
|
+
sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
|
475
|
+
if sample_selection == "all":
|
476
|
+
df = pd.concat([self.df_train, self.df_test])
|
477
|
+
elif sample_selection == "train":
|
478
|
+
df = self.df_train
|
479
|
+
elif sample_selection == "test":
|
480
|
+
df = self.df_test
|
481
|
+
else:
|
482
|
+
self.util.error(
|
483
|
+
f"unknown augmentation selection specifier {sample_selection},"
|
484
|
+
" should be [all | train | test]"
|
485
|
+
)
|
486
|
+
randomsplicer = Randomsplicer(df)
|
487
|
+
df_ret = randomsplicer.run(sample_selection)
|
488
|
+
return df_ret
|
489
|
+
|
490
|
+
def analyse_features(self, needs_feats):
|
491
|
+
"""Do a feature exploration."""
|
492
|
+
plot_feats = eval(
|
493
|
+
self.util.config_val("EXPL", "feature_distributions", "False")
|
494
|
+
)
|
495
|
+
sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
|
496
|
+
# get the data labels
|
497
|
+
if sample_selection == "all":
|
498
|
+
df_labels = pd.concat([self.df_train, self.df_test])
|
499
|
+
self.util.copy_flags(self.df_train, df_labels)
|
500
|
+
elif sample_selection == "train":
|
501
|
+
df_labels = self.df_train
|
502
|
+
self.util.copy_flags(self.df_train, df_labels)
|
503
|
+
elif sample_selection == "test":
|
504
|
+
df_labels = self.df_test
|
505
|
+
self.util.copy_flags(self.df_test, df_labels)
|
506
|
+
else:
|
507
|
+
self.util.error(
|
508
|
+
f"unknown sample selection specifier {sample_selection}, should"
|
509
|
+
" be [all | train | test]"
|
510
|
+
)
|
511
|
+
self.util.debug(f"sampling selection: {sample_selection}")
|
512
|
+
if self.util.config_val("EXPL", "value_counts", False):
|
513
|
+
self.plot_distribution(df_labels)
|
514
|
+
|
515
|
+
# check if data should be shown with the spotlight data visualizer
|
516
|
+
spotlight = eval(self.util.config_val("EXPL", "spotlight", "False"))
|
517
|
+
if spotlight:
|
518
|
+
self.util.debug("opening spotlight tab in web browser")
|
519
|
+
from renumics import spotlight
|
520
|
+
|
521
|
+
spotlight.show(df_labels.reset_index())
|
522
|
+
|
523
|
+
if not needs_feats:
|
524
|
+
return
|
525
|
+
# get the feature values
|
526
|
+
if sample_selection == "all":
|
527
|
+
df_feats = pd.concat([self.feats_train, self.feats_test])
|
528
|
+
elif sample_selection == "train":
|
529
|
+
df_feats = self.feats_train
|
530
|
+
elif sample_selection == "test":
|
531
|
+
df_feats = self.feats_test
|
532
|
+
else:
|
533
|
+
self.util.error(
|
534
|
+
f"unknown sample selection specifier {sample_selection}, should"
|
535
|
+
" be [all | train | test]"
|
536
|
+
)
|
537
|
+
feat_analyser = FeatureAnalyser(sample_selection, df_labels, df_feats)
|
538
|
+
# check if SHAP features should be analysed
|
539
|
+
shap = eval(self.util.config_val("EXPL", "shap", "False"))
|
540
|
+
if shap:
|
541
|
+
feat_analyser.analyse_shap(self.runmgr.get_best_model())
|
542
|
+
|
543
|
+
if plot_feats:
|
544
|
+
feat_analyser.analyse()
|
545
|
+
|
546
|
+
# check if a scatterplot should be done
|
547
|
+
scatter_var = eval(self.util.config_val("EXPL", "scatter", "False"))
|
548
|
+
scatter_target = self.util.config_val(
|
549
|
+
"EXPL", "scatter.target", "['class_label']"
|
550
|
+
)
|
551
|
+
if scatter_var:
|
552
|
+
scatters = ast.literal_eval(glob_conf.config["EXPL"]["scatter"])
|
553
|
+
scat_targets = ast.literal_eval(scatter_target)
|
554
|
+
plots = Plots()
|
555
|
+
for scat_target in scat_targets:
|
556
|
+
if self.util.is_categorical(df_labels[scat_target]):
|
557
|
+
for scatter in scatters:
|
558
|
+
plots.scatter_plot(df_feats, df_labels, scat_target, scatter)
|
559
|
+
else:
|
560
|
+
self.util.debug(
|
561
|
+
f"{self.name}: binning continuous variable to categories"
|
562
|
+
)
|
563
|
+
cat_vals = self.util.continuous_to_categorical(
|
564
|
+
df_labels[scat_target]
|
565
|
+
)
|
566
|
+
df_labels[f"{scat_target}_bins"] = cat_vals.values
|
567
|
+
for scatter in scatters:
|
568
|
+
plots.scatter_plot(
|
569
|
+
df_feats, df_labels, f"{scat_target}_bins", scatter
|
570
|
+
)
|
571
|
+
|
572
|
+
def _check_scale(self):
|
573
|
+
scale_feats = self.util.config_val("FEATS", "scale", False)
|
574
|
+
# print the scale
|
575
|
+
self.util.debug(f"scaler: {scale_feats}")
|
576
|
+
if scale_feats:
|
577
|
+
self.scaler_feats = Scaler(
|
578
|
+
self.df_train,
|
579
|
+
self.df_test,
|
580
|
+
self.feats_train,
|
581
|
+
self.feats_test,
|
582
|
+
scale_feats,
|
583
|
+
)
|
584
|
+
self.feats_train, self.feats_test = self.scaler_feats.scale()
|
585
|
+
# store versions
|
586
|
+
self.util.save_to_store(self.feats_train, "feats_train_scaled")
|
587
|
+
self.util.save_to_store(self.feats_test, "feats_test_scaled")
|
588
|
+
|
589
|
+
def init_runmanager(self):
|
590
|
+
"""Initialize the manager object for the runs."""
|
591
|
+
self.runmgr = Runmanager(
|
592
|
+
self.df_train, self.df_test, self.feats_train, self.feats_test
|
593
|
+
)
|
594
|
+
|
595
|
+
def run(self):
|
596
|
+
"""Do the runs."""
|
597
|
+
self.runmgr.do_runs()
|
598
|
+
|
599
|
+
# access the best results all runs
|
600
|
+
self.reports = self.runmgr.best_results
|
601
|
+
last_epochs = self.runmgr.last_epochs
|
602
|
+
# try to save yourself
|
603
|
+
save = self.util.config_val("EXP", "save", False)
|
604
|
+
if save:
|
605
|
+
# save the experiment for future use
|
606
|
+
self.save(self.util.get_save_name())
|
607
|
+
# self.save_onnx(self.util.get_save_name())
|
608
|
+
|
609
|
+
# self.__collect_reports()
|
610
|
+
self.util.print_best_results(self.reports)
|
611
|
+
|
612
|
+
# check if the test predictions should be saved to disk
|
613
|
+
test_pred_file = self.util.config_val("EXP", "save_test", False)
|
614
|
+
if test_pred_file:
|
615
|
+
self.predict_test_and_save(test_pred_file)
|
616
|
+
|
617
|
+
# check if the majority voting for all speakers should be plotted
|
618
|
+
conf_mat_per_speaker_function = self.util.config_val(
|
619
|
+
"PLOT", "combine_per_speaker", False
|
620
|
+
)
|
621
|
+
if conf_mat_per_speaker_function:
|
622
|
+
self.plot_confmat_per_speaker(conf_mat_per_speaker_function)
|
623
|
+
used_time = time.process_time() - self.start
|
624
|
+
self.util.debug(f"Done, used {used_time:.3f} seconds")
|
625
|
+
|
626
|
+
# check if a test set should be labeled by the model:
|
627
|
+
label_data = self.util.config_val("DATA", "label_data", False)
|
628
|
+
label_result = self.util.config_val("DATA", "label_result", False)
|
629
|
+
if label_data and label_result:
|
630
|
+
self.predict_test_and_save(label_result)
|
631
|
+
|
632
|
+
return self.reports, last_epochs
|
633
|
+
|
634
|
+
def plot_confmat_per_speaker(self, function):
|
635
|
+
if self.loso or self.logo or self.xfoldx:
|
636
|
+
self.util.debug(
|
637
|
+
"plot combined speaker predictions not possible for cross" " validation"
|
638
|
+
)
|
639
|
+
return
|
640
|
+
best = self.get_best_report(self.reports)
|
641
|
+
# if not best.is_classification:
|
642
|
+
# best.continuous_to_categorical()
|
643
|
+
truths = best.truths
|
644
|
+
preds = best.preds
|
645
|
+
speakers = self.df_test.speaker.values
|
646
|
+
print(f"{len(truths)} {len(preds)} {len(speakers) }")
|
647
|
+
df = pd.DataFrame(data={"truth": truths, "pred": preds, "speaker": speakers})
|
648
|
+
plot_name = "result_combined_per_speaker"
|
649
|
+
self.util.debug(
|
650
|
+
f"plotting speaker combination ({function}) confusion matrix to"
|
651
|
+
f" {plot_name}"
|
652
|
+
)
|
653
|
+
best.plot_per_speaker(df, plot_name, function)
|
654
|
+
|
655
|
+
def get_best_report(self, reports):
|
656
|
+
return self.runmgr.get_best_result(reports)
|
657
|
+
|
658
|
+
def print_best_model(self):
|
659
|
+
self.runmgr.print_best_result_runs()
|
660
|
+
|
661
|
+
def demo(self, file, is_list, outfile):
|
662
|
+
model = self.runmgr.get_best_model()
|
663
|
+
labelEncoder = None
|
664
|
+
try:
|
665
|
+
labelEncoder = self.label_encoder
|
666
|
+
except AttributeError:
|
667
|
+
pass
|
668
|
+
demo = Demo_predictor(
|
669
|
+
model, file, is_list, self.feature_extractor, labelEncoder, outfile
|
670
|
+
)
|
671
|
+
demo.run_demo()
|
672
|
+
|
673
|
+
def predict_test_and_save(self, result_name):
|
674
|
+
model = self.runmgr.get_best_model()
|
675
|
+
model.set_testdata(self.df_test, self.feats_test)
|
676
|
+
test_predictor = TestPredictor(
|
677
|
+
model, self.df_test, self.label_encoder, result_name
|
678
|
+
)
|
679
|
+
result = test_predictor.predict_and_store()
|
680
|
+
return result
|
681
|
+
|
682
|
+
def load(self, filename):
|
683
|
+
try:
|
684
|
+
f = open(filename, "rb")
|
685
|
+
tmp_dict = pickle.load(f)
|
686
|
+
f.close()
|
687
|
+
except EOFError as eof:
|
688
|
+
self.util.error(f"can't open file {filename}: {eof}")
|
689
|
+
self.__dict__.update(tmp_dict)
|
690
|
+
glob_conf.set_labels(self.labels)
|
691
|
+
|
692
|
+
def save(self, filename):
|
693
|
+
if self.runmgr.modelrunner.model.is_ann():
|
694
|
+
self.runmgr.modelrunner.model = None
|
695
|
+
self.util.warn(
|
696
|
+
"Save experiment: Can't pickle the trained model so saving without it. (it should be stored anyway)"
|
697
|
+
)
|
698
|
+
try:
|
699
|
+
f = open(filename, "wb")
|
700
|
+
pickle.dump(self.__dict__, f)
|
701
|
+
f.close()
|
702
|
+
except (TypeError, AttributeError) as error:
|
703
|
+
self.feature_extractor.feat_extractor.model = None
|
704
|
+
f = open(filename, "wb")
|
705
|
+
pickle.dump(self.__dict__, f)
|
706
|
+
f.close()
|
707
|
+
self.util.warn(
|
708
|
+
"Save experiment: Can't pickle the feature extraction model so saving without it."
|
709
|
+
+ f"{type(error).__name__} {error}"
|
710
|
+
)
|
711
|
+
except RuntimeError as error:
|
712
|
+
self.util.warn(
|
713
|
+
"Save experiment: Can't pickle local object, NOT saving: "
|
714
|
+
+ f"{type(error).__name__} {error}"
|
715
|
+
)
|
716
|
+
|
717
|
+
def save_onnx(self, filename):
|
718
|
+
# export the model to onnx
|
719
|
+
model = self.runmgr.get_best_model()
|
720
|
+
if model.is_ann():
|
721
|
+
print("converting to onnx from torch")
|
722
|
+
else:
|
723
|
+
|
724
|
+
print("converting to onnx from sklearn")
|
725
|
+
# save the rest
|
726
|
+
f = open(filename, "wb")
|
727
|
+
pickle.dump(self.__dict__, f)
|
728
|
+
f.close()
|