nkululeko 0.85.2__py3-none-any.whl → 0.86.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.85.2"
1
+ VERSION="0.86.1"
2
2
  SAMPLING_RATE = 16000
nkululeko/experiment.py CHANGED
@@ -30,15 +30,14 @@ from nkululeko.utils.util import Util
30
30
 
31
31
 
32
32
  class Experiment:
33
- """Main class specifying an experiment"""
33
+ """Main class specifying an experiment."""
34
34
 
35
35
  def __init__(self, config_obj):
36
- """
37
- Parameters
38
- ----------
39
- config_obj : a config parser object that sets the experiment parameters and being set as a global object.
40
- """
36
+ """Constructor.
41
37
 
38
+ Args:
39
+ - config_obj : a config parser object that sets the experiment parameters and being set as a global object.
40
+ """
42
41
  self.set_globals(config_obj)
43
42
  self.name = glob_conf.config["EXP"]["name"]
44
43
  self.root = os.path.join(glob_conf.config["EXP"]["root"], "")
@@ -73,8 +72,9 @@ class Experiment:
73
72
  if self.util.config_val("REPORT", "latex", False):
74
73
  self.report.export_latex()
75
74
 
76
- def get_name(self):
77
- return self.util.get_exp_name()
75
+ # moved to util
76
+ # def get_name(self):
77
+ # return self.util.get_exp_name()
78
78
 
79
79
  def set_globals(self, config_obj):
80
80
  """install a config object in the global space"""
@@ -109,15 +109,13 @@ class Experiment:
109
109
  # print keys/column
110
110
  dbs = ",".join(list(self.datasets.keys()))
111
111
  labels = self.util.config_val("DATA", "labels", False)
112
- auto_labels = list(
113
- next(iter(self.datasets.values())).df[self.target].unique()
114
- )
112
+ auto_labels = list(next(iter(self.datasets.values())).df[self.target].unique())
115
113
  if labels:
116
114
  self.labels = ast.literal_eval(labels)
117
115
  self.util.debug(f"Target labels (from config): {labels}")
118
116
  else:
119
117
  self.labels = auto_labels
120
- self.util.debug(f"Target labels (from database): {auto_labels}")
118
+ self.util.debug(f"Target labels (from database): {auto_labels}")
121
119
  glob_conf.set_labels(self.labels)
122
120
  self.util.debug(f"loaded databases {dbs}")
123
121
 
@@ -160,8 +158,7 @@ class Experiment:
160
158
  data.split()
161
159
  data.prepare_labels()
162
160
  self.df_test = pd.concat(
163
- [self.df_test, self.util.make_segmented_index(
164
- data.df_test)]
161
+ [self.df_test, self.util.make_segmented_index(data.df_test)]
165
162
  )
166
163
  self.df_test.is_labeled = data.is_labeled
167
164
  self.df_test.got_gender = self.got_gender
@@ -262,8 +259,7 @@ class Experiment:
262
259
  test_cats = self.df_test[self.target].unique()
263
260
  else:
264
261
  # if there is no target, copy a dummy label
265
- self.df_test = self._add_random_target(
266
- self.df_test).astype("str")
262
+ self.df_test = self._add_random_target(self.df_test).astype("str")
267
263
  train_cats = self.df_train[self.target].unique()
268
264
  # print(f"df_train: {pd.DataFrame(self.df_train[self.target])}")
269
265
  # print(f"train_cats with target {self.target}: {train_cats}")
@@ -271,8 +267,7 @@ class Experiment:
271
267
  if type(test_cats) == np.ndarray:
272
268
  self.util.debug(f"Categories test (nd.array): {test_cats}")
273
269
  else:
274
- self.util.debug(
275
- f"Categories test (list): {list(test_cats)}")
270
+ self.util.debug(f"Categories test (list): {list(test_cats)}")
276
271
  if type(train_cats) == np.ndarray:
277
272
  self.util.debug(f"Categories train (nd.array): {train_cats}")
278
273
  else:
@@ -295,8 +290,7 @@ class Experiment:
295
290
 
296
291
  target_factor = self.util.config_val("DATA", "target_divide_by", False)
297
292
  if target_factor:
298
- self.df_test[self.target] = self.df_test[self.target] / \
299
- float(target_factor)
293
+ self.df_test[self.target] = self.df_test[self.target] / float(target_factor)
300
294
  self.df_train[self.target] = self.df_train[self.target] / float(
301
295
  target_factor
302
296
  )
@@ -319,16 +313,14 @@ class Experiment:
319
313
  def plot_distribution(self, df_labels):
320
314
  """Plot the distribution of samples and speaker per target class and biological sex"""
321
315
  plot = Plots()
322
- sample_selection = self.util.config_val(
323
- "EXPL", "sample_selection", "all")
316
+ sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
324
317
  plot.plot_distributions(df_labels)
325
318
  if self.got_speaker:
326
319
  plot.plot_distributions_speaker(df_labels)
327
320
 
328
321
  def extract_test_feats(self):
329
322
  self.feats_test = pd.DataFrame()
330
- feats_name = "_".join(ast.literal_eval(
331
- glob_conf.config["DATA"]["tests"]))
323
+ feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["tests"]))
332
324
  feats_types = self.util.config_val_list("FEATS", "type", ["os"])
