nkululeko 0.94.3__py3-none-any.whl → 0.95.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. nkululeko/augmenting/resampler.py +5 -2
  2. nkululeko/autopredict/ap_emotion.py +36 -0
  3. nkululeko/autopredict/ap_text.py +45 -0
  4. nkululeko/autopredict/tests/__init__.py +0 -0
  5. nkululeko/autopredict/tests/test_whisper_transcriber.py +122 -0
  6. nkululeko/autopredict/whisper_transcriber.py +81 -0
  7. nkululeko/balance.py +222 -0
  8. nkululeko/constants.py +1 -1
  9. nkululeko/experiment.py +53 -3
  10. nkululeko/explore.py +32 -13
  11. nkululeko/feat_extract/feats_analyser.py +45 -17
  12. nkululeko/feat_extract/feats_emotion2vec.py +51 -26
  13. nkululeko/feat_extract/feats_praat.py +3 -3
  14. nkululeko/feat_extract/feats_praat_core.py +769 -0
  15. nkululeko/feat_extract/tests/__init__.py +1 -0
  16. nkululeko/feat_extract/tests/test_feats_opensmile.py +162 -0
  17. nkululeko/feat_extract/tests/test_feats_praat_core.py +507 -0
  18. nkululeko/glob_conf.py +9 -0
  19. nkululeko/modelrunner.py +15 -39
  20. nkululeko/models/model.py +4 -42
  21. nkululeko/models/model_tuned.py +416 -84
  22. nkululeko/models/model_xgb.py +148 -2
  23. nkululeko/models/tests/test_model_knn.py +49 -0
  24. nkululeko/models/tests/test_model_mlp.py +153 -0
  25. nkululeko/models/tests/test_model_xgb.py +33 -0
  26. nkululeko/nkululeko.py +0 -9
  27. nkululeko/plots.py +25 -19
  28. nkululeko/predict.py +8 -6
  29. nkululeko/reporting/report.py +7 -5
  30. nkululeko/reporting/reporter.py +20 -5
  31. nkululeko/test_predictor.py +7 -1
  32. nkululeko/tests/__init__.py +1 -0
  33. nkululeko/tests/test_balancing.py +270 -0
  34. nkululeko/utils/util.py +38 -6
  35. {nkululeko-0.94.3.dist-info → nkululeko-0.95.1.dist-info}/METADATA +1 -1
  36. {nkululeko-0.94.3.dist-info → nkululeko-0.95.1.dist-info}/RECORD +40 -27
  37. nkululeko/feat_extract/feats_opensmile copy.py +0 -93
  38. nkululeko/feat_extract/feinberg_praat.py +0 -628
  39. {nkululeko-0.94.3.dist-info → nkululeko-0.95.1.dist-info}/WHEEL +0 -0
  40. {nkululeko-0.94.3.dist-info → nkululeko-0.95.1.dist-info}/entry_points.txt +0 -0
  41. {nkululeko-0.94.3.dist-info → nkululeko-0.95.1.dist-info}/licenses/LICENSE +0 -0
  42. {nkululeko-0.94.3.dist-info → nkululeko-0.95.1.dist-info}/top_level.txt +0 -0
@@ -68,7 +68,9 @@ class Resampler:
68
68
  self.df.index.set_levels(new_files, level="file")
69
69
  )
70
70
  if not self.not_testing:
71
- target_file = self.util.config_val("RESAMPLE", "target", "resampled.csv")
71
+ target_file = self.util.config_val(
72
+ "RESAMPLE", "target", "resampled.csv"
73
+ )
72
74
  # remove encoded labels
73
75
  target = self.util.config_val("DATA", "target", "emotion")
74
76
  if "class_label" in self.df.columns:
@@ -77,7 +79,8 @@ class Resampler:
77
79
  # save file
78
80
  self.df.to_csv(target_file)
79
81
  self.util.debug(
80
- "saved resampled list of files to" f" {os.path.abspath(target_file)}"
82
+ "saved resampled list of files to"
83
+ f" {os.path.abspath(target_file)}"
81
84
  )
82
85
  else:
83
86
  # When running from command line, save to simple resampled.csv
