nkululeko 0.83.0__py3-none-any.whl → 0.83.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.83.0"
1
+ VERSION="0.83.1"
2
2
  SAMPLING_RATE = 16000
nkululeko/experiment.py CHANGED
@@ -675,7 +675,8 @@ class Experiment:
675
675
  test_predictor = TestPredictor(
676
676
  model, self.df_test, self.label_encoder, result_name
677
677
  )
678
- test_predictor.predict_and_store()
678
+ result = test_predictor.predict_and_store()
679
+ return result
679
680
 
680
681
  def load(self, filename):
681
682
  f = open(filename, "rb")
nkululeko/nkuluflag.py CHANGED
@@ -2,13 +2,16 @@ import argparse
2
2
  import configparser
3
3
  import os
4
4
  import os.path
5
+ import sys
5
6
 
6
7
  from nkululeko.nkululeko import doit as nkulu
8
+ from nkululeko.test import do_it as test_mod
7
9
 
8
10
 
9
- def do_it(src_dir):
11
+ def doit(cla):
10
12
  parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
11
13
  parser.add_argument("--config", help="The base configuration")
14
+ parser.add_argument("--mod", default="nkulu", help="Which nkululeko module to call")
12
15
  parser.add_argument("--data", help="The databases", nargs="*", action="append")