333
325
  self.feature_extractor = FeatureExtractor(
334
326
  self.df_test, feats_types, feats_name, "test"
@@ -345,8 +337,7 @@ class Experiment:
345
337
 
346
338
  """
347
339
  df_train, df_test = self.df_train, self.df_test
348
- feats_name = "_".join(ast.literal_eval(
349
- glob_conf.config["DATA"]["databases"]))
340
+ feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
350
341
  self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
351
342
  feats_types = self.util.config_val_list("FEATS", "type", [])
352
343
  # for some models no features are needed
@@ -380,20 +371,22 @@ class Experiment:
380
371
  f"test feats ({self.feats_test.shape[0]}) != test labels"
381
372
  f" ({self.df_test.shape[0]})"
382
373
  )
383
- self.df_test = self.df_test[self.df_test.index.isin(
384
- self.feats_test.index)]
385
- self.util.warn(f"mew test labels shape: {self.df_test.shape[0]}")
374
+ self.df_test = self.df_test[self.df_test.index.isin(self.feats_test.index)]
375
+ self.util.warn(f"new test labels shape: {self.df_test.shape[0]}")
386
376
 
387
377
  self._check_scale()
378
+ # store = self.util.get_path("store")
379
+ # store_format = self.util.config_val("FEATS", "store_format", "pkl")
380
+ # storage = f"{store}test_feats.{store_format}"
381
+ # self.util.write_store(self.feats_test, storage, store_format)
382
+ # storage = f"{store}train_feats.{store_format}"
383
+ # self.util.write_store(self.feats_train, storage, store_format)
388
384
 
389
385
  def augment(self):
390
- """
391
- Augment the selected samples
392
- """
386
+ """Augment the selected samples."""
393
387
  from nkululeko.augmenting.augmenter import Augmenter
394
388
 
395
- sample_selection = self.util.config_val(
396
- "AUGMENT", "sample_selection", "all")
389
+ sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
397
390
  if sample_selection == "all":
398
391
  df = pd.concat([self.df_train, self.df_test])
399
392
  elif sample_selection == "train":
@@ -488,8 +481,7 @@ class Experiment:
488
481
  """
489
482
  from nkululeko.augmenting.randomsplicer import Randomsplicer
490
483
 
491
- sample_selection = self.util.config_val(
492
- "AUGMENT", "sample_selection", "all")
484
+ sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
493
485
  if sample_selection == "all":
494
486
  df = pd.concat([self.df_train, self.df_test])
495
487
  elif sample_selection == "train":
@@ -510,8 +502,7 @@ class Experiment:
510
502
  plot_feats = eval(
511
503
  self.util.config_val("EXPL", "feature_distributions", "False")
512
504
  )
513
- sample_selection = self.util.config_val(
514
- "EXPL", "sample_selection", "all")
505
+ sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
515
506
  # get the data labels
516
507
  if sample_selection == "all":
517
508
  df_labels = pd.concat([self.df_train, self.df_test])
@@ -574,8 +565,7 @@ class Experiment:
574
565
  for scat_target in scat_targets:
575
566
  if self.util.is_categorical(df_labels[scat_target]):
576
567
  for scatter in scatters:
577
- plots.scatter_plot(
578
- df_feats, df_labels, scat_target, scatter)
568
+ plots.scatter_plot(df_feats, df_labels, scat_target, scatter)
579
569
  else:
580
570
  self.util.debug(
581
571
  f"{self.name}: binning continuous variable to categories"
@@ -590,6 +580,8 @@ class Experiment:
590
580
  )
591
581
 
592
582
  def _check_scale(self):
583
+ self.util.save_to_store(self.feats_train, "feats_train")
584
+ self.util.save_to_store(self.feats_test, "feats_test")
593
585
  scale_feats = self.util.config_val("FEATS", "scale", False)
594
586
  # print the scale
595
587
  self.util.debug(f"scaler: {scale_feats}")
@@ -664,8 +656,7 @@ class Experiment:
664
656
  preds = best.preds
665
657
  speakers = self.df_test.speaker.values
666
658
  print(f"{len(truths)} {len(preds)} {len(speakers) }")
667
- df = pd.DataFrame(
668
- data={"truth": truths, "pred": preds, "speaker": speakers})
659
+ df = pd.DataFrame(data={"truth": truths, "pred": preds, "speaker": speakers})
669
660
  plot_name = "result_combined_per_speaker"
670
661
  self.util.debug(
671
662
  f"plotting speaker combination ({function}) confusion matrix to"
@@ -65,28 +65,28 @@ class Opensmileset(Featureset):
65
65
  feats = smile.process_signal(signal, sr)
66
66
  return feats.to_numpy()
67
67
 
68
- def filter(self):
69
- # use only the features that are indexed in the target dataframes
70
- self.df = self.df[self.df.index.isin(self.data_df.index)]
71
- try:
72
- # use only some features
73
- selected_features = ast.literal_eval(
74
- glob_conf.config["FEATS"]["os.features"]
75
- )
76
- self.util.debug(f"selecting features from opensmile: {selected_features}")
77
- sel_feats_df = pd.DataFrame()
78
- hit = False
79
- for feat in selected_features:
80
- try:
81
- sel_feats_df[feat] = self.df[feat]
82
- hit = True
83
- except KeyError:
84
- pass
85
- if hit:
86
- self.df = sel_feats_df
87
- self.util.debug(
88
- "new feats shape after selecting opensmile features:"
89
- f" {self.df.shape}"
90
- )
91
- except KeyError:
92
- pass
68
+ # def filter(self):
69
+ # # use only the features that are indexed in the target dataframes
70
+ # self.df = self.df[self.df.index.isin(self.data_df.index)]
71
+ # try:
72
+ # # use only some features
73
+ # selected_features = ast.literal_eval(
74
+ # glob_conf.config["FEATS"]["os.features"]
75
+ # )
76
+ # self.util.debug(f"selecting features from opensmile: {selected_features}")
77
+ # sel_feats_df = pd.DataFrame()
78
+ # hit = False
79
+ # for feat in selected_features:
80
+ # try:
81
+ # sel_feats_df[feat] = self.df[feat]
82
+ # hit = True
83
+ # except KeyError:
84
+ # pass
85
+ # if hit:
86
+ # self.df = sel_feats_df
87
+ # self.util.debug(
88
+ # "new feats shape after selecting opensmile features:"
89
+ # f" {self.df.shape}"
90
+ # )
91
+ # except KeyError:
92
+ # pass
@@ -15,7 +15,7 @@ class Featureset:
15
15
  self.name = name
16
16
  self.data_df = data_df
17
17
  self.util = Util("featureset")
18
- self.feats_types = feats_type
18
+ self.feats_type = feats_type
19
19
 
20
20
  def extract(self):
21
21
  pass
@@ -25,8 +25,7 @@ class Featureset:
25
25
  self.df = self.df[self.df.index.isin(self.data_df.index)]
26
26
  try:
27
27
  # use only some features
28
- selected_features = ast.literal_eval(
29
- glob_conf.config["FEATS"]["features"])
28
+ selected_features = ast.literal_eval(glob_conf.config["FEATS"]["features"])
30
29
  self.util.debug(f"selecting features: {selected_features}")
31
30
  sel_feats_df = pd.DataFrame()
32
31
  hit = False
@@ -35,11 +34,12 @@ class Featureset:
35
34
  sel_feats_df[feat] = self.df[feat]
36
35
  hit = True
37
36
  except KeyError:
37
+ self.util.warn(f"non existent feature in {self.feats_type}: {feat}")
38
38
  pass
39
39
  if hit:
40
40
  self.df = sel_feats_df
41
41
  self.util.debug(
42
- f"new feats shape after selecting features: {self.df.shape}"
42
+ f"new feats shape after selecting features for {self.feats_type}: {self.df.shape}"
43
43
  )
44
44
  except KeyError:
45
45
  pass
@@ -1,6 +1,4 @@
1
- """
2
- Code based on @jwagner.
3
- """
1
+ """Code based on @jwagner."""
4
2
 
5
3
  import dataclasses
6
4
  import json
@@ -27,8 +25,6 @@ from nkululeko.reporting.reporter import Reporter
27
25
 
28
26
  class TunedModel(BaseModel):
29
27
 
30
- is_classifier = True
31
-
32
28
  def __init__(self, df_train, df_test, feats_train, feats_test):
33
29
  """Constructor taking the configuration and all dataframes."""
34
30
  super().__init__(df_train, df_test, feats_train, feats_test)
@@ -37,25 +33,48 @@ class TunedModel(BaseModel):
37
33
  self.target = glob_conf.config["DATA"]["target"]
38
34
  labels = glob_conf.labels
39
35
  self.class_num = len(labels)
40
- device = self.util.config_val("MODEL", "device", "cpu")
41
- self.batch_size = int(self.util.config_val("MODEL", "batch_size", "8"))
42
- if device != "cpu":
43
- self.util.debug(f"running on device {device}")
36
+ device = self.util.config_val("MODEL", "device", False)
37
+ if not device:
38
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ else:
40
+ self.device = device
41
+ if self.device != "cpu":
44
42
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
45
- os.environ["CUDA_VISIBLE_DEVICES"] = device
43
+ os.environ["CUDA_VISIBLE_DEVICES"] = self.device
44
+ self.util.debug(f"running on device {self.device}")
45
+ self.is_classifier = self.util.exp_is_classification()
46
+ if self.is_classifier:
47
+ self.measure = "uar"
48
+ else:
49
+ self.measure = self.util.config_val("MODEL", "measure", "ccc")
50
+ self.util.debug(f"evaluation metrics: {self.measure}")
51
+ self.batch_size = int(self.util.config_val("MODEL", "batch_size", "8"))
52
+ self.util.debug(f"batch size: {self.batch_size}")
53
+ self.learning_rate = float(
54
+ self.util.config_val("MODEL", "learning_rate", 0.0001)
55
+ )
46
56
  self.df_train, self.df_test = df_train, df_test
47
57
  self.epoch_num = int(self.util.config_val("EXP", "epochs", 1))
48
-
58
+ drop = self.util.config_val("MODEL", "drop", False)
59
+ self.drop = 0.1
60
+ if drop:
61
+ self.drop = float(drop)
62
+ self.util.debug(f"init: training with dropout: {self.drop}")
49
63
  self._init_model()
50
64
 
51
65
  def _init_model(self):
52
66
  model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
67
+ pretrained_model = self.util.config_val(
68
+ "MODEL", "pretrained_model", model_path)
53
69
  self.num_layers = None
54
70
  self.sampling_rate = 16000
55
71
  self.max_duration_sec = 8.0
56
72
  self.accumulation_steps = 4
57
- # create dataset
58
73
 
74
+ # print finetuning information via debug
75
+ self.util.debug(f"Finetuning from model: {pretrained_model}")
76
+
77
+ # create dataset
59
78
  dataset = {}
60
79
  target_name = glob_conf.target
61
80
  data_sources = {
@@ -76,22 +95,34 @@ class TunedModel(BaseModel):
76
95
  self.dataset = datasets.DatasetDict(dataset)
77
96
 
78
97
  # load pre-trained model
79
- le = glob_conf.label_encoder
80
- mapping = dict(zip(le.classes_, range(len(le.classes_))))
81
- target_mapping = {k: int(v) for k, v in mapping.items()}
82
- target_mapping_reverse = {value: key for key, value in target_mapping.items()}
83
-
84
- self.config = transformers.AutoConfig.from_pretrained(
85
- model_path,
86
- num_labels=len(target_mapping),
87
- label2id=target_mapping,
88
- id2label=target_mapping_reverse,
89
- finetuning_task=target_name,
90
- )
98
+ if self.is_classifier:
99
+ self.util.debug(f"Task is classification.")
100
+ le = glob_conf.label_encoder
101
+ mapping = dict(zip(le.classes_, range(len(le.classes_))))
102
+ target_mapping = {k: int(v) for k, v in mapping.items()}
103
+ target_mapping_reverse = {
104
+ value: key for key, value in target_mapping.items()
105
+ }
106
+ self.config = transformers.AutoConfig.from_pretrained(
107
+ pretrained_model,
108
+ num_labels=len(target_mapping),
109
+ label2id=target_mapping,
110
+ id2label=target_mapping_reverse,
111
+ finetuning_task=target_name,
112
+ )
113
+ else:
114
+ self.util.debug(f"Task is regression.")
115
+ self.config = transformers.AutoConfig.from_pretrained(
116
+ pretrained_model,
117
+ num_labels=1,
118
+ finetuning_task=target_name,
119
+ )
91
120
  if self.num_layers is not None:
92
121
  self.config.num_hidden_layers = self.num_layers
122
+ self.config.final_dropout = self.drop
93
123
  setattr(self.config, "sampling_rate", self.sampling_rate)
94
124
  setattr(self.config, "data", self.util.get_data_name())
125
+ setattr(self.config, "is_classifier", self.is_classifier)
95
126
 
96
127
  vocab_dict = {}
97
128
  with open("vocab.json", "w") as vocab_file:
@@ -113,7 +144,7 @@ class TunedModel(BaseModel):
113
144
  assert self.processor.feature_extractor.sampling_rate == self.sampling_rate
114
145
 
115
146
  self.model = Model.from_pretrained(
116
- model_path,
147
+ pretrained_model,
117
148
  config=self.config,
118
149
  )
119
150
  self.model.freeze_feature_extractor()
@@ -170,7 +201,7 @@ class TunedModel(BaseModel):
170
201
  return_tensors="pt",
171
202
  )
172
203
 
173
- batch["labels"] = torch.tensor(targets)
204
+ batch["labels"] = torch.Tensor(targets)
174
205
 
175
206
  return batch
176
207
 
@@ -180,14 +211,25 @@ class TunedModel(BaseModel):
180
211
  "UAR": audmetric.unweighted_average_recall,
181
212
  "ACC": audmetric.accuracy,
182
213
  }
214
+ metrics_reg = {
215
+ "PCC": audmetric.pearson_cc,
216
+ "CCC": audmetric.concordance_cc,
217
+ "MSE": audmetric.mean_squared_error,
218
+ "MAE": audmetric.mean_absolute_error,
219
+ }
183
220
 
184
221
  # truth = p.label_ids[:, 0].astype(int)
185
222
  truth = p.label_ids
186
223
  preds = p.predictions
187
224
  preds = np.argmax(preds, axis=1)
188
225
  scores = {}
189
- for name, metric in metrics.items():
190
- scores[f"{name}"] = metric(truth, preds)
226
+ if self.is_classifier:
227
+ for name, metric in metrics.items():
228
+ scores[f"{name}"] = metric(truth, preds)
229
+ else:
230
+ for name, metric in metrics_reg.items():
231
+ scores[f"{name}"] = metric(truth, preds)
232
+
191
233
  return scores
192
234
 
193
235
  def train(self):
@@ -203,23 +245,27 @@ class TunedModel(BaseModel):
203
245
  return
204
246
  targets = pd.DataFrame(self.dataset["train"]["targets"])
205
247
  counts = targets[0].value_counts().sort_index()
206
- train_weights = 1 / counts
207
- train_weights /= train_weights.sum()
208
- self.util.debug("train weights: {train_weights}")
209
- criterion = torch.nn.CrossEntropyLoss(
210
- weight=torch.Tensor(train_weights).to("cuda"),
211
- )
212
- # criterion = torch.nn.CrossEntropyLoss()
213
248
 
214
- class Trainer(transformers.Trainer):
249
+ if self.is_classifier:
250
+ train_weights = 1 / counts
251
+ train_weights /= train_weights.sum()
252
+ self.util.debug(f"train weights: {train_weights}")
253
+ criterion = torch.nn.CrossEntropyLoss(
254
+ weight=torch.Tensor(train_weights).to("cuda"),
255
+ )
256
+ else:
257
+ criterion = ConcordanceCorCoeff()
258
+
259
+ # set push_to_hub value, default false
260
+ push = self.util.config_val("MODEL", "push_to_hub", False)
215
261
 
262
+ class Trainer(transformers.Trainer):
216
263
  def compute_loss(
217
264
  self,
218
265
  model,
219
266
  inputs,
220
267
  return_outputs=False,
221
268
  ):
222
-
223
269
  targets = inputs.pop("labels").squeeze()
224
270
  targets = targets.type(torch.long)
225
271
 
@@ -236,7 +282,8 @@ class TunedModel(BaseModel):
236
282
  // 5
237
283
  )
