nkululeko 0.83.0__py3-none-any.whl → 0.83.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/experiment.py +2 -1
- nkululeko/nkuluflag.py +19 -6
- nkululeko/test.py +20 -15
- nkululeko/test_predictor.py +3 -0
- {nkululeko-0.83.0.dist-info → nkululeko-0.83.1.dist-info}/METADATA +5 -1
- {nkululeko-0.83.0.dist-info → nkululeko-0.83.1.dist-info}/RECORD +10 -11
- nkululeko/reporter.py +0 -324
- {nkululeko-0.83.0.dist-info → nkululeko-0.83.1.dist-info}/LICENSE +0 -0
- {nkululeko-0.83.0.dist-info → nkululeko-0.83.1.dist-info}/WHEEL +0 -0
- {nkululeko-0.83.0.dist-info → nkululeko-0.83.1.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.83.
|
1
|
+
VERSION="0.83.1"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/experiment.py
CHANGED
@@ -675,7 +675,8 @@ class Experiment:
|
|
675
675
|
test_predictor = TestPredictor(
|
676
676
|
model, self.df_test, self.label_encoder, result_name
|
677
677
|
)
|
678
|
-
test_predictor.predict_and_store()
|
678
|
+
result = test_predictor.predict_and_store()
|
679
|
+
return result
|
679
680
|
|
680
681
|
def load(self, filename):
|
681
682
|
f = open(filename, "rb")
|
nkululeko/nkuluflag.py
CHANGED
@@ -2,13 +2,16 @@ import argparse
|
|
2
2
|
import configparser
|
3
3
|
import os
|
4
4
|
import os.path
|
5
|
+
import sys
|
5
6
|
|
6
7
|
from nkululeko.nkululeko import doit as nkulu
|
8
|
+
from nkululeko.test import do_it as test_mod
|
7
9
|
|
8
10
|
|
9
|
-
def
|
11
|
+
def doit(cla):
|
10
12
|
parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
|
11
13
|
parser.add_argument("--config", help="The base configuration")
|
14
|
+
parser.add_argument("--mod", default="nkulu", help="Which nkululeko module to call")
|
12
15
|
parser.add_argument("--data", help="The databases", nargs="*", action="append")
|
13
16
|
parser.add_argument(
|
14
17
|
"--label", nargs="*", help="The labels for the target", action="append"
|
@@ -25,20 +28,23 @@ def do_it(src_dir):
|
|
25
28
|
parser.add_argument("--model", default="xgb", help="The model type")
|
26
29
|
parser.add_argument("--feat", default="['os']", help="The feature type")
|
27
30
|
parser.add_argument("--set", help="The opensmile set")
|
28
|
-
parser.add_argument("--with_os", help="To add os features")
|
29
31
|
parser.add_argument("--target", help="The target designation")
|
30
32
|
parser.add_argument("--epochs", help="The number of epochs")
|
31
33
|
parser.add_argument("--runs", help="The number of runs")
|
32
34
|
parser.add_argument("--learning_rate", help="The learning rate")
|
33
35
|
parser.add_argument("--drop", help="The dropout rate [0:1]")
|
34
36
|
|
35
|
-
args = parser.parse_args()
|
37
|
+
args = parser.parse_args(cla)
|
36
38
|
|
37
39
|
if args.config is not None:
|
38
40
|
config_file = args.config
|
39
41
|
else:
|
40
42
|
print("ERROR: need config file")
|
41
43
|
quit(-1)
|
44
|
+
|
45
|
+
if args.mod is not None:
|
46
|
+
nkulu_mod = args.mod
|
47
|
+
|
42
48
|
# test if config is there
|
43
49
|
if not os.path.isfile(config_file):
|
44
50
|
print(f"ERROR: no such file {config_file}")
|
@@ -86,10 +92,17 @@ def do_it(src_dir):
|
|
86
92
|
with open(tmp_config, "w") as tmp_file:
|
87
93
|
config.write(tmp_file)
|
88
94
|
|
89
|
-
result, last_epoch =
|
95
|
+
result, last_epoch = 0, 0
|
96
|
+
if nkulu_mod == "nkulu":
|
97
|
+
result, last_epoch = nkulu(tmp_config)
|
98
|
+
elif nkulu_mod == "test":
|
99
|
+
result, last_epoch = test_mod(tmp_config, "test_results.csv")
|
100
|
+
else:
|
101
|
+
print(f"ERROR: unknown module: {nkulu_mod}, should be [nkulu | test]")
|
90
102
|
return result, last_epoch
|
91
103
|
|
92
104
|
|
93
105
|
if __name__ == "__main__":
|
94
|
-
|
95
|
-
|
106
|
+
cla = sys.argv
|
107
|
+
cla.pop(0)
|
108
|
+
doit(cla) # sys.argv[1])
|
nkululeko/test.py
CHANGED
@@ -10,20 +10,7 @@ from nkululeko.experiment import Experiment
|
|
10
10
|
from nkululeko.utils.util import Util
|
11
11
|
|
12
12
|
|
13
|
-
def
|
14
|
-
parser = argparse.ArgumentParser(
|
15
|
-
description="Call the nkululeko TEST framework.")
|
16
|
-
parser.add_argument("--config", default="exp.ini",
|
17
|
-
help="The base configuration")
|
18
|
-
parser.add_argument(
|
19
|
-
"--outfile",
|
20
|
-
default="my_results.csv",
|
21
|
-
help="File name to store the predictions",
|
22
|
-
)
|
23
|
-
|
24
|
-
args = parser.parse_args()
|
25
|
-
|
26
|
-
config_file = args.config
|
13
|
+
def do_it(config_file, outfile):
|
27
14
|
|
28
15
|
# test if the configuration file exists
|
29
16
|
if not os.path.isfile(config_file):
|
@@ -48,10 +35,28 @@ def main(src_dir):
|
|
48
35
|
expr.load(f"{util.get_save_name()}")
|
49
36
|
expr.fill_tests()
|
50
37
|
expr.extract_test_feats()
|
51
|
-
expr.predict_test_and_save(
|
38
|
+
result = expr.predict_test_and_save(outfile)
|
52
39
|
|
53
40
|
print("DONE")
|
54
41
|
|
42
|
+
return result, 0
|
43
|
+
|
44
|
+
|
45
|
+
def main(src_dir):
|
46
|
+
parser = argparse.ArgumentParser(description="Call the nkululeko TEST framework.")
|
47
|
+
parser.add_argument("--config", default="exp.ini", help="The base configuration")
|
48
|
+
parser.add_argument(
|
49
|
+
"--outfile",
|
50
|
+
default="my_results.csv",
|
51
|
+
help="File name to store the predictions",
|
52
|
+
)
|
53
|
+
args = parser.parse_args()
|
54
|
+
if args.config is not None:
|
55
|
+
config_file = args.config
|
56
|
+
else:
|
57
|
+
config_file = f"{src_dir}/exp.ini"
|
58
|
+
do_it(config_file, args.outfile)
|
59
|
+
|
55
60
|
|
56
61
|
if __name__ == "__main__":
|
57
62
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
nkululeko/test_predictor.py
CHANGED
@@ -29,6 +29,7 @@ class TestPredictor:
|
|
29
29
|
|
30
30
|
def predict_and_store(self):
|
31
31
|
label_data = self.util.config_val("DATA", "label_data", False)
|
32
|
+
result = 0
|
32
33
|
if label_data:
|
33
34
|
data = Dataset(label_data)
|
34
35
|
data.load()
|
@@ -57,6 +58,7 @@ class TestPredictor:
|
|
57
58
|
test_dbs_string = "_".join(test_dbs)
|
58
59
|
predictions = self.model.get_predictions()
|
59
60
|
report = self.model.predict()
|
61
|
+
result = report.result.get_result()
|
60
62
|
report.set_filename_add(f"test-{test_dbs_string}")
|
61
63
|
self.util.print_best_results([report])
|
62
64
|
report.plot_confmatrix(self.util.get_plot_name(), 0)
|
@@ -74,3 +76,4 @@ class TestPredictor:
|
|
74
76
|
df = df.rename(columns={"class_label": target})
|
75
77
|
df.to_csv(self.name)
|
76
78
|
self.util.debug(f"results stored in {self.name}")
|
79
|
+
return result
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.83.
|
3
|
+
Version: 0.83.1
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -333,6 +333,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
333
333
|
Changelog
|
334
334
|
=========
|
335
335
|
|
336
|
+
Version 0.83.1
|
337
|
+
--------------
|
338
|
+
* add test module to nkuluflag
|
339
|
+
|
336
340
|
Version 0.83.0
|
337
341
|
--------------
|
338
342
|
* test module now prints out reports
|
@@ -2,11 +2,11 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=i6-Vtyje9xE8w8o3lG27IiJczQFyrNbsxiXs7b4-q28,39
|
6
6
|
nkululeko/demo.py,sha256=55kNFA2helMhOxD4yZuKg1JWDtlUUpxm-6uAnroIydI,3264
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=-ggSHc3DXxRzjzcGB4qFBOMvKsfUdTkkde50BDrS9dA,4755
|
9
|
-
nkululeko/experiment.py,sha256=
|
9
|
+
nkululeko/experiment.py,sha256=aueWoKJCQx8wU9daosh6n7ZDGhT2cfo_9Av5HIfN1_w,29605
|
10
10
|
nkululeko/explore.py,sha256=2wdoGRqldvsN1zCiWk0quSDgHHHUoF2UZOWQ1r-2OLM,2310
|
11
11
|
nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
|
12
12
|
nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzwM,3976
|
@@ -15,18 +15,17 @@ nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
|
|
15
15
|
nkululeko/glob_conf.py,sha256=iHiVSxDYgmYwdx6z0HuGUMSWrfZfufPHxHb60q2dLRY,453
|
16
16
|
nkululeko/modelrunner.py,sha256=GwDXcE2gDQXat4W0-HhHQ1BcUNCRBXMBQ4QycfHp_5c,9288
|
17
17
|
nkululeko/multidb.py,sha256=fG3VukEWP1vreVN4gB1IRXxwwg4jLftsSEYtu0o1f78,5634
|
18
|
-
nkululeko/nkuluflag.py,sha256=
|
18
|
+
nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
|
19
19
|
nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
|
20
20
|
nkululeko/plots.py,sha256=K88ZRPFGX_r03BT742H06Dde20xZYdltv7dxjgUiAFA,23025
|
21
21
|
nkululeko/predict.py,sha256=sF091sSSLnEWcISx9ZcULLie3tY5XeFsQJd6b3vrxFg,2409
|
22
|
-
nkululeko/reporter.py,sha256=8mlIaKep4hM-tdRv8t98tK80rx3zOmVGXSORhiPc3as,12483
|
23
22
|
nkululeko/resample.py,sha256=3WbxkwgyTe_fW38046Rjxk3knOkFdhqn2C4nfhbUurQ,2287
|
24
23
|
nkululeko/runmanager.py,sha256=eTM1DNQKt1lxYhzt4vZyZluPXW9sWlIJHNQzex4lkJU,7624
|
25
24
|
nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
|
26
25
|
nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
|
27
26
|
nkululeko/syllable_nuclei.py,sha256=Sky-C__MeUDaxqHnDl2TGLLYOYvsahD35TUjWGeG31k,10047
|
28
|
-
nkululeko/test.py,sha256=
|
29
|
-
nkululeko/test_predictor.py,sha256=
|
27
|
+
nkululeko/test.py,sha256=1w624vo5KTzmFC8BUStGlLDmIEAFuJUz7J0W-gp7AxI,1677
|
28
|
+
nkululeko/test_predictor.py,sha256=_w5J8CxH6hmW3mLTKbdfmywl5QpdNAnW1Y8TE5GtlfE,3237
|
30
29
|
nkululeko/augmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
30
|
nkululeko/augmenting/augmenter.py,sha256=XAt0dpmlnKxqyysqCgV3rcz-pRIvOz7rU7dmGDCVAzs,2905
|
32
31
|
nkululeko/augmenting/randomsplicer.py,sha256=Z5rxdKKUpuncLWuTS6xVfVKUeVbeiYU_dLRHQ5fcg4Y,2669
|
@@ -104,8 +103,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
104
103
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
105
104
|
nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
|
106
105
|
nkululeko/utils/util.py,sha256=_Z6OMJ3f-8TdETW9eqJYY5hwNRS5XCt9azzRnqoTTZE,12330
|
107
|
-
nkululeko-0.83.
|
108
|
-
nkululeko-0.83.
|
109
|
-
nkululeko-0.83.
|
110
|
-
nkululeko-0.83.
|
111
|
-
nkululeko-0.83.
|
106
|
+
nkululeko-0.83.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
107
|
+
nkululeko-0.83.1.dist-info/METADATA,sha256=EgPYOS_ELZQmEvPWlX-klt8gmo59suFFL_HDptU474w,36080
|
108
|
+
nkululeko-0.83.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
109
|
+
nkululeko-0.83.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
110
|
+
nkululeko-0.83.1.dist-info/RECORD,,
|
nkululeko/reporter.py
DELETED
@@ -1,324 +0,0 @@
|
|
1
|
-
"""Reporter module.
|
2
|
-
|
3
|
-
This module contains the Reporter class which is responsible for generating reports.
|
4
|
-
"""
|
5
|
-
|
6
|
-
import ast
|
7
|
-
import glob
|
8
|
-
import json
|
9
|
-
import math
|
10
|
-
|
11
|
-
import matplotlib.pyplot as plt
|
12
|
-
import numpy as np
|
13
|
-
from scipy.stats import pearsonr
|
14
|
-
from sklearn.metrics import ConfusionMatrixDisplay
|
15
|
-
from sklearn.metrics import accuracy_score
|
16
|
-
from sklearn.metrics import classification_report
|
17
|
-
from sklearn.metrics import confusion_matrix
|
18
|
-
from sklearn.metrics import mean_absolute_error
|
19
|
-
from sklearn.metrics import mean_squared_error
|
20
|
-
from sklearn.metrics import r2_score
|
21
|
-
from sklearn.metrics import recall_score
|
22
|
-
from sklearn.utils import resample
|
23
|
-
|
24
|
-
import nkululeko.glob_conf as glob_conf
|
25
|
-
from nkululeko.reporting.defines import Header
|
26
|
-
from nkululeko.reporting.report_item import ReportItem
|
27
|
-
from nkululeko.result import Result
|
28
|
-
from nkululeko.utils.util import Util
|
29
|
-
|
30
|
-
|
31
|
-
class Reporter:
|
32
|
-
def __set_measure(self):
|
33
|
-
if self.util.exp_is_classification():
|
34
|
-
self.MEASURE = "UAR"
|
35
|
-
self.result.measure = self.MEASURE
|
36
|
-
self.is_classification = True
|
37
|
-
else:
|
38
|
-
self.is_classification = False
|
39
|
-
self.measure = self.util.config_val("MODEL", "measure", "mse")
|
40
|
-
if self.measure == "mse":
|
41
|
-
self.MEASURE = "MSE"
|
42
|
-
self.result.measure = self.MEASURE
|
43
|
-
elif self.measure == "mae":
|
44
|
-
self.MEASURE = "MAE"
|
45
|
-
self.result.measure = self.MEASURE
|
46
|
-
elif self.measure == "ccc":
|
47
|
-
self.MEASURE = "CCC"
|
48
|
-
self.result.measure = self.MEASURE
|
49
|
-
|
50
|
-
def __init__(self, truths, preds, run, epoch):
|
51
|
-
"""Initialization with ground truth und predictions vector"""
|
52
|
-
self.util = Util("reporter")
|
53
|
-
self.format = self.util.config_val("PLOT", "format", "png")
|
54
|
-
self.truths = truths
|
55
|
-
self.preds = preds
|
56
|
-
self.result = Result(0, 0, 0, 0, "unknown")
|
57
|
-
self.run = run
|
58
|
-
self.epoch = epoch
|
59
|
-
self.__set_measure()
|
60
|
-
self.cont_to_cat = False
|
61
|
-
if len(self.truths) > 0 and len(self.preds) > 0:
|
62
|
-
if self.util.exp_is_classification():
|
63
|
-
self.result.test = recall_score(
|
64
|
-
self.truths, self.preds, average="macro"
|
65
|
-
)
|
66
|
-
self.result.loss = 1 - accuracy_score(self.truths, self.preds)
|
67
|
-
else:
|
68
|
-
# regression experiment
|
69
|
-
if self.measure == "mse":
|
70
|
-
self.result.test = mean_squared_error(self.truths, self.preds)
|
71
|
-
elif self.measure == "mae":
|
72
|
-
self.result.test = mean_absolute_error(self.truths, self.preds)
|
73
|
-
elif self.measure == "ccc":
|
74
|
-
self.result.test = self.ccc(self.truths, self.preds)
|
75
|
-
if math.isnan(self.result.test):
|
76
|
-
self.util.debug(f"Truth: {self.truths}")
|
77
|
-
self.util.debug(f"Predict.: {self.preds}")
|
78
|
-
self.util.debug(f"Result is NAN: setting to -1")
|
79
|
-
self.result.test = -1
|
80
|
-
else:
|
81
|
-
self.util.error(f"unknown measure: {self.measure}")
|
82
|
-
|
83
|
-
# train and loss are being set by the model
|
84
|
-
|
85
|
-
def set_id(self, run, epoch):
|
86
|
-
"""Make the report identifiable with run and epoch index"""
|
87
|
-
self.run = run
|
88
|
-
self.epoch = epoch
|
89
|
-
|
90
|
-
def continuous_to_categorical(self):
|
91
|
-
if self.cont_to_cat:
|
92
|
-
return
|
93
|
-
self.cont_to_cat = True
|
94
|
-
bins = ast.literal_eval(glob_conf.config["DATA"]["bins"])
|
95
|
-
self.truths = np.digitize(self.truths, bins) - 1
|
96
|
-
self.preds = np.digitize(self.preds, bins) - 1
|
97
|
-
|
98
|
-
def plot_confmatrix(self, plot_name, epoch):
|
99
|
-
if not self.util.exp_is_classification():
|
100
|
-
self.continuous_to_categorical()
|
101
|
-
self._plot_confmat(self.truths, self.preds, plot_name, epoch)
|
102
|
-
|
103
|
-
|
104
|
-
def plot_per_speaker(self, result_df, plot_name, function):
|
105
|
-
"""Plot a confusion matrix with the mode category per speakers.
|
106
|
-
|
107
|
-
This function creates a confusion matrix for each speaker in the result_df.
|
108
|
-
The result_df should contain the columns: preds, truths and speaker.
|
109
|
-
|
110
|
-
Args:
|
111
|
-
* result_df: a pandas dataframe with columns: preds, truths and speaker
|
112
|
-
* plot_name: a string with the name of the plot
|
113
|
-
* function: a string with the function to use for each speaker,
|
114
|
-
can be 'mode' or 'mean'
|
115
|
-
|
116
|
-
Returns:
|
117
|
-
* None
|
118
|
-
"""
|
119
|
-
# Initialize empty arrays for predictions and truths
|
120
|
-
pred = np.zeros(0)
|
121
|
-
truth = np.zeros(0)
|
122
|
-
|
123
|
-
# Iterate over each speaker
|
124
|
-
for s in result_df.speaker.unique():
|
125
|
-
# Filter the dataframe for the current speaker
|
126
|
-
s_df = result_df[result_df.speaker == s]
|
127
|
-
|
128
|
-
# Get the mode or mean prediction for the current speaker
|
129
|
-
mode = s_df.pred.mode().iloc[-1]
|
130
|
-
mean = s_df.pred.mean()
|
131
|
-
if function == "mode":
|
132
|
-
s_df.pred = mode
|
133
|
-
elif function == "mean":
|
134
|
-
s_df.pred = mean
|
135
|
-
else:
|
136
|
-
self.util.error(f"unknown function {function}")
|
137
|
-
|
138
|
-
# Append the current speaker's predictions and truths to the arrays
|
139
|
-
pred = np.append(pred, s_df.pred.values)
|
140
|
-
truth = np.append(truth, s_df["truth"].values)
|
141
|
-
|
142
|
-
# If the experiment is not a classification or continuous to categorical conversion was performed,
|
143
|
-
# convert the truths and predictions to categorical
|
144
|
-
if not (self.is_classification or self.cont_to_cat):
|
145
|
-
bins = ast.literal_eval(glob_conf.config["DATA"]["bins"])
|
146
|
-
truth = np.digitize(truth, bins) - 1
|
147
|
-
pred = np.digitize(pred, bins) - 1
|
148
|
-
|
149
|
-
# Plot the confusion matrix for the speakers
|
150
|
-
self._plot_confmat(truth, pred.astype("int"), plot_name, 0)
|
151
|
-
|
152
|
-
def _plot_confmat(self, truths, preds, plot_name, epoch):
|
153
|
-
# print(truths)
|
154
|
-
# print(preds)
|
155
|
-
fig_dir = self.util.get_path("fig_dir")
|
156
|
-
labels = glob_conf.labels
|
157
|
-
fig = plt.figure() # figsize=[5, 5]
|
158
|
-
uar = recall_score(truths, preds, average="macro")
|
159
|
-
acc = accuracy_score(truths, preds)
|
160
|
-
cm = confusion_matrix(
|
161
|
-
truths, preds, normalize=None
|
162
|
-
) # normalize must be one of {'true', 'pred', 'all', None}
|
163
|
-
if cm.shape[0] != len(labels):
|
164
|
-
self.util.error(
|
165
|
-
f"mismatch between confmatrix dim ({cm.shape[0]}) and labels"
|
166
|
-
f" length ({len(labels)}: {labels})"
|
167
|
-
)
|
168
|
-
try:
|
169
|
-
disp = ConfusionMatrixDisplay(
|
170
|
-
confusion_matrix=cm, display_labels=labels
|
171
|
-
).plot(cmap="Blues")
|
172
|
-
except ValueError:
|
173
|
-
disp = ConfusionMatrixDisplay(
|
174
|
-
confusion_matrix=cm,
|
175
|
-
display_labels=list(labels).remove("neutral"),
|
176
|
-
).plot(cmap="Blues")
|
177
|
-
|
178
|
-
reg_res = ""
|
179
|
-
if not self.is_classification:
|
180
|
-
reg_res = f", {self.MEASURE}: {self.result.test:.3f}"
|
181
|
-
|
182
|
-
if epoch != 0:
|
183
|
-
plt.title(f"Confusion Matrix, UAR: {uar:.3f}{reg_res}, Epoch: {epoch}")
|
184
|
-
else:
|
185
|
-
plt.title(f"Confusion Matrix, UAR: {uar:.3f}{reg_res}")
|
186
|
-
img_path = f"{fig_dir}{plot_name}.{self.format}"
|
187
|
-
plt.savefig(img_path)
|
188
|
-
fig.clear()
|
189
|
-
plt.close(fig)
|
190
|
-
plt.savefig(img_path)
|
191
|
-
plt.close(fig)
|
192
|
-
glob_conf.report.add_item(
|
193
|
-
ReportItem(
|
194
|
-
Header.HEADER_RESULTS,
|
195
|
-
self.util.get_model_description(),
|
196
|
-
"Confusion matrix",
|
197
|
-
img_path,
|
198
|
-
)
|
199
|
-
)
|
200
|
-
|
201
|
-
res_dir = self.util.get_path("res_dir")
|
202
|
-
uar = int(uar * 1000) / 1000.0
|
203
|
-
acc = int(acc * 1000) / 1000.0
|
204
|
-
rpt = f"epoch: {epoch}, UAR: {uar}, ACC: {acc}"
|
205
|
-
# print(rpt)
|
206
|
-
self.util.debug(rpt)
|
207
|
-
file_name = f"{res_dir}{self.util.get_exp_name()}_conf.txt"
|
208
|
-
with open(file_name, "w") as text_file:
|
209
|
-
text_file.write(rpt)
|
210
|
-
|
211
|
-
def print_results(self, epoch):
|
212
|
-
"""Print all evaluation values to text file"""
|
213
|
-
res_dir = self.util.get_path("res_dir")
|
214
|
-
file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}.txt"
|
215
|
-
if self.util.exp_is_classification():
|
216
|
-
labels = glob_conf.labels
|
217
|
-
try:
|
218
|
-
rpt = classification_report(
|
219
|
-
self.truths,
|
220
|
-
self.preds,
|
221
|
-
target_names=labels,
|
222
|
-
output_dict=True,
|
223
|
-
)
|
224
|
-
except ValueError as e:
|
225
|
-
self.util.debug(
|
226
|
-
"Reporter: caught a ValueError when trying to get"
|
227
|
-
" classification_report: " + e
|
228
|
-
)
|
229
|
-
rpt = self.result.to_string()
|
230
|
-
with open(file_name, "w") as text_file:
|
231
|
-
c_ress = list(range(len(labels)))
|
232
|
-
for i, l in enumerate(labels):
|
233
|
-
c_res = rpt[l]["f1-score"]
|
234
|
-
c_ress[i] = float(f"{c_res:.3f}")
|
235
|
-
self.util.debug(f"labels: {labels}")
|
236
|
-
f1_per_class = f"result per class (F1 score): {c_ress}"
|
237
|
-
self.util.debug(f1_per_class)
|
238
|
-
rpt_str = f"{json.dumps(rpt)}\n{f1_per_class}"
|
239
|
-
text_file.write(rpt_str)
|
240
|
-
glob_conf.report.add_item(
|
241
|
-
ReportItem(
|
242
|
-
Header.HEADER_RESULTS,
|
243
|
-
f"Classification result {self.util.get_model_description()}",
|
244
|
-
rpt_str,
|
245
|
-
)
|
246
|
-
)
|
247
|
-
|
248
|
-
else: # regression
|
249
|
-
result = self.result.test
|
250
|
-
r2 = r2_score(self.truths, self.preds)
|
251
|
-
pcc = pearsonr(self.truths, self.preds)[0]
|
252
|
-
measure = self.util.config_val("MODEL", "measure", "mse")
|
253
|
-
with open(file_name, "w") as text_file:
|
254
|
-
text_file.write(
|
255
|
-
f"{measure}: {result:.3f}, r_2: {r2:.3f}, pcc {pcc:.3f}"
|
256
|
-
)
|
257
|
-
|
258
|
-
def make_conf_animation(self, out_name):
|
259
|
-
import imageio
|
260
|
-
|
261
|
-
fig_dir = self.util.get_path("fig_dir")
|
262
|
-
filenames = glob.glob(fig_dir + f"{self.util.get_plot_name()}*_?_???_cnf.png")
|
263
|
-
images = []
|
264
|
-
for filename in filenames:
|
265
|
-
images.append(imageio.imread(filename))
|
266
|
-
fps = self.util.config_val("PLOT", "fps", "1")
|
267
|
-
try:
|
268
|
-
imageio.mimsave(fig_dir + out_name, images, fps=int(fps))
|
269
|
-
except RuntimeError as e:
|
270
|
-
self.util.error("error writing anim gif: " + e)
|
271
|
-
|
272
|
-
def get_result(self):
|
273
|
-
return self.result
|
274
|
-
|
275
|
-
def plot_epoch_progression(self, reports, out_name):
|
276
|
-
fig_dir = self.util.get_path("fig_dir")
|
277
|
-
results, losses, train_results, losses_eval = [], [], [], []
|
278
|
-
for r in reports:
|
279
|
-
results.append(r.get_result().test)
|
280
|
-
losses.append(r.get_result().loss)
|
281
|
-
train_results.append(r.get_result().train)
|
282
|
-
losses_eval.append(r.get_result().loss_eval)
|
283
|
-
|
284
|
-
# do a plot per run
|
285
|
-
# scale the losses so they fit on the picture
|
286
|
-
losses, results, train_results, losses_eval = (
|
287
|
-
np.asarray(losses),
|
288
|
-
np.asarray(results),
|
289
|
-
np.asarray(train_results),
|
290
|
-
np.asarray(losses_eval),
|
291
|
-
)
|
292
|
-
|
293
|
-
if np.all((results > 1)):
|
294
|
-
# scale down values
|
295
|
-
results = results / 100.0
|
296
|
-
train_results = train_results / 100.0
|
297
|
-
# if np.all((losses < 1)):
|
298
|
-
# scale up values
|
299
|
-
plt.figure(dpi=200)
|
300
|
-
plt.plot(train_results, "green", label="train set")
|
301
|
-
plt.plot(results, "red", label="dev set")
|
302
|
-
plt.plot(losses, "black", label="losses")
|
303
|
-
plt.plot(losses_eval, "grey", label="losses_eval")
|
304
|
-
plt.xlabel("epochs")
|
305
|
-
plt.ylabel(f"{self.MEASURE}")
|
306
|
-
plt.legend()
|
307
|
-
plt.savefig(f"{fig_dir}{out_name}.{self.format}")
|
308
|
-
plt.close()
|
309
|
-
|
310
|
-
@staticmethod
|
311
|
-
def ccc(ground_truth, prediction):
|
312
|
-
mean_gt = np.mean(ground_truth, 0)
|
313
|
-
mean_pred = np.mean(prediction, 0)
|
314
|
-
var_gt = np.var(ground_truth, 0)
|
315
|
-
var_pred = np.var(prediction, 0)
|
316
|
-
v_pred = prediction - mean_pred
|
317
|
-
v_gt = ground_truth - mean_gt
|
318
|
-
cor = sum(v_pred * v_gt) / (np.sqrt(sum(v_pred**2)) * np.sqrt(sum(v_gt**2)))
|
319
|
-
sd_gt = np.std(ground_truth)
|
320
|
-
sd_pred = np.std(prediction)
|
321
|
-
numerator = 2 * cor * sd_gt * sd_pred
|
322
|
-
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
323
|
-
ccc = numerator / denominator
|
324
|
-
return ccc
|
File without changes
|
File without changes
|
File without changes
|