nkululeko 0.81.4__py3-none-any.whl → 0.81.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. nkululeko/autopredict/estimate_snr.py +17 -6
  2. nkululeko/constants.py +1 -1
  3. nkululeko/data/dataset.py +9 -2
  4. nkululeko/demo.py +20 -5
  5. nkululeko/demo_predictor.py +6 -3
  6. nkululeko/experiment.py +1 -1
  7. nkululeko/explore.py +13 -8
  8. nkululeko/feat_extract/feats_agender.py +7 -8
  9. nkululeko/feat_extract/{feats_audmodel_dim.py → feats_auddim.py} +10 -7
  10. nkululeko/feat_extract/feats_audmodel.py +10 -7
  11. nkululeko/feat_extract/feats_clap.py +10 -6
  12. nkululeko/feat_extract/feats_hubert.py +3 -2
  13. nkululeko/feat_extract/feats_import.py +3 -3
  14. nkululeko/feat_extract/feats_mos.py +4 -3
  15. nkululeko/feat_extract/feats_opensmile.py +10 -24
  16. nkululeko/feat_extract/feats_oxbow.py +16 -11
  17. nkululeko/feat_extract/feats_praat.py +18 -13
  18. nkululeko/feat_extract/feats_snr.py +17 -9
  19. nkululeko/feat_extract/feats_spectra.py +3 -2
  20. nkululeko/feat_extract/feats_squim.py +15 -18
  21. nkululeko/feat_extract/feats_trill.py +10 -6
  22. nkululeko/feat_extract/feats_wav2vec2.py +16 -7
  23. nkululeko/feat_extract/feats_wavlm.py +1 -4
  24. nkululeko/feat_extract/feats_whisper.py +110 -0
  25. nkululeko/feat_extract/featureset.py +6 -3
  26. nkululeko/feature_extractor.py +83 -148
  27. nkululeko/multidb.py +18 -12
  28. nkululeko/predict.py +26 -8
  29. nkululeko/reporter.py +332 -0
  30. nkululeko/resample.py +12 -7
  31. nkululeko/runmanager.py +17 -8
  32. nkululeko/test.py +9 -6
  33. nkululeko/test_predictor.py +1 -0
  34. nkululeko/utils/stats.py +12 -5
  35. {nkululeko-0.81.4.dist-info → nkululeko-0.81.7.dist-info}/METADATA +16 -1
  36. {nkululeko-0.81.4.dist-info → nkululeko-0.81.7.dist-info}/RECORD +39 -37
  37. {nkululeko-0.81.4.dist-info → nkululeko-0.81.7.dist-info}/LICENSE +0 -0
  38. {nkululeko-0.81.4.dist-info → nkululeko-0.81.7.dist-info}/WHEEL +0 -0
  39. {nkululeko-0.81.4.dist-info → nkululeko-0.81.7.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,17 @@
1
- """ feats_snr.py
2
- Estimate snr (signal to noise ratio as acoustic features)
1
+ """ feats_snr.py is to estimate snr.
2
+
3
+ SNR (signal to noise ratio) is extracted as acoustic features.
3
4
  """
4
5
  import os
5
- from tqdm import tqdm
6
- import pandas as pd
6
+
7
7
  import audiofile
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+
8
11
  import nkululeko.glob_conf as glob_conf
9
- from nkululeko.utils.util import Util
10
- from nkululeko.feat_extract.featureset import Featureset
11
12
  from nkululeko.autopredict.estimate_snr import SNREstimator
13
+ from nkululeko.feat_extract.featureset import Featureset
14
+ from nkululeko.utils.util import Util
12
15
 
13
16
 
14
17
  class SNRSet(Featureset):
@@ -16,14 +19,17 @@ class SNRSet(Featureset):
16
19
 
17
20
  def __init__(self, name, data_df):
18
21
  """Constructor."""
22
+
19
23
  super().__init__(name, data_df)
20
24
 
21
25
  def extract(self):
22
26
  """Estimate the features or load them from disk if present."""
27
+
23
28
  store = self.util.get_path("store")
24
29
  store_format = self.util.config_val("FEATS", "store_format", "pkl")
25
30
  storage = f"{store}{self.name}.{store_format}"
26
- extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
31
+ extract = self.util.config_val(
32
+ "FEATS", "needs_feature_extraction", False)
27
33
  no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
28
34
  if extract or no_reuse or not os.path.isfile(storage):
29
35
  self.util.debug("estimating SNR, this might take a while...")
@@ -40,7 +46,8 @@ class SNRSet(Featureset):
40
46
  snr = self.get_snr(signal[0], sampling_rate)
41
47
  snr_series[idx] = snr
42
48
  print("")
43
- self.df = pd.DataFrame(snr_series.values.tolist(), index=self.data_df.index)
49
+ self.df = pd.DataFrame(
50
+ snr_series.values.tolist(), index=self.data_df.index)
44
51
  self.df.columns = ["snr"]
45
52
  self.util.write_store(self.df, storage, store_format)
46
53
  try:
@@ -53,10 +60,11 @@ class SNRSet(Featureset):
53
60
 
54
61
  def get_snr(self, signal, sampling_rate):
55
62
  r"""Estimate SNR from raw audio signal.
63
+
56
64
  Args:
57
65
  signal: audio signal
58
66
  sampling_rate: sample rate
59
- Returns
67
+ Returns:
60
68
  snr: estimated signal to noise ratio
61
69
  """
62
70
  snr_estimator = SNREstimator(signal, sampling_rate)
@@ -4,6 +4,7 @@ feats_spectra.py
4
4
  Inspired by code from Su Lei
5
5
 
6
6
  """
7
+
7
8
  import os
8
9
  import torchaudio
9
10
  import torchaudio.transforms as T
@@ -23,9 +24,9 @@ import nkululeko.glob_conf as glob_conf
23
24
 
24
25
 
25
26
  class Spectraloader(Featureset):
26
- def __init__(self, name, data_df):
27
+ def __init__(self, name, data_df, feat_type):
27
28
  """Constructor setting the name"""
28
- Featureset.__init__(self, name, data_df)
29
+ super().__init__(name, data_df, feat_type)
29
30
  self.sampling_rate = SAMPLING_RATE
30
31
  self.num_bands = int(self.util.config_val("FEATS", "fft_nbands", "64"))
31
32
  self.win_dur = int(self.util.config_val("FEATS", "fft_win_dur", "25"))
@@ -1,41 +1,38 @@
1
- """ feats_squim.py
2
- predict SQUIM ( SPEECH QUALITY AND INTELLIGIBILITY
3
- MEASURES) features
1
+ """Predict SQUIM ( SPEECH QUALITY AND INTELLIGIBILITY MEASURES) features.
4
2
 
5
-
6
- Wideband Perceptual Estimation of Speech Quality (PESQ) [2]
7
- Short-Time Objective Intelligibility (STOI) [3]
8
- Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) [4]
9
-
10
-
11
- adapted from
3
+ Wideband Perceptual Estimation of Speech Quality (PESQ) [2].
4
+ Short-Time Objective Intelligibility (STOI) [3].
5
+ Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) [4].
6
+ Adapted from
12
7
  from https://pytorch.org/audio/main/tutorials/squim_tutorial.html#sphx-glr-tutorials-squim-tutorial-py