238
284
  num_steps = max(1, num_steps)
239
- # print(num_steps)
285
+
286
+ metrics_for_best_model = self.measure.upper()
240
287
 
241
288
  training_args = transformers.TrainingArguments(
242
289
  output_dir=model_root,
@@ -246,17 +293,20 @@ class TunedModel(BaseModel):
246
293
  gradient_accumulation_steps=self.accumulation_steps,
247
294
  evaluation_strategy="steps",
248
295
  num_train_epochs=self.epoch_num,
249
- fp16=True,
296
+ fp16=self.device == "cuda",
250
297
  save_steps=num_steps,
251
298
  eval_steps=num_steps,
252
299
  logging_steps=num_steps,
253
- learning_rate=1e-4,
300
+ logging_strategy="epoch",
301
+ learning_rate=self.learning_rate,
254
302
  save_total_limit=2,
255
- metric_for_best_model="UAR",
303
+ metric_for_best_model=metrics_for_best_model,
256
304
  greater_is_better=True,
257
305
  load_best_model_at_end=True,
258
306
  remove_unused_columns=False,
259
307
  report_to="none",
308
+ push_to_hub=push,
309
+ hub_model_id=f"{self.util.get_name()}",
260
310
  )
261
311
 
262
312
  trainer = Trainer(
@@ -271,6 +321,7 @@ class TunedModel(BaseModel):
271
321
  )
