nkululeko 0.94.2__py3-none-any.whl → 0.95.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,7 +17,7 @@ class Resampler:
17
17
  def __init__(self, df, replace, not_testing=True):
18
18
  self.SAMPLING_RATE = 16000
19
19
  self.df = df
20
- self.util = Util("resampler", has_config=not_testing)
20
+ self.util = Util("resampler", has_config=not not_testing)
21
21
  self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
22
22
  self.not_testing = not_testing
23
23
  self.replace = (
@@ -30,7 +30,7 @@ class Resampler:
30
30
  files = self.df.index.get_level_values(0).values
31
31
  # replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
32
32
  replace = self.replace
33
- if self.not_testing:
33
+ if not self.not_testing:
34
34
  store = self.util.get_path("store")
35
35
  else:
36
36
  store = "./"
@@ -67,17 +67,28 @@ class Resampler:
67
67
  self.df = self.df.set_index(
68
68
  self.df.index.set_levels(new_files, level="file")
69
69
  )
70
- target_file = self.util.config_val("RESAMPLE", "target", "resampled.csv")
71
- # remove encoded labels
72
- target = self.util.config_val("DATA", "target", "emotion")
73
- if "class_label" in self.df.columns:
74
- self.df = self.df.drop(columns=[target])
75
- self.df = self.df.rename(columns={"class_label": target})
76
- # save file
77
- self.df.to_csv(target_file)
78
- self.util.debug(
79
- "saved resampled list of files to" f" {os.path.abspath(target_file)}"
80
- )
70
+ if not self.not_testing:
71
+ target_file = self.util.config_val(
72
+ "RESAMPLE", "target", "resampled.csv"
73
+ )
74
+ # remove encoded labels
75
+ target = self.util.config_val("DATA", "target", "emotion")
76
+ if "class_label" in self.df.columns:
77
+ self.df = self.df.drop(columns=[target])
78
+ self.df = self.df.rename(columns={"class_label": target})
79
+ # save file
80
+ self.df.to_csv(target_file)
81
+ self.util.debug(
82
+ "saved resampled list of files to"
83
+ f" {os.path.abspath(target_file)}"
84
+ )
85
+ else:
86
+ # When running from command line, save to simple resampled.csv
87
+ target_file = "resampled.csv"
88
+ self.df.to_csv(target_file)
89
+ self.util.debug(
90
+ f"saved resampled list of files to {os.path.abspath(target_file)}"
91
+ )
81
92
  self.util.debug(f"resampled {succes} files, {error} errors")
82
93
 
83
94
 
@@ -91,7 +102,7 @@ def main():
91
102
  df_sample.index, allow_nat=False
92
103
  )
93
104
  df_sample.head(10)
94
- resampler = Resampler(df_sample, not_testing=False)
105
+ resampler = Resampler(df_sample, False, not_testing=False)
95
106
  resampler.resample()
96
107
  shutil.copyfile(testfile, "tmp.resample_result.wav")
97
108
  shutil.copyfile("tmp.wav", testfile)
