nkululeko 0.93.15__py3-none-any.whl → 0.94.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/aug_train.py +13 -2
- nkululeko/constants.py +1 -1
- nkululeko/data/dataset.py +287 -36
- nkululeko/experiment.py +121 -17
- nkululeko/feat_extract/feats_opensmile copy.py +93 -0
- nkululeko/feat_extract/feats_opensmile.py +207 -60
- nkululeko/feat_extract/feats_trill.py +2 -2
- nkululeko/modelrunner.py +23 -10
- nkululeko/models/model_mlp.py +2 -0
- nkululeko/nkululeko.py +0 -1
- nkululeko/plots.py +11 -2
- nkululeko/reporting/reporter.py +25 -39
- nkululeko/runmanager.py +53 -33
- nkululeko/scaler.py +41 -24
- nkululeko/utils/util.py +1 -1
- {nkululeko-0.93.15.dist-info → nkululeko-0.94.0.dist-info}/METADATA +3 -2
- {nkululeko-0.93.15.dist-info → nkululeko-0.94.0.dist-info}/RECORD +21 -20
- {nkululeko-0.93.15.dist-info → nkululeko-0.94.0.dist-info}/WHEEL +1 -1
- {nkululeko-0.93.15.dist-info → nkululeko-0.94.0.dist-info}/entry_points.txt +0 -0
- {nkululeko-0.93.15.dist-info → nkululeko-0.94.0.dist-info/licenses}/LICENSE +0 -0
- {nkululeko-0.93.15.dist-info → nkululeko-0.94.0.dist-info}/top_level.txt +0 -0
@@ -1,93 +1,240 @@
|
|
1
1
|
# opensmileset.py
|
2
|
+
"""Module for extracting OpenSMILE features from audio files.
|
3
|
+
OpenSMILE is an audio feature extraction toolkit supporting various feature sets.
|
4
|
+
"""
|
2
5
|
import os
|
6
|
+
import logging
|
7
|
+
from typing import Optional, Union, List, Any, Dict
|
3
8
|
|
4
9
|
import opensmile
|
5
10
|
import pandas as pd
|
11
|
+
import numpy as np
|
6
12
|
|
7
13
|
import nkululeko.glob_conf as glob_conf
|
8
14
|
from nkululeko.feat_extract.featureset import Featureset
|
9
15
|
|
10
16
|
|
11
17
|
class Opensmileset(Featureset):
|
12
|
-
|
18
|
+
"""Class for extracting OpenSMILE features from audio files.
|
19
|
+
|
20
|
+
This class provides methods to extract various OpenSMILE feature sets like eGeMAPSv02,
|
21
|
+
ComParE_2016, etc. at different feature levels (LowLevelDescriptors or Functionals).
|
22
|
+
|
23
|
+
Attributes:
|
24
|
+
featset (str): The OpenSMILE feature set to extract (e.g., 'eGeMAPSv02')
|
25
|
+
feature_set: The OpenSMILE feature set object
|
26
|
+
featlevel (str): The feature level ('LowLevelDescriptors' or 'Functionals')
|
27
|
+
feature_level: The OpenSMILE feature level object
|
28
|
+
"""
|
29
|
+
|
30
|
+
# Available feature sets for validation
|
31
|
+
AVAILABLE_FEATURE_SETS = ["eGeMAPSv02", "ComParE_2016", "GeMAPSv01a", "eGeMAPSv01a"]
|
32
|
+
|
33
|
+
# Available feature levels for validation
|
34
|
+
AVAILABLE_FEATURE_LEVELS = ["LowLevelDescriptors", "Functionals"]
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
name: str,
|
39
|
+
data_df: pd.DataFrame,
|
40
|
+
feats_type: Optional[str] = None,
|
41
|
+
config_file: Optional[str] = None,
|
42
|
+
):
|
43
|
+
"""Initialize the Opensmileset class.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
name (str): Name of the feature set
|
47
|
+
data_df (pd.DataFrame): DataFrame containing audio file paths
|
48
|
+
feats_type (Optional[str]): Type of features to extract
|
49
|
+
config_file (Optional[str]): Configuration file path
|
50
|
+
"""
|
13
51
|
super().__init__(name, data_df, feats_type)
|
52
|
+
|
53
|
+
# Get feature set configuration
|
14
54
|
self.featset = self.util.config_val("FEATS", "set", "eGeMAPSv02")
|
55
|
+
|
56
|
+
# Validate and set feature set
|
57
|
+
if self.featset not in self.AVAILABLE_FEATURE_SETS:
|
58
|
+
self.util.warning(
|
59
|
+
f"Feature set '{self.featset}' might not be supported. "
|
60
|
+
f"Available sets: {', '.join(self.AVAILABLE_FEATURE_SETS)}"
|
61
|
+
)
|
62
|
+
|
15
63
|
try:
|
16
64
|
self.feature_set = eval(f"opensmile.FeatureSet.{self.featset}")
|
17
|
-
|
18
|
-
|
19
|
-
|
65
|
+
except (AttributeError, SyntaxError) as e:
|
66
|
+
self.util.error(f"Invalid feature set: {self.featset}. Error: {str(e)}")
|
67
|
+
raise ValueError(f"Invalid feature set: {self.featset}")
|
68
|
+
|
69
|
+
# Get feature level configuration
|
20
70
|
self.featlevel = self.util.config_val("FEATS", "level", "functionals")
|
71
|
+
|
72
|
+
# Convert shorthand names to full OpenSMILE names
|
73
|
+
if self.featlevel == "lld":
|
74
|
+
self.featlevel = "LowLevelDescriptors"
|
75
|
+
elif self.featlevel == "functionals":
|
76
|
+
self.featlevel = "Functionals"
|
77
|
+
|
78
|
+
# Validate and set feature level
|
79
|
+
if self.featlevel not in self.AVAILABLE_FEATURE_LEVELS:
|
80
|
+
self.util.warning(
|
81
|
+
f"Feature level '{self.featlevel}' might not be supported. "
|
82
|
+
f"Available levels: {', '.join(self.AVAILABLE_FEATURE_LEVELS)}"
|
83
|
+
)
|
84
|
+
|
21
85
|
try:
|
22
|
-
self.featlevel = self.featlevel.replace("lld", "LowLevelDescriptors")
|
23
|
-
self.featlevel = self.featlevel.replace("functionals", "Functionals")
|
24
86
|
self.feature_level = eval(f"opensmile.FeatureLevel.{self.featlevel}")
|
25
|
-
except AttributeError:
|
26
|
-
self.util.error(f"
|
87
|
+
except (AttributeError, SyntaxError) as e:
|
88
|
+
self.util.error(f"Invalid feature level: {self.featlevel}. Error: {str(e)}")
|
89
|
+
raise ValueError(f"Invalid feature level: {self.featlevel}")
|
90
|
+
|
91
|
+
def extract(self) -> pd.DataFrame:
|
92
|
+
"""Extract the features based on the initialized dataset or load them from disk if available.
|
93
|
+
|
94
|
+
This method checks if features are already stored on disk and loads them if available,
|
95
|
+
otherwise it extracts features using OpenSMILE.
|
27
96
|
|
28
|
-
|
29
|
-
|
97
|
+
Returns:
|
98
|
+
pd.DataFrame: DataFrame containing the extracted features
|
99
|
+
|
100
|
+
Raises:
|
101
|
+
RuntimeError: If feature extraction fails
|
102
|
+
"""
|
30
103
|
store = self.util.get_path("store")
|
31
104
|
store_format = self.util.config_val("FEATS", "store_format", "pkl")
|
32
105
|
storage = f"{store}{self.name}.{store_format}"
|
106
|
+
|
107
|
+
# Check if we need to extract features or use existing ones
|
33
108
|
extract = eval(
|
34
109
|
self.util.config_val("FEATS", "needs_feature_extraction", "False")
|
35
110
|
)
|
36
111
|
no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
|
112
|
+
|
37
113
|
if extract or not os.path.isfile(storage) or no_reuse:
|
38
|
-
self.util.debug("
|
114
|
+
self.util.debug("Extracting OpenSMILE features, this might take a while...")
|
115
|
+
|
116
|
+
try:
|
117
|
+
smile = opensmile.Smile(
|
118
|
+
feature_set=self.feature_set,
|
119
|
+
feature_level=self.feature_level,
|
120
|
+
num_workers=self.n_jobs,
|
121
|
+
verbose=True,
|
122
|
+
)
|
123
|
+
|
124
|
+
# Extract features based on index type
|
125
|
+
if isinstance(self.data_df.index, pd.MultiIndex):
|
126
|
+
self.df = smile.process_index(self.data_df.index)
|
127
|
+
self.df = self.df.set_index(self.data_df.index)
|
128
|
+
else:
|
129
|
+
self.df = smile.process_files(self.data_df.index)
|
130
|
+
# Clean up the index
|
131
|
+
if self.df.index.nlevels > 1:
|
132
|
+
self.df.index = self.df.index.droplevel(1)
|
133
|
+
self.df.index = self.df.index.droplevel(1)
|
134
|
+
|
135
|
+
# Save extracted features
|
136
|
+
self.util.write_store(self.df, storage, store_format)
|
137
|
+
|
138
|
+
# Update configuration to avoid re-extraction
|
139
|
+
try:
|
140
|
+
glob_conf.config["DATA"]["needs_feature_extraction"] = "False"
|
141
|
+
except KeyError:
|
142
|
+
pass
|
143
|
+
|
144
|
+
except Exception as e:
|
145
|
+
self.util.error(f"Feature extraction failed: {str(e)}")
|
146
|
+
raise RuntimeError(f"Feature extraction failed: {str(e)}")
|
147
|
+
|
148
|
+
else:
|
149
|
+
self.util.debug(f"Reusing extracted OpenSMILE features from: {storage}")
|
150
|
+
try:
|
151
|
+
self.df = self.util.get_store(storage, store_format)
|
152
|
+
except Exception as e:
|
153
|
+
self.util.error(f"Failed to load stored features: {str(e)}")
|
154
|
+
raise RuntimeError(f"Failed to load stored features: {str(e)}")
|
155
|
+
|
156
|
+
return self.df
|
157
|
+
|
158
|
+
def extract_sample(self, signal: np.ndarray, sr: int) -> np.ndarray:
|
159
|
+
"""Extract features from a single audio sample.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
signal (np.ndarray): Audio signal as numpy array
|
163
|
+
sr (int): Sample rate of the audio signal
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
np.ndarray: Extracted features as numpy array
|
167
|
+
|
168
|
+
Raises:
|
169
|
+
ValueError: If signal or sample rate is invalid
|
170
|
+
"""
|
171
|
+
if signal is None or len(signal) == 0:
|
172
|
+
raise ValueError("Empty or invalid audio signal provided")
|
173
|
+
|
174
|
+
if sr <= 0:
|
175
|
+
raise ValueError(f"Invalid sample rate: {sr}")
|
176
|
+
|
177
|
+
try:
|
39
178
|
smile = opensmile.Smile(
|
40
179
|
feature_set=self.feature_set,
|
41
|
-
feature_level=
|
42
|
-
num_workers=self.n_jobs,
|
43
|
-
verbose=True,
|
180
|
+
feature_level=opensmile.FeatureLevel.Functionals,
|
44
181
|
)
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
182
|
+
feats = smile.process_signal(signal, sr)
|
183
|
+
return feats.to_numpy()
|
184
|
+
except Exception as e:
|
185
|
+
self.util.error(f"Failed to extract features from sample: {str(e)}")
|
186
|
+
raise RuntimeError(f"Failed to extract features from sample: {str(e)}")
|
187
|
+
|
188
|
+
def filter_features(self, feature_list: List[str] = None) -> pd.DataFrame:
|
189
|
+
"""Filter the extracted features to keep only the specified ones.
|
190
|
+
|
191
|
+
Args:
|
192
|
+
feature_list (List[str], optional): List of feature names to keep.
|
193
|
+
If None, uses the list from config.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
pd.DataFrame: Filtered features DataFrame
|
197
|
+
"""
|
198
|
+
# First ensure we're only using features indexed in the target dataframes
|
199
|
+
self.df = self.df[self.df.index.isin(self.data_df.index)]
|
200
|
+
|
201
|
+
if feature_list is None:
|
202
|
+
try:
|
203
|
+
# Try to get feature list from config
|
204
|
+
import ast
|
205
|
+
|
206
|
+
feature_list = ast.literal_eval(
|
207
|
+
glob_conf.config["FEATS"]["os.features"]
|
208
|
+
)
|
209
|
+
except (KeyError, ValueError, SyntaxError):
|
210
|
+
self.util.debug("No feature list specified, using all features")
|
211
|
+
return self.df
|
212
|
+
|
213
|
+
if not feature_list:
|
214
|
+
return self.df
|
215
|
+
|
216
|
+
self.util.debug(f"Selecting features from OpenSMILE: {feature_list}")
|
217
|
+
sel_feats_df = pd.DataFrame(index=self.df.index)
|
218
|
+
hit = False
|
219
|
+
|
220
|
+
for feat in feature_list:
|
53
221
|
try:
|
54
|
-
|
222
|
+
sel_feats_df[feat] = self.df[feat]
|
223
|
+
hit = True
|
55
224
|
except KeyError:
|
56
|
-
|
57
|
-
else:
|
58
|
-
self.util.debug(f"reusing extracted OS features: {storage}.")
|
59
|
-
self.df = self.util.get_store(storage, store_format)
|
225
|
+
self.util.warning(f"Feature '{feat}' not found in extracted features")
|
60
226
|
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
# glob_conf.config["FEATS"]["os.features"]
|
76
|
-
# )
|
77
|
-
# self.util.debug(f"selecting features from opensmile: {selected_features}")
|
78
|
-
# sel_feats_df = pd.DataFrame()
|
79
|
-
# hit = False
|
80
|
-
# for feat in selected_features:
|
81
|
-
# try:
|
82
|
-
# sel_feats_df[feat] = self.df[feat]
|
83
|
-
# hit = True
|
84
|
-
# except KeyError:
|
85
|
-
# pass
|
86
|
-
# if hit:
|
87
|
-
# self.df = sel_feats_df
|
88
|
-
# self.util.debug(
|
89
|
-
# "new feats shape after selecting opensmile features:"
|
90
|
-
# f" {self.df.shape}"
|
91
|
-
# )
|
92
|
-
# except KeyError:
|
93
|
-
# pass
|
227
|
+
if hit:
|
228
|
+
self.df = sel_feats_df
|
229
|
+
self.util.debug(f"New feature shape after selection: {self.df.shape}")
|
230
|
+
|
231
|
+
return self.df
|
232
|
+
|
233
|
+
@staticmethod
|
234
|
+
def get_available_feature_sets() -> List[str]:
|
235
|
+
"""Get a list of available OpenSMILE feature sets.
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
List[str]: List of available feature sets
|
239
|
+
"""
|
240
|
+
return Opensmileset.AVAILABLE_FEATURE_SETS
|
nkululeko/modelrunner.py
CHANGED
@@ -53,8 +53,6 @@ class Modelrunner:
|
|
53
53
|
# epochs are handled by Huggingface API
|
54
54
|
self.model.train()
|
55
55
|
report = self.model.predict()
|
56
|
-
# todo: findout the best epoch -> no need
|
57
|
-
# since load_best_model_at_end is given in training args
|
58
56
|
epoch = epoch_num
|
59
57
|
report.set_id(self.run, epoch)
|
60
58
|
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
@@ -118,16 +116,31 @@ class Modelrunner:
|
|
118
116
|
f"reached patience ({str(patience)}): early stopping"
|
119
117
|
)
|
120
118
|
break
|
121
|
-
# After training, report the best performance and epoch
|
122
|
-
last_report = reports[-1]
|
123
|
-
# self.util.debug(f"Best score at epoch: {self.best_epoch}, UAR: {self.best_performance}") # move to reporter below
|
124
|
-
|
125
|
-
if not plot_epochs:
|
126
|
-
# Do at least one confusion matrix plot
|
127
|
-
self.util.debug(f"plotting last confusion matrix to {plot_name}")
|
128
|
-
last_report.plot_confmatrix(plot_name, epoch_index)
|
129
119
|
return reports, epoch
|
130
120
|
|
121
|
+
def eval_last_model(self, df_test, feats_test):
|
122
|
+
self.model.reset_test(df_test, feats_test)
|
123
|
+
report = self.model.predict()
|
124
|
+
report.set_id(self.run, 0)
|
125
|
+
return report
|
126
|
+
|
127
|
+
def eval_specific_model(self, model, df_test, feats_test):
|
128
|
+
self.model = model
|
129
|
+
self.util.debug(f"evaluating model: {self.model.store_path}")
|
130
|
+
self.model.reset_test(df_test, feats_test)
|
131
|
+
report = self.model.predict()
|
132
|
+
report.set_id(self.run, 0)
|
133
|
+
return report
|
134
|
+
|
135
|
+
def _check_balancing(self):
|
136
|
+
if self.util.config_val("EXP", "balancing", False):
|
137
|
+
self.util.debug("balancing data")
|
138
|
+
self.df_train, self.df_test = self.util.balance_data(
|
139
|
+
self.df_train, self.df_test
|
140
|
+
)
|
141
|
+
self.util.debug(f"new train size: {self.df_train.shape}")
|
142
|
+
self.util.debug(f"new test size: {self.df_test.shape}")
|
143
|
+
|
131
144
|
def _select_model(self, model_type):
|
132
145
|
self._check_balancing()
|
133
146
|
|
nkululeko/models/model_mlp.py
CHANGED
nkululeko/nkululeko.py
CHANGED
nkululeko/plots.py
CHANGED
@@ -303,6 +303,7 @@ class Plots:
|
|
303
303
|
plot_df = plot_df.rename(columns={cont_col: self.target})
|
304
304
|
cont_col = self.target
|
305
305
|
dist_type = self.util.config_val("EXPL", "dist_type", "kde")
|
306
|
+
fill_areas = eval(self.util.config_val("PLOT", "fill_areas", "False"))
|
306
307
|
max_cat, cat_str, effect_results = su.get_effect_size(
|
307
308
|
plot_df, cat_col, cont_col
|
308
309
|
)
|
@@ -324,7 +325,7 @@ class Plots:
|
|
324
325
|
x=cont_col,
|
325
326
|
hue=cat_col,
|
326
327
|
kind="kde",
|
327
|
-
fill=
|
328
|
+
fill=fill_areas,
|
328
329
|
warn_singular=False,
|
329
330
|
)
|
330
331
|
ax.set(xlabel=f"{cont_col}")
|
@@ -604,9 +605,17 @@ class Plots:
|
|
604
605
|
df_plot = pd.DataFrame(
|
605
606
|
{label: df_labels[label], feature: df_features[feature]}
|
606
607
|
)
|
608
|
+
p_val = ""
|
609
|
+
if df_labels[label].nunique() == 2:
|
610
|
+
label_1 = df_labels[label].unique()[0]
|
611
|
+
label_2 = df_labels[label].unique()[1]
|
612
|
+
vals_1 = df_plot[df_plot[label] == label_1][feature].values
|
613
|
+
vals_2 = df_plot[df_plot[label] == label_2][feature].values
|
614
|
+
r_stats = stats.mannwhitneyu(vals_1, vals_2, alternative="two-sided")
|
615
|
+
p_val = f", Mann-Whitney p-val: {r_stats.pvalue:.3f}"
|
607
616
|
ax = sns.violinplot(data=df_plot, x=label, y=feature)
|
608
617
|
label = self.util.config_val("DATA", "target", "class_label")
|
609
|
-
ax.set(title=f"{title} samples", xlabel=label)
|
618
|
+
ax.set(title=f"{title} samples {p_val}", xlabel=label)
|
610
619
|
else:
|
611
620
|
plot_df = pd.concat([df_labels, df_features], axis=1)
|
612
621
|
ax, caption = self._plot2cont(plot_df, label, feature, feature)
|
nkululeko/reporting/reporter.py
CHANGED
@@ -138,7 +138,7 @@ class Reporter:
|
|
138
138
|
self.util.error(f"unknown metric: {self.metric}")
|
139
139
|
return test_result, upper, lower
|
140
140
|
|
141
|
-
def print_probabilities(self):
|
141
|
+
def print_probabilities(self, file_name = None):
|
142
142
|
"""Print the probabilities per class to a file in the store."""
|
143
143
|
if (
|
144
144
|
self.util.exp_is_classification()
|
@@ -168,11 +168,11 @@ class Reporter:
|
|
168
168
|
)
|
169
169
|
probas["uncertainty"] = uncertainty
|
170
170
|
probas["correct"] = probas.predicted == probas.truth
|
171
|
-
|
172
|
-
|
171
|
+
if file_name is None:
|
172
|
+
file_name = self.util.get_pred_name()+".csv"
|
173
173
|
self.probas = probas
|
174
|
-
probas.to_csv(
|
175
|
-
self.util.debug(f"Saved probabilities to {
|
174
|
+
probas.to_csv(file_name)
|
175
|
+
self.util.debug(f"Saved probabilities to {file_name}")
|
176
176
|
plots = Plots()
|
177
177
|
ax, caption = plots.plotcatcont(
|
178
178
|
probas, "correct", "uncertainty", "uncertainty", "correct"
|
@@ -182,7 +182,7 @@ class Reporter:
|
|
182
182
|
caption,
|
183
183
|
"Uncertainty",
|
184
184
|
"uncertainty_samples",
|
185
|
-
|
185
|
+
file_name,
|
186
186
|
)
|
187
187
|
|
188
188
|
def set_id(self, run, epoch):
|
@@ -368,7 +368,7 @@ class Reporter:
|
|
368
368
|
|
369
369
|
res_dir = self.util.get_path("res_dir")
|
370
370
|
rpt = (
|
371
|
-
f"
|
371
|
+
f"Confusion matrix result for epoch: {epoch}, UAR: {uar_str}"
|
372
372
|
+ f", (+-{up_str}/{low_str}), ACC: {acc_str}"
|
373
373
|
)
|
374
374
|
# print(rpt)
|
@@ -392,13 +392,16 @@ class Reporter:
|
|
392
392
|
text_file.write(result_str)
|
393
393
|
self.util.debug(result_str)
|
394
394
|
|
395
|
-
def print_results(self, epoch=None):
|
395
|
+
def print_results(self, epoch=None, file_name = None):
|
396
396
|
if epoch is None:
|
397
397
|
epoch = self.epoch
|
398
398
|
"""Print all evaluation values to text file."""
|
399
399
|
res_dir = self.util.get_path("res_dir")
|
400
|
-
|
401
|
-
|
400
|
+
if file_name is None:
|
401
|
+
file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}.txt"
|
402
|
+
else:
|
403
|
+
self.util.debug(f"####->{file_name}<-####")
|
404
|
+
file_name = f"{res_dir}{file_name}{self.filenameadd}.txt"
|
402
405
|
if self.util.exp_is_classification():
|
403
406
|
labels = glob_conf.labels
|
404
407
|
try:
|
@@ -427,25 +430,6 @@ class Reporter:
|
|
427
430
|
f1_per_class = (
|
428
431
|
f"result per class (F1 score): {c_ress} from epoch: {epoch}"
|
429
432
|
)
|
430
|
-
# the following auc is buggy, preds should be probabilities
|
431
|
-
# if len(np.unique(self.truths)) == 2:
|
432
|
-
# fpr, tpr, _ = roc_curve(self.truths, self.preds)
|
433
|
-
# auc_score = auc(fpr, tpr)
|
434
|
-
# plot_path = f"{fig_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}_roc.{self.format}"
|
435
|
-
# plt.figure()
|
436
|
-
# display = RocCurveDisplay(
|
437
|
-
# fpr=fpr,
|
438
|
-
# tpr=tpr,
|
439
|
-
# roc_auc=auc_score,
|
440
|
-
# estimator_name=f"{self.model_type} estimator",
|
441
|
-
# )
|
442
|
-
# display.plot(ax=None)
|
443
|
-
# plt.savefig(plot_path)
|
444
|
-
# plt.close()
|
445
|
-
# self.util.debug(f"Saved ROC curve to {plot_path}")
|
446
|
-
# pauc_score = roc_auc_score(self.truths, self.preds, max_fpr=0.1)
|
447
|
-
# auc_pauc = f"auc: {auc_score:.3f}, pauc: {pauc_score:.3f} from epoch: {epoch}"
|
448
|
-
# self.util.debug(auc_pauc)
|
449
433
|
self.util.debug(f1_per_class)
|
450
434
|
rpt_str = f"{json.dumps(rpt)}\n{f1_per_class}"
|
451
435
|
# rpt_str += f"\n{auc_auc}"
|
@@ -514,18 +498,12 @@ class Reporter:
|
|
514
498
|
# do a plot per run
|
515
499
|
# scale the losses so they fit on the picture
|
516
500
|
losses, results, train_results, losses_eval = (
|
517
|
-
np.asarray(losses),
|
518
|
-
np.asarray(results),
|
519
|
-
np.asarray(train_results),
|
520
|
-
np.asarray(losses_eval),
|
501
|
+
self._scaleresults(np.asarray(losses)),
|
502
|
+
self._scaleresults(np.asarray(results)),
|
503
|
+
self._scaleresults(np.asarray(train_results)),
|
504
|
+
self._scaleresults(np.asarray(losses_eval)),
|
521
505
|
)
|
522
506
|
|
523
|
-
if np.all((results > 1)):
|
524
|
-
# scale down values
|
525
|
-
results = results / 100.0
|
526
|
-
train_results = train_results / 100.0
|
527
|
-
# if np.all((losses < 1)):
|
528
|
-
# scale up values
|
529
507
|
plt.figure(dpi=200)
|
530
508
|
plt.plot(train_results, "green", label="train set")
|
531
509
|
plt.plot(results, "red", label="dev set")
|
@@ -536,3 +514,11 @@ class Reporter:
|
|
536
514
|
plt.legend()
|
537
515
|
plt.savefig(f"{fig_dir}{out_name}.{self.format}")
|
538
516
|
plt.close()
|
517
|
+
|
518
|
+
def _scaleresults(self, results:np.ndarray) -> np.ndarray:
|
519
|
+
results = results.copy()
|
520
|
+
"""Scale results to fit on the plot."""
|
521
|
+
if np.any((results > 1)):
|
522
|
+
# scale down values
|
523
|
+
results = results / 100.0
|
524
|
+
return results
|