nkululeko 0.89.2__py3-none-any.whl → 0.90.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. nkululeko/aug_train.py +6 -4
  2. nkululeko/augment.py +8 -6
  3. nkululeko/augmenting/augmenter.py +4 -4
  4. nkululeko/augmenting/randomsplicer.py +12 -9
  5. nkululeko/augmenting/randomsplicing.py +2 -3
  6. nkululeko/augmenting/resampler.py +9 -6
  7. nkululeko/autopredict/ap_age.py +4 -2
  8. nkululeko/autopredict/ap_arousal.py +4 -2
  9. nkululeko/autopredict/ap_dominance.py +3 -2
  10. nkululeko/autopredict/ap_gender.py +4 -2
  11. nkululeko/autopredict/ap_mos.py +5 -2
  12. nkululeko/autopredict/ap_pesq.py +5 -2
  13. nkululeko/autopredict/ap_sdr.py +5 -2
  14. nkululeko/autopredict/ap_snr.py +5 -2
  15. nkululeko/autopredict/ap_stoi.py +5 -2
  16. nkululeko/autopredict/ap_valence.py +4 -2
  17. nkululeko/autopredict/estimate_snr.py +10 -14
  18. nkululeko/cacheddataset.py +1 -1
  19. nkululeko/constants.py +1 -1
  20. nkululeko/data/dataset.py +19 -16
  21. nkululeko/data/dataset_csv.py +5 -3
  22. nkululeko/demo-ft.py +29 -0
  23. nkululeko/demo_feats.py +5 -4
  24. nkululeko/demo_predictor.py +3 -4
  25. nkululeko/ensemble.py +27 -28
  26. nkululeko/experiment.py +11 -7
  27. nkululeko/experiment_felix.py +728 -0
  28. nkululeko/explore.py +1 -0
  29. nkululeko/export.py +7 -5
  30. nkululeko/feat_extract/feats_agender.py +5 -4
  31. nkululeko/feat_extract/feats_agender_agender.py +7 -6
  32. nkululeko/feat_extract/feats_analyser.py +18 -16
  33. nkululeko/feat_extract/feats_ast.py +9 -8
  34. nkululeko/feat_extract/feats_auddim.py +3 -5
  35. nkululeko/feat_extract/feats_audmodel.py +2 -2
  36. nkululeko/feat_extract/feats_clap.py +9 -12
  37. nkululeko/feat_extract/feats_hubert.py +2 -3
  38. nkululeko/feat_extract/feats_import.py +5 -4
  39. nkululeko/feat_extract/feats_mld.py +3 -5
  40. nkululeko/feat_extract/feats_mos.py +4 -3
  41. nkululeko/feat_extract/feats_opensmile.py +4 -3
  42. nkululeko/feat_extract/feats_oxbow.py +5 -4
  43. nkululeko/feat_extract/feats_praat.py +4 -7
  44. nkululeko/feat_extract/feats_snr.py +3 -5
  45. nkululeko/feat_extract/feats_spectra.py +8 -9
  46. nkululeko/feat_extract/feats_spkrec.py +6 -11
  47. nkululeko/feat_extract/feats_squim.py +2 -4
  48. nkululeko/feat_extract/feats_trill.py +2 -5
  49. nkululeko/feat_extract/feats_wav2vec2.py +8 -4
  50. nkululeko/feat_extract/feats_wavlm.py +2 -3
  51. nkululeko/feat_extract/feats_whisper.py +4 -6
  52. nkululeko/feat_extract/featureset.py +4 -2
  53. nkululeko/feat_extract/feinberg_praat.py +1 -3
  54. nkululeko/feat_extract/transformer_feature_extractor.py +147 -0
  55. nkululeko/file_checker.py +3 -3
  56. nkululeko/filter_data.py +3 -1
  57. nkululeko/fixedsegment.py +83 -0
  58. nkululeko/models/model.py +3 -5
  59. nkululeko/models/model_bayes.py +1 -0
  60. nkululeko/models/model_cnn.py +4 -6
  61. nkululeko/models/model_gmm.py +13 -9
  62. nkululeko/models/model_knn.py +1 -0
  63. nkululeko/models/model_knn_reg.py +1 -0
  64. nkululeko/models/model_lin_reg.py +1 -0
  65. nkululeko/models/model_mlp.py +2 -3
  66. nkululeko/models/model_mlp_regression.py +1 -6
  67. nkululeko/models/model_svm.py +2 -2
  68. nkululeko/models/model_svr.py +1 -0
  69. nkululeko/models/model_tree.py +2 -3
  70. nkululeko/models/model_tree_reg.py +1 -0
  71. nkululeko/models/model_tuned.py +54 -33
  72. nkululeko/models/model_xgb.py +1 -0
  73. nkululeko/models/model_xgr.py +1 -0
  74. nkululeko/multidb.py +1 -0
  75. nkululeko/nkululeko.py +1 -1
  76. nkululeko/plots.py +1 -1
  77. nkululeko/predict.py +4 -5
  78. nkululeko/reporting/defines.py +6 -8
  79. nkululeko/reporting/latex_writer.py +3 -3
  80. nkululeko/reporting/report.py +2 -2
  81. nkululeko/reporting/report_item.py +1 -0
  82. nkululeko/reporting/reporter.py +20 -19
  83. nkululeko/resample.py +8 -12
  84. nkululeko/resample_cli.py +99 -0
  85. nkululeko/runmanager.py +3 -1
  86. nkululeko/scaler.py +1 -1
  87. nkululeko/segment.py +6 -5
  88. nkululeko/segmenting/seg_inaspeechsegmenter.py +3 -3
  89. nkululeko/segmenting/seg_silero.py +4 -4
  90. nkululeko/syllable_nuclei.py +9 -22
  91. nkululeko/test_pretrain.py +6 -7
  92. nkululeko/utils/stats.py +0 -1
  93. nkululeko/utils/util.py +2 -3
  94. {nkululeko-0.89.2.dist-info → nkululeko-0.90.1.dist-info}/METADATA +12 -2
  95. nkululeko-0.90.1.dist-info/RECORD +119 -0
  96. {nkululeko-0.89.2.dist-info → nkululeko-0.90.1.dist-info}/WHEEL +1 -1
  97. nkululeko-0.89.2.dist-info/RECORD +0 -114
  98. {nkululeko-0.89.2.dist-info → nkululeko-0.90.1.dist-info}/LICENSE +0 -0
  99. {nkululeko-0.89.2.dist-info → nkululeko-0.90.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,728 @@
1
+ # experiment.py: Main class for an experiment (nkululeko.nkululeko)
2
+ import ast
3
+ import os
4
+ import pickle
5
+ import random
6
+ import time
7
+
8
+ import audeer
9
+ import audformat
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sklearn.preprocessing import LabelEncoder
13
+
14
+ import nkululeko.glob_conf as glob_conf
15
+ from nkululeko.data.dataset import Dataset
16
+ from nkululeko.data.dataset_csv import Dataset_CSV
17
+ from nkululeko.demo_predictor import Demo_predictor
18
+ from nkululeko.feat_extract.feats_analyser import FeatureAnalyser
19
+ from nkululeko.feature_extractor import FeatureExtractor
20
+ from nkululeko.file_checker import FileChecker
21
+ from nkululeko.filter_data import DataFilter
22
+ from nkululeko.plots import Plots
23
+ from nkululeko.reporting.report import Report
24
+ from nkululeko.runmanager import Runmanager
25
+ from nkululeko.scaler import Scaler
26
+ from nkululeko.test_predictor import TestPredictor
27
+ from nkululeko.utils.util import Util
28
+
29
+
30
+ class Experiment:
31
+ """Main class specifying an experiment"""
32
+
33
+ def __init__(self, config_obj):
34
+ """
35
+ Parameters
36
+ ----------
37
+ config_obj : a config parser object that sets the experiment parameters and being set as a global object.
38
+ """
39
+
40
+ self.set_globals(config_obj)
41
+ self.name = glob_conf.config["EXP"]["name"]
42
+ self.root = os.path.join(glob_conf.config["EXP"]["root"], "")
43
+ self.data_dir = os.path.join(self.root, self.name)
44
+ audeer.mkdir(self.data_dir) # create the experiment directory
45
+ self.util = Util("experiment")
46
+ glob_conf.set_util(self.util)
47
+ fresh_report = eval(self.util.config_val("REPORT", "fresh", "False"))
48
+ if not fresh_report:
49
+ try:
50
+ with open(os.path.join(self.data_dir, "report.pkl"), "rb") as handle:
51
+ self.report = pickle.load(handle)
52
+ except FileNotFoundError:
53
+ self.report = Report()
54
+ else:
55
+ self.util.debug("starting a fresh report")
56
+ self.report = Report()
57
+ glob_conf.set_report(self.report)
58
+ self.loso = self.util.config_val("MODEL", "loso", False)
59
+ self.logo = self.util.config_val("MODEL", "logo", False)
60
+ self.xfoldx = self.util.config_val("MODEL", "k_fold_cross", False)
61
+ self.start = time.process_time()
62
+
63
+ def set_module(self, module):
64
+ glob_conf.set_module(module)
65
+
66
+ def store_report(self):
67
+ with open(os.path.join(self.data_dir, "report.pkl"), "wb") as handle:
68
+ pickle.dump(self.report, handle)
69
+ if eval(self.util.config_val("REPORT", "show", "False")):
70
+ self.report.print()
71
+ if self.util.config_val("REPORT", "latex", False):
72
+ self.report.export_latex()
73
+
74
+ def get_name(self):
75
+ return self.util.get_exp_name()
76
+
77
+ def set_globals(self, config_obj):
78
+ """install a config object in the global space"""
79
+ glob_conf.init_config(config_obj)
80
+
81
+ def load_datasets(self):
82
+ """Load all databases specified in the configuration and map the labels"""
83
+ ds = ast.literal_eval(glob_conf.config["DATA"]["databases"])
84
+ self.datasets = {}
85
+ self.got_speaker, self.got_gender, self.got_age = False, False, False
86
+ for d in ds:
87
+ ds_type = self.util.config_val_data(d, "type", "audformat")
88
+ if ds_type == "audformat":
89
+ data = Dataset(d)
90
+ elif ds_type == "csv":
91
+ data = Dataset_CSV(d)
92
+ else:
93
+ self.util.error(f"unknown data type: {ds_type}")
94
+ data.load()
95
+ data.prepare()
96
+ if data.got_gender:
97
+ self.got_gender = True
98
+ if data.got_age:
99
+ self.got_age = True
100
+ if data.got_speaker:
101
+ self.got_speaker = True
102
+ self.datasets.update({d: data})
103
+ self.target = self.util.config_val("DATA", "target", "emotion")
104
+ glob_conf.set_target(self.target)
105
+ # print target via debug
106
+ self.util.debug(f"target: {self.target}")
107
+ # print keys/column
108
+ dbs = ",".join(list(self.datasets.keys()))
109
+ labels = self.util.config_val("DATA", "labels", False)
110
+ if labels:
111
+ self.labels = ast.literal_eval(labels)
112
+ self.util.debug(f"Target labels (from config): {labels}")
113
+ else:
114
+ self.labels = list(
115
+ next(iter(self.datasets.values())).df[self.target].unique()
116
+ )
117
+ self.util.debug(f"Target labels (from database): {labels}")
118
+ glob_conf.set_labels(self.labels)
119
+ self.util.debug(f"loaded databases {dbs}")
120
+
121
+ def _import_csv(self, storage):
122
+ # df = pd.read_csv(storage, header=0, index_col=[0,1,2])
123
+ # df.index.set_levels(pd.to_timedelta(df.index.levels[1]), level=1)
124
+ # df.index.set_levels(pd.to_timedelta(df.index.levels[2]), level=2)
125
+ df = audformat.utils.read_csv(storage)
126
+ df.is_labeled = True if self.target in df else False
127
+ # print(df.head())
128
+ return df
129
+
130
+ def fill_tests(self):
131
+ """Only fill a new test set"""
132
+
133
+ test_dbs = ast.literal_eval(glob_conf.config["DATA"]["tests"])
134
+ self.df_test = pd.DataFrame()
135
+ start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
136
+ store = self.util.get_path("store")
137
+ storage_test = f"{store}extra_testdf.csv"
138
+ if os.path.isfile(storage_test) and not start_fresh:
139
+ self.util.debug(f"reusing previously stored {storage_test}")
140
+ self.df_test = self._import_csv(storage_test)
141
+ else:
142
+ for d in test_dbs:
143
+ ds_type = self.util.config_val_data(d, "type", "audformat")
144
+ if ds_type == "audformat":
145
+ data = Dataset(d)
146
+ elif ds_type == "csv":
147
+ data = Dataset_CSV(d)
148
+ else:
149
+ self.util.error(f"unknown data type: {ds_type}")
150
+ data.load()
151
+ if data.got_gender:
152
+ self.got_gender = True
153
+ if data.got_age:
154
+ self.got_age = True
155
+ if data.got_speaker:
156
+ self.got_speaker = True
157
+ data.split()
158
+ data.prepare_labels()
159
+ self.df_test = pd.concat(
160
+ [self.df_test, self.util.make_segmented_index(data.df_test)]
161
+ )
162
+ self.df_test.is_labeled = data.is_labeled
163
+ self.df_test.got_gender = self.got_gender
164
+ self.df_test.got_speaker = self.got_speaker
165
+ # self.util.set_config_val('FEATS', 'needs_features_extraction', 'True')
166
+ # self.util.set_config_val('FEATS', 'no_reuse', 'True')
167
+ self.df_test["class_label"] = self.df_test[self.target]
168
+ self.df_test[self.target] = self.label_encoder.transform(
169
+ self.df_test[self.target]
170
+ )
171
+ self.df_test.to_csv(storage_test)
172
+
173
+ def fill_train_and_tests(self):
174
+ """Set up train and development sets. The method should be specified in the config."""
175
+ store = self.util.get_path("store")
176
+ storage_test = f"{store}testdf.csv"
177
+ storage_train = f"{store}traindf.csv"
178
+ start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
179
+ if (
180
+ os.path.isfile(storage_train)
181
+ and os.path.isfile(storage_test)
182
+ and not start_fresh
183
+ ):
184
+ self.util.debug(
185
+ f"reusing previously stored {storage_test} and {storage_train}"
186
+ )
187
+ self.df_test = self._import_csv(storage_test)
188
+ # print(f"df_test: {self.df_test}")
189
+ self.df_train = self._import_csv(storage_train)
190
+ # print(f"df_train: {self.df_train}")
191
+ else:
192
+ self.df_train, self.df_test = pd.DataFrame(), pd.DataFrame()
193
+ for d in self.datasets.values():
194
+ d.split()
195
+ d.prepare_labels()
196
+ if d.df_train.shape[0] == 0:
197
+ self.util.debug(f"warn: {d.name} train empty")
198
+ self.df_train = pd.concat([self.df_train, d.df_train])
199
+ # print(f"df_train: {self.df_train}")
200
+ self.util.copy_flags(d, self.df_train)
201
+ if d.df_test.shape[0] == 0:
202
+ self.util.debug(f"warn: {d.name} test empty")
203
+ self.df_test = pd.concat([self.df_test, d.df_test])
204
+ self.util.copy_flags(d, self.df_test)
205
+ store = self.util.get_path("store")
206
+ storage_test = f"{store}testdf.csv"
207
+ storage_train = f"{store}traindf.csv"
208
+ self.df_test.to_csv(storage_test)
209
+ self.df_train.to_csv(storage_train)
210
+
211
+ self.util.copy_flags(self, self.df_test)
212
+ self.util.copy_flags(self, self.df_train)
213
+ # Try data checks
214
+ datachecker = FileChecker(self.df_train)
215
+ self.df_train = datachecker.all_checks()
216
+ datachecker.set_data(self.df_test)
217
+ self.df_test = datachecker.all_checks()
218
+
219
+ # Check for filters
220
+ filter_sample_selection = self.util.config_val(
221
+ "DATA", "filter.sample_selection", "all"
222
+ )
223
+ if filter_sample_selection == "all":
224
+ datafilter = DataFilter(self.df_train)
225
+ self.df_train = datafilter.all_filters()
226
+ datafilter = DataFilter(self.df_test)
227
+ self.df_test = datafilter.all_filters()
228
+ elif filter_sample_selection == "train":
229
+ datafilter = DataFilter(self.df_train)
230
+ self.df_train = datafilter.all_filters()
231
+ elif filter_sample_selection == "test":
232
+ datafilter = DataFilter(self.df_test)
233
+ self.df_test = datafilter.all_filters()
234
+ else:
235
+ self.util.error(
236
+ "unkown filter sample selection specifier"
237
+ f" {filter_sample_selection}, should be [all | train | test]"
238
+ )
239
+
240
+ # encode the labels
241
+ if self.util.exp_is_classification():
242
+ datatype = self.util.config_val("DATA", "type", "dummy")
243
+ if datatype == "continuous":
244
+ # if self.df_test.is_labeled:
245
+ # # remember the target in case they get labelencoded later
246
+ # self.df_test["class_label"] = self.df_test[self.target]
247
+ test_cats = self.df_test["class_label"].unique()
248
+ # else:
249
+ # # if there is no target, copy a dummy label
250
+ # self.df_test = self._add_random_target(self.df_test)
251
+ # if self.df_train.is_labeled:
252
+ # # remember the target in case they get labelencoded later
253
+ # self.df_train["class_label"] = self.df_train[self.target]
254
+ train_cats = self.df_train["class_label"].unique()
255
+
256
+ else:
257
+ if self.df_test.is_labeled:
258
+ test_cats = self.df_test[self.target].unique()
259
+ else:
260
+ # if there is no target, copy a dummy label
261
+ self.df_test = self._add_random_target(self.df_test).astype("str")
262
+ train_cats = self.df_train[self.target].unique()
263
+ # print(f"df_train: {pd.DataFrame(self.df_train[self.target])}")
264
+ # print(f"train_cats with target {self.target}: {train_cats}")
265
+ if self.df_test.is_labeled:
266
+ if type(test_cats) == np.ndarray:
267
+ self.util.debug(f"Categories test (nd.array): {test_cats}")
268
+ else:
269
+ self.util.debug(f"Categories test (list): {list(test_cats)}")
270
+ if type(train_cats) == np.ndarray:
271
+ self.util.debug(f"Categories train (nd.array): {train_cats}")
272
+ else:
273
+ self.util.debug(f"Categories train (list): {list(train_cats)}")
274
+
275
+ # encode the labels as numbers
276
+ self.label_encoder = LabelEncoder()
277
+ self.df_train[self.target] = self.label_encoder.fit_transform(
278
+ self.df_train[self.target]
279
+ )
280
+ self.df_test[self.target] = self.label_encoder.transform(
281
+ self.df_test[self.target]
282
+ )
283
+ glob_conf.set_label_encoder(self.label_encoder)
284
+ if self.got_speaker:
285
+ self.util.debug(
286
+ f"{self.df_test.speaker.nunique()} speakers in test and"
287
+ f" {self.df_train.speaker.nunique()} speakers in train"
288
+ )
289
+
290
+ target_factor = self.util.config_val("DATA", "target_divide_by", False)
291
+ if target_factor:
292
+ self.df_test[self.target] = self.df_test[self.target] / float(target_factor)
293
+ self.df_train[self.target] = self.df_train[self.target] / float(
294
+ target_factor
295
+ )
296
+ if not self.util.exp_is_classification():
297
+ self.df_test["class_label"] = self.df_test["class_label"] / float(
298
+ target_factor
299
+ )
300
+ self.df_train["class_label"] = self.df_train["class_label"] / float(
301
+ target_factor
302
+ )
303
+
304
+ def _add_random_target(self, df):
305
+ labels = glob_conf.labels
306
+ a = [None] * len(df)
307
+ for i in range(0, len(df)):
308
+ a[i] = random.choice(labels)
309
+ df[self.target] = a
310
+ return df
311
+
312
+ def plot_distribution(self, df_labels):
313
+ """Plot the distribution of samples and speaker per target class and biological sex"""
314
+ plot = Plots()
315
+ sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
316
+ plot.plot_distributions(df_labels)
317
+ if self.got_speaker:
318
+ plot.plot_distributions_speaker(df_labels)
319
+
320
+ def extract_test_feats(self):
321
+ self.feats_test = pd.DataFrame()
322
+ feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["tests"]))
323
+ feats_types = self.util.config_val_list("FEATS", "type", ["os"])
324
+ self.feature_extractor = FeatureExtractor(
325
+ self.df_test, feats_types, feats_name, "test"
326
+ )
327
+ self.feats_test = self.feature_extractor.extract()
328
+ self.util.debug(f"Test features shape:{self.feats_test.shape}")
329
+
330
+ def extract_feats(self):
331
+ """Extract the features for train and dev sets.
332
+
333
+ They will be stored on disk and need to be removed manually.
334
+
335
+ The string FEATS.feats_type is read from the config, defaults to os.
336
+
337
+ """
338
+ df_train, df_test = self.df_train, self.df_test
339
+ feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
340
+ self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
341
+ feats_types = self.util.config_val_list("FEATS", "type", ["os"])
342
+ self.feature_extractor = FeatureExtractor(
343
+ df_train, feats_types, feats_name, "train"
344
+ )
345
+ self.feats_train = self.feature_extractor.extract()
346
+ self.feature_extractor = FeatureExtractor(
347
+ df_test, feats_types, feats_name, "test"
348
+ )
349
+ self.feats_test = self.feature_extractor.extract()
350
+ self.util.debug(
351
+ f"All features: train shape : {self.feats_train.shape}, test"
352
+ f" shape:{self.feats_test.shape}"
353
+ )
354
+ if self.feats_train.shape[0] < self.df_train.shape[0]:
355
+ self.util.warn(
356
+ f"train feats ({self.feats_train.shape[0]}) != train labels"
357
+ f" ({self.df_train.shape[0]})"
358
+ )
359
+ self.df_train = self.df_train[
360
+ self.df_train.index.isin(self.feats_train.index)
361
+ ]
362
+ self.util.warn(f"new train labels shape: {self.df_train.shape[0]}")
363
+ if self.feats_test.shape[0] < self.df_test.shape[0]:
364
+ self.util.warn(
365
+ f"test feats ({self.feats_test.shape[0]}) != test labels"
366
+ f" ({self.df_test.shape[0]})"
367
+ )
368
+ self.df_test = self.df_test[self.df_test.index.isin(self.feats_test.index)]
369
+ self.util.warn(f"mew test labels shape: {self.df_test.shape[0]}")
370
+
371
+ self._check_scale()
372
+
373
+ def augment(self):
374
+ """
375
+ Augment the selected samples
376
+ """
377
+ from nkululeko.augmenting.augmenter import Augmenter
378
+
379
+ sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
380
+ if sample_selection == "all":
381
+ df = pd.concat([self.df_train, self.df_test])
382
+ elif sample_selection == "train":
383
+ df = self.df_train
384
+ elif sample_selection == "test":
385
+ df = self.df_test
386
+ else:
387
+ self.util.error(
388
+ f"unknown augmentation selection specifier {sample_selection},"
389
+ " should be [all | train | test]"
390
+ )
391
+
392
+ augmenter = Augmenter(df)
393
+ df_ret = augmenter.augment(sample_selection)
394
+ return df_ret
395
+
396
+ def autopredict(self):
397
+ """
398
+ Predict labels for samples with existing models and add to the dataframe.
399
+ """
400
+ sample_selection = self.util.config_val("PREDICT", "split", "all")
401
+ if sample_selection == "all":
402
+ df = pd.concat([self.df_train, self.df_test])
403
+ elif sample_selection == "train":
404
+ df = self.df_train
405
+ elif sample_selection == "test":
406
+ df = self.df_test
407
+ else:
408
+ self.util.error(
409
+ f"unknown augmentation selection specifier {sample_selection},"
410
+ " should be [all | train | test]"
411
+ )
412
+ targets = self.util.config_val_list("PREDICT", "targets", ["gender"])
413
+ for target in targets:
414
+ if target == "gender":
415
+ from nkululeko.autopredict.ap_gender import GenderPredictor
416
+
417
+ predictor = GenderPredictor(df)
418
+ df = predictor.predict(sample_selection)
419
+ elif target == "age":
420
+ from nkululeko.autopredict.ap_age import AgePredictor
421
+
422
+ predictor = AgePredictor(df)
423
+ df = predictor.predict(sample_selection)
424
+ elif target == "snr":
425
+ from nkululeko.autopredict.ap_snr import SNRPredictor
426
+
427
+ predictor = SNRPredictor(df)
428
+ df = predictor.predict(sample_selection)
429
+ elif target == "mos":
430
+ from nkululeko.autopredict.ap_mos import MOSPredictor
431
+
432
+ predictor = MOSPredictor(df)
433
+ df = predictor.predict(sample_selection)
434
+ elif target == "pesq":
435
+ from nkululeko.autopredict.ap_pesq import PESQPredictor
436
+
437
+ predictor = PESQPredictor(df)
438
+ df = predictor.predict(sample_selection)
439
+ elif target == "sdr":
440
+ from nkululeko.autopredict.ap_sdr import SDRPredictor
441
+
442
+ predictor = SDRPredictor(df)
443
+ df = predictor.predict(sample_selection)
444
+ elif target == "stoi":
445
+ from nkululeko.autopredict.ap_stoi import STOIPredictor
446
+
447
+ predictor = STOIPredictor(df)
448
+ df = predictor.predict(sample_selection)
449
+ elif target == "arousal":
450
+ from nkululeko.autopredict.ap_arousal import ArousalPredictor
451
+
452
+ predictor = ArousalPredictor(df)
453
+ df = predictor.predict(sample_selection)
454
+ elif target == "valence":
455
+ from nkululeko.autopredict.ap_valence import ValencePredictor
456
+
457
+ predictor = ValencePredictor(df)
458
+ df = predictor.predict(sample_selection)
459
+ elif target == "dominance":
460
+ from nkululeko.autopredict.ap_dominance import DominancePredictor
461
+
462
+ predictor = DominancePredictor(df)
463
+ df = predictor.predict(sample_selection)
464
+ else:
465
+ self.util.error(f"unknown auto predict target: {target}")
466
+ return df
467
+
468
+ def random_splice(self):
469
+ """
470
+ Random-splice the selected samples
471
+ """
472
+ from nkululeko.augmenting.randomsplicer import Randomsplicer
473
+
474
+ sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
475
+ if sample_selection == "all":
476
+ df = pd.concat([self.df_train, self.df_test])
477
+ elif sample_selection == "train":
478
+ df = self.df_train
479
+ elif sample_selection == "test":
480
+ df = self.df_test
481
+ else:
482
+ self.util.error(
483
+ f"unknown augmentation selection specifier {sample_selection},"
484
+ " should be [all | train | test]"
485
+ )
486
+ randomsplicer = Randomsplicer(df)
487
+ df_ret = randomsplicer.run(sample_selection)
488
+ return df_ret
489
+
490
+ def analyse_features(self, needs_feats):
491
+ """Do a feature exploration."""
492
+ plot_feats = eval(
493
+ self.util.config_val("EXPL", "feature_distributions", "False")
494
+ )
495
+ sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
496
+ # get the data labels
497
+ if sample_selection == "all":
498
+ df_labels = pd.concat([self.df_train, self.df_test])
499
+ self.util.copy_flags(self.df_train, df_labels)
500
+ elif sample_selection == "train":
501
+ df_labels = self.df_train
502
+ self.util.copy_flags(self.df_train, df_labels)
503
+ elif sample_selection == "test":
504
+ df_labels = self.df_test
505
+ self.util.copy_flags(self.df_test, df_labels)
506
+ else:
507
+ self.util.error(
508
+ f"unknown sample selection specifier {sample_selection}, should"
509
+ " be [all | train | test]"
510
+ )
511
+ self.util.debug(f"sampling selection: {sample_selection}")
512
+ if self.util.config_val("EXPL", "value_counts", False):
513
+ self.plot_distribution(df_labels)
514
+
515
+ # check if data should be shown with the spotlight data visualizer
516
+ spotlight = eval(self.util.config_val("EXPL", "spotlight", "False"))
517
+ if spotlight:
518
+ self.util.debug("opening spotlight tab in web browser")
519
+ from renumics import spotlight
520
+
521
+ spotlight.show(df_labels.reset_index())
522
+
523
+ if not needs_feats:
524
+ return
525
+ # get the feature values
526
+ if sample_selection == "all":
527
+ df_feats = pd.concat([self.feats_train, self.feats_test])
528
+ elif sample_selection == "train":
529
+ df_feats = self.feats_train
530
+ elif sample_selection == "test":
531
+ df_feats = self.feats_test
532
+ else:
533
+ self.util.error(
534
+ f"unknown sample selection specifier {sample_selection}, should"
535
+ " be [all | train | test]"
536
+ )
537
+ feat_analyser = FeatureAnalyser(sample_selection, df_labels, df_feats)
538
+ # check if SHAP features should be analysed
539
+ shap = eval(self.util.config_val("EXPL", "shap", "False"))
540
+ if shap:
541
+ feat_analyser.analyse_shap(self.runmgr.get_best_model())
542
+
543
+ if plot_feats:
544
+ feat_analyser.analyse()
545
+
546
+ # check if a scatterplot should be done
547
+ scatter_var = eval(self.util.config_val("EXPL", "scatter", "False"))
548
+ scatter_target = self.util.config_val(
549
+ "EXPL", "scatter.target", "['class_label']"
550
+ )
551
+ if scatter_var:
552
+ scatters = ast.literal_eval(glob_conf.config["EXPL"]["scatter"])
553
+ scat_targets = ast.literal_eval(scatter_target)
554
+ plots = Plots()
555
+ for scat_target in scat_targets:
556
+ if self.util.is_categorical(df_labels[scat_target]):
557
+ for scatter in scatters:
558
+ plots.scatter_plot(df_feats, df_labels, scat_target, scatter)
559
+ else:
560
+ self.util.debug(
561
+ f"{self.name}: binning continuous variable to categories"
562
+ )
563
+ cat_vals = self.util.continuous_to_categorical(
564
+ df_labels[scat_target]
565
+ )
566
+ df_labels[f"{scat_target}_bins"] = cat_vals.values
567
+ for scatter in scatters:
568
+ plots.scatter_plot(
569
+ df_feats, df_labels, f"{scat_target}_bins", scatter
570
+ )
571
+
572
+ def _check_scale(self):
573
+ scale_feats = self.util.config_val("FEATS", "scale", False)
574
+ # print the scale
575
+ self.util.debug(f"scaler: {scale_feats}")
576
+ if scale_feats:
577
+ self.scaler_feats = Scaler(
578
+ self.df_train,
579
+ self.df_test,
580
+ self.feats_train,
581
+ self.feats_test,
582
+ scale_feats,
583
+ )
584
+ self.feats_train, self.feats_test = self.scaler_feats.scale()
585
+ # store versions
586
+ self.util.save_to_store(self.feats_train, "feats_train_scaled")
587
+ self.util.save_to_store(self.feats_test, "feats_test_scaled")
588
+
589
+ def init_runmanager(self):
590
+ """Initialize the manager object for the runs."""
591
+ self.runmgr = Runmanager(
592
+ self.df_train, self.df_test, self.feats_train, self.feats_test
593
+ )
594
+
595
+ def run(self):
596
+ """Do the runs."""
597
+ self.runmgr.do_runs()
598
+
599
+ # access the best results all runs
600
+ self.reports = self.runmgr.best_results
601
+ last_epochs = self.runmgr.last_epochs
602
+ # try to save yourself
603
+ save = self.util.config_val("EXP", "save", False)
604
+ if save:
605
+ # save the experiment for future use
606
+ self.save(self.util.get_save_name())
607
+ # self.save_onnx(self.util.get_save_name())
608
+
609
+ # self.__collect_reports()
610
+ self.util.print_best_results(self.reports)
611
+
612
+ # check if the test predictions should be saved to disk
613
+ test_pred_file = self.util.config_val("EXP", "save_test", False)
614
+ if test_pred_file:
615
+ self.predict_test_and_save(test_pred_file)
616
+
617
+ # check if the majority voting for all speakers should be plotted
618
+ conf_mat_per_speaker_function = self.util.config_val(
619
+ "PLOT", "combine_per_speaker", False
620
+ )
621
+ if conf_mat_per_speaker_function:
622
+ self.plot_confmat_per_speaker(conf_mat_per_speaker_function)
623
+ used_time = time.process_time() - self.start
624
+ self.util.debug(f"Done, used {used_time:.3f} seconds")
625
+
626
+ # check if a test set should be labeled by the model:
627
+ label_data = self.util.config_val("DATA", "label_data", False)
628
+ label_result = self.util.config_val("DATA", "label_result", False)
629
+ if label_data and label_result:
630
+ self.predict_test_and_save(label_result)
631
+
632
+ return self.reports, last_epochs
633
+
634
+ def plot_confmat_per_speaker(self, function):
635
+ if self.loso or self.logo or self.xfoldx:
636
+ self.util.debug(
637
+ "plot combined speaker predictions not possible for cross" " validation"
638
+ )
639
+ return
640
+ best = self.get_best_report(self.reports)
641
+ # if not best.is_classification:
642
+ # best.continuous_to_categorical()
643
+ truths = best.truths
644
+ preds = best.preds
645
+ speakers = self.df_test.speaker.values
646
+ print(f"{len(truths)} {len(preds)} {len(speakers) }")
647
+ df = pd.DataFrame(data={"truth": truths, "pred": preds, "speaker": speakers})
648
+ plot_name = "result_combined_per_speaker"
649
+ self.util.debug(
650
+ f"plotting speaker combination ({function}) confusion matrix to"
651
+ f" {plot_name}"
652
+ )
653
+ best.plot_per_speaker(df, plot_name, function)
654
+
655
+ def get_best_report(self, reports):
656
+ return self.runmgr.get_best_result(reports)
657
+
658
+ def print_best_model(self):
659
+ self.runmgr.print_best_result_runs()
660
+
661
+ def demo(self, file, is_list, outfile):
662
+ model = self.runmgr.get_best_model()
663
+ labelEncoder = None
664
+ try:
665
+ labelEncoder = self.label_encoder
666
+ except AttributeError:
667
+ pass
668
+ demo = Demo_predictor(
669
+ model, file, is_list, self.feature_extractor, labelEncoder, outfile
670
+ )
671
+ demo.run_demo()
672
+
673
+ def predict_test_and_save(self, result_name):
674
+ model = self.runmgr.get_best_model()
675
+ model.set_testdata(self.df_test, self.feats_test)
676
+ test_predictor = TestPredictor(
677
+ model, self.df_test, self.label_encoder, result_name
678
+ )
679
+ result = test_predictor.predict_and_store()
680
+ return result
681
+
682
+ def load(self, filename):
683
+ try:
684
+ f = open(filename, "rb")
685
+ tmp_dict = pickle.load(f)
686
+ f.close()
687
+ except EOFError as eof:
688
+ self.util.error(f"can't open file {filename}: {eof}")
689
+ self.__dict__.update(tmp_dict)
690
+ glob_conf.set_labels(self.labels)
691
+
692
+ def save(self, filename):
693
+ if self.runmgr.modelrunner.model.is_ann():
694
+ self.runmgr.modelrunner.model = None
695
+ self.util.warn(
696
+ "Save experiment: Can't pickle the trained model so saving without it. (it should be stored anyway)"
697
+ )
698
+ try:
699
+ f = open(filename, "wb")
700
+ pickle.dump(self.__dict__, f)
701
+ f.close()
702
+ except (TypeError, AttributeError) as error:
703
+ self.feature_extractor.feat_extractor.model = None
704
+ f = open(filename, "wb")
705
+ pickle.dump(self.__dict__, f)
706
+ f.close()
707
+ self.util.warn(
708
+ "Save experiment: Can't pickle the feature extraction model so saving without it."
709
+ + f"{type(error).__name__} {error}"
710
+ )
711
+ except RuntimeError as error:
712
+ self.util.warn(
713
+ "Save experiment: Can't pickle local object, NOT saving: "
714
+ + f"{type(error).__name__} {error}"
715
+ )
716
+
717
+ def save_onnx(self, filename):
718
+ # export the model to onnx
719
+ model = self.runmgr.get_best_model()
720
+ if model.is_ann():
721
+ print("converting to onnx from torch")
722
+ else:
723
+
724
+ print("converting to onnx from sklearn")
725
+ # save the rest
726
+ f = open(filename, "wb")
727
+ pickle.dump(self.__dict__, f)
728
+ f.close()