nkululeko 0.85.0__py3-none-any.whl → 0.85.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.85.0"
1
+ VERSION="0.85.1"
2
2
  SAMPLING_RATE = 16000
nkululeko/modelrunner.py CHANGED
@@ -151,9 +151,9 @@ class Modelrunner:
151
151
  self.df_train, self.df_test, self.feats_train, self.feats_test
152
152
  )
153
153
  elif model_type == "finetune":
154
- from nkululeko.models.model_tuned import Pretrained_model
154
+ from nkululeko.models.model_tuned import TunedModel
155
155
 
156
- self.model = Pretrained_model(
156
+ self.model = TunedModel(
157
157
  self.df_train, self.df_test, self.feats_train, self.feats_test
158
158
  )
159
159
  elif model_type == "gmm":
@@ -1,48 +1,39 @@
1
1
  """
2
- Code based on @jwagner
2
+ Code based on @jwagner.
3
3
  """
4
4
 
5
- import audiofile
6
- import audeer
7
- import audmetric
8
- import datasets
9
- import pandas as pd
10
- import transformers
11
- from nkululeko.utils.util import Util
12
- import nkululeko.glob_conf as glob_conf
13
- from nkululeko.models.model import Model as BaseModel
14
-
15
- # import nkululeko.models.finetune_model as fm
16
- from nkululeko.reporting.reporter import Reporter
17
- import torch
18
- import ast
19
- import numpy as np
20
- from sklearn.metrics import recall_score
21
- from collections import OrderedDict
22
- import os
5
+ import dataclasses
23
6
  import json
7
+ import os
24
8
  import pickle
25
- import dataclasses
26
9
  import typing
27
10
 
11
+ import datasets
12
+ import numpy as np
13
+ import pandas as pd
28
14
  import torch
29
15
  import transformers
30
- from transformers.models.wav2vec2.modeling_wav2vec2 import (
31
- Wav2Vec2PreTrainedModel,
32
- Wav2Vec2Model,
33
- )
16
+ from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
17
+ from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel
34
18
 
19
+ import audeer
20
+ import audiofile
21
+ import audmetric
22
+
23
+ import nkululeko.glob_conf as glob_conf
24
+ from nkululeko.models.model import Model as BaseModel
25
+ from nkululeko.reporting.reporter import Reporter
35
26
 
36
- class Pretrained_model(BaseModel):
27
+
28
+ class TunedModel(BaseModel):
37
29
 
38
30
  is_classifier = True
39
31
 
40
32
  def __init__(self, df_train, df_test, feats_train, feats_test):
41
- """Constructor taking the configuration and all dataframes"""
33
+ """Constructor taking the configuration and all dataframes."""
42
34
  super().__init__(df_train, df_test, feats_train, feats_test)
43
- super().set_model_type("ann")
35
+ super().set_model_type("finetuned")
44
36
  self.name = "finetuned_wav2vec2"
45
- self.model_type = "finetuned"
46
37
  self.target = glob_conf.config["DATA"]["target"]
47
38
  labels = glob_conf.labels
48
39
  self.class_num = len(labels)
@@ -74,22 +65,11 @@ class Pretrained_model(BaseModel):
74
65
 
75
66
  for split in ["train", "dev"]:
76
67
  df = data_sources[split]
77
- df[target_name] = df[target_name].astype("float")
78
-
79
- y = pd.Series(
80
- data=df.itertuples(index=False, name=None),
81
- index=df.index,
82
- dtype=object,
83
- name="labels",
84
- )
85
-
68
+ y = df[target_name].astype("float")
86
69
  y.name = "targets"
87
70
  df = y.reset_index()
88
71
  df.start = df.start.dt.total_seconds()
89
72
  df.end = df.end.dt.total_seconds()
90
-
91
- # print(f"{split}: {len(df)}")
92
-
93
73
  ds = datasets.Dataset.from_pandas(df)
94
74
  dataset[split] = ds
95
75
 
@@ -143,12 +123,6 @@ class Pretrained_model(BaseModel):
143
123
  def set_model_type(self, type):
144
124
  self.model_type = type
145
125
 
146
- def is_ann(self):
147
- if self.model_type == "ann":
148
- return True
149
- else:
150
- return False
151
-
152
126
  def set_testdata(self, data_df, feats_df):
153
127
  self.df_test, self.feats_test = data_df, feats_df
154
128
 
@@ -207,7 +181,8 @@ class Pretrained_model(BaseModel):
207
181
  "ACC": audmetric.accuracy,
208
182
  }
