nkululeko 0.90.1__py3-none-any.whl → 0.90.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/models/model_tuned.py +34 -8
- nkululeko/utils/util.py +2 -2
- {nkululeko-0.90.1.dist-info → nkululeko-0.90.2.dist-info}/METADATA +6 -1
- {nkululeko-0.90.1.dist-info → nkululeko-0.90.2.dist-info}/RECORD +8 -8
- {nkululeko-0.90.1.dist-info → nkululeko-0.90.2.dist-info}/LICENSE +0 -0
- {nkululeko-0.90.1.dist-info → nkululeko-0.90.2.dist-info}/WHEEL +0 -0
- {nkululeko-0.90.1.dist-info → nkululeko-0.90.2.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION = "0.90.
|
1
|
+
VERSION = "0.90.2"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/models/model_tuned.py
CHANGED
@@ -30,10 +30,16 @@ class TunedModel(BaseModel):
|
|
30
30
|
"""Constructor taking the configuration and all dataframes."""
|
31
31
|
super().__init__(df_train, df_test, feats_train, feats_test)
|
32
32
|
super().set_model_type("finetuned")
|
33
|
+
self.df_test, self.df_train, self.feats_test, self.feats_train = (
|
34
|
+
df_test,
|
35
|
+
df_train,
|
36
|
+
feats_test,
|
37
|
+
feats_train,
|
38
|
+
)
|
33
39
|
self.name = "finetuned_wav2vec2"
|
34
40
|
self.target = glob_conf.config["DATA"]["target"]
|
35
|
-
labels = glob_conf.labels
|
36
|
-
self.class_num = len(labels)
|
41
|
+
self.labels = glob_conf.labels
|
42
|
+
self.class_num = len(self.labels)
|
37
43
|
device = self.util.config_val("MODEL", "device", False)
|
38
44
|
if not device:
|
39
45
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -304,7 +310,7 @@ class TunedModel(BaseModel):
|
|
304
310
|
else:
|
305
311
|
self.util.error(f"criterion {criterion} not supported for classifier")
|
306
312
|
else:
|
307
|
-
|
313
|
+
criterion = self.util.config_val("MODEL", "loss", "1-ccc")
|
308
314
|
if criterion == "1-ccc":
|
309
315
|
criterion = ConcordanceCorCoeff()
|
310
316
|
elif criterion == "mse":
|
@@ -402,7 +408,7 @@ class TunedModel(BaseModel):
|
|
402
408
|
self.load(self.run, self.epoch)
|
403
409
|
|
404
410
|
def get_predictions(self):
|
405
|
-
results = []
|
411
|
+
results = [[]].pop(0)
|
406
412
|
for (file, start, end), _ in audeer.progress_bar(
|
407
413
|
self.df_test.iterrows(),
|
408
414
|
total=len(self.df_test),
|
@@ -415,18 +421,37 @@ class TunedModel(BaseModel):
|
|
415
421
|
file, duration=end - start, offset=start, always_2d=True
|
416
422
|
)
|
417
423
|
assert sr == self.sampling_rate
|
418
|
-
|
419
|
-
results.append(
|
420
|
-
|
424
|
+
prediction = self.model.predict(signal)
|
425
|
+
results.append(prediction)
|
426
|
+
# results.append(predictions.argmax())
|
427
|
+
predictions = np.asarray(results)
|
428
|
+
if self.util.exp_is_classification():
|
429
|
+
# make a dataframe for the class probabilities
|
430
|
+
proba_d = {}
|
431
|
+
for c in range(self.class_num):
|
432
|
+
proba_d[c] = []
|
433
|
+
# get the class probabilities
|
434
|
+
# predictions = self.clf.predict_proba(self.feats_test.to_numpy())
|
435
|
+
# pred = self.clf.predict(features)
|
436
|
+
for i in range(self.class_num):
|
437
|
+
proba_d[i] = list(predictions.T[i])
|
438
|
+
probas = pd.DataFrame(proba_d)
|
439
|
+
probas = probas.set_index(self.df_test.index)
|
440
|
+
predictions = probas.idxmax(axis=1).values
|
441
|
+
else:
|
442
|
+
predictions = predictions.flatten()
|
443
|
+
probas = None
|
444
|
+
return predictions, probas
|
421
445
|
|
422
446
|
def predict(self):
|
423
447
|
"""Predict the whole eval feature set"""
|
424
|
-
predictions = self.get_predictions()
|
448
|
+
predictions, probas = self.get_predictions()
|
425
449
|
report = Reporter(
|
426
450
|
self.df_test[self.target].to_numpy().astype(float),
|
427
451
|
predictions,
|
428
452
|
self.run,
|
429
453
|
self.epoch_num,
|
454
|
+
probas=probas,
|
430
455
|
)
|
431
456
|
self._plot_epoch_progression(report)
|
432
457
|
return report
|
@@ -438,6 +463,7 @@ class TunedModel(BaseModel):
|
|
438
463
|
)
|
439
464
|
with open(log_file, "r") as file:
|
440
465
|
data = file.read()
|
466
|
+
data = data.strip().replace("nan", "0")
|
441
467
|
list = ast.literal_eval(data)
|
442
468
|
epochs, vals, loss = [], [], []
|
443
469
|
for index, tp in enumerate(list):
|
nkululeko/utils/util.py
CHANGED
@@ -155,10 +155,10 @@ class Util:
|
|
155
155
|
return f"{store}/{self.get_exp_name()}.pkl"
|
156
156
|
|
157
157
|
def get_pred_name(self):
|
158
|
-
|
158
|
+
results_dir = self.get_path("res_dir")
|
159
159
|
target = self.get_target_name()
|
160
160
|
pred_name = self.get_model_description()
|
161
|
-
return f"{
|
161
|
+
return f"{results_dir}/pred_{target}_{pred_name}.csv"
|
162
162
|
|
163
163
|
def is_categorical(self, pd_series):
|
164
164
|
"""Check if a dataframe column is categorical."""
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.90.
|
3
|
+
Version: 0.90.2
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -356,6 +356,11 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
356
356
|
Changelog
|
357
357
|
=========
|
358
358
|
|
359
|
+
Version 0.90.2
|
360
|
+
--------------
|
361
|
+
* added probability output to finetuning classification models
|
362
|
+
* switched path to prob. output from "store" to "results"
|
363
|
+
|
359
364
|
Version 0.90.1
|
360
365
|
--------------
|
361
366
|
* Add balancing for finetune and update data README
|
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=FoMbBrfyOZd4QAw7oIHl3X6-UpsqAKWVDIolCA7qOWs,3196
|
3
3
|
nkululeko/augment.py,sha256=sIXRg19Uz8dWKgQv2LBGH7jbd2pgcUTh0PIQ_62B0kA,3135
|
4
4
|
nkululeko/cacheddataset.py,sha256=XFpWZmbJRg0pvhnIgYf0TkclxllD-Fctu-Ol0PF_00c,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=RbyLuq3HuWP1QWBrcWXo-YcwlYf2qDk6H1ihR4_KqbY,41
|
6
6
|
nkululeko/demo-ft.py,sha256=iD9Pzp9QjyAv31q1cDZ75vPez7Ve8A4Cfukv5yfZdrQ,770
|
7
7
|
nkululeko/demo.py,sha256=bLuHkeEl5rOfm7ecGHCcWATiPK7-njNbtrGljxzNzFs,5088
|
8
8
|
nkululeko/demo_feats.py,sha256=BvZjeNFTlERIRlq34OHM4Z96jdDQAhB01BGQAUcX9dM,2026
|
@@ -95,7 +95,7 @@ nkululeko/models/model_svm.py,sha256=zP8ykLhCZTYvwSqw06XHuzq9qMBtsiYpxjUpWDAnMyA
|
|
95
95
|
nkululeko/models/model_svr.py,sha256=FEwYRdgqwgGhZdkpRnT7Ef12lklWi6GZL28PyV99xWs,726
|
96
96
|
nkululeko/models/model_tree.py,sha256=6L3PD3aIiiQz1RPWS6z3Edx4f0gnR7AOfBKOJzf0BNU,433
|
97
97
|
nkululeko/models/model_tree_reg.py,sha256=IMaQpNImoRqP8Biw1CsJevxpV_PVpKblsKtYlMW5d_U,429
|
98
|
-
nkululeko/models/model_tuned.py,sha256=
|
98
|
+
nkululeko/models/model_tuned.py,sha256=VuRyNqw3XTpQ2eHsWOJN8X-V98AN8Wqiq7UgwT5BQRU,23763
|
99
99
|
nkululeko/models/model_xgb.py,sha256=ytBaSHZH8r7VvRYdmrBrQnzRM6V4HyCJ8O-v20J8G_g,448
|
100
100
|
nkululeko/models/model_xgr.py,sha256=H01FJCRgmX2unvambMs5TTCS9sI6VDB9ip9G6rVGt2c,419
|
101
101
|
nkululeko/reporting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -111,9 +111,9 @@ nkululeko/segmenting/seg_silero.py,sha256=CnhjKGTW5OXf-bmw4YsSJeN2yUwkY5m3xnulM_
|
|
111
111
|
nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
112
112
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
113
113
|
nkululeko/utils/stats.py,sha256=vCRzhCR0Gx5SiJyAGbj1TIto8ocGz58CM5Pr3LltagA,2948
|
114
|
-
nkululeko/utils/util.py,sha256=
|
115
|
-
nkululeko-0.90.
|
116
|
-
nkululeko-0.90.
|
117
|
-
nkululeko-0.90.
|
118
|
-
nkululeko-0.90.
|
119
|
-
nkululeko-0.90.
|
114
|
+
nkululeko/utils/util.py,sha256=XFZdhCc_LM4EmoZ5tKKaBCQLXclcNmvHwhfT_CXB98c,16723
|
115
|
+
nkululeko-0.90.2.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
116
|
+
nkululeko-0.90.2.dist-info/METADATA,sha256=rJnGf71UEIyv0OBiNxrfu0l1e6o83v8q_UlIlmhtE_0,41113
|
117
|
+
nkululeko-0.90.2.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
118
|
+
nkululeko-0.90.2.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
119
|
+
nkululeko-0.90.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|