13
- paper: https://arxiv.org/pdf/2304.01448.pdf
14
-
15
- needs
8
+ paper: https://arxiv.org/pdf/2304.01448.pdf.
9
+ Needs
16
10
  pip uninstall -y torch torchvision torchaudio
17
11
  pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
18
12
 
19
13
  """
20
14
 
21
15
  import os
22
- from tqdm import tqdm
16
+
23
17
  import pandas as pd
24
18
  import torch
25
19
  import torchaudio
26
20
  from torchaudio.pipelines import SQUIM_OBJECTIVE
21
+ from tqdm import tqdm
22
+
27
23
  import audiofile
24
+
25
+ from nkululeko.feat_extract.featureset import Featureset
28
26
  import nkululeko.glob_conf as glob_conf
29
27
  from nkululeko.utils.util import Util
30
- from nkululeko.feat_extract.featureset import Featureset
31
28
 
32
29
 
33
- class SQUIMSet(Featureset):
30
+ class SquimSet(Featureset):
34
31
  """Class to predict SQUIM features"""
35
32
 
36
- def __init__(self, name, data_df):
33
+ def __init__(self, name, data_df, feats_type):
37
34
  """Constructor. is_train is needed to distinguish from test/dev sets, because they use the codebook from the training"""
38
- super().__init__(name, data_df)
35
+ super().__init__(name, data_df, feats_type)
39
36
  self.device = self.util.config_val("MODEL", "device", "cpu")
40
37
  self.model_initialized = False
41
38
 
@@ -1,4 +1,5 @@
1
1
  # feats_trill.py
2
+ import tensorflow_hub as hub
2
3
  import os
3
4
  import tensorflow as tf
4
5
  from numpy.core.numeric import tensordot
@@ -11,7 +12,6 @@ from nkululeko.feat_extract.featureset import Featureset
11
12
 
12
13
  # Import TF 2.X and make sure we're running eager.
13
14
  assert tf.executing_eagerly()
14
- import tensorflow_hub as hub
15
15
 
16
16
 
17
17
  class TRILLset(Featureset):
@@ -20,7 +20,7 @@ class TRILLset(Featureset):
20
20
  """https://ai.googleblog.com/2020/06/improving-speech-representations-and.html"""
