nkululeko 0.89.1__py3-none-any.whl → 0.89.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.89.1"
1
+ VERSION="0.89.2"
2
2
  SAMPLING_RATE = 16000
nkululeko/ensemble.py CHANGED
@@ -18,6 +18,7 @@ Raises:
18
18
  #!/usr/bin/env python
19
19
  # -*- coding: utf-8 -*-
20
20
 
21
+
21
22
  from typing import List
22
23
  import configparser
23
24
  import time
@@ -26,10 +27,15 @@ from pathlib import Path
26
27
 
27
28
  import numpy as np
28
29
  import pandas as pd
30
+ import matplotlib.pyplot as plt
31
+
29
32
  from sklearn.metrics import(
33
+ RocCurveDisplay,
30
34
  balanced_accuracy_score,
31
35
  classification_report,
32
- f1_score
36
+ auc,
37
+ roc_auc_score,
38
+ roc_curve
33
39
  )
34
40
 
35
41
  from nkululeko.constants import VERSION
@@ -289,9 +295,7 @@ def ensemble_predictions(
289
295
  uar = balanced_accuracy_score(truth, predicted)
290
296
  acc = (truth == predicted).mean()
291
297
  # print classification report
292
- Util("ensemble").debug(f"\n {classification_report(truth, predicted)}")
293
- # f1 = f1_score(truth, predicted, pos_label='p')
294
- # Util("ensemble").debug(f"F1: {f1:.3f}")
298
+ Util("ensemble").debug(f"\n {classification_report(truth, predicted, digits=4)}")
295
299
  Util("ensemble").debug(f"{method}: UAR: {uar:.3f}, ACC: {acc:.3f}")
296
300
 
297
301
  return ensemble_preds
nkululeko/explore.py CHANGED
@@ -91,7 +91,9 @@ def main(src_dir):
91
91
  # these investigations need features to explore
92
92
  expr.extract_feats()
93
93
  needs_feats = True
94
- # explore
94
+ # explore
95
+ expr.init_runmanager()
96
+ expr.runmgr.do_runs()
95
97
  expr.analyse_features(needs_feats)
96
98
  expr.store_report()
97
99
  print("DONE")
@@ -50,19 +50,32 @@ class FeatureAnalyser:
50
50
 
51
51
  name = "my_shap_values"
52
52
  if not self.util.exist_pickle(name):
53
-
53
+ # get model name
54
+ model_name = self.util.get_model_type()
55
+ if hasattr(model, "predict_shap"):
56
+ model_func = model.predict_shap
57
+ elif hasattr(model, "clf"):
58
+ model_func = model.clf.predict
59
+ else:
60
+ raise Exception("Model not supported for SHAP analysis")
61
+
62
+ self.util.debug(f"using SHAP explainer for {model_name} model")
63
+
54
64
  explainer = shap.Explainer(
55
- model.predict_shap,
65
+ model_func,
56
66
  self.features,
57
67
  output_names=glob_conf.labels,
58
68
  algorithm="permutation",
59
69
  npermutations=5,
60
70
  )
71
+
61
72
  self.util.debug("computing SHAP values...")
62
73
  shap_values = explainer(self.features)
63
74
  self.util.to_pickle(shap_values, name)
64
75
  else:
65
76
  shap_values = self.util.from_pickle(name)
77
+ # plt.figure()
78
+ plt.close('all')
66
79
  plt.tight_layout()
67
80
  shap.plots.bar(shap_values)
68
81
  fig_dir = self.util.get_path("fig_dir") + "../" # one up because of the runs
@@ -71,7 +84,8 @@ class FeatureAnalyser:
71
84
  filename = f"_SHAP_{model.name}"
72
85
  filename = f"{fig_dir}{exp_name}{filename}.{format}"
73
86
  plt.savefig(filename)
74
- self.util.debug(f"plotted SHAP feature importance tp {filename}")
87
+ plt.close()
88
+ self.util.debug(f"plotted SHAP feature importance to {filename}")
75
89
 
76
90
  def analyse(self):
77
91
  models = ast.literal_eval(self.util.config_val("EXPL", "model", "['log_reg']"))
nkululeko/modelrunner.py CHANGED
@@ -53,8 +53,8 @@ class Modelrunner:
53
53
  # epochs are handled by Huggingface API
54
54
  self.model.train()
55
55
  report = self.model.predict()
56
- # todo: findout the best epoch, no need
57
- # since oad_best_model_at_end is given in training args
56
+ # todo: findout the best epoch -> no need
57
+ # since load_best_model_at_end is given in training args
58
58
  epoch = epoch_num
59
59
  report.set_id(self.run, epoch)
60
60
  plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
@@ -402,7 +402,7 @@ class Reporter:
402
402
  )
