nkululeko 0.86.8__py3-none-any.whl → 0.88.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.86.8"
1
+ VERSION="0.88.0"
2
2
  SAMPLING_RATE = 16000
@@ -23,6 +23,9 @@ class Dataset_CSV(Dataset):
23
23
  root = os.path.dirname(data_file)
24
24
  audio_path = self.util.config_val_data(self.name, "audio_path", "./")
25
25
  df = pd.read_csv(data_file)
26
+ # trim all string values
27
+ df_obj = df.select_dtypes("object")
28
+ df[df_obj.columns] = df_obj.apply(lambda x: x.str.strip())
26
29
  # special treatment for segmented dataframes with only one column:
27
30
  if "start" in df.columns and len(df.columns) == 4:
28
31
  index = audformat.segmented_index(
@@ -49,8 +52,7 @@ class Dataset_CSV(Dataset):
49
52
  .map(lambda x: root + "/" + audio_path + "/" + x)
50
53
  .values
51
54
  )
52
- df = df.set_index(df.index.set_levels(
53
- file_index, level="file"))
55
+ df = df.set_index(df.index.set_levels(file_index, level="file"))
54
56
  else:
55
57
  if not isinstance(df, pd.DataFrame):
56
58
  df = pd.DataFrame(df)
@@ -59,27 +61,24 @@ class Dataset_CSV(Dataset):
59
61
  lambda x: root + "/" + audio_path + "/" + x
60
62
  )
61
63
  )
62
- else: # absolute path is True
64
+ else: # absolute path is True
63
65
  if audformat.index_type(df.index) == "segmented":
64
66
  file_index = (
65
- df.index.levels[0]
66
- .map(lambda x: audio_path + "/" + x)
67
- .values
67
+ df.index.levels[0].map(lambda x: audio_path + "/" + x).values
68
68
  )
69
- df = df.set_index(df.index.set_levels(
70
- file_index, level="file"))
69
+ df = df.set_index(df.index.set_levels(file_index, level="file"))
71
70
  else:
72
71
  if not isinstance(df, pd.DataFrame):
73
72
  df = pd.DataFrame(df)
74
- df = df.set_index(df.index.to_series().apply(
75
- lambda x: audio_path + "/" + x ))
73
+ df = df.set_index(
74
+ df.index.to_series().apply(lambda x: audio_path + "/" + x)
75
+ )
76
76
 
77
77
  self.df = df
78
78
  self.db = None
79
79
  self.got_target = True
80
80
  self.is_labeled = self.got_target
81
- self.start_fresh = eval(
82
- self.util.config_val("DATA", "no_reuse", "False"))
81
+ self.start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
83
82
  is_index = False
84
83
  try:
85
84
  if self.is_labeled and not "class_label" in self.df.columns:
@@ -106,8 +105,7 @@ class Dataset_CSV(Dataset):
106
105
  f" {self.got_gender}, got age: {self.got_age}"
107
106
  )
108
107
  self.util.debug(r_string)
109
- glob_conf.report.add_item(ReportItem(
110
- "Data", "Loaded report", r_string))
108
+ glob_conf.report.add_item(ReportItem("Data", "Loaded report", r_string))
111
109
 
112
110
  def prepare(self):
113
111
  super().prepare()
nkululeko/demo.py CHANGED
@@ -20,20 +20,19 @@ Options: \n
20
20
  import argparse
21
21
  import configparser
22
22
  import os
23
+
23
24
  import pandas as pd
25
+ from transformers import pipeline
24
26
 
27
+ import nkululeko.glob_conf as glob_conf
25
28
  from nkululeko.constants import VERSION
26
29
  from nkululeko.experiment import Experiment
27
- import nkululeko.glob_conf as glob_conf
28
30
  from nkululeko.utils.util import Util
29
- from transformers import pipeline
30
31
 
31
32
 
32
33
  def main(src_dir):
