nkululeko 0.88.7__py3-none-any.whl → 0.88.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.88.7"
1
+ VERSION="0.88.9"
2
2
  SAMPLING_RATE = 16000
nkululeko/ensemble.py CHANGED
@@ -1,3 +1,20 @@
1
+ """
2
+ Ensemble predictions from multiple experiments.
3
+
4
+ Args:
5
+ config_files (list): List of configuration file paths.
6
+ method (str): Ensemble method to use. Options are 'majority_voting', 'mean', 'max', 'sum', 'uncertainty', 'uncertainty_weighted', 'confidence_weighted', or 'performance_weighted'.
7
+ threshold (float): Threshold for the 'uncertainty' ensemble method (default: 1.0, i.e. no threshold).
8
+ weights (list): Weights for the 'performance_weighted' ensemble method.
9
+ no_labels (bool): Flag indicating whether the predictions have labels or not.
10
+
11
+ Returns:
12
+ pandas.DataFrame: The ensemble predictions.
13
+
14
+ Raises:
15
+ ValueError: If an unknown ensemble method is provided.
16
+ AssertionError: If the number of config files is less than 2 for majority voting.
17
+ """
1
18
  #!/usr/bin/env python
2
19
  # -*- coding: utf-8 -*-
3
20
 
@@ -45,45 +62,7 @@ def sum_ensemble(ensemble_preds, labels):
45
62
  return ensemble_preds[labels].idxmax(axis=1)
46
63
 
47
64
 
48
- def uncertainty_ensemble(ensemble_preds):
49
- """Same as uncertainty_threshold with a threshold of 0.1"""
50
- final_predictions = []
51
- best_uncertainty = []
52
- for _, row in ensemble_preds.iterrows():
53
- uncertainties = row[["uncertainty"]].values
54
- min_uncertainty_idx = np.argmin(uncertainties)
55
- final_predictions.append(row["predicted"].iloc[min_uncertainty_idx])
56
- best_uncertainty.append(uncertainties[min_uncertainty_idx])
57
-
58
- return final_predictions, best_uncertainty
59
-
60
-
61
- def max_class_ensemble(ensemble_preds_ls, labels):
62
- """Compare the highest probabilites of all models across classes (instead of same class as in max_ensemble) and return the highest probability and the class"""
63
- final_preds = []
64
- final_probs = []
65
-
66
- for _, row in pd.concat(ensemble_preds_ls, axis=1).iterrows():
67
- max_probs = []
68
- max_classes = []
69
-
70
- for model_df in ensemble_preds_ls:
71
- model_probs = row[labels].astype(float)
72
- max_prob = model_probs.max()
73
- max_class = model_probs.idxmax()
74
-
75
- max_probs.append(max_prob)
76
- max_classes.append(max_class)
77
-
78
- best_model_index = np.argmax(max_probs)
79
-
80
- final_preds.append(max_classes[best_model_index])
81
- final_probs.append(max_probs[best_model_index])
82
-
83
- return pd.Series(final_preds), pd.Series(final_probs)
84
-
85
-
86
- def uncertainty_threshold_ensemble(ensemble_preds_ls, labels, threshold):
65
+ def uncertainty_ensemble(ensemble_preds_ls, labels, threshold):
87
66
  final_predictions = []
88
67
  final_uncertainties = []
89
68
 
@@ -173,8 +152,40 @@ def confidence_weighted_ensemble(ensemble_preds_ls, labels):
173
152
  return final_predictions, final_confidences
174
153
 
175
154
 