21
21
 
22
22
  # Initialization of the class
23
- def __init__(self, name, data_df):
23
+ def __init__(self, name, data_df, feats_type):
24
24
  """
25
25
  Initialize the class with name, data and Util instance
26
26
  Also loads the model from hub
@@ -31,7 +31,7 @@ class TRILLset(Featureset):
31
31
  :type data_df: DataFrame
32
32
  :return: None
33
33
  """
34
- super().__init__(name, data_df)
34
+ super().__init__(name, data_df, feats_type)
35
35
  # Load the model from the configured path
36
36
  model_path = self.util.config_val(
37
37
  "FEATS",
@@ -39,20 +39,24 @@ class TRILLset(Featureset):
39
39
  "https://tfhub.dev/google/nonsemantic-speech-benchmark/trill/3",
40
40
  )
41
41
  self.module = hub.load(model_path)
42
+ self.feats_type = feats_type
42
43
 
43
44
  def extract(self):
44
45
  store = self.util.get_path("store")
45
46
  storage = f"{store}{self.name}.pkl"
46
- extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
47
+ extract = self.util.config_val(
48
+ "FEATS", "needs_feature_extraction", False)
47
49
  no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
48
50
  if extract or no_reuse or not os.path.isfile(storage):
49
- self.util.debug("extracting TRILL embeddings, this might take a while...")
51
+ self.util.debug(
52
+ "extracting TRILL embeddings, this might take a while...")
50
53
  emb_series = pd.Series(index=self.data_df.index, dtype=object)
51
54
  length = len(self.data_df.index)
52
55
  for idx, file in enumerate(tqdm(self.data_df.index.get_level_values(0))):
53
56
  emb = self.getEmbeddings(file)
54
57
  emb_series[idx] = emb
55
- self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
58
+ self.df = pd.DataFrame(
59
+ emb_series.values.tolist(), index=self.data_df.index)
56
60
  self.df.to_pickle(storage)
57
61
  try:
58
62
  glob_conf.config["DATA"]["needs_feature_extraction"] = "false"
@@ -1,5 +1,11 @@
1
- # feats_wav2vec2.py
2
- # feat_types example = wav2vec2-large-robust-ft-swbd-300h
1
+ """ feats_wav2vec2.py
2
+ feat_types example = [wav2vec2-large-robust-ft-swbd-300h,
3
+ wav2vec2-xls-r-2b, wav2vec2-large, wav2vec2-large-xlsr-53, wav2vec2-base]
4
+
5
+ Complete list: https://huggingface.co/facebook?search_models=wav2vec2
6
+ Currently only supports wav2vec2
7
+ """
8
+
3
9
  import os