33
- parser = argparse.ArgumentParser(
34
- description="Call the nkululeko DEMO framework.")
35
- parser.add_argument("--config", default="exp.ini",
36
- help="The base configuration")
34
+ parser = argparse.ArgumentParser(description="Call the nkululeko DEMO framework.")
35
+ parser.add_argument("--config", default="exp.ini", help="The base configuration")
37
36
  parser.add_argument(
38
37
  "--file", help="A file that should be processed (16kHz mono wav)"
39
38
  )
@@ -84,8 +83,7 @@ def main(src_dir):
84
83
  )
85
84
 
86
85
  def print_pipe(files, outfile):
87
- """
88
- Prints the pipeline output for a list of files, and optionally writes the results to an output file.
86
+ """Prints the pipeline output for a list of files, and optionally writes the results to an output file.
89
87
 
90
88
  Args:
91
89
  files (list): A list of file paths to process through the pipeline.
@@ -108,8 +106,7 @@ def main(src_dir):
108
106
  f.write("\n".join(results))
109
107
 
110
108
  if util.get_model_type() == "finetune":
111
- model_path = os.path.join(
112
- util.get_exp_dir(), "models", "run_0", "torch")
109
+ model_path = os.path.join(util.get_exp_dir(), "models", "run_0", "torch")
113
110
  pipe = pipeline("audio-classification", model=model_path)
114
111
  if args.file is not None:
115
112
  print_pipe([args.file], args.outfile)
