nkululeko 0.89.0__py3-none-any.whl → 0.89.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/ensemble.py +13 -1
- nkululeko/explore.py +3 -1
- nkululeko/feat_extract/feats_analyser.py +24 -7
- nkululeko/modelrunner.py +2 -2
- nkululeko/reporting/reporter.py +4 -3
- {nkululeko-0.89.0.dist-info → nkululeko-0.89.2.dist-info}/METADATA +9 -1
- {nkululeko-0.89.0.dist-info → nkululeko-0.89.2.dist-info}/RECORD +11 -11
- {nkululeko-0.89.0.dist-info → nkululeko-0.89.2.dist-info}/WHEEL +1 -1
- {nkululeko-0.89.0.dist-info → nkululeko-0.89.2.dist-info}/LICENSE +0 -0
- {nkululeko-0.89.0.dist-info → nkululeko-0.89.2.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.89.
|
1
|
+
VERSION="0.89.2"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/ensemble.py
CHANGED
@@ -18,6 +18,7 @@ Raises:
|
|
18
18
|
#!/usr/bin/env python
|
19
19
|
# -*- coding: utf-8 -*-
|
20
20
|
|
21
|
+
|
21
22
|
from typing import List
|
22
23
|
import configparser
|
23
24
|
import time
|
@@ -26,7 +27,16 @@ from pathlib import Path
|
|
26
27
|
|
27
28
|
import numpy as np
|
28
29
|
import pandas as pd
|
29
|
-
|
30
|
+
import matplotlib.pyplot as plt
|
31
|
+
|
32
|
+
from sklearn.metrics import(
|
33
|
+
RocCurveDisplay,
|
34
|
+
balanced_accuracy_score,
|
35
|
+
classification_report,
|
36
|
+
auc,
|
37
|
+
roc_auc_score,
|
38
|
+
roc_curve
|
39
|
+
)
|
30
40
|
|
31
41
|
from nkululeko.constants import VERSION
|
32
42
|
from nkululeko.experiment import Experiment
|
@@ -284,6 +294,8 @@ def ensemble_predictions(
|
|
284
294
|
predicted = ensemble_preds["predicted"]
|
285
295
|
uar = balanced_accuracy_score(truth, predicted)
|
286
296
|
acc = (truth == predicted).mean()
|
297
|
+
# print classification report
|
298
|
+
Util("ensemble").debug(f"\n {classification_report(truth, predicted, digits=4)}")
|
287
299
|
Util("ensemble").debug(f"{method}: UAR: {uar:.3f}, ACC: {acc:.3f}")
|
288
300
|
|
289
301
|
return ensemble_preds
|
nkululeko/explore.py
CHANGED
@@ -91,7 +91,9 @@ def main(src_dir):
|
|
91
91
|
# these investigations need features to explore
|
92
92
|
expr.extract_feats()
|
93
93
|
needs_feats = True
|
94
|
-
|
94
|
+
# explore
|
95
|
+
expr.init_runmanager()
|
96
|
+
expr.runmgr.do_runs()
|
95
97
|
expr.analyse_features(needs_feats)
|
96
98
|
expr.store_report()
|
97
99
|
print("DONE")
|
@@ -50,19 +50,32 @@ class FeatureAnalyser:
|
|
50
50
|
|
51
51
|
name = "my_shap_values"
|
52
52
|
if not self.util.exist_pickle(name):
|
53
|
-
|
53
|
+
# get model name
|
54
|
+
model_name = self.util.get_model_type()
|
55
|
+
if hasattr(model, "predict_shap"):
|
56
|
+
model_func = model.predict_shap
|
57
|
+
elif hasattr(model, "clf"):
|
58
|
+
model_func = model.clf.predict
|
59
|
+
else:
|
60
|
+
raise Exception("Model not supported for SHAP analysis")
|
61
|
+
|
62
|
+
self.util.debug(f"using SHAP explainer for {model_name} model")
|
63
|
+
|
54
64
|
explainer = shap.Explainer(
|
55
|
-
|
65
|
+
model_func,
|
56
66
|
self.features,
|
57
67
|
output_names=glob_conf.labels,
|
58
68
|
algorithm="permutation",
|
59
69
|
npermutations=5,
|
60
70
|
)
|
71
|
+
|
61
72
|
self.util.debug("computing SHAP values...")
|
62
73
|
shap_values = explainer(self.features)
|
63
74
|
self.util.to_pickle(shap_values, name)
|
64
75
|
else:
|
65
76
|
shap_values = self.util.from_pickle(name)
|
77
|
+
# plt.figure()
|
78
|
+
plt.close('all')
|
66
79
|
plt.tight_layout()
|
67
80
|
shap.plots.bar(shap_values)
|
68
81
|
fig_dir = self.util.get_path("fig_dir") + "../" # one up because of the runs
|
@@ -71,7 +84,8 @@ class FeatureAnalyser:
|
|
71
84
|
filename = f"_SHAP_{model.name}"
|
72
85
|
filename = f"{fig_dir}{exp_name}{filename}.{format}"
|
73
86
|
plt.savefig(filename)
|
74
|
-
|
87
|
+
plt.close()
|
88
|
+
self.util.debug(f"plotted SHAP feature importance to {filename}")
|
75
89
|
|
76
90
|
def analyse(self):
|
77
91
|
models = ast.literal_eval(self.util.config_val("EXPL", "model", "['log_reg']"))
|
@@ -139,7 +153,7 @@ class FeatureAnalyser:
|
|
139
153
|
elif model_s == "svm":
|
140
154
|
from sklearn.svm import SVC
|
141
155
|
|
142
|
-
c = float(self.util.config_val("MODEL", "C_val", "0
|
156
|
+
c = float(self.util.config_val("MODEL", "C_val", "1.0"))
|
143
157
|
model = SVC(kernel="linear", C=c, gamma="scale")
|
144
158
|
result_importances[model_s] = self._get_importance(
|
145
159
|
model, permutation
|
@@ -205,7 +219,7 @@ class FeatureAnalyser:
|
|
205
219
|
model, permutation
|
206
220
|
)
|
207
221
|
elif model_s == "xgr":
|
208
|
-
from xgboost import
|
222
|
+
from xgboost import XGBRegressor
|
209
223
|
|
210
224
|
model = XGBRegressor()
|
211
225
|
result_importances[model_s] = self._get_importance(
|
@@ -270,12 +284,14 @@ class FeatureAnalyser:
|
|
270
284
|
)
|
271
285
|
)
|
272
286
|
|
287
|
+
# print feature importance values to file and debug and save to result
|
288
|
+
self.util.debug(f"Importance features from {model_name}: features = \n{df_imp['feats'].values.tolist()}")
|
273
289
|
# result file
|
274
290
|
res_dir = self.util.get_path("res_dir")
|
275
291
|
filename = f"_EXPL_{model_name}"
|
276
292
|
if permutation:
|
277
293
|
filename += "_perm"
|
278
|
-
filename = f"{res_dir}{self.util.get_exp_name(only_data=True)}{filename}_{
|
294
|
+
filename = f"{res_dir}{self.util.get_exp_name(only_data=True)}{filename}_{max_feat_num}_fi.txt"
|
279
295
|
with open(filename, "w") as text_file:
|
280
296
|
text_file.write(
|
281
297
|
"features in order of decreasing importance according to model"
|
@@ -283,7 +299,8 @@ class FeatureAnalyser:
|
|
283
299
|
)
|
284
300
|
|
285
301
|
df_imp.to_csv(filename, mode="a")
|
286
|
-
|
302
|
+
self.util.debug(f"Saved feature importance values to {filename}")
|
303
|
+
|
287
304
|
# check if feature distributions should be plotted
|
288
305
|
plot_feats = self.util.config_val("EXPL", "feature_distributions", False)
|
289
306
|
if plot_feats:
|
nkululeko/modelrunner.py
CHANGED
@@ -53,8 +53,8 @@ class Modelrunner:
|
|
53
53
|
# epochs are handled by Huggingface API
|
54
54
|
self.model.train()
|
55
55
|
report = self.model.predict()
|
56
|
-
# todo: findout the best epoch
|
57
|
-
# since
|
56
|
+
# todo: findout the best epoch -> no need
|
57
|
+
# since load_best_model_at_end is given in training args
|
58
58
|
epoch = epoch_num
|
59
59
|
report.set_id(self.run, epoch)
|
60
60
|
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
nkululeko/reporting/reporter.py
CHANGED
@@ -402,7 +402,7 @@ class Reporter:
|
|
402
402
|
)
|
403
403
|
# print classifcation report in console
|
404
404
|
self.util.debug(
|
405
|
-
f"\n {classification_report(self.truths, self.preds, target_names=labels)}"
|
405
|
+
f"\n {classification_report(self.truths, self.preds, target_names=labels, digits=4)}"
|
406
406
|
)
|
407
407
|
except ValueError as e:
|
408
408
|
self.util.debug(
|
@@ -422,16 +422,17 @@ class Reporter:
|
|
422
422
|
if len(np.unique(self.truths)) == 2:
|
423
423
|
fpr, tpr, _ = roc_curve(self.truths, self.preds)
|
424
424
|
auc_score = auc(fpr, tpr)
|
425
|
+
plot_path = f"{fig_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}_roc.{self.format}"
|
426
|
+
plt.figure()
|
425
427
|
display = RocCurveDisplay(
|
426
428
|
fpr=fpr,
|
427
429
|
tpr=tpr,
|
428
430
|
roc_auc=auc_score,
|
429
431
|
estimator_name=f"{self.model_type} estimator",
|
430
432
|
)
|
431
|
-
# save plot
|
432
|
-
plot_path = f"{fig_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}_roc.{self.format}"
|
433
433
|
display.plot(ax=None)
|
434
434
|
plt.savefig(plot_path)
|
435
|
+
plt.close()
|
435
436
|
self.util.debug(f"Saved ROC curve to {plot_path}")
|
436
437
|
pauc_score = roc_auc_score(self.truths, self.preds, max_fpr=0.1)
|
437
438
|
auc_pauc = f"auc: {auc_score:.3f}, pauc: {pauc_score:.3f} from epoch: {epoch}"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.89.
|
3
|
+
Version: 0.89.2
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -356,6 +356,14 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
356
356
|
Changelog
|
357
357
|
=========
|
358
358
|
|
359
|
+
Version 0.89.2
|
360
|
+
--------------
|
361
|
+
* fix shap value calculation
|
362
|
+
|
363
|
+
Version 0.89.1
|
364
|
+
--------------
|
365
|
+
* print and save result of feature importance
|
366
|
+
|
359
367
|
Version 0.89.0
|
360
368
|
--------------
|
361
369
|
* added Roc plots and classification report on Debug
|
@@ -2,19 +2,19 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=WFGVylIst9Be_eHBZ9GiR43Qi4CdRySmNUzyNox6aMM,39
|
6
6
|
nkululeko/demo.py,sha256=bLuHkeEl5rOfm7ecGHCcWATiPK7-njNbtrGljxzNzFs,5088
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=zs1bjhpnKuNCPLJeiyDm19ME1NEDOQT3QNeyVKJq9Yc,4882
|
9
|
-
nkululeko/ensemble.py,sha256=
|
9
|
+
nkululeko/ensemble.py,sha256=MayHpngGH_FTvSxUsH3NdxJd6WBAosGRFQeQ7cMjIco,12922
|
10
10
|
nkululeko/experiment.py,sha256=L4PzoScPLG2xTyniVy9evcBy_8CIe3RTeTEUVTqiuvQ,31186
|
11
|
-
nkululeko/explore.py,sha256=
|
11
|
+
nkululeko/explore.py,sha256=AbTVDmuDIaLfALQGvDW1yndcw2ikaEVEZ_fJVuUS070,3940
|
12
12
|
nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
|
13
13
|
nkululeko/feature_extractor.py,sha256=UnspIWz3XrNhKnBBhWZkH2bHvD-sROtrQVqB1JvkUyw,4088
|
14
14
|
nkululeko/file_checker.py,sha256=LoLnL8aHpW-axMQ46qbqrManTs5otG9ShpEZuz9iRSk,3474
|
15
15
|
nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
|
16
16
|
nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
|
17
|
-
nkululeko/modelrunner.py,sha256=
|
17
|
+
nkululeko/modelrunner.py,sha256=lJy-xM4QfDDWeL0dLTE_VIb4sYrnd_Z_yJRK3wwohQA,11199
|
18
18
|
nkululeko/multidb.py,sha256=CCjmVsZyvydgOztFlaeBvOJH8nsvU-sPQdFAw8-q0U4,6752
|
19
19
|
nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
|
20
20
|
nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
|
@@ -51,7 +51,7 @@ nkululeko/data/dataset_csv.py,sha256=UGEpi__eT2KFS6Fop6N4HkMrzO-u5VP71gt44kwZavo
|
|
51
51
|
nkululeko/feat_extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
52
52
|
nkululeko/feat_extract/feats_agender.py,sha256=tMK3_qs8adylNNSR0CS1RjU9RxmpumLqmuyzmc2ZYjA,3184
|
53
53
|
nkululeko/feat_extract/feats_agender_agender.py,sha256=19NoRT0KJ8WoZ3EabTYexXymD7bDy58-H20jYmdqjD0,3498
|
54
|
-
nkululeko/feat_extract/feats_analyser.py,sha256=
|
54
|
+
nkululeko/feat_extract/feats_analyser.py,sha256=rSsN6kcDUv64DaTl2DvReXF3_g7CtSwiPKgMzbJPqVI,13516
|
55
55
|
nkululeko/feat_extract/feats_ast.py,sha256=ycJn5eSVOxcEpmeHVk0FPB8q5XiTC8VSKz61L9n0Wa4,4638
|
56
56
|
nkululeko/feat_extract/feats_auddim.py,sha256=ulP_o4SGeQDFTs8YYCGKgccARAo6-wcjPK6-hhGjmn8,3155
|
57
57
|
nkululeko/feat_extract/feats_audmodel.py,sha256=aRGTBDKdYaTT_9xDaFZqpuyPhzxSNN_3b1PJDUHtYW4,3180
|
@@ -98,7 +98,7 @@ nkululeko/reporting/defines.py,sha256=IsY1YgKRMaABpylVKjBJgJ5bNCEbGCVA_E6pivraqS
|
|
98
98
|
nkululeko/reporting/latex_writer.py,sha256=qiCRSmB4KOD_za4oHu5x-PhwjZohzfo8wecMOwlXZwc,1886
|
99
99
|
nkululeko/reporting/report.py,sha256=W0rcigDdjBvxZQ3pZja_gvToILYvaZ1BFtnN2qFRfYI,1060
|
100
100
|
nkululeko/reporting/report_item.py,sha256=siWeGNgo4bAE46YBMNcsdf3jTMTy76BO9Fi6DTvDig4,533
|
101
|
-
nkululeko/reporting/reporter.py,sha256=
|
101
|
+
nkululeko/reporting/reporter.py,sha256=oodLaNZXqPpfoRqVxTldYcx68oN35OGgy-vvbAuY-yI,20039
|
102
102
|
nkululeko/reporting/result.py,sha256=G63a2tHCwHhM6NBJgYzsWKWJm4Yu3r4hsCHA2Km7eHU,1073
|
103
103
|
nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
104
104
|
nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=pmLHuXsaqvcdYxB4PSW9l1mbQWZZBJFhi_CGabqydas,1947
|
@@ -107,8 +107,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
107
107
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
108
108
|
nkululeko/utils/stats.py,sha256=eC9dMO-by6CDnGLHDBQu-2B4-BudZNJ0nnWGhKYdUMA,2968
|
109
109
|
nkululeko/utils/util.py,sha256=363Lgmcg6fPKCGbroX0DDyW_zcYNx-Ayqv67qdpfYcw,16710
|
110
|
-
nkululeko-0.89.
|
111
|
-
nkululeko-0.89.
|
112
|
-
nkululeko-0.89.
|
113
|
-
nkululeko-0.89.
|
114
|
-
nkululeko-0.89.
|
110
|
+
nkululeko-0.89.2.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
111
|
+
nkululeko-0.89.2.dist-info/METADATA,sha256=00CLy_4Wm7IktJy7dAkKrXkCMi0f1HUXCoQYMNcp2kw,40729
|
112
|
+
nkululeko-0.89.2.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
113
|
+
nkululeko-0.89.2.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
114
|
+
nkululeko-0.89.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|