nkululeko 0.84.0__py3-none-any.whl → 0.85.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/augmenting/resampler.py +9 -4
- nkululeko/constants.py +1 -1
- nkululeko/experiment.py +6 -1
- nkululeko/feat_extract/feats_whisper.py +3 -6
- nkululeko/modelrunner.py +56 -33
- nkululeko/models/finetune_model.py +190 -0
- nkululeko/models/model.py +1 -1
- nkululeko/models/model_tuned.py +506 -0
- nkululeko/resample.py +76 -54
- nkululeko/test_pretrain.py +200 -11
- nkululeko/utils/util.py +53 -32
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/METADATA +9 -1
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/RECORD +16 -14
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/LICENSE +0 -0
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/WHEEL +0 -0
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,19 @@ from nkululeko.utils.util import Util
|
|
12
12
|
|
13
13
|
|
14
14
|
class Resampler:
|
15
|
-
def __init__(self, df, not_testing=True):
|
15
|
+
def __init__(self, df, replace, not_testing=True):
|
16
16
|
self.SAMPLING_RATE = 16000
|
17
17
|
self.df = df
|
18
18
|
self.util = Util("resampler", has_config=not_testing)
|
19
19
|
self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
|
20
20
|
self.not_testing = not_testing
|
21
|
+
self.replace = eval(self.util.config_val(
|
22
|
+
"RESAMPLE", "replace", "False")) if not not_testing else replace
|
21
23
|
|
22
24
|
def resample(self):
|
23
25
|
files = self.df.index.get_level_values(0).values
|
24
|
-
replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
|
26
|
+
# replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
|
27
|
+
replace = self.replace
|
25
28
|
if self.not_testing:
|
26
29
|
store = self.util.get_path("store")
|
27
30
|
else:
|
@@ -42,7 +45,8 @@ class Resampler:
|
|
42
45
|
continue
|
43
46
|
if org_sr != self.SAMPLING_RATE:
|
44
47
|
self.util.debug(f"resampling {f} (sr = {org_sr})")
|
45
|
-
resampler = torchaudio.transforms.Resample(
|
48
|
+
resampler = torchaudio.transforms.Resample(
|
49
|
+
org_sr, self.SAMPLING_RATE)
|
46
50
|
signal = resampler(signal)
|
47
51
|
if replace:
|
48
52
|
torchaudio.save(
|
@@ -59,7 +63,8 @@ class Resampler:
|
|
59
63
|
self.df = self.df.set_index(
|
60
64
|
self.df.index.set_levels(new_files, level="file")
|
61
65
|
)
|
62
|
-
target_file = self.util.config_val(
|
66
|
+
target_file = self.util.config_val(
|
67
|
+
"RESAMPLE", "target", "resampled.csv")
|
63
68
|
# remove encoded labels
|
64
69
|
target = self.util.config_val("DATA", "target", "emotion")
|
65
70
|
if "class_label" in self.df.columns:
|
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.
|
1
|
+
VERSION="0.85.0"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/experiment.py
CHANGED
@@ -340,7 +340,12 @@ class Experiment:
|
|
340
340
|
df_train, df_test = self.df_train, self.df_test
|
341
341
|
feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
|
342
342
|
self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
|
343
|
-
feats_types = self.util.config_val_list("FEATS", "type", [
|
343
|
+
feats_types = self.util.config_val_list("FEATS", "type", [])
|
344
|
+
# for some models no features are needed
|
345
|
+
if len(feats_types) == 0:
|
346
|
+
self.util.debug("no feature extractor specified.")
|
347
|
+
self.feats_train, self.feats_test = pd.DataFrame(), pd.DataFrame()
|
348
|
+
return
|
344
349
|
self.feature_extractor = FeatureExtractor(
|
345
350
|
df_train, feats_types, feats_name, "train"
|
346
351
|
)
|
@@ -32,22 +32,19 @@ class Whisper(Featureset):
|
|
32
32
|
model_name = f"openai/{self.feat_type}"
|
33
33
|
self.model = WhisperModel.from_pretrained(model_name).to(self.device)
|
34
34
|
print(f"intialized Whisper model on {self.device}")
|
35
|
-
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
36
|
-
model_name)
|
35
|
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
37
36
|
self.model_initialized = True
|
38
37
|
|
39
38
|
def extract(self):
|
40
39
|
"""Extract the features or load them from disk if present."""
|
41
40
|
store = self.util.get_path("store")
|
42
41
|
storage = f"{store}{self.name}.pkl"
|
43
|
-
extract = self.util.config_val(
|
44
|
-
"FEATS", "needs_feature_extraction", False)
|
42
|
+
extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
|
45
43
|
no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
|
46
44
|
if extract or no_reuse or not os.path.isfile(storage):
|
47
45
|
if not self.model_initialized:
|
48
46
|
self.init_model()
|
49
|
-
self.util.debug(
|
50
|
-
"extracting whisper embeddings, this might take a while...")
|
47
|
+
self.util.debug("extracting whisper embeddings, this might take a while...")
|
51
48
|
emb_series = []
|
52
49
|
for (file, start, end), _ in audeer.progress_bar(
|
53
50
|
self.data_df.iterrows(),
|
nkululeko/modelrunner.py
CHANGED
@@ -47,16 +47,12 @@ class Modelrunner:
|
|
47
47
|
highest = 0
|
48
48
|
else:
|
49
49
|
highest = 100000
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
self.model.load(self.run, epoch)
|
54
|
-
self.util.debug(f"reusing model: {self.model.store_path}")
|
55
|
-
self.model.reset_test(self.df_test, self.feats_test)
|
56
|
-
else:
|
57
|
-
self.model.set_id(self.run, epoch)
|
58
|
-
self.model.train()
|
50
|
+
if self.model.model_type == "finetuned":
|
51
|
+
# epochs are handled by Huggingface API
|
52
|
+
self.model.train()
|
59
53
|
report = self.model.predict()
|
54
|
+
# todo: findout the best epoch
|
55
|
+
epoch = epoch_num
|
60
56
|
report.set_id(self.run, epoch)
|
61
57
|
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
62
58
|
reports.append(report)
|
@@ -67,32 +63,53 @@ class Modelrunner:
|
|
67
63
|
if plot_epochs:
|
68
64
|
self.util.debug(f"plotting conf matrix to {plot_name}")
|
69
65
|
report.plot_confmatrix(plot_name, epoch)
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
patience = int(patience)
|
78
|
-
result = report.result.get_result()
|
79
|
-
if self.util.high_is_good():
|
80
|
-
if result > highest:
|
81
|
-
highest = result
|
82
|
-
patience_counter = 0
|
83
|
-
else:
|
84
|
-
patience_counter += 1
|
66
|
+
else:
|
67
|
+
# for all epochs
|
68
|
+
for epoch in range(epoch_num):
|
69
|
+
if only_test:
|
70
|
+
self.model.load(self.run, epoch)
|
71
|
+
self.util.debug(f"reusing model: {self.model.store_path}")
|
72
|
+
self.model.reset_test(self.df_test, self.feats_test)
|
85
73
|
else:
|
86
|
-
|
87
|
-
|
88
|
-
|
74
|
+
self.model.set_id(self.run, epoch)
|
75
|
+
self.model.train()
|
76
|
+
report = self.model.predict()
|
77
|
+
report.set_id(self.run, epoch)
|
78
|
+
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
79
|
+
reports.append(report)
|
80
|
+
self.util.debug(
|
81
|
+
f"run: {self.run} epoch: {epoch}: result: "
|
82
|
+
f"{reports[-1].get_result().get_test_result()}"
|
83
|
+
)
|
84
|
+
if plot_epochs:
|
85
|
+
self.util.debug(f"plotting conf matrix to {plot_name}")
|
86
|
+
report.plot_confmatrix(plot_name, epoch)
|
87
|
+
store_models = self.util.config_val("EXP", "save", False)
|
88
|
+
plot_best_model = self.util.config_val("PLOT", "best_model", False)
|
89
|
+
if (store_models or plot_best_model) and (
|
90
|
+
not only_test
|
91
|
+
): # in any case the model needs to be stored to disk.
|
92
|
+
self.model.store()
|
93
|
+
if patience:
|
94
|
+
patience = int(patience)
|
95
|
+
result = report.result.get_result()
|
96
|
+
if self.util.high_is_good():
|
97
|
+
if result > highest:
|
98
|
+
highest = result
|
99
|
+
patience_counter = 0
|
100
|
+
else:
|
101
|
+
patience_counter += 1
|
89
102
|
else:
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
103
|
+
if result < highest:
|
104
|
+
highest = result
|
105
|
+
patience_counter = 0
|
106
|
+
else:
|
107
|
+
patience_counter += 1
|
108
|
+
if patience_counter >= patience:
|
109
|
+
self.util.debug(
|
110
|
+
f"reached patience ({str(patience)}): early stopping"
|
111
|
+
)
|
112
|
+
break
|
96
113
|
|
97
114
|
if not plot_epochs:
|
98
115
|
# Do at least one confusion matrix plot
|
@@ -133,6 +150,12 @@ class Modelrunner:
|
|
133
150
|
self.model = Bayes_model(
|
134
151
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
135
152
|
)
|
153
|
+
elif model_type == "finetune":
|
154
|
+
from nkululeko.models.model_tuned import Pretrained_model
|
155
|
+
|
156
|
+
self.model = Pretrained_model(
|
157
|
+
self.df_train, self.df_test, self.feats_train, self.feats_test
|
158
|
+
)
|
136
159
|
elif model_type == "gmm":
|
137
160
|
from nkululeko.models.model_gmm import GMM_model
|
138
161
|
|
@@ -0,0 +1,190 @@
|
|
1
|
+
"""
|
2
|
+
Code based on @jwagner
|
3
|
+
"""
|
4
|
+
|
5
|
+
import dataclasses
|
6
|
+
import typing
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import transformers
|
10
|
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
11
|
+
Wav2Vec2PreTrainedModel,
|
12
|
+
Wav2Vec2Model,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
class ConcordanceCorCoeff(torch.nn.Module):
|
17
|
+
|
18
|
+
def __init__(self):
|
19
|
+
|
20
|
+
super().__init__()
|
21
|
+
|
22
|
+
self.mean = torch.mean
|
23
|
+
self.var = torch.var
|
24
|
+
self.sum = torch.sum
|
25
|
+
self.sqrt = torch.sqrt
|
26
|
+
self.std = torch.std
|
27
|
+
|
28
|
+
def forward(self, prediction, ground_truth):
|
29
|
+
|
30
|
+
mean_gt = self.mean(ground_truth, 0)
|
31
|
+
mean_pred = self.mean(prediction, 0)
|
32
|
+
var_gt = self.var(ground_truth, 0)
|
33
|
+
var_pred = self.var(prediction, 0)
|
34
|
+
v_pred = prediction - mean_pred
|
35
|
+
v_gt = ground_truth - mean_gt
|
36
|
+
cor = self.sum(v_pred * v_gt) / (
|
37
|
+
self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
|
38
|
+
)
|
39
|
+
sd_gt = self.std(ground_truth)
|
40
|
+
sd_pred = self.std(prediction)
|
41
|
+
numerator = 2 * cor * sd_gt * sd_pred
|
42
|
+
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
43
|
+
ccc = numerator / denominator
|
44
|
+
|
45
|
+
return 1 - ccc
|
46
|
+
|
47
|
+
|
48
|
+
@dataclasses.dataclass
|
49
|
+
class ModelOutput(transformers.file_utils.ModelOutput):
|
50
|
+
|
51
|
+
logits_cat: torch.FloatTensor = None
|
52
|
+
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
53
|
+
cnn_features: torch.FloatTensor = None
|
54
|
+
|
55
|
+
|
56
|
+
class ModelHead(torch.nn.Module):
|
57
|
+
|
58
|
+
def __init__(self, config, num_labels):
|
59
|
+
|
60
|
+
super().__init__()
|
61
|
+
|
62
|
+
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
63
|
+
self.dropout = torch.nn.Dropout(config.final_dropout)
|
64
|
+
self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
|
65
|
+
|
66
|
+
def forward(self, features, **kwargs):
|
67
|
+
|
68
|
+
x = features
|
69
|
+
x = self.dropout(x)
|
70
|
+
x = self.dense(x)
|
71
|
+
x = torch.tanh(x)
|
72
|
+
x = self.dropout(x)
|
73
|
+
x = self.out_proj(x)
|
74
|
+
|
75
|
+
return x
|
76
|
+
|
77
|
+
|
78
|
+
class Model(Wav2Vec2PreTrainedModel):
|
79
|
+
|
80
|
+
def __init__(self, config):
|
81
|
+
|
82
|
+
super().__init__(config)
|
83
|
+
|
84
|
+
self.wav2vec2 = Wav2Vec2Model(config)
|
85
|
+
self.cat = ModelHead(config, 2)
|
86
|
+
self.init_weights()
|
87
|
+
|
88
|
+
def freeze_feature_extractor(self):
|
89
|
+
self.wav2vec2.feature_extractor._freeze_parameters()
|
90
|
+
|
91
|
+
def pooling(
|
92
|
+
self,
|
93
|
+
hidden_states,
|
94
|
+
attention_mask,
|
95
|
+
):
|
96
|
+
|
97
|
+
if attention_mask is None: # For evaluation with batch_size==1
|
98
|
+
outputs = torch.mean(hidden_states, dim=1)
|
99
|
+
else:
|
100
|
+
attention_mask = self._get_feature_vector_attention_mask(
|
101
|
+
hidden_states.shape[1],
|
102
|
+
attention_mask,
|
103
|
+
)
|
104
|
+
hidden_states = hidden_states * torch.reshape(
|
105
|
+
attention_mask,
|
106
|
+
(-1, attention_mask.shape[-1], 1),
|
107
|
+
)
|
108
|
+
outputs = torch.sum(hidden_states, dim=1)
|
109
|
+
attention_sum = torch.sum(attention_mask, dim=1)
|
110
|
+
outputs = outputs / torch.reshape(attention_sum, (-1, 1))
|
111
|
+
|
112
|
+
return outputs
|
113
|
+
|
114
|
+
def forward(
|
115
|
+
self,
|
116
|
+
input_values,
|
117
|
+
attention_mask=None,
|
118
|
+
labels=None,
|
119
|
+
return_hidden=False,
|
120
|
+
):
|
121
|
+
|
122
|
+
outputs = self.wav2vec2(
|
123
|
+
input_values,
|
124
|
+
attention_mask=attention_mask,
|
125
|
+
)
|
126
|
+
|
127
|
+
cnn_features = outputs.extract_features
|
128
|
+
hidden_states_framewise = outputs.last_hidden_state
|
129
|
+
hidden_states = self.pooling(
|
130
|
+
hidden_states_framewise,
|
131
|
+
attention_mask,
|
132
|
+
)
|
133
|
+
logits_cat = self.cat(hidden_states)
|
134
|
+
|
135
|
+
if not self.training:
|
136
|
+
logits_cat = torch.softmax(logits_cat, dim=1)
|
137
|
+
|
138
|
+
if return_hidden:
|
139
|
+
|
140
|
+
# make time last axis
|
141
|
+
cnn_features = torch.transpose(cnn_features, 1, 2)
|
142
|
+
|
143
|
+
return ModelOutput(
|
144
|
+
logits_cat=logits_cat,
|
145
|
+
hidden_states=hidden_states,
|
146
|
+
cnn_features=cnn_features,
|
147
|
+
)
|
148
|
+
|
149
|
+
else:
|
150
|
+
|
151
|
+
return ModelOutput(
|
152
|
+
logits_cat=logits_cat,
|
153
|
+
)
|
154
|
+
|
155
|
+
def predict(self, signal):
|
156
|
+
result = self(torch.from_numpy(signal))
|
157
|
+
result = result[0].detach().numpy()[0]
|
158
|
+
return result
|
159
|
+
|
160
|
+
|
161
|
+
class ModelWithPreProcessing(Model):
|
162
|
+
|
163
|
+
def __init__(self, config):
|
164
|
+
super().__init__(config)
|
165
|
+
|
166
|
+
def forward(
|
167
|
+
self,
|
168
|
+
input_values,
|
169
|
+
):
|
170
|
+
# Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
|
171
|
+
# normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
172
|
+
|
173
|
+
mean = input_values.mean()
|
174
|
+
|
175
|
+
# var = input_values.var()
|
176
|
+
# raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
|
177
|
+
|
178
|
+
var = torch.square(input_values - mean).mean()
|
179
|
+
input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
|
180
|
+
|
181
|
+
output = super().forward(
|
182
|
+
input_values,
|
183
|
+
return_hidden=True,
|
184
|
+
)
|
185
|
+
|
186
|
+
return (
|
187
|
+
output.hidden_states,
|
188
|
+
output.logits_cat,
|
189
|
+
output.cnn_features,
|
190
|
+
)
|