209
183
 
210
- truth = p.label_ids[:, 0].astype(int)
184
+ # truth = p.label_ids[:, 0].astype(int)
185
+ truth = p.label_ids
211
186
  preds = p.predictions
212
187
  preds = np.argmax(preds, axis=1)
213
188
  scores = {}
@@ -216,8 +191,7 @@ class Pretrained_model(BaseModel):
216
191
  return scores
217
192
 
218
193
  def train(self):
219
- """Train the model"""
220
-
194
+ """Train the model."""
221
195
  model_root = self.util.get_path("model_dir")
222
196
  log_root = os.path.join(self.util.get_exp_dir(), "log")
223
197
  audeer.mkdir(log_root)
@@ -225,16 +199,17 @@ class Pretrained_model(BaseModel):
225
199
  conf_file = os.path.join(self.torch_root, "config.json")
226
200
  if os.path.isfile(conf_file):
227
201
  self.util.debug(f"reusing finetuned model: {conf_file}")
228
- self.load(self.run, self.epoch)
202
+ self.load(self.run, self.epoch_num)
229
203
  return
230
204
  targets = pd.DataFrame(self.dataset["train"]["targets"])
231
205
  counts = targets[0].value_counts().sort_index()
232
206
  train_weights = 1 / counts
233
207
  train_weights /= train_weights.sum()
234
- # print(train_weights)
235
- criterion_gender = torch.nn.CrossEntropyLoss(
208
+ self.util.debug("train weights: {train_weights}")
209
+ criterion = torch.nn.CrossEntropyLoss(
236
210
  weight=torch.Tensor(train_weights).to("cuda"),
237
211
  )
212
+ # criterion = torch.nn.CrossEntropyLoss()
238
213
 
239
214
  class Trainer(transformers.Trainer):
240
215
 
@@ -246,14 +221,12 @@ class Pretrained_model(BaseModel):
246
221
  ):
247
222
 
248
223
  targets = inputs.pop("labels").squeeze()
249
- targets_gender = targets.type(torch.long)
224
+ targets = targets.type(torch.long)
250
225
 
251
226
  outputs = model(**inputs)
252
- logits_gender = outputs[0].squeeze()
227
+ logits = outputs[0].squeeze()
253
228
 
254
- loss_gender = criterion_gender(logits_gender, targets_gender)
255
-
256
- loss = loss_gender
229
+ loss = criterion(logits, targets)
257
230
 
258
231
  return (loss, outputs) if return_outputs else loss
259
232
 
@@ -325,7 +298,7 @@ class Pretrained_model(BaseModel):
325
298
  self.df_test[self.target].to_numpy().astype(float),
326
299
  predictions,
327
300
  self.run,
328
- self.epoch,
301
+ self.epoch_num,
329
302
  )
330
303
  return report
331
304
 
@@ -371,13 +344,13 @@ class ModelOutput(transformers.file_utils.ModelOutput):
371
344
 
372
345
  class ModelHead(torch.nn.Module):
373
346
 
374
- def __init__(self, config, num_labels):
347
+ def __init__(self, config):
375
348
 
376
349
  super().__init__()
377
350
 
378
351
  self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
379
352
  self.dropout = torch.nn.Dropout(config.final_dropout)
380
- self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
353
+ self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels)
381
354
 
382
355
  def forward(self, features, **kwargs):
383
356
 
@@ -398,7 +371,7 @@ class Model(Wav2Vec2PreTrainedModel):
398
371
  super().__init__(config)
399
372
 
400
373
  self.wav2vec2 = Wav2Vec2Model(config)
401
- self.cat = ModelHead(config, 2)
374
+ self.cat = ModelHead(config)
402
375
  self.init_weights()
403
376
 