155
+ def performance_weighted_ensemble(ensemble_preds_ls, labels, weights):
156
+ """Weighted ensemble based on performances"""
157
+ final_predictions = []
158
+ final_confidences = []
159
+
160
+ # asserts weiths in decimal 0-1
161
+ assert all(0 <= w <= 1 for w in weights), "Weights must be between 0 and 1"
162
+
163
+ # assert lenght of weights matches number of models
164
+ assert len(weights) == len(ensemble_preds_ls), "Number of weights must match number of models"
165
+
166
+ # Normalize weights
167
+ total_weight = sum(weights)
168
+ weights = [weight / total_weight for weight in weights]
169
+
170
+ for idx in ensemble_preds_ls[0].index:
171
+ class_probabilities = {label: 0 for label in labels}
172
+
173
+ for df, weight in zip(ensemble_preds_ls, weights):
174
+ row = df.loc[idx]
175
+ for label in labels:
176
+ class_probabilities[label] += row[label] * weight
177
+
178
+ predicted_class = max(class_probabilities, key=class_probabilities.get)
179
+ final_predictions.append(predicted_class)
180
+ final_confidences.append(max(class_probabilities.values()))
181
+
182
+ return final_predictions, final_confidences
183
+
184
+
185
+
186
+
176
187
  def ensemble_predictions(
177
- config_files: List[str], method: str, threshold: float, no_labels: bool
188
+ config_files: List[str], method: str, threshold: float, weights: List[float], no_labels: bool
178
189
  ) -> pd.DataFrame:
