nkululeko 0.85.2__py3-none-any.whl → 0.86.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/experiment.py +30 -40
- nkululeko/feat_extract/feats_opensmile.py +25 -25
- nkululeko/feat_extract/featureset.py +4 -4
- nkululeko/models/model_tuned.py +149 -88
- {nkululeko-0.85.2.dist-info → nkululeko-0.86.0.dist-info}/METADATA +7 -1
- {nkululeko-0.85.2.dist-info → nkululeko-0.86.0.dist-info}/RECORD +10 -11
- nkululeko/models/finetune_model.py +0 -190
- {nkululeko-0.85.2.dist-info → nkululeko-0.86.0.dist-info}/LICENSE +0 -0
- {nkululeko-0.85.2.dist-info → nkululeko-0.86.0.dist-info}/WHEEL +0 -0
- {nkululeko-0.85.2.dist-info → nkululeko-0.86.0.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.
|
1
|
+
VERSION="0.86.0"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/experiment.py
CHANGED
@@ -30,15 +30,14 @@ from nkululeko.utils.util import Util
|
|
30
30
|
|
31
31
|
|
32
32
|
class Experiment:
|
33
|
-
"""Main class specifying an experiment"""
|
33
|
+
"""Main class specifying an experiment."""
|
34
34
|
|
35
35
|
def __init__(self, config_obj):
|
36
|
-
"""
|
37
|
-
Parameters
|
38
|
-
----------
|
39
|
-
config_obj : a config parser object that sets the experiment parameters and being set as a global object.
|
40
|
-
"""
|
36
|
+
"""Constructor.
|
41
37
|
|
38
|
+
Args:
|
39
|
+
- config_obj : a config parser object that sets the experiment parameters and being set as a global object.
|
40
|
+
"""
|
42
41
|
self.set_globals(config_obj)
|
43
42
|
self.name = glob_conf.config["EXP"]["name"]
|
44
43
|
self.root = os.path.join(glob_conf.config["EXP"]["root"], "")
|
@@ -109,15 +108,13 @@ class Experiment:
|
|
109
108
|
# print keys/column
|
110
109
|
dbs = ",".join(list(self.datasets.keys()))
|
111
110
|
labels = self.util.config_val("DATA", "labels", False)
|
112
|
-
auto_labels = list(
|
113
|
-
next(iter(self.datasets.values())).df[self.target].unique()
|
114
|
-
)
|
111
|
+
auto_labels = list(next(iter(self.datasets.values())).df[self.target].unique())
|
115
112
|
if labels:
|
116
113
|
self.labels = ast.literal_eval(labels)
|
117
114
|
self.util.debug(f"Target labels (from config): {labels}")
|
118
115
|
else:
|
119
116
|
self.labels = auto_labels
|
120
|
-
|
117
|
+
self.util.debug(f"Target labels (from database): {auto_labels}")
|
121
118
|
glob_conf.set_labels(self.labels)
|
122
119
|
self.util.debug(f"loaded databases {dbs}")
|
123
120
|
|
@@ -160,8 +157,7 @@ class Experiment:
|
|
160
157
|
data.split()
|
161
158
|
data.prepare_labels()
|
162
159
|
self.df_test = pd.concat(
|
163
|
-
[self.df_test, self.util.make_segmented_index(
|
164
|
-
data.df_test)]
|
160
|
+
[self.df_test, self.util.make_segmented_index(data.df_test)]
|
165
161
|
)
|
166
162
|
self.df_test.is_labeled = data.is_labeled
|
167
163
|
self.df_test.got_gender = self.got_gender
|
@@ -262,8 +258,7 @@ class Experiment:
|
|
262
258
|
test_cats = self.df_test[self.target].unique()
|
263
259
|
else:
|
264
260
|
# if there is no target, copy a dummy label
|
265
|
-
self.df_test = self._add_random_target(
|
266
|
-
self.df_test).astype("str")
|
261
|
+
self.df_test = self._add_random_target(self.df_test).astype("str")
|
267
262
|
train_cats = self.df_train[self.target].unique()
|
268
263
|
# print(f"df_train: {pd.DataFrame(self.df_train[self.target])}")
|
269
264
|
# print(f"train_cats with target {self.target}: {train_cats}")
|
@@ -271,8 +266,7 @@ class Experiment:
|
|
271
266
|
if type(test_cats) == np.ndarray:
|
272
267
|
self.util.debug(f"Categories test (nd.array): {test_cats}")
|
273
268
|
else:
|
274
|
-
self.util.debug(
|
275
|
-
f"Categories test (list): {list(test_cats)}")
|
269
|
+
self.util.debug(f"Categories test (list): {list(test_cats)}")
|
276
270
|
if type(train_cats) == np.ndarray:
|
277
271
|
self.util.debug(f"Categories train (nd.array): {train_cats}")
|
278
272
|
else:
|
@@ -295,8 +289,7 @@ class Experiment:
|
|
295
289
|
|
296
290
|
target_factor = self.util.config_val("DATA", "target_divide_by", False)
|
297
291
|
if target_factor:
|
298
|
-
self.df_test[self.target] = self.df_test[self.target] /
|
299
|
-
float(target_factor)
|
292
|
+
self.df_test[self.target] = self.df_test[self.target] / float(target_factor)
|
300
293
|
self.df_train[self.target] = self.df_train[self.target] / float(
|
301
294
|
target_factor
|
302
295
|
)
|
@@ -319,16 +312,14 @@ class Experiment:
|
|
319
312
|
def plot_distribution(self, df_labels):
|
320
313
|
"""Plot the distribution of samples and speaker per target class and biological sex"""
|
321
314
|
plot = Plots()
|
322
|
-
sample_selection = self.util.config_val(
|
323
|
-
"EXPL", "sample_selection", "all")
|
315
|
+
sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
|
324
316
|
plot.plot_distributions(df_labels)
|
325
317
|
if self.got_speaker:
|
326
318
|
plot.plot_distributions_speaker(df_labels)
|
327
319
|
|
328
320
|
def extract_test_feats(self):
|
329
321
|
self.feats_test = pd.DataFrame()
|
330
|
-
feats_name = "_".join(ast.literal_eval(
|
331
|
-
glob_conf.config["DATA"]["tests"]))
|
322
|
+
feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["tests"]))
|
332
323
|
feats_types = self.util.config_val_list("FEATS", "type", ["os"])
|
333
324
|
self.feature_extractor = FeatureExtractor(
|
334
325
|
self.df_test, feats_types, feats_name, "test"
|
@@ -345,8 +336,7 @@ class Experiment:
|
|
345
336
|
|
346
337
|
"""
|
347
338
|
df_train, df_test = self.df_train, self.df_test
|
348
|
-
feats_name = "_".join(ast.literal_eval(
|
349
|
-
glob_conf.config["DATA"]["databases"]))
|
339
|
+
feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
|
350
340
|
self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
|
351
341
|
feats_types = self.util.config_val_list("FEATS", "type", [])
|
352
342
|
# for some models no features are needed
|
@@ -380,20 +370,22 @@ class Experiment:
|
|
380
370
|
f"test feats ({self.feats_test.shape[0]}) != test labels"
|
381
371
|
f" ({self.df_test.shape[0]})"
|
382
372
|
)
|
383
|
-
self.df_test = self.df_test[self.df_test.index.isin(
|
384
|
-
|
385
|
-
self.util.warn(f"mew test labels shape: {self.df_test.shape[0]}")
|
373
|
+
self.df_test = self.df_test[self.df_test.index.isin(self.feats_test.index)]
|
374
|
+
self.util.warn(f"new test labels shape: {self.df_test.shape[0]}")
|
386
375
|
|
387
376
|
self._check_scale()
|
377
|
+
# store = self.util.get_path("store")
|
378
|
+
# store_format = self.util.config_val("FEATS", "store_format", "pkl")
|
379
|
+
# storage = f"{store}test_feats.{store_format}"
|
380
|
+
# self.util.write_store(self.feats_test, storage, store_format)
|
381
|
+
# storage = f"{store}train_feats.{store_format}"
|
382
|
+
# self.util.write_store(self.feats_train, storage, store_format)
|
388
383
|
|
389
384
|
def augment(self):
|
390
|
-
"""
|
391
|
-
Augment the selected samples
|
392
|
-
"""
|
385
|
+
"""Augment the selected samples."""
|
393
386
|
from nkululeko.augmenting.augmenter import Augmenter
|
394
387
|
|
395
|
-
sample_selection = self.util.config_val(
|
396
|
-
"AUGMENT", "sample_selection", "all")
|
388
|
+
sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
|
397
389
|
if sample_selection == "all":
|
398
390
|
df = pd.concat([self.df_train, self.df_test])
|
399
391
|
elif sample_selection == "train":
|
@@ -488,8 +480,7 @@ class Experiment:
|
|
488
480
|
"""
|
489
481
|
from nkululeko.augmenting.randomsplicer import Randomsplicer
|
490
482
|
|
491
|
-
sample_selection = self.util.config_val(
|
492
|
-
"AUGMENT", "sample_selection", "all")
|
483
|
+
sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
|
493
484
|
if sample_selection == "all":
|
494
485
|
df = pd.concat([self.df_train, self.df_test])
|
495
486
|
elif sample_selection == "train":
|
@@ -510,8 +501,7 @@ class Experiment:
|
|
510
501
|
plot_feats = eval(
|
511
502
|
self.util.config_val("EXPL", "feature_distributions", "False")
|
512
503
|
)
|
513
|
-
sample_selection = self.util.config_val(
|
514
|
-
"EXPL", "sample_selection", "all")
|
504
|
+
sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
|
515
505
|
# get the data labels
|
516
506
|
if sample_selection == "all":
|
517
507
|
df_labels = pd.concat([self.df_train, self.df_test])
|
@@ -574,8 +564,7 @@ class Experiment:
|
|
574
564
|
for scat_target in scat_targets:
|
575
565
|
if self.util.is_categorical(df_labels[scat_target]):
|
576
566
|
for scatter in scatters:
|
577
|
-
plots.scatter_plot(
|
578
|
-
df_feats, df_labels, scat_target, scatter)
|
567
|
+
plots.scatter_plot(df_feats, df_labels, scat_target, scatter)
|
579
568
|
else:
|
580
569
|
self.util.debug(
|
581
570
|
f"{self.name}: binning continuous variable to categories"
|
@@ -590,6 +579,8 @@ class Experiment:
|
|
590
579
|
)
|
591
580
|
|
592
581
|
def _check_scale(self):
|
582
|
+
self.util.save_to_store(self.feats_train, "feats_train")
|
583
|
+
self.util.save_to_store(self.feats_test, "feats_test")
|
593
584
|
scale_feats = self.util.config_val("FEATS", "scale", False)
|
594
585
|
# print the scale
|
595
586
|
self.util.debug(f"scaler: {scale_feats}")
|
@@ -664,8 +655,7 @@ class Experiment:
|
|
664
655
|
preds = best.preds
|
665
656
|
speakers = self.df_test.speaker.values
|
666
657
|
print(f"{len(truths)} {len(preds)} {len(speakers) }")
|
667
|
-
df = pd.DataFrame(
|
668
|
-
data={"truth": truths, "pred": preds, "speaker": speakers})
|
658
|
+
df = pd.DataFrame(data={"truth": truths, "pred": preds, "speaker": speakers})
|
669
659
|
plot_name = "result_combined_per_speaker"
|
670
660
|
self.util.debug(
|
671
661
|
f"plotting speaker combination ({function}) confusion matrix to"
|
@@ -65,28 +65,28 @@ class Opensmileset(Featureset):
|
|
65
65
|
feats = smile.process_signal(signal, sr)
|
66
66
|
return feats.to_numpy()
|
67
67
|
|
68
|
-
def filter(self):
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
68
|
+
# def filter(self):
|
69
|
+
# # use only the features that are indexed in the target dataframes
|
70
|
+
# self.df = self.df[self.df.index.isin(self.data_df.index)]
|
71
|
+
# try:
|
72
|
+
# # use only some features
|
73
|
+
# selected_features = ast.literal_eval(
|
74
|
+
# glob_conf.config["FEATS"]["os.features"]
|
75
|
+
# )
|
76
|
+
# self.util.debug(f"selecting features from opensmile: {selected_features}")
|
77
|
+
# sel_feats_df = pd.DataFrame()
|
78
|
+
# hit = False
|
79
|
+
# for feat in selected_features:
|
80
|
+
# try:
|
81
|
+
# sel_feats_df[feat] = self.df[feat]
|
82
|
+
# hit = True
|
83
|
+
# except KeyError:
|
84
|
+
# pass
|
85
|
+
# if hit:
|
86
|
+
# self.df = sel_feats_df
|
87
|
+
# self.util.debug(
|
88
|
+
# "new feats shape after selecting opensmile features:"
|
89
|
+
# f" {self.df.shape}"
|
90
|
+
# )
|
91
|
+
# except KeyError:
|
92
|
+
# pass
|
@@ -15,7 +15,7 @@ class Featureset:
|
|
15
15
|
self.name = name
|
16
16
|
self.data_df = data_df
|
17
17
|
self.util = Util("featureset")
|
18
|
-
self.
|
18
|
+
self.feats_type = feats_type
|
19
19
|
|
20
20
|
def extract(self):
|
21
21
|
pass
|
@@ -25,8 +25,7 @@ class Featureset:
|
|
25
25
|
self.df = self.df[self.df.index.isin(self.data_df.index)]
|
26
26
|
try:
|
27
27
|
# use only some features
|
28
|
-
selected_features = ast.literal_eval(
|
29
|
-
glob_conf.config["FEATS"]["features"])
|
28
|
+
selected_features = ast.literal_eval(glob_conf.config["FEATS"]["features"])
|
30
29
|
self.util.debug(f"selecting features: {selected_features}")
|
31
30
|
sel_feats_df = pd.DataFrame()
|
32
31
|
hit = False
|
@@ -35,11 +34,12 @@ class Featureset:
|
|
35
34
|
sel_feats_df[feat] = self.df[feat]
|
36
35
|
hit = True
|
37
36
|
except KeyError:
|
37
|
+
self.util.warn(f"non existent feature in {self.feats_type}: {feat}")
|
38
38
|
pass
|
39
39
|
if hit:
|
40
40
|
self.df = sel_feats_df
|
41
41
|
self.util.debug(
|
42
|
-
f"new feats shape after selecting features: {self.df.shape}"
|
42
|
+
f"new feats shape after selecting features for {self.feats_type}: {self.df.shape}"
|
43
43
|
)
|
44
44
|
except KeyError:
|
45
45
|
pass
|
nkululeko/models/model_tuned.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
Code based on @jwagner.
|
3
|
-
"""
|
1
|
+
"""Code based on @jwagner."""
|
4
2
|
|
5
3
|
import dataclasses
|
6
4
|
import json
|
@@ -27,8 +25,6 @@ from nkululeko.reporting.reporter import Reporter
|
|
27
25
|
|
28
26
|
class TunedModel(BaseModel):
|
29
27
|
|
30
|
-
is_classifier = True
|
31
|
-
|
32
28
|
def __init__(self, df_train, df_test, feats_train, feats_test):
|
33
29
|
"""Constructor taking the configuration and all dataframes."""
|
34
30
|
super().__init__(df_train, df_test, feats_train, feats_test)
|
@@ -37,25 +33,47 @@ class TunedModel(BaseModel):
|
|
37
33
|
self.target = glob_conf.config["DATA"]["target"]
|
38
34
|
labels = glob_conf.labels
|
39
35
|
self.class_num = len(labels)
|
40
|
-
device = self.util.config_val("MODEL", "device",
|
41
|
-
|
42
|
-
|
43
|
-
|
36
|
+
device = self.util.config_val("MODEL", "device", False)
|
37
|
+
if not device:
|
38
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
39
|
+
else:
|
40
|
+
self.device = device
|
41
|
+
if self.device != "cpu":
|
44
42
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
45
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = device
|
43
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = self.device
|
44
|
+
self.util.debug(f"running on device {self.device}")
|
45
|
+
self.is_classifier = self.util.exp_is_classification()
|
46
|
+
if self.is_classifier:
|
47
|
+
self.measure = "uar"
|
48
|
+
else:
|
49
|
+
self.measure = self.util.config_val("MODEL", "measure", "ccc")
|
50
|
+
self.util.debug(f"evaluation metrics: {self.measure}")
|
51
|
+
self.batch_size = int(self.util.config_val("MODEL", "batch_size", "8"))
|
52
|
+
self.util.debug(f"batch size: {self.batch_size}")
|
53
|
+
self.learning_rate = float(
|
54
|
+
self.util.config_val("MODEL", "learning_rate", 0.0001)
|
55
|
+
)
|
46
56
|
self.df_train, self.df_test = df_train, df_test
|
47
57
|
self.epoch_num = int(self.util.config_val("EXP", "epochs", 1))
|
48
|
-
|
58
|
+
drop = self.util.config_val("MODEL", "drop", False)
|
59
|
+
self.drop = 0.1
|
60
|
+
if drop:
|
61
|
+
self.drop = float(drop)
|
62
|
+
self.util.debug(f"init: training with dropout: {self.drop}")
|
49
63
|
self._init_model()
|
50
64
|
|
51
65
|
def _init_model(self):
|
52
66
|
model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
|
67
|
+
pretrained_model = self.util.config_val("MODEL", "pretrained_model", model_path)
|
53
68
|
self.num_layers = None
|
54
69
|
self.sampling_rate = 16000
|
55
70
|
self.max_duration_sec = 8.0
|
56
71
|
self.accumulation_steps = 4
|
57
|
-
# create dataset
|
58
72
|
|
73
|
+
# print finetuning information via debug
|
74
|
+
self.util.debug(f"Finetuning from model: {pretrained_model}")
|
75
|
+
|
76
|
+
# create dataset
|
59
77
|
dataset = {}
|
60
78
|
target_name = glob_conf.target
|
61
79
|
data_sources = {
|
@@ -76,22 +94,32 @@ class TunedModel(BaseModel):
|
|
76
94
|
self.dataset = datasets.DatasetDict(dataset)
|
77
95
|
|
78
96
|
# load pre-trained model
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
97
|
+
if self.is_classifier:
|
98
|
+
le = glob_conf.label_encoder
|
99
|
+
mapping = dict(zip(le.classes_, range(len(le.classes_))))
|
100
|
+
target_mapping = {k: int(v) for k, v in mapping.items()}
|
101
|
+
target_mapping_reverse = {
|
102
|
+
value: key for key, value in target_mapping.items()
|
103
|
+
}
|
104
|
+
self.config = transformers.AutoConfig.from_pretrained(
|
105
|
+
model_path,
|
106
|
+
num_labels=len(target_mapping),
|
107
|
+
label2id=target_mapping,
|
108
|
+
id2label=target_mapping_reverse,
|
109
|
+
finetuning_task=target_name,
|
110
|
+
)
|
111
|
+
else:
|
112
|
+
self.config = transformers.AutoConfig.from_pretrained(
|
113
|
+
model_path,
|
114
|
+
num_labels=1,
|
115
|
+
finetuning_task=target_name,
|
116
|
+
)
|
91
117
|
if self.num_layers is not None:
|
92
118
|
self.config.num_hidden_layers = self.num_layers
|
119
|
+
self.config.final_dropout = self.drop
|
93
120
|
setattr(self.config, "sampling_rate", self.sampling_rate)
|
94
121
|
setattr(self.config, "data", self.util.get_data_name())
|
122
|
+
setattr(self.config, "is_classifier", self.is_classifier)
|
95
123
|
|
96
124
|
vocab_dict = {}
|
97
125
|
with open("vocab.json", "w") as vocab_file:
|
@@ -113,7 +141,7 @@ class TunedModel(BaseModel):
|
|
113
141
|
assert self.processor.feature_extractor.sampling_rate == self.sampling_rate
|
114
142
|
|
115
143
|
self.model = Model.from_pretrained(
|
116
|
-
|
144
|
+
pretrained_model,
|
117
145
|
config=self.config,
|
118
146
|
)
|
119
147
|
self.model.freeze_feature_extractor()
|
@@ -170,7 +198,7 @@ class TunedModel(BaseModel):
|
|
170
198
|
return_tensors="pt",
|
171
199
|
)
|
172
200
|
|
173
|
-
batch["labels"] = torch.
|
201
|
+
batch["labels"] = torch.Tensor(targets)
|
174
202
|
|
175
203
|
return batch
|
176
204
|
|
@@ -180,14 +208,25 @@ class TunedModel(BaseModel):
|
|
180
208
|
"UAR": audmetric.unweighted_average_recall,
|
181
209
|
"ACC": audmetric.accuracy,
|
182
210
|
}
|
211
|
+
metrics_reg = {
|
212
|
+
"PCC": audmetric.pearson_cc,
|
213
|
+
"CCC": audmetric.concordance_cc,
|
214
|
+
"MSE": audmetric.mean_squared_error,
|
215
|
+
"MAE": audmetric.mean_absolute_error,
|
216
|
+
}
|
183
217
|
|
184
218
|
# truth = p.label_ids[:, 0].astype(int)
|
185
219
|
truth = p.label_ids
|
186
220
|
preds = p.predictions
|
187
221
|
preds = np.argmax(preds, axis=1)
|
188
222
|
scores = {}
|
189
|
-
|
190
|
-
|
223
|
+
if self.is_classifier:
|
224
|
+
for name, metric in metrics.items():
|
225
|
+
scores[f"{name}"] = metric(truth, preds)
|
226
|
+
else:
|
227
|
+
for name, metric in metrics_reg.items():
|
228
|
+
scores[f"{name}"] = metric(truth, preds)
|
229
|
+
|
191
230
|
return scores
|
192
231
|
|
193
232
|
def train(self):
|
@@ -203,23 +242,24 @@ class TunedModel(BaseModel):
|
|
203
242
|
return
|
204
243
|
targets = pd.DataFrame(self.dataset["train"]["targets"])
|
205
244
|
counts = targets[0].value_counts().sort_index()
|
206
|
-
train_weights = 1 / counts
|
207
|
-
train_weights /= train_weights.sum()
|
208
|
-
self.util.debug("train weights: {train_weights}")
|
209
|
-
criterion = torch.nn.CrossEntropyLoss(
|
210
|
-
weight=torch.Tensor(train_weights).to("cuda"),
|
211
|
-
)
|
212
|
-
# criterion = torch.nn.CrossEntropyLoss()
|
213
245
|
|
214
|
-
|
246
|
+
if self.is_classifier:
|
247
|
+
train_weights = 1 / counts
|
248
|
+
train_weights /= train_weights.sum()
|
249
|
+
self.util.debug(f"train weights: {train_weights}")
|
250
|
+
criterion = torch.nn.CrossEntropyLoss(
|
251
|
+
weight=torch.Tensor(train_weights).to("cuda"),
|
252
|
+
)
|
253
|
+
else:
|
254
|
+
criterion = ConcordanceCorCoeff()
|
215
255
|
|
256
|
+
class Trainer(transformers.Trainer):
|
216
257
|
def compute_loss(
|
217
258
|
self,
|
218
259
|
model,
|
219
260
|
inputs,
|
220
261
|
return_outputs=False,
|
221
262
|
):
|
222
|
-
|
223
263
|
targets = inputs.pop("labels").squeeze()
|
224
264
|
targets = targets.type(torch.long)
|
225
265
|
|
@@ -236,7 +276,8 @@ class TunedModel(BaseModel):
|
|
236
276
|
// 5
|
237
277
|
)
|
238
278
|
num_steps = max(1, num_steps)
|
239
|
-
|
279
|
+
|
280
|
+
metrics_for_best_model = self.measure.upper()
|
240
281
|
|
241
282
|
training_args = transformers.TrainingArguments(
|
242
283
|
output_dir=model_root,
|
@@ -246,13 +287,14 @@ class TunedModel(BaseModel):
|
|
246
287
|
gradient_accumulation_steps=self.accumulation_steps,
|
247
288
|
evaluation_strategy="steps",
|
248
289
|
num_train_epochs=self.epoch_num,
|
249
|
-
fp16=
|
290
|
+
fp16=self.device == "cuda",
|
250
291
|
save_steps=num_steps,
|
251
292
|
eval_steps=num_steps,
|
252
293
|
logging_steps=num_steps,
|
253
|
-
|
294
|
+
logging_strategy="epoch",
|
295
|
+
learning_rate=self.learning_rate,
|
254
296
|
save_total_limit=2,
|
255
|
-
metric_for_best_model=
|
297
|
+
metric_for_best_model=metrics_for_best_model,
|
256
298
|
greater_is_better=True,
|
257
299
|
load_best_model_at_end=True,
|
258
300
|
remove_unused_columns=False,
|
@@ -271,6 +313,7 @@ class TunedModel(BaseModel):
|
|
271
313
|
)
|
272
314
|
trainer.train()
|
273
315
|
trainer.save_model(self.torch_root)
|
316
|
+
self.util.debug(f"saved best model to {self.torch_root}")
|
274
317
|
self.load(self.run, self.epoch)
|
275
318
|
|
276
319
|
def get_predictions(self):
|
@@ -305,7 +348,7 @@ class TunedModel(BaseModel):
|
|
305
348
|
def predict_sample(self, signal):
|
306
349
|
"""Predict one sample"""
|
307
350
|
prediction = {}
|
308
|
-
if self.
|
351
|
+
if self.is_classifier:
|
309
352
|
# get the class probabilities
|
310
353
|
predictions = self.model.predict(signal)
|
311
354
|
# pred = self.clf.predict(features)
|
@@ -337,8 +380,19 @@ class TunedModel(BaseModel):
|
|
337
380
|
@dataclasses.dataclass
|
338
381
|
class ModelOutput(transformers.file_utils.ModelOutput):
|
339
382
|
|
340
|
-
|
383
|
+
logits: torch.FloatTensor = None
|
384
|
+
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
385
|
+
cnn_features: torch.FloatTensor = None
|
386
|
+
|
387
|
+
|
388
|
+
@dataclasses.dataclass
|
389
|
+
class ModelOutputReg(transformers.file_utils.ModelOutput):
|
390
|
+
|
391
|
+
logits: torch.FloatTensor
|
341
392
|
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
393
|
+
attentions: typing.Tuple[torch.FloatTensor] = None
|
394
|
+
logits_framewise: torch.FloatTensor = None
|
395
|
+
hidden_states_framewise: torch.FloatTensor = None
|
342
396
|
cnn_features: torch.FloatTensor = None
|
343
397
|
|
344
398
|
|
@@ -368,10 +422,14 @@ class Model(Wav2Vec2PreTrainedModel):
|
|
368
422
|
|
369
423
|
def __init__(self, config):
|
370
424
|
|
425
|
+
if not hasattr(config, "add_adapter"):
|
426
|
+
setattr(config, "add_adapter", False)
|
427
|
+
|
371
428
|
super().__init__(config)
|
372
429
|
|
373
430
|
self.wav2vec2 = Wav2Vec2Model(config)
|
374
|
-
self.
|
431
|
+
self.head = ModelHead(config)
|
432
|
+
self.is_classifier = config.is_classifier
|
375
433
|
self.init_weights()
|
376
434
|
|
377
435
|
def freeze_feature_extractor(self):
|
@@ -407,39 +465,44 @@ class Model(Wav2Vec2PreTrainedModel):
|
|
407
465
|
labels=None,
|
408
466
|
return_hidden=False,
|
409
467
|
):
|
410
|
-
|
411
468
|
outputs = self.wav2vec2(
|
412
469
|
input_values,
|
413
470
|
attention_mask=attention_mask,
|
414
471
|
)
|
415
|
-
|
416
472
|
cnn_features = outputs.extract_features
|
417
473
|
hidden_states_framewise = outputs.last_hidden_state
|
418
474
|
hidden_states = self.pooling(
|
419
475
|
hidden_states_framewise,
|
420
476
|
attention_mask,
|
421
477
|
)
|
422
|
-
|
423
|
-
|
478
|
+
logits = self.head(hidden_states)
|
424
479
|
if not self.training:
|
425
|
-
|
480
|
+
logits = torch.softmax(logits, dim=1)
|
426
481
|
|
427
482
|
if return_hidden:
|
428
|
-
|
429
483
|
# make time last axis
|
430
484
|
cnn_features = torch.transpose(cnn_features, 1, 2)
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
485
|
+
if self.is_classifier:
|
486
|
+
return ModelOutput(
|
487
|
+
logits=logits,
|
488
|
+
hidden_states=hidden_states,
|
489
|
+
cnn_features=cnn_features,
|
490
|
+
)
|
491
|
+
else:
|
492
|
+
return ModelOutputReg(
|
493
|
+
logits=logits,
|
494
|
+
hidden_states=hidden_states,
|
495
|
+
cnn_features=cnn_features,
|
496
|
+
)
|
438
497
|
else:
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
498
|
+
if self.is_classifier:
|
499
|
+
return ModelOutput(
|
500
|
+
logits=logits,
|
501
|
+
)
|
502
|
+
else:
|
503
|
+
return ModelOutputReg(
|
504
|
+
logits=logits,
|
505
|
+
)
|
443
506
|
|
444
507
|
def predict(self, signal):
|
445
508
|
result = self(torch.from_numpy(signal))
|
@@ -447,33 +510,31 @@ class Model(Wav2Vec2PreTrainedModel):
|
|
447
510
|
return result
|
448
511
|
|
449
512
|
|
450
|
-
class
|
451
|
-
|
452
|
-
def __init__(self, config):
|
453
|
-
super().__init__(config)
|
513
|
+
class ConcordanceCorCoeff(torch.nn.Module):
|
454
514
|
|
455
|
-
def
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
515
|
+
def __init__(self):
|
516
|
+
super().__init__()
|
517
|
+
self.mean = torch.mean
|
518
|
+
self.var = torch.var
|
519
|
+
self.sum = torch.sum
|
520
|
+
self.sqrt = torch.sqrt
|
521
|
+
self.std = torch.std
|
522
|
+
|
523
|
+
def forward(self, prediction, ground_truth):
|
524
|
+
ground_truth = ground_truth.float()
|
525
|
+
mean_gt = self.mean(ground_truth, 0)
|
526
|
+
mean_pred = self.mean(prediction, 0)
|
527
|
+
var_gt = self.var(ground_truth, 0)
|
528
|
+
var_pred = self.var(prediction, 0)
|
529
|
+
v_pred = prediction - mean_pred
|
530
|
+
v_gt = ground_truth - mean_gt
|
531
|
+
cor = self.sum(v_pred * v_gt) / (
|
532
|
+
self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
|
473
533
|
)
|
534
|
+
sd_gt = self.std(ground_truth)
|
535
|
+
sd_pred = self.std(prediction)
|
536
|
+
numerator = 2 * cor * sd_gt * sd_pred
|
537
|
+
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
538
|
+
ccc = numerator / denominator
|
474
539
|
|
475
|
-
return
|
476
|
-
output.hidden_states,
|
477
|
-
output.logits_cat,
|
478
|
-
output.cnn_features,
|
479
|
-
)
|
540
|
+
return 1 - ccc
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.86.0
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -333,6 +333,12 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
333
333
|
Changelog
|
334
334
|
=========
|
335
335
|
|
336
|
+
Version 0.86.0
|
337
|
+
--------------
|
338
|
+
* added regression to finetuning
|
339
|
+
* added other transformer models to finetuning
|
340
|
+
* added output the train/dev features sets actually used by the model
|
341
|
+
|
336
342
|
Version 0.85.2
|
337
343
|
--------------
|
338
344
|
* added data, and automatic task label detection
|
@@ -2,11 +2,11 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=hvi1X27m7vcqkB_Rgl7alourAusZB1mjPxdW4ChdVyU,39
|
6
6
|
nkululeko/demo.py,sha256=8bl15Kitoesnz8oa8yrs52T6YCSOhWbbq9PnZ8Hj6D0,3232
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
|
9
|
-
nkululeko/experiment.py,sha256=
|
9
|
+
nkululeko/experiment.py,sha256=gUJsBMWuadqxEVzuPVToQzFHC9FRUadptP49kTcBiGs,30962
|
10
10
|
nkululeko/explore.py,sha256=lDzRoW_Taa5u4BBABZLD89BcQWnYlrftJR4jgt1yyj0,2609
|
11
11
|
nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
|
12
12
|
nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzwM,3976
|
@@ -58,7 +58,7 @@ nkululeko/feat_extract/feats_hubert.py,sha256=cLoUzSLjSYBkQnftjacSL7ES3O7Ysh_KrP
|
|
58
58
|
nkululeko/feat_extract/feats_import.py,sha256=rj1p8lz19tCAC8hLzzZAwZ0M6gzwH3BzfabFUgal0yw,1622
|
59
59
|
nkululeko/feat_extract/feats_mld.py,sha256=Vvu7GZOkn7Vda8eIOXqHjg78zegkFe3vTUaCXyVM0eA,2021
|
60
60
|
nkululeko/feat_extract/feats_mos.py,sha256=KXNt7QYEfxkvr6UyVhig2aWQBaIvovlrR4gPuP03gmo,4174
|
61
|
-
nkululeko/feat_extract/feats_opensmile.py,sha256=
|
61
|
+
nkululeko/feat_extract/feats_opensmile.py,sha256=g6ZsAxjjGGvGfrr5fngWC-NJ8E7CP1kYZwrlodZJzzU,4028
|
62
62
|
nkululeko/feat_extract/feats_oxbow.py,sha256=CmIG9cbHTJTJVnzgCPdQpYpnlewWExpsr5ZcK8Malyo,4980
|
63
63
|
nkululeko/feat_extract/feats_praat.py,sha256=kZrS6srzH7WoWEd2prp1Dxw6g9JklFQGTNq5zzPpHzg,3105
|
64
64
|
nkululeko/feat_extract/feats_snr.py,sha256=9dqZ-4RpK98iJEssM3ttozNd18LWlZYM_QVXvp5xDcs,2829
|
@@ -69,13 +69,12 @@ nkululeko/feat_extract/feats_trill.py,sha256=K2ahhdpwpjgg3WZS1POg3UMP2U44i8cLZZv
|
|
69
69
|
nkululeko/feat_extract/feats_wav2vec2.py,sha256=9WUMfyddB_3nx79g7mZoQrRynhM1uEBWuOotRq8bxoU,5268
|
70
70
|
nkululeko/feat_extract/feats_wavlm.py,sha256=ulxpGjifUFx2ZgGmY32SmBJGIuvkYHoLb2n1LZ8KMwA,4703
|
71
71
|
nkululeko/feat_extract/feats_whisper.py,sha256=0N7Vj65OVi2PNoB_NrDjWT5lP6xZNKxFOZZIoxkJvcA,4533
|
72
|
-
nkululeko/feat_extract/featureset.py,sha256=
|
72
|
+
nkululeko/feat_extract/featureset.py,sha256=ll7tyKAdr--TDShyOYJg0FB4I9NqBq0Ni1k_kUJ-2Vw,1541
|
73
73
|
nkululeko/feat_extract/feinberg_praat.py,sha256=EP9pMALjlKdiYInLQdrZ7MmE499Mq-ISRCgqbqL3Rxc,21304
|
74
74
|
nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
75
75
|
nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
|
76
76
|
nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
|
77
77
|
nkululeko/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
78
|
-
nkululeko/models/finetune_model.py,sha256=OMlzDyUFNXZ2xSiqqH8tbzey_KzPJ4jsoYT-4KrWFKM,5091
|
79
78
|
nkululeko/models/model.py,sha256=PUCqF2r_dEfmFsZn6Cgr1UIzYvxziLH6nSqZ5-vuN1o,11639
|
80
79
|
nkululeko/models/model_bayes.py,sha256=WJFZ8wFKwWATz6MhmjeZIi1Pal1viU549WL_PjXDSy8,406
|
81
80
|
nkululeko/models/model_cnn.py,sha256=bJxqwe6FnVR2hFeqN6EXexYGgvKYFED1VOhBXVlLWaE,9954
|
@@ -89,7 +88,7 @@ nkululeko/models/model_svm.py,sha256=rsME3KvKvNG7bdE5lbvYUu85WZhaASZxxmdNDIVJRZ4
|
|
89
88
|
nkululeko/models/model_svr.py,sha256=_YZeksqB3eBENGlg3g9RwYFlk9rQQ-XCeNBKLlGGVoE,725
|
90
89
|
nkululeko/models/model_tree.py,sha256=rf16faUm4o2LJgkoYpeY998b8DQIvXZ73_m1IS3TnnE,417
|
91
90
|
nkululeko/models/model_tree_reg.py,sha256=IgQcPTE-304HQLYSKPF8Z4ot_Ur9dH01fZjS0nXke_M,428
|
92
|
-
nkululeko/models/model_tuned.py,sha256=
|
91
|
+
nkululeko/models/model_tuned.py,sha256=J5CemIAW_WhZIQgppFgPChrsMJvGYzJlCvJC8O62l9M,18049
|
93
92
|
nkululeko/models/model_xgb.py,sha256=Thgx5ESdIok4v72mKh4plxpo4smGcKALWNCJTDScY0M,447
|
94
93
|
nkululeko/models/model_xgr.py,sha256=aGBtNGLWjOE_2rICGYGFxmT8DtnHYsIl1lIpMtghHsY,418
|
95
94
|
nkululeko/reporting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -106,8 +105,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
106
105
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
107
106
|
nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
|
108
107
|
nkululeko/utils/util.py,sha256=b1IHFucRNuF9Iyv5IJeK4AEg0Rga0xKG80UM5GWWdHA,13816
|
109
|
-
nkululeko-0.
|
110
|
-
nkululeko-0.
|
111
|
-
nkululeko-0.
|
112
|
-
nkululeko-0.
|
113
|
-
nkululeko-0.
|
108
|
+
nkululeko-0.86.0.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
109
|
+
nkululeko-0.86.0.dist-info/METADATA,sha256=KrHrjQ6rc4oGxN4EJ_TuZ0dVGGI-qIxw8dY1RBTCnLo,36852
|
110
|
+
nkululeko-0.86.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
111
|
+
nkululeko-0.86.0.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
112
|
+
nkululeko-0.86.0.dist-info/RECORD,,
|
@@ -1,190 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Code based on @jwagner
|
3
|
-
"""
|
4
|
-
|
5
|
-
import dataclasses
|
6
|
-
import typing
|
7
|
-
|
8
|
-
import torch
|
9
|
-
import transformers
|
10
|
-
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
11
|
-
Wav2Vec2PreTrainedModel,
|
12
|
-
Wav2Vec2Model,
|
13
|
-
)
|
14
|
-
|
15
|
-
|
16
|
-
class ConcordanceCorCoeff(torch.nn.Module):
|
17
|
-
|
18
|
-
def __init__(self):
|
19
|
-
|
20
|
-
super().__init__()
|
21
|
-
|
22
|
-
self.mean = torch.mean
|
23
|
-
self.var = torch.var
|
24
|
-
self.sum = torch.sum
|
25
|
-
self.sqrt = torch.sqrt
|
26
|
-
self.std = torch.std
|
27
|
-
|
28
|
-
def forward(self, prediction, ground_truth):
|
29
|
-
|
30
|
-
mean_gt = self.mean(ground_truth, 0)
|
31
|
-
mean_pred = self.mean(prediction, 0)
|
32
|
-
var_gt = self.var(ground_truth, 0)
|
33
|
-
var_pred = self.var(prediction, 0)
|
34
|
-
v_pred = prediction - mean_pred
|
35
|
-
v_gt = ground_truth - mean_gt
|
36
|
-
cor = self.sum(v_pred * v_gt) / (
|
37
|
-
self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
|
38
|
-
)
|
39
|
-
sd_gt = self.std(ground_truth)
|
40
|
-
sd_pred = self.std(prediction)
|
41
|
-
numerator = 2 * cor * sd_gt * sd_pred
|
42
|
-
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
43
|
-
ccc = numerator / denominator
|
44
|
-
|
45
|
-
return 1 - ccc
|
46
|
-
|
47
|
-
|
48
|
-
@dataclasses.dataclass
|
49
|
-
class ModelOutput(transformers.file_utils.ModelOutput):
|
50
|
-
|
51
|
-
logits_cat: torch.FloatTensor = None
|
52
|
-
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
53
|
-
cnn_features: torch.FloatTensor = None
|
54
|
-
|
55
|
-
|
56
|
-
class ModelHead(torch.nn.Module):
|
57
|
-
|
58
|
-
def __init__(self, config, num_labels):
|
59
|
-
|
60
|
-
super().__init__()
|
61
|
-
|
62
|
-
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
63
|
-
self.dropout = torch.nn.Dropout(config.final_dropout)
|
64
|
-
self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
|
65
|
-
|
66
|
-
def forward(self, features, **kwargs):
|
67
|
-
|
68
|
-
x = features
|
69
|
-
x = self.dropout(x)
|
70
|
-
x = self.dense(x)
|
71
|
-
x = torch.tanh(x)
|
72
|
-
x = self.dropout(x)
|
73
|
-
x = self.out_proj(x)
|
74
|
-
|
75
|
-
return x
|
76
|
-
|
77
|
-
|
78
|
-
class Model(Wav2Vec2PreTrainedModel):
|
79
|
-
|
80
|
-
def __init__(self, config):
|
81
|
-
|
82
|
-
super().__init__(config)
|
83
|
-
|
84
|
-
self.wav2vec2 = Wav2Vec2Model(config)
|
85
|
-
self.cat = ModelHead(config, 2)
|
86
|
-
self.init_weights()
|
87
|
-
|
88
|
-
def freeze_feature_extractor(self):
|
89
|
-
self.wav2vec2.feature_extractor._freeze_parameters()
|
90
|
-
|
91
|
-
def pooling(
|
92
|
-
self,
|
93
|
-
hidden_states,
|
94
|
-
attention_mask,
|
95
|
-
):
|
96
|
-
|
97
|
-
if attention_mask is None: # For evaluation with batch_size==1
|
98
|
-
outputs = torch.mean(hidden_states, dim=1)
|
99
|
-
else:
|
100
|
-
attention_mask = self._get_feature_vector_attention_mask(
|
101
|
-
hidden_states.shape[1],
|
102
|
-
attention_mask,
|
103
|
-
)
|
104
|
-
hidden_states = hidden_states * torch.reshape(
|
105
|
-
attention_mask,
|
106
|
-
(-1, attention_mask.shape[-1], 1),
|
107
|
-
)
|
108
|
-
outputs = torch.sum(hidden_states, dim=1)
|
109
|
-
attention_sum = torch.sum(attention_mask, dim=1)
|
110
|
-
outputs = outputs / torch.reshape(attention_sum, (-1, 1))
|
111
|
-
|
112
|
-
return outputs
|
113
|
-
|
114
|
-
def forward(
|
115
|
-
self,
|
116
|
-
input_values,
|
117
|
-
attention_mask=None,
|
118
|
-
labels=None,
|
119
|
-
return_hidden=False,
|
120
|
-
):
|
121
|
-
|
122
|
-
outputs = self.wav2vec2(
|
123
|
-
input_values,
|
124
|
-
attention_mask=attention_mask,
|
125
|
-
)
|
126
|
-
|
127
|
-
cnn_features = outputs.extract_features
|
128
|
-
hidden_states_framewise = outputs.last_hidden_state
|
129
|
-
hidden_states = self.pooling(
|
130
|
-
hidden_states_framewise,
|
131
|
-
attention_mask,
|
132
|
-
)
|
133
|
-
logits_cat = self.cat(hidden_states)
|
134
|
-
|
135
|
-
if not self.training:
|
136
|
-
logits_cat = torch.softmax(logits_cat, dim=1)
|
137
|
-
|
138
|
-
if return_hidden:
|
139
|
-
|
140
|
-
# make time last axis
|
141
|
-
cnn_features = torch.transpose(cnn_features, 1, 2)
|
142
|
-
|
143
|
-
return ModelOutput(
|
144
|
-
logits_cat=logits_cat,
|
145
|
-
hidden_states=hidden_states,
|
146
|
-
cnn_features=cnn_features,
|
147
|
-
)
|
148
|
-
|
149
|
-
else:
|
150
|
-
|
151
|
-
return ModelOutput(
|
152
|
-
logits_cat=logits_cat,
|
153
|
-
)
|
154
|
-
|
155
|
-
def predict(self, signal):
|
156
|
-
result = self(torch.from_numpy(signal))
|
157
|
-
result = result[0].detach().numpy()[0]
|
158
|
-
return result
|
159
|
-
|
160
|
-
|
161
|
-
class ModelWithPreProcessing(Model):
|
162
|
-
|
163
|
-
def __init__(self, config):
|
164
|
-
super().__init__(config)
|
165
|
-
|
166
|
-
def forward(
|
167
|
-
self,
|
168
|
-
input_values,
|
169
|
-
):
|
170
|
-
# Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
|
171
|
-
# normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
172
|
-
|
173
|
-
mean = input_values.mean()
|
174
|
-
|
175
|
-
# var = input_values.var()
|
176
|
-
# raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
|
177
|
-
|
178
|
-
var = torch.square(input_values - mean).mean()
|
179
|
-
input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
|
180
|
-
|
181
|
-
output = super().forward(
|
182
|
-
input_values,
|
183
|
-
return_hidden=True,
|
184
|
-
)
|
185
|
-
|
186
|
-
return (
|
187
|
-
output.hidden_states,
|
188
|
-
output.logits_cat,
|
189
|
-
output.cnn_features,
|
190
|
-
)
|
File without changes
|
File without changes
|
File without changes
|