nkululeko 0.94.3__py3-none-any.whl → 0.95.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/augmenting/resampler.py +5 -2
- nkululeko/autopredict/ap_emotion.py +36 -0
- nkululeko/autopredict/ap_text.py +45 -0
- nkululeko/autopredict/whisper_transcriber.py +81 -0
- nkululeko/constants.py +1 -1
- nkululeko/experiment.py +53 -3
- nkululeko/explore.py +32 -13
- nkululeko/feat_extract/feats_analyser.py +45 -17
- nkululeko/feat_extract/feats_emotion2vec.py +51 -26
- nkululeko/feat_extract/feinberg_praat.py +515 -372
- nkululeko/glob_conf.py +9 -0
- nkululeko/modelrunner.py +15 -6
- nkululeko/models/model.py +4 -42
- nkululeko/models/model_tuned.py +416 -84
- nkululeko/models/model_xgb.py +148 -2
- nkululeko/nkululeko.py +0 -9
- nkululeko/plots.py +25 -19
- nkululeko/predict.py +6 -5
- nkululeko/reporting/report.py +7 -5
- nkululeko/reporting/reporter.py +8 -5
- nkululeko/utils/util.py +34 -2
- {nkululeko-0.94.3.dist-info → nkululeko-0.95.0.dist-info}/METADATA +1 -1
- {nkululeko-0.94.3.dist-info → nkululeko-0.95.0.dist-info}/RECORD +27 -24
- {nkululeko-0.94.3.dist-info → nkululeko-0.95.0.dist-info}/WHEEL +0 -0
- {nkululeko-0.94.3.dist-info → nkululeko-0.95.0.dist-info}/entry_points.txt +0 -0
- {nkululeko-0.94.3.dist-info → nkululeko-0.95.0.dist-info}/licenses/LICENSE +0 -0
- {nkululeko-0.94.3.dist-info → nkululeko-0.95.0.dist-info}/top_level.txt +0 -0
@@ -68,7 +68,9 @@ class Resampler:
|
|
68
68
|
self.df.index.set_levels(new_files, level="file")
|
69
69
|
)
|
70
70
|
if not self.not_testing:
|
71
|
-
target_file = self.util.config_val(
|
71
|
+
target_file = self.util.config_val(
|
72
|
+
"RESAMPLE", "target", "resampled.csv"
|
73
|
+
)
|
72
74
|
# remove encoded labels
|
73
75
|
target = self.util.config_val("DATA", "target", "emotion")
|
74
76
|
if "class_label" in self.df.columns:
|
@@ -77,7 +79,8 @@ class Resampler:
|
|
77
79
|
# save file
|
78
80
|
self.df.to_csv(target_file)
|
79
81
|
self.util.debug(
|
80
|
-
"saved resampled list of files to"
|
82
|
+
"saved resampled list of files to"
|
83
|
+
f" {os.path.abspath(target_file)}"
|
81
84
|
)
|
82
85
|
else:
|
83
86
|
# When running from command line, save to simple resampled.csv
|
@@ -0,0 +1,36 @@
|
|
1
|
+
"""
|
2
|
+
A predictor for emotion classification.
|
3
|
+
Uses emotion2vec models for emotion prediction.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import ast
|
7
|
+
|
8
|
+
import nkululeko.glob_conf as glob_conf
|
9
|
+
from nkululeko.feature_extractor import FeatureExtractor
|
10
|
+
from nkululeko.utils.util import Util
|
11
|
+
|
12
|
+
|
13
|
+
class EmotionPredictor:
|
14
|
+
"""
|
15
|
+
EmotionPredictor
|
16
|
+
predicting emotion with emotion2vec models
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, df):
|
20
|
+
self.df = df
|
21
|
+
self.util = Util("emotionPredictor")
|
22
|
+
|
23
|
+
def predict(self, split_selection):
|
24
|
+
self.util.debug(f"predicting emotion for {split_selection} samples")
|
25
|
+
feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
|
26
|
+
|
27
|
+
self.feature_extractor = FeatureExtractor(
|
28
|
+
self.df, ["emotion2vec-large"], feats_name, split_selection
|
29
|
+
)
|
30
|
+
emotion_df = self.feature_extractor.extract()
|
31
|
+
|
32
|
+
pred_emotion = ["neutral"] * len(emotion_df)
|
33
|
+
|
34
|
+
return_df = self.df.copy()
|
35
|
+
return_df["emotion_pred"] = pred_emotion
|
36
|
+
return return_df
|
@@ -0,0 +1,45 @@
|
|
1
|
+
"""A predictor for text.
|
2
|
+
|
3
|
+
Currently based on whisper model.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import ast
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from nkululeko.feature_extractor import FeatureExtractor
|
11
|
+
import nkululeko.glob_conf as glob_conf
|
12
|
+
from nkululeko.utils.util import Util
|
13
|
+
|
14
|
+
|
15
|
+
class TextPredictor:
|
16
|
+
"""TextPredictor.
|
17
|
+
|
18
|
+
predicting text with the whisper model
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, df, util=None):
|
22
|
+
self.df = df
|
23
|
+
if util is not None:
|
24
|
+
self.util = util
|
25
|
+
else:
|
26
|
+
# create a new util instance
|
27
|
+
# this is needed to access the config and other utilities
|
28
|
+
# in the autopredict module
|
29
|
+
self.util = Util("textPredictor")
|
30
|
+
from nkululeko.autopredict.whisper_transcriber import Transcriber
|
31
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32
|
+
device = self.util.config_val("MODEL", "device", device)
|
33
|
+
self.transcriber = Transcriber(
|
34
|
+
device=device,
|
35
|
+
language=self.util.config_val("EXP", "language", "en"),
|
36
|
+
util=self.util,
|
37
|
+
)
|
38
|
+
def predict(self, split_selection):
|
39
|
+
self.util.debug(f"predicting text for {split_selection} samples")
|
40
|
+
df = self.transcriber.transcribe_index(
|
41
|
+
self.df.index
|
42
|
+
)
|
43
|
+
return_df = self.df.copy()
|
44
|
+
return_df["text"] = df["text"].values
|
45
|
+
return return_df
|
@@ -0,0 +1,81 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import pandas as pd
|
4
|
+
import torch
|
5
|
+
from tqdm import tqdm
|
6
|
+
import whisper
|
7
|
+
|
8
|
+
import audeer
|
9
|
+
import audiofile
|
10
|
+
|
11
|
+
from nkululeko.utils.util import Util
|
12
|
+
|
13
|
+
|
14
|
+
class Transcriber:
|
15
|
+
def __init__(self, model_name="turbo", device=None, language="en", util=None):
|
16
|
+
if device is None:
|
17
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18
|
+
self.model = whisper.load_model(model_name, device=device)
|
19
|
+
self.language = language
|
20
|
+
self.util = util
|
21
|
+
|
22
|
+
def transcribe_file(self, audio_path):
|
23
|
+
"""Transcribe the audio file at the given path.
|
24
|
+
|
25
|
+
:param audio_path: Path to the audio file to transcribe.
|
26
|
+
:return: Transcription text.
|
27
|
+
"""
|
28
|
+
result = self.model.transcribe(
|
29
|
+
audio_path, language=self.language, without_timestamps=True)
|
30
|
+
result = result["text"].strip()
|
31
|
+
return result
|
32
|
+
|
33
|
+
def transcribe_array(self, signal, sampling_rate):
|
34
|
+
"""Transcribe the audio file at the given path.
|
35
|
+
|
36
|
+
:param audio_path: Path to the audio file to transcribe.
|
37
|
+
:return: Transcription text.
|
38
|
+
"""
|
39
|
+
tmporary_path = "temp.wav"
|
40
|
+
audiofile.write(
|
41
|
+
"temp.wav", signal, sampling_rate, format="wav")
|
42
|
+
result = self.transcribe_file(tmporary_path)
|
43
|
+
return result
|
44
|
+
|
45
|
+
def transcribe_index(self, index:pd.Index) -> pd.DataFrame:
|
46
|
+
"""Transcribe the audio files in the given index.
|
47
|
+
|
48
|
+
:param index: Index containing tuples of (file, start, end).
|
49
|
+
:return: DataFrame with transcriptions indexed by the original index.
|
50
|
+
:rtype: pd.DataFrame
|
51
|
+
"""
|
52
|
+
file_name = ""
|
53
|
+
seg_index = 0
|
54
|
+
transcriptions = []
|
55
|
+
transcriber_cache = audeer.mkdir(
|
56
|
+
audeer.path(self.util.get_path("cache"), "transcriptions"))
|
57
|
+
for idx, (file, start, end) in enumerate(
|
58
|
+
tqdm(index.to_list())
|
59
|
+
):
|
60
|
+
if file != file_name:
|
61
|
+
file_name = file
|
62
|
+
seg_index = 0
|
63
|
+
cache_name = audeer.basename_wo_ext(file)+str(seg_index)
|
64
|
+
cache_path = audeer.path(transcriber_cache, cache_name + ".json")
|
65
|
+
if os.path.isfile(cache_path):
|
66
|
+
transcription = self.util.read_json(cache_path)["transcription"]
|
67
|
+
else:
|
68
|
+
dur = end.total_seconds() - start.total_seconds()
|
69
|
+
y, sr = audiofile.read(file, offset=start, duration=dur)
|
70
|
+
transcription = self.transcribe_array(
|
71
|
+
y, sr)
|
72
|
+
self.util.save_json(cache_path,
|
73
|
+
{"transcription": transcription,
|
74
|
+
"file": file,
|
75
|
+
"start": start.total_seconds(),
|
76
|
+
"end": end.total_seconds()})
|
77
|
+
transcriptions.append(transcription)
|
78
|
+
seg_index += 1
|
79
|
+
|
80
|
+
df = pd.DataFrame({"text":transcriptions}, index=index)
|
81
|
+
return df
|
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.
|
1
|
+
VERSION="0.95.0"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/experiment.py
CHANGED
@@ -513,7 +513,7 @@ class Experiment:
|
|
513
513
|
|
514
514
|
def autopredict(self):
|
515
515
|
"""Predict labels for samples with existing models and add to the dataframe."""
|
516
|
-
sample_selection = self.util.config_val("PREDICT", "
|
516
|
+
sample_selection = self.util.config_val("PREDICT", "sample_selection", "all")
|
517
517
|
if sample_selection == "all":
|
518
518
|
df = pd.concat([self.df_train, self.df_test])
|
519
519
|
elif sample_selection == "train":
|
@@ -569,6 +569,11 @@ class Experiment:
|
|
569
569
|
|
570
570
|
predictor = STOIPredictor(df)
|
571
571
|
df = predictor.predict(sample_selection)
|
572
|
+
elif target == "text":
|
573
|
+
from nkululeko.autopredict.ap_text import TextPredictor
|
574
|
+
|
575
|
+
predictor = TextPredictor(df, self.util)
|
576
|
+
df = predictor.predict(sample_selection)
|
572
577
|
elif target == "arousal":
|
573
578
|
from nkululeko.autopredict.ap_arousal import ArousalPredictor
|
574
579
|
|
@@ -584,6 +589,11 @@ class Experiment:
|
|
584
589
|
|
585
590
|
predictor = DominancePredictor(df)
|
586
591
|
df = predictor.predict(sample_selection)
|
592
|
+
elif target == "emotion":
|
593
|
+
from nkululeko.autopredict.ap_emotion import EmotionPredictor
|
594
|
+
|
595
|
+
predictor = EmotionPredictor(df)
|
596
|
+
df = predictor.predict(sample_selection)
|
587
597
|
else:
|
588
598
|
self.util.error(f"unknown auto predict target: {target}")
|
589
599
|
return df
|
@@ -668,11 +678,27 @@ class Experiment:
|
|
668
678
|
|
669
679
|
# check if a scatterplot should be done
|
670
680
|
scatter_var = eval(self.util.config_val("EXPL", "scatter", "False"))
|
681
|
+
|
682
|
+
# Priority: use [EXPL][scatter.target] if available, otherwise use [DATA][target] value
|
683
|
+
if hasattr(self, "target") and self.target != "none":
|
684
|
+
default_scatter_target = f"['{self.target}']"
|
685
|
+
else:
|
686
|
+
default_scatter_target = "['class_label']"
|
687
|
+
|
671
688
|
scatter_target = self.util.config_val(
|
672
|
-
"EXPL", "scatter.target",
|
689
|
+
"EXPL", "scatter.target", default_scatter_target
|
673
690
|
)
|
691
|
+
|
692
|
+
if scatter_target == default_scatter_target:
|
693
|
+
self.util.debug(
|
694
|
+
f"scatter.target using default from [DATA][target]: {scatter_target}"
|
695
|
+
)
|
696
|
+
else:
|
697
|
+
self.util.debug(
|
698
|
+
f"scatter.target from [EXPL][scatter.target]: {scatter_target}"
|
699
|
+
)
|
674
700
|
if scatter_var:
|
675
|
-
scatters = ast.literal_eval(
|
701
|
+
scatters = ast.literal_eval(scatter_target)
|
676
702
|
scat_targets = ast.literal_eval(scatter_target)
|
677
703
|
plots = Plots()
|
678
704
|
for scat_target in scat_targets:
|
@@ -692,6 +718,30 @@ class Experiment:
|
|
692
718
|
df_feats, df_labels, f"{scat_target}_bins", scatter
|
693
719
|
)
|
694
720
|
|
721
|
+
# check if t-SNE plot should be generated
|
722
|
+
tsne = eval(self.util.config_val("EXPL", "tsne", "False"))
|
723
|
+
if tsne:
|
724
|
+
target_column = self.util.config_val("DATA", "target", "emotion")
|
725
|
+
plots = Plots()
|
726
|
+
self.util.debug("generating t-SNE plot...")
|
727
|
+
plots.scatter_plot(df_feats, df_labels, target_column, "tsne")
|
728
|
+
|
729
|
+
# check if UMAP plot should be generated
|
730
|
+
umap_plot = eval(self.util.config_val("EXPL", "umap", "False"))
|
731
|
+
if umap_plot:
|
732
|
+
target_column = self.util.config_val("DATA", "target", "emotion")
|
733
|
+
plots = Plots()
|
734
|
+
self.util.debug("generating UMAP plot...")
|
735
|
+
plots.scatter_plot(df_feats, df_labels, target_column, "umap")
|
736
|
+
|
737
|
+
# check if PCA plot should be generated
|
738
|
+
pca_plot = eval(self.util.config_val("EXPL", "pca", "False"))
|
739
|
+
if pca_plot:
|
740
|
+
target_column = self.util.config_val("DATA", "target", "emotion")
|
741
|
+
plots = Plots()
|
742
|
+
self.util.debug("generating PCA plot...")
|
743
|
+
plots.scatter_plot(df_feats, df_labels, target_column, "pca")
|
744
|
+
|
695
745
|
def _check_scale(self):
|
696
746
|
self.util.save_to_store(self.feats_train, "feats_train")
|
697
747
|
self.util.save_to_store(self.feats_test, "feats_test")
|
nkululeko/explore.py
CHANGED
@@ -8,6 +8,8 @@ The script supports the following configuration options:
|
|
8
8
|
- `no_warnings`: If set to `True`, it will ignore all warnings during the exploration.
|
9
9
|
- `feature_distributions`: If set to `True`, it will generate plots of the feature distributions.
|
10
10
|
- `tsne`: If set to `True`, it will generate a t-SNE plot of the feature space.
|
11
|
+
- `umap`: If set to `True`, it will generate a UMAP plot of the feature space.
|
12
|
+
- `pca`: If set to `True`, it will generate a PCA plot of the feature space.
|
11
13
|
- `scatter`: If set to `True`, it will generate a scatter plot of the feature space.
|
12
14
|
- `spotlight`: If set to `True`, it will generate a 'spotlight' plot of the feature space.
|
13
15
|
- `shap`: If set to `True`, it will generate SHAP feature importance plots.
|
@@ -59,10 +61,12 @@ def main():
|
|
59
61
|
|
60
62
|
warnings.filterwarnings("ignore")
|
61
63
|
needs_feats = False
|
64
|
+
experiment_loaded = False
|
62
65
|
try:
|
63
66
|
# load the experiment
|
64
67
|
expr.load(f"{util.get_save_name()}")
|
65
68
|
needs_feats = True
|
69
|
+
experiment_loaded = True
|
66
70
|
except FileNotFoundError:
|
67
71
|
# first time: load the data
|
68
72
|
expr.load_datasets()
|
@@ -73,20 +77,35 @@ def main():
|
|
73
77
|
f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}"
|
74
78
|
)
|
75
79
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
80
|
+
# Check exploration settings regardless of whether experiment was loaded or not
|
81
|
+
plot_feats = eval(util.config_val("EXPL", "feature_distributions", "False"))
|
82
|
+
tsne_plot = eval(util.config_val("EXPL", "tsne", "False"))
|
83
|
+
umap_plot = eval(util.config_val("EXPL", "umap", "False"))
|
84
|
+
pca_plot = eval(util.config_val("EXPL", "pca", "False"))
|
85
|
+
scatter = eval(util.config_val("EXPL", "scatter", "False"))
|
86
|
+
shap = eval(util.config_val("EXPL", "shap", "False"))
|
87
|
+
model_type = util.config_val("EXPL", "model", False)
|
88
|
+
plot_tree = eval(util.config_val("EXPL", "plot_tree", "False"))
|
89
|
+
|
90
|
+
if (
|
91
|
+
plot_feats
|
92
|
+
or tsne_plot
|
93
|
+
or umap_plot
|
94
|
+
or pca_plot
|
95
|
+
or scatter
|
96
|
+
or model_type
|
97
|
+
or plot_tree
|
98
|
+
or shap
|
99
|
+
):
|
100
|
+
# these investigations need features to explore
|
101
|
+
if not experiment_loaded or not needs_feats:
|
85
102
|
expr.extract_feats()
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
#
|
103
|
+
needs_feats = True
|
104
|
+
# explore
|
105
|
+
if shap:
|
106
|
+
# SHAP analysis requires a trained model
|
107
|
+
expr.init_runmanager()
|
108
|
+
expr.runmgr.do_runs()
|
90
109
|
expr.analyse_features(needs_feats)
|
91
110
|
expr.store_report()
|
92
111
|
print("DONE")
|
@@ -1,5 +1,6 @@
|
|
1
1
|
# feats_analyser.py
|
2
2
|
import ast
|
3
|
+
import os
|
3
4
|
|
4
5
|
import matplotlib.pyplot as plt
|
5
6
|
import pandas as pd
|
@@ -76,17 +77,37 @@ class FeatureAnalyser:
|
|
76
77
|
self.util.to_pickle(shap_values, name)
|
77
78
|
else:
|
78
79
|
shap_values = self.util.from_pickle(name)
|
79
|
-
#
|
80
|
-
plt.
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
exp_name = self.util.get_exp_name(only_data=True)
|
80
|
+
# Create SHAP summary plot instead
|
81
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
82
|
+
shap.plots.bar(shap_values, ax=ax, show=False)
|
83
|
+
fig_dir = os.path.join(self.util.get_path("fig_dir"), "..")
|
84
|
+
|
85
85
|
format = self.util.config_val("PLOT", "format", "png")
|
86
|
-
|
87
|
-
filename = f"{
|
88
|
-
|
89
|
-
|
86
|
+
feat_type = self.util.get_feattype_name()
|
87
|
+
filename = f"SHAP_{feat_type}_{model.name}.{format}"
|
88
|
+
filename = os.path.join(fig_dir, filename)
|
89
|
+
|
90
|
+
fig.savefig(filename, dpi=300, bbox_inches="tight")
|
91
|
+
plt.close(fig)
|
92
|
+
|
93
|
+
# print and save SHAP feature importance
|
94
|
+
max_feat_num = len(self.features.columns)
|
95
|
+
shap_importance_values = shap_values.abs.mean(0).values
|
96
|
+
|
97
|
+
feature_cols = self.features.columns
|
98
|
+
feature_importance = pd.DataFrame(
|
99
|
+
shap_importance_values[:max_feat_num],
|
100
|
+
index=feature_cols,
|
101
|
+
columns=["importance"],
|
102
|
+
).sort_values("importance", ascending=False)
|
103
|
+
|
104
|
+
self.util.debug(
|
105
|
+
f"SHAP analysis, features = {feature_importance.index.tolist()}"
|
106
|
+
)
|
107
|
+
# Save to CSV (save all features, not just top ones)
|
108
|
+
csv_filename = os.path.join(fig_dir, f"SHAP_{feat_type}_importance_{model.name}.csv")
|
109
|
+
feature_importance.to_csv(csv_filename)
|
110
|
+
self.util.debug(f"Saved SHAP feature importance to {csv_filename}")
|
90
111
|
self.util.debug(f"plotted SHAP feature importance to {filename}")
|
91
112
|
|
92
113
|
def analyse(self):
|
@@ -120,6 +141,12 @@ class FeatureAnalyser:
|
|
120
141
|
covariance_type = self.util.config_val(
|
121
142
|
"MODEL", "GMM_covariance_type", "full"
|
122
143
|
)
|
144
|
+
allowed_cov_types = ["full", "tied", "diag", "spherical"]
|
145
|
+
if covariance_type not in allowed_cov_types:
|
146
|
+
self.util.error(
|
147
|
+
f"Invalid covariance_type '{covariance_type}', must be one of {allowed_cov_types}. Using default 'full'."
|
148
|
+
)
|
149
|
+
covariance_type = "full"
|
123
150
|
model = mixture.GaussianMixture(
|
124
151
|
n_components=n_components, covariance_type=covariance_type
|
125
152
|
)
|
@@ -156,7 +183,7 @@ class FeatureAnalyser:
|
|
156
183
|
from sklearn.svm import SVC
|
157
184
|
|
158
185
|
c = float(self.util.config_val("MODEL", "C_val", "1.0"))
|
159
|
-
model = SVC(kernel="linear", C=c, gamma="scale")
|
186
|
+
model = SVC(kernel="linear", C=c, gamma="scale", random_state=42)
|
160
187
|
result_importances[model_s] = self._get_importance(
|
161
188
|
model, permutation
|
162
189
|
)
|
@@ -165,7 +192,7 @@ class FeatureAnalyser:
|
|
165
192
|
plots = Plots()
|
166
193
|
plots.plot_tree(model, self.features)
|
167
194
|
elif model_s == "tree":
|
168
|
-
model = DecisionTreeClassifier()
|
195
|
+
model = DecisionTreeClassifier(random_state=42)
|
169
196
|
result_importances[model_s] = self._get_importance(
|
170
197
|
model, permutation
|
171
198
|
)
|
@@ -176,7 +203,9 @@ class FeatureAnalyser:
|
|
176
203
|
elif model_s == "xgb":
|
177
204
|
from xgboost import XGBClassifier
|
178
205
|
|
179
|
-
model = XGBClassifier(
|
206
|
+
model = XGBClassifier(
|
207
|
+
enable_categorical=True, tree_method="hist", random_state=42
|
208
|
+
)
|
180
209
|
self.labels = self.labels.astype("category")
|
181
210
|
result_importances[model_s] = self._get_importance(
|
182
211
|
model, permutation
|
@@ -263,13 +292,12 @@ class FeatureAnalyser:
|
|
263
292
|
title += "\n based on feature permutation"
|
264
293
|
ax.set(title=title)
|
265
294
|
plt.tight_layout()
|
266
|
-
fig_dir = self.util.get_path("fig_dir")
|
267
|
-
exp_name = self.util.get_exp_name(only_data=True)
|
295
|
+
fig_dir = self.util.get_path("fig_dir")
|
268
296
|
format = self.util.config_val("PLOT", "format", "png")
|
269
|
-
filename = f"
|
297
|
+
filename = f"EXPL_{model_name}"
|
270
298
|
if permutation:
|
271
299
|
filename += "_perm"
|
272
|
-
filename = f"{fig_dir}{
|
300
|
+
filename = f"{fig_dir}{filename}.{format}"
|
273
301
|
plt.savefig(filename)
|
274
302
|
fig = ax.figure
|
275
303
|
fig.clear()
|
@@ -3,7 +3,6 @@
|
|
3
3
|
# choices for feat_type = "emotion2vec", "emotion2vec-large", "emotion2vec-base", "emotion2vec-seed"
|
4
4
|
|
5
5
|
# requirements:
|
6
|
-
# pip install "modelscope>=1.9.5,<2.0.0"
|
7
6
|
# pip install funasr
|
8
7
|
|
9
8
|
import os
|
@@ -43,27 +42,30 @@ class Emotion2vec(Featureset):
|
|
43
42
|
except ImportError:
|
44
43
|
self.util.error(
|
45
44
|
"FunASR is required for emotion2vec features. "
|
46
|
-
"Please install with: pip install funasr
|
45
|
+
"Please install with: pip install funasr"
|
47
46
|
)
|
48
47
|
|
49
|
-
# Map feat_type to model names
|
48
|
+
# Map feat_type to model names on HuggingFace
|
50
49
|
model_mapping = {
|
51
|
-
"emotion2vec": "
|
52
|
-
"emotion2vec-base": "
|
53
|
-
"emotion2vec-seed": "
|
54
|
-
"emotion2vec-large": "
|
50
|
+
"emotion2vec": "emotion2vec/emotion2vec_base",
|
51
|
+
"emotion2vec-base": "emotion2vec/emotion2vec_base",
|
52
|
+
"emotion2vec-seed": "emotion2vec/emotion2vec_plus_seed",
|
53
|
+
"emotion2vec-large": "emotion2vec/emotion2vec_plus_large",
|
55
54
|
}
|
56
55
|
|
57
56
|
# Get model path from config or use default mapping
|
58
57
|
model_path = self.util.config_val(
|
59
58
|
"FEATS",
|
60
59
|
"emotion2vec.model",
|
61
|
-
model_mapping.get(self.feat_type, "
|
60
|
+
model_mapping.get(self.feat_type, "emotion2vec/emotion2vec_base"),
|
62
61
|
)
|
63
62
|
|
64
63
|
try:
|
65
|
-
# Initialize the FunASR model for emotion2vec
|
66
|
-
self.model = AutoModel(
|
64
|
+
# Initialize the FunASR model for emotion2vec using HuggingFace Hub
|
65
|
+
self.model = AutoModel(
|
66
|
+
model=model_path,
|
67
|
+
hub="hf" # Use HuggingFace Hub instead of ModelScope
|
68
|
+
)
|
67
69
|
self.util.debug(f"initialized emotion2vec model: {model_path}")
|
68
70
|
self.model_initialized = True
|
69
71
|
except Exception as e:
|
@@ -131,7 +133,9 @@ class Emotion2vec(Featureset):
|
|
131
133
|
import tempfile
|
132
134
|
import soundfile as sf
|
133
135
|
|
134
|
-
with tempfile.NamedTemporaryFile(
|
136
|
+
with tempfile.NamedTemporaryFile(
|
137
|
+
suffix=".wav", delete=False
|
138
|
+
) as tmp_file:
|
135
139
|
sf.write(tmp_file.name, signal_np, sampling_rate)
|
136
140
|
audio_path = tmp_file.name
|
137
141
|
else:
|
@@ -152,11 +156,20 @@ class Emotion2vec(Featureset):
|
|
152
156
|
embeddings = np.array(embeddings)
|
153
157
|
return embeddings.flatten()
|
154
158
|
else:
|
155
|
-
# Fallback
|
156
|
-
|
159
|
+
# Fallback based on model type
|
160
|
+
if 'large' in self.feat_type.lower():
|
161
|
+
return np.array([0.0] * 1024)
|
162
|
+
else:
|
163
|
+
return np.array([0.0] * 768)
|
157
164
|
else:
|
158
|
-
self.util.error(
|
159
|
-
|
165
|
+
self.util.error(
|
166
|
+
f"No result from emotion2vec model for file: {file}"
|
167
|
+
)
|
168
|
+
# Fallback based on model type
|
169
|
+
if 'large' in self.feat_type.lower():
|
170
|
+
return np.array([0.0] * 1024)
|
171
|
+
else:
|
172
|
+
return np.array([0.0] * 768)
|
160
173
|
|
161
174
|
finally:
|
162
175
|
# Clean up temporary file if we created one
|
@@ -166,36 +179,40 @@ class Emotion2vec(Featureset):
|
|
166
179
|
except Exception as e:
|
167
180
|
print(f"Error processing {file}: {str(e)}")
|
168
181
|
self.util.error(f"couldn't extract file: {file}, error: {str(e)}")
|
169
|
-
|
182
|
+
# Return appropriate dimension based on model type
|
183
|
+
if 'large' in self.feat_type.lower():
|
184
|
+
return np.array([0.0] * 1024)
|
185
|
+
else:
|
186
|
+
return np.array([0.0] * 768)
|
170
187
|
|
171
188
|
def extract_sample(self, signal, sr):
|
172
189
|
"""Extract features from a single sample."""
|
173
190
|
if not self.model_initialized:
|
174
191
|
self.init_model()
|
175
|
-
|
192
|
+
|
176
193
|
# Save signal as temporary file for emotion2vec
|
177
194
|
import tempfile
|
178
195
|
import soundfile as sf
|
179
|
-
|
196
|
+
|
180
197
|
try:
|
181
198
|
# Convert tensor to numpy if needed
|
182
199
|
if torch.is_tensor(signal):
|
183
200
|
signal_np = signal.squeeze().numpy()
|
184
201
|
else:
|
185
202
|
signal_np = signal.squeeze()
|
186
|
-
|
203
|
+
|
187
204
|
# Handle multi-channel audio
|
188
205
|
if signal_np.ndim > 1:
|
189
206
|
signal_np = signal_np[0]
|
190
|
-
|
207
|
+
|
191
208
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
192
209
|
sf.write(tmp_file.name, signal_np, sr)
|
193
|
-
|
210
|
+
|
194
211
|
# Extract using the emotion2vec model
|
195
212
|
res = self.model.generate(
|
196
213
|
tmp_file.name, granularity="utterance", extract_embedding=True
|
197
214
|
)
|
198
|
-
|
215
|
+
|
199
216
|
# Get embeddings from result
|
200
217
|
if isinstance(res, list) and len(res) > 0:
|
201
218
|
embeddings = res[0].get("feats", None)
|
@@ -203,12 +220,20 @@ class Emotion2vec(Featureset):
|
|
203
220
|
if isinstance(embeddings, list):
|
204
221
|
embeddings = np.array(embeddings)
|
205
222
|
return embeddings.flatten()
|
206
|
-
|
207
|
-
|
208
|
-
|
223
|
+
|
224
|
+
# Fallback based on model type
|
225
|
+
if 'large' in self.feat_type.lower():
|
226
|
+
return np.array([0.0] * 1024)
|
227
|
+
else:
|
228
|
+
return np.array([0.0] * 768)
|
229
|
+
|
209
230
|
except Exception as e:
|
210
231
|
print(f"Error in extract_sample: {str(e)}")
|
211
|
-
|
232
|
+
# Return appropriate dimension based on model type
|
233
|
+
if 'large' in self.feat_type.lower():
|
234
|
+
return np.array([0.0] * 1024)
|
235
|
+
else:
|
236
|
+
return np.array([0.0] * 768)
|
212
237
|
finally:
|
213
238
|
# Clean up temporary file
|
214
239
|
if tmp_file is not None: # Check if tmp_file was created
|