nkululeko/ensemble.py ADDED
@@ -0,0 +1,158 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ import configparser
6
+ import time
7
+ from argparse import ArgumentParser
8
+ from pathlib import Path
9
+
10
+ import pandas as pd
11
+
12
+ from nkululeko.constants import VERSION
13
+ from nkululeko.experiment import Experiment
14
+ from nkululeko.utils.util import Util
15
+
16
+
17
+ def ensemble_predictions(config_files, method, no_labels):
18
+ """
19
+ Ensemble predictions from multiple experiments.
20
+
21
+ Args:
22
+ config_files (list): List of configuration file paths.
23
+ method (str): Ensemble method to use. Options are 'majority_voting', 'mean', 'max', or 'sum'.
24
+ no_labels (bool): Flag indicating whether the predictions have labels or not.
25
+
26
+ Returns:
27
+ pandas.DataFrame: The ensemble predictions.
28
+
29
+ Raises:
30
+ ValueError: If an unknown ensemble method is provided.
31
+ AssertionError: If the number of config files is less than 2 for majority voting.
32
+
33
+ """
34
+ ensemble_preds = []
35
+ # labels = []
36
+ for config_file in config_files:
37
+ if no_labels:
38
+ # for ensembling results from Nkululeko.demo
39
+ pred = pd.read_csv(config_file)
40
+ labels = pred.columns[1:-2]
41
+ else:
42
+ # for ensembling results from Nkululeko.nkululeko
43
+ config = configparser.ConfigParser()
44
+ config.read(config_file)
45
+ expr = Experiment(config)
46
+ module = "ensemble"
47
+ expr.set_module(module)
48
+ util = Util(module, has_config=True)
49
+ util.debug(
50
+ f"running {expr.name} from config {config_file}, nkululeko version"
51
+ f" {VERSION}"
52
+ )
53
+
54
+ # get labels
55
+ labels = expr.util.get_labels()
56
+ # load the experiment
57
+ # get CSV files of predictions
58
+ pred = expr.util.get_pred_name()
59
+ print(f"Loading predictions from {pred}")
60
+ preds = pd.read_csv(pred)
61
+
62
+ ensemble_preds.append(preds)
63
+
64
+ # pd concate
65
+ ensemble_preds = pd.concat(ensemble_preds, axis=1)
66
+
67
+ if method == "majority_voting":
68
+ # majority voting, get mode, works for odd number of models
69
+ # raise error when number of configs only two:
70
+ assert (
71
+ len(config_files) > 2
72
+ ), "Majority voting only works for more than two models"
73
+ ensemble_preds["predicted"] = ensemble_preds.mode(axis=1)[0]
74
+
75
+ elif method == "mean":
76
+ for label in labels:
77
+ ensemble_preds[label] = ensemble_preds[label].mean(axis=1)
78
+
79
+ elif method == "max":
80
+ for label in labels:
81
+ ensemble_preds[label] = ensemble_preds[label].max(axis=1)
82
+ # get max value from all labels to inver that labels
83
+
84
+ elif method == "sum":
85
+ for label in labels:
86
+ ensemble_preds[label] = ensemble_preds[label].sum(axis=1)
87
+
88
+ else:
89
+ raise ValueError(f"Unknown ensemble method: {method}")
90
+
91
+ # get the highest value from all labels to inver that labels
92
+ # replace the old first predicted column
93
+ ensemble_preds["predicted"] = ensemble_preds[labels].idxmax(axis=1)
94
+
95
+ if no_labels:
96
+ return ensemble_preds
97
+
98
+ # Drop start, end columns
99
+ ensemble_preds = ensemble_preds.drop(columns=["start", "end"])
100
+
101
+ # Drop other column except until truth
102
+ ensemble_preds = ensemble_preds.iloc[:, : len(labels) + 3]
103
+
104
+ # calculate UAR from predicted and truth columns
105
+
106
+ truth = ensemble_preds["truth"]
107
+ predicted = ensemble_preds["predicted"]
108
+ uar = (truth == predicted).mean()
109
+ Util("ensemble").debug(f"UAR: {uar:.3f}")
110
+
111
+ # only return until 'predicted' column
112
+ return ensemble_preds
113
+
114
+
115
+ def main(src_dir):
116
+ parser = ArgumentParser()
117
+ parser.add_argument(
118
+ "configs",
119
+ nargs="+",
120
+ help="Paths to the configuration files of the experiments to ensemble. \
121
+ Can be INI files for Nkululeko.nkululeo or CSV files from Nkululeko.demo.",
122
+ )
123
+ parser.add_argument(
124
+ "--method",
125
+ default="majority_voting",
126
+ choices=["majority_voting", "mean", "max", "sum"],
127
+ help="Ensemble method to use (default: majority_voting)",
128
+ )
129
+ parser.add_argument(
130
+ "--outfile",
131
+ default="ensemble_result.csv",
132
+ help="Output file path for the ensemble predictions (default: ensemble_predictions.csv)",
133
+ )
134
+
135
+ # add argument if true label is not available
136
+ parser.add_argument(
137
+ "--no_labels",
138
+ action="store_true",
139
+ help="True if true labels are not available. For Nkululeko.demo results.",
140
+ )
141
+
142
+ args = parser.parse_args()
143
+
144
+ start = time.time()
145
+
146
+ ensemble_preds = ensemble_predictions(args.configs, args.method, args.no_labels)
147
+
148
+ # save to csv
149
+ ensemble_preds.to_csv(args.outfile, index=False)
150
+ print(f"Ensemble predictions saved to: {args.outfile}")
151
+ print(f"Ensemble done, used {time.time()-start:.2f} seconds")
152
+
153
+ print("DONE")
154
+
155
+
156
+ if __name__ == "__main__":
157
+ cwd = Path(__file__).parent
158
+ main(cwd)
@@ -0,0 +1,118 @@
1
+ # feats_ast.py
2
+ import os
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from tqdm import tqdm
10
+ from transformers import AutoProcessor, ASTModel
11
+
12
+ import nkululeko.glob_conf as glob_conf
13
+ from nkululeko.feat_extract.featureset import Featureset
14
+
15
+
16
+ class Ast(Featureset):
17
+ """Class to extract AST (Audio Spectrogram Transformer) embeddings"""
18
+
19
+ def __init__(self, name, data_df, feat_type):
20
+ super().__init__(name, data_df, feat_type)
21
+ cuda = "cuda" if torch.cuda.is_available() else "cpu"
22
+ self.device = self.util.config_val("MODEL", "device", cuda)
23
+ self.model_initialized = False
24
+ self.feat_type = feat_type
25
+
26
+ def init_model(self):
27
+ self.util.debug("loading AST model...")
28
+ model_path = self.util.config_val(
29
+ "FEATS", "ast.model", "MIT/ast-finetuned-audioset-10-10-0.4593"
30
+ )
31
+ self.processor = AutoProcessor.from_pretrained(model_path)
32
+ self.model = ASTModel.from_pretrained(model_path).to(self.device)
33
+ print(f"initialized AST model on {self.device}")
34
+ self.model.eval()
35
+ self.model_initialized = True
36
+
37
+
38
+ def extract(self):
39
+ """Extract the features or load them from disk if present."""
40
+ store = self.util.get_path("store")
41
+ storage = f"{store}{self.name}.pkl"
42
+ extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
43
+ no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
44
+ if extract or no_reuse or not os.path.isfile(storage):
45
+ if not self.model_initialized:
46
+ self.init_model()
47
+ self.util.debug("extracting wavlm embeddings, this might take a while...")
48
+ emb_series = pd.Series(index=self.data_df.index, dtype=object)
49
+ length = len(self.data_df.index)
50
+ for idx, (file, start, end) in enumerate(
51
+ tqdm(self.data_df.index.to_list())
52
+ ):
53
+ signal, sampling_rate = torchaudio.load(
54
+ file,
55
+ frame_offset=int(start.total_seconds() * 16000),
56
+ num_frames=int((end - start).total_seconds() * 16000),
57
+ )
58
+ # make mono if stereo
59
+ if signal.shape[0] == 2:
60
+ signal = torch.mean(signal, dim=0, keepdim=True)
61
+
62
+ assert (
63
+ sampling_rate == 16000
64
+ ), f"sampling rate should be 16000 but is {sampling_rate}"
65
+ emb = self.get_embeddings(signal, sampling_rate, file)
66
+ emb_series.iloc[idx] = emb
67
+ self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
68
+ self.df.to_pickle(storage)
69
+ try:
70
+ glob_conf.config["DATA"]["needs_feature_extraction"] = "false"
71
+ except KeyError:
72
+ pass
73
+ else:
74
+ self.util.debug(f"reusing extracted {self.feat_type} embeddings")
75
+ self.df = pd.read_pickle(storage)
76
+ if self.df.isnull().values.any():
77
+ # nanrows = self.df.columns[self.df.isna().any()].tolist()
78
+ # print(nanrows)
79
+ self.util.error(
80
+ f"got nan: {self.df.shape} {self.df.isnull().sum().sum()}"
81
+ )
82
+
83
+
84
+ def get_embeddings(self, signal, sampling_rate, file):
85
+ """Extract embeddings from raw audio signal."""
86
+ try:
87
+ inputs = self.processor(signal.numpy(), sampling_rate=sampling_rate, return_tensors="pt")
88
+
89
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
90
+
91
+ with torch.no_grad():
92
+ # Get the hidden states
93
+ outputs = self.model(**inputs)
94
+
95
+ # Get the hidden states from the last layer
96
+ last_hidden_state = outputs.last_hidden_state
97
+
98
+ # print(f"last_hidden_state shape: {last_hidden_state.shape}")
99
+ # Average pooling over the time dimension
100
+ embeddings = torch.mean(last_hidden_state, dim=1)
101
+ embeddings = embeddings.cpu().numpy()
102
+
103
+ # convert the same from (768,) to (1, 768)
104
+ # embeddings = embeddings.reshape(1, -1)
105
+ print(f"hs shape: {embeddings.shape}")
106
+
107
+
108
+ except Exception as e:
109
+ self.util.error(f"Error extracting embeddings for file {file}: {str(e)}, fill with")
110
+ return np.zeros(
111
+ self.model.config.hidden_size
112
+ ) # Return zero vector on error
113
+ return embeddings.ravel()
114
+
115
+ def extract_sample(self, signal, sr):
116
+ self.init_model()
117
+ feats = self.get_embeddings(signal, sr, "no file")
118
+ return feats
@@ -47,9 +47,7 @@ class Wav2vec2(Featureset):
47
47
  config.num_hidden_layers = layer_num - hidden_layer