@@ -0,0 +1,36 @@
1
+ """
2
+ A predictor for emotion classification.
3
+ Uses emotion2vec models for emotion prediction.
4
+ """
5
+
6
+ import ast
7
+
8
+ import nkululeko.glob_conf as glob_conf
9
+ from nkululeko.feature_extractor import FeatureExtractor
10
+ from nkululeko.utils.util import Util
11
+
12
+
13
+ class EmotionPredictor:
14
+ """
15
+ EmotionPredictor
16
+ predicting emotion with emotion2vec models
17
+ """
18
+
19
+ def __init__(self, df):
20
+ self.df = df
21
+ self.util = Util("emotionPredictor")
22
+
23
+ def predict(self, split_selection):
24
+ self.util.debug(f"predicting emotion for {split_selection} samples")
25
+ feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
26
+
27
+ self.feature_extractor = FeatureExtractor(
28
+ self.df, ["emotion2vec-large"], feats_name, split_selection
29
+ )
30
+ emotion_df = self.feature_extractor.extract()
31
+
32
+ pred_emotion = ["neutral"] * len(emotion_df)
33
+
34
+ return_df = self.df.copy()
35
+ return_df["emotion_pred"] = pred_emotion
36
+ return return_df
@@ -0,0 +1,45 @@
1
+ """A predictor for text.
2
+
3
+ Currently based on whisper model.
4
+ """
5
+
6
+ import ast
7
+
8
+ import torch
9
+
10
+ from nkululeko.feature_extractor import FeatureExtractor
11
+ import nkululeko.glob_conf as glob_conf
12
+ from nkululeko.utils.util import Util
13
+
14
+
15
+ class TextPredictor:
16
+ """TextPredictor.
17
+
18
+ predicting text with the whisper model
19
+ """
20
+
21
+ def __init__(self, df, util=None):
22
+ self.df = df
23
+ if util is not None:
24
+ self.util = util
25
+ else:
26
+ # create a new util instance
27
+ # this is needed to access the config and other utilities
28
+ # in the autopredict module
29
+ self.util = Util("textPredictor")
30
+ from nkululeko.autopredict.whisper_transcriber import Transcriber
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ device = self.util.config_val("MODEL", "device", device)
33
+ self.transcriber = Transcriber(
34
+ device=device,
35
+ language=self.util.config_val("EXP", "language", "en"),
36
+ util=self.util,
37
+ )
38
+ def predict(self, split_selection):
39
+ self.util.debug(f"predicting text for {split_selection} samples")
40
+ df = self.transcriber.transcribe_index(
41
+ self.df.index
42
+ )
43
+ return_df = self.df.copy()
44
+ return_df["text"] = df["text"].values
45
+ return return_df
@@ -0,0 +1,81 @@
1
+ import os
2
+
3
+ import pandas as pd
4
+ import torch
5
+ from tqdm import tqdm
6
+ import whisper
7
+
8
+ import audeer
9
+ import audiofile
10
+
11
+ from nkululeko.utils.util import Util
12
+
13
+
14
+ class Transcriber:
15
+ def __init__(self, model_name="turbo", device=None, language="en", util=None):
16
+ if device is None:
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ self.model = whisper.load_model(model_name, device=device)
19
+ self.language = language
20
+ self.util = util
21
+
22
+ def transcribe_file(self, audio_path):
23
+ """Transcribe the audio file at the given path.
24
+
25
+ :param audio_path: Path to the audio file to transcribe.
26
+ :return: Transcription text.
27
+ """
28
+ result = self.model.transcribe(
29
+ audio_path, language=self.language, without_timestamps=True)
30
+ result = result["text"].strip()
31
+ return result
32
+
33
+ def transcribe_array(self, signal, sampling_rate):
34
+ """Transcribe the audio file at the given path.
35
+
36
+ :param audio_path: Path to the audio file to transcribe.
37
+ :return: Transcription text.
38
+ """
39
+ tmporary_path = "temp.wav"
40
+ audiofile.write(
41
+ "temp.wav", signal, sampling_rate, format="wav")
42
+ result = self.transcribe_file(tmporary_path)
43
+ return result
44
+
45
+ def transcribe_index(self, index:pd.Index) -> pd.DataFrame:
46
+ """Transcribe the audio files in the given index.
47
+
48
+ :param index: Index containing tuples of (file, start, end).
49
+ :return: DataFrame with transcriptions indexed by the original index.
50
+ :rtype: pd.DataFrame
51
+ """
52
+ file_name = ""
53
+ seg_index = 0
54
+ transcriptions = []
55
+ transcriber_cache = audeer.mkdir(
56
+ audeer.path(self.util.get_path("cache"), "transcriptions"))
57
+ for idx, (file, start, end) in enumerate(
58
+ tqdm(index.to_list())
59
+ ):
60
+ if file != file_name:
61
+ file_name = file
62
+ seg_index = 0
63
+ cache_name = audeer.basename_wo_ext(file)+str(seg_index)
64
+ cache_path = audeer.path(transcriber_cache, cache_name + ".json")
65
+ if os.path.isfile(cache_path):
66
+ transcription = self.util.read_json(cache_path)["transcription"]
67
+ else:
68
+ dur = end.total_seconds() - start.total_seconds()
69
+ y, sr = audiofile.read(file, offset=start, duration=dur)
70
+ transcription = self.transcribe_array(
71
+ y, sr)
72
+ self.util.save_json(cache_path,
73
+ {"transcription": transcription,
74
+ "file": file,
75
+ "start": start.total_seconds(),
76
+ "end": end.total_seconds()})
77
+ transcriptions.append(transcription)
78
+ seg_index += 1
79
+
80
+ df = pd.DataFrame({"text":transcriptions}, index=index)
81
+ return df
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.94.2"
1
+ VERSION="0.95.0"
2
2
  SAMPLING_RATE = 16000