272
322
  trainer.train()
273
323
  trainer.save_model(self.torch_root)
324
+ self.util.debug(f"saved best model to {self.torch_root}")
274
325
  self.load(self.run, self.epoch)
275
326
 
276
327
  def get_predictions(self):
@@ -305,7 +356,7 @@ class TunedModel(BaseModel):
305
356
  def predict_sample(self, signal):
306
357
  """Predict one sample"""
307
358
  prediction = {}
308
- if self.util.exp_is_classification():
359
+ if self.is_classifier:
309
360
  # get the class probabilities
310
361
  predictions = self.model.predict(signal)
311
362
  # pred = self.clf.predict(features)
@@ -337,8 +388,19 @@ class TunedModel(BaseModel):
337
388
  @dataclasses.dataclass
338
389
  class ModelOutput(transformers.file_utils.ModelOutput):
339
390
 
340
- logits_cat: torch.FloatTensor = None
391
+ logits: torch.FloatTensor = None
392
+ hidden_states: typing.Tuple[torch.FloatTensor] = None
393
+ cnn_features: torch.FloatTensor = None
394
+
395
+
396
+ @dataclasses.dataclass
397
+ class ModelOutputReg(transformers.file_utils.ModelOutput):
398
+
399
+ logits: torch.FloatTensor
341
400
  hidden_states: typing.Tuple[torch.FloatTensor] = None