48
48
  self.util.debug(f"using hidden layer #{config.num_hidden_layers}")
49
49
  self.processor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
50
- self.model = Wav2Vec2Model.from_pretrained(model_path, config=config).to(
51
- self.device
52
- )
50
+ self.model = Wav2Vec2Model.from_pretrained(model_path, config=config).to(self.device)
53
51
  print(f"intialized Wav2vec model on {self.device}")
54
52
  self.model.eval()
55
53
  self.model_initialized = True
@@ -90,7 +88,7 @@ class Wav2vec2(Featureset):
90
88
  self.util.debug("reusing extracted wav2vec2 embeddings")
91
89
  self.df = pd.read_pickle(storage)
92
90
  if self.df.isnull().values.any():
93
- nanrows = self.df.columns[self.df.isna().any()].tolist()
91
+ # nanrows = self.df.columns[self.df.isna().any()].tolist()
94
92
  # print(nanrows)
95
93
  self.util.error(
96
94
  f"got nan: {self.df.shape} {self.df.isnull().sum().sum()}"
@@ -79,8 +79,8 @@ class Wavlm(Featureset):
79
79
  self.util.debug(f"reusing extracted {self.feat_type} embeddings")
80
80
  self.df = pd.read_pickle(storage)
81
81
  if self.df.isnull().values.any():
82
- nanrows = self.df.columns[self.df.isna().any()].tolist()
83
- print(nanrows)
82
+ # nanrows = self.df.columns[self.df.isna().any()].tolist()
83
+ # print(nanrows)
84
84
  self.util.error(
85
85
  f"got nan: {self.df.shape} {self.df.isnull().sum().sum()}"
86
86
  )
@@ -104,11 +104,14 @@ class Wavlm(Featureset):
104
104
  # pool result and convert to numpy
105
105
  y = torch.mean(y, dim=1)
106
106
  y = y.detach().cpu().numpy()
107
+
108
+ # print(f"hs shape: {y.shape}")
109
+
107
110
  except RuntimeError as re:
108
111
  print(str(re))
109
- self.util.error(f"couldn't extract file: {file}")
112
+ self.util.error(f"Couldn't extract file: {file}")
110
113
 
111
- return y.flatten()
114
+ return y.ravel()
112
115
 
113
116
  def extract_sample(self, signal, sr):
114
117
  self.init_model()
@@ -39,12 +39,10 @@ class FeatureExtractor:
39
39
  self.feats = pd.DataFrame()
40
40
  for feats_type in self.feats_types:
41
41
  store_name = f"{self.data_name}_{feats_type}"
42
- self.feat_extractor = self._get_feat_extractor(
43
- store_name, feats_type)
42
+ self.feat_extractor = self._get_feat_extractor(store_name, feats_type)
44
43
  self.feat_extractor.extract()
45
44
  self.feat_extractor.filter()
46
- self.feats = pd.concat(
47
- [self.feats, self.feat_extractor.df], axis=1)
45
+ self.feats = pd.concat([self.feats, self.feat_extractor.df], axis=1)
48
46
  return self.feats
49
47
 
50
48
  def extract_sample(self, signal, sr):
@@ -77,7 +75,7 @@ class FeatureExtractor:
77
75
  return TRILLset
78
76
 
79
77
  elif feats_type.startswith(
80
- ("wav2vec2", "hubert", "wavlm", "spkrec", "whisper")
78
+ ("wav2vec2", "hubert", "wavlm", "spkrec", "whisper", "ast")
81
79
  ):
82
80
  return self._get_feat_extractor_by_prefix(feats_type)
83
81
 
@@ -107,15 +105,13 @@ class FeatureExtractor:
107
105
  prefix, _, ext = feats_type.partition("-")
108
106
  from importlib import import_module
109
107
 
110
- module = import_module(
111
- f"nkululeko.feat_extract.feats_{prefix.lower()}")
108
+ module = import_module(f"nkululeko.feat_extract.feats_{prefix.lower()}")
112
109
  class_name = f"{prefix.capitalize()}"
113
110
  return getattr(module, class_name)
114
111
 
115
112
  def _get_feat_extractor_by_name(self, feats_type):
116
113
  from importlib import import_module
117
114
 
118
- module = import_module(
119
- f"nkululeko.feat_extract.feats_{feats_type.lower()}")
115
+ module = import_module(f"nkululeko.feat_extract.feats_{feats_type.lower()}")
120
116
  class_name = f"{feats_type.capitalize()}Set"
121
117
  return getattr(module, class_name)
nkululeko/modelrunner.py CHANGED
@@ -85,7 +85,7 @@ class Modelrunner:
85
85
  f"run: {self.run} epoch: {epoch}: result: {test_score_metric}"
86
86
  )
