nkululeko 0.84.0__py3-none-any.whl → 0.84.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/augmenting/resampler.py +9 -4
- nkululeko/constants.py +1 -1
- nkululeko/models/finetune_model.py +181 -0
- nkululeko/resample.py +76 -54
- nkululeko/test_pretrain.py +188 -11
- nkululeko/utils/util.py +53 -32
- {nkululeko-0.84.0.dist-info → nkululeko-0.84.1.dist-info}/METADATA +5 -1
- {nkululeko-0.84.0.dist-info → nkululeko-0.84.1.dist-info}/RECORD +11 -10
- {nkululeko-0.84.0.dist-info → nkululeko-0.84.1.dist-info}/LICENSE +0 -0
- {nkululeko-0.84.0.dist-info → nkululeko-0.84.1.dist-info}/WHEEL +0 -0
- {nkululeko-0.84.0.dist-info → nkululeko-0.84.1.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,19 @@ from nkululeko.utils.util import Util
|
|
12
12
|
|
13
13
|
|
14
14
|
class Resampler:
|
15
|
-
def __init__(self, df, not_testing=True):
|
15
|
+
def __init__(self, df, replace, not_testing=True):
|
16
16
|
self.SAMPLING_RATE = 16000
|
17
17
|
self.df = df
|
18
18
|
self.util = Util("resampler", has_config=not_testing)
|
19
19
|
self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
|
20
20
|
self.not_testing = not_testing
|
21
|
+
self.replace = eval(self.util.config_val(
|
22
|
+
"RESAMPLE", "replace", "False")) if not not_testing else replace
|
21
23
|
|
22
24
|
def resample(self):
|
23
25
|
files = self.df.index.get_level_values(0).values
|
24
|
-
replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
|
26
|
+
# replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
|
27
|
+
replace = self.replace
|
25
28
|
if self.not_testing:
|
26
29
|
store = self.util.get_path("store")
|
27
30
|
else:
|
@@ -42,7 +45,8 @@ class Resampler:
|
|
42
45
|
continue
|
43
46
|
if org_sr != self.SAMPLING_RATE:
|
44
47
|
self.util.debug(f"resampling {f} (sr = {org_sr})")
|
45
|
-
resampler = torchaudio.transforms.Resample(
|
48
|
+
resampler = torchaudio.transforms.Resample(
|
49
|
+
org_sr, self.SAMPLING_RATE)
|
46
50
|
signal = resampler(signal)
|
47
51
|
if replace:
|
48
52
|
torchaudio.save(
|
@@ -59,7 +63,8 @@ class Resampler:
|
|
59
63
|
self.df = self.df.set_index(
|
60
64
|
self.df.index.set_levels(new_files, level="file")
|
61
65
|
)
|
62
|
-
target_file = self.util.config_val(
|
66
|
+
target_file = self.util.config_val(
|
67
|
+
"RESAMPLE", "target", "resampled.csv")
|
63
68
|
# remove encoded labels
|
64
69
|
target = self.util.config_val("DATA", "target", "emotion")
|
65
70
|
if "class_label" in self.df.columns:
|
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.84.
|
1
|
+
VERSION="0.84.1"
|
2
2
|
SAMPLING_RATE = 16000
|
@@ -0,0 +1,181 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import typing
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import transformers
|
6
|
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
7
|
+
Wav2Vec2PreTrainedModel,
|
8
|
+
Wav2Vec2Model,
|
9
|
+
)
|
10
|
+
|
11
|
+
|
12
|
+
class ConcordanceCorCoeff(torch.nn.Module):
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
|
16
|
+
super().__init__()
|
17
|
+
|
18
|
+
self.mean = torch.mean
|
19
|
+
self.var = torch.var
|
20
|
+
self.sum = torch.sum
|
21
|
+
self.sqrt = torch.sqrt
|
22
|
+
self.std = torch.std
|
23
|
+
|
24
|
+
def forward(self, prediction, ground_truth):
|
25
|
+
|
26
|
+
mean_gt = self.mean(ground_truth, 0)
|
27
|
+
mean_pred = self.mean(prediction, 0)
|
28
|
+
var_gt = self.var(ground_truth, 0)
|
29
|
+
var_pred = self.var(prediction, 0)
|
30
|
+
v_pred = prediction - mean_pred
|
31
|
+
v_gt = ground_truth - mean_gt
|
32
|
+
cor = self.sum(v_pred * v_gt) / (
|
33
|
+
self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
|
34
|
+
)
|
35
|
+
sd_gt = self.std(ground_truth)
|
36
|
+
sd_pred = self.std(prediction)
|
37
|
+
numerator = 2 * cor * sd_gt * sd_pred
|
38
|
+
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
39
|
+
ccc = numerator / denominator
|
40
|
+
|
41
|
+
return 1 - ccc
|
42
|
+
|
43
|
+
|
44
|
+
@dataclasses.dataclass
|
45
|
+
class ModelOutput(transformers.file_utils.ModelOutput):
|
46
|
+
|
47
|
+
logits_cat: torch.FloatTensor = None
|
48
|
+
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
49
|
+
cnn_features: torch.FloatTensor = None
|
50
|
+
|
51
|
+
|
52
|
+
class ModelHead(torch.nn.Module):
|
53
|
+
|
54
|
+
def __init__(self, config, num_labels):
|
55
|
+
|
56
|
+
super().__init__()
|
57
|
+
|
58
|
+
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
59
|
+
self.dropout = torch.nn.Dropout(config.final_dropout)
|
60
|
+
self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
|
61
|
+
|
62
|
+
def forward(self, features, **kwargs):
|
63
|
+
|
64
|
+
x = features
|
65
|
+
x = self.dropout(x)
|
66
|
+
x = self.dense(x)
|
67
|
+
x = torch.tanh(x)
|
68
|
+
x = self.dropout(x)
|
69
|
+
x = self.out_proj(x)
|
70
|
+
|
71
|
+
return x
|
72
|
+
|
73
|
+
|
74
|
+
class Model(Wav2Vec2PreTrainedModel):
|
75
|
+
|
76
|
+
def __init__(self, config):
|
77
|
+
|
78
|
+
super().__init__(config)
|
79
|
+
|
80
|
+
self.wav2vec2 = Wav2Vec2Model(config)
|
81
|
+
self.cat = ModelHead(config, 2)
|
82
|
+
self.init_weights()
|
83
|
+
|
84
|
+
def freeze_feature_extractor(self):
|
85
|
+
self.wav2vec2.feature_extractor._freeze_parameters()
|
86
|
+
|
87
|
+
def pooling(
|
88
|
+
self,
|
89
|
+
hidden_states,
|
90
|
+
attention_mask,
|
91
|
+
):
|
92
|
+
|
93
|
+
if attention_mask is None: # For evaluation with batch_size==1
|
94
|
+
outputs = torch.mean(hidden_states, dim=1)
|
95
|
+
else:
|
96
|
+
attention_mask = self._get_feature_vector_attention_mask(
|
97
|
+
hidden_states.shape[1],
|
98
|
+
attention_mask,
|
99
|
+
)
|
100
|
+
hidden_states = hidden_states * torch.reshape(
|
101
|
+
attention_mask,
|
102
|
+
(-1, attention_mask.shape[-1], 1),
|
103
|
+
)
|
104
|
+
outputs = torch.sum(hidden_states, dim=1)
|
105
|
+
attention_sum = torch.sum(attention_mask, dim=1)
|
106
|
+
outputs = outputs / torch.reshape(attention_sum, (-1, 1))
|
107
|
+
|
108
|
+
return outputs
|
109
|
+
|
110
|
+
def forward(
|
111
|
+
self,
|
112
|
+
input_values,
|
113
|
+
attention_mask=None,
|
114
|
+
labels=None,
|
115
|
+
return_hidden=False,
|
116
|
+
):
|
117
|
+
|
118
|
+
outputs = self.wav2vec2(
|
119
|
+
input_values,
|
120
|
+
attention_mask=attention_mask,
|
121
|
+
)
|
122
|
+
|
123
|
+
cnn_features = outputs.extract_features
|
124
|
+
hidden_states_framewise = outputs.last_hidden_state
|
125
|
+
hidden_states = self.pooling(
|
126
|
+
hidden_states_framewise,
|
127
|
+
attention_mask,
|
128
|
+
)
|
129
|
+
logits_cat = self.cat(hidden_states)
|
130
|
+
|
131
|
+
if not self.training:
|
132
|
+
logits_cat = torch.softmax(logits_cat, dim=1)
|
133
|
+
|
134
|
+
if return_hidden:
|
135
|
+
|
136
|
+
# make time last axis
|
137
|
+
cnn_features = torch.transpose(cnn_features, 1, 2)
|
138
|
+
|
139
|
+
return ModelOutput(
|
140
|
+
logits_cat=logits_cat,
|
141
|
+
hidden_states=hidden_states,
|
142
|
+
cnn_features=cnn_features,
|
143
|
+
)
|
144
|
+
|
145
|
+
else:
|
146
|
+
|
147
|
+
return ModelOutput(
|
148
|
+
logits_cat=logits_cat,
|
149
|
+
)
|
150
|
+
|
151
|
+
|
152
|
+
class ModelWithPreProcessing(Model):
|
153
|
+
|
154
|
+
def __init__(self, config):
|
155
|
+
super().__init__(config)
|
156
|
+
|
157
|
+
def forward(
|
158
|
+
self,
|
159
|
+
input_values,
|
160
|
+
):
|
161
|
+
# Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
|
162
|
+
# normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
163
|
+
|
164
|
+
mean = input_values.mean()
|
165
|
+
|
166
|
+
# var = input_values.var()
|
167
|
+
# raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
|
168
|
+
|
169
|
+
var = torch.square(input_values - mean).mean()
|
170
|
+
input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
|
171
|
+
|
172
|
+
output = super().forward(
|
173
|
+
input_values,
|
174
|
+
return_hidden=True,
|
175
|
+
)
|
176
|
+
|
177
|
+
return (
|
178
|
+
output.hidden_states,
|
179
|
+
output.logits_cat,
|
180
|
+
output.cnn_features,
|
181
|
+
)
|
nkululeko/resample.py
CHANGED
@@ -1,78 +1,100 @@
|
|
1
1
|
# resample.py
|
2
|
-
# change the sampling rate for train
|
2
|
+
# change the sampling rate for audio file or INI file (train, test, all)
|
3
3
|
|
4
4
|
import argparse
|
5
5
|
import configparser
|
6
6
|
import os
|
7
|
-
|
8
7
|
import pandas as pd
|
9
|
-
|
8
|
+
import audformat
|
10
9
|
from nkululeko.augmenting.resampler import Resampler
|
10
|
+
from nkululeko.utils.util import Util
|
11
|
+
|
11
12
|
from nkululeko.constants import VERSION
|
12
13
|
from nkululeko.experiment import Experiment
|
13
|
-
from nkululeko.utils.util import Util
|
14
14
|
|
15
15
|
|
16
16
|
def main(src_dir):
|
17
17
|
parser = argparse.ArgumentParser(
|
18
|
-
description="Call the nkululeko RESAMPLE
|
19
|
-
parser.add_argument("--config", default=
|
18
|
+
description="Call the nkululeko RESAMPLE framework.")
|
19
|
+
parser.add_argument("--config", default=None,
|
20
20
|
help="The base configuration")
|
21
|
+
parser.add_argument("--file", default=None,
|
22
|
+
help="The input audio file to resample")
|
23
|
+
parser.add_argument("--replace", action="store_true",
|
24
|
+
help="Replace the original audio file")
|
25
|
+
|
21
26
|
args = parser.parse_args()
|
22
|
-
if args.config is not None:
|
23
|
-
config_file = args.config
|
24
|
-
else:
|
25
|
-
config_file = f"{src_dir}/exp.ini"
|
26
27
|
|
27
|
-
|
28
|
-
|
29
|
-
print(f"ERROR: no such file: {config_file}")
|
28
|
+
if args.file is None and args.config is None:
|
29
|
+
print("ERROR: Either --file or --config argument must be provided.")
|
30
30
|
exit()
|
31
31
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
f" {
|
43
|
-
|
44
|
-
|
45
|
-
if util.config_val("EXP", "no_warnings", False):
|
46
|
-
import warnings
|
47
|
-
|
48
|
-
warnings.filterwarnings("ignore")
|
49
|
-
|
50
|
-
# load the data
|
51
|
-
expr.load_datasets()
|
52
|
-
|
53
|
-
# split into train and test
|
54
|
-
expr.fill_train_and_tests()
|
55
|
-
util.debug(
|
56
|
-
f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
57
|
-
|
58
|
-
sample_selection = util.config_val("RESAMPLE", "sample_selection", "all")
|
59
|
-
if sample_selection == "all":
|
60
|
-
df = pd.concat([expr.df_train, expr.df_test])
|
61
|
-
elif sample_selection == "train":
|
62
|
-
df = expr.df_train
|
63
|
-
elif sample_selection == "test":
|
64
|
-
df = expr.df_test
|
32
|
+
if args.file is not None:
|
33
|
+
# Load the audio file into a DataFrame
|
34
|
+
files = pd.Series([args.file])
|
35
|
+
df_sample = pd.DataFrame(index=files)
|
36
|
+
df_sample.index = audformat.utils.to_segmented_index(
|
37
|
+
df_sample.index, allow_nat=False
|
38
|
+
)
|
39
|
+
|
40
|
+
# Resample the audio file
|
41
|
+
util = Util("resampler", has_config=False)
|
42
|
+
util.debug(f"Resampling audio file: {args.file}")
|
43
|
+
rs = Resampler(df_sample, not_testing=True, replace=args.replace)
|
44
|
+
rs.resample()
|
65
45
|
else:
|
66
|
-
|
67
|
-
|
68
|
-
|
46
|
+
# Existing code for handling INI file
|
47
|
+
config_file = args.config
|
48
|
+
|
49
|
+
# Test if the configuration file exists
|
50
|
+
if not os.path.isfile(config_file):
|
51
|
+
print(f"ERROR: no such file: {config_file}")
|
52
|
+
exit()
|
53
|
+
|
54
|
+
# Load one configuration per experiment
|
55
|
+
config = configparser.ConfigParser()
|
56
|
+
config.read(config_file)
|
57
|
+
# Create a new experiment
|
58
|
+
expr = Experiment(config)
|
59
|
+
module = "resample"
|
60
|
+
expr.set_module(module)
|
61
|
+
util = Util(module)
|
62
|
+
util.debug(
|
63
|
+
f"running {expr.name} from config {config_file}, nkululeko version"
|
64
|
+
f" {VERSION}"
|
69
65
|
)
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
66
|
+
|
67
|
+
if util.config_val("EXP", "no_warnings", False):
|
68
|
+
import warnings
|
69
|
+
warnings.filterwarnings("ignore")
|
70
|
+
|
71
|
+
# Load the data
|
72
|
+
expr.load_datasets()
|
73
|
+
|
74
|
+
# Split into train and test
|
75
|
+
expr.fill_train_and_tests()
|
76
|
+
util.debug(
|
77
|
+
f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
78
|
+
|
79
|
+
sample_selection = util.config_val(
|
80
|
+
"RESAMPLE", "sample_selection", "all")
|
81
|
+
if sample_selection == "all":
|
82
|
+
df = pd.concat([expr.df_train, expr.df_test])
|
83
|
+
elif sample_selection == "train":
|
84
|
+
df = expr.df_train
|
85
|
+
elif sample_selection == "test":
|
86
|
+
df = expr.df_test
|
87
|
+
else:
|
88
|
+
util.error(
|
89
|
+
f"unknown selection specifier {sample_selection}, should be [all |"
|
90
|
+
" train | test]"
|
91
|
+
)
|
92
|
+
util.debug(f"resampling {sample_selection}: {df.shape[0]} samples")
|
93
|
+
replace = util.config_val("RESAMPLE", "replace", "False")
|
94
|
+
rs = Resampler(df, replace=replace)
|
95
|
+
rs.resample()
|
74
96
|
|
75
97
|
|
76
98
|
if __name__ == "__main__":
|
77
99
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
78
|
-
main(cwd)
|
100
|
+
main(cwd)
|
nkululeko/test_pretrain.py
CHANGED
@@ -11,11 +11,14 @@ import transformers
|
|
11
11
|
|
12
12
|
import audeer
|
13
13
|
import audiofile
|
14
|
+
import audmetric
|
14
15
|
|
15
16
|
from nkululeko.constants import VERSION
|
16
17
|
import nkululeko.experiment as exp
|
18
|
+
import nkululeko.models.finetune_model as fm
|
17
19
|
import nkululeko.glob_conf as glob_conf
|
18
20
|
from nkululeko.utils.util import Util
|
21
|
+
import json
|
19
22
|
|
20
23
|
|
21
24
|
def doit(config_file):
|
@@ -50,28 +53,42 @@ def doit(config_file):
|
|
50
53
|
expr.fill_train_and_tests()
|
51
54
|
util.debug(f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
52
55
|
|
56
|
+
log_root = audeer.mkdir("log")
|
57
|
+
model_root = audeer.mkdir("model")
|
58
|
+
torch_root = audeer.path(model_root, "torch")
|
59
|
+
|
60
|
+
metrics_gender = {
|
61
|
+
"UAR": audmetric.unweighted_average_recall,
|
62
|
+
"ACC": audmetric.accuracy,
|
63
|
+
}
|
64
|
+
|
53
65
|
sampling_rate = 16000
|
54
66
|
max_duration_sec = 8.0
|
55
67
|
|
56
68
|
model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
|
57
69
|
num_layers = None
|
58
70
|
|
71
|
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
72
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
73
|
+
|
59
74
|
batch_size = 16
|
60
75
|
accumulation_steps = 4
|
61
|
-
|
62
76
|
# create dataset
|
63
77
|
|
64
78
|
dataset = {}
|
79
|
+
target_name = glob_conf.target
|
65
80
|
data_sources = {
|
66
|
-
"train": pd.DataFrame(expr.df_train[
|
67
|
-
"dev": pd.DataFrame(expr.df_test[
|
81
|
+
"train": pd.DataFrame(expr.df_train[target_name]),
|
82
|
+
"dev": pd.DataFrame(expr.df_test[target_name]),
|
68
83
|
}
|
69
84
|
|
70
85
|
for split in ["train", "dev"]:
|
86
|
+
df = data_sources[split]
|
87
|
+
df[target_name] = df[target_name].astype("float")
|
71
88
|
|
72
89
|
y = pd.Series(
|
73
|
-
data=
|
74
|
-
index=
|
90
|
+
data=df.itertuples(index=False, name=None),
|
91
|
+
index=df.index,
|
75
92
|
dtype=object,
|
76
93
|
name="labels",
|
77
94
|
)
|
@@ -80,23 +97,183 @@ def doit(config_file):
|
|
80
97
|
df = y.reset_index()
|
81
98
|
df.start = df.start.dt.total_seconds()
|
82
99
|
df.end = df.end.dt.total_seconds()
|
100
|
+
|
83
101
|
print(f"{split}: {len(df)}")
|
102
|
+
|
84
103
|
ds = datasets.Dataset.from_pandas(df)
|
85
104
|
dataset[split] = ds
|
86
105
|
|
87
|
-
|
106
|
+
dataset = datasets.DatasetDict(dataset)
|
107
|
+
|
108
|
+
# load pre-trained model
|
109
|
+
le = glob_conf.label_encoder
|
110
|
+
mapping = dict(zip(le.classes_, range(len(le.classes_))))
|
111
|
+
target_mapping = {k: int(v) for k, v in mapping.items()}
|
112
|
+
target_mapping_reverse = {value: key for key, value in target_mapping.items()}
|
88
113
|
|
89
114
|
config = transformers.AutoConfig.from_pretrained(
|
90
115
|
model_path,
|
91
|
-
num_labels=len(
|
92
|
-
label2id=
|
93
|
-
id2label=
|
94
|
-
finetuning_task=
|
116
|
+
num_labels=len(target_mapping),
|
117
|
+
label2id=target_mapping,
|
118
|
+
id2label=target_mapping_reverse,
|
119
|
+
finetuning_task=target_name,
|
95
120
|
)
|
96
121
|
if num_layers is not None:
|
97
122
|
config.num_hidden_layers = num_layers
|
98
123
|
setattr(config, "sampling_rate", sampling_rate)
|
99
|
-
setattr(config, "data",
|
124
|
+
setattr(config, "data", util.get_data_name())
|
125
|
+
|
126
|
+
vocab_dict = {}
|
127
|
+
with open("vocab.json", "w") as vocab_file:
|
128
|
+
json.dump(vocab_dict, vocab_file)
|
129
|
+
tokenizer = transformers.Wav2Vec2CTCTokenizer("./vocab.json")
|
130
|
+
tokenizer.save_pretrained(".")
|
131
|
+
|
132
|
+
feature_extractor = transformers.Wav2Vec2FeatureExtractor(
|
133
|
+
feature_size=1,
|
134
|
+
sampling_rate=16000,
|
135
|
+
padding_value=0.0,
|
136
|
+
do_normalize=True,
|
137
|
+
return_attention_mask=True,
|
138
|
+
)
|
139
|
+
processor = transformers.Wav2Vec2Processor(
|
140
|
+
feature_extractor=feature_extractor,
|
141
|
+
tokenizer=tokenizer,
|
142
|
+
)
|
143
|
+
assert processor.feature_extractor.sampling_rate == sampling_rate
|
144
|
+
|
145
|
+
model = fm.Model.from_pretrained(
|
146
|
+
model_path,
|
147
|
+
config=config,
|
148
|
+
)
|
149
|
+
model.freeze_feature_extractor()
|
150
|
+
model.train()
|
151
|
+
|
152
|
+
# training
|
153
|
+
|
154
|
+
def data_collator(data):
|
155
|
+
|
156
|
+
files = [d["file"] for d in data]
|
157
|
+
starts = [d["start"] for d in data]
|
158
|
+
ends = [d["end"] for d in data]
|
159
|
+
targets = [d["targets"] for d in data]
|
160
|
+
|
161
|
+
signals = []
|
162
|
+
for file, start, end in zip(
|
163
|
+
files,
|
164
|
+
starts,
|
165
|
+
ends,
|
166
|
+
):
|
167
|
+
offset = start
|
168
|
+
duration = end - offset
|
169
|
+
if max_duration_sec is not None:
|
170
|
+
duration = min(duration, max_duration_sec)
|
171
|
+
signal, _ = audiofile.read(
|
172
|
+
file,
|
173
|
+
offset=offset,
|
174
|
+
duration=duration,
|
175
|
+
)
|
176
|
+
signals.append(signal.squeeze())
|
177
|
+
|
178
|
+
input_values = processor(
|
179
|
+
signals,
|
180
|
+
sampling_rate=sampling_rate,
|
181
|
+
padding=True,
|
182
|
+
)
|
183
|
+
batch = processor.pad(
|
184
|
+
input_values,
|
185
|
+
padding=True,
|
186
|
+
return_tensors="pt",
|
187
|
+
)
|
188
|
+
|
189
|
+
batch["labels"] = torch.tensor(targets)
|
190
|
+
|
191
|
+
return batch
|
192
|
+
|
193
|
+
def compute_metrics(p: transformers.EvalPrediction):
|
194
|
+
|
195
|
+
truth_gender = p.label_ids[:, 0].astype(int)
|
196
|
+
preds = p.predictions
|
197
|
+
preds_gender = np.argmax(preds, axis=1)
|
198
|
+
|
199
|
+
scores = {}
|
200
|
+
|
201
|
+
for name, metric in metrics_gender.items():
|
202
|
+
scores[f"gender-{name}"] = metric(truth_gender, preds_gender)
|
203
|
+
|
204
|
+
scores["combined"] = scores["gender-UAR"]
|
205
|
+
|
206
|
+
return scores
|
207
|
+
|
208
|
+
targets = pd.DataFrame(dataset["train"]["targets"])
|
209
|
+
counts = targets[0].value_counts().sort_index()
|
210
|
+
train_weights = 1 / counts
|
211
|
+
train_weights /= train_weights.sum()
|
212
|
+
|
213
|
+
print(train_weights)
|
214
|
+
|
215
|
+
criterion_gender = torch.nn.CrossEntropyLoss(
|
216
|
+
weight=torch.Tensor(train_weights).to("cuda"),
|
217
|
+
)
|
218
|
+
|
219
|
+
class Trainer(transformers.Trainer):
|
220
|
+
|
221
|
+
def compute_loss(
|
222
|
+
self,
|
223
|
+
model,
|
224
|
+
inputs,
|
225
|
+
return_outputs=False,
|
226
|
+
):
|
227
|
+
|
228
|
+
targets = inputs.pop("labels").squeeze()
|
229
|
+
targets_gender = targets.type(torch.long)
|
230
|
+
|
231
|
+
outputs = model(**inputs)
|
232
|
+
logits_gender = outputs[0].squeeze()
|
233
|
+
|
234
|
+
loss_gender = criterion_gender(logits_gender, targets_gender)
|
235
|
+
|
236
|
+
loss = loss_gender
|
237
|
+
|
238
|
+
return (loss, outputs) if return_outputs else loss
|
239
|
+
|
240
|
+
num_steps = len(dataset["train"]) // (batch_size * accumulation_steps) // 5
|
241
|
+
num_steps = max(1, num_steps)
|
242
|
+
print(num_steps)
|
243
|
+
|
244
|
+
training_args = transformers.TrainingArguments(
|
245
|
+
output_dir=model_root,
|
246
|
+
logging_dir=log_root,
|
247
|
+
per_device_train_batch_size=batch_size,
|
248
|
+
per_device_eval_batch_size=batch_size,
|
249
|
+
gradient_accumulation_steps=accumulation_steps,
|
250
|
+
evaluation_strategy="steps",
|
251
|
+
num_train_epochs=5.0,
|
252
|
+
fp16=True,
|
253
|
+
save_steps=num_steps,
|
254
|
+
eval_steps=num_steps,
|
255
|
+
logging_steps=num_steps,
|
256
|
+
learning_rate=1e-4,
|
257
|
+
save_total_limit=2,
|
258
|
+
metric_for_best_model="combined",
|
259
|
+
greater_is_better=True,
|
260
|
+
load_best_model_at_end=True,
|
261
|
+
remove_unused_columns=False,
|
262
|
+
)
|
263
|
+
|
264
|
+
trainer = Trainer(
|
265
|
+
model=model,
|
266
|
+
data_collator=data_collator,
|
267
|
+
args=training_args,
|
268
|
+
compute_metrics=compute_metrics,
|
269
|
+
train_dataset=dataset["train"],
|
270
|
+
eval_dataset=dataset["dev"],
|
271
|
+
tokenizer=processor.feature_extractor,
|
272
|
+
callbacks=[transformers.integrations.TensorBoardCallback()],
|
273
|
+
)
|
274
|
+
|
275
|
+
trainer.train()
|
276
|
+
trainer.save_model(torch_root)
|
100
277
|
|
101
278
|
print("DONE")
|
102
279
|
|
nkululeko/utils/util.py
CHANGED
@@ -33,43 +33,58 @@ class Util:
|
|
33
33
|
else:
|
34
34
|
self.caller = ""
|
35
35
|
if has_config:
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
36
|
+
try:
|
37
|
+
import nkululeko.glob_conf as glob_conf
|
38
|
+
self.config = glob_conf.config
|
39
|
+
self.got_data_roots = self.config_val(
|
40
|
+
"DATA", "root_folders", False)
|
41
|
+
if self.got_data_roots:
|
42
|
+
# if there is a global data rootfolder file, read from there
|
43
|
+
if not os.path.isfile(self.got_data_roots):
|
44
|
+
self.error(f"no such file: {self.got_data_roots}")
|
45
|
+
self.data_roots = configparser.ConfigParser()
|
46
|
+
self.data_roots.read(self.got_data_roots)
|
47
|
+
except (ModuleNotFoundError, AttributeError):
|
48
|
+
self.config = None
|
49
|
+
self.got_data_roots = False
|
47
50
|
|
48
51
|
def get_path(self, entry):
|
49
52
|
"""
|
50
53
|
This method allows the user to get the directory path for the given argument.
|
51
54
|
"""
|
52
|
-
|
53
|
-
|
54
|
-
try:
|
55
|
-
entryn = self.config["EXP"][entry]
|
56
|
-
except KeyError:
|
57
|
-
# some default values
|
55
|
+
if self.config is None:
|
56
|
+
# If no configuration file is provided, use default paths
|
58
57
|
if entry == "fig_dir":
|
59
|
-
|
58
|
+
dir_name = "./images/"
|
60
59
|
elif entry == "res_dir":
|
61
|
-
|
60
|
+
dir_name = "./results/"
|
62
61
|
elif entry == "model_dir":
|
63
|
-
|
62
|
+
dir_name = "./models/"
|
64
63
|
else:
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
64
|
+
dir_name = "./store/"
|
65
|
+
else:
|
66
|
+
root = os.path.join(self.config["EXP"]["root"], "")
|
67
|
+
name = self.config["EXP"]["name"]
|
68
|
+
try:
|
69
|
+
entryn = self.config["EXP"][entry]
|
70
|
+
except KeyError:
|
71
|
+
# some default values
|
72
|
+
if entry == "fig_dir":
|
73
|
+
entryn = "./images/"
|
74
|
+
elif entry == "res_dir":
|
75
|
+
entryn = "./results/"
|
76
|
+
elif entry == "model_dir":
|
77
|
+
entryn = "./models/"
|
78
|
+
else:
|
79
|
+
entryn = "./store/"
|
80
|
+
|
81
|
+
# Expand image, model and result directories with run index
|
82
|
+
if entry == "fig_dir" or entry == "res_dir" or entry == "model_dir":
|
83
|
+
run = self.config_val("EXP", "run", 0)
|
84
|
+
entryn = entryn + f"run_{run}/"
|
85
|
+
|
86
|
+
dir_name = f"{root}{name}/{entryn}"
|
71
87
|
|
72
|
-
dir_name = f"{root}{name}/{entryn}"
|
73
88
|
audeer.mkdir(dir_name)
|
74
89
|
return dir_name
|
75
90
|
|
@@ -101,7 +116,8 @@ class Util:
|
|
101
116
|
)
|
102
117
|
return default
|
103
118
|
if not default in self.stopvals:
|
104
|
-
self.debug(
|
119
|
+
self.debug(
|
120
|
+
f"value for {key} not found, using default: {default}")
|
105
121
|
return default
|
106
122
|
|
107
123
|
def set_config(self, config):
|
@@ -138,7 +154,8 @@ class Util:
|
|
138
154
|
if len(df) == 0:
|
139
155
|
return df
|
140
156
|
if not isinstance(df.index, pd.MultiIndex):
|
141
|
-
df.index = audformat.utils.to_segmented_index(
|
157
|
+
df.index = audformat.utils.to_segmented_index(
|
158
|
+
df.index, allow_nat=False)
|
142
159
|
return df
|
143
160
|
|
144
161
|
def _get_value_descript(self, section, name):
|
@@ -243,11 +260,14 @@ class Util:
|
|
243
260
|
print(df.head(1))
|
244
261
|
|
245
262
|
def config_val(self, section, key, default):
|
263
|
+
if self.config is None:
|
264
|
+
return default
|
246
265
|
try:
|
247
266
|
return self.config[section][key]
|
248
267
|
except KeyError:
|
249
|
-
if not
|
250
|
-
self.debug(
|
268
|
+
if default not in self.stopvals:
|
269
|
+
self.debug(
|
270
|
+
f"value for {key} not found, using default: {default}")
|
251
271
|
return default
|
252
272
|
|
253
273
|
def config_val_list(self, section, key, default):
|
@@ -255,7 +275,8 @@ class Util:
|
|
255
275
|
return ast.literal_eval(self.config[section][key])
|
256
276
|
except KeyError:
|
257
277
|
if not default in self.stopvals:
|
258
|
-
self.debug(
|
278
|
+
self.debug(
|
279
|
+
f"value for {key} not found, using default: {default}")
|
259
280
|
return default
|
260
281
|
|
261
282
|
def continuous_to_categorical(self, series):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.84.
|
3
|
+
Version: 0.84.1
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -333,6 +333,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
333
333
|
Changelog
|
334
334
|
=========
|
335
335
|
|
336
|
+
Version 0.84.1
|
337
|
+
--------------
|
338
|
+
* made resample independent of config file
|
339
|
+
|
336
340
|
Version 0.84.0
|
337
341
|
--------------
|
338
342
|
* added SHAP analysis
|
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=31GQXyAN-nrfQCNIt6_aSkBVeE_J3GO-PklTEy6EgBg,39
|
6
6
|
nkululeko/demo.py,sha256=8bl15Kitoesnz8oa8yrs52T6YCSOhWbbq9PnZ8Hj6D0,3232
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
|
@@ -19,19 +19,19 @@ nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
|
|
19
19
|
nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
|
20
20
|
nkululeko/plots.py,sha256=nd9tF_61DyAx7oGZF8gTrHXazkgFjFe4eClxu1nQ_XU,23276
|
21
21
|
nkululeko/predict.py,sha256=sF091sSSLnEWcISx9ZcULLie3tY5XeFsQJd6b3vrxFg,2409
|
22
|
-
nkululeko/resample.py,sha256=
|
22
|
+
nkululeko/resample.py,sha256=IPtYqU0nhZ-CqO_O1jJN0EvpfjxHZdFRwdTpEJOVuaQ,3354
|
23
23
|
nkululeko/runmanager.py,sha256=eTM1DNQKt1lxYhzt4vZyZluPXW9sWlIJHNQzex4lkJU,7624
|
24
24
|
nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
|
25
25
|
nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
|
26
26
|
nkululeko/syllable_nuclei.py,sha256=Sky-C__MeUDaxqHnDl2TGLLYOYvsahD35TUjWGeG31k,10047
|
27
27
|
nkululeko/test.py,sha256=1w624vo5KTzmFC8BUStGlLDmIEAFuJUz7J0W-gp7AxI,1677
|
28
28
|
nkululeko/test_predictor.py,sha256=_w5J8CxH6hmW3mLTKbdfmywl5QpdNAnW1Y8TE5GtlfE,3237
|
29
|
-
nkululeko/test_pretrain.py,sha256=
|
29
|
+
nkululeko/test_pretrain.py,sha256=4b_39l01dySei_e0ys2NKo9Gipf1Fukp1GvhQllFHt8,8131
|
30
30
|
nkululeko/augmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
nkululeko/augmenting/augmenter.py,sha256=XAt0dpmlnKxqyysqCgV3rcz-pRIvOz7rU7dmGDCVAzs,2905
|
32
32
|
nkululeko/augmenting/randomsplicer.py,sha256=Z5rxdKKUpuncLWuTS6xVfVKUeVbeiYU_dLRHQ5fcg4Y,2669
|
33
33
|
nkululeko/augmenting/randomsplicing.py,sha256=ldym9vZNsZIU5BAAaJVaOmAgmVHNs4a5i5K3bW-WAQU,1791
|
34
|
-
nkululeko/augmenting/resampler.py,sha256=
|
34
|
+
nkululeko/augmenting/resampler.py,sha256=nOBsiQpX6p4jXsP7x6wak78F3B5YYYRmC_iHX8iuOXs,3542
|
35
35
|
nkululeko/autopredict/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
36
36
|
nkululeko/autopredict/ap_age.py,sha256=2Wn5E-Jd49sTn40WqaMcYtUEl4zEq3OY75XmjOpdxsA,1095
|
37
37
|
nkululeko/autopredict/ap_arousal.py,sha256=ymt0diu4v1osw3VxJbSglsVKDAJYRzebQ2TTfFMKKxk,1024
|
@@ -75,6 +75,7 @@ nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
75
75
|
nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
|
76
76
|
nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
|
77
77
|
nkululeko/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
78
|
+
nkululeko/models/finetune_model.py,sha256=bx9NsFpEqf_mBohcrf-9lWjrC4AtOIJ7holNXwaFo2Y,4910
|
78
79
|
nkululeko/models/model.py,sha256=fL6LB6I9Oqo_OWUIptqiu6abuxVYYv8bW2a3m4XSLqU,11601
|
79
80
|
nkululeko/models/model_bayes.py,sha256=WJFZ8wFKwWATz6MhmjeZIi1Pal1viU549WL_PjXDSy8,406
|
80
81
|
nkululeko/models/model_cnn.py,sha256=bJxqwe6FnVR2hFeqN6EXexYGgvKYFED1VOhBXVlLWaE,9954
|
@@ -103,9 +104,9 @@ nkululeko/segmenting/seg_silero.py,sha256=lLytS38KzARS17omwv8VBw-zz60RVSXGSvZ5Ev
|
|
103
104
|
nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
104
105
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
105
106
|
nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
|
106
|
-
nkululeko/utils/util.py,sha256=
|
107
|
-
nkululeko-0.84.
|
108
|
-
nkululeko-0.84.
|
109
|
-
nkululeko-0.84.
|
110
|
-
nkululeko-0.84.
|
111
|
-
nkululeko-0.84.
|
107
|
+
nkululeko/utils/util.py,sha256=b1IHFucRNuF9Iyv5IJeK4AEg0Rga0xKG80UM5GWWdHA,13816
|
108
|
+
nkululeko-0.84.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
109
|
+
nkululeko-0.84.1.dist-info/METADATA,sha256=Y647w-vkRjPG7fssLTEF_Aa_pP74aN-WPCGv6r0_NcE,36420
|
110
|
+
nkululeko-0.84.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
111
|
+
nkululeko-0.84.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
112
|
+
nkululeko-0.84.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|