nkululeko 0.84.0__py3-none-any.whl → 0.85.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,16 +12,19 @@ from nkululeko.utils.util import Util
12
12
 
13
13
 
14
14
  class Resampler:
15
- def __init__(self, df, not_testing=True):
15
+ def __init__(self, df, replace, not_testing=True):
16
16
  self.SAMPLING_RATE = 16000
17
17
  self.df = df
18
18
  self.util = Util("resampler", has_config=not_testing)
19
19
  self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
20
20
  self.not_testing = not_testing
21
+ self.replace = eval(self.util.config_val(
22
+ "RESAMPLE", "replace", "False")) if not not_testing else replace
21
23
 
22
24
  def resample(self):
23
25
  files = self.df.index.get_level_values(0).values
24
- replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
26
+ # replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
27
+ replace = self.replace
25
28
  if self.not_testing:
26
29
  store = self.util.get_path("store")
27
30
  else:
@@ -42,7 +45,8 @@ class Resampler:
42
45
  continue
43
46
  if org_sr != self.SAMPLING_RATE:
44
47
  self.util.debug(f"resampling {f} (sr = {org_sr})")
45
- resampler = torchaudio.transforms.Resample(org_sr, self.SAMPLING_RATE)
48
+ resampler = torchaudio.transforms.Resample(
49
+ org_sr, self.SAMPLING_RATE)
46
50
  signal = resampler(signal)
47
51
  if replace:
48
52
  torchaudio.save(
@@ -59,7 +63,8 @@ class Resampler:
59
63
  self.df = self.df.set_index(
60
64
  self.df.index.set_levels(new_files, level="file")
61
65
  )
62
- target_file = self.util.config_val("RESAMPLE", "target", "resampled.csv")
66
+ target_file = self.util.config_val(
67
+ "RESAMPLE", "target", "resampled.csv")
63
68
  # remove encoded labels
64
69
  target = self.util.config_val("DATA", "target", "emotion")
65
70
  if "class_label" in self.df.columns:
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.84.0"
1
+ VERSION="0.85.0"
2
2
  SAMPLING_RATE = 16000
nkululeko/experiment.py CHANGED
@@ -340,7 +340,12 @@ class Experiment:
340
340
  df_train, df_test = self.df_train, self.df_test
341
341
  feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
342
342
  self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
343
- feats_types = self.util.config_val_list("FEATS", "type", ["os"])
343
+ feats_types = self.util.config_val_list("FEATS", "type", [])
344
+ # for some models no features are needed
345
+ if len(feats_types) == 0:
346
+ self.util.debug("no feature extractor specified.")
347
+ self.feats_train, self.feats_test = pd.DataFrame(), pd.DataFrame()
348
+ return
344
349
  self.feature_extractor = FeatureExtractor(
345
350
  df_train, feats_types, feats_name, "train"
346
351
  )
@@ -32,22 +32,19 @@ class Whisper(Featureset):
32
32
  model_name = f"openai/{self.feat_type}"
33
33
  self.model = WhisperModel.from_pretrained(model_name).to(self.device)
34
34
  print(f"intialized Whisper model on {self.device}")
35
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(
36
- model_name)
35
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
37
36
  self.model_initialized = True
38
37
 
39
38
  def extract(self):
40
39
  """Extract the features or load them from disk if present."""
41
40
  store = self.util.get_path("store")
42
41
  storage = f"{store}{self.name}.pkl"
43
- extract = self.util.config_val(
44
- "FEATS", "needs_feature_extraction", False)
42
+ extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
45
43
  no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
46
44
  if extract or no_reuse or not os.path.isfile(storage):
47
45
  if not self.model_initialized:
48
46
  self.init_model()
49
- self.util.debug(
50
- "extracting whisper embeddings, this might take a while...")
47
+ self.util.debug("extracting whisper embeddings, this might take a while...")
51
48
  emb_series = []
52
49
  for (file, start, end), _ in audeer.progress_bar(
53
50
  self.data_df.iterrows(),
nkululeko/modelrunner.py CHANGED
@@ -47,16 +47,12 @@ class Modelrunner:
47
47
  highest = 0
48
48
  else:
49
49
  highest = 100000
50
- # for all epochs
51
- for epoch in range(epoch_num):
52
- if only_test:
53
- self.model.load(self.run, epoch)
54
- self.util.debug(f"reusing model: {self.model.store_path}")
55
- self.model.reset_test(self.df_test, self.feats_test)
56
- else:
57
- self.model.set_id(self.run, epoch)
58
- self.model.train()
50
+ if self.model.model_type == "finetuned":
51
+ # epochs are handled by Huggingface API
52
+ self.model.train()
59
53
  report = self.model.predict()
54
+ # todo: findout the best epoch
55
+ epoch = epoch_num
60
56
  report.set_id(self.run, epoch)
61
57
  plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
62
58
  reports.append(report)
@@ -67,32 +63,53 @@ class Modelrunner:
67
63
  if plot_epochs:
68
64
  self.util.debug(f"plotting conf matrix to {plot_name}")
69
65
  report.plot_confmatrix(plot_name, epoch)
70
- store_models = self.util.config_val("EXP", "save", False)
71
- plot_best_model = self.util.config_val("PLOT", "best_model", False)
72
- if (store_models or plot_best_model) and (
73
- not only_test
74
- ): # in any case the model needs to be stored to disk.
75
- self.model.store()
76
- if patience:
77
- patience = int(patience)
78
- result = report.result.get_result()
79
- if self.util.high_is_good():
80
- if result > highest:
81
- highest = result
82
- patience_counter = 0
83
- else:
84
- patience_counter += 1
66
+ else:
67
+ # for all epochs
68
+ for epoch in range(epoch_num):
69
+ if only_test:
70
+ self.model.load(self.run, epoch)
71
+ self.util.debug(f"reusing model: {self.model.store_path}")
72
+ self.model.reset_test(self.df_test, self.feats_test)
85
73
  else:
86
- if result < highest:
87
- highest = result
88
- patience_counter = 0
74
+ self.model.set_id(self.run, epoch)
75
+ self.model.train()
76
+ report = self.model.predict()
77
+ report.set_id(self.run, epoch)
78
+ plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
79
+ reports.append(report)
80
+ self.util.debug(
81
+ f"run: {self.run} epoch: {epoch}: result: "
82
+ f"{reports[-1].get_result().get_test_result()}"
83
+ )
84
+ if plot_epochs:
85
+ self.util.debug(f"plotting conf matrix to {plot_name}")
86
+ report.plot_confmatrix(plot_name, epoch)
87
+ store_models = self.util.config_val("EXP", "save", False)
88
+ plot_best_model = self.util.config_val("PLOT", "best_model", False)
89
+ if (store_models or plot_best_model) and (
90
+ not only_test
91
+ ): # in any case the model needs to be stored to disk.
92
+ self.model.store()
93
+ if patience:
94
+ patience = int(patience)
95
+ result = report.result.get_result()
96
+ if self.util.high_is_good():
97
+ if result > highest:
98
+ highest = result
99
+ patience_counter = 0
100
+ else:
101
+ patience_counter += 1
89
102
  else:
90
- patience_counter += 1
91
- if patience_counter >= patience:
92
- self.util.debug(
93
- f"reached patience ({str(patience)}): early stopping"
94
- )
95
- break
103
+ if result < highest:
104
+ highest = result
105
+ patience_counter = 0
106
+ else:
107
+ patience_counter += 1
108
+ if patience_counter >= patience:
109
+ self.util.debug(
110
+ f"reached patience ({str(patience)}): early stopping"
111
+ )
112
+ break
96
113
 
97
114
  if not plot_epochs:
98
115
  # Do at least one confusion matrix plot
@@ -133,6 +150,12 @@ class Modelrunner:
133
150
  self.model = Bayes_model(
134
151
  self.df_train, self.df_test, self.feats_train, self.feats_test
135
152
  )
153
+ elif model_type == "finetune":
154
+ from nkululeko.models.model_tuned import Pretrained_model
155
+
156
+ self.model = Pretrained_model(
157
+ self.df_train, self.df_test, self.feats_train, self.feats_test
158
+ )
136
159
  elif model_type == "gmm":
137
160
  from nkululeko.models.model_gmm import GMM_model
138
161
 
@@ -0,0 +1,190 @@
1
+ """
2
+ Code based on @jwagner
3
+ """
4
+
5
+ import dataclasses
6
+ import typing
7
+
8
+ import torch
9
+ import transformers
10
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
11
+ Wav2Vec2PreTrainedModel,
12
+ Wav2Vec2Model,
13
+ )
14
+
15
+
16
+ class ConcordanceCorCoeff(torch.nn.Module):
17
+
18
+ def __init__(self):
19
+
20
+ super().__init__()
21
+
22
+ self.mean = torch.mean
23
+ self.var = torch.var
24
+ self.sum = torch.sum
25
+ self.sqrt = torch.sqrt
26
+ self.std = torch.std
27
+
28
+ def forward(self, prediction, ground_truth):
29
+
30
+ mean_gt = self.mean(ground_truth, 0)
31
+ mean_pred = self.mean(prediction, 0)
32
+ var_gt = self.var(ground_truth, 0)
33
+ var_pred = self.var(prediction, 0)
34
+ v_pred = prediction - mean_pred
35
+ v_gt = ground_truth - mean_gt
36
+ cor = self.sum(v_pred * v_gt) / (
37
+ self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
38
+ )
39
+ sd_gt = self.std(ground_truth)
40
+ sd_pred = self.std(prediction)
41
+ numerator = 2 * cor * sd_gt * sd_pred
42
+ denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
43
+ ccc = numerator / denominator
44
+
45
+ return 1 - ccc
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class ModelOutput(transformers.file_utils.ModelOutput):
50
+
51
+ logits_cat: torch.FloatTensor = None
52
+ hidden_states: typing.Tuple[torch.FloatTensor] = None
53
+ cnn_features: torch.FloatTensor = None
54
+
55
+
56
+ class ModelHead(torch.nn.Module):
57
+
58
+ def __init__(self, config, num_labels):
59
+
60
+ super().__init__()
61
+
62
+ self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
63
+ self.dropout = torch.nn.Dropout(config.final_dropout)
64
+ self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
65
+
66
+ def forward(self, features, **kwargs):
67
+
68
+ x = features
69
+ x = self.dropout(x)
70
+ x = self.dense(x)
71
+ x = torch.tanh(x)
72
+ x = self.dropout(x)
73
+ x = self.out_proj(x)
74
+
75
+ return x
76
+
77
+
78
+ class Model(Wav2Vec2PreTrainedModel):
79
+
80
+ def __init__(self, config):
81
+
82
+ super().__init__(config)
83
+
84
+ self.wav2vec2 = Wav2Vec2Model(config)
85
+ self.cat = ModelHead(config, 2)
86
+ self.init_weights()
87
+
88
+ def freeze_feature_extractor(self):
89
+ self.wav2vec2.feature_extractor._freeze_parameters()
90
+
91
+ def pooling(
92
+ self,
93
+ hidden_states,
94
+ attention_mask,
95
+ ):
96
+
97
+ if attention_mask is None: # For evaluation with batch_size==1
98
+ outputs = torch.mean(hidden_states, dim=1)
99
+ else:
100
+ attention_mask = self._get_feature_vector_attention_mask(
101
+ hidden_states.shape[1],
102
+ attention_mask,
103
+ )
104
+ hidden_states = hidden_states * torch.reshape(
105
+ attention_mask,
106
+ (-1, attention_mask.shape[-1], 1),
107
+ )
108
+ outputs = torch.sum(hidden_states, dim=1)
109
+ attention_sum = torch.sum(attention_mask, dim=1)
110
+ outputs = outputs / torch.reshape(attention_sum, (-1, 1))
111
+
112
+ return outputs
113
+
114
+ def forward(
115
+ self,
116
+ input_values,
117
+ attention_mask=None,
118
+ labels=None,
119
+ return_hidden=False,
120
+ ):
121
+
122
+ outputs = self.wav2vec2(
123
+ input_values,
124
+ attention_mask=attention_mask,
125
+ )
126
+
127
+ cnn_features = outputs.extract_features
128
+ hidden_states_framewise = outputs.last_hidden_state
129
+ hidden_states = self.pooling(
130
+ hidden_states_framewise,
131
+ attention_mask,
132
+ )
133
+ logits_cat = self.cat(hidden_states)
134
+
135
+ if not self.training:
136
+ logits_cat = torch.softmax(logits_cat, dim=1)
137
+
138
+ if return_hidden:
139
+
140
+ # make time last axis
141
+ cnn_features = torch.transpose(cnn_features, 1, 2)
142
+
143
+ return ModelOutput(
144
+ logits_cat=logits_cat,
145
+ hidden_states=hidden_states,
146
+ cnn_features=cnn_features,
147
+ )
148
+
149
+ else:
150
+
151
+ return ModelOutput(
152
+ logits_cat=logits_cat,
153
+ )
154
+
155
+ def predict(self, signal):
156
+ result = self(torch.from_numpy(signal))
157
+ result = result[0].detach().numpy()[0]
158
+ return result
159
+
160
+
161
+ class ModelWithPreProcessing(Model):
162
+
163
+ def __init__(self, config):
164
+ super().__init__(config)
165
+
166
+ def forward(
167
+ self,
168
+ input_values,
169
+ ):
170
+ # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
171
+ # normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
172
+
173
+ mean = input_values.mean()
174
+
175
+ # var = input_values.var()
176
+ # raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
177
+
178
+ var = torch.square(input_values - mean).mean()
179
+ input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
180
+
181
+ output = super().forward(
182
+ input_values,
183
+ return_hidden=True,
184
+ )
185
+
186
+ return (
187
+ output.hidden_states,
188
+ output.logits_cat,
189
+ output.cnn_features,
190
+ )
nkululeko/models/model.py CHANGED
@@ -39,7 +39,7 @@ class Model:
39
39
  self.model_type = type
40
40
 
41
41
  def is_ann(self):
42
- if self.model_type == "ann":
42
+ if (self.model_type == "ann") or (self.model_type == "finetuned"):
43
43
  return True
44
44
  else:
45
45
  return False