13
16
  parser.add_argument(
14
17
  "--label", nargs="*", help="The labels for the target", action="append"
@@ -25,20 +28,23 @@ def do_it(src_dir):
25
28
  parser.add_argument("--model", default="xgb", help="The model type")
26
29
  parser.add_argument("--feat", default="['os']", help="The feature type")
27
30
  parser.add_argument("--set", help="The opensmile set")
28
- parser.add_argument("--with_os", help="To add os features")
29
31
  parser.add_argument("--target", help="The target designation")
30
32
  parser.add_argument("--epochs", help="The number of epochs")
31
33
  parser.add_argument("--runs", help="The number of runs")
32
34
  parser.add_argument("--learning_rate", help="The learning rate")
33
35
  parser.add_argument("--drop", help="The dropout rate [0:1]")
34
36
 
35
- args = parser.parse_args()
37
+ args = parser.parse_args(cla)
36
38
 
37
39
  if args.config is not None:
38
40
  config_file = args.config
39
41
  else:
40
42
  print("ERROR: need config file")
41
43
  quit(-1)
44
+
45
+ if args.mod is not None:
46
+ nkulu_mod = args.mod
47
+
42
48
  # test if config is there
43
49
  if not os.path.isfile(config_file):
44
50
  print(f"ERROR: no such file {config_file}")
@@ -86,10 +92,17 @@ def do_it(src_dir):
86
92
  with open(tmp_config, "w") as tmp_file:
87
93
  config.write(tmp_file)
88
94
 
89
- result, last_epoch = nkulu(tmp_config)
95
+ result, last_epoch = 0, 0
96
+ if nkulu_mod == "nkulu":
97
+ result, last_epoch = nkulu(tmp_config)
98
+ elif nkulu_mod == "test":
99
+ result, last_epoch = test_mod(tmp_config, "test_results.csv")
100
+ else:
101
+ print(f"ERROR: unknown module: {nkulu_mod}, should be [nkulu | test]")
90
102
  return result, last_epoch
91
103
 
92
104
 
93
105
  if __name__ == "__main__":
94
- cwd = os.path.dirname(os.path.abspath(__file__))
95
- do_it(cwd) # sys.argv[1])
106
+ cla = sys.argv
107
+ cla.pop(0)
108
+ doit(cla) # sys.argv[1])
nkululeko/test.py CHANGED
@@ -10,20 +10,7 @@ from nkululeko.experiment import Experiment
10
10
  from nkululeko.utils.util import Util
11
11
 
12
12
 
13
- def main(src_dir):
14
- parser = argparse.ArgumentParser(
15
- description="Call the nkululeko TEST framework.")
16
- parser.add_argument("--config", default="exp.ini",
17
- help="The base configuration")
18
- parser.add_argument(
19
- "--outfile",
20
- default="my_results.csv",
21
- help="File name to store the predictions",
22
- )
23
-
24
- args = parser.parse_args()
25
-
26
- config_file = args.config
13
+ def do_it(config_file, outfile):
27
14
 
28
15
  # test if the configuration file exists
29
16
  if not os.path.isfile(config_file):
@@ -48,10 +35,28 @@ def main(src_dir):
48
35
  expr.load(f"{util.get_save_name()}")
49
36
  expr.fill_tests()
50
37
  expr.extract_test_feats()
51
- expr.predict_test_and_save(args.outfile)
38
+ result = expr.predict_test_and_save(outfile)
52
39
 
53
40
  print("DONE")
54
41
 
42
+ return result, 0
43
+
44
+
45
+ def main(src_dir):
46
+ parser = argparse.ArgumentParser(description="Call the nkululeko TEST framework.")
47
+ parser.add_argument("--config", default="exp.ini", help="The base configuration")
48
+ parser.add_argument(
49
+ "--outfile",
50
+ default="my_results.csv",
51
+ help="File name to store the predictions",
52
+ )
53
+ args = parser.parse_args()
54
+ if args.config is not None:
55
+ config_file = args.config
56
+ else:
57
+ config_file = f"{src_dir}/exp.ini"
58
+ do_it(config_file, args.outfile)
59
+
55
60
 
56
61
  if __name__ == "__main__":
57
62
  cwd = os.path.dirname(os.path.abspath(__file__))
@@ -29,6 +29,7 @@ class TestPredictor:
29
29
 
30
30
  def predict_and_store(self):
31
31
  label_data = self.util.config_val("DATA", "label_data", False)
32
+ result = 0
32
33
  if label_data:
33
34
  data = Dataset(label_data)
34
35
  data.load()
@@ -57,6 +58,7 @@ class TestPredictor:
57
58
  test_dbs_string = "_".join(test_dbs)
58
59
  predictions = self.model.get_predictions()
59
60
  report = self.model.predict()
61
+ result = report.result.get_result()
60
62
  report.set_filename_add(f"test-{test_dbs_string}")
61
63
  self.util.print_best_results([report])
62
64
  report.plot_confmatrix(self.util.get_plot_name(), 0)
@@ -74,3 +76,4 @@ class TestPredictor:
74
76
  df = df.rename(columns={"class_label": target})
75
77
  df.to_csv(self.name)
76
78
  self.util.debug(f"results stored in {self.name}")
79
+ return result
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.83.0
3
+ Version: 0.83.1
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -333,6 +333,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
333
333
  Changelog
334
334
  =========
335
335
 
336
+ Version 0.83.1
337
+ --------------
338
+ * add test module to nkuluflag
339
+
336
340
  Version 0.83.0
337
341
  --------------
338
342
  * test module now prints out reports
@@ -2,11 +2,11 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
3
3
  nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
4
4
  nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
5
- nkululeko/constants.py,sha256=NNx53OyRpXv780Ycj6Cdw4bDJfdvEn180CaN2PcmQkY,39
5
+ nkululeko/constants.py,sha256=i6-Vtyje9xE8w8o3lG27IiJczQFyrNbsxiXs7b4-q28,39
6
6
  nkululeko/demo.py,sha256=55kNFA2helMhOxD4yZuKg1JWDtlUUpxm-6uAnroIydI,3264
7
7
  nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
8
8
  nkululeko/demo_predictor.py,sha256=-ggSHc3DXxRzjzcGB4qFBOMvKsfUdTkkde50BDrS9dA,4755
9
- nkululeko/experiment.py,sha256=SRcB0ni0XLK910NSWTyRAe-Eoa6fVSKDCJlDJKyCzMc,29574
9
+ nkululeko/experiment.py,sha256=aueWoKJCQx8wU9daosh6n7ZDGhT2cfo_9Av5HIfN1_w,29605
10
10
  nkululeko/explore.py,sha256=2wdoGRqldvsN1zCiWk0quSDgHHHUoF2UZOWQ1r-2OLM,2310
11
11
  nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
12
12
  nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzwM,3976
@@ -15,18 +15,17 @@ nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
15
15
  nkululeko/glob_conf.py,sha256=iHiVSxDYgmYwdx6z0HuGUMSWrfZfufPHxHb60q2dLRY,453
16
16
  nkululeko/modelrunner.py,sha256=GwDXcE2gDQXat4W0-HhHQ1BcUNCRBXMBQ4QycfHp_5c,9288
17
17
  nkululeko/multidb.py,sha256=fG3VukEWP1vreVN4gB1IRXxwwg4jLftsSEYtu0o1f78,5634
18
- nkululeko/nkuluflag.py,sha256=FCetTfgH69u4AwENgeCKVi3vBIR10Di67SfbupGQqfc,3354
18
+ nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
19
19
  nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
20
20
  nkululeko/plots.py,sha256=K88ZRPFGX_r03BT742H06Dde20xZYdltv7dxjgUiAFA,23025
21
21
  nkululeko/predict.py,sha256=sF091sSSLnEWcISx9ZcULLie3tY5XeFsQJd6b3vrxFg,2409
22
- nkululeko/reporter.py,sha256=8mlIaKep4hM-tdRv8t98tK80rx3zOmVGXSORhiPc3as,12483
23
22
  nkululeko/resample.py,sha256=3WbxkwgyTe_fW38046Rjxk3knOkFdhqn2C4nfhbUurQ,2287
24
23
  nkululeko/runmanager.py,sha256=eTM1DNQKt1lxYhzt4vZyZluPXW9sWlIJHNQzex4lkJU,7624
25
24
  nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
26
25
  nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
27
26
  nkululeko/syllable_nuclei.py,sha256=Sky-C__MeUDaxqHnDl2TGLLYOYvsahD35TUjWGeG31k,10047
28
- nkululeko/test.py,sha256=JRoLgqQJEhAIGetw-qlOUihSTTQ7O8DYafB0FlQESIQ,1525
29
- nkululeko/test_predictor.py,sha256=L8XKrIweTf-oKeaGuDw_ZhtvzRUxFuWmOhva6jgf7-s,3148
27
+ nkululeko/test.py,sha256=1w624vo5KTzmFC8BUStGlLDmIEAFuJUz7J0W-gp7AxI,1677
28
+ nkululeko/test_predictor.py,sha256=_w5J8CxH6hmW3mLTKbdfmywl5QpdNAnW1Y8TE5GtlfE,3237
30
29
  nkululeko/augmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
30
  nkululeko/augmenting/augmenter.py,sha256=XAt0dpmlnKxqyysqCgV3rcz-pRIvOz7rU7dmGDCVAzs,2905
32
31
  nkululeko/augmenting/randomsplicer.py,sha256=Z5rxdKKUpuncLWuTS6xVfVKUeVbeiYU_dLRHQ5fcg4Y,2669
@@ -104,8 +103,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
104
103
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
105
104
  nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
106
105
  nkululeko/utils/util.py,sha256=_Z6OMJ3f-8TdETW9eqJYY5hwNRS5XCt9azzRnqoTTZE,12330
107
- nkululeko-0.83.0.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
108
- nkululeko-0.83.0.dist-info/METADATA,sha256=20S7IpMbLE7irV0ikdaFNfdqdBEEywH7jjlJwur8smA,36018
109
- nkululeko-0.83.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
110
- nkululeko-0.83.0.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
111
- nkululeko-0.83.0.dist-info/RECORD,,
106
+ nkululeko-0.83.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
107
+ nkululeko-0.83.1.dist-info/METADATA,sha256=EgPYOS_ELZQmEvPWlX-klt8gmo59suFFL_HDptU474w,36080
108
+ nkululeko-0.83.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
109
+ nkululeko-0.83.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
110
+ nkululeko-0.83.1.dist-info/RECORD,,
nkululeko/reporter.py DELETED
@@ -1,324 +0,0 @@
1
- """Reporter module.
2
-
3
- This module contains the Reporter class which is responsible for generating reports.
4
- """
5
-
6
- import ast
7
- import glob
8
- import json
9
- import math
10
-
11
- import matplotlib.pyplot as plt
12
- import numpy as np
13
- from scipy.stats import pearsonr
14
- from sklearn.metrics import ConfusionMatrixDisplay
15
- from sklearn.metrics import accuracy_score
16
- from sklearn.metrics import classification_report
17
- from sklearn.metrics import confusion_matrix
18
- from sklearn.metrics import mean_absolute_error
19
- from sklearn.metrics import mean_squared_error
20
- from sklearn.metrics import r2_score
21
- from sklearn.metrics import recall_score
22
- from sklearn.utils import resample
23
-
24
- import nkululeko.glob_conf as glob_conf
25
- from nkululeko.reporting.defines import Header
26
- from nkululeko.reporting.report_item import ReportItem
27
- from nkululeko.result import Result
28
- from nkululeko.utils.util import Util
29
-
30
-
31
- class Reporter:
32
- def __set_measure(self):
33
- if self.util.exp_is_classification():
34
- self.MEASURE = "UAR"
35
- self.result.measure = self.MEASURE
36
- self.is_classification = True
37
- else:
38
- self.is_classification = False
39
- self.measure = self.util.config_val("MODEL", "measure", "mse")
40
- if self.measure == "mse":
41
- self.MEASURE = "MSE"
42
- self.result.measure = self.MEASURE
43
- elif self.measure == "mae":
44
- self.MEASURE = "MAE"
45
- self.result.measure = self.MEASURE
46
- elif self.measure == "ccc":
47
- self.MEASURE = "CCC"
48
- self.result.measure = self.MEASURE
49
-
50
- def __init__(self, truths, preds, run, epoch):
51
- """Initialization with ground truth und predictions vector"""
52
- self.util = Util("reporter")
53
- self.format = self.util.config_val("PLOT", "format", "png")
54
- self.truths = truths
55
- self.preds = preds
56
- self.result = Result(0, 0, 0, 0, "unknown")
57
- self.run = run
58
- self.epoch = epoch
59
- self.__set_measure()
60
- self.cont_to_cat = False
61
- if len(self.truths) > 0 and len(self.preds) > 0:
62
- if self.util.exp_is_classification():
63
- self.result.test = recall_score(
64
- self.truths, self.preds, average="macro"
65
- )
66
- self.result.loss = 1 - accuracy_score(self.truths, self.preds)
67
- else:
68
- # regression experiment
69
- if self.measure == "mse":
70
- self.result.test = mean_squared_error(self.truths, self.preds)
71
- elif self.measure == "mae":
72
- self.result.test = mean_absolute_error(self.truths, self.preds)
73
- elif self.measure == "ccc":
74
- self.result.test = self.ccc(self.truths, self.preds)
75
- if math.isnan(self.result.test):
76
- self.util.debug(f"Truth: {self.truths}")
77
- self.util.debug(f"Predict.: {self.preds}")
78
- self.util.debug(f"Result is NAN: setting to -1")
79
- self.result.test = -1
80
- else:
81
- self.util.error(f"unknown measure: {self.measure}")
82
-
83
- # train and loss are being set by the model
84
-
85
- def set_id(self, run, epoch):
86
- """Make the report identifiable with run and epoch index"""
87
- self.run = run
88
- self.epoch = epoch
89
-
90
- def continuous_to_categorical(self):
91
- if self.cont_to_cat:
92
- return
93
- self.cont_to_cat = True
94
- bins = ast.literal_eval(glob_conf.config["DATA"]["bins"])
95
- self.truths = np.digitize(self.truths, bins) - 1
96
- self.preds = np.digitize(self.preds, bins) - 1
97
-
98
- def plot_confmatrix(self, plot_name, epoch):
99
- if not self.util.exp_is_classification():
100
- self.continuous_to_categorical()
101
- self._plot_confmat(self.truths, self.preds, plot_name, epoch)
102
-
103
-
104
- def plot_per_speaker(self, result_df, plot_name, function):
105
- """Plot a confusion matrix with the mode category per speakers.
106
-
107
- This function creates a confusion matrix for each speaker in the result_df.
108
- The result_df should contain the columns: preds, truths and speaker.
109
-
110
- Args:
111
- * result_df: a pandas dataframe with columns: preds, truths and speaker
112
- * plot_name: a string with the name of the plot
113
- * function: a string with the function to use for each speaker,
114
- can be 'mode' or 'mean'
115
-
116
- Returns:
117
- * None
118
- """
119
- # Initialize empty arrays for predictions and truths
120
- pred = np.zeros(0)
121
- truth = np.zeros(0)
122
-
123
- # Iterate over each speaker
124
- for s in result_df.speaker.unique():
125
- # Filter the dataframe for the current speaker
126
- s_df = result_df[result_df.speaker == s]
127
-
128
- # Get the mode or mean prediction for the current speaker
129
- mode = s_df.pred.mode().iloc[-1]
130
- mean = s_df.pred.mean()
131
- if function == "mode":
132
- s_df.pred = mode
133
- elif function == "mean":
134
- s_df.pred = mean
135
- else:
136
- self.util.error(f"unknown function {function}")
137
-
138
- # Append the current speaker's predictions and truths to the arrays
139
- pred = np.append(pred, s_df.pred.values)
140
- truth = np.append(truth, s_df["truth"].values)
141
-
142
- # If the experiment is not a classification or continuous to categorical conversion was performed,
143
- # convert the truths and predictions to categorical
144
- if not (self.is_classification or self.cont_to_cat):
145
- bins = ast.literal_eval(glob_conf.config["DATA"]["bins"])
146
- truth = np.digitize(truth, bins) - 1
147
- pred = np.digitize(pred, bins) - 1
148
-
149
- # Plot the confusion matrix for the speakers
150
- self._plot_confmat(truth, pred.astype("int"), plot_name, 0)
151
-
152
- def _plot_confmat(self, truths, preds, plot_name, epoch):
153
- # print(truths)
154
- # print(preds)
155
- fig_dir = self.util.get_path("fig_dir")
156
- labels = glob_conf.labels
157
- fig = plt.figure() # figsize=[5, 5]
158
- uar = recall_score(truths, preds, average="macro")
159
- acc = accuracy_score(truths, preds)
160
- cm = confusion_matrix(
161
- truths, preds, normalize=None
162
- ) # normalize must be one of {'true', 'pred', 'all', None}
163
- if cm.shape[0] != len(labels):
164
- self.util.error(
165
- f"mismatch between confmatrix dim ({cm.shape[0]}) and labels"
166
- f" length ({len(labels)}: {labels})"
167
- )
168
- try:
169
- disp = ConfusionMatrixDisplay(
170
- confusion_matrix=cm, display_labels=labels
171
- ).plot(cmap="Blues")
172
- except ValueError:
173
- disp = ConfusionMatrixDisplay(
174
- confusion_matrix=cm,
175
- display_labels=list(labels).remove("neutral"),
176
- ).plot(cmap="Blues")
177
-
178
- reg_res = ""
179
- if not self.is_classification:
180
- reg_res = f", {self.MEASURE}: {self.result.test:.3f}"
181
-
182
- if epoch != 0:
183
- plt.title(f"Confusion Matrix, UAR: {uar:.3f}{reg_res}, Epoch: {epoch}")
184
- else:
185
- plt.title(f"Confusion Matrix, UAR: {uar:.3f}{reg_res}")
186
- img_path = f"{fig_dir}{plot_name}.{self.format}"
187
- plt.savefig(img_path)
188
- fig.clear()
189
- plt.close(fig)
190
- plt.savefig(img_path)
191
- plt.close(fig)
192
- glob_conf.report.add_item(
193
- ReportItem(
194
- Header.HEADER_RESULTS,
195
- self.util.get_model_description(),
196
- "Confusion matrix",
197
- img_path,
198
- )
199
- )
200
-
201
- res_dir = self.util.get_path("res_dir")
202
- uar = int(uar * 1000) / 1000.0
203
- acc = int(acc * 1000) / 1000.0
204
- rpt = f"epoch: {epoch}, UAR: {uar}, ACC: {acc}"
205
- # print(rpt)
206
- self.util.debug(rpt)
207
- file_name = f"{res_dir}{self.util.get_exp_name()}_conf.txt"
208
- with open(file_name, "w") as text_file:
209
- text_file.write(rpt)
210
-
211
- def print_results(self, epoch):
212
- """Print all evaluation values to text file"""
213
- res_dir = self.util.get_path("res_dir")
214
- file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}.txt"
215
- if self.util.exp_is_classification():
216
- labels = glob_conf.labels
217
- try:
218
- rpt = classification_report(
219
- self.truths,
220
- self.preds,
221
- target_names=labels,
222
- output_dict=True,
223
- )
224
- except ValueError as e:
225
- self.util.debug(
226
- "Reporter: caught a ValueError when trying to get"
227
- " classification_report: " + e
228
- )
229
- rpt = self.result.to_string()
230
- with open(file_name, "w") as text_file:
231
- c_ress = list(range(len(labels)))
232
- for i, l in enumerate(labels):
233
- c_res = rpt[l]["f1-score"]
234
- c_ress[i] = float(f"{c_res:.3f}")
235
- self.util.debug(f"labels: {labels}")
236
- f1_per_class = f"result per class (F1 score): {c_ress}"
237
- self.util.debug(f1_per_class)
238
- rpt_str = f"{json.dumps(rpt)}\n{f1_per_class}"
239
- text_file.write(rpt_str)
240
- glob_conf.report.add_item(
241
- ReportItem(
242
- Header.HEADER_RESULTS,
243
- f"Classification result {self.util.get_model_description()}",
244
- rpt_str,
245
- )
246
- )
247
-
248
- else: # regression
249
- result = self.result.test
250
- r2 = r2_score(self.truths, self.preds)
251
- pcc = pearsonr(self.truths, self.preds)[0]
252
- measure = self.util.config_val("MODEL", "measure", "mse")
253
- with open(file_name, "w") as text_file:
254
- text_file.write(
255
- f"{measure}: {result:.3f}, r_2: {r2:.3f}, pcc {pcc:.3f}"
256
- )
257
-
258
- def make_conf_animation(self, out_name):
259
- import imageio
260
-
261
- fig_dir = self.util.get_path("fig_dir")
262
- filenames = glob.glob(fig_dir + f"{self.util.get_plot_name()}*_?_???_cnf.png")
263
- images = []
264
- for filename in filenames:
265
- images.append(imageio.imread(filename))
266
- fps = self.util.config_val("PLOT", "fps", "1")
267
- try:
268
- imageio.mimsave(fig_dir + out_name, images, fps=int(fps))
269
- except RuntimeError as e:
270
- self.util.error("error writing anim gif: " + e)
271
-
272
- def get_result(self):
273
- return self.result
274
-
275
- def plot_epoch_progression(self, reports, out_name):
276
- fig_dir = self.util.get_path("fig_dir")
277
- results, losses, train_results, losses_eval = [], [], [], []
278
- for r in reports:
279
- results.append(r.get_result().test)
280
- losses.append(r.get_result().loss)
281
- train_results.append(r.get_result().train)
282
- losses_eval.append(r.get_result().loss_eval)
283
-
284
- # do a plot per run
285
- # scale the losses so they fit on the picture
286
- losses, results, train_results, losses_eval = (
287
- np.asarray(losses),
288
- np.asarray(results),
289
- np.asarray(train_results),
290
- np.asarray(losses_eval),
291
- )
292
-
293
- if np.all((results > 1)):
294
- # scale down values
295
- results = results / 100.0
296
- train_results = train_results / 100.0
297
- # if np.all((losses < 1)):
298
- # scale up values
299
- plt.figure(dpi=200)
300
- plt.plot(train_results, "green", label="train set")
301
- plt.plot(results, "red", label="dev set")
302
- plt.plot(losses, "black", label="losses")
303
- plt.plot(losses_eval, "grey", label="losses_eval")
304
- plt.xlabel("epochs")
305
- plt.ylabel(f"{self.MEASURE}")
306
- plt.legend()
307
- plt.savefig(f"{fig_dir}{out_name}.{self.format}")
308
- plt.close()
309
-
310
- @staticmethod
311
- def ccc(ground_truth, prediction):
312
- mean_gt = np.mean(ground_truth, 0)
313
- mean_pred = np.mean(prediction, 0)
314
- var_gt = np.var(ground_truth, 0)
315
- var_pred = np.var(prediction, 0)
316
- v_pred = prediction - mean_pred
317
- v_gt = ground_truth - mean_gt
318
- cor = sum(v_pred * v_gt) / (np.sqrt(sum(v_pred**2)) * np.sqrt(sum(v_gt**2)))
319
- sd_gt = np.std(ground_truth)
320
- sd_pred = np.std(prediction)
321
- numerator = 2 * cor * sd_gt * sd_pred
322
- denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
323
- ccc = numerator / denominator
324
- return ccc