nkululeko 0.82.4__py3-none-any.whl → 0.83.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/experiment.py +4 -3
- nkululeko/modelrunner.py +4 -6
- nkululeko/nkuluflag.py +19 -6
- nkululeko/reporting/reporter.py +7 -3
- nkululeko/test.py +20 -15
- nkululeko/test_predictor.py +21 -7
- {nkululeko-0.82.4.dist-info → nkululeko-0.83.1.dist-info}/METADATA +9 -1
- {nkululeko-0.82.4.dist-info → nkululeko-0.83.1.dist-info}/RECORD +12 -13
- nkululeko/reporter.py +0 -332
- {nkululeko-0.82.4.dist-info → nkululeko-0.83.1.dist-info}/LICENSE +0 -0
- {nkululeko-0.82.4.dist-info → nkululeko-0.83.1.dist-info}/WHEEL +0 -0
- {nkululeko-0.82.4.dist-info → nkululeko-0.83.1.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.
|
1
|
+
VERSION="0.83.1"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/experiment.py
CHANGED
@@ -23,7 +23,7 @@ from nkululeko.plots import Plots
|
|
23
23
|
from nkululeko.reporting.report import Report
|
24
24
|
from nkululeko.runmanager import Runmanager
|
25
25
|
from nkululeko.scaler import Scaler
|
26
|
-
from nkululeko.test_predictor import
|
26
|
+
from nkululeko.test_predictor import TestPredictor
|
27
27
|
from nkululeko.utils.util import Util
|
28
28
|
|
29
29
|
|
@@ -672,10 +672,11 @@ class Experiment:
|
|
672
672
|
def predict_test_and_save(self, result_name):
|
673
673
|
model = self.runmgr.get_best_model()
|
674
674
|
model.set_testdata(self.df_test, self.feats_test)
|
675
|
-
test_predictor =
|
675
|
+
test_predictor = TestPredictor(
|
676
676
|
model, self.df_test, self.label_encoder, result_name
|
677
677
|
)
|
678
|
-
test_predictor.predict_and_store()
|
678
|
+
result = test_predictor.predict_and_store()
|
679
|
+
return result
|
679
680
|
|
680
681
|
def load(self, filename):
|
681
682
|
f = open(filename, "rb")
|
nkululeko/modelrunner.py
CHANGED
@@ -2,18 +2,16 @@
|
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
5
|
-
from nkululeko.utils.util import Util
|
6
5
|
from nkululeko import glob_conf
|
7
|
-
|
6
|
+
from nkululeko.utils.util import Util
|
8
7
|
|
9
8
|
|
10
9
|
class Modelrunner:
|
11
|
-
"""
|
12
|
-
Class to model one run
|
13
|
-
"""
|
10
|
+
"""Class to model one run."""
|
14
11
|
|
15
12
|
def __init__(self, df_train, df_test, feats_train, feats_test, run):
|
16
|
-
"""Constructor setting up the dataframes
|
13
|
+
"""Constructor setting up the dataframes.
|
14
|
+
|
17
15
|
Args:
|
18
16
|
df_train: train dataframe
|
19
17
|
df_test: test dataframe
|
nkululeko/nkuluflag.py
CHANGED
@@ -2,13 +2,16 @@ import argparse
|
|
2
2
|
import configparser
|
3
3
|
import os
|
4
4
|
import os.path
|
5
|
+
import sys
|
5
6
|
|
6
7
|
from nkululeko.nkululeko import doit as nkulu
|
8
|
+
from nkululeko.test import do_it as test_mod
|
7
9
|
|
8
10
|
|
9
|
-
def
|
11
|
+
def doit(cla):
|
10
12
|
parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
|
11
13
|
parser.add_argument("--config", help="The base configuration")
|
14
|
+
parser.add_argument("--mod", default="nkulu", help="Which nkululeko module to call")
|
12
15
|
parser.add_argument("--data", help="The databases", nargs="*", action="append")
|
13
16
|
parser.add_argument(
|
14
17
|
"--label", nargs="*", help="The labels for the target", action="append"
|
@@ -25,20 +28,23 @@ def do_it(src_dir):
|
|
25
28
|
parser.add_argument("--model", default="xgb", help="The model type")
|
26
29
|
parser.add_argument("--feat", default="['os']", help="The feature type")
|
27
30
|
parser.add_argument("--set", help="The opensmile set")
|
28
|
-
parser.add_argument("--with_os", help="To add os features")
|
29
31
|
parser.add_argument("--target", help="The target designation")
|
30
32
|
parser.add_argument("--epochs", help="The number of epochs")
|
31
33
|
parser.add_argument("--runs", help="The number of runs")
|
32
34
|
parser.add_argument("--learning_rate", help="The learning rate")
|
33
35
|
parser.add_argument("--drop", help="The dropout rate [0:1]")
|
34
36
|
|
35
|
-
args = parser.parse_args()
|
37
|
+
args = parser.parse_args(cla)
|
36
38
|
|
37
39
|
if args.config is not None:
|
38
40
|
config_file = args.config
|
39
41
|
else:
|
40
42
|
print("ERROR: need config file")
|
41
43
|
quit(-1)
|
44
|
+
|
45
|
+
if args.mod is not None:
|
46
|
+
nkulu_mod = args.mod
|
47
|
+
|
42
48
|
# test if config is there
|
43
49
|
if not os.path.isfile(config_file):
|
44
50
|
print(f"ERROR: no such file {config_file}")
|
@@ -86,10 +92,17 @@ def do_it(src_dir):
|
|
86
92
|
with open(tmp_config, "w") as tmp_file:
|
87
93
|
config.write(tmp_file)
|
88
94
|
|
89
|
-
result, last_epoch =
|
95
|
+
result, last_epoch = 0, 0
|
96
|
+
if nkulu_mod == "nkulu":
|
97
|
+
result, last_epoch = nkulu(tmp_config)
|
98
|
+
elif nkulu_mod == "test":
|
99
|
+
result, last_epoch = test_mod(tmp_config, "test_results.csv")
|
100
|
+
else:
|
101
|
+
print(f"ERROR: unknown module: {nkulu_mod}, should be [nkulu | test]")
|
90
102
|
return result, last_epoch
|
91
103
|
|
92
104
|
|
93
105
|
if __name__ == "__main__":
|
94
|
-
|
95
|
-
|
106
|
+
cla = sys.argv
|
107
|
+
cla.pop(0)
|
108
|
+
doit(cla) # sys.argv[1])
|
nkululeko/reporting/reporter.py
CHANGED
@@ -55,6 +55,7 @@ class Reporter:
|
|
55
55
|
self.run = run
|
56
56
|
self.epoch = epoch
|
57
57
|
self.__set_measure()
|
58
|
+
self.filenameadd = ""
|
58
59
|
self.cont_to_cat = False
|
59
60
|
if len(self.truths) > 0 and len(self.preds) > 0:
|
60
61
|
if self.util.exp_is_classification():
|
@@ -206,7 +207,7 @@ class Reporter:
|
|
206
207
|
f"Confusion Matrix, UAR: {uar_str} "
|
207
208
|
+ f"(+-{up_str}/{low_str}) {reg_res}"
|
208
209
|
)
|
209
|
-
img_path = f"{fig_dir}{plot_name}.{self.format}"
|
210
|
+
img_path = f"{fig_dir}{plot_name}{self.filenameadd}.{self.format}"
|
210
211
|
plt.savefig(img_path)
|
211
212
|
fig.clear()
|
212
213
|
plt.close(fig)
|
@@ -228,14 +229,17 @@ class Reporter:
|
|
228
229
|
)
|
229
230
|
# print(rpt)
|
230
231
|
self.util.debug(rpt)
|
231
|
-
file_name = f"{res_dir}{self.util.get_exp_name()}_conf.txt"
|
232
|
+
file_name = f"{res_dir}{self.util.get_exp_name()}{self.filenameadd}_conf.txt"
|
232
233
|
with open(file_name, "w") as text_file:
|
233
234
|
text_file.write(rpt)
|
234
235
|
|
236
|
+
def set_filename_add(self, my_string):
|
237
|
+
self.filenameadd = f"_{my_string}"
|
238
|
+
|
235
239
|
def print_results(self, epoch):
|
236
240
|
"""Print all evaluation values to text file."""
|
237
241
|
res_dir = self.util.get_path("res_dir")
|
238
|
-
file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}.txt"
|
242
|
+
file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}.txt"
|
239
243
|
if self.util.exp_is_classification():
|
240
244
|
labels = glob_conf.labels
|
241
245
|
try:
|
nkululeko/test.py
CHANGED
@@ -10,20 +10,7 @@ from nkululeko.experiment import Experiment
|
|
10
10
|
from nkululeko.utils.util import Util
|
11
11
|
|
12
12
|
|
13
|
-
def
|
14
|
-
parser = argparse.ArgumentParser(
|
15
|
-
description="Call the nkululeko TEST framework.")
|
16
|
-
parser.add_argument("--config", default="exp.ini",
|
17
|
-
help="The base configuration")
|
18
|
-
parser.add_argument(
|
19
|
-
"--outfile",
|
20
|
-
default="my_results.csv",
|
21
|
-
help="File name to store the predictions",
|
22
|
-
)
|
23
|
-
|
24
|
-
args = parser.parse_args()
|
25
|
-
|
26
|
-
config_file = args.config
|
13
|
+
def do_it(config_file, outfile):
|
27
14
|
|
28
15
|
# test if the configuration file exists
|
29
16
|
if not os.path.isfile(config_file):
|
@@ -48,10 +35,28 @@ def main(src_dir):
|
|
48
35
|
expr.load(f"{util.get_save_name()}")
|
49
36
|
expr.fill_tests()
|
50
37
|
expr.extract_test_feats()
|
51
|
-
expr.predict_test_and_save(
|
38
|
+
result = expr.predict_test_and_save(outfile)
|
52
39
|
|
53
40
|
print("DONE")
|
54
41
|
|
42
|
+
return result, 0
|
43
|
+
|
44
|
+
|
45
|
+
def main(src_dir):
|
46
|
+
parser = argparse.ArgumentParser(description="Call the nkululeko TEST framework.")
|
47
|
+
parser.add_argument("--config", default="exp.ini", help="The base configuration")
|
48
|
+
parser.add_argument(
|
49
|
+
"--outfile",
|
50
|
+
default="my_results.csv",
|
51
|
+
help="File name to store the predictions",
|
52
|
+
)
|
53
|
+
args = parser.parse_args()
|
54
|
+
if args.config is not None:
|
55
|
+
config_file = args.config
|
56
|
+
else:
|
57
|
+
config_file = f"{src_dir}/exp.ini"
|
58
|
+
do_it(config_file, args.outfile)
|
59
|
+
|
55
60
|
|
56
61
|
if __name__ == "__main__":
|
57
62
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
nkululeko/test_predictor.py
CHANGED
@@ -1,21 +1,25 @@
|
|
1
|
-
"""
|
1
|
+
"""test_predictor.py.
|
2
|
+
|
2
3
|
Predict targets from a model and save as csv file.
|
3
4
|
|
4
5
|
"""
|
5
6
|
|
6
|
-
import
|
7
|
-
|
7
|
+
import ast
|
8
|
+
|
9
|
+
import numpy as np
|
8
10
|
import pandas as pd
|
11
|
+
from sklearn.preprocessing import LabelEncoder
|
12
|
+
|
9
13
|
from nkululeko.data.dataset import Dataset
|
10
14
|
from nkululeko.feature_extractor import FeatureExtractor
|
15
|
+
import nkululeko.glob_conf as glob_conf
|
11
16
|
from nkululeko.scaler import Scaler
|
12
|
-
|
13
|
-
from sklearn.preprocessing import LabelEncoder
|
17
|
+
from nkululeko.utils.util import Util
|
14
18
|
|
15
19
|
|
16
|
-
class
|
20
|
+
class TestPredictor:
|
17
21
|
def __init__(self, model, orig_df, labenc, name):
|
18
|
-
"""Constructor setting up name and configuration"""
|
22
|
+
"""Constructor setting up name and configuration."""
|
19
23
|
self.model = model
|
20
24
|
self.orig_df = orig_df
|
21
25
|
self.label_encoder = labenc
|
@@ -25,6 +29,7 @@ class Test_predictor:
|
|
25
29
|
|
26
30
|
def predict_and_store(self):
|
27
31
|
label_data = self.util.config_val("DATA", "label_data", False)
|
32
|
+
result = 0
|
28
33
|
if label_data:
|
29
34
|
data = Dataset(label_data)
|
30
35
|
data.load()
|
@@ -49,7 +54,15 @@ class Test_predictor:
|
|
49
54
|
df[self.target] = labelenc.inverse_transform(predictions.tolist())
|
50
55
|
df.to_csv(self.name)
|
51
56
|
else:
|
57
|
+
test_dbs = ast.literal_eval(glob_conf.config["DATA"]["tests"])
|
58
|
+
test_dbs_string = "_".join(test_dbs)
|
52
59
|
predictions = self.model.get_predictions()
|
60
|
+
report = self.model.predict()
|
61
|
+
result = report.result.get_result()
|
62
|
+
report.set_filename_add(f"test-{test_dbs_string}")
|
63
|
+
self.util.print_best_results([report])
|
64
|
+
report.plot_confmatrix(self.util.get_plot_name(), 0)
|
65
|
+
report.print_results(0)
|
53
66
|
# print(predictions)
|
54
67
|
# df = pd.DataFrame(index=self.orig_df.index)
|
55
68
|
# df["speaker"] = self.orig_df["speaker"]
|
@@ -63,3 +76,4 @@ class Test_predictor:
|
|
63
76
|
df = df.rename(columns={"class_label": target})
|
64
77
|
df.to_csv(self.name)
|
65
78
|
self.util.debug(f"results stored in {self.name}")
|
79
|
+
return result
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.83.1
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -333,6 +333,14 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
333
333
|
Changelog
|
334
334
|
=========
|
335
335
|
|
336
|
+
Version 0.83.1
|
337
|
+
--------------
|
338
|
+
* add test module to nkuluflag
|
339
|
+
|
340
|
+
Version 0.83.0
|
341
|
+
--------------
|
342
|
+
* test module now prints out reports
|
343
|
+
|
336
344
|
Version 0.82.4
|
337
345
|
--------------
|
338
346
|
* fixed bug in wavlm
|
@@ -2,31 +2,30 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=i6-Vtyje9xE8w8o3lG27IiJczQFyrNbsxiXs7b4-q28,39
|
6
6
|
nkululeko/demo.py,sha256=55kNFA2helMhOxD4yZuKg1JWDtlUUpxm-6uAnroIydI,3264
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=-ggSHc3DXxRzjzcGB4qFBOMvKsfUdTkkde50BDrS9dA,4755
|
9
|
-
nkululeko/experiment.py,sha256=
|
9
|
+
nkululeko/experiment.py,sha256=aueWoKJCQx8wU9daosh6n7ZDGhT2cfo_9Av5HIfN1_w,29605
|
10
10
|
nkululeko/explore.py,sha256=2wdoGRqldvsN1zCiWk0quSDgHHHUoF2UZOWQ1r-2OLM,2310
|
11
11
|
nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
|
12
12
|
nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzwM,3976
|
13
13
|
nkululeko/file_checker.py,sha256=LoLnL8aHpW-axMQ46qbqrManTs5otG9ShpEZuz9iRSk,3474
|
14
14
|
nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
|
15
15
|
nkululeko/glob_conf.py,sha256=iHiVSxDYgmYwdx6z0HuGUMSWrfZfufPHxHb60q2dLRY,453
|
16
|
-
nkululeko/modelrunner.py,sha256=
|
16
|
+
nkululeko/modelrunner.py,sha256=GwDXcE2gDQXat4W0-HhHQ1BcUNCRBXMBQ4QycfHp_5c,9288
|
17
17
|
nkululeko/multidb.py,sha256=fG3VukEWP1vreVN4gB1IRXxwwg4jLftsSEYtu0o1f78,5634
|
18
|
-
nkululeko/nkuluflag.py,sha256=
|
18
|
+
nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
|
19
19
|
nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
|
20
20
|
nkululeko/plots.py,sha256=K88ZRPFGX_r03BT742H06Dde20xZYdltv7dxjgUiAFA,23025
|
21
21
|
nkululeko/predict.py,sha256=sF091sSSLnEWcISx9ZcULLie3tY5XeFsQJd6b3vrxFg,2409
|
22
|
-
nkululeko/reporter.py,sha256=mCy8er8z4e5hJ7XbOyy6BgZYZM6Lz-EKXHh4zlT0Zc8,12427
|
23
22
|
nkululeko/resample.py,sha256=3WbxkwgyTe_fW38046Rjxk3knOkFdhqn2C4nfhbUurQ,2287
|
24
23
|
nkululeko/runmanager.py,sha256=eTM1DNQKt1lxYhzt4vZyZluPXW9sWlIJHNQzex4lkJU,7624
|
25
24
|
nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
|
26
25
|
nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
|
27
26
|
nkululeko/syllable_nuclei.py,sha256=Sky-C__MeUDaxqHnDl2TGLLYOYvsahD35TUjWGeG31k,10047
|
28
|
-
nkululeko/test.py,sha256=
|
29
|
-
nkululeko/test_predictor.py,sha256=
|
27
|
+
nkululeko/test.py,sha256=1w624vo5KTzmFC8BUStGlLDmIEAFuJUz7J0W-gp7AxI,1677
|
28
|
+
nkululeko/test_predictor.py,sha256=_w5J8CxH6hmW3mLTKbdfmywl5QpdNAnW1Y8TE5GtlfE,3237
|
30
29
|
nkululeko/augmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
30
|
nkululeko/augmenting/augmenter.py,sha256=XAt0dpmlnKxqyysqCgV3rcz-pRIvOz7rU7dmGDCVAzs,2905
|
32
31
|
nkululeko/augmenting/randomsplicer.py,sha256=Z5rxdKKUpuncLWuTS6xVfVKUeVbeiYU_dLRHQ5fcg4Y,2669
|
@@ -95,7 +94,7 @@ nkululeko/reporting/defines.py,sha256=IsY1YgKRMaABpylVKjBJgJ5bNCEbGCVA_E6pivraqS
|
|
95
94
|
nkululeko/reporting/latex_writer.py,sha256=qiCRSmB4KOD_za4oHu5x-PhwjZohzfo8wecMOwlXZwc,1886
|
96
95
|
nkululeko/reporting/report.py,sha256=W0rcigDdjBvxZQ3pZja_gvToILYvaZ1BFtnN2qFRfYI,1060
|
97
96
|
nkululeko/reporting/report_item.py,sha256=siWeGNgo4bAE46YBMNcsdf3jTMTy76BO9Fi6DTvDig4,533
|
98
|
-
nkululeko/reporting/reporter.py,sha256=
|
97
|
+
nkululeko/reporting/reporter.py,sha256=eLqwKEUTQ7v5CedzhZP2617qmXGcvi0rjyyFLOBdxtQ,12841
|
99
98
|
nkululeko/reporting/result.py,sha256=nSN5or-Py2GPRWHkWpGRh7UCi1W0er7WLEHz8fYLk-A,742
|
100
99
|
nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
101
100
|
nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=pmLHuXsaqvcdYxB4PSW9l1mbQWZZBJFhi_CGabqydas,1947
|
@@ -104,8 +103,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
104
103
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
105
104
|
nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
|
106
105
|
nkululeko/utils/util.py,sha256=_Z6OMJ3f-8TdETW9eqJYY5hwNRS5XCt9azzRnqoTTZE,12330
|
107
|
-
nkululeko-0.
|
108
|
-
nkululeko-0.
|
109
|
-
nkululeko-0.
|
110
|
-
nkululeko-0.
|
111
|
-
nkululeko-0.
|
106
|
+
nkululeko-0.83.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
107
|
+
nkululeko-0.83.1.dist-info/METADATA,sha256=EgPYOS_ELZQmEvPWlX-klt8gmo59suFFL_HDptU474w,36080
|
108
|
+
nkululeko-0.83.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
109
|
+
nkululeko-0.83.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
110
|
+
nkululeko-0.83.1.dist-info/RECORD,,
|
nkululeko/reporter.py
DELETED
@@ -1,332 +0,0 @@
|
|
1
|
-
"""Reporter module.
|
2
|
-
|
3
|
-
This module contains the Reporter class which is responsible for generating reports.
|
4
|
-
"""
|
5
|
-
|
6
|
-
|
7
|
-
import ast
|
8
|
-
import glob
|
9
|
-
import json
|
10
|
-
import math
|
11
|
-
|
12
|
-
import matplotlib.pyplot as plt
|
13
|
-
import numpy as np
|
14
|
-
from scipy.stats import pearsonr
|
15
|
-
from sklearn.metrics import (
|
16
|
-
ConfusionMatrixDisplay,
|
17
|
-
accuracy_score,
|
18
|
-
classification_report,
|
19
|
-
confusion_matrix,
|
20
|
-
mean_absolute_error,
|
21
|
-
mean_squared_error,
|
22
|
-
r2_score,
|
23
|
-
recall_score,
|
24
|
-
)
|
25
|
-
from sklearn.utils import resample
|
26
|
-
|
27
|
-
import nkululeko.glob_conf as glob_conf
|
28
|
-
from nkululeko.reporting.defines import Header
|
29
|
-
from nkululeko.reporting.report_item import ReportItem
|
30
|
-
from nkululeko.result import Result
|
31
|
-
from nkululeko.utils.util import Util
|
32
|
-
|
33
|
-
|
34
|
-
class Reporter:
|
35
|
-
def __set_measure(self):
|
36
|
-
if self.util.exp_is_classification():
|
37
|
-
self.MEASURE = "UAR"
|
38
|
-
self.result.measure = self.MEASURE
|
39
|
-
self.is_classification = True
|
40
|
-
else:
|
41
|
-
self.is_classification = False
|
42
|
-
self.measure = self.util.config_val("MODEL", "measure", "mse")
|
43
|
-
if self.measure == "mse":
|
44
|
-
self.MEASURE = "MSE"
|
45
|
-
self.result.measure = self.MEASURE
|
46
|
-
elif self.measure == "mae":
|
47
|
-
self.MEASURE = "MAE"
|
48
|
-
self.result.measure = self.MEASURE
|
49
|
-
elif self.measure == "ccc":
|
50
|
-
self.MEASURE = "CCC"
|
51
|
-
self.result.measure = self.MEASURE
|
52
|
-
|
53
|
-
def __init__(self, truths, preds, run, epoch):
|
54
|
-
"""Initialization with ground truth und predictions vector"""
|
55
|
-
self.util = Util("reporter")
|
56
|
-
self.format = self.util.config_val("PLOT", "format", "png")
|
57
|
-
self.truths = truths
|
58
|
-
self.preds = preds
|
59
|
-
self.result = Result(0, 0, 0, 0, "unknown")
|
60
|
-
self.run = run
|
61
|
-
self.epoch = epoch
|
62
|
-
self.__set_measure()
|
63
|
-
self.cont_to_cat = False
|
64
|
-
if len(self.truths) > 0 and len(self.preds) > 0:
|
65
|
-
if self.util.exp_is_classification():
|
66
|
-
self.result.test = recall_score(
|
67
|
-
self.truths, self.preds, average="macro"
|
68
|
-
)
|
69
|
-
self.result.loss = 1 - accuracy_score(self.truths, self.preds)
|
70
|
-
else:
|
71
|
-
# regression experiment
|
72
|
-
if self.measure == "mse":
|
73
|
-
self.result.test = mean_squared_error(
|
74
|
-
self.truths, self.preds)
|
75
|
-
elif self.measure == "mae":
|
76
|
-
self.result.test = mean_absolute_error(
|
77
|
-
self.truths, self.preds)
|
78
|
-
elif self.measure == "ccc":
|
79
|
-
self.result.test = self.ccc(self.truths, self.preds)
|
80
|
-
if math.isnan(self.result.test):
|
81
|
-
self.util.debug(f"Truth: {self.truths}")
|
82
|
-
self.util.debug(f"Predict.: {self.preds}")
|
83
|
-
self.util.debug(f"Result is NAN: setting to -1")
|
84
|
-
self.result.test = -1
|
85
|
-
else:
|
86
|
-
self.util.error(f"unknown measure: {self.measure}")
|
87
|
-
|
88
|
-
# train and loss are being set by the model
|
89
|
-
|
90
|
-
def set_id(self, run, epoch):
|
91
|
-
"""Make the report identifiable with run and epoch index"""
|
92
|
-
self.run = run
|
93
|
-
self.epoch = epoch
|
94
|
-
|
95
|
-
def continuous_to_categorical(self):
|
96
|
-
if self.cont_to_cat:
|
97
|
-
return
|
98
|
-
self.cont_to_cat = True
|
99
|
-
bins = ast.literal_eval(glob_conf.config["DATA"]["bins"])
|
100
|
-
self.truths = np.digitize(self.truths, bins) - 1
|
101
|
-
self.preds = np.digitize(self.preds, bins) - 1
|
102
|
-
|
103
|
-
def plot_confmatrix(self, plot_name, epoch):
|
104
|
-
if not self.util.exp_is_classification():
|
105
|
-
self.continuous_to_categorical()
|
106
|
-
self._plot_confmat(self.truths, self.preds, plot_name, epoch)
|
107
|
-
|
108
|
-
|
109
|
-
def plot_per_speaker(self, result_df, plot_name, function):
|
110
|
-
"""Plot a confusion matrix with the mode category per speakers.
|
111
|
-
|
112
|
-
This function creates a confusion matrix for each speaker in the result_df.
|
113
|
-
The result_df should contain the columns: preds, truths and speaker.
|
114
|
-
|
115
|
-
Args:
|
116
|
-
* result_df: a pandas dataframe with columns: preds, truths and speaker
|
117
|
-
* plot_name: a string with the name of the plot
|
118
|
-
* function: a string with the function to use for each speaker,
|
119
|
-
can be 'mode' or 'mean'
|
120
|
-
|
121
|
-
Returns:
|
122
|
-
* None
|
123
|
-
"""
|
124
|
-
# Initialize empty arrays for predictions and truths
|
125
|
-
pred = np.zeros(0)
|
126
|
-
truth = np.zeros(0)
|
127
|
-
|
128
|
-
# Iterate over each speaker
|
129
|
-
for s in result_df.speaker.unique():
|
130
|
-
# Filter the dataframe for the current speaker
|
131
|
-
s_df = result_df[result_df.speaker == s]
|
132
|
-
|
133
|
-
# Get the mode or mean prediction for the current speaker
|
134
|
-
mode = s_df.pred.mode().iloc[-1]
|
135
|
-
mean = s_df.pred.mean()
|
136
|
-
if function == "mode":
|
137
|
-
s_df.pred = mode
|
138
|
-
elif function == "mean":
|
139
|
-
s_df.pred = mean
|
140
|
-
else:
|
141
|
-
self.util.error(f"unknown function {function}")
|
142
|
-
|
143
|
-
# Append the current speaker's predictions and truths to the arrays
|
144
|
-
pred = np.append(pred, s_df.pred.values)
|
145
|
-
truth = np.append(truth, s_df["truth"].values)
|
146
|
-
|
147
|
-
# If the experiment is not a classification or continuous to categorical conversion was performed,
|
148
|
-
# convert the truths and predictions to categorical
|
149
|
-
if not (self.is_classification or self.cont_to_cat):
|
150
|
-
bins = ast.literal_eval(glob_conf.config["DATA"]["bins"])
|
151
|
-
truth = np.digitize(truth, bins) - 1
|
152
|
-
pred = np.digitize(pred, bins) - 1
|
153
|
-
|
154
|
-
# Plot the confusion matrix for the speakers
|
155
|
-
self._plot_confmat(truth, pred.astype("int"), plot_name, 0)
|
156
|
-
|
157
|
-
def _plot_confmat(self, truths, preds, plot_name, epoch):
|
158
|
-
# print(truths)
|
159
|
-
# print(preds)
|
160
|
-
fig_dir = self.util.get_path("fig_dir")
|
161
|
-
labels = glob_conf.labels
|
162
|
-
fig = plt.figure() # figsize=[5, 5]
|
163
|
-
uar = recall_score(truths, preds, average="macro")
|
164
|
-
acc = accuracy_score(truths, preds)
|
165
|
-
cm = confusion_matrix(
|
166
|
-
truths, preds, normalize=None
|
167
|
-
) # normalize must be one of {'true', 'pred', 'all', None}
|
168
|
-
if cm.shape[0] != len(labels):
|
169
|
-
self.util.error(
|
170
|
-
f"mismatch between confmatrix dim ({cm.shape[0]}) and labels"
|
171
|
-
f" length ({len(labels)}: {labels})"
|
172
|
-
)
|
173
|
-
try:
|
174
|
-
disp = ConfusionMatrixDisplay(
|
175
|
-
confusion_matrix=cm, display_labels=labels
|
176
|
-
).plot(cmap="Blues")
|
177
|
-
except ValueError:
|
178
|
-
disp = ConfusionMatrixDisplay(
|
179
|
-
confusion_matrix=cm,
|
180
|
-
display_labels=list(labels).remove("neutral"),
|
181
|
-
).plot(cmap="Blues")
|
182
|
-
|
183
|
-
reg_res = ""
|
184
|
-
if not self.is_classification:
|
185
|
-
reg_res = f", {self.MEASURE}: {self.result.test:.3f}"
|
186
|
-
|
187
|
-
if epoch != 0:
|
188
|
-
plt.title(
|
189
|
-
f"Confusion Matrix, UAR: {uar:.3f}{reg_res}, Epoch: {epoch}")
|
190
|
-
else:
|
191
|
-
plt.title(f"Confusion Matrix, UAR: {uar:.3f}{reg_res}")
|
192
|
-
img_path = f"{fig_dir}{plot_name}.{self.format}"
|
193
|
-
plt.savefig(img_path)
|
194
|
-
fig.clear()
|
195
|
-
plt.close(fig)
|
196
|
-
plt.savefig(img_path)
|
197
|
-
plt.close(fig)
|
198
|
-
glob_conf.report.add_item(
|
199
|
-
ReportItem(
|
200
|
-
Header.HEADER_RESULTS,
|
201
|
-
self.util.get_model_description(),
|
202
|
-
"Confusion matrix",
|
203
|
-
img_path,
|
204
|
-
)
|
205
|
-
)
|
206
|
-
|
207
|
-
res_dir = self.util.get_path("res_dir")
|
208
|
-
uar = int(uar * 1000) / 1000.0
|
209
|
-
acc = int(acc * 1000) / 1000.0
|
210
|
-
rpt = f"epoch: {epoch}, UAR: {uar}, ACC: {acc}"
|
211
|
-
# print(rpt)
|
212
|
-
self.util.debug(rpt)
|
213
|
-
file_name = f"{res_dir}{self.util.get_exp_name()}_conf.txt"
|
214
|
-
with open(file_name, "w") as text_file:
|
215
|
-
text_file.write(rpt)
|
216
|
-
|
217
|
-
def print_results(self, epoch):
|
218
|
-
"""Print all evaluation values to text file"""
|
219
|
-
res_dir = self.util.get_path("res_dir")
|
220
|
-
file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}.txt"
|
221
|
-
if self.util.exp_is_classification():
|
222
|
-
labels = glob_conf.labels
|
223
|
-
try:
|
224
|
-
rpt = classification_report(
|
225
|
-
self.truths,
|
226
|
-
self.preds,
|
227
|
-
target_names=labels,
|
228
|
-
output_dict=True,
|
229
|
-
)
|
230
|
-
except ValueError as e:
|
231
|
-
self.util.debug(
|
232
|
-
"Reporter: caught a ValueError when trying to get"
|
233
|
-
" classification_report: " + e
|
234
|
-
)
|
235
|
-
rpt = self.result.to_string()
|
236
|
-
with open(file_name, "w") as text_file:
|
237
|
-
c_ress = list(range(len(labels)))
|
238
|
-
for i, l in enumerate(labels):
|
239
|
-
c_res = rpt[l]["f1-score"]
|
240
|
-
c_ress[i] = float(f"{c_res:.3f}")
|
241
|
-
self.util.debug(f"labels: {labels}")
|
242
|
-
f1_per_class = f"result per class (F1 score): {c_ress}"
|
243
|
-
self.util.debug(f1_per_class)
|
244
|
-
rpt_str = f"{json.dumps(rpt)}\n{f1_per_class}"
|
245
|
-
text_file.write(rpt_str)
|
246
|
-
glob_conf.report.add_item(
|
247
|
-
ReportItem(
|
248
|
-
Header.HEADER_RESULTS,
|
249
|
-
f"Classification result {self.util.get_model_description()}",
|
250
|
-
rpt_str,
|
251
|
-
)
|
252
|
-
)
|
253
|
-
|
254
|
-
else: # regression
|
255
|
-
result = self.result.test
|
256
|
-
r2 = r2_score(self.truths, self.preds)
|
257
|
-
pcc = pearsonr(self.truths, self.preds)[0]
|
258
|
-
measure = self.util.config_val("MODEL", "measure", "mse")
|
259
|
-
with open(file_name, "w") as text_file:
|
260
|
-
text_file.write(
|
261
|
-
f"{measure}: {result:.3f}, r_2: {r2:.3f}, pcc {pcc:.3f}"
|
262
|
-
)
|
263
|
-
|
264
|
-
def make_conf_animation(self, out_name):
|
265
|
-
import imageio
|
266
|
-
|
267
|
-
fig_dir = self.util.get_path("fig_dir")
|
268
|
-
filenames = glob.glob(
|
269
|
-
fig_dir + f"{self.util.get_plot_name()}*_?_???_cnf.png")
|
270
|
-
images = []
|
271
|
-
for filename in filenames:
|
272
|
-
images.append(imageio.imread(filename))
|
273
|
-
fps = self.util.config_val("PLOT", "fps", "1")
|
274
|
-
try:
|
275
|
-
imageio.mimsave(fig_dir + out_name, images, fps=int(fps))
|
276
|
-
except RuntimeError as e:
|
277
|
-
self.util.error("error writing anim gif: " + e)
|
278
|
-
|
279
|
-
def get_result(self):
|
280
|
-
return self.result
|
281
|
-
|
282
|
-
def plot_epoch_progression(self, reports, out_name):
|
283
|
-
fig_dir = self.util.get_path("fig_dir")
|
284
|
-
results, losses, train_results, losses_eval = [], [], [], []
|
285
|
-
for r in reports:
|
286
|
-
results.append(r.get_result().test)
|
287
|
-
losses.append(r.get_result().loss)
|
288
|
-
train_results.append(r.get_result().train)
|
289
|
-
losses_eval.append(r.get_result().loss_eval)
|
290
|
-
|
291
|
-
# do a plot per run
|
292
|
-
# scale the losses so they fit on the picture
|
293
|
-
losses, results, train_results, losses_eval = (
|
294
|
-
np.asarray(losses),
|
295
|
-
np.asarray(results),
|
296
|
-
np.asarray(train_results),
|
297
|
-
np.asarray(losses_eval),
|
298
|
-
)
|
299
|
-
|
300
|
-
if np.all((results > 1)):
|
301
|
-
# scale down values
|
302
|
-
results = results / 100.0
|
303
|
-
train_results = train_results / 100.0
|
304
|
-
# if np.all((losses < 1)):
|
305
|
-
# scale up values
|
306
|
-
plt.figure(dpi=200)
|
307
|
-
plt.plot(train_results, "green", label="train set")
|
308
|
-
plt.plot(results, "red", label="dev set")
|
309
|
-
plt.plot(losses, "black", label="losses")
|
310
|
-
plt.plot(losses_eval, "grey", label="losses_eval")
|
311
|
-
plt.xlabel("epochs")
|
312
|
-
plt.ylabel(f"{self.MEASURE}")
|
313
|
-
plt.legend()
|
314
|
-
plt.savefig(f"{fig_dir}{out_name}.{self.format}")
|
315
|
-
plt.close()
|
316
|
-
|
317
|
-
@staticmethod
|
318
|
-
def ccc(ground_truth, prediction):
|
319
|
-
mean_gt = np.mean(ground_truth, 0)
|
320
|
-
mean_pred = np.mean(prediction, 0)
|
321
|
-
var_gt = np.var(ground_truth, 0)
|
322
|
-
var_pred = np.var(prediction, 0)
|
323
|
-
v_pred = prediction - mean_pred
|
324
|
-
v_gt = ground_truth - mean_gt
|
325
|
-
cor = sum(v_pred * v_gt) / \
|
326
|
-
(np.sqrt(sum(v_pred**2)) * np.sqrt(sum(v_gt**2)))
|
327
|
-
sd_gt = np.std(ground_truth)
|
328
|
-
sd_pred = np.std(prediction)
|
329
|
-
numerator = 2 * cor * sd_gt * sd_pred
|
330
|
-
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
331
|
-
ccc = numerator / denominator
|
332
|
-
return ccc
|
File without changes
|
File without changes
|
File without changes
|