nkululeko 0.85.0__py3-none-any.whl → 0.85.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/modelrunner.py +2 -2
- nkululeko/models/model_tuned.py +34 -61
- {nkululeko-0.85.0.dist-info → nkululeko-0.85.1.dist-info}/METADATA +5 -1
- {nkululeko-0.85.0.dist-info → nkululeko-0.85.1.dist-info}/RECORD +8 -8
- {nkululeko-0.85.0.dist-info → nkululeko-0.85.1.dist-info}/LICENSE +0 -0
- {nkululeko-0.85.0.dist-info → nkululeko-0.85.1.dist-info}/WHEEL +0 -0
- {nkululeko-0.85.0.dist-info → nkululeko-0.85.1.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.85.
|
1
|
+
VERSION="0.85.1"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/modelrunner.py
CHANGED
@@ -151,9 +151,9 @@ class Modelrunner:
|
|
151
151
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
152
152
|
)
|
153
153
|
elif model_type == "finetune":
|
154
|
-
from nkululeko.models.model_tuned import
|
154
|
+
from nkululeko.models.model_tuned import TunedModel
|
155
155
|
|
156
|
-
self.model =
|
156
|
+
self.model = TunedModel(
|
157
157
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
158
158
|
)
|
159
159
|
elif model_type == "gmm":
|
nkululeko/models/model_tuned.py
CHANGED
@@ -1,48 +1,39 @@
|
|
1
1
|
"""
|
2
|
-
Code based on @jwagner
|
2
|
+
Code based on @jwagner.
|
3
3
|
"""
|
4
4
|
|
5
|
-
import
|
6
|
-
import audeer
|
7
|
-
import audmetric
|
8
|
-
import datasets
|
9
|
-
import pandas as pd
|
10
|
-
import transformers
|
11
|
-
from nkululeko.utils.util import Util
|
12
|
-
import nkululeko.glob_conf as glob_conf
|
13
|
-
from nkululeko.models.model import Model as BaseModel
|
14
|
-
|
15
|
-
# import nkululeko.models.finetune_model as fm
|
16
|
-
from nkululeko.reporting.reporter import Reporter
|
17
|
-
import torch
|
18
|
-
import ast
|
19
|
-
import numpy as np
|
20
|
-
from sklearn.metrics import recall_score
|
21
|
-
from collections import OrderedDict
|
22
|
-
import os
|
5
|
+
import dataclasses
|
23
6
|
import json
|
7
|
+
import os
|
24
8
|
import pickle
|
25
|
-
import dataclasses
|
26
9
|
import typing
|
27
10
|
|
11
|
+
import datasets
|
12
|
+
import numpy as np
|
13
|
+
import pandas as pd
|
28
14
|
import torch
|
29
15
|
import transformers
|
30
|
-
from transformers.models.wav2vec2.modeling_wav2vec2 import
|
31
|
-
|
32
|
-
Wav2Vec2Model,
|
33
|
-
)
|
16
|
+
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
|
17
|
+
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel
|
34
18
|
|
19
|
+
import audeer
|
20
|
+
import audiofile
|
21
|
+
import audmetric
|
22
|
+
|
23
|
+
import nkululeko.glob_conf as glob_conf
|
24
|
+
from nkululeko.models.model import Model as BaseModel
|
25
|
+
from nkululeko.reporting.reporter import Reporter
|
35
26
|
|
36
|
-
|
27
|
+
|
28
|
+
class TunedModel(BaseModel):
|
37
29
|
|
38
30
|
is_classifier = True
|
39
31
|
|
40
32
|
def __init__(self, df_train, df_test, feats_train, feats_test):
|
41
|
-
"""Constructor taking the configuration and all dataframes"""
|
33
|
+
"""Constructor taking the configuration and all dataframes."""
|
42
34
|
super().__init__(df_train, df_test, feats_train, feats_test)
|
43
|
-
super().set_model_type("
|
35
|
+
super().set_model_type("finetuned")
|
44
36
|
self.name = "finetuned_wav2vec2"
|
45
|
-
self.model_type = "finetuned"
|
46
37
|
self.target = glob_conf.config["DATA"]["target"]
|
47
38
|
labels = glob_conf.labels
|
48
39
|
self.class_num = len(labels)
|
@@ -74,22 +65,11 @@ class Pretrained_model(BaseModel):
|
|
74
65
|
|
75
66
|
for split in ["train", "dev"]:
|
76
67
|
df = data_sources[split]
|
77
|
-
|
78
|
-
|
79
|
-
y = pd.Series(
|
80
|
-
data=df.itertuples(index=False, name=None),
|
81
|
-
index=df.index,
|
82
|
-
dtype=object,
|
83
|
-
name="labels",
|
84
|
-
)
|
85
|
-
|
68
|
+
y = df[target_name].astype("float")
|
86
69
|
y.name = "targets"
|
87
70
|
df = y.reset_index()
|
88
71
|
df.start = df.start.dt.total_seconds()
|
89
72
|
df.end = df.end.dt.total_seconds()
|
90
|
-
|
91
|
-
# print(f"{split}: {len(df)}")
|
92
|
-
|
93
73
|
ds = datasets.Dataset.from_pandas(df)
|
94
74
|
dataset[split] = ds
|
95
75
|
|
@@ -143,12 +123,6 @@ class Pretrained_model(BaseModel):
|
|
143
123
|
def set_model_type(self, type):
|
144
124
|
self.model_type = type
|
145
125
|
|
146
|
-
def is_ann(self):
|
147
|
-
if self.model_type == "ann":
|
148
|
-
return True
|
149
|
-
else:
|
150
|
-
return False
|
151
|
-
|
152
126
|
def set_testdata(self, data_df, feats_df):
|
153
127
|
self.df_test, self.feats_test = data_df, feats_df
|
154
128
|
|
@@ -207,7 +181,8 @@ class Pretrained_model(BaseModel):
|
|
207
181
|
"ACC": audmetric.accuracy,
|
208
182
|
}
|
209
183
|
|
210
|
-
truth = p.label_ids[:, 0].astype(int)
|
184
|
+
# truth = p.label_ids[:, 0].astype(int)
|
185
|
+
truth = p.label_ids
|
211
186
|
preds = p.predictions
|
212
187
|
preds = np.argmax(preds, axis=1)
|
213
188
|
scores = {}
|
@@ -216,8 +191,7 @@ class Pretrained_model(BaseModel):
|
|
216
191
|
return scores
|
217
192
|
|
218
193
|
def train(self):
|
219
|
-
"""Train the model"""
|
220
|
-
|
194
|
+
"""Train the model."""
|
221
195
|
model_root = self.util.get_path("model_dir")
|
222
196
|
log_root = os.path.join(self.util.get_exp_dir(), "log")
|
223
197
|
audeer.mkdir(log_root)
|
@@ -225,16 +199,17 @@ class Pretrained_model(BaseModel):
|
|
225
199
|
conf_file = os.path.join(self.torch_root, "config.json")
|
226
200
|
if os.path.isfile(conf_file):
|
227
201
|
self.util.debug(f"reusing finetuned model: {conf_file}")
|
228
|
-
self.load(self.run, self.
|
202
|
+
self.load(self.run, self.epoch_num)
|
229
203
|
return
|
230
204
|
targets = pd.DataFrame(self.dataset["train"]["targets"])
|
231
205
|
counts = targets[0].value_counts().sort_index()
|
232
206
|
train_weights = 1 / counts
|
233
207
|
train_weights /= train_weights.sum()
|
234
|
-
|
235
|
-
|
208
|
+
self.util.debug("train weights: {train_weights}")
|
209
|
+
criterion = torch.nn.CrossEntropyLoss(
|
236
210
|
weight=torch.Tensor(train_weights).to("cuda"),
|
237
211
|
)
|
212
|
+
# criterion = torch.nn.CrossEntropyLoss()
|
238
213
|
|
239
214
|
class Trainer(transformers.Trainer):
|
240
215
|
|
@@ -246,14 +221,12 @@ class Pretrained_model(BaseModel):
|
|
246
221
|
):
|
247
222
|
|
248
223
|
targets = inputs.pop("labels").squeeze()
|
249
|
-
|
224
|
+
targets = targets.type(torch.long)
|
250
225
|
|
251
226
|
outputs = model(**inputs)
|
252
|
-
|
227
|
+
logits = outputs[0].squeeze()
|
253
228
|
|
254
|
-
|
255
|
-
|
256
|
-
loss = loss_gender
|
229
|
+
loss = criterion(logits, targets)
|
257
230
|
|
258
231
|
return (loss, outputs) if return_outputs else loss
|
259
232
|
|
@@ -325,7 +298,7 @@ class Pretrained_model(BaseModel):
|
|
325
298
|
self.df_test[self.target].to_numpy().astype(float),
|
326
299
|
predictions,
|
327
300
|
self.run,
|
328
|
-
self.
|
301
|
+
self.epoch_num,
|
329
302
|
)
|
330
303
|
return report
|
331
304
|
|
@@ -371,13 +344,13 @@ class ModelOutput(transformers.file_utils.ModelOutput):
|
|
371
344
|
|
372
345
|
class ModelHead(torch.nn.Module):
|
373
346
|
|
374
|
-
def __init__(self, config
|
347
|
+
def __init__(self, config):
|
375
348
|
|
376
349
|
super().__init__()
|
377
350
|
|
378
351
|
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
379
352
|
self.dropout = torch.nn.Dropout(config.final_dropout)
|
380
|
-
self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
|
353
|
+
self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels)
|
381
354
|
|
382
355
|
def forward(self, features, **kwargs):
|
383
356
|
|
@@ -398,7 +371,7 @@ class Model(Wav2Vec2PreTrainedModel):
|
|
398
371
|
super().__init__(config)
|
399
372
|
|
400
373
|
self.wav2vec2 = Wav2Vec2Model(config)
|
401
|
-
self.cat = ModelHead(config
|
374
|
+
self.cat = ModelHead(config)
|
402
375
|
self.init_weights()
|
403
376
|
|
404
377
|
def freeze_feature_extractor(self):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.85.
|
3
|
+
Version: 0.85.1
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -333,6 +333,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
333
333
|
Changelog
|
334
334
|
=========
|
335
335
|
|
336
|
+
Version 0.85.1
|
337
|
+
--------------
|
338
|
+
* fixed bug in model_finetuned that label_num was constant 2
|
339
|
+
|
336
340
|
Version 0.85.0
|
337
341
|
--------------
|
338
342
|
* first version with finetuning wav2vec2 layers
|
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=WnTSXQjJmWE-IrXcNSEa5FFV_83-z0EOGXa9trq00uE,39
|
6
6
|
nkululeko/demo.py,sha256=8bl15Kitoesnz8oa8yrs52T6YCSOhWbbq9PnZ8Hj6D0,3232
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
|
@@ -13,7 +13,7 @@ nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzw
|
|
13
13
|
nkululeko/file_checker.py,sha256=LoLnL8aHpW-axMQ46qbqrManTs5otG9ShpEZuz9iRSk,3474
|
14
14
|
nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
|
15
15
|
nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
|
16
|
-
nkululeko/modelrunner.py,sha256=
|
16
|
+
nkululeko/modelrunner.py,sha256=iCmfJxsS2UafcikjRdUqPQuqQMOYA-Ctr3et3HeNR3c,10452
|
17
17
|
nkululeko/multidb.py,sha256=fG3VukEWP1vreVN4gB1IRXxwwg4jLftsSEYtu0o1f78,5634
|
18
18
|
nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
|
19
19
|
nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
|
@@ -89,7 +89,7 @@ nkululeko/models/model_svm.py,sha256=rsME3KvKvNG7bdE5lbvYUu85WZhaASZxxmdNDIVJRZ4
|
|
89
89
|
nkululeko/models/model_svr.py,sha256=_YZeksqB3eBENGlg3g9RwYFlk9rQQ-XCeNBKLlGGVoE,725
|
90
90
|
nkululeko/models/model_tree.py,sha256=rf16faUm4o2LJgkoYpeY998b8DQIvXZ73_m1IS3TnnE,417
|
91
91
|
nkululeko/models/model_tree_reg.py,sha256=IgQcPTE-304HQLYSKPF8Z4ot_Ur9dH01fZjS0nXke_M,428
|
92
|
-
nkululeko/models/model_tuned.py,sha256=
|
92
|
+
nkululeko/models/model_tuned.py,sha256=WJplfUK3CGLSd2mahUrPSjMvqjPfxLp99KFeZaz2AbU,15098
|
93
93
|
nkululeko/models/model_xgb.py,sha256=Thgx5ESdIok4v72mKh4plxpo4smGcKALWNCJTDScY0M,447
|
94
94
|
nkululeko/models/model_xgr.py,sha256=aGBtNGLWjOE_2rICGYGFxmT8DtnHYsIl1lIpMtghHsY,418
|
95
95
|
nkululeko/reporting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -106,8 +106,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
106
106
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
107
107
|
nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
|
108
108
|
nkululeko/utils/util.py,sha256=b1IHFucRNuF9Iyv5IJeK4AEg0Rga0xKG80UM5GWWdHA,13816
|
109
|
-
nkululeko-0.85.
|
110
|
-
nkululeko-0.85.
|
111
|
-
nkululeko-0.85.
|
112
|
-
nkululeko-0.85.
|
113
|
-
nkululeko-0.85.
|
109
|
+
nkululeko-0.85.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
110
|
+
nkululeko-0.85.1.dist-info/METADATA,sha256=RonY9PdKyHjwYsZ3T9TgEs1JNnY1qbMdDr-Sp6kcCW8,36591
|
111
|
+
nkululeko-0.85.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
112
|
+
nkululeko-0.85.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
113
|
+
nkululeko-0.85.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|