nkululeko 0.89.1__py3-none-any.whl → 0.89.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/ensemble.py +8 -4
- nkululeko/explore.py +3 -1
- nkululeko/feat_extract/feats_analyser.py +17 -3
- nkululeko/modelrunner.py +2 -2
- nkululeko/reporting/reporter.py +4 -3
- {nkululeko-0.89.1.dist-info → nkululeko-0.89.2.dist-info}/METADATA +5 -1
- {nkululeko-0.89.1.dist-info → nkululeko-0.89.2.dist-info}/RECORD +11 -11
- {nkululeko-0.89.1.dist-info → nkululeko-0.89.2.dist-info}/WHEEL +1 -1
- {nkululeko-0.89.1.dist-info → nkululeko-0.89.2.dist-info}/LICENSE +0 -0
- {nkululeko-0.89.1.dist-info → nkululeko-0.89.2.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.89.
|
1
|
+
VERSION="0.89.2"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/ensemble.py
CHANGED
@@ -18,6 +18,7 @@ Raises:
|
|
18
18
|
#!/usr/bin/env python
|
19
19
|
# -*- coding: utf-8 -*-
|
20
20
|
|
21
|
+
|
21
22
|
from typing import List
|
22
23
|
import configparser
|
23
24
|
import time
|
@@ -26,10 +27,15 @@ from pathlib import Path
|
|
26
27
|
|
27
28
|
import numpy as np
|
28
29
|
import pandas as pd
|
30
|
+
import matplotlib.pyplot as plt
|
31
|
+
|
29
32
|
from sklearn.metrics import(
|
33
|
+
RocCurveDisplay,
|
30
34
|
balanced_accuracy_score,
|
31
35
|
classification_report,
|
32
|
-
|
36
|
+
auc,
|
37
|
+
roc_auc_score,
|
38
|
+
roc_curve
|
33
39
|
)
|
34
40
|
|
35
41
|
from nkululeko.constants import VERSION
|
@@ -289,9 +295,7 @@ def ensemble_predictions(
|
|
289
295
|
uar = balanced_accuracy_score(truth, predicted)
|
290
296
|
acc = (truth == predicted).mean()
|
291
297
|
# print classification report
|
292
|
-
Util("ensemble").debug(f"\n {classification_report(truth, predicted)}")
|
293
|
-
# f1 = f1_score(truth, predicted, pos_label='p')
|
294
|
-
# Util("ensemble").debug(f"F1: {f1:.3f}")
|
298
|
+
Util("ensemble").debug(f"\n {classification_report(truth, predicted, digits=4)}")
|
295
299
|
Util("ensemble").debug(f"{method}: UAR: {uar:.3f}, ACC: {acc:.3f}")
|
296
300
|
|
297
301
|
return ensemble_preds
|
nkululeko/explore.py
CHANGED
@@ -91,7 +91,9 @@ def main(src_dir):
|
|
91
91
|
# these investigations need features to explore
|
92
92
|
expr.extract_feats()
|
93
93
|
needs_feats = True
|
94
|
-
|
94
|
+
# explore
|
95
|
+
expr.init_runmanager()
|
96
|
+
expr.runmgr.do_runs()
|
95
97
|
expr.analyse_features(needs_feats)
|
96
98
|
expr.store_report()
|
97
99
|
print("DONE")
|
@@ -50,19 +50,32 @@ class FeatureAnalyser:
|
|
50
50
|
|
51
51
|
name = "my_shap_values"
|
52
52
|
if not self.util.exist_pickle(name):
|
53
|
-
|
53
|
+
# get model name
|
54
|
+
model_name = self.util.get_model_type()
|
55
|
+
if hasattr(model, "predict_shap"):
|
56
|
+
model_func = model.predict_shap
|
57
|
+
elif hasattr(model, "clf"):
|
58
|
+
model_func = model.clf.predict
|
59
|
+
else:
|
60
|
+
raise Exception("Model not supported for SHAP analysis")
|
61
|
+
|
62
|
+
self.util.debug(f"using SHAP explainer for {model_name} model")
|
63
|
+
|
54
64
|
explainer = shap.Explainer(
|
55
|
-
|
65
|
+
model_func,
|
56
66
|
self.features,
|
57
67
|
output_names=glob_conf.labels,
|
58
68
|
algorithm="permutation",
|
59
69
|
npermutations=5,
|
60
70
|
)
|
71
|
+
|
61
72
|
self.util.debug("computing SHAP values...")
|
62
73
|
shap_values = explainer(self.features)
|
63
74
|
self.util.to_pickle(shap_values, name)
|
64
75
|
else:
|
65
76
|
shap_values = self.util.from_pickle(name)
|
77
|
+
# plt.figure()
|
78
|
+
plt.close('all')
|
66
79
|
plt.tight_layout()
|
67
80
|
shap.plots.bar(shap_values)
|
68
81
|
fig_dir = self.util.get_path("fig_dir") + "../" # one up because of the runs
|
@@ -71,7 +84,8 @@ class FeatureAnalyser:
|
|
71
84
|
filename = f"_SHAP_{model.name}"
|
72
85
|
filename = f"{fig_dir}{exp_name}{filename}.{format}"
|
73
86
|
plt.savefig(filename)
|
74
|
-
|
87
|
+
plt.close()
|
88
|
+
self.util.debug(f"plotted SHAP feature importance to {filename}")
|
75
89
|
|
76
90
|
def analyse(self):
|
77
91
|
models = ast.literal_eval(self.util.config_val("EXPL", "model", "['log_reg']"))
|
nkululeko/modelrunner.py
CHANGED
@@ -53,8 +53,8 @@ class Modelrunner:
|
|
53
53
|
# epochs are handled by Huggingface API
|
54
54
|
self.model.train()
|
55
55
|
report = self.model.predict()
|
56
|
-
# todo: findout the best epoch
|
57
|
-
# since
|
56
|
+
# todo: findout the best epoch -> no need
|
57
|
+
# since load_best_model_at_end is given in training args
|
58
58
|
epoch = epoch_num
|
59
59
|
report.set_id(self.run, epoch)
|
60
60
|
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
nkululeko/reporting/reporter.py
CHANGED
@@ -402,7 +402,7 @@ class Reporter:
|
|
402
402
|
)
|
403
403
|
# print classifcation report in console
|
404
404
|
self.util.debug(
|
405
|
-
f"\n {classification_report(self.truths, self.preds, target_names=labels)}"
|
405
|
+
f"\n {classification_report(self.truths, self.preds, target_names=labels, digits=4)}"
|
406
406
|
)
|
407
407
|
except ValueError as e:
|
408
408
|
self.util.debug(
|
@@ -422,16 +422,17 @@ class Reporter:
|
|
422
422
|
if len(np.unique(self.truths)) == 2:
|
423
423
|
fpr, tpr, _ = roc_curve(self.truths, self.preds)
|
424
424
|
auc_score = auc(fpr, tpr)
|
425
|
+
plot_path = f"{fig_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}_roc.{self.format}"
|
426
|
+
plt.figure()
|
425
427
|
display = RocCurveDisplay(
|
426
428
|
fpr=fpr,
|
427
429
|
tpr=tpr,
|
428
430
|
roc_auc=auc_score,
|
429
431
|
estimator_name=f"{self.model_type} estimator",
|
430
432
|
)
|
431
|
-
# save plot
|
432
|
-
plot_path = f"{fig_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}_roc.{self.format}"
|
433
433
|
display.plot(ax=None)
|
434
434
|
plt.savefig(plot_path)
|
435
|
+
plt.close()
|
435
436
|
self.util.debug(f"Saved ROC curve to {plot_path}")
|
436
437
|
pauc_score = roc_auc_score(self.truths, self.preds, max_fpr=0.1)
|
437
438
|
auc_pauc = f"auc: {auc_score:.3f}, pauc: {pauc_score:.3f} from epoch: {epoch}"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.89.
|
3
|
+
Version: 0.89.2
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -356,6 +356,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
356
356
|
Changelog
|
357
357
|
=========
|
358
358
|
|
359
|
+
Version 0.89.2
|
360
|
+
--------------
|
361
|
+
* fix shap value calculation
|
362
|
+
|
359
363
|
Version 0.89.1
|
360
364
|
--------------
|
361
365
|
* print and save result of feature importance
|
@@ -2,19 +2,19 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=WFGVylIst9Be_eHBZ9GiR43Qi4CdRySmNUzyNox6aMM,39
|
6
6
|
nkululeko/demo.py,sha256=bLuHkeEl5rOfm7ecGHCcWATiPK7-njNbtrGljxzNzFs,5088
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=zs1bjhpnKuNCPLJeiyDm19ME1NEDOQT3QNeyVKJq9Yc,4882
|
9
|
-
nkululeko/ensemble.py,sha256=
|
9
|
+
nkululeko/ensemble.py,sha256=MayHpngGH_FTvSxUsH3NdxJd6WBAosGRFQeQ7cMjIco,12922
|
10
10
|
nkululeko/experiment.py,sha256=L4PzoScPLG2xTyniVy9evcBy_8CIe3RTeTEUVTqiuvQ,31186
|
11
|
-
nkululeko/explore.py,sha256=
|
11
|
+
nkululeko/explore.py,sha256=AbTVDmuDIaLfALQGvDW1yndcw2ikaEVEZ_fJVuUS070,3940
|
12
12
|
nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
|
13
13
|
nkululeko/feature_extractor.py,sha256=UnspIWz3XrNhKnBBhWZkH2bHvD-sROtrQVqB1JvkUyw,4088
|
14
14
|
nkululeko/file_checker.py,sha256=LoLnL8aHpW-axMQ46qbqrManTs5otG9ShpEZuz9iRSk,3474
|
15
15
|
nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
|
16
16
|
nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
|
17
|
-
nkululeko/modelrunner.py,sha256=
|
17
|
+
nkululeko/modelrunner.py,sha256=lJy-xM4QfDDWeL0dLTE_VIb4sYrnd_Z_yJRK3wwohQA,11199
|
18
18
|
nkululeko/multidb.py,sha256=CCjmVsZyvydgOztFlaeBvOJH8nsvU-sPQdFAw8-q0U4,6752
|
19
19
|
nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
|
20
20
|
nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
|
@@ -51,7 +51,7 @@ nkululeko/data/dataset_csv.py,sha256=UGEpi__eT2KFS6Fop6N4HkMrzO-u5VP71gt44kwZavo
|
|
51
51
|
nkululeko/feat_extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
52
52
|
nkululeko/feat_extract/feats_agender.py,sha256=tMK3_qs8adylNNSR0CS1RjU9RxmpumLqmuyzmc2ZYjA,3184
|
53
53
|
nkululeko/feat_extract/feats_agender_agender.py,sha256=19NoRT0KJ8WoZ3EabTYexXymD7bDy58-H20jYmdqjD0,3498
|
54
|
-
nkululeko/feat_extract/feats_analyser.py,sha256=
|
54
|
+
nkululeko/feat_extract/feats_analyser.py,sha256=rSsN6kcDUv64DaTl2DvReXF3_g7CtSwiPKgMzbJPqVI,13516
|
55
55
|
nkululeko/feat_extract/feats_ast.py,sha256=ycJn5eSVOxcEpmeHVk0FPB8q5XiTC8VSKz61L9n0Wa4,4638
|
56
56
|
nkululeko/feat_extract/feats_auddim.py,sha256=ulP_o4SGeQDFTs8YYCGKgccARAo6-wcjPK6-hhGjmn8,3155
|
57
57
|
nkululeko/feat_extract/feats_audmodel.py,sha256=aRGTBDKdYaTT_9xDaFZqpuyPhzxSNN_3b1PJDUHtYW4,3180
|
@@ -98,7 +98,7 @@ nkululeko/reporting/defines.py,sha256=IsY1YgKRMaABpylVKjBJgJ5bNCEbGCVA_E6pivraqS
|
|
98
98
|
nkululeko/reporting/latex_writer.py,sha256=qiCRSmB4KOD_za4oHu5x-PhwjZohzfo8wecMOwlXZwc,1886
|
99
99
|
nkululeko/reporting/report.py,sha256=W0rcigDdjBvxZQ3pZja_gvToILYvaZ1BFtnN2qFRfYI,1060
|
100
100
|
nkululeko/reporting/report_item.py,sha256=siWeGNgo4bAE46YBMNcsdf3jTMTy76BO9Fi6DTvDig4,533
|
101
|
-
nkululeko/reporting/reporter.py,sha256=
|
101
|
+
nkululeko/reporting/reporter.py,sha256=oodLaNZXqPpfoRqVxTldYcx68oN35OGgy-vvbAuY-yI,20039
|
102
102
|
nkululeko/reporting/result.py,sha256=G63a2tHCwHhM6NBJgYzsWKWJm4Yu3r4hsCHA2Km7eHU,1073
|
103
103
|
nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
104
104
|
nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=pmLHuXsaqvcdYxB4PSW9l1mbQWZZBJFhi_CGabqydas,1947
|
@@ -107,8 +107,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
107
107
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
108
108
|
nkululeko/utils/stats.py,sha256=eC9dMO-by6CDnGLHDBQu-2B4-BudZNJ0nnWGhKYdUMA,2968
|
109
109
|
nkululeko/utils/util.py,sha256=363Lgmcg6fPKCGbroX0DDyW_zcYNx-Ayqv67qdpfYcw,16710
|
110
|
-
nkululeko-0.89.
|
111
|
-
nkululeko-0.89.
|
112
|
-
nkululeko-0.89.
|
113
|
-
nkululeko-0.89.
|
114
|
-
nkululeko-0.89.
|
110
|
+
nkululeko-0.89.2.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
111
|
+
nkululeko-0.89.2.dist-info/METADATA,sha256=00CLy_4Wm7IktJy7dAkKrXkCMi0f1HUXCoQYMNcp2kw,40729
|
112
|
+
nkululeko-0.89.2.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
113
|
+
nkululeko-0.89.2.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
114
|
+
nkululeko-0.89.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|