4
10
  from tqdm import tqdm
5
11
  import pandas as pd
@@ -16,11 +22,11 @@ class Wav2vec2(Featureset):
16
22
 
17
23
  def __init__(self, name, data_df, feat_type):
18
24
  """Constructor. is_train is needed to distinguish from test/dev sets, because they use the codebook from the training"""
19
- super().__init__(name, data_df)
25
+ super().__init__(name, data_df, feat_type)
20
26
  cuda = "cuda" if torch.cuda.is_available() else "cpu"
21
27
  self.device = self.util.config_val("MODEL", "device", cuda)
22
28
  self.model_initialized = False
23
- if feat_type == "wav2vec" or feat_type == "wav2vec2":
29
+ if feat_type == "wav2vec2":
24
30
  self.feat_type = "wav2vec2-large-robust-ft-swbd-300h"
25
31
  else:
26
32
  self.feat_type = feat_type
@@ -33,7 +39,8 @@ class Wav2vec2(Featureset):
33
39
  )
34
40
  config = transformers.AutoConfig.from_pretrained(model_path)
35
41
  layer_num = config.num_hidden_layers
36
- hidden_layer = int(self.util.config_val("FEATS", "wav2vec2.layer", "0"))
42
+ hidden_layer = int(self.util.config_val(
43
+ "FEATS", "wav2vec2.layer", "0"))
37
44
  config.num_hidden_layers = layer_num - hidden_layer
38
45
  self.util.debug(f"using hidden layer #{config.num_hidden_layers}")
39
46
  self.processor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
@@ -48,7 +55,8 @@ class Wav2vec2(Featureset):
48
55
  """Extract the features or load them from disk if present."""
49
56
  store = self.util.get_path("store")
50
57
  storage = f"{store}{self.name}.pkl"
51
- extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
58
+ extract = self.util.config_val(
59
+ "FEATS", "needs_feature_extraction", False)
52
60
  no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
53
61
  if extract or no_reuse or not os.path.isfile(storage):
54
62
  if not self.model_initialized:
@@ -69,7 +77,8 @@ class Wav2vec2(Featureset):
69
77
  emb = self.get_embeddings(signal, sampling_rate, file)
70
78
  emb_series[idx] = emb
71
79
  # print(f"emb_series shape: {emb_series.shape}")
72
- self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
80
+ self.df = pd.DataFrame(
81
+ emb_series.values.tolist(), index=self.data_df.index)
73
82
  # print(f"df shape: {self.df.shape}")
74
83
  self.df.to_pickle(storage)
75
84
  try:
@@ -59,10 +59,7 @@ class Wavlm(Featureset):
59
59
  frame_offset=int(start.total_seconds() * 16000),
60
60
  num_frames=int((end - start).total_seconds() * 16000),
61
61
  )
62
- if sampling_rate != 16000:
63
- self.util.error(
64
- f"sampling rate should be 16000 but is {sampling_rate}"
65
- )
62
+ assert sampling_rate == 16000, f"sampling rate should be 16000 but is {sampling_rate}"
66
63
  emb = self.get_embeddings(signal, sampling_rate, file)
67
64
  emb_series.iloc[idx] = emb
68
65
  self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