nkululeko/experiment.py CHANGED
@@ -513,7 +513,7 @@ class Experiment:
513
513
 
514
514
  def autopredict(self):
515
515
  """Predict labels for samples with existing models and add to the dataframe."""
516
- sample_selection = self.util.config_val("PREDICT", "split", "all")
516
+ sample_selection = self.util.config_val("PREDICT", "sample_selection", "all")
517
517
  if sample_selection == "all":
518
518
  df = pd.concat([self.df_train, self.df_test])
519
519
  elif sample_selection == "train":
@@ -569,6 +569,11 @@ class Experiment:
569
569
 
570
570
  predictor = STOIPredictor(df)
571
571
  df = predictor.predict(sample_selection)
572
+ elif target == "text":
573
+ from nkululeko.autopredict.ap_text import TextPredictor
574
+
575
+ predictor = TextPredictor(df, self.util)
576
+ df = predictor.predict(sample_selection)
572
577
  elif target == "arousal":
573
578
  from nkululeko.autopredict.ap_arousal import ArousalPredictor
574
579
 
@@ -584,6 +589,11 @@ class Experiment:
584
589
 
585
590
  predictor = DominancePredictor(df)
586
591
  df = predictor.predict(sample_selection)
592
+ elif target == "emotion":
593
+ from nkululeko.autopredict.ap_emotion import EmotionPredictor
594
+
595
+ predictor = EmotionPredictor(df)
596
+ df = predictor.predict(sample_selection)
587
597
  else:
588
598
  self.util.error(f"unknown auto predict target: {target}")
589
599
  return df
@@ -668,11 +678,27 @@ class Experiment:
668
678
 
669
679
  # check if a scatterplot should be done
670
680
  scatter_var = eval(self.util.config_val("EXPL", "scatter", "False"))
681
+
682
+ # Priority: use [EXPL][scatter.target] if available, otherwise use [DATA][target] value
683
+ if hasattr(self, "target") and self.target != "none":
684
+ default_scatter_target = f"['{self.target}']"
685
+ else:
686
+ default_scatter_target = "['class_label']"
687
+
671
688
  scatter_target = self.util.config_val(
672
- "EXPL", "scatter.target", "['class_label']"
689
+ "EXPL", "scatter.target", default_scatter_target
673
690
  )
691
+
692
+ if scatter_target == default_scatter_target:
693
+ self.util.debug(
694
+ f"scatter.target using default from [DATA][target]: {scatter_target}"
695
+ )
696
+ else:
697
+ self.util.debug(
698
+ f"scatter.target from [EXPL][scatter.target]: {scatter_target}"
699
+ )
674
700
  if scatter_var:
675
- scatters = ast.literal_eval(glob_conf.config["EXPL"]["scatter"])
701
+ scatters = ast.literal_eval(scatter_target)
676
702
  scat_targets = ast.literal_eval(scatter_target)
677
703
  plots = Plots()
678
704
  for scat_target in scat_targets:
@@ -692,6 +718,30 @@ class Experiment:
692
718
  df_feats, df_labels, f"{scat_target}_bins", scatter
693
719
  )
694
720
 