179
190
  """
180
191
  Ensemble predictions from multiple experiments.
@@ -235,12 +246,8 @@ def ensemble_predictions(
235
246
  ensemble_preds["predicted"] = max_ensemble(ensemble_preds, labels)
236
247
  elif method == "sum":
237
248
  ensemble_preds["predicted"] = sum_ensemble(ensemble_preds, labels)
238
- elif method == "max_class":
239
- ensemble_preds["predicted"], ensemble_preds["max_probability"] = (
240
- max_class_ensemble(ensemble_preds_ls, labels)
241
- )
242
- elif method == "uncertainty_threshold":
243
- ensemble_preds["predicted"] = uncertainty_threshold_ensemble(
249
+ elif method == "uncertainty":
250
+ ensemble_preds["predicted"] = uncertainty_ensemble(
244
251
  ensemble_preds_ls, labels, threshold
245
252
  )
246
253
  elif method == "uncertainty_weighted":
@@ -251,6 +258,10 @@ def ensemble_predictions(
251
258
  ensemble_preds["predicted"], ensemble_preds["confidence"] = (
252
259
  confidence_weighted_ensemble(ensemble_preds_ls, labels)
253
260
  )
261
+ elif method == "performance_weighted":
262
+ ensemble_preds["predicted"], ensemble_preds["confidence"] = (
263
+ performance_weighted_ensemble(ensemble_preds_ls, labels, weights)
264
+ )
254
265
  else:
255
266
  raise ValueError(f"Unknown ensemble method: {method}")
256
267
 
@@ -269,7 +280,6 @@ def ensemble_predictions(
269
280
  ensemble_preds = ensemble_preds.iloc[:, : len(labels) + 3]
270
281
 
271
282
  # calculate UAR from predicted and truth columns
272
-
273
283
  truth = ensemble_preds["truth"]
274
284
  predicted = ensemble_preds["predicted"]
275
285
  uar = balanced_accuracy_score(truth, predicted)
@@ -285,7 +295,7 @@ def main(src_dir: Path) -> None:
285
295
  "configs",
286
296
  nargs="+",
287
297
  help="Paths to the configuration files of the experiments to ensemble. \
288
- Can be INI files for Nkululeko.nkululeo or CSV files from Nkululeko.demo.",
298
+ Can be INI files for Nkululeko.nkululeko or CSV files from Nkululeko.demo.",
289
299
  )
290
300
  parser.add_argument(
291
301
  "--method",
@@ -295,12 +305,13 @@ def main(src_dir: Path) -> None:
295
305
  "mean",
296
306
  "max",
297
307
  "sum",
298
- "max_class",
308
+ # "max_class",
299
309
  # "uncertainty_lowest",
300
310
  # "entropy",
301
- "uncertainty_threshold",
311
+ "uncertainty",
302
312
  "uncertainty_weighted",
303
313
  "confidence_weighted",
314
+ "performance_weighted",
304
315
  ],
305
316
  help=f"Ensemble method to use (default: {DEFAULT_METHOD})",
306
317
  )
@@ -316,6 +327,13 @@ def main(src_dir: Path) -> None:
316
327
  default=DEFAULT_OUTFILE,
317
328
  help=f"Output file path for the ensemble predictions (default: {DEFAULT_OUTFILE})",
318
329
  )
330
+ parser.add_argument(
331
+ "--weights",
332
+ default=None,
333
+ nargs="+",
334
+ type=float,
335
+ help="Weights for the ensemble method (default: None, e.g. 0.5 0.5)",
336
+ )
319
337
  parser.add_argument(
320
338
  "--no_labels",
321
339
  action="store_true",
@@ -327,7 +345,7 @@ def main(src_dir: Path) -> None:
327
345
  start = time.time()
328
346
 
329
347
  ensemble_preds = ensemble_predictions(
330
- args.configs, args.method, args.threshold, args.no_labels
348
+ args.configs, args.method, args.threshold, args.weights, args.no_labels
331
349
  )
332
350
 
333
351
  # save to csv
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.88.7
3
+ Version: 0.88.9
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -360,6 +360,14 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
360
360
  Changelog
361
361
  =========
362
362
 
363
+ Version 0.88.9
364
+ --------------
365
+ * added performance_weighted ensemble
366
+
367
+ Version 0.88.8
368
+ --------------
369
+ * some cosmetics
370
+
363
371
  Version 0.88.7
364
372
  --------------
365
373
  * added use_splits for multidb
@@ -2,11 +2,11 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
3
3
  nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
4
4
  nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
5
- nkululeko/constants.py,sha256=p-kvGUZX0J2JPXoROES9PcftVSZ1B1GfzkBt6d8MJhY,39
5
+ nkululeko/constants.py,sha256=tK1QIQ72lahwT47cOoEvhMfH2sH4BRnP3p6P7kdC_QQ,39
6
6
  nkululeko/demo.py,sha256=bLuHkeEl5rOfm7ecGHCcWATiPK7-njNbtrGljxzNzFs,5088
7
7
  nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
8
8
  nkululeko/demo_predictor.py,sha256=zs1bjhpnKuNCPLJeiyDm19ME1NEDOQT3QNeyVKJq9Yc,4882
9
- nkululeko/ensemble.py,sha256=rUHg8YmD6L8Ktt2T5M6iwsWVWbpCnfiynhHdN22bLRQ,11873
9
+ nkululeko/ensemble.py,sha256=cVz8hWd2m7poyS0lTIfrsha0K8U-hd6eiBWMqDOAlt8,12669
10
10
  nkululeko/experiment.py,sha256=L4PzoScPLG2xTyniVy9evcBy_8CIe3RTeTEUVTqiuvQ,31186
11
11
  nkululeko/explore.py,sha256=lDzRoW_Taa5u4BBABZLD89BcQWnYlrftJR4jgt1yyj0,2609
12
12
  nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
@@ -107,8 +107,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
107
107
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
108
108
  nkululeko/utils/stats.py,sha256=eC9dMO-by6CDnGLHDBQu-2B4-BudZNJ0nnWGhKYdUMA,2968
109
109
  nkululeko/utils/util.py,sha256=KMxPzb0HN3XuNzAd7Kn3M3Nq91-0sDrAAEBgDKryCdo,16688
110
- nkululeko-0.88.7.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
111
- nkululeko-0.88.7.dist-info/METADATA,sha256=VKwlkHohr4PJezcmZ45fVykmKmh1T6d2LCDvjR8Ierw,40017
112
- nkululeko-0.88.7.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
113
- nkululeko-0.88.7.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
114
- nkululeko-0.88.7.dist-info/RECORD,,
110
+ nkululeko-0.88.9.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
111
+ nkululeko-0.88.9.dist-info/METADATA,sha256=2NTuv6JzIYo9FbjMFT2zP_SuxZcBuagowGZ9YneOcOA,40134
112
+ nkululeko-0.88.9.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
113
+ nkululeko-0.88.9.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
114
+ nkululeko-0.88.9.dist-info/RECORD,,