403
403
  # print classifcation report in console
404
404
  self.util.debug(
405
- f"\n {classification_report(self.truths, self.preds, target_names=labels)}"
405
+ f"\n {classification_report(self.truths, self.preds, target_names=labels, digits=4)}"
406
406
  )
407
407
  except ValueError as e:
408
408
  self.util.debug(
@@ -422,16 +422,17 @@ class Reporter:
422
422
  if len(np.unique(self.truths)) == 2:
423
423
  fpr, tpr, _ = roc_curve(self.truths, self.preds)
424
424
  auc_score = auc(fpr, tpr)
425
+ plot_path = f"{fig_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}_roc.{self.format}"
426
+ plt.figure()
425
427
  display = RocCurveDisplay(
426
428
  fpr=fpr,
427
429
  tpr=tpr,
428
430
  roc_auc=auc_score,
429
431
  estimator_name=f"{self.model_type} estimator",
430
432
  )
431
- # save plot
432
- plot_path = f"{fig_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}_roc.{self.format}"
433
433
  display.plot(ax=None)
434
434
  plt.savefig(plot_path)
435
+ plt.close()
435
436
  self.util.debug(f"Saved ROC curve to {plot_path}")
436
437
  pauc_score = roc_auc_score(self.truths, self.preds, max_fpr=0.1)
437
438
  auc_pauc = f"auc: {auc_score:.3f}, pauc: {pauc_score:.3f} from epoch: {epoch}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.89.1
3
+ Version: 0.89.2
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -356,6 +356,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
356
356
  Changelog
357
357
  =========
358
358
 
359
+ Version 0.89.2
360
+ --------------
361
+ * fix shap value calculation
362
+
359
363
  Version 0.89.1
360
364
  --------------
361
365
  * print and save result of feature importance
@@ -2,19 +2,19 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
3
3
  nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
4
4
  nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
5
- nkululeko/constants.py,sha256=nRA0bWrvi-5tXm8QWv4dzDE-3sujMiz26U4QgSVuck0,39
5
+ nkululeko/constants.py,sha256=WFGVylIst9Be_eHBZ9GiR43Qi4CdRySmNUzyNox6aMM,39
6
6
  nkululeko/demo.py,sha256=bLuHkeEl5rOfm7ecGHCcWATiPK7-njNbtrGljxzNzFs,5088
7
7
  nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
8
8
  nkululeko/demo_predictor.py,sha256=zs1bjhpnKuNCPLJeiyDm19ME1NEDOQT3QNeyVKJq9Yc,4882
9
- nkululeko/ensemble.py,sha256=egtOFxEp7gjuM5cKBfETnhTn1-7_4zWBPEah65K1C3U,12927
9
+ nkululeko/ensemble.py,sha256=MayHpngGH_FTvSxUsH3NdxJd6WBAosGRFQeQ7cMjIco,12922
10
10
  nkululeko/experiment.py,sha256=L4PzoScPLG2xTyniVy9evcBy_8CIe3RTeTEUVTqiuvQ,31186
11
- nkululeko/explore.py,sha256=_GOgcRaPvh2xBbKPAkSJjYzgHhD_xb3ZCB6M1MPA6ao,3867
11
+ nkululeko/explore.py,sha256=AbTVDmuDIaLfALQGvDW1yndcw2ikaEVEZ_fJVuUS070,3940
12
12
  nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
13
13
  nkululeko/feature_extractor.py,sha256=UnspIWz3XrNhKnBBhWZkH2bHvD-sROtrQVqB1JvkUyw,4088
14
14
  nkululeko/file_checker.py,sha256=LoLnL8aHpW-axMQ46qbqrManTs5otG9ShpEZuz9iRSk,3474
15
15
  nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
16
16
  nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
17
- nkululeko/modelrunner.py,sha256=cKYD9a7MRoBxfqUy3X8kf6rGTYho-33In8I9YkzMOo8,11196
17
+ nkululeko/modelrunner.py,sha256=lJy-xM4QfDDWeL0dLTE_VIb4sYrnd_Z_yJRK3wwohQA,11199
18
18
  nkululeko/multidb.py,sha256=CCjmVsZyvydgOztFlaeBvOJH8nsvU-sPQdFAw8-q0U4,6752
19
19
  nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
20
20
  nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
@@ -51,7 +51,7 @@ nkululeko/data/dataset_csv.py,sha256=UGEpi__eT2KFS6Fop6N4HkMrzO-u5VP71gt44kwZavo
51
51
  nkululeko/feat_extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
52
52
  nkululeko/feat_extract/feats_agender.py,sha256=tMK3_qs8adylNNSR0CS1RjU9RxmpumLqmuyzmc2ZYjA,3184
53
53
  nkululeko/feat_extract/feats_agender_agender.py,sha256=19NoRT0KJ8WoZ3EabTYexXymD7bDy58-H20jYmdqjD0,3498
54
- nkululeko/feat_extract/feats_analyser.py,sha256=eW0v7Boybfj2gXi77MPjaLyHUQ1C42mx9hgoQeDwNac,12999
54
+ nkululeko/feat_extract/feats_analyser.py,sha256=rSsN6kcDUv64DaTl2DvReXF3_g7CtSwiPKgMzbJPqVI,13516
55
55
  nkululeko/feat_extract/feats_ast.py,sha256=ycJn5eSVOxcEpmeHVk0FPB8q5XiTC8VSKz61L9n0Wa4,4638
56
56
  nkululeko/feat_extract/feats_auddim.py,sha256=ulP_o4SGeQDFTs8YYCGKgccARAo6-wcjPK6-hhGjmn8,3155
57
57
  nkululeko/feat_extract/feats_audmodel.py,sha256=aRGTBDKdYaTT_9xDaFZqpuyPhzxSNN_3b1PJDUHtYW4,3180
@@ -98,7 +98,7 @@ nkululeko/reporting/defines.py,sha256=IsY1YgKRMaABpylVKjBJgJ5bNCEbGCVA_E6pivraqS
98
98
  nkululeko/reporting/latex_writer.py,sha256=qiCRSmB4KOD_za4oHu5x-PhwjZohzfo8wecMOwlXZwc,1886
99
99
  nkululeko/reporting/report.py,sha256=W0rcigDdjBvxZQ3pZja_gvToILYvaZ1BFtnN2qFRfYI,1060
100
100
  nkululeko/reporting/report_item.py,sha256=siWeGNgo4bAE46YBMNcsdf3jTMTy76BO9Fi6DTvDig4,533
101
- nkululeko/reporting/reporter.py,sha256=xFyGj6gQ8T1WB3w3tJ0awlgQcq1e3IKXEIfl_DvOngg,19996
101
+ nkululeko/reporting/reporter.py,sha256=oodLaNZXqPpfoRqVxTldYcx68oN35OGgy-vvbAuY-yI,20039
102
102
  nkululeko/reporting/result.py,sha256=G63a2tHCwHhM6NBJgYzsWKWJm4Yu3r4hsCHA2Km7eHU,1073
103
103
  nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
104
104
  nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=pmLHuXsaqvcdYxB4PSW9l1mbQWZZBJFhi_CGabqydas,1947
@@ -107,8 +107,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
107
107
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
108
108
  nkululeko/utils/stats.py,sha256=eC9dMO-by6CDnGLHDBQu-2B4-BudZNJ0nnWGhKYdUMA,2968
109
109
  nkululeko/utils/util.py,sha256=363Lgmcg6fPKCGbroX0DDyW_zcYNx-Ayqv67qdpfYcw,16710
110
- nkululeko-0.89.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
111
- nkululeko-0.89.1.dist-info/METADATA,sha256=AuVssWNRMXlseH5xSzcls--AAYLFSeEbFtHbAFT2o_o,40667
112
- nkululeko-0.89.1.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
113
- nkululeko-0.89.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
114
- nkululeko-0.89.1.dist-info/RECORD,,
110
+ nkululeko-0.89.2.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
111
+ nkululeko-0.89.2.dist-info/METADATA,sha256=00CLy_4Wm7IktJy7dAkKrXkCMi0f1HUXCoQYMNcp2kw,40729
112
+ nkululeko-0.89.2.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
113
+ nkululeko-0.89.2.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
114
+ nkululeko-0.89.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.0.0)
2
+ Generator: setuptools (74.1.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5