nkululeko 0.83.3__py3-none-any.whl → 0.84.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,16 +12,19 @@ from nkululeko.utils.util import Util
12
12
 
13
13
 
14
14
  class Resampler:
15
- def __init__(self, df, not_testing=True):
15
+ def __init__(self, df, replace, not_testing=True):
16
16
  self.SAMPLING_RATE = 16000
17
17
  self.df = df
18
18
  self.util = Util("resampler", has_config=not_testing)
19
19
  self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
20
20
  self.not_testing = not_testing
21
+ self.replace = eval(self.util.config_val(
22
+ "RESAMPLE", "replace", "False")) if not not_testing else replace
21
23
 
22
24
  def resample(self):
23
25
  files = self.df.index.get_level_values(0).values
24
- replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
26
+ # replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
27
+ replace = self.replace
25
28
  if self.not_testing:
26
29
  store = self.util.get_path("store")
27
30
  else:
@@ -42,7 +45,8 @@ class Resampler:
42
45
  continue
43
46
  if org_sr != self.SAMPLING_RATE:
44
47
  self.util.debug(f"resampling {f} (sr = {org_sr})")
45
- resampler = torchaudio.transforms.Resample(org_sr, self.SAMPLING_RATE)
48
+ resampler = torchaudio.transforms.Resample(
49
+ org_sr, self.SAMPLING_RATE)
46
50
  signal = resampler(signal)
47
51
  if replace:
48
52
  torchaudio.save(
@@ -59,7 +63,8 @@ class Resampler:
59
63
  self.df = self.df.set_index(
60
64
  self.df.index.set_levels(new_files, level="file")
61
65
  )
62
- target_file = self.util.config_val("RESAMPLE", "target", "resampled.csv")
66
+ target_file = self.util.config_val(
67
+ "RESAMPLE", "target", "resampled.csv")
63
68
  # remove encoded labels
64
69
  target = self.util.config_val("DATA", "target", "emotion")
65
70
  if "class_label" in self.df.columns:
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.83.3"
1
+ VERSION="0.84.1"
2
2
  SAMPLING_RATE = 16000
nkululeko/demo.py CHANGED
@@ -2,8 +2,9 @@
2
2
  # Demonstration code to use the ML-experiment framework
3
3
  # Test the loading of a previously trained model and demo mode
4
4
  # needs the project config file to run before
5
- """
6
- This script is used to test the loading of a previously trained model and run it in demo mode.
5
+ """This script is used to test the loading of a previously trained model.
6
+
7
+ And run it in demo mode.
7
8
  It requires the project config file to be run before.
8
9
 
9
10
  Usage:
@@ -20,17 +21,15 @@ import argparse
20
21
  import configparser
21
22
  import os
22
23
 
23
- import nkululeko.glob_conf as glob_conf
24
24
  from nkululeko.constants import VERSION
25
25
  from nkululeko.experiment import Experiment
26
+ import nkululeko.glob_conf as glob_conf
26
27
  from nkululeko.utils.util import Util
27
28
 
28
29
 
29
30
  def main(src_dir):
30
- parser = argparse.ArgumentParser(
31
- description="Call the nkululeko DEMO framework.")
32
- parser.add_argument("--config", default="exp.ini",
33
- help="The base configuration")
31
+ parser = argparse.ArgumentParser(description="Call the nkululeko DEMO framework.")
32
+ parser.add_argument("--config", default="exp.ini", help="The base configuration")
34
33
  parser.add_argument(
35
34
  "--file", help="A file that should be processed (16kHz mono wav)"
36
35
  )
@@ -1,18 +1,19 @@
1
1
  # demo_predictor.py
2
2
  import os
3
3
 
4
- import audformat
5
- import audiofile
6
4
  import numpy as np
7
5
  import pandas as pd
8
6
 
7
+ import audformat
8
+ import audiofile
9
+
9
10
  import nkululeko.glob_conf as glob_conf
10
11
  from nkululeko.utils.util import Util
11
12
 
12
13
 
13
14
  class Demo_predictor:
14
15
  def __init__(self, model, file, is_list, feature_extractor, label_encoder, outfile):
15
- """Constructor setting up name and configuration"""
16
+ """Constructor setting up name and configuration."""
16
17
  self.model = model
17
18
  self.feature_extractor = feature_extractor
18
19
  self.label_encoder = label_encoder
nkululeko/experiment.py CHANGED
@@ -5,20 +5,22 @@ import pickle
5
5
  import random
6
6
  import time
7
7
 
8
- import audeer
9
- import audformat
10
8
  import numpy as np
11
9
  import pandas as pd
12
10
  from sklearn.preprocessing import LabelEncoder
13
11
 
14
- import nkululeko.glob_conf as glob_conf
12
+ import audeer
13
+ import audformat
14
+
15
15
  from nkululeko.data.dataset import Dataset
16
16
  from nkululeko.data.dataset_csv import Dataset_CSV
17
17
  from nkululeko.demo_predictor import Demo_predictor
18
18
  from nkululeko.feat_extract.feats_analyser import FeatureAnalyser
19
19
  from nkululeko.feature_extractor import FeatureExtractor
20
20
  from nkululeko.file_checker import FileChecker
21
- from nkululeko.filter_data import DataFilter, filter_min_dur
21
+ from nkululeko.filter_data import DataFilter
22
+ from nkululeko.filter_data import filter_min_dur
23
+ import nkululeko.glob_conf as glob_conf
22
24
  from nkululeko.plots import Plots
23
25
  from nkululeko.reporting.report import Report
24
26
  from nkululeko.runmanager import Runmanager
@@ -101,6 +103,7 @@ class Experiment:
101
103
  self.got_speaker = True
102
104
  self.datasets.update({d: data})
103
105
  self.target = self.util.config_val("DATA", "target", "emotion")
106
+ glob_conf.set_target(self.target)
104
107
  # print target via debug
105
108
  self.util.debug(f"target: {self.target}")
106
109
  # print keys/column
@@ -487,11 +490,7 @@ class Experiment:
487
490
  return df_ret
488
491
 
489
492
  def analyse_features(self, needs_feats):
490
- """
491
- Do a feature exploration
492
-
493
- """
494
-
493
+ """Do a feature exploration."""
495
494
  plot_feats = eval(
496
495
  self.util.config_val("EXPL", "feature_distributions", "False")
497
496
  )
@@ -511,7 +510,7 @@ class Experiment:
511
510
  f"unknown sample selection specifier {sample_selection}, should"
512
511
  " be [all | train | test]"
513
512
  )