@@ -0,0 +1,110 @@
1
+ # feats_whisper.py
2
+ import os
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from transformers import AutoFeatureExtractor
7
+ from transformers import WhisperModel
8
+
9
+ import audeer
10
+ import audiofile
11
+
12
+ from nkululeko.feat_extract.featureset import Featureset
13
+ import nkululeko.glob_conf as glob_conf
14
+
15
+
16
+ class Whisper(Featureset):
17
+ """Class to extract whisper embeddings."""
18
+
19
+ def __init__(self, name, data_df, feat_type):
20
+ super().__init__(name, data_df, feat_type)
21
+ cuda = "cuda" if torch.cuda.is_available() else "cpu"
22
+ self.device = self.util.config_val("MODEL", "device", cuda)
23
+ self.model_initialized = False
24
+ if feat_type == "whisper":
25
+ self.feat_type = "whisper-base"
26
+ else:
27
+ self.feat_type = feat_type
28
+
29
+ def init_model(self):
30
+ # load model
31
+ self.util.debug("loading whisper model...")
32
+ model_name = f"openai/{self.feat_type}"
33
+ self.model = WhisperModel.from_pretrained(model_name).to(self.device)
34
+ print(f"intialized Whisper model on {self.device}")
35
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
36
+ self.model_initialized = True
37
+
38
+ def extract(self):
39
+ """Extract the features or load them from disk if present."""
40
+ store = self.util.get_path("store")
41
+ storage = f"{store}{self.name}.pkl"
42
+ extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
43
+ no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
44
+ if extract or no_reuse or not os.path.isfile(storage):
45
+ if not self.model_initialized:
46
+ self.init_model()
47
+ self.util.debug("extracting whisper embeddings, this might take a while...")
48
+ emb_series = []
49
+ for (file, start, end), _ in audeer.progress_bar(
50
+ self.data_df.iterrows(),
51
+ total=len(self.data_df),
52
+ desc=f"Running whisper on {len(self.data_df)} audiofiles",
53
+ ):
54
+ if end == pd.NaT:
55
+ signal, sr = audiofile.read(file, offset=start)
56
+ else:
57
+ signal, sr = audiofile.read(
58
+ file, duration=end - start, offset=start
59
+ )
60
+ emb = self.get_embeddings(signal, sr, file)
61
+ emb_series.append(emb)
62
+ # print(f"emb_series shape: {emb_series.shape}")
63
+ self.df = pd.DataFrame(emb_series, index=self.data_df.index)
64
+ # print(f"df shape: {self.df.shape}")
65
+ self.df.to_pickle(storage)
66
+ try:
67
+ glob_conf.config["DATA"]["needs_feature_extraction"] = "false"
68
+ except KeyError:
69
+ pass
70
+ else:
71
+ self.util.debug("reusing extracted wav2vec2 embeddings")
72
+ self.df = pd.read_pickle(storage)
73
+ if self.df.isnull().values.any():
74
+ nanrows = self.df.columns[self.df.isna().any()].tolist()
75
+ # print(nanrows)
76
+ self.util.error(
77
+ f"got nan: {self.df.shape} {self.df.isnull().sum().sum()}"
78
+ )
79
+
80
+ def get_embeddings(self, signal, sampling_rate, file):
81
+ r"""Extract embeddings from raw audio signal."""
82
+ try:
83
+ with torch.no_grad():
84
+ embed_size = self.model.config.hidden_size
85
+ embed_columns = [f"whisper_{i}" for i in range(embed_size)]
86
+ inputs = self.feature_extractor(signal, sampling_rate=16000)[
87
+ "input_features"
88
+ ][0]
89
+ inputs = torch.from_numpy(inputs).to(self.device).unsqueeze(0)
90
+ decoder_input_ids = (
91
+ torch.tensor([[1, 1]]).to(self.device)
92
+ * self.model.config.decoder_start_token_id
93
+ )
94
+ full_outputs = self.model(
95
+ inputs,
96
+ decoder_input_ids=decoder_input_ids,
97
+ output_hidden_states=True,
98
+ )
99
+ outputs = full_outputs.encoder_last_hidden_state[0]
100
+ average_embeds = outputs.squeeze().mean(axis=0).cpu().detach().numpy()
101
+ except RuntimeError as re:
102
+ print(str(re))
103
+ self.util.error(f"couldn't extract file: {file}")
104
+ # print(f"y flattened shape: {y.ravel().shape}")
105
+ return average_embeds
106
+
107
+ def extract_sample(self, signal, sr):
108
+ self.init_model()
109
+ feats = self.get_embeddings(signal, sr, "no file")
110
+ return feats
@@ -7,13 +7,15 @@ import ast
7
7
 
