nkululeko 0.90.1__py3-none-any.whl → 0.90.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION = "0.90.1"
1
+ VERSION = "0.90.2"
2
2
  SAMPLING_RATE = 16000
@@ -30,10 +30,16 @@ class TunedModel(BaseModel):
30
30
  """Constructor taking the configuration and all dataframes."""
31
31
  super().__init__(df_train, df_test, feats_train, feats_test)
32
32
  super().set_model_type("finetuned")
33
+ self.df_test, self.df_train, self.feats_test, self.feats_train = (
34
+ df_test,
35
+ df_train,
36
+ feats_test,
37
+ feats_train,
38
+ )
33
39
  self.name = "finetuned_wav2vec2"
34
40
  self.target = glob_conf.config["DATA"]["target"]
35
- labels = glob_conf.labels
36
- self.class_num = len(labels)
41
+ self.labels = glob_conf.labels
42
+ self.class_num = len(self.labels)
37
43
  device = self.util.config_val("MODEL", "device", False)
38
44
  if not device:
39
45
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -304,7 +310,7 @@ class TunedModel(BaseModel):
304
310
  else:
305
311
  self.util.error(f"criterion {criterion} not supported for classifier")
306
312
  else:
307
- self.criterion = self.util.config_val("MODEL", "loss", "ccc")
313
+ criterion = self.util.config_val("MODEL", "loss", "1-ccc")
308
314
  if criterion == "1-ccc":
309
315
  criterion = ConcordanceCorCoeff()
310
316
  elif criterion == "mse":
@@ -402,7 +408,7 @@ class TunedModel(BaseModel):
402
408
  self.load(self.run, self.epoch)
403
409
 
404
410
  def get_predictions(self):
405
- results = []
411
+ results = [[]].pop(0)
406
412
  for (file, start, end), _ in audeer.progress_bar(
407
413
  self.df_test.iterrows(),
408
414
  total=len(self.df_test),
@@ -415,18 +421,37 @@ class TunedModel(BaseModel):
415
421
  file, duration=end - start, offset=start, always_2d=True
416
422
  )
417
423
  assert sr == self.sampling_rate
418
- predictions = self.model.predict(signal)
419
- results.append(predictions.argmax())
420
- return results
424
+ prediction = self.model.predict(signal)
425
+ results.append(prediction)
426
+ # results.append(predictions.argmax())
427
+ predictions = np.asarray(results)
428
+ if self.util.exp_is_classification():
429
+ # make a dataframe for the class probabilities
430
+ proba_d = {}
431
+ for c in range(self.class_num):
432
+ proba_d[c] = []
433
+ # get the class probabilities
434
+ # predictions = self.clf.predict_proba(self.feats_test.to_numpy())
435
+ # pred = self.clf.predict(features)
436
+ for i in range(self.class_num):
437
+ proba_d[i] = list(predictions.T[i])
438
+ probas = pd.DataFrame(proba_d)
439
+ probas = probas.set_index(self.df_test.index)
440
+ predictions = probas.idxmax(axis=1).values
441
+ else:
442
+ predictions = predictions.flatten()
443
+ probas = None
444
+ return predictions, probas
421
445
 
422
446
  def predict(self):
423
447
  """Predict the whole eval feature set"""
424
- predictions = self.get_predictions()
448
+ predictions, probas = self.get_predictions()
425
449
  report = Reporter(
426
450
  self.df_test[self.target].to_numpy().astype(float),
427
451
  predictions,
428
452
  self.run,
429
453
  self.epoch_num,
454
+ probas=probas,
430
455
  )
431
456
  self._plot_epoch_progression(report)
432
457
  return report
@@ -438,6 +463,7 @@ class TunedModel(BaseModel):
438
463
  )
439
464
  with open(log_file, "r") as file:
440
465
  data = file.read()
466
+ data = data.strip().replace("nan", "0")
441
467
  list = ast.literal_eval(data)
442
468
  epochs, vals, loss = [], [], []
443
469
  for index, tp in enumerate(list):
nkululeko/utils/util.py CHANGED
@@ -155,10 +155,10 @@ class Util:
155
155
  return f"{store}/{self.get_exp_name()}.pkl"
156
156
 
157
157
  def get_pred_name(self):
158
- store = self.get_path("store")
158
+ results_dir = self.get_path("res_dir")
159
159
  target = self.get_target_name()
160
160
  pred_name = self.get_model_description()
161
- return f"{store}/pred_{target}_{pred_name}.csv"
161
+ return f"{results_dir}/pred_{target}_{pred_name}.csv"
162
162
 
163
163
  def is_categorical(self, pd_series):