87
87
  # print(f"performance: {performance.split(' ')[1]}")
88
- performance = float(test_score_metric.split(' ')[1])
88
+ performance = float(test_score_metric.split(" ")[1])
89
89
  if performance > self.best_performance:
90
90
  self.best_performance = performance
91
91
  self.best_epoch = epoch
@@ -204,15 +204,15 @@ class Modelrunner:
204
204
  self.df_train, self.df_test, self.feats_train, self.feats_test
205
205
  )
206
206
  elif model_type == "cnn":
207
- from nkululeko.models.model_cnn import CNN_model
207
+ from nkululeko.models.model_cnn import CNNModel
208
208
 
209
- self.model = CNN_model(
209
+ self.model = CNNModel(
210
210
  self.df_train, self.df_test, self.feats_train, self.feats_test
211
211
  )
212
212
  elif model_type == "mlp":
213
- from nkululeko.models.model_mlp import MLP_model
213
+ from nkululeko.models.model_mlp import MLPModel
214
214
 
215
- self.model = MLP_model(
215
+ self.model = MLPModel(
216
216
  self.df_train, self.df_test, self.feats_train, self.feats_test
217
217
  )
218
218
  elif model_type == "mlp_reg":
nkululeko/models/model.py CHANGED
@@ -247,8 +247,25 @@ class Model:
247
247
  self.clf.fit(feats, labels)
248
248
 
249
249
  def get_predictions(self):
250
- predictions = self.clf.predict(self.feats_test.to_numpy())
251
- return predictions
250
+ # predictions = self.clf.predict(self.feats_test.to_numpy())
251
+ if self.util.exp_is_classification():
252
+ # make a dataframe for the class probabilities
253
+ proba_d = {}
254
+ for c in self.clf.classes_:
255
+ proba_d[c] = []
256
+ # get the class probabilities
257
+ predictions = self.clf.predict_proba(self.feats_test.to_numpy())
258
+ # pred = self.clf.predict(features)
259
+ for i, c in enumerate(self.clf.classes_):
260
+ proba_d[c] = list(predictions.T[i])
261
+ probas = pd.DataFrame(proba_d)
262
+ probas = probas.set_index(self.feats_test.index)
263
+ predictions = probas.idxmax(axis=1).values
264
+ else:
265
+ predictions = self.clf.predict(self.feats_test.to_numpy())
266
+ probas = None
267
+
268
+ return predictions, probas
252
269
 
253
270
  def predict(self):
254
271
  if self.feats_test.isna().to_numpy().any():
@@ -263,13 +280,16 @@ class Model:
263
280
  )
264
281
  return report
265
282
  """Predict the whole eval feature set"""
266
- predictions = self.get_predictions()
283
+ predictions, probas = self.get_predictions()
284
+
267
285
  report = Reporter(
268
286
  self.df_test[self.target].to_numpy().astype(float),
269
287
  predictions,
270
288
  self.run,
271
289
  self.epoch,
290
+ probas=probas,
272
291
  )
292
+ report.print_probabilities()
273
293
  return report
274
294
 
275
295
  def get_type(self):
@@ -5,33 +5,40 @@ Inspired by code from Su Lei
5
5
 
6
6
  """
7
7
 
8
+ import ast
9
+ from collections import OrderedDict
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from PIL import Image
14
+ from sklearn.metrics import recall_score
8
15
  import torch
9
16
  import torch.nn as nn
10
17
  import torch.nn.functional as F
11
- import torchvision
12
- import torchvision.transforms as transforms
13
18
  from torch.utils.data import Dataset
14
- import ast
15
- import numpy as np
16
- from sklearn.metrics import recall_score
17
- from collections import OrderedDict
18
- from PIL import Image
19
- from traitlets import default
19
+ import torchvision.transforms as transforms
20
20
 
21
- from nkululeko.utils.util import Util
22
21
  import nkululeko.glob_conf as glob_conf
22
+ from nkululeko.losses.loss_softf1loss import SoftF1Loss
23
23
  from nkululeko.models.model import Model
24
24
  from nkululeko.reporting.reporter import Reporter
25
- from nkululeko.losses.loss_softf1loss import SoftF1Loss
25
+ from nkululeko.utils.util import Util
26
26
 
27
27
 
28
- class CNN_model(Model):
29
- """CNN = convolutional neural net"""
28
+ class CNNModel(Model):
29
+ """CNN = convolutional neural net."""
30
30
 
31
31
  is_classifier = True
32
32
 
33
33
  def __init__(self, df_train, df_test, feats_train, feats_test):
34
- """Constructor taking the configuration and all dataframes"""
34
+ """Constructor, taking all dataframes.
35
+
36
+ Args:
37
+ df_train (pd.DataFrame): The train labels.
38
+ df_test (pd.DataFrame): The test labels.
39
+ feats_train (pd.DataFrame): The train features.
40
+ feats_test (pd.DataFrame): The test features.
41
+ """
35
42
  super().__init__(df_train, df_test, feats_train, feats_test)
36
43
  super().set_model_type("ann")
37
44
  self.name = "cnn"
@@ -147,7 +154,20 @@ class CNN_model(Model):
147
154
  self.optimizer.step()
148
155
  self.loss = (np.asarray(losses)).mean()
149
156
 
150
- def evaluate_model(self, model, loader, device):
157
+ def get_probas(self, logits):
158
+ # make a dataframe for probabilites (logits)
159
+ proba_d = {}
160
+ classes = self.df_test[self.target].unique()
161
+ classes.sort()
162
+ for c in classes:
163
+ proba_d[c] = []
164
+ for i, c in enumerate(classes):
165
+ proba_d[c] = list(logits.numpy().T[i])
166
+ probas = pd.DataFrame(proba_d)
167
+ probas = probas.set_index(self.df_test.index)
168
+ return probas
169
+
170
+ def evaluate(self, model, loader, device):
151
171
  logits = torch.zeros(len(loader.dataset), self.class_num)
152
172
  targets = torch.zeros(len(loader.dataset))
153
173
  model.eval()
@@ -169,14 +189,15 @@ class CNN_model(Model):
169
189
  self.loss_eval = (np.asarray(losses)).mean()
170
190
  predictions = logits.argmax(dim=1)
171
191
  uar = recall_score(targets.numpy(), predictions.numpy(), average="macro")
172
- return uar, targets, predictions
192
+ return uar, targets, predictions, logits
173
193
 
174
194
  def predict(self):
175
- _, truths, predictions = self.evaluate_model(
195
+ _, truths, predictions, logits = self.evaluate(
176
196
  self.model, self.testloader, self.device
177
197
  )
178
- uar, _, _ = self.evaluate_model(self.model, self.trainloader, self.device)
179
- report = Reporter(truths, predictions, self.run, self.epoch)
198
+ uar, _, _, _ = self.evaluate(self.model, self.trainloader, self.device)
199
+ probas = self.get_probas(logits)
200
+ report = Reporter(truths, predictions, self.run, self.epoch, probas=probas)
180
201
  try:
181
202
  report.result.loss = self.loss
182
203
  except AttributeError: # if the model was loaded from disk the loss is unknown
@@ -189,13 +210,11 @@ class CNN_model(Model):
189
210
  return report
190
211
 
191
212
  def get_predictions(self):
192
- _, truths, predictions = self.evaluate_model(
193
- self.model, self.testloader, self.device
194
- )
213
+ _, _, predictions, _ = self.evaluate(self.model, self.testloader, self.device)
195
214
  return predictions.numpy()
196
215
 
197
216
  def predict_sample(self, features):
198
- """Predict one sample"""
217
+ """Predict one sample."""
199
218
  with torch.no_grad():
200
219
  logits = self.model(torch.from_numpy(features).to(self.device))
201
220
  a = logits.numpy()