nkululeko 0.86.8__py3-none-any.whl → 0.88.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/data/dataset_csv.py +12 -14
- nkululeko/demo.py +7 -10
- nkululeko/ensemble.py +158 -0
- nkululeko/feat_extract/feats_ast.py +118 -0
- nkululeko/feat_extract/feats_wav2vec2.py +2 -4
- nkululeko/feat_extract/feats_wavlm.py +7 -4
- nkululeko/feature_extractor.py +5 -9
- nkululeko/modelrunner.py +5 -5
- nkululeko/models/model.py +23 -3
- nkululeko/models/model_cnn.py +41 -22
- nkululeko/models/model_mlp.py +37 -17
- nkululeko/models/model_mlp_regression.py +3 -1
- nkululeko/plots.py +25 -37
- nkululeko/reporting/reporter.py +69 -6
- nkululeko/runmanager.py +8 -11
- nkululeko/test_predictor.py +2 -9
- nkululeko/utils/stats.py +11 -7
- nkululeko/utils/util.py +24 -19
- {nkululeko-0.86.8.dist-info → nkululeko-0.88.0.dist-info}/METADATA +22 -1
- {nkululeko-0.86.8.dist-info → nkululeko-0.88.0.dist-info}/RECORD +24 -22
- {nkululeko-0.86.8.dist-info → nkululeko-0.88.0.dist-info}/WHEEL +1 -1
- {nkululeko-0.86.8.dist-info → nkululeko-0.88.0.dist-info}/LICENSE +0 -0
- {nkululeko-0.86.8.dist-info → nkululeko-0.88.0.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.
|
1
|
+
VERSION="0.88.0"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/data/dataset_csv.py
CHANGED
@@ -23,6 +23,9 @@ class Dataset_CSV(Dataset):
|
|
23
23
|
root = os.path.dirname(data_file)
|
24
24
|
audio_path = self.util.config_val_data(self.name, "audio_path", "./")
|
25
25
|
df = pd.read_csv(data_file)
|
26
|
+
# trim all string values
|
27
|
+
df_obj = df.select_dtypes("object")
|
28
|
+
df[df_obj.columns] = df_obj.apply(lambda x: x.str.strip())
|
26
29
|
# special treatment for segmented dataframes with only one column:
|
27
30
|
if "start" in df.columns and len(df.columns) == 4:
|
28
31
|
index = audformat.segmented_index(
|
@@ -49,8 +52,7 @@ class Dataset_CSV(Dataset):
|
|
49
52
|
.map(lambda x: root + "/" + audio_path + "/" + x)
|
50
53
|
.values
|
51
54
|
)
|
52
|
-
df = df.set_index(df.index.set_levels(
|
53
|
-
file_index, level="file"))
|
55
|
+
df = df.set_index(df.index.set_levels(file_index, level="file"))
|
54
56
|
else:
|
55
57
|
if not isinstance(df, pd.DataFrame):
|
56
58
|
df = pd.DataFrame(df)
|
@@ -59,27 +61,24 @@ class Dataset_CSV(Dataset):
|
|
59
61
|
lambda x: root + "/" + audio_path + "/" + x
|
60
62
|
)
|
61
63
|
)
|
62
|
-
else:
|
64
|
+
else: # absolute path is True
|
63
65
|
if audformat.index_type(df.index) == "segmented":
|
64
66
|
file_index = (
|
65
|
-
df.index.levels[0]
|
66
|
-
.map(lambda x: audio_path + "/" + x)
|
67
|
-
.values
|
67
|
+
df.index.levels[0].map(lambda x: audio_path + "/" + x).values
|
68
68
|
)
|
69
|
-
df = df.set_index(df.index.set_levels(
|
70
|
-
file_index, level="file"))
|
69
|
+
df = df.set_index(df.index.set_levels(file_index, level="file"))
|
71
70
|
else:
|
72
71
|
if not isinstance(df, pd.DataFrame):
|
73
72
|
df = pd.DataFrame(df)
|
74
|
-
df = df.set_index(
|
75
|
-
lambda x: audio_path + "/" + x
|
73
|
+
df = df.set_index(
|
74
|
+
df.index.to_series().apply(lambda x: audio_path + "/" + x)
|
75
|
+
)
|
76
76
|
|
77
77
|
self.df = df
|
78
78
|
self.db = None
|
79
79
|
self.got_target = True
|
80
80
|
self.is_labeled = self.got_target
|
81
|
-
self.start_fresh = eval(
|
82
|
-
self.util.config_val("DATA", "no_reuse", "False"))
|
81
|
+
self.start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
|
83
82
|
is_index = False
|
84
83
|
try:
|
85
84
|
if self.is_labeled and not "class_label" in self.df.columns:
|
@@ -106,8 +105,7 @@ class Dataset_CSV(Dataset):
|
|
106
105
|
f" {self.got_gender}, got age: {self.got_age}"
|
107
106
|
)
|
108
107
|
self.util.debug(r_string)
|
109
|
-
glob_conf.report.add_item(ReportItem(
|
110
|
-
"Data", "Loaded report", r_string))
|
108
|
+
glob_conf.report.add_item(ReportItem("Data", "Loaded report", r_string))
|
111
109
|
|
112
110
|
def prepare(self):
|
113
111
|
super().prepare()
|
nkululeko/demo.py
CHANGED
@@ -20,20 +20,19 @@ Options: \n
|
|
20
20
|
import argparse
|
21
21
|
import configparser
|
22
22
|
import os
|
23
|
+
|
23
24
|
import pandas as pd
|
25
|
+
from transformers import pipeline
|
24
26
|
|
27
|
+
import nkululeko.glob_conf as glob_conf
|
25
28
|
from nkululeko.constants import VERSION
|
26
29
|
from nkululeko.experiment import Experiment
|
27
|
-
import nkululeko.glob_conf as glob_conf
|
28
30
|
from nkululeko.utils.util import Util
|
29
|
-
from transformers import pipeline
|
30
31
|
|
31
32
|
|
32
33
|
def main(src_dir):
|
33
|
-
parser = argparse.ArgumentParser(
|
34
|
-
|
35
|
-
parser.add_argument("--config", default="exp.ini",
|
36
|
-
help="The base configuration")
|
34
|
+
parser = argparse.ArgumentParser(description="Call the nkululeko DEMO framework.")
|
35
|
+
parser.add_argument("--config", default="exp.ini", help="The base configuration")
|
37
36
|
parser.add_argument(
|
38
37
|
"--file", help="A file that should be processed (16kHz mono wav)"
|
39
38
|
)
|
@@ -84,8 +83,7 @@ def main(src_dir):
|
|
84
83
|
)
|
85
84
|
|
86
85
|
def print_pipe(files, outfile):
|
87
|
-
"""
|
88
|
-
Prints the pipeline output for a list of files, and optionally writes the results to an output file.
|
86
|
+
"""Prints the pipeline output for a list of files, and optionally writes the results to an output file.
|
89
87
|
|
90
88
|
Args:
|
91
89
|
files (list): A list of file paths to process through the pipeline.
|
@@ -108,8 +106,7 @@ def main(src_dir):
|
|
108
106
|
f.write("\n".join(results))
|
109
107
|
|
110
108
|
if util.get_model_type() == "finetune":
|
111
|
-
model_path = os.path.join(
|
112
|
-
util.get_exp_dir(), "models", "run_0", "torch")
|
109
|
+
model_path = os.path.join(util.get_exp_dir(), "models", "run_0", "torch")
|
113
110
|
pipe = pipeline("audio-classification", model=model_path)
|
114
111
|
if args.file is not None:
|
115
112
|
print_pipe([args.file], args.outfile)
|
nkululeko/ensemble.py
ADDED
@@ -0,0 +1,158 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
|
5
|
+
import configparser
|
6
|
+
import time
|
7
|
+
from argparse import ArgumentParser
|
8
|
+
from pathlib import Path
|
9
|
+
|
10
|
+
import pandas as pd
|
11
|
+
|
12
|
+
from nkululeko.constants import VERSION
|
13
|
+
from nkululeko.experiment import Experiment
|
14
|
+
from nkululeko.utils.util import Util
|
15
|
+
|
16
|
+
|
17
|
+
def ensemble_predictions(config_files, method, no_labels):
|
18
|
+
"""
|
19
|
+
Ensemble predictions from multiple experiments.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
config_files (list): List of configuration file paths.
|
23
|
+
method (str): Ensemble method to use. Options are 'majority_voting', 'mean', 'max', or 'sum'.
|
24
|
+
no_labels (bool): Flag indicating whether the predictions have labels or not.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
pandas.DataFrame: The ensemble predictions.
|
28
|
+
|
29
|
+
Raises:
|
30
|
+
ValueError: If an unknown ensemble method is provided.
|
31
|
+
AssertionError: If the number of config files is less than 2 for majority voting.
|
32
|
+
|
33
|
+
"""
|
34
|
+
ensemble_preds = []
|
35
|
+
# labels = []
|
36
|
+
for config_file in config_files:
|
37
|
+
if no_labels:
|
38
|
+
# for ensembling results from Nkululeko.demo
|
39
|
+
pred = pd.read_csv(config_file)
|
40
|
+
labels = pred.columns[1:-2]
|
41
|
+
else:
|
42
|
+
# for ensembling results from Nkululeko.nkululeko
|
43
|
+
config = configparser.ConfigParser()
|
44
|
+
config.read(config_file)
|
45
|
+
expr = Experiment(config)
|
46
|
+
module = "ensemble"
|
47
|
+
expr.set_module(module)
|
48
|
+
util = Util(module, has_config=True)
|
49
|
+
util.debug(
|
50
|
+
f"running {expr.name} from config {config_file}, nkululeko version"
|
51
|
+
f" {VERSION}"
|
52
|
+
)
|
53
|
+
|
54
|
+
# get labels
|
55
|
+
labels = expr.util.get_labels()
|
56
|
+
# load the experiment
|
57
|
+
# get CSV files of predictions
|
58
|
+
pred = expr.util.get_pred_name()
|
59
|
+
print(f"Loading predictions from {pred}")
|
60
|
+
preds = pd.read_csv(pred)
|
61
|
+
|
62
|
+
ensemble_preds.append(preds)
|
63
|
+
|
64
|
+
# pd concate
|
65
|
+
ensemble_preds = pd.concat(ensemble_preds, axis=1)
|
66
|
+
|
67
|
+
if method == "majority_voting":
|
68
|
+
# majority voting, get mode, works for odd number of models
|
69
|
+
# raise error when number of configs only two:
|
70
|
+
assert (
|
71
|
+
len(config_files) > 2
|
72
|
+
), "Majority voting only works for more than two models"
|
73
|
+
ensemble_preds["predicted"] = ensemble_preds.mode(axis=1)[0]
|
74
|
+
|
75
|
+
elif method == "mean":
|
76
|
+
for label in labels:
|
77
|
+
ensemble_preds[label] = ensemble_preds[label].mean(axis=1)
|
78
|
+
|
79
|
+
elif method == "max":
|
80
|
+
for label in labels:
|
81
|
+
ensemble_preds[label] = ensemble_preds[label].max(axis=1)
|
82
|
+
# get max value from all labels to inver that labels
|
83
|
+
|
84
|
+
elif method == "sum":
|
85
|
+
for label in labels:
|
86
|
+
ensemble_preds[label] = ensemble_preds[label].sum(axis=1)
|
87
|
+
|
88
|
+
else:
|
89
|
+
raise ValueError(f"Unknown ensemble method: {method}")
|
90
|
+
|
91
|
+
# get the highest value from all labels to inver that labels
|
92
|
+
# replace the old first predicted column
|
93
|
+
ensemble_preds["predicted"] = ensemble_preds[labels].idxmax(axis=1)
|
94
|
+
|
95
|
+
if no_labels:
|
96
|
+
return ensemble_preds
|
97
|
+
|
98
|
+
# Drop start, end columns
|
99
|
+
ensemble_preds = ensemble_preds.drop(columns=["start", "end"])
|
100
|
+
|
101
|
+
# Drop other column except until truth
|
102
|
+
ensemble_preds = ensemble_preds.iloc[:, : len(labels) + 3]
|
103
|
+
|
104
|
+
# calculate UAR from predicted and truth columns
|
105
|
+
|
106
|
+
truth = ensemble_preds["truth"]
|
107
|
+
predicted = ensemble_preds["predicted"]
|
108
|
+
uar = (truth == predicted).mean()
|
109
|
+
Util("ensemble").debug(f"UAR: {uar:.3f}")
|
110
|
+
|
111
|
+
# only return until 'predicted' column
|
112
|
+
return ensemble_preds
|
113
|
+
|
114
|
+
|
115
|
+
def main(src_dir):
|
116
|
+
parser = ArgumentParser()
|
117
|
+
parser.add_argument(
|
118
|
+
"configs",
|
119
|
+
nargs="+",
|
120
|
+
help="Paths to the configuration files of the experiments to ensemble. \
|
121
|
+
Can be INI files for Nkululeko.nkululeo or CSV files from Nkululeko.demo.",
|
122
|
+
)
|
123
|
+
parser.add_argument(
|
124
|
+
"--method",
|
125
|
+
default="majority_voting",
|
126
|
+
choices=["majority_voting", "mean", "max", "sum"],
|
127
|
+
help="Ensemble method to use (default: majority_voting)",
|
128
|
+
)
|
129
|
+
parser.add_argument(
|
130
|
+
"--outfile",
|
131
|
+
default="ensemble_result.csv",
|
132
|
+
help="Output file path for the ensemble predictions (default: ensemble_predictions.csv)",
|
133
|
+
)
|
134
|
+
|
135
|
+
# add argument if true label is not available
|
136
|
+
parser.add_argument(
|
137
|
+
"--no_labels",
|
138
|
+
action="store_true",
|
139
|
+
help="True if true labels are not available. For Nkululeko.demo results.",
|
140
|
+
)
|
141
|
+
|
142
|
+
args = parser.parse_args()
|
143
|
+
|
144
|
+
start = time.time()
|
145
|
+
|
146
|
+
ensemble_preds = ensemble_predictions(args.configs, args.method, args.no_labels)
|
147
|
+
|
148
|
+
# save to csv
|
149
|
+
ensemble_preds.to_csv(args.outfile, index=False)
|
150
|
+
print(f"Ensemble predictions saved to: {args.outfile}")
|
151
|
+
print(f"Ensemble done, used {time.time()-start:.2f} seconds")
|
152
|
+
|
153
|
+
print("DONE")
|
154
|
+
|
155
|
+
|
156
|
+
if __name__ == "__main__":
|
157
|
+
cwd = Path(__file__).parent
|
158
|
+
main(cwd)
|
@@ -0,0 +1,118 @@
|
|
1
|
+
# feats_ast.py
|
2
|
+
import os
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
import torch
|
7
|
+
import torch.nn.functional as F
|
8
|
+
import torchaudio
|
9
|
+
from tqdm import tqdm
|
10
|
+
from transformers import AutoProcessor, ASTModel
|
11
|
+
|
12
|
+
import nkululeko.glob_conf as glob_conf
|
13
|
+
from nkululeko.feat_extract.featureset import Featureset
|
14
|
+
|
15
|
+
|
16
|
+
class Ast(Featureset):
|
17
|
+
"""Class to extract AST (Audio Spectrogram Transformer) embeddings"""
|
18
|
+
|
19
|
+
def __init__(self, name, data_df, feat_type):
|
20
|
+
super().__init__(name, data_df, feat_type)
|
21
|
+
cuda = "cuda" if torch.cuda.is_available() else "cpu"
|
22
|
+
self.device = self.util.config_val("MODEL", "device", cuda)
|
23
|
+
self.model_initialized = False
|
24
|
+
self.feat_type = feat_type
|
25
|
+
|
26
|
+
def init_model(self):
|
27
|
+
self.util.debug("loading AST model...")
|
28
|
+
model_path = self.util.config_val(
|
29
|
+
"FEATS", "ast.model", "MIT/ast-finetuned-audioset-10-10-0.4593"
|
30
|
+
)
|
31
|
+
self.processor = AutoProcessor.from_pretrained(model_path)
|
32
|
+
self.model = ASTModel.from_pretrained(model_path).to(self.device)
|
33
|
+
print(f"initialized AST model on {self.device}")
|
34
|
+
self.model.eval()
|
35
|
+
self.model_initialized = True
|
36
|
+
|
37
|
+
|
38
|
+
def extract(self):
|
39
|
+
"""Extract the features or load them from disk if present."""
|
40
|
+
store = self.util.get_path("store")
|
41
|
+
storage = f"{store}{self.name}.pkl"
|
42
|
+
extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
|
43
|
+
no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
|
44
|
+
if extract or no_reuse or not os.path.isfile(storage):
|
45
|
+
if not self.model_initialized:
|
46
|
+
self.init_model()
|
47
|
+
self.util.debug("extracting wavlm embeddings, this might take a while...")
|
48
|
+
emb_series = pd.Series(index=self.data_df.index, dtype=object)
|
49
|
+
length = len(self.data_df.index)
|
50
|
+
for idx, (file, start, end) in enumerate(
|
51
|
+
tqdm(self.data_df.index.to_list())
|
52
|
+
):
|
53
|
+
signal, sampling_rate = torchaudio.load(
|
54
|
+
file,
|
55
|
+
frame_offset=int(start.total_seconds() * 16000),
|
56
|
+
num_frames=int((end - start).total_seconds() * 16000),
|
57
|
+
)
|
58
|
+
# make mono if stereo
|
59
|
+
if signal.shape[0] == 2:
|
60
|
+
signal = torch.mean(signal, dim=0, keepdim=True)
|
61
|
+
|
62
|
+
assert (
|
63
|
+
sampling_rate == 16000
|
64
|
+
), f"sampling rate should be 16000 but is {sampling_rate}"
|
65
|
+
emb = self.get_embeddings(signal, sampling_rate, file)
|
66
|
+
emb_series.iloc[idx] = emb
|
67
|
+
self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
|
68
|
+
self.df.to_pickle(storage)
|
69
|
+
try:
|
70
|
+
glob_conf.config["DATA"]["needs_feature_extraction"] = "false"
|
71
|
+
except KeyError:
|
72
|
+
pass
|
73
|
+
else:
|
74
|
+
self.util.debug(f"reusing extracted {self.feat_type} embeddings")
|
75
|
+
self.df = pd.read_pickle(storage)
|
76
|
+
if self.df.isnull().values.any():
|
77
|
+
# nanrows = self.df.columns[self.df.isna().any()].tolist()
|
78
|
+
# print(nanrows)
|
79
|
+
self.util.error(
|
80
|
+
f"got nan: {self.df.shape} {self.df.isnull().sum().sum()}"
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
def get_embeddings(self, signal, sampling_rate, file):
|
85
|
+
"""Extract embeddings from raw audio signal."""
|
86
|
+
try:
|
87
|
+
inputs = self.processor(signal.numpy(), sampling_rate=sampling_rate, return_tensors="pt")
|
88
|
+
|
89
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
90
|
+
|
91
|
+
with torch.no_grad():
|
92
|
+
# Get the hidden states
|
93
|
+
outputs = self.model(**inputs)
|
94
|
+
|
95
|
+
# Get the hidden states from the last layer
|
96
|
+
last_hidden_state = outputs.last_hidden_state
|
97
|
+
|
98
|
+
# print(f"last_hidden_state shape: {last_hidden_state.shape}")
|
99
|
+
# Average pooling over the time dimension
|
100
|
+
embeddings = torch.mean(last_hidden_state, dim=1)
|
101
|
+
embeddings = embeddings.cpu().numpy()
|
102
|
+
|
103
|
+
# convert the same from (768,) to (1, 768)
|
104
|
+
# embeddings = embeddings.reshape(1, -1)
|
105
|
+
print(f"hs shape: {embeddings.shape}")
|
106
|
+
|
107
|
+
|
108
|
+
except Exception as e:
|
109
|
+
self.util.error(f"Error extracting embeddings for file {file}: {str(e)}, fill with")
|
110
|
+
return np.zeros(
|
111
|
+
self.model.config.hidden_size
|
112
|
+
) # Return zero vector on error
|
113
|
+
return embeddings.ravel()
|
114
|
+
|
115
|
+
def extract_sample(self, signal, sr):
|
116
|
+
self.init_model()
|
117
|
+
feats = self.get_embeddings(signal, sr, "no file")
|
118
|
+
return feats
|
@@ -47,9 +47,7 @@ class Wav2vec2(Featureset):
|
|
47
47
|
config.num_hidden_layers = layer_num - hidden_layer
|
48
48
|
self.util.debug(f"using hidden layer #{config.num_hidden_layers}")
|
49
49
|
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
|
50
|
-
self.model = Wav2Vec2Model.from_pretrained(model_path, config=config).to(
|
51
|
-
self.device
|
52
|
-
)
|
50
|
+
self.model = Wav2Vec2Model.from_pretrained(model_path, config=config).to(self.device)
|
53
51
|
print(f"intialized Wav2vec model on {self.device}")
|
54
52
|
self.model.eval()
|
55
53
|
self.model_initialized = True
|
@@ -90,7 +88,7 @@ class Wav2vec2(Featureset):
|
|
90
88
|
self.util.debug("reusing extracted wav2vec2 embeddings")
|
91
89
|
self.df = pd.read_pickle(storage)
|
92
90
|
if self.df.isnull().values.any():
|
93
|
-
nanrows = self.df.columns[self.df.isna().any()].tolist()
|
91
|
+
# nanrows = self.df.columns[self.df.isna().any()].tolist()
|
94
92
|
# print(nanrows)
|
95
93
|
self.util.error(
|
96
94
|
f"got nan: {self.df.shape} {self.df.isnull().sum().sum()}"
|
@@ -79,8 +79,8 @@ class Wavlm(Featureset):
|
|
79
79
|
self.util.debug(f"reusing extracted {self.feat_type} embeddings")
|
80
80
|
self.df = pd.read_pickle(storage)
|
81
81
|
if self.df.isnull().values.any():
|
82
|
-
nanrows = self.df.columns[self.df.isna().any()].tolist()
|
83
|
-
print(nanrows)
|
82
|
+
# nanrows = self.df.columns[self.df.isna().any()].tolist()
|
83
|
+
# print(nanrows)
|
84
84
|
self.util.error(
|
85
85
|
f"got nan: {self.df.shape} {self.df.isnull().sum().sum()}"
|
86
86
|
)
|
@@ -104,11 +104,14 @@ class Wavlm(Featureset):
|
|
104
104
|
# pool result and convert to numpy
|
105
105
|
y = torch.mean(y, dim=1)
|
106
106
|
y = y.detach().cpu().numpy()
|
107
|
+
|
108
|
+
# print(f"hs shape: {y.shape}")
|
109
|
+
|
107
110
|
except RuntimeError as re:
|
108
111
|
print(str(re))
|
109
|
-
self.util.error(f"
|
112
|
+
self.util.error(f"Couldn't extract file: {file}")
|
110
113
|
|
111
|
-
return y.
|
114
|
+
return y.ravel()
|
112
115
|
|
113
116
|
def extract_sample(self, signal, sr):
|
114
117
|
self.init_model()
|
nkululeko/feature_extractor.py
CHANGED
@@ -39,12 +39,10 @@ class FeatureExtractor:
|
|
39
39
|
self.feats = pd.DataFrame()
|
40
40
|
for feats_type in self.feats_types:
|
41
41
|
store_name = f"{self.data_name}_{feats_type}"
|
42
|
-
self.feat_extractor = self._get_feat_extractor(
|
43
|
-
store_name, feats_type)
|
42
|
+
self.feat_extractor = self._get_feat_extractor(store_name, feats_type)
|
44
43
|
self.feat_extractor.extract()
|
45
44
|
self.feat_extractor.filter()
|
46
|
-
self.feats = pd.concat(
|
47
|
-
[self.feats, self.feat_extractor.df], axis=1)
|
45
|
+
self.feats = pd.concat([self.feats, self.feat_extractor.df], axis=1)
|
48
46
|
return self.feats
|
49
47
|
|
50
48
|
def extract_sample(self, signal, sr):
|
@@ -77,7 +75,7 @@ class FeatureExtractor:
|
|
77
75
|
return TRILLset
|
78
76
|
|
79
77
|
elif feats_type.startswith(
|
80
|
-
("wav2vec2", "hubert", "wavlm", "spkrec", "whisper")
|
78
|
+
("wav2vec2", "hubert", "wavlm", "spkrec", "whisper", "ast")
|
81
79
|
):
|
82
80
|
return self._get_feat_extractor_by_prefix(feats_type)
|
83
81
|
|
@@ -107,15 +105,13 @@ class FeatureExtractor:
|
|
107
105
|
prefix, _, ext = feats_type.partition("-")
|
108
106
|
from importlib import import_module
|
109
107
|
|
110
|
-
module = import_module(
|
111
|
-
f"nkululeko.feat_extract.feats_{prefix.lower()}")
|
108
|
+
module = import_module(f"nkululeko.feat_extract.feats_{prefix.lower()}")
|
112
109
|
class_name = f"{prefix.capitalize()}"
|
113
110
|
return getattr(module, class_name)
|
114
111
|
|
115
112
|
def _get_feat_extractor_by_name(self, feats_type):
|
116
113
|
from importlib import import_module
|
117
114
|
|
118
|
-
module = import_module(
|
119
|
-
f"nkululeko.feat_extract.feats_{feats_type.lower()}")
|
115
|
+
module = import_module(f"nkululeko.feat_extract.feats_{feats_type.lower()}")
|
120
116
|
class_name = f"{feats_type.capitalize()}Set"
|
121
117
|
return getattr(module, class_name)
|
nkululeko/modelrunner.py
CHANGED
@@ -85,7 +85,7 @@ class Modelrunner:
|
|
85
85
|
f"run: {self.run} epoch: {epoch}: result: {test_score_metric}"
|
86
86
|
)
|
87
87
|
# print(f"performance: {performance.split(' ')[1]}")
|
88
|
-
performance = float(test_score_metric.split(
|
88
|
+
performance = float(test_score_metric.split(" ")[1])
|
89
89
|
if performance > self.best_performance:
|
90
90
|
self.best_performance = performance
|
91
91
|
self.best_epoch = epoch
|
@@ -204,15 +204,15 @@ class Modelrunner:
|
|
204
204
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
205
205
|
)
|
206
206
|
elif model_type == "cnn":
|
207
|
-
from nkululeko.models.model_cnn import
|
207
|
+
from nkululeko.models.model_cnn import CNNModel
|
208
208
|
|
209
|
-
self.model =
|
209
|
+
self.model = CNNModel(
|
210
210
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
211
211
|
)
|
212
212
|
elif model_type == "mlp":
|
213
|
-
from nkululeko.models.model_mlp import
|
213
|
+
from nkululeko.models.model_mlp import MLPModel
|
214
214
|
|
215
|
-
self.model =
|
215
|
+
self.model = MLPModel(
|
216
216
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
217
217
|
)
|
218
218
|
elif model_type == "mlp_reg":
|
nkululeko/models/model.py
CHANGED
@@ -247,8 +247,25 @@ class Model:
|
|
247
247
|
self.clf.fit(feats, labels)
|
248
248
|
|
249
249
|
def get_predictions(self):
|
250
|
-
predictions = self.clf.predict(self.feats_test.to_numpy())
|
251
|
-
|
250
|
+
# predictions = self.clf.predict(self.feats_test.to_numpy())
|
251
|
+
if self.util.exp_is_classification():
|
252
|
+
# make a dataframe for the class probabilities
|
253
|
+
proba_d = {}
|
254
|
+
for c in self.clf.classes_:
|
255
|
+
proba_d[c] = []
|
256
|
+
# get the class probabilities
|
257
|
+
predictions = self.clf.predict_proba(self.feats_test.to_numpy())
|
258
|
+
# pred = self.clf.predict(features)
|
259
|
+
for i, c in enumerate(self.clf.classes_):
|
260
|
+
proba_d[c] = list(predictions.T[i])
|
261
|
+
probas = pd.DataFrame(proba_d)
|
262
|
+
probas = probas.set_index(self.feats_test.index)
|
263
|
+
predictions = probas.idxmax(axis=1).values
|
264
|
+
else:
|
265
|
+
predictions = self.clf.predict(self.feats_test.to_numpy())
|
266
|
+
probas = None
|
267
|
+
|
268
|
+
return predictions, probas
|
252
269
|
|
253
270
|
def predict(self):
|
254
271
|
if self.feats_test.isna().to_numpy().any():
|
@@ -263,13 +280,16 @@ class Model:
|
|
263
280
|
)
|
264
281
|
return report
|
265
282
|
"""Predict the whole eval feature set"""
|
266
|
-
predictions = self.get_predictions()
|
283
|
+
predictions, probas = self.get_predictions()
|
284
|
+
|
267
285
|
report = Reporter(
|
268
286
|
self.df_test[self.target].to_numpy().astype(float),
|
269
287
|
predictions,
|
270
288
|
self.run,
|
271
289
|
self.epoch,
|
290
|
+
probas=probas,
|
272
291
|
)
|
292
|
+
report.print_probabilities()
|
273
293
|
return report
|
274
294
|
|
275
295
|
def get_type(self):
|
nkululeko/models/model_cnn.py
CHANGED
@@ -5,33 +5,40 @@ Inspired by code from Su Lei
|
|
5
5
|
|
6
6
|
"""
|
7
7
|
|
8
|
+
import ast
|
9
|
+
from collections import OrderedDict
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
from PIL import Image
|
14
|
+
from sklearn.metrics import recall_score
|
8
15
|
import torch
|
9
16
|
import torch.nn as nn
|
10
17
|
import torch.nn.functional as F
|
11
|
-
import torchvision
|
12
|
-
import torchvision.transforms as transforms
|
13
18
|
from torch.utils.data import Dataset
|
14
|
-
import
|
15
|
-
import numpy as np
|
16
|
-
from sklearn.metrics import recall_score
|
17
|
-
from collections import OrderedDict
|
18
|
-
from PIL import Image
|
19
|
-
from traitlets import default
|
19
|
+
import torchvision.transforms as transforms
|
20
20
|
|
21
|
-
from nkululeko.utils.util import Util
|
22
21
|
import nkululeko.glob_conf as glob_conf
|
22
|
+
from nkululeko.losses.loss_softf1loss import SoftF1Loss
|
23
23
|
from nkululeko.models.model import Model
|
24
24
|
from nkululeko.reporting.reporter import Reporter
|
25
|
-
from nkululeko.
|
25
|
+
from nkululeko.utils.util import Util
|
26
26
|
|
27
27
|
|
28
|
-
class
|
29
|
-
"""CNN = convolutional neural net"""
|
28
|
+
class CNNModel(Model):
|
29
|
+
"""CNN = convolutional neural net."""
|
30
30
|
|
31
31
|
is_classifier = True
|
32
32
|
|
33
33
|
def __init__(self, df_train, df_test, feats_train, feats_test):
|
34
|
-
"""Constructor taking
|
34
|
+
"""Constructor, taking all dataframes.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
df_train (pd.DataFrame): The train labels.
|
38
|
+
df_test (pd.DataFrame): The test labels.
|
39
|
+
feats_train (pd.DataFrame): The train features.
|
40
|
+
feats_test (pd.DataFrame): The test features.
|
41
|
+
"""
|
35
42
|
super().__init__(df_train, df_test, feats_train, feats_test)
|
36
43
|
super().set_model_type("ann")
|
37
44
|
self.name = "cnn"
|
@@ -147,7 +154,20 @@ class CNN_model(Model):
|
|
147
154
|
self.optimizer.step()
|
148
155
|
self.loss = (np.asarray(losses)).mean()
|
149
156
|
|
150
|
-
def
|
157
|
+
def get_probas(self, logits):
|
158
|
+
# make a dataframe for probabilites (logits)
|
159
|
+
proba_d = {}
|
160
|
+
classes = self.df_test[self.target].unique()
|
161
|
+
classes.sort()
|
162
|
+
for c in classes:
|
163
|
+
proba_d[c] = []
|
164
|
+
for i, c in enumerate(classes):
|
165
|
+
proba_d[c] = list(logits.numpy().T[i])
|
166
|
+
probas = pd.DataFrame(proba_d)
|
167
|
+
probas = probas.set_index(self.df_test.index)
|
168
|
+
return probas
|
169
|
+
|
170
|
+
def evaluate(self, model, loader, device):
|
151
171
|
logits = torch.zeros(len(loader.dataset), self.class_num)
|
152
172
|
targets = torch.zeros(len(loader.dataset))
|
153
173
|
model.eval()
|
@@ -169,14 +189,15 @@ class CNN_model(Model):
|
|
169
189
|
self.loss_eval = (np.asarray(losses)).mean()
|
170
190
|
predictions = logits.argmax(dim=1)
|
171
191
|
uar = recall_score(targets.numpy(), predictions.numpy(), average="macro")
|
172
|
-
return uar, targets, predictions
|
192
|
+
return uar, targets, predictions, logits
|
173
193
|
|
174
194
|
def predict(self):
|
175
|
-
_, truths, predictions = self.
|
195
|
+
_, truths, predictions, logits = self.evaluate(
|
176
196
|
self.model, self.testloader, self.device
|
177
197
|
)
|
178
|
-
uar, _, _ = self.
|
179
|
-
|
198
|
+
uar, _, _, _ = self.evaluate(self.model, self.trainloader, self.device)
|
199
|
+
probas = self.get_probas(logits)
|
200
|
+
report = Reporter(truths, predictions, self.run, self.epoch, probas=probas)
|
180
201
|
try:
|
181
202
|
report.result.loss = self.loss
|
182
203
|
except AttributeError: # if the model was loaded from disk the loss is unknown
|
@@ -189,13 +210,11 @@ class CNN_model(Model):
|
|
189
210
|
return report
|
190
211
|
|
191
212
|
def get_predictions(self):
|
192
|
-
_,
|
193
|
-
self.model, self.testloader, self.device
|
194
|
-
)
|
213
|
+
_, _, predictions, _ = self.evaluate(self.model, self.testloader, self.device)
|
195
214
|
return predictions.numpy()
|
196
215
|
|
197
216
|
def predict_sample(self, features):
|
198
|
-
"""Predict one sample"""
|
217
|
+
"""Predict one sample."""
|
199
218
|
with torch.no_grad():
|
200
219
|
logits = self.model(torch.from_numpy(features).to(self.device))
|
201
220
|
a = logits.numpy()
|