@@ -0,0 +1,36 @@
1
+ """
2
+ A predictor for emotion classification.
3
+ Uses emotion2vec models for emotion prediction.
4
+ """
5
+
6
+ import ast
7
+
8
+ import nkululeko.glob_conf as glob_conf
9
+ from nkululeko.feature_extractor import FeatureExtractor
10
+ from nkululeko.utils.util import Util
11
+
12
+
13
+ class EmotionPredictor:
14
+ """
15
+ EmotionPredictor
16
+ predicting emotion with emotion2vec models
17
+ """
18
+
19
+ def __init__(self, df):
20
+ self.df = df
21
+ self.util = Util("emotionPredictor")
22
+
23
+ def predict(self, split_selection):
24
+ self.util.debug(f"predicting emotion for {split_selection} samples")
25
+ feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
26
+
27
+ self.feature_extractor = FeatureExtractor(
28
+ self.df, ["emotion2vec-large"], feats_name, split_selection
29
+ )
30
+ emotion_df = self.feature_extractor.extract()
31
+
32
+ pred_emotion = ["neutral"] * len(emotion_df)
33
+
34
+ return_df = self.df.copy()
35
+ return_df["emotion_pred"] = pred_emotion
36
+ return return_df
@@ -0,0 +1,45 @@
1
+ """A predictor for text.
2
+
3
+ Currently based on whisper model.
4
+ """
5
+
6
+ import ast
7
+
8
+ import torch
9
+
10
+ from nkululeko.feature_extractor import FeatureExtractor
11
+ import nkululeko.glob_conf as glob_conf
12
+ from nkululeko.utils.util import Util
13
+
14
+
15
+ class TextPredictor:
16
+ """TextPredictor.
17
+
18
+ predicting text with the whisper model
19
+ """
20
+
21
+ def __init__(self, df, util=None):
22
+ self.df = df
23
+ if util is not None:
24
+ self.util = util
25
+ else:
26
+ # create a new util instance
27
+ # this is needed to access the config and other utilities
28
+ # in the autopredict module
29
+ self.util = Util("textPredictor")
30
+ from nkululeko.autopredict.whisper_transcriber import Transcriber
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ device = self.util.config_val("MODEL", "device", device)
33
+ self.transcriber = Transcriber(
34
+ device=device,
35
+ language=self.util.config_val("EXP", "language", "en"),
36
+ util=self.util,
37
+ )
38
+ def predict(self, split_selection):
39
+ self.util.debug(f"predicting text for {split_selection} samples")
40
+ df = self.transcriber.transcribe_index(
41
+ self.df.index
42
+ )
43
+ return_df = self.df.copy()
44
+ return_df["text"] = df["text"].values
45
+ return return_df
File without changes
@@ -0,0 +1,122 @@
1
+ import os
2
+ import tempfile
3
+ from datetime import timedelta
4
+ from unittest.mock import MagicMock, Mock, patch
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pytest
9
+
10
+ from nkululeko.autopredict.whisper_transcriber import Transcriber
11
+
12
+
13
+ class TestTranscriber:
14
+
15
+ @patch('nkululeko.autopredict.whisper_transcriber.whisper.load_model')
16
+ @patch('nkululeko.autopredict.whisper_transcriber.torch.cuda.is_available')
17
+ def test_init_default_device(self, mock_cuda, mock_load_model):
18
+ mock_cuda.return_value = True
19
+ mock_model = Mock()
20
+ mock_load_model.return_value = mock_model
21
+
22
+ transcriber = Transcriber()
23
+
24
+ mock_load_model.assert_called_once_with("turbo", device="cuda")
25
+ assert transcriber.language == "en"
26
+ assert transcriber.model == mock_model
27
+
28
+ @patch('nkululeko.autopredict.whisper_transcriber.whisper.load_model')
29
+ def test_init_custom_params(self, mock_load_model):
30
+ mock_model = Mock()
31
+ mock_load_model.return_value = mock_model
32
+ mock_util = Mock()
33
+
34
+ transcriber = Transcriber(model_name="base", device="cpu", language="es", util=mock_util)
35
+
36
+ mock_load_model.assert_called_once_with("base", device="cpu")
37
+ assert transcriber.language == "es"
38
+ assert transcriber.util == mock_util
39
+
40
+ def test_transcribe_file(self):
41
+ mock_model = Mock()
42
+ mock_model.transcribe.return_value = {"text": " Hello world "}
43
+
44
+ transcriber = Transcriber()
45
+ transcriber.model = mock_model
46
+
47
+ result = transcriber.transcribe_file("test.wav")
48
+
49
+ mock_model.transcribe.assert_called_once_with("test.wav", language="en", without_timestamps=True)
50
+ assert result == "Hello world"
51
+
52
+ @patch('nkululeko.autopredict.whisper_transcriber.audiofile.write')
53
+ def test_transcribe_array(self, mock_write):
54
+ transcriber = Transcriber()
55
+ transcriber.transcribe_file = Mock(return_value="transcribed text")
56
+
57
+ signal = np.array([0.1, 0.2, 0.3])
58
+ sampling_rate = 16000
59
+
60
+ result = transcriber.transcribe_array(signal, sampling_rate)
61
+
62
+ mock_write.assert_called_once_with("temp.wav", signal, sampling_rate, format="wav")
63
+ transcriber.transcribe_file.assert_called_once_with("temp.wav")
64
+ assert result == "transcribed text"
65
+
66
+ @patch('nkululeko.autopredict.whisper_transcriber.audiofile.read')
67
+ @patch('nkululeko.autopredict.whisper_transcriber.audeer.mkdir')
68
+ @patch('nkululeko.autopredict.whisper_transcriber.audeer.path')
69
+ @patch('nkululeko.autopredict.whisper_transcriber.audeer.basename_wo_ext')
70
+ @patch('nkululeko.autopredict.whisper_transcriber.os.path.isfile')
71
+ def test_transcribe_index_with_cache(self, mock_isfile, mock_basename, mock_path, mock_mkdir, mock_read):
72
+ mock_util = Mock()
73
+ mock_util.get_path.return_value = "/cache"
74
+ mock_util.read_json.return_value = {"transcription": "cached text"}
75
+
76
+ mock_mkdir.return_value = "/cache/transcriptions"
77
+ mock_path.side_effect = lambda *args: "/".join(args)
78
+ mock_basename.return_value = "file1"
79
+ mock_isfile.return_value = True
80
+
81
+ transcriber = Transcriber(util=mock_util)
82
+
83
+ index = pd.Index([
84
+ ("file1.wav", timedelta(seconds=0), timedelta(seconds=1))
85
+ ])
86
+
87
+ result = transcriber.transcribe_index(index)
88
+
89
+ assert isinstance(result, pd.DataFrame)
90
+ assert len(result) == 1
91
+ assert result.iloc[0]["text"] == "cached text"
92
+
93
+ @patch('nkululeko.autopredict.whisper_transcriber.whisper.load_model')
94
+ @patch('nkululeko.autopredict.whisper_transcriber.audiofile.read')
95
+ @patch('nkululeko.autopredict.whisper_transcriber.audeer.mkdir')
96
+ @patch('nkululeko.autopredict.whisper_transcriber.audeer.path')
97
+ @patch('nkululeko.autopredict.whisper_transcriber.audeer.basename_wo_ext')
98
+ @patch('nkululeko.autopredict.whisper_transcriber.os.path.isfile')
99
+ def test_transcribe_index_without_cache(self, mock_isfile, mock_basename, mock_path, mock_mkdir, mock_audioread, mock_load_model):
100
+ mock_util = Mock()
101
+ mock_util.get_path.return_value = "/cache"
102
+
103
+ mock_mkdir.return_value = "/cache/transcriptions"
104
+ mock_path.side_effect = lambda *args: "/".join(args)
105
+ mock_basename.return_value = "file1"
106
+ mock_isfile.return_value = False
107
+ mock_audioread.return_value = (np.array([0.1, 0.2]), 16000)
108
+ mock_load_model.return_value = Mock()
109
+
110
+ transcriber = Transcriber(util=mock_util)
111
+ transcriber.transcribe_array = Mock(return_value="new transcription")
112
+
113
+ index = pd.Index([
114
+ ("file1.wav", timedelta(seconds=0), timedelta(seconds=1))
115
+ ])
116
+
117
+ result = transcriber.transcribe_index(index)
118
+
119
+ mock_util.save_json.assert_called_once()
120
+ assert isinstance(result, pd.DataFrame)
121
+ assert len(result) == 1
122
+ assert result.iloc[0]["text"] == "new transcription"
@@ -0,0 +1,81 @@
1
+ import os
2
+
3
+ import pandas as pd
4
+ import torch
5
+ from tqdm import tqdm
6
+ import whisper
7
+
8
+ import audeer
9
+ import audiofile
10
+
11
+ from nkululeko.utils.util import Util
12
+
13
+
14
+ class Transcriber:
15
+ def __init__(self, model_name="turbo", device=None, language="en", util=None):
16
+ if device is None:
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ self.model = whisper.load_model(model_name, device=device)
19
+ self.language = language
20
+ self.util = util
21
+
22
+ def transcribe_file(self, audio_path):
23
+ """Transcribe the audio file at the given path.
24
+
25
+ :param audio_path: Path to the audio file to transcribe.
26
+ :return: Transcription text.
27
+ """
28
+ result = self.model.transcribe(
29
+ audio_path, language=self.language, without_timestamps=True)
30
+ result = result["text"].strip()
31
+ return result
32
+
33
+ def transcribe_array(self, signal, sampling_rate):
34
+ """Transcribe the audio file at the given path.
35
+
36
+ :param audio_path: Path to the audio file to transcribe.
37
+ :return: Transcription text.
38
+ """
39
+ tmporary_path = "temp.wav"
40
+ audiofile.write(
41
+ "temp.wav", signal, sampling_rate, format="wav")
42
+ result = self.transcribe_file(tmporary_path)
43
+ return result
44
+
45
+ def transcribe_index(self, index:pd.Index) -> pd.DataFrame:
46
+ """Transcribe the audio files in the given index.
47
+
48
+ :param index: Index containing tuples of (file, start, end).
49
+ :return: DataFrame with transcriptions indexed by the original index.
50
+ :rtype: pd.DataFrame
51
+ """
52
+ file_name = ""
53
+ seg_index = 0
54
+ transcriptions = []
55
+ transcriber_cache = audeer.mkdir(
56
+ audeer.path(self.util.get_path("cache"), "transcriptions"))
57
+ for idx, (file, start, end) in enumerate(
58
+ tqdm(index.to_list())
59
+ ):
60
+ if file != file_name:
61
+ file_name = file
62
+ seg_index = 0
63
+ cache_name = audeer.basename_wo_ext(file)+str(seg_index)
64
+ cache_path = audeer.path(transcriber_cache, cache_name + ".json")
65
+ if os.path.isfile(cache_path):
66
+ transcription = self.util.read_json(cache_path)["transcription"]
67
+ else:
68
+ dur = end.total_seconds() - start.total_seconds()
69
+ y, sr = audiofile.read(file, offset=start, duration=dur)
70
+ transcription = self.transcribe_array(
71
+ y, sr)
72
+ self.util.save_json(cache_path,
73
+ {"transcription": transcription,
74
+ "file": file,
75
+ "start": start.total_seconds(),
76
+ "end": end.total_seconds()})
77
+ transcriptions.append(transcription)
78
+ seg_index += 1
79
+
80
+ df = pd.DataFrame({"text":transcriptions}, index=index)
81
+ return df
nkululeko/balance.py ADDED
@@ -0,0 +1,222 @@
1
+ # balance.py
2
+ """
3
+ Data and feature balancing module for imbalanced datasets.
4
+
5
+ This module provides a unified interface for various balancing techniques
6
+ including over-sampling, under-sampling, and combination methods.
7
+ """
8
+
9
+ import pandas as pd
10
+ import numpy as np
11
+ from nkululeko.utils.util import Util
12
+ import nkululeko.glob_conf as glob_conf
13
+
14
+
15
+ class DataBalancer:
16
+ """Class to handle data and feature balancing operations."""
17
+
18
+ def __init__(self, random_state=42):
19
+ """
20
+ Initialize the DataBalancer.
21
+
22
+ Args:
23
+ random_state (int): Random state for reproducible results
24
+ """
25
+ self.util = Util("data_balancer")
26
+ self.random_state = random_state
27
+
28
+ # Supported balancing algorithms
29
+ self.oversampling_methods = [
30
+ 'ros', # RandomOverSampler
31
+ 'smote', # SMOTE
32
+ 'adasyn', # ADASYN
33
+ 'borderlinesmote', # BorderlineSMOTE
34
+ 'svmsmote' # SVMSMOTE
35
+ ]
36
+
37
+ self.undersampling_methods = [
38
+ 'clustercentroids', # ClusterCentroids
39
+ 'randomundersampler', # RandomUnderSampler
40
+ 'editednearestneighbours', # EditedNearestNeighbours
41
+ 'tomeklinks' # TomekLinks
42
+ ]
43
+
44
+ self.combination_methods = [
45
+ 'smoteenn', # SMOTEENN
46
+ 'smotetomek' # SMOTETomek
47
+ ]
48
+
49
+ def get_supported_methods(self):
50
+ """Get all supported balancing methods."""
51
+ return {
52
+ 'oversampling': self.oversampling_methods,
53
+ 'undersampling': self.undersampling_methods,
54
+ 'combination': self.combination_methods
55
+ }
56
+
57
+ def is_valid_method(self, method):
58
+ """Check if a balancing method is supported."""
59
+ all_methods = (self.oversampling_methods +
60
+ self.undersampling_methods +
61
+ self.combination_methods)
62
+ return method.lower() in all_methods
63
+
64
+ def balance_features(self, df_train, feats_train, target_column, method):
65
+ """
66
+ Balance features using the specified method.
67
+
68
+ Args:
69
+ df_train (pd.DataFrame): Training dataframe with target labels
70
+ feats_train (np.ndarray or pd.DataFrame): Training features
71
+ target_column (str): Name of the target column
72
+ method (str): Balancing method to use
73
+
74
+ Returns:
75
+ tuple: (balanced_df, balanced_features)
76
+ """
77
+ if not self.is_valid_method(method):
78
+ available_methods = (self.oversampling_methods +
79
+ self.undersampling_methods +
80
+ self.combination_methods)
81
+ self.util.error(
82
+ f"Unknown balancing algorithm: {method}. "
83
+ f"Available methods: {available_methods}"
84
+ )
85
+ return df_train, feats_train
86
+
87
+ orig_size = len(df_train)
88
+ self.util.debug(f"Balancing features with: {method}")
89
+ self.util.debug(f"Original dataset size: {orig_size}")
90
+
91
+ # Get original class distribution
92
+ orig_dist = df_train[target_column].value_counts().to_dict()
93
+ self.util.debug(f"Original class distribution: {orig_dist}")
94
+
95
+ try:
96
+ # Apply the specified balancing method
97
+ X_res, y_res = self._apply_balancing_method(
98
+ feats_train, df_train[target_column], method
99
+ )
100
+
101
+ # Create new balanced dataframe
102
+ balanced_df = pd.DataFrame({target_column: y_res})
103
+
104
+ # If original dataframe has an index, try to preserve it
105
+ if hasattr(X_res, 'index'):
106
+ balanced_df.index = X_res.index
107
+
108
+ new_size = len(balanced_df)
109
+ new_dist = balanced_df[target_column].value_counts().to_dict()
110
+
111
+ self.util.debug(f"Balanced dataset size: {new_size} (was {orig_size})")
112
+ self.util.debug(f"New class distribution: {new_dist}")
113
+
114
+ # Log class distribution with label names if encoder is available
115
+ self._log_class_distribution(y_res, method)
116
+
117
+ return balanced_df, X_res
118
+
119
+ except Exception as e:
120
+ self.util.debug(f"Error applying {method} balancing: {str(e)}")
121
+ # Don't call sys.exit() in tests, just return original data
122
+ return df_train, feats_train
123
+
124
+ def _apply_balancing_method(self, features, targets, method):
125
+ """Apply the specific balancing method."""
126
+ method = method.lower()
127
+
128
+ # Over-sampling methods
129
+ if method == 'ros':
130
+ from imblearn.over_sampling import RandomOverSampler
131
+ sampler = RandomOverSampler(random_state=self.random_state)
132
+
133
+ elif method == 'smote':
134
+ from imblearn.over_sampling import SMOTE
135
+ sampler = SMOTE(random_state=self.random_state)
136
+
137
+ elif method == 'adasyn':
138
+ from imblearn.over_sampling import ADASYN
139
+ sampler = ADASYN(random_state=self.random_state)
140
+
141
+ elif method == 'borderlinesmote':
142
+ from imblearn.over_sampling import BorderlineSMOTE
143
+ sampler = BorderlineSMOTE(random_state=self.random_state)
144
+
145
+ elif method == 'svmsmote':
146
+ from imblearn.over_sampling import SVMSMOTE
147
+ sampler = SVMSMOTE(random_state=self.random_state)
148
+
149
+ # Under-sampling methods
150
+ elif method == 'clustercentroids':
151
+ from imblearn.under_sampling import ClusterCentroids
152
+ sampler = ClusterCentroids(random_state=self.random_state)
153
+
154
+ elif method == 'randomundersampler':
155
+ from imblearn.under_sampling import RandomUnderSampler
156
+ sampler = RandomUnderSampler(random_state=self.random_state)
157
+
158
+ elif method == 'editednearestneighbours':
159
+ from imblearn.under_sampling import EditedNearestNeighbours
160
+ sampler = EditedNearestNeighbours()
161
+
162
+ elif method == 'tomeklinks':
163
+ from imblearn.under_sampling import TomekLinks
164
+ sampler = TomekLinks()
165
+
166
+ # Combination methods
167
+ elif method == 'smoteenn':
168
+ from imblearn.combine import SMOTEENN
169
+ sampler = SMOTEENN(random_state=self.random_state)
170
+
171
+ elif method == 'smotetomek':
172
+ from imblearn.combine import SMOTETomek
173
+ sampler = SMOTETomek(random_state=self.random_state)
174
+
175
+ else:
176
+ raise ValueError(f"Unsupported balancing method: {method}")
177
+
178
+ # Apply the balancing
179
+ X_res, y_res = sampler.fit_resample(features, targets)
180
+ return X_res, y_res
181
+
182
+ def _log_class_distribution(self, y_res, method):
183
+ """Log class distribution with label names if possible."""
184
+ # Check if label encoder is available for pretty printing
185
+ if (hasattr(glob_conf, "label_encoder") and
186
+ glob_conf.label_encoder is not None):
187
+ try:
188
+ le = glob_conf.label_encoder
189
+ res = pd.Series(y_res).value_counts()
190
+ resd = {}
191
+ for i, label_idx in enumerate(res.index.values):
192
+ label_name = le.inverse_transform([label_idx])[0]
193
+ resd[label_name] = res.values[i]
194
+ self.util.debug(f"Class distribution after {method} balancing: {resd}")
195
+ except Exception as e:
196
+ self.util.debug(
197
+ f"Could not decode class labels: {e}. "
198
+ f"Showing numeric distribution: {pd.Series(y_res).value_counts().to_dict()}"
199
+ )
200
+ else:
201
+ self.util.debug(
202
+ f"Label encoder not available. "
203
+ f"Class distribution after {method} balancing: {pd.Series(y_res).value_counts().to_dict()}"
204
+ )
205
+
206
+
207
+ class LegacyDataBalancer:
208
+ """Legacy data balancer for backward compatibility."""
209
+
210
+ def __init__(self):
211
+ self.util = Util("legacy_data_balancer")
212
+
213
+ def balance_data(self, df_train, df_test):
214
+ """
215
+ Legacy method for data balancing (kept for backward compatibility).
216
+
217
+ This method should be replaced by the new DataBalancer class.
218
+ """
219
+ self.util.debug("Using legacy data balancing method")
220
+ # Implementation for legacy balance_data method would go here
221
+ # For now, just return the original data unchanged
222
+ return df_train, df_test
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.94.3"
1
+ VERSION="0.95.1"
2
2
  SAMPLING_RATE = 16000