401
+ attentions: typing.Tuple[torch.FloatTensor] = None
402
+ logits_framewise: torch.FloatTensor = None
403
+ hidden_states_framewise: torch.FloatTensor = None
342
404
  cnn_features: torch.FloatTensor = None
343
405
 
344
406
 
@@ -368,10 +430,14 @@ class Model(Wav2Vec2PreTrainedModel):
368
430
 
369
431
  def __init__(self, config):
370
432
 
433
+ if not hasattr(config, "add_adapter"):
434
+ setattr(config, "add_adapter", False)
435
+
371
436
  super().__init__(config)
372
437
 
373
438
  self.wav2vec2 = Wav2Vec2Model(config)
374
- self.cat = ModelHead(config)
439
+ self.head = ModelHead(config)
440
+ self.is_classifier = config.is_classifier
375
441
  self.init_weights()
376
442
 
377
443
  def freeze_feature_extractor(self):
@@ -407,39 +473,44 @@ class Model(Wav2Vec2PreTrainedModel):
407
473
  labels=None,
408
474
  return_hidden=False,
409
475
  ):
410
-
411
476
  outputs = self.wav2vec2(
412
477
  input_values,
413
478
  attention_mask=attention_mask,
414
479
  )
415
-
416
480
  cnn_features = outputs.extract_features
417
481
  hidden_states_framewise = outputs.last_hidden_state
418
482
  hidden_states = self.pooling(
419
483
  hidden_states_framewise,
420
484
  attention_mask,
421
485
  )
422
- logits_cat = self.cat(hidden_states)
423
-
486
+ logits = self.head(hidden_states)
424
487
  if not self.training:
425
- logits_cat = torch.softmax(logits_cat, dim=1)
488
+ logits = torch.softmax(logits, dim=1)
426
489
 
427
490
  if return_hidden:
428
-
429
491
  # make time last axis
430
492
  cnn_features = torch.transpose(cnn_features, 1, 2)
431
-
432
- return ModelOutput(
433
- logits_cat=logits_cat,
434
- hidden_states=hidden_states,
435
- cnn_features=cnn_features,
436
- )
437
-
493
+ if self.is_classifier:
494
+ return ModelOutput(
495
+ logits=logits,
496
+ hidden_states=hidden_states,
497
+ cnn_features=cnn_features,
498
+ )
499
+ else:
500
+ return ModelOutputReg(
501
+ logits=logits,
502
+ hidden_states=hidden_states,
503
+ cnn_features=cnn_features,
504
+ )
438
505
  else:
439
-
440
- return ModelOutput(
441
- logits_cat=logits_cat,
442
- )
506
+ if self.is_classifier:
507
+ return ModelOutput(
508
+ logits=logits,
509
+ )
510
+ else:
511
+ return ModelOutputReg(
512
+ logits=logits,
513
+ )
443
514
 
444
515
  def predict(self, signal):
445
516
  result = self(torch.from_numpy(signal))
@@ -447,33 +518,31 @@ class Model(Wav2Vec2PreTrainedModel):
447
518
  return result
448
519
 
449
520
 
450
- class ModelWithPreProcessing(Model):
451
-
452
- def __init__(self, config):
453
- super().__init__(config)
521
+ class ConcordanceCorCoeff(torch.nn.Module):
454
522
 
455
- def forward(
456
- self,
457
- input_values,
458
- ):
459
- # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
460
- # normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
461
-
462
- mean = input_values.mean()
463
-
464
- # var = input_values.var()
465
- # raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
466
-
467
- var = torch.square(input_values - mean).mean()
468
- input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
469
-
470
- output = super().forward(
471
- input_values,
472
- return_hidden=True,
523
+ def __init__(self):
524
+ super().__init__()
525
+ self.mean = torch.mean
526
+ self.var = torch.var
527
+ self.sum = torch.sum
528
+ self.sqrt = torch.sqrt
529
+ self.std = torch.std
530
+
531
+ def forward(self, prediction, ground_truth):
532
+ ground_truth = ground_truth.float()
533
+ mean_gt = self.mean(ground_truth, 0)
534
+ mean_pred = self.mean(prediction, 0)
535
+ var_gt = self.var(ground_truth, 0)
536
+ var_pred = self.var(prediction, 0)
537
+ v_pred = prediction - mean_pred
538
+ v_gt = ground_truth - mean_gt
539
+ cor = self.sum(v_pred * v_gt) / (
540
+ self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
473
541
  )
542
+ sd_gt = self.std(ground_truth)
543
+ sd_pred = self.std(prediction)
544
+ numerator = 2 * cor * sd_gt * sd_pred
545
+ denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
546
+ ccc = numerator / denominator
474
547
 
475
- return (
476
- output.hidden_states,
477
- output.logits_cat,
478
- output.cnn_features,
479
- )
548
+ return 1 - ccc
nkululeko/resample.py CHANGED
@@ -11,22 +11,32 @@ from nkululeko.utils.util import Util
11
11
 
12
12
  from nkululeko.constants import VERSION
13
13
  from nkululeko.experiment import Experiment
14
+ from nkululeko.utils.files import find_files
14
15
 
15
16
 
16
17
  def main(src_dir):
17
18
  parser = argparse.ArgumentParser(
18
- description="Call the nkululeko RESAMPLE framework.")
19
+ description="Call the nkululeko RESAMPLE framework."
20
+ )
19
21
  parser.add_argument("--config", default=None,
20
22
  help="The base configuration")
21
23
  parser.add_argument("--file", default=None,
22
24
  help="The input audio file to resample")
23
- parser.add_argument("--replace", action="store_true",
24
- help="Replace the original audio file")
25
+ parser.add_argument(
26
+ "--folder",
27
+ default=None,
28
+ help="The input directory containing audio files and subdirectories to resample",
29
+ )
30
+ parser.add_argument(
31
+ "--replace", action="store_true", help="Replace the original audio file"
32
+ )
25
33
 