721
+ # check if t-SNE plot should be generated
722
+ tsne = eval(self.util.config_val("EXPL", "tsne", "False"))
723
+ if tsne:
724
+ target_column = self.util.config_val("DATA", "target", "emotion")
725
+ plots = Plots()
726
+ self.util.debug("generating t-SNE plot...")
727
+ plots.scatter_plot(df_feats, df_labels, target_column, "tsne")
728
+
729
+ # check if UMAP plot should be generated
730
+ umap_plot = eval(self.util.config_val("EXPL", "umap", "False"))
731
+ if umap_plot:
732
+ target_column = self.util.config_val("DATA", "target", "emotion")
733
+ plots = Plots()
734
+ self.util.debug("generating UMAP plot...")
735
+ plots.scatter_plot(df_feats, df_labels, target_column, "umap")
736
+
737
+ # check if PCA plot should be generated
738
+ pca_plot = eval(self.util.config_val("EXPL", "pca", "False"))
739
+ if pca_plot:
740
+ target_column = self.util.config_val("DATA", "target", "emotion")
741
+ plots = Plots()
742
+ self.util.debug("generating PCA plot...")
743
+ plots.scatter_plot(df_feats, df_labels, target_column, "pca")
744
+
695
745
  def _check_scale(self):
696
746
  self.util.save_to_store(self.feats_train, "feats_train")
697
747
  self.util.save_to_store(self.feats_test, "feats_test")
nkululeko/explore.py CHANGED
@@ -8,6 +8,8 @@ The script supports the following configuration options:
8
8
  - `no_warnings`: If set to `True`, it will ignore all warnings during the exploration.
9
9
  - `feature_distributions`: If set to `True`, it will generate plots of the feature distributions.
10
10
  - `tsne`: If set to `True`, it will generate a t-SNE plot of the feature space.
11
+ - `umap`: If set to `True`, it will generate a UMAP plot of the feature space.
12
+ - `pca`: If set to `True`, it will generate a PCA plot of the feature space.
11
13
  - `scatter`: If set to `True`, it will generate a scatter plot of the feature space.
12
14
  - `spotlight`: If set to `True`, it will generate a 'spotlight' plot of the feature space.
13
15
  - `shap`: If set to `True`, it will generate SHAP feature importance plots.
@@ -59,10 +61,12 @@ def main():
59
61
 
60
62
  warnings.filterwarnings("ignore")
61
63
  needs_feats = False
64
+ experiment_loaded = False
62
65
  try:
63
66
  # load the experiment
64
67
  expr.load(f"{util.get_save_name()}")
65
68
  needs_feats = True
69
+ experiment_loaded = True
66
70
  except FileNotFoundError:
67
71
  # first time: load the data
68
72
  expr.load_datasets()
@@ -73,20 +77,35 @@ def main():
73
77
  f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}"
74
78
  )
75
79
 
76
- plot_feats = eval(util.config_val("EXPL", "feature_distributions", "False"))
77
- tsne = eval(util.config_val("EXPL", "tsne", "False"))
78
- scatter = eval(util.config_val("EXPL", "scatter", "False"))
79
- shap = eval(util.config_val("EXPL", "shap", "False"))
80
- model_type = util.config_val("EXPL", "model", False)
81
- plot_tree = eval(util.config_val("EXPL", "plot_tree", "False"))
82
- needs_feats = False
83
- if plot_feats or tsne or scatter or model_type or plot_tree or shap:
84
- # these investigations need features to explore
80
+ # Check exploration settings regardless of whether experiment was loaded or not
81
+ plot_feats = eval(util.config_val("EXPL", "feature_distributions", "False"))
82
+ tsne_plot = eval(util.config_val("EXPL", "tsne", "False"))
83
+ umap_plot = eval(util.config_val("EXPL", "umap", "False"))
84
+ pca_plot = eval(util.config_val("EXPL", "pca", "False"))
85
+ scatter = eval(util.config_val("EXPL", "scatter", "False"))
86
+ shap = eval(util.config_val("EXPL", "shap", "False"))
87
+ model_type = util.config_val("EXPL", "model", False)
88
+ plot_tree = eval(util.config_val("EXPL", "plot_tree", "False"))
89
+
90
+ if (
91
+ plot_feats
92
+ or tsne_plot
93
+ or umap_plot
94
+ or pca_plot
95
+ or scatter
96
+ or model_type
97
+ or plot_tree
98
+ or shap
99
+ ):
100
+ # these investigations need features to explore
101
+ if not experiment_loaded or not needs_feats:
85
102
  expr.extract_feats()