8
8
  class Featureset:
9
9
  name = "" # designation
10
- df = None # pandas dataframe to store the features (and indexed with the data from the sets)
10
+ df = None # pandas dataframe to store the features
11
+ # (and indexed with the data from the sets)
11
12
  data_df = None # dataframe to get audio paths
12
13
 
13
- def __init__(self, name, data_df):
14
+ def __init__(self, name, data_df, feats_type):
14
15
  self.name = name
15
16
  self.data_df = data_df
16
17
  self.util = Util("featureset")
18
+ self.feats_types = feats_type
17
19
 
18
20
  def extract(self):
19
21
  pass
@@ -23,7 +25,8 @@ class Featureset:
23
25
  self.df = self.df[self.df.index.isin(self.data_df.index)]
24
26
  try:
25
27
  # use only some features
26
- selected_features = ast.literal_eval(glob_conf.config["FEATS"]["features"])
28
+ selected_features = ast.literal_eval(
29
+ glob_conf.config["FEATS"]["features"])
27
30
  self.util.debug(f"selecting features: {selected_features}")
28
31
  sel_feats_df = pd.DataFrame()
29
32
  hit = False
@@ -1,22 +1,25 @@
1
- """
2
- feature_extractor.py
3
-
4
- Helper class to encapsulate feature extraction methods
1
+ """Extract acoustic features from audio samples.
5
2
 
3
+ Extract acoustic features using several feature extractors
4
+ (appends the features column-wise)
6
5
  """
6
+
7
7
  import pandas as pd
8
8
 
9
9
  from nkululeko.utils.util import Util
10
10
 
11
11
 
12
12
  class FeatureExtractor:
13
- """
14
- Extract acoustic features from audio samples, using several feature extractors (appends the features column-wise)
13
+ """Extract acoustic features from audio samples.
14
+
15
+ Extract acoustic features using several feature extractors (appends the features column-wise).
16
+
15
17
  Args:
16
18
  data_df (pandas.DataFrame): dataframe with audiofile paths as index
17
- feats_types (array of strings): designations of acoustic feature extractors to be used
18
- data_name (string): names of databases that are extracted (for the caching)
19
- feats_designation (string): the type of split (train/test), also is used for the cache name.
19
+ feats_types (List[str]): designations of acoustic feature extractors to be used
20
+ data_name (str): name of databases that are extracted (for caching)
21
+ feats_designation (str): the type of split (train/test), also is used for the cache name.
22
+
20
23
  Returns:
21
24
  df (pandas.DataFrame): dataframe with same index as data_df and acoustic features in columns
22
25
  """
@@ -25,7 +28,6 @@ class FeatureExtractor:
25
28
  df = None
26
29
  data_df = None # dataframe to get audio paths
27
30
 
28
- # def __init__
29
31
  def __init__(self, data_df, feats_types, data_name, feats_designation):
30
32
  self.data_df = data_df
31
33
  self.data_name = data_name
@@ -34,147 +36,80 @@ class FeatureExtractor:
34
36
  self.feats_designation = feats_designation
35
37
 
36
38
  def extract(self):
37
- # feats_types = self.util.config_val_list('FEATS', 'type', ['os'])
38
- self.featExtractor = None
39
39
  self.feats = pd.DataFrame()
40
- _scale = True
41
40
  for feats_type in self.feats_types:
42
41
  store_name = f"{self.data_name}_{feats_type}"