404
377
  def freeze_feature_extractor(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.85.0
3
+ Version: 0.85.1
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -333,6 +333,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
333
333
  Changelog
334
334
  =========
335
335
 
336
+ Version 0.85.1
337
+ --------------
338
+ * fixed bug in model_finetuned that label_num was constant 2
339
+
336
340
  Version 0.85.0
337
341
  --------------
338
342
  * first version with finetuning wav2vec2 layers
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
3
3
  nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
4
4
  nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
5
- nkululeko/constants.py,sha256=flWSUNQs4r0X0SgoR1I72Mk49cRUdpBN8Zng8sySFBE,39
5
+ nkululeko/constants.py,sha256=WnTSXQjJmWE-IrXcNSEa5FFV_83-z0EOGXa9trq00uE,39
6
6
  nkululeko/demo.py,sha256=8bl15Kitoesnz8oa8yrs52T6YCSOhWbbq9PnZ8Hj6D0,3232
7
7
  nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
8
8
  nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
@@ -13,7 +13,7 @@ nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzw
13
13
  nkululeko/file_checker.py,sha256=LoLnL8aHpW-axMQ46qbqrManTs5otG9ShpEZuz9iRSk,3474
14
14
  nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
15
15
  nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
16
- nkululeko/modelrunner.py,sha256=pPhvTh1rIrFQg5Ox9T1KoFJ4wRcLCmJl7LFud2DA41w,10464
16
+ nkululeko/modelrunner.py,sha256=iCmfJxsS2UafcikjRdUqPQuqQMOYA-Ctr3et3HeNR3c,10452
17
17
  nkululeko/multidb.py,sha256=fG3VukEWP1vreVN4gB1IRXxwwg4jLftsSEYtu0o1f78,5634
18
18
  nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
19
19
  nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
@@ -89,7 +89,7 @@ nkululeko/models/model_svm.py,sha256=rsME3KvKvNG7bdE5lbvYUu85WZhaASZxxmdNDIVJRZ4
89
89
  nkululeko/models/model_svr.py,sha256=_YZeksqB3eBENGlg3g9RwYFlk9rQQ-XCeNBKLlGGVoE,725
90
90
  nkululeko/models/model_tree.py,sha256=rf16faUm4o2LJgkoYpeY998b8DQIvXZ73_m1IS3TnnE,417
91
91
  nkululeko/models/model_tree_reg.py,sha256=IgQcPTE-304HQLYSKPF8Z4ot_Ur9dH01fZjS0nXke_M,428
92
- nkululeko/models/model_tuned.py,sha256=zmagIE3QHP67_XJCx5r7ZXBojsp6SC8IS-L3XRWmCEk,15650
92
+ nkululeko/models/model_tuned.py,sha256=WJplfUK3CGLSd2mahUrPSjMvqjPfxLp99KFeZaz2AbU,15098
93
93
  nkululeko/models/model_xgb.py,sha256=Thgx5ESdIok4v72mKh4plxpo4smGcKALWNCJTDScY0M,447
94
94
  nkululeko/models/model_xgr.py,sha256=aGBtNGLWjOE_2rICGYGFxmT8DtnHYsIl1lIpMtghHsY,418
95
95
  nkululeko/reporting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -106,8 +106,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
106
106
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
107
107
  nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
108
108
  nkululeko/utils/util.py,sha256=b1IHFucRNuF9Iyv5IJeK4AEg0Rga0xKG80UM5GWWdHA,13816
109
- nkululeko-0.85.0.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
110
- nkululeko-0.85.0.dist-info/METADATA,sha256=Zt3H0FmIXOJvzyLOI0aC8VfvjrdIkd4uNvb937luo_k,36499
111
- nkululeko-0.85.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
112
- nkululeko-0.85.0.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
113
- nkululeko-0.85.0.dist-info/RECORD,,
109
+ nkululeko-0.85.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
110
+ nkululeko-0.85.1.dist-info/METADATA,sha256=RonY9PdKyHjwYsZ3T9TgEs1JNnY1qbMdDr-Sp6kcCW8,36591
111
+ nkululeko-0.85.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
112
+ nkululeko-0.85.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
113
+ nkululeko-0.85.1.dist-info/RECORD,,