86
- needs_feats = True
87
- # explore
88
- # expr.init_runmanager()
89
- # expr.runmgr.do_runs()
103
+ needs_feats = True
104
+ # explore
105
+ if shap:
106
+ # SHAP analysis requires a trained model
107
+ expr.init_runmanager()
108
+ expr.runmgr.do_runs()
90
109
  expr.analyse_features(needs_feats)
91
110
  expr.store_report()
92
111
  print("DONE")
@@ -1,5 +1,6 @@
1
1
  # feats_analyser.py
2
2
  import ast
3
+ import os
3
4
 
4
5
  import matplotlib.pyplot as plt
5
6
  import pandas as pd
@@ -76,17 +77,37 @@ class FeatureAnalyser:
76
77
  self.util.to_pickle(shap_values, name)
77
78
  else:
78
79
  shap_values = self.util.from_pickle(name)
79
- # plt.figure()
80
- plt.close("all")
81
- plt.tight_layout()
82
- shap.plots.bar(shap_values)
83
- fig_dir = self.util.get_path("fig_dir") + "../" # one up because of the runs
84
- exp_name = self.util.get_exp_name(only_data=True)
80
+ # Create SHAP summary plot instead
81
+ fig, ax = plt.subplots(figsize=(10, 6))
82
+ shap.plots.bar(shap_values, ax=ax, show=False)
83
+ fig_dir = os.path.join(self.util.get_path("fig_dir"), "..")
84
+
85
85
  format = self.util.config_val("PLOT", "format", "png")
86
- filename = f"_SHAP_{model.name}"
87
- filename = f"{fig_dir}{exp_name}{filename}.{format}"
88
- plt.savefig(filename)
89
- plt.close()
86
+ feat_type = self.util.get_feattype_name()
87
+ filename = f"SHAP_{feat_type}_{model.name}.{format}"
88
+ filename = os.path.join(fig_dir, filename)
89
+
90
+ fig.savefig(filename, dpi=300, bbox_inches="tight")
91
+ plt.close(fig)
92
+
93
+ # print and save SHAP feature importance
94
+ max_feat_num = len(self.features.columns)
95
+ shap_importance_values = shap_values.abs.mean(0).values
96
+
97
+ feature_cols = self.features.columns
98
+ feature_importance = pd.DataFrame(
99
+ shap_importance_values[:max_feat_num],
100
+ index=feature_cols,
101
+ columns=["importance"],
102
+ ).sort_values("importance", ascending=False)
103
+
104
+ self.util.debug(
105
+ f"SHAP analysis, features = {feature_importance.index.tolist()}"
106
+ )
107
+ # Save to CSV (save all features, not just top ones)
108
+ csv_filename = os.path.join(fig_dir, f"SHAP_{feat_type}_importance_{model.name}.csv")
109
+ feature_importance.to_csv(csv_filename)
110
+ self.util.debug(f"Saved SHAP feature importance to {csv_filename}")
90
111
  self.util.debug(f"plotted SHAP feature importance to {filename}")
91
112
 
92
113
  def analyse(self):
@@ -120,6 +141,12 @@ class FeatureAnalyser:
120
141
  covariance_type = self.util.config_val(
121
142
  "MODEL", "GMM_covariance_type", "full"
122
143
  )
144
+ allowed_cov_types = ["full", "tied", "diag", "spherical"]
145
+ if covariance_type not in allowed_cov_types:
146
+ self.util.error(
147
+ f"Invalid covariance_type '{covariance_type}', must be one of {allowed_cov_types}. Using default 'full'."
148
+ )
149
+ covariance_type = "full"
123
150
  model = mixture.GaussianMixture(
124
151
  n_components=n_components, covariance_type=covariance_type
125
152
  )
@@ -156,7 +183,7 @@ class FeatureAnalyser:
156
183
  from sklearn.svm import SVC
157
184
 
158
185
  c = float(self.util.config_val("MODEL", "C_val", "1.0"))
159
- model = SVC(kernel="linear", C=c, gamma="scale")
186
+ model = SVC(kernel="linear", C=c, gamma="scale", random_state=42)
160
187
  result_importances[model_s] = self._get_importance(
161
188
  model, permutation
162
189
  )