514
-
513
+ self.util.debug(f"sampling selection: {sample_selection}")
515
514
  if self.util.config_val("EXPL", "value_counts", False):
516
515
  self.plot_distribution(df_labels)
517
516
 
@@ -537,9 +536,13 @@ class Experiment:
537
536
  f"unknown sample selection specifier {sample_selection}, should"
538
537
  " be [all | train | test]"
539
538
  )
539
+ feat_analyser = FeatureAnalyser(sample_selection, df_labels, df_feats)
540
+ # check if SHAP features should be analysed
541
+ shap = eval(self.util.config_val("EXPL", "shap", "False"))
542
+ if shap:
543
+ feat_analyser.analyse_shap(self.runmgr.get_best_model())
540
544
 
541
545
  if plot_feats:
542
- feat_analyser = FeatureAnalyser(sample_selection, df_labels, df_feats)
543
546
  feat_analyser.analyse()
544
547
 
545
548
  # check if a scatterplot should be done
@@ -692,7 +695,7 @@ class Experiment:
692
695
  if self.runmgr.modelrunner.model.is_ann():
693
696
  self.runmgr.modelrunner.model = None
694
697
  self.util.warn(
695
- "Save experiment: Can't pickle the learning model so saving without it."
698
+ "Save experiment: Can't pickle the trained model so saving without it. (it should be stored anyway)"
696
699
  )
697
700
  try:
698
701
  f = open(filename, "wb")
nkululeko/explore.py CHANGED
@@ -12,9 +12,9 @@ from nkululeko.utils.util import Util
12
12
 
13
13
  def main(src_dir):
14
14
  parser = argparse.ArgumentParser(
15
- description="Call the nkululeko EXPLORE framework.")
16
- parser.add_argument("--config", default="exp.ini",
17
- help="The base configuration")
15
+ description="Call the nkululeko EXPLORE framework."
16
+ )
17
+ parser.add_argument("--config", default="exp.ini", help="The base configuration")
18
18
  args = parser.parse_args()
19
19
  if args.config is not None:
20
20
  config_file = args.config
@@ -43,28 +43,34 @@ def main(src_dir):
43
43
  import warnings
44
44
 
45
45
  warnings.filterwarnings("ignore")