26
34
  args = parser.parse_args()
27
35
 
28
- if args.file is None and args.config is None:
29
- print("ERROR: Either --file or --config argument must be provided.")
36
+ if args.file is None and args.folder is None and args.config is None:
37
+ print(
38
+ "ERROR: Either --file, --folder, or --config argument must be provided."
39
+ )
30
40
  exit()
31
41
 
32
42
  if args.file is not None:
@@ -42,6 +52,20 @@ def main(src_dir):
42
52
  util.debug(f"Resampling audio file: {args.file}")
43
53
  rs = Resampler(df_sample, not_testing=True, replace=args.replace)
44
54
  rs.resample()
55
+ elif args.folder is not None:
56
+ # Load all audio files in the directory and its subdirectories into a DataFrame
57
+ files = find_files(args.folder, relative=True, ext=["wav"])
58
+ files = pd.Series(files)
59
+ df_sample = pd.DataFrame(index=files)
60
+ df_sample.index = audformat.utils.to_segmented_index(
61
+ df_sample.index, allow_nat=False
62
+ )
63
+
64
+ # Resample the audio files
65
+ util = Util("resampler", has_config=False)
66
+ util.debug(f"Resampling audio files in directory: {args.folder}")
67
+ rs = Resampler(df_sample, not_testing=True, replace=args.replace)
68
+ rs.resample()
45
69
  else:
46
70
  # Existing code for handling INI file
47
71
  config_file = args.config
@@ -66,6 +90,7 @@ def main(src_dir):
66
90
 
67
91
  if util.config_val("EXP", "no_warnings", False):
68
92
  import warnings
93
+
69
94
  warnings.filterwarnings("ignore")
70
95
 
71
96
  # Load the data
@@ -74,7 +99,8 @@ def main(src_dir):
74
99
  # Split into train and test
75
100
  expr.fill_train_and_tests()
76
101
  util.debug(
77
- f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
102
+ f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}"
103
+ )
78
104
 
79
105
  sample_selection = util.config_val(
80
106
  "RESAMPLE", "sample_selection", "all")
nkululeko/utils/util.py CHANGED
@@ -134,6 +134,12 @@ class Util:
134
134
  pd_series.dtype, pd.CategoricalDtype
135
135
  )
136
136
 
137
+ def get_name(self):
138
+ """
139
+ Get the name of the experiment
140
+ """
141
+ return self.config["EXP"]["name"]
142
+
137
143
  def get_exp_dir(self):