nkululeko/experiment.py CHANGED
@@ -513,7 +513,7 @@ class Experiment:
513
513
 
514
514
  def autopredict(self):
515
515
  """Predict labels for samples with existing models and add to the dataframe."""
516
- sample_selection = self.util.config_val("PREDICT", "split", "all")
516
+ sample_selection = self.util.config_val("PREDICT", "sample_selection", "all")
517
517
  if sample_selection == "all":
518
518
  df = pd.concat([self.df_train, self.df_test])
519
519
  elif sample_selection == "train":
@@ -569,6 +569,11 @@ class Experiment:
569
569
 
570
570
  predictor = STOIPredictor(df)
571
571
  df = predictor.predict(sample_selection)
572
+ elif target == "text":
573
+ from nkululeko.autopredict.ap_text import TextPredictor
574
+
575
+ predictor = TextPredictor(df, self.util)
576
+ df = predictor.predict(sample_selection)
572
577
  elif target == "arousal":
573
578
  from nkululeko.autopredict.ap_arousal import ArousalPredictor
574
579
 
@@ -584,6 +589,11 @@ class Experiment:
584
589
 
585
590
  predictor = DominancePredictor(df)
586
591
  df = predictor.predict(sample_selection)
592
+ elif target == "emotion":
593
+ from nkululeko.autopredict.ap_emotion import EmotionPredictor
594
+
595
+ predictor = EmotionPredictor(df)
596
+ df = predictor.predict(sample_selection)
587
597
  else:
588
598
  self.util.error(f"unknown auto predict target: {target}")
589
599
  return df
@@ -668,11 +678,27 @@ class Experiment:
668
678
 
669
679
  # check if a scatterplot should be done
670
680
  scatter_var = eval(self.util.config_val("EXPL", "scatter", "False"))
681
+
682
+ # Priority: use [EXPL][scatter.target] if available, otherwise use [DATA][target] value
683
+ if hasattr(self, "target") and self.target != "none":
684
+ default_scatter_target = f"['{self.target}']"
685
+ else:
686
+ default_scatter_target = "['class_label']"
687
+
671
688
  scatter_target = self.util.config_val(
672
- "EXPL", "scatter.target", "['class_label']"
689
+ "EXPL", "scatter.target", default_scatter_target
673
690
  )
691
+
692
+ if scatter_target == default_scatter_target:
693
+ self.util.debug(
694
+ f"scatter.target using default from [DATA][target]: {scatter_target}"
695
+ )
696
+ else:
697
+ self.util.debug(
698
+ f"scatter.target from [EXPL][scatter.target]: {scatter_target}"
699
+ )
674
700
  if scatter_var:
675
- scatters = ast.literal_eval(glob_conf.config["EXPL"]["scatter"])
701
+ scatters = ast.literal_eval(scatter_target)
676
702
  scat_targets = ast.literal_eval(scatter_target)