43
- if feats_type == "os":
44
- from nkululeko.feat_extract.feats_opensmile import Opensmileset
45
-
46
- self.featExtractor = Opensmileset(
47
- f"{store_name}_{self.feats_designation}", self.data_df
48
- )
49
- elif feats_type == "spectra":
50
- from nkululeko.feat_extract.feats_spectra import Spectraloader
51
-
52
- self.featExtractor = Spectraloader(
53
- f"{store_name}_{self.feats_designation}", self.data_df
54
- )
55
- elif feats_type == "trill":
56
- from nkululeko.feat_extract.feats_trill import TRILLset
57
-
58
- self.featExtractor = TRILLset(
59
- f"{store_name}_{self.feats_designation}", self.data_df
60
- )
61
- elif feats_type.startswith("wav2vec"):
62
- from nkululeko.feat_extract.feats_wav2vec2 import Wav2vec2
63
-
64
- self.featExtractor = Wav2vec2(
65
- f"{store_name}_{self.feats_designation}",
66
- self.data_df,
67
- feats_type,
68
- )
69
- elif feats_type.startswith("hubert"):
70
- from nkululeko.feat_extract.feats_hubert import Hubert
71
-
72
- self.featExtractor = Hubert(
73
- f"{store_name}_{self.feats_designation}",
74
- self.data_df,
75
- feats_type,
76
- )
77
-
78
- elif feats_type.startswith("wavlm"):
79
- from nkululeko.feat_extract.feats_wavlm import Wavlm
80
-
81
- self.featExtractor = Wavlm(
82
- f"{store_name}_{self.feats_designation}",
83
- self.data_df,
84
- feats_type,
85
- )
86
-
87
- elif feats_type.startswith("spkrec"):
88
- from nkululeko.feat_extract.feats_spkrec import Spkrec
89
-
90
- self.featExtractor = Spkrec(
91
- f"{store_name}_{self.feats_designation}",
92
- self.data_df,
93
- feats_type,
94
- )
95
- elif feats_type == "audmodel":
96
- from nkululeko.feat_extract.feats_audmodel import AudModelSet
97
-
98
- self.featExtractor = AudModelSet(
99
- f"{store_name}_{self.feats_designation}", self.data_df
100
- )
101
- elif feats_type == "auddim":
102
- from nkululeko.feat_extract.feats_audmodel_dim import (
103
- AudModelDimSet,
104
- )
105
-
106
- self.featExtractor = AudModelDimSet(
107
- f"{store_name}_{self.feats_designation}", self.data_df
108
- )
109
- elif feats_type == "agender":
110
- from nkululeko.feat_extract.feats_agender import (
111
- AudModelAgenderSet,
112
- )
113
-
114
- self.featExtractor = AudModelAgenderSet(
115
- f"{store_name}_{self.feats_designation}", self.data_df
116
- )
117
- elif feats_type == "agender_agender":
118
- from nkululeko.feat_extract.feats_agender_agender import (
119
- AgenderAgenderSet,
120
- )
121
-
122
- self.featExtractor = AgenderAgenderSet(
123
- f"{store_name}_{self.feats_designation}", self.data_df
124
- )
125
- elif feats_type == "snr":
126
- from nkululeko.feat_extract.feats_snr import SNRSet
127
-
128
- self.featExtractor = SNRSet(
129
- f"{store_name}_{self.feats_designation}", self.data_df
130
- )
131
- elif feats_type == "mos":
132
- from nkululeko.feat_extract.feats_mos import MOSSet
133
-
134
- self.featExtractor = MOSSet(
135
- f"{store_name}_{self.feats_designation}", self.data_df
136
- )
137
- elif feats_type == "squim":
138
- from nkululeko.feat_extract.feats_squim import SQUIMSet
139
-
140
- self.featExtractor = SQUIMSet(
141
- f"{store_name}_{self.feats_designation}", self.data_df
142
- )
143
- elif feats_type == "clap":
144
- from nkululeko.feat_extract.feats_clap import Clap
145
-
146
- self.featExtractor = Clap(
147
- f"{store_name}_{self.feats_designation}", self.data_df
148
- )
149
- elif feats_type == "praat":
150
- from nkululeko.feat_extract.feats_praat import Praatset
151
-
152
- self.featExtractor = Praatset(
153
- f"{store_name}_{self.feats_designation}", self.data_df
154
- )
155
- elif feats_type == "mld":
156
- from nkululeko.feat_extract.feats_mld import MLD_set
157
-
158
- self.featExtractor = MLD_set(
159
- f"{store_name}_{self.feats_designation}", self.data_df
160
- )
161
- elif feats_type == "import":
162
- from nkululeko.feat_extract.feats_import import Importset
163
-
164
- self.featExtractor = Importset(
165
- f"{store_name}_{self.feats_designation}", self.data_df
166
- )
167
- else:
168
- self.util.error(f"unknown feats_type: {feats_type}")
169
-
170
- self.featExtractor.extract()
171
- self.featExtractor.filter()
172
- # remove samples that were not extracted by MLD
173
- # self.df_test = self.df_test.loc[self.df_test.index.intersection(featExtractor_test.df.index)]
174
- # self.df_train = self.df_train.loc[self.df_train.index.intersection(featExtractor_train.df.index)]
175
- self.util.debug(f"{feats_type}: shape : {self.featExtractor.df.shape}")
176
- self.feats = pd.concat([self.feats, self.featExtractor.df], axis=1)
42
+ self.feat_extractor = self._get_feat_extractor(store_name, feats_type)
43
+ self.feat_extractor.extract()
44
+ self.feat_extractor.filter()
45
+ self.feats = pd.concat([self.feats, self.feat_extractor.df], axis=1)
177
46
  return self.feats