138
144
  """
139
145
  Get the experiment directory
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.85.2
3
+ Version: 0.86.1
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -256,6 +256,7 @@ There's my [blog](http://blog.syntheticspeech.de/?s=nkululeko) with tutorials:
256
256
  * [Compare several databases](http://blog.syntheticspeech.de/2024/01/02/nkululeko-compare-several-databases/)
257
257
  * [Tweak the target variable for database comparison](http://blog.syntheticspeech.de/2024/03/13/nkululeko-how-to-tweak-the-target-variable-for-database-comparison/)
258
258
  * [How to run multiple experiments in one go](http://blog.syntheticspeech.de/2022/03/28/how-to-run-multiple-experiments-in-one-go-with-nkululeko/)
259
+ * [How to finetune a transformer-model](http://blog.syntheticspeech.de/2024/05/29/nkululeko-how-to-finetune-a-transformer-model/)
259
260
 
260
261
  ### <a name="helloworld">Hello World example</a>
261
262
  * NEW: [Here's a Google colab that runs this example out-of-the-box](https://colab.research.google.com/drive/1GYNBd5cdZQ1QC3Jm58qoeMaJg3UuPhjw?usp=sharing#scrollTo=4G_SjuF9xeQf), and here is the same [with Kaggle](https://www.kaggle.com/felixburk/nkululeko-hello-world-example)
@@ -333,6 +334,17 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
333
334
  Changelog
334
335
  =========
335
336
 
337
+ Version 0.86.1
338
+ --------------
339
+ * functionality to push to hub
340
+ * fixed bug that prevented wavlm finetuning
341
+
342
+ Version 0.86.0
343
+ --------------
344
+ * added regression to finetuning
345
+ * added other transformer models to finetuning
346
+ * added output the train/dev features sets actually used by the model
347
+
336
348
  Version 0.85.2
337
349
  --------------
338
350
  * added data, and automatic task label detection
@@ -2,11 +2,11 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
3
3
  nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
4
4
  nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
5
- nkululeko/constants.py,sha256=l15EMSj8vmejkCKCzQ6jMrgj5PuNrcHIREXt9kbSw7U,39
5
+ nkululeko/constants.py,sha256=pZ3DZYgXdEpxfaj-mnI6q21TyYMa2QQG_sKa6CBxCCA,39
6
6
  nkululeko/demo.py,sha256=8bl15Kitoesnz8oa8yrs52T6YCSOhWbbq9PnZ8Hj6D0,3232
7
7
  nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
8
8
  nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
9
- nkululeko/experiment.py,sha256=ZsSWdasWUyIBF_4vxb4FxvHs42pytG7ErUOABA-WWTo,30722
9
+ nkululeko/experiment.py,sha256=24FmvF9_zNXE86fO6gzss1M-BjceOCiV6nyJAs0SM_Y,30986
10
10
  nkululeko/explore.py,sha256=lDzRoW_Taa5u4BBABZLD89BcQWnYlrftJR4jgt1yyj0,2609
11
11
  nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
12
12
  nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzwM,3976
@@ -19,7 +19,7 @@ nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
19
19
  nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
20
20
  nkululeko/plots.py,sha256=nd9tF_61DyAx7oGZF8gTrHXazkgFjFe4eClxu1nQ_XU,23276
21
21
  nkululeko/predict.py,sha256=sF091sSSLnEWcISx9ZcULLie3tY5XeFsQJd6b3vrxFg,2409
22
- nkululeko/resample.py,sha256=IPtYqU0nhZ-CqO_O1jJN0EvpfjxHZdFRwdTpEJOVuaQ,3354
22
+ nkululeko/resample.py,sha256=2d9eao_0sLrGZ_KSl8OVKsPor3BkFrlmMhrpB9WelIs,4267
23
23
  nkululeko/runmanager.py,sha256=eTM1DNQKt1lxYhzt4vZyZluPXW9sWlIJHNQzex4lkJU,7624
24
24
  nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
25
25
  nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
@@ -58,7 +58,7 @@ nkululeko/feat_extract/feats_hubert.py,sha256=cLoUzSLjSYBkQnftjacSL7ES3O7Ysh_KrP
58
58
  nkululeko/feat_extract/feats_import.py,sha256=rj1p8lz19tCAC8hLzzZAwZ0M6gzwH3BzfabFUgal0yw,1622
59
59
  nkululeko/feat_extract/feats_mld.py,sha256=Vvu7GZOkn7Vda8eIOXqHjg78zegkFe3vTUaCXyVM0eA,2021
60
60
  nkululeko/feat_extract/feats_mos.py,sha256=KXNt7QYEfxkvr6UyVhig2aWQBaIvovlrR4gPuP03gmo,4174
61
- nkululeko/feat_extract/feats_opensmile.py,sha256=vLY8HCpeOj9NdJXzt_GVI3Vxwsjf9cEfcqJ3IHqlTQY,3978
61
+ nkululeko/feat_extract/feats_opensmile.py,sha256=g6ZsAxjjGGvGfrr5fngWC-NJ8E7CP1kYZwrlodZJzzU,4028
62
62
  nkululeko/feat_extract/feats_oxbow.py,sha256=CmIG9cbHTJTJVnzgCPdQpYpnlewWExpsr5ZcK8Malyo,4980
63
63
  nkululeko/feat_extract/feats_praat.py,sha256=kZrS6srzH7WoWEd2prp1Dxw6g9JklFQGTNq5zzPpHzg,3105
64
64
  nkululeko/feat_extract/feats_snr.py,sha256=9dqZ-4RpK98iJEssM3ttozNd18LWlZYM_QVXvp5xDcs,2829
@@ -69,13 +69,12 @@ nkululeko/feat_extract/feats_trill.py,sha256=K2ahhdpwpjgg3WZS1POg3UMP2U44i8cLZZv
69
69
  nkululeko/feat_extract/feats_wav2vec2.py,sha256=9WUMfyddB_3nx79g7mZoQrRynhM1uEBWuOotRq8bxoU,5268
70
70
  nkululeko/feat_extract/feats_wavlm.py,sha256=ulxpGjifUFx2ZgGmY32SmBJGIuvkYHoLb2n1LZ8KMwA,4703
71
71
  nkululeko/feat_extract/feats_whisper.py,sha256=0N7Vj65OVi2PNoB_NrDjWT5lP6xZNKxFOZZIoxkJvcA,4533
72
- nkululeko/feat_extract/featureset.py,sha256=HtgW2389rmlRAgFP3F1sSFzq2_iUVr2NhOfIXG9omt0,1448
72
+ nkululeko/feat_extract/featureset.py,sha256=ll7tyKAdr--TDShyOYJg0FB4I9NqBq0Ni1k_kUJ-2Vw,1541
73
73
  nkululeko/feat_extract/feinberg_praat.py,sha256=EP9pMALjlKdiYInLQdrZ7MmE499Mq-ISRCgqbqL3Rxc,21304
74
74
  nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
75
75
  nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
76
76
  nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
77
77
  nkululeko/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
- nkululeko/models/finetune_model.py,sha256=OMlzDyUFNXZ2xSiqqH8tbzey_KzPJ4jsoYT-4KrWFKM,5091
79
78
  nkululeko/models/model.py,sha256=PUCqF2r_dEfmFsZn6Cgr1UIzYvxziLH6nSqZ5-vuN1o,11639
80
79
  nkululeko/models/model_bayes.py,sha256=WJFZ8wFKwWATz6MhmjeZIi1Pal1viU549WL_PjXDSy8,406
81
80
  nkululeko/models/model_cnn.py,sha256=bJxqwe6FnVR2hFeqN6EXexYGgvKYFED1VOhBXVlLWaE,9954
@@ -89,7 +88,7 @@ nkululeko/models/model_svm.py,sha256=rsME3KvKvNG7bdE5lbvYUu85WZhaASZxxmdNDIVJRZ4
89
88
  nkululeko/models/model_svr.py,sha256=_YZeksqB3eBENGlg3g9RwYFlk9rQQ-XCeNBKLlGGVoE,725
90
89
  nkululeko/models/model_tree.py,sha256=rf16faUm4o2LJgkoYpeY998b8DQIvXZ73_m1IS3TnnE,417
91
90
  nkululeko/models/model_tree_reg.py,sha256=IgQcPTE-304HQLYSKPF8Z4ot_Ur9dH01fZjS0nXke_M,428
92
- nkululeko/models/model_tuned.py,sha256=WJplfUK3CGLSd2mahUrPSjMvqjPfxLp99KFeZaz2AbU,15098
91
+ nkululeko/models/model_tuned.py,sha256=eiSKFmObn9_VNTqF1lZvWbyyWxvhy1PVjOiIcs3YiGA,18379
93
92
  nkululeko/models/model_xgb.py,sha256=Thgx5ESdIok4v72mKh4plxpo4smGcKALWNCJTDScY0M,447
94
93
  nkululeko/models/model_xgr.py,sha256=aGBtNGLWjOE_2rICGYGFxmT8DtnHYsIl1lIpMtghHsY,418
95
94
  nkululeko/reporting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -105,9 +104,9 @@ nkululeko/segmenting/seg_silero.py,sha256=lLytS38KzARS17omwv8VBw-zz60RVSXGSvZ5Ev
105
104
  nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
106
105
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
107
106
  nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
108
- nkululeko/utils/util.py,sha256=b1IHFucRNuF9Iyv5IJeK4AEg0Rga0xKG80UM5GWWdHA,13816
109
- nkululeko-0.85.2.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
110
- nkululeko-0.85.2.dist-info/METADATA,sha256=RVGREhA1jakUtQ707C0ecklnUZwx4skVHV0UbPwEsn0,36671
111
- nkululeko-0.85.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
112
- nkululeko-0.85.2.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
113
- nkululeko-0.85.2.dist-info/RECORD,,
107
+ nkululeko/utils/util.py,sha256=mK1MgO14NinrPhavJw72eR_2WN_kBKjVKiEJnzvdO1Q,13946
108
+ nkululeko-0.86.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
109
+ nkululeko-0.86.1.dist-info/METADATA,sha256=LXoMlzo5QBzABv0fpIDvf4nYDjCJkRCZL1XmffikrRc,37088
110
+ nkululeko-0.86.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
111
+ nkululeko-0.86.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
112
+ nkululeko-0.86.1.dist-info/RECORD,,
@@ -1,190 +0,0 @@
1
- """
2
- Code based on @jwagner
3
- """
4
-
5
- import dataclasses
6
- import typing
7
-
8
- import torch
9
- import transformers
10
- from transformers.models.wav2vec2.modeling_wav2vec2 import (
11
- Wav2Vec2PreTrainedModel,
12
- Wav2Vec2Model,
13
- )
14
-
15
-
16
- class ConcordanceCorCoeff(torch.nn.Module):
17
-
18
- def __init__(self):
19
-
20
- super().__init__()
21
-
22
- self.mean = torch.mean
23
- self.var = torch.var
24
- self.sum = torch.sum
25
- self.sqrt = torch.sqrt
26
- self.std = torch.std
27
-
28
- def forward(self, prediction, ground_truth):
29
-
30
- mean_gt = self.mean(ground_truth, 0)
31
- mean_pred = self.mean(prediction, 0)
32
- var_gt = self.var(ground_truth, 0)
33
- var_pred = self.var(prediction, 0)
34
- v_pred = prediction - mean_pred
35
- v_gt = ground_truth - mean_gt
36
- cor = self.sum(v_pred * v_gt) / (
37
- self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
38
- )
39
- sd_gt = self.std(ground_truth)
40
- sd_pred = self.std(prediction)
41
- numerator = 2 * cor * sd_gt * sd_pred
42
- denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
43
- ccc = numerator / denominator
44
-
45
- return 1 - ccc
46
-
47
-
48
- @dataclasses.dataclass
49
- class ModelOutput(transformers.file_utils.ModelOutput):
50
-
51
- logits_cat: torch.FloatTensor = None
52
- hidden_states: typing.Tuple[torch.FloatTensor] = None
53
- cnn_features: torch.FloatTensor = None
54
-
55
-
56
- class ModelHead(torch.nn.Module):
57
-
58
- def __init__(self, config, num_labels):
59
-
60
- super().__init__()
61
-
62
- self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
63
- self.dropout = torch.nn.Dropout(config.final_dropout)
64
- self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
65
-
66
- def forward(self, features, **kwargs):
67
-
68
- x = features
69
- x = self.dropout(x)
70
- x = self.dense(x)
71
- x = torch.tanh(x)
72
- x = self.dropout(x)
73
- x = self.out_proj(x)
74
-
75
- return x
76
-
77
-
78
- class Model(Wav2Vec2PreTrainedModel):
79
-
80
- def __init__(self, config):
81
-
82
- super().__init__(config)
83
-
84
- self.wav2vec2 = Wav2Vec2Model(config)
85
- self.cat = ModelHead(config, 2)
86
- self.init_weights()
87
-
88
- def freeze_feature_extractor(self):
89
- self.wav2vec2.feature_extractor._freeze_parameters()
90
-
91
- def pooling(
92
- self,
93
- hidden_states,
94
- attention_mask,
95
- ):
96
-
97
- if attention_mask is None: # For evaluation with batch_size==1
98
- outputs = torch.mean(hidden_states, dim=1)
99
- else:
100
- attention_mask = self._get_feature_vector_attention_mask(
101
- hidden_states.shape[1],
102
- attention_mask,
103
- )
104
- hidden_states = hidden_states * torch.reshape(
105
- attention_mask,
106
- (-1, attention_mask.shape[-1], 1),
107
- )
108
- outputs = torch.sum(hidden_states, dim=1)
109
- attention_sum = torch.sum(attention_mask, dim=1)
110
- outputs = outputs / torch.reshape(attention_sum, (-1, 1))
111
-
112
- return outputs
113
-
114
- def forward(
115
- self,
116
- input_values,
117
- attention_mask=None,
118
- labels=None,
119
- return_hidden=False,
120
- ):
121
-
122
- outputs = self.wav2vec2(
123
- input_values,
124
- attention_mask=attention_mask,
125
- )
126
-
127
- cnn_features = outputs.extract_features
128
- hidden_states_framewise = outputs.last_hidden_state
129
- hidden_states = self.pooling(
130
- hidden_states_framewise,
131
- attention_mask,
132
- )
133
- logits_cat = self.cat(hidden_states)
134
-
135
- if not self.training:
136
- logits_cat = torch.softmax(logits_cat, dim=1)
137
-
138
- if return_hidden:
139
-
140
- # make time last axis
141
- cnn_features = torch.transpose(cnn_features, 1, 2)
142
-
143
- return ModelOutput(
144
- logits_cat=logits_cat,
145
- hidden_states=hidden_states,
146
- cnn_features=cnn_features,
147
- )
148
-
149
- else:
150
-
151
- return ModelOutput(
152
- logits_cat=logits_cat,
153
- )
154
-
155
- def predict(self, signal):
156
- result = self(torch.from_numpy(signal))
157
- result = result[0].detach().numpy()[0]
158
- return result
159
-
160
-
161
- class ModelWithPreProcessing(Model):
162
-
163
- def __init__(self, config):
164
- super().__init__(config)
165
-
166
- def forward(
167
- self,
168
- input_values,
169
- ):
170
- # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
171
- # normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
172
-
173
- mean = input_values.mean()
174
-
175
- # var = input_values.var()
176
- # raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
177
-
178
- var = torch.square(input_values - mean).mean()
179
- input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
180
-
181
- output = super().forward(
182
- input_values,
183
- return_hidden=True,
184
- )
185
-
186
- return (
187
- output.hidden_states,
188
- output.logits_cat,
189
- output.cnn_features,
190
- )