677
703
  plots = Plots()
678
704
  for scat_target in scat_targets:
@@ -692,6 +718,30 @@ class Experiment:
692
718
  df_feats, df_labels, f"{scat_target}_bins", scatter
693
719
  )
694
720
 
721
+ # check if t-SNE plot should be generated
722
+ tsne = eval(self.util.config_val("EXPL", "tsne", "False"))
723
+ if tsne:
724
+ target_column = self.util.config_val("DATA", "target", "emotion")
725
+ plots = Plots()
726
+ self.util.debug("generating t-SNE plot...")
727
+ plots.scatter_plot(df_feats, df_labels, target_column, "tsne")
728
+
729
+ # check if UMAP plot should be generated
730
+ umap_plot = eval(self.util.config_val("EXPL", "umap", "False"))
731
+ if umap_plot:
732
+ target_column = self.util.config_val("DATA", "target", "emotion")
733
+ plots = Plots()
734
+ self.util.debug("generating UMAP plot...")
735
+ plots.scatter_plot(df_feats, df_labels, target_column, "umap")
736
+
737
+ # check if PCA plot should be generated
738
+ pca_plot = eval(self.util.config_val("EXPL", "pca", "False"))
739
+ if pca_plot:
740
+ target_column = self.util.config_val("DATA", "target", "emotion")
741
+ plots = Plots()
742
+ self.util.debug("generating PCA plot...")
743
+ plots.scatter_plot(df_feats, df_labels, target_column, "pca")
744
+
695
745
  def _check_scale(self):
696
746
  self.util.save_to_store(self.feats_train, "feats_train")
697
747
  self.util.save_to_store(self.feats_test, "feats_test")