46
-
47
- # load the data
48
- expr.load_datasets()
49
-
50
- # split into train and test
51
- expr.fill_train_and_tests()
52
- util.debug(
53
- f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
54
-
55
- plot_feats = eval(util.config_val(
56
- "EXPL", "feature_distributions", "False"))
57
- tsne = eval(util.config_val("EXPL", "tsne", "False"))
58
- scatter = eval(util.config_val("EXPL", "scatter", "False"))
59
- spotlight = eval(util.config_val("EXPL", "spotlight", "False"))
60
- model_type = util.config_val("EXPL", "model", False)
61
- plot_tree = eval(util.config_val("EXPL", "plot_tree", "False"))
62
46
  needs_feats = False
63
- if plot_feats or tsne or scatter or model_type or plot_tree:
64
- # these investigations need features to explore
65
- expr.extract_feats()
47
+ try:
48
+ # load the experiment
49
+ expr.load(f"{util.get_save_name()}")
66
50
  needs_feats = True
67
- # explore
51
+ except FileNotFoundError:
52
+ # first time: load the data
53
+ expr.load_datasets()
54
+
55
+ # split into train and test
56
+ expr.fill_train_and_tests()
57
+ util.debug(
58
+ f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}"
59
+ )
60
+
61
+ plot_feats = eval(util.config_val("EXPL", "feature_distributions", "False"))
62
+ tsne = eval(util.config_val("EXPL", "tsne", "False"))
63
+ scatter = eval(util.config_val("EXPL", "scatter", "False"))
64
+ spotlight = eval(util.config_val("EXPL", "spotlight", "False"))
65
+ shap = eval(util.config_val("EXPL", "shap", "False"))
66
+ model_type = util.config_val("EXPL", "model", False)
67
+ plot_tree = eval(util.config_val("EXPL", "plot_tree", "False"))
68
+ needs_feats = False
69
+ if plot_feats or tsne or scatter or model_type or plot_tree or shap:
70
+ # these investigations need features to explore
71
+ expr.extract_feats()
72
+ needs_feats = True
73
+ # explore
68
74
  expr.analyse_features(needs_feats)
69
75
  expr.store_report()
70
76
  print("DONE")
@@ -40,6 +40,39 @@ class FeatureAnalyser:
40
40
  importance = model.feature_importances_
41
41
  return importance
42
42
 
43
+ def analyse_shap(self, model):
44
+ """Shap analysis.
45
+
46
+ Use the best model from a previous run and analyse feature importance with SHAP.
47
+ https://m.mage.ai/how-to-interpret-and-explain-your-machine-learning-models-using-shap-values-471c2635b78e.
48
+ """
49
+ import shap
50
+
51
+ name = "my_shap_values"
52
+ if not self.util.exist_pickle(name):
53
+
54
+ explainer = shap.Explainer(
55
+ model.predict_shap,
56
+ self.features,
57
+ output_names=glob_conf.labels,
58
+ algorithm="permutation",
59
+ npermutations=5,
60
+ )
61
+ self.util.debug("computing SHAP values...")
62
+ shap_values = explainer(self.features)
63
+ self.util.to_pickle(shap_values, name)
64
+ else:
65
+ shap_values = self.util.from_pickle(name)
66
+ plt.tight_layout()
67
+ shap.plots.bar(shap_values)
68
+ fig_dir = self.util.get_path("fig_dir") + "../" # one up because of the runs
69
+ exp_name = self.util.get_exp_name(only_data=True)
70
+ format = self.util.config_val("PLOT", "format", "png")
71
+ filename = f"_SHAP_{model.name}"
72
+ filename = f"{fig_dir}{exp_name}{filename}.{format}"
73
+ plt.savefig(filename)
74
+ self.util.debug(f"plotted SHAP feature importance tp {filename}")
75
+
43
76
  def analyse(self):
44
77
  models = ast.literal_eval(self.util.config_val("EXPL", "model", "['log_reg']"))
45
78
  model_name = "_".join(models)
nkululeko/glob_conf.py CHANGED
@@ -29,3 +29,8 @@ def set_report(report_obj):
29
29
  def set_labels(labels_obj):
30
30
  global labels
31
31
  labels = labels_obj