164
164
  """Check if a dataframe column is categorical."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.90.1
3
+ Version: 0.90.2
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -356,6 +356,11 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
356
356
  Changelog
357
357
  =========
358
358
 
359
+ Version 0.90.2
360
+ --------------
361
+ * added probability output to finetuning classification models
362
+ * switched path to prob. output from "store" to "results"
363
+
359
364
  Version 0.90.1
360
365
  --------------
361
366
  * Add balancing for finetune and update data README
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=FoMbBrfyOZd4QAw7oIHl3X6-UpsqAKWVDIolCA7qOWs,3196
3
3
  nkululeko/augment.py,sha256=sIXRg19Uz8dWKgQv2LBGH7jbd2pgcUTh0PIQ_62B0kA,3135
4
4
  nkululeko/cacheddataset.py,sha256=XFpWZmbJRg0pvhnIgYf0TkclxllD-Fctu-Ol0PF_00c,969
5
- nkululeko/constants.py,sha256=TmPPFi_-OUMYF2mfBNMLxBQl0vwneI1opUPN0vK2XPY,41
5
+ nkululeko/constants.py,sha256=RbyLuq3HuWP1QWBrcWXo-YcwlYf2qDk6H1ihR4_KqbY,41
6
6
  nkululeko/demo-ft.py,sha256=iD9Pzp9QjyAv31q1cDZ75vPez7Ve8A4Cfukv5yfZdrQ,770
7
7
  nkululeko/demo.py,sha256=bLuHkeEl5rOfm7ecGHCcWATiPK7-njNbtrGljxzNzFs,5088
8
8
  nkululeko/demo_feats.py,sha256=BvZjeNFTlERIRlq34OHM4Z96jdDQAhB01BGQAUcX9dM,2026
@@ -95,7 +95,7 @@ nkululeko/models/model_svm.py,sha256=zP8ykLhCZTYvwSqw06XHuzq9qMBtsiYpxjUpWDAnMyA
95
95
  nkululeko/models/model_svr.py,sha256=FEwYRdgqwgGhZdkpRnT7Ef12lklWi6GZL28PyV99xWs,726
96
96
  nkululeko/models/model_tree.py,sha256=6L3PD3aIiiQz1RPWS6z3Edx4f0gnR7AOfBKOJzf0BNU,433
97
97
  nkululeko/models/model_tree_reg.py,sha256=IMaQpNImoRqP8Biw1CsJevxpV_PVpKblsKtYlMW5d_U,429
98
- nkululeko/models/model_tuned.py,sha256=k6c8dPKy2BeFMKABrNTMSwQuiKa9VrZ7oeJdfNYoYAo,22678
98
+ nkululeko/models/model_tuned.py,sha256=VuRyNqw3XTpQ2eHsWOJN8X-V98AN8Wqiq7UgwT5BQRU,23763
99
99
  nkululeko/models/model_xgb.py,sha256=ytBaSHZH8r7VvRYdmrBrQnzRM6V4HyCJ8O-v20J8G_g,448
100
100
  nkululeko/models/model_xgr.py,sha256=H01FJCRgmX2unvambMs5TTCS9sI6VDB9ip9G6rVGt2c,419
101
101
  nkululeko/reporting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -111,9 +111,9 @@ nkululeko/segmenting/seg_silero.py,sha256=CnhjKGTW5OXf-bmw4YsSJeN2yUwkY5m3xnulM_
111
111
  nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
112
112
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
113
113
  nkululeko/utils/stats.py,sha256=vCRzhCR0Gx5SiJyAGbj1TIto8ocGz58CM5Pr3LltagA,2948
114
- nkululeko/utils/util.py,sha256=a9fs5swVkv_k0CfJRwDhEx1ChZv7rs7K4oQDYspiQWY,16709
115
- nkululeko-0.90.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
116
- nkululeko-0.90.1.dist-info/METADATA,sha256=unqq8xrL0bfP178Q3fKBaGyry4SJvHxPGJCR3figOpQ,40961
117
- nkululeko-0.90.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
118
- nkululeko-0.90.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
119
- nkululeko-0.90.1.dist-info/RECORD,,
114
+ nkululeko/utils/util.py,sha256=XFZdhCc_LM4EmoZ5tKKaBCQLXclcNmvHwhfT_CXB98c,16723
115
+ nkululeko-0.90.2.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
116
+ nkululeko-0.90.2.dist-info/METADATA,sha256=rJnGf71UEIyv0OBiNxrfu0l1e6o83v8q_UlIlmhtE_0,41113
117
+ nkululeko-0.90.2.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
118
+ nkululeko-0.90.2.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
119
+ nkululeko-0.90.2.dist-info/RECORD,,