@@ -165,7 +192,7 @@ class FeatureAnalyser:
165
192
  plots = Plots()
166
193
  plots.plot_tree(model, self.features)
167
194
  elif model_s == "tree":
168
- model = DecisionTreeClassifier()
195
+ model = DecisionTreeClassifier(random_state=42)
169
196
  result_importances[model_s] = self._get_importance(
170
197
  model, permutation
171
198
  )
@@ -176,7 +203,9 @@ class FeatureAnalyser:
176
203
  elif model_s == "xgb":
177
204
  from xgboost import XGBClassifier
178
205
 
179
- model = XGBClassifier(enable_categorical=True, tree_method="hist")
206
+ model = XGBClassifier(
207
+ enable_categorical=True, tree_method="hist", random_state=42
208
+ )
180
209
  self.labels = self.labels.astype("category")
181
210
  result_importances[model_s] = self._get_importance(
182
211
  model, permutation
@@ -263,13 +292,12 @@ class FeatureAnalyser:
263
292
  title += "\n based on feature permutation"
264
293
  ax.set(title=title)
265
294
  plt.tight_layout()
266
- fig_dir = self.util.get_path("fig_dir") + "../" # one up because of the runs
267
- exp_name = self.util.get_exp_name(only_data=True)
295
+ fig_dir = self.util.get_path("fig_dir")
268
296
  format = self.util.config_val("PLOT", "format", "png")
269
- filename = f"_EXPL_{model_name}"
297
+ filename = f"EXPL_{model_name}"
270
298
  if permutation:
271
299
  filename += "_perm"
272
- filename = f"{fig_dir}{exp_name}{filename}.{format}"
300
+ filename = f"{fig_dir}{filename}.{format}"
273
301
  plt.savefig(filename)
274
302
  fig = ax.figure
275
303
  fig.clear()
@@ -3,7 +3,6 @@
3
3
  # choices for feat_type = "emotion2vec", "emotion2vec-large", "emotion2vec-base", "emotion2vec-seed"
4
4
 
5
5
  # requirements:
6
- # pip install "modelscope>=1.9.5,<2.0.0"
7
6
  # pip install funasr
8
7
 
9
8
  import os
@@ -43,27 +42,30 @@ class Emotion2vec(Featureset):
43
42
  except ImportError:
44
43
  self.util.error(
45
44
  "FunASR is required for emotion2vec features. "
46
- "Please install with: pip install funasr modelscope"
45
+ "Please install with: pip install funasr"
47
46
  )
48
47
 
49
- # Map feat_type to model names
48
+ # Map feat_type to model names on HuggingFace
50
49
  model_mapping = {
51
- "emotion2vec": "iic/emotion2vec_base",
52
- "emotion2vec-base": "iic/emotion2vec_base_finetuned",
53
- "emotion2vec-seed": "iic/emotion2vec_plus_seed",
54
- "emotion2vec-large": "iic/emotion2vec_plus_large",
50
+ "emotion2vec": "emotion2vec/emotion2vec_base",
51
+ "emotion2vec-base": "emotion2vec/emotion2vec_base",
52
+ "emotion2vec-seed": "emotion2vec/emotion2vec_plus_seed",
53
+ "emotion2vec-large": "emotion2vec/emotion2vec_plus_large",
55
54
  }
56
55
 
57
56
  # Get model path from config or use default mapping
58
57
  model_path = self.util.config_val(
59
58
  "FEATS",
60
59
  "emotion2vec.model",
61
- model_mapping.get(self.feat_type, "iic/emotion2vec_base"),
60
+ model_mapping.get(self.feat_type, "emotion2vec/emotion2vec_base"),
62
61
  )
63
62
 
64
63
  try:
65
- # Initialize the FunASR model for emotion2vec
66
- self.model = AutoModel(model=model_path)
64
+ # Initialize the FunASR model for emotion2vec using HuggingFace Hub
65
+ self.model = AutoModel(
66
+ model=model_path,
67
+ hub="hf" # Use HuggingFace Hub instead of ModelScope
68
+ )
67
69
  self.util.debug(f"initialized emotion2vec model: {model_path}")
68
70
  self.model_initialized = True
69
71
  except Exception as e:
@@ -131,7 +133,9 @@ class Emotion2vec(Featureset):
131
133
  import tempfile
132
134
  import soundfile as sf
133
135
 
134
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
136
+ with tempfile.NamedTemporaryFile(
137
+ suffix=".wav", delete=False
138
+ ) as tmp_file:
135
139
  sf.write(tmp_file.name, signal_np, sampling_rate)
136
140
  audio_path = tmp_file.name
137
141
  else:
@@ -152,11 +156,20 @@ class Emotion2vec(Featureset):
152
156
  embeddings = np.array(embeddings)
153
157
  return embeddings.flatten()
154
158
  else:
155
- # Fallback to create default embedding
156
- return np.array([0.0] * 768)
159
+ # Fallback based on model type
160
+ if 'large' in self.feat_type.lower():
161
+ return np.array([0.0] * 1024)
162
+ else:
163
+ return np.array([0.0] * 768)
157
164
  else:
158
- self.util.error(f"No result from emotion2vec model for file: {file}")
159
- return np.array([0.0] * 768)
165
+ self.util.error(
166
+ f"No result from emotion2vec model for file: {file}"
167
+ )
168
+ # Fallback based on model type
169
+ if 'large' in self.feat_type.lower():
170
+ return np.array([0.0] * 1024)
171
+ else:
172
+ return np.array([0.0] * 768)
160
173
 
161
174
  finally:
162
175
  # Clean up temporary file if we created one
@@ -166,36 +179,40 @@ class Emotion2vec(Featureset):
166
179
  except Exception as e:
167
180
  print(f"Error processing {file}: {str(e)}")
168
181
  self.util.error(f"couldn't extract file: {file}, error: {str(e)}")
169
- return np.array([0.0] * 768)
182
+ # Return appropriate dimension based on model type
183
+ if 'large' in self.feat_type.lower():
184
+ return np.array([0.0] * 1024)
185
+ else:
186
+ return np.array([0.0] * 768)
170
187
 
171
188
  def extract_sample(self, signal, sr):
172
189
  """Extract features from a single sample."""
173
190
  if not self.model_initialized:
174
191
  self.init_model()
175
-
192
+
176
193
  # Save signal as temporary file for emotion2vec
177
194
  import tempfile
178
195
  import soundfile as sf
179
-
196
+
180
197
  try:
181
198
  # Convert tensor to numpy if needed
182
199
  if torch.is_tensor(signal):
183
200
  signal_np = signal.squeeze().numpy()
184
201
  else:
185
202
  signal_np = signal.squeeze()
186
-
203
+
187
204
  # Handle multi-channel audio
188
205
  if signal_np.ndim > 1:
189
206
  signal_np = signal_np[0]
190
-
207
+
191
208
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
192
209
  sf.write(tmp_file.name, signal_np, sr)
193
-
210
+
194
211
  # Extract using the emotion2vec model
195
212
  res = self.model.generate(
196
213
  tmp_file.name, granularity="utterance", extract_embedding=True
197
214
  )
198
-
215
+
199
216
  # Get embeddings from result
200
217
  if isinstance(res, list) and len(res) > 0:
201
218
  embeddings = res[0].get("feats", None)
@@ -203,12 +220,20 @@ class Emotion2vec(Featureset):
203
220
  if isinstance(embeddings, list):
204
221
  embeddings = np.array(embeddings)
205
222
  return embeddings.flatten()
206
-
207
- return np.array([0.0] * 768)
208
-
223
+
224
+ # Fallback based on model type
225
+ if 'large' in self.feat_type.lower():
226
+ return np.array([0.0] * 1024)
227
+ else:
228
+ return np.array([0.0] * 768)
229
+
209
230
  except Exception as e:
210
231
  print(f"Error in extract_sample: {str(e)}")
211
- return np.array([0.0] * 768)
232
+ # Return appropriate dimension based on model type
233
+ if 'large' in self.feat_type.lower():
234
+ return np.array([0.0] * 1024)
235
+ else:
236
+ return np.array([0.0] * 768)
212
237
  finally:
213
238
  # Clean up temporary file
214
239
  if tmp_file is not None: # Check if tmp_file was created