32
+
33
+
34
+ def set_target(target_obj):
35
+ global target
36
+ target = target_obj
@@ -0,0 +1,181 @@
1
+ import dataclasses
2
+ import typing
3
+
4
+ import torch
5
+ import transformers
6
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
7
+ Wav2Vec2PreTrainedModel,
8
+ Wav2Vec2Model,
9
+ )
10
+
11
+
12
+ class ConcordanceCorCoeff(torch.nn.Module):
13
+
14
+ def __init__(self):
15
+
16
+ super().__init__()
17
+
18
+ self.mean = torch.mean
19
+ self.var = torch.var
20
+ self.sum = torch.sum
21
+ self.sqrt = torch.sqrt
22
+ self.std = torch.std
23
+
24
+ def forward(self, prediction, ground_truth):
25
+
26
+ mean_gt = self.mean(ground_truth, 0)
27
+ mean_pred = self.mean(prediction, 0)
28
+ var_gt = self.var(ground_truth, 0)
29
+ var_pred = self.var(prediction, 0)
30
+ v_pred = prediction - mean_pred
31
+ v_gt = ground_truth - mean_gt
32
+ cor = self.sum(v_pred * v_gt) / (
33
+ self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
34
+ )
35
+ sd_gt = self.std(ground_truth)
36
+ sd_pred = self.std(prediction)
37
+ numerator = 2 * cor * sd_gt * sd_pred
38
+ denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
39
+ ccc = numerator / denominator
40
+
41
+ return 1 - ccc
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class ModelOutput(transformers.file_utils.ModelOutput):
46
+
47
+ logits_cat: torch.FloatTensor = None
48
+ hidden_states: typing.Tuple[torch.FloatTensor] = None
49
+ cnn_features: torch.FloatTensor = None
50
+
51
+
52
+ class ModelHead(torch.nn.Module):
53
+
54
+ def __init__(self, config, num_labels):
55
+
56
+ super().__init__()
57
+
58
+ self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
59
+ self.dropout = torch.nn.Dropout(config.final_dropout)
60
+ self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
61
+
62
+ def forward(self, features, **kwargs):
63
+
64
+ x = features
65
+ x = self.dropout(x)
66
+ x = self.dense(x)
67
+ x = torch.tanh(x)
68
+ x = self.dropout(x)
69
+ x = self.out_proj(x)
70
+
71
+ return x
72
+
73
+
74
+ class Model(Wav2Vec2PreTrainedModel):
75
+
76
+ def __init__(self, config):
77
+
78
+ super().__init__(config)
79
+
80
+ self.wav2vec2 = Wav2Vec2Model(config)
81
+ self.cat = ModelHead(config, 2)
82
+ self.init_weights()
83
+
84
+ def freeze_feature_extractor(self):
85
+ self.wav2vec2.feature_extractor._freeze_parameters()
86
+
87
+ def pooling(
88
+ self,
89
+ hidden_states,
90
+ attention_mask,
91
+ ):
92
+
93
+ if attention_mask is None: # For evaluation with batch_size==1
94
+ outputs = torch.mean(hidden_states, dim=1)
95
+ else:
96
+ attention_mask = self._get_feature_vector_attention_mask(
97
+ hidden_states.shape[1],
98
+ attention_mask,
99
+ )
100
+ hidden_states = hidden_states * torch.reshape(
101
+ attention_mask,
102
+ (-1, attention_mask.shape[-1], 1),
103
+ )
104
+ outputs = torch.sum(hidden_states, dim=1)
105
+ attention_sum = torch.sum(attention_mask, dim=1)
106
+ outputs = outputs / torch.reshape(attention_sum, (-1, 1))
107
+
108
+ return outputs
109
+
110
+ def forward(
111
+ self,
112
+ input_values,
113
+ attention_mask=None,
114
+ labels=None,
115
+ return_hidden=False,
116
+ ):
117
+
118
+ outputs = self.wav2vec2(
119
+ input_values,
120
+ attention_mask=attention_mask,
121
+ )
122
+
123
+ cnn_features = outputs.extract_features
124
+ hidden_states_framewise = outputs.last_hidden_state
125
+ hidden_states = self.pooling(
126
+ hidden_states_framewise,
127
+ attention_mask,
128
+ )
129
+ logits_cat = self.cat(hidden_states)
130
+
131
+ if not self.training:
132
+ logits_cat = torch.softmax(logits_cat, dim=1)
133
+
134
+ if return_hidden:
135
+
136
+ # make time last axis
137
+ cnn_features = torch.transpose(cnn_features, 1, 2)
138
+
139
+ return ModelOutput(
140
+ logits_cat=logits_cat,
141
+ hidden_states=hidden_states,
142
+ cnn_features=cnn_features,
143
+ )
144
+
145
+ else:
146
+
147
+ return ModelOutput(
148
+ logits_cat=logits_cat,
149
+ )
150
+
151
+
152
+ class ModelWithPreProcessing(Model):
153
+
154
+ def __init__(self, config):
155
+ super().__init__(config)
156
+
157
+ def forward(
158
+ self,
159
+ input_values,
160
+ ):
161
+ # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
162
+ # normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
163
+
164
+ mean = input_values.mean()
165
+
166
+ # var = input_values.var()
167
+ # raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
168
+
169
+ var = torch.square(input_values - mean).mean()
170
+ input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
171
+
172
+ output = super().forward(
173
+ input_values,
174
+ return_hidden=True,
175
+ )
176
+
177
+ return (
178
+ output.hidden_states,
179
+ output.logits_cat,
180
+ output.cnn_features,
181
+ )
nkululeko/models/model.py CHANGED
@@ -20,6 +20,7 @@ class Model:
20
20
 