178
47
 
179
48
  def extract_sample(self, signal, sr):
180
- return self.featExtractor.extract_sample(signal, sr)
49
+ return self.feat_extractor.extract_sample(signal, sr)
50
+
51
+ def _get_feat_extractor(self, store_name, feats_type):
52
+ feat_extractor_class = self._get_feat_extractor_class(feats_type)
53
+ if feat_extractor_class is None:
54
+ self.util.error(f"unknown feats_type: {feats_type}")
55
+ return feat_extractor_class(
56
+ f"{store_name}_{self.feats_designation}", self.data_df, feats_type
57
+ )
58
+
59
+ def _get_feat_extractor_class(self, feats_type):
60
+ if feats_type == "os":
61
+ from nkululeko.feat_extract.feats_opensmile import Opensmileset
62
+
63
+ return Opensmileset
64
+
65
+ elif feats_type == "spectra":
66
+ from nkululeko.feat_extract.feats_spectra import Spectraloader
67
+
68
+ return Spectraloader
69
+
70
+ elif feats_type == "trill":
71
+ from nkululeko.feat_extract.feats_trill import TRILLset
72
+
73
+ return TRILLset
74
+
75
+ elif feats_type.startswith(
76
+ ("wav2vec2", "hubert", "wavlm", "spkrec", "whisper")
77
+ ):
78
+ return self._get_feat_extractor_by_prefix(feats_type)
79
+
80
+ elif feats_type == "xbow":
81
+ from nkululeko.feat_extract.feats_oxbow import Openxbow
82
+
83
+ return Openxbow
84
+
85
+ elif feats_type in (
86
+ "audmodel",
87
+ "auddim",
88
+ "agender",
89
+ "agender_agender",
90
+ "snr",
91
+ "mos",
92
+ "squim",
93
+ "clap",
94
+ "praat",
95
+ "mld",
96
+ "import",
97
+ ):
98
+ return self._get_feat_extractor_by_name(feats_type)
99
+ else:
100
+ return None
101
+
102
+ def _get_feat_extractor_by_prefix(self, feats_type):
103
+ prefix, _, ext = feats_type.partition("-")
104
+ from importlib import import_module
105
+
106
+ module = import_module(f"nkululeko.feat_extract.feats_{prefix.lower()}")
107
+ class_name = f"{prefix.capitalize()}"
108
+ return getattr(module, class_name)
109
+
110
+ def _get_feat_extractor_by_name(self, feats_type):
111
+ from importlib import import_module
112
+
113
+ module = import_module(f"nkululeko.feat_extract.feats_{feats_type.lower()}")
114
+ class_name = f"{feats_type.capitalize()}Set"
115
+ return getattr(module, class_name)