21
21
  def __init__(self, df_train, df_test, feats_train, feats_test):
22
22
  """Constructor taking the configuration and all dataframes."""
23
+ self.name = "undefined"
23
24
  self.df_train, self.df_test, self.feats_train, self.feats_test = (
24
25
  df_train,
25
26
  df_test,
@@ -12,3 +12,4 @@ class Bayes_model(Model):
12
12
  def __init__(self, df_train, df_test, feats_train, feats_test):
13
13
  super().__init__(df_train, df_test, feats_train, feats_test)
14
14
  self.clf = GaussianNB() # set up the classifier
15
+ self.name = "bayes"
@@ -34,7 +34,8 @@ class CNN_model(Model):
34
34
  """Constructor taking the configuration and all dataframes"""
35
35
  super().__init__(df_train, df_test, feats_train, feats_test)
36
36
  super().set_model_type("ann")
37
- self.target = glob_conf.config["DATA"]["target"]
37
+ self.name = "cnn"
38
+ self.target = glob_conf.target
38
39
  labels = glob_conf.labels
39
40
  self.class_num = len(labels)
40
41
  # set up loss criterion
@@ -86,8 +87,7 @@ class CNN_model(Model):
86
87
  train_set = self.Dataset_image(
87
88
  feats_train, df_train, self.target, transformations
88
89
  )
89
- test_set = self.Dataset_image(
90
- feats_test, df_test, self.target, transformations)
90
+ test_set = self.Dataset_image(feats_test, df_test, self.target, transformations)
91
91
  # Define data loaders
92
92
  self.trainloader = torch.utils.data.DataLoader(
93
93
  train_set,
@@ -140,8 +140,7 @@ class CNN_model(Model):
140
140
  losses = []
141
141
  for images, labels in self.trainloader:
142
142
  logits = self.model(images.to(self.device))
143
- loss = self.criterion(logits, labels.to(
144
- self.device, dtype=torch.int64))
143
+ loss = self.criterion(logits, labels.to(self.device, dtype=torch.int64))
145
144
  losses.append(loss.item())
146
145
  self.optimizer.zero_grad()
147
146
  loss.backward()
@@ -169,16 +168,14 @@ class CNN_model(Model):
169
168
 
170
169
  self.loss_eval = (np.asarray(losses)).mean()
171
170
  predictions = logits.argmax(dim=1)
172
- uar = recall_score(
173
- targets.numpy(), predictions.numpy(), average="macro")
171
+ uar = recall_score(targets.numpy(), predictions.numpy(), average="macro")
174
172
  return uar, targets, predictions
175
173
 
176
174
  def predict(self):
177
175
  _, truths, predictions = self.evaluate_model(
178
176
  self.model, self.testloader, self.device
179
177
  )
180
- uar, _, _ = self.evaluate_model(
181
- self.model, self.trainloader, self.device)
178
+ uar, _, _ = self.evaluate_model(self.model, self.trainloader, self.device)
182
179
  report = Reporter(truths, predictions, self.run, self.epoch)
183
180
  try:
184
181
  report.result.loss = self.loss
@@ -11,10 +11,9 @@ class GMM_model(Model):
11
11
 
12
12
  def __init__(self, df_train, df_test, feats_train, feats_test):
13
13
  super().__init__(df_train, df_test, feats_train, feats_test)
14
+ self.name = "gmm"
14
15
  n_components = int(self.util.config_val("MODEL", "GMM_components", "4"))
15
- covariance_type = self.util.config_val(
16
- "MODEL", "GMM_covariance_type", "full"
17
- )
16
+ covariance_type = self.util.config_val("MODEL", "GMM_covariance_type", "full")
18
17
  self.clf = mixture.GaussianMixture(
19
18
  n_components=n_components, covariance_type=covariance_type
20
19
  )
@@ -11,6 +11,7 @@ class KNN_model(Model):
11
11
 
12
12
  def __init__(self, df_train, df_test, feats_train, feats_test):
13
13
  super().__init__(df_train, df_test, feats_train, feats_test)
14
+ self.name = "knn"
14
15
  method = self.util.config_val("MODEL", "KNN_weights", "uniform")
15
16
  k = int(self.util.config_val("MODEL", "K_val", "5"))
16
17
  self.clf = KNeighborsClassifier(
@@ -11,6 +11,7 @@ class KNN_reg_model(Model):
11
11
 
12
12
  def __init__(self, df_train, df_test, feats_train, feats_test):
13
13
  super().__init__(df_train, df_test, feats_train, feats_test)
14
+ self.name = "knn_reg"
14
15
  method = self.util.config_val("MODEL", "KNN_weights", "uniform")
15
16
  k = int(self.util.config_val("MODEL", "K_val", "5"))
16
17
  self.clf = KNeighborsRegressor(
@@ -11,4 +11,5 @@ class Lin_reg_model(Model):
11
11
 
12
12
  def __init__(self, df_train, df_test, feats_train, feats_test):
13
13
  super().__init__(df_train, df_test, feats_train, feats_test)
14
+ self.name = "lin_reg"
14
15
  self.clf = LinearRegression() # set up the classifier
@@ -1,4 +1,6 @@
1
1
  # model_mlp.py
2
+ import pandas as pd
3
+
2
4
  from nkululeko.utils.util import Util
3
5
  import nkululeko.glob_conf as glob_conf
4
6
  from nkululeko.models.model import Model
@@ -20,6 +22,7 @@ class MLP_model(Model):
20
22
  """Constructor taking the configuration and all dataframes"""
21
23
  super().__init__(df_train, df_test, feats_train, feats_test)
22
24
  super().set_model_type("ann")
25
+ self.name = "mlp"
23
26
  self.target = glob_conf.config["DATA"]["target"]
24
27
  labels = glob_conf.labels
25
28
  self.class_num = len(labels)
@@ -87,8 +90,7 @@ class MLP_model(Model):
87
90
  losses = []
88
91
  for features, labels in self.trainloader:
89
92
  logits = self.model(features.to(self.device))
90
- loss = self.criterion(logits, labels.to(
91
- self.device, dtype=torch.int64))
93
+ loss = self.criterion(logits, labels.to(self.device, dtype=torch.int64))
92
94
  losses.append(loss.item())
93
95
  self.optimizer.zero_grad()
94
96
  loss.backward()
@@ -116,16 +118,14 @@ class MLP_model(Model):
116
118
 
117
119
  self.loss_eval = (np.asarray(losses)).mean()
118
120
  predictions = logits.argmax(dim=1)
119
- uar = recall_score(
120
- targets.numpy(), predictions.numpy(), average="macro")
121
+ uar = recall_score(targets.numpy(), predictions.numpy(), average="macro")
121
122
  return uar, targets, predictions
122
123
 
123
124
  def predict(self):
124
125
  _, truths, predictions = self.evaluate_model(
125
126
  self.model, self.testloader, self.device
126
127
  )
127
- uar, _, _ = self.evaluate_model(
128
- self.model, self.trainloader, self.device)
128
+ uar, _, _ = self.evaluate_model(self.model, self.trainloader, self.device)
129
129
  report = Reporter(truths, predictions, self.run, self.epoch)
130
130
  try:
131
131
  report.result.loss = self.loss
@@ -176,8 +176,18 @@ class MLP_model(Model):
176
176
  x = x.squeeze(dim=1).float()
177
177
  return self.linear(x)
178
178
 
179
+ def predict_shap(self, features):
180
+ # predict outputs for all samples in SHAP format (pd. dataframe)
181
+ results = []
182
+ for index, row in features.iterrows():
183
+ feats = row.values
184
+ res_dict = self.predict_sample(feats)
185
+ class_key = max(res_dict, key=res_dict.get)
186
+ results.append(class_key)
187
+ return results
188
+
179
189
  def predict_sample(self, features):
180
- """Predict one sample"""
190
+ """Predict one sample."""
181
191
  with torch.no_grad():
182
192
  features = torch.from_numpy(features)
183
193
  features = np.reshape(features, (-1, 1)).T