nkululeko 0.95.0__py3-none-any.whl → 0.95.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,507 @@
1
+ import os
2
+ import tempfile
3
+ from unittest.mock import MagicMock, Mock, patch
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import parselmouth
8
+ import pytest
9
+ from scipy.stats import lognorm
10
+
11
+ from nkululeko.feat_extract.feats_praat_core import (AudioFeatureExtractor,
12
+ add_derived_features,
13
+ compute_features,
14
+ get_speech_rate, run_pca,
15
+ speech_rate)
16
+
17
+
18
+ class TestAudioFeatureExtractor:
19
+
20
+ @pytest.fixture
21
+ def extractor(self):
22
+ return AudioFeatureExtractor(f0min=75, f0max=300)
23
+
24
+ @pytest.fixture
25
+ def mock_sound(self):
26
+ sound = Mock()
27
+ sound.get_total_duration.return_value = 2.5
28
+ return sound
29
+
30
+ def test_init(self):
31
+ extractor = AudioFeatureExtractor(f0min=50, f0max=400)
32
+ assert extractor.f0min == 50
33
+ assert extractor.f0max == 400
34
+
35
+ def test_init_default_values(self):
36
+ extractor = AudioFeatureExtractor()
37
+ assert extractor.f0min == 75
38
+ assert extractor.f0max == 300
39
+
40
+ @patch('nkululeko.feat_extract.feats_praat_core.call')
41
+ def test_extract_pitch_features(self, mock_call, extractor, mock_sound):
42
+ mock_pitch = Mock()
43
+ mock_point_process = Mock()
44
+
45
+ # Mock call return values
46
+ mock_call.side_effect = [
47
+ 150.0, # mean_f0
48
+ 25.0, # stdev_f0
49
+ Mock(), # harmonicity object
50
+ 0.8, # hnr
51
+ 0.01, # local_jitter
52
+ 0.05, # localabsolute_jitter
53
+ 0.02, # rap_jitter
54
+ 0.03, # ppq5_jitter
55
+ 0.04, # ddp_jitter
56
+ 0.1, # local_shimmer
57
+ 0.5, # localdb_shimmer
58
+ 0.15, # apq3_shimmer
59
+ 0.2, # apq5_shimmer
60
+ 0.25, # apq11_shimmer
61
+ 0.3 # dda_shimmer
62
+ ]
63
+
64
+ result = extractor._extract_pitch_features(mock_sound, mock_pitch, mock_point_process)
65
+
66
+ assert result['meanF0Hz'] == 150.0
67
+ assert result['stdevF0Hz'] == 25.0
68
+ assert result['HNR'] == 0.8
69
+ assert result['localJitter'] == 0.01
70
+ assert len(result) == 14
71
+
72
+ @patch('nkululeko.feat_extract.feats_praat_core.call')
73
+ def test_extract_formant_features(self, mock_call, extractor, mock_sound):
74
+ mock_point_process = Mock()
75
+
76
+ # Mock formant values
77
+ mock_call.side_effect = [
78
+ Mock(), # formants object
79
+ 3, # num_points
80
+ 0.5, # time from index 1
81
+ 800.0, # f1 at time 0.5
82
+ 1200.0, # f2 at time 0.5
83
+ 2800.0, # f3 at time 0.5
84
+ 3500.0, # f4 at time 0.5
85
+ 1.0, # time from index 2
86
+ 750.0, # f1 at time 1.0
87
+ 1150.0, # f2 at time 1.0
88
+ 2700.0, # f3 at time 1.0
89
+ 3400.0, # f4 at time 1.0
90
+ 1.5, # time from index 3
91
+ 820.0, # f1 at time 1.5
92
+ 1250.0, # f2 at time 1.5
93
+ 2900.0, # f3 at time 1.5
94
+ 3600.0 # f4 at time 1.5
95
+ ]
96
+
97
+ result = extractor._extract_formant_features(mock_sound, mock_point_process)
98
+
99
+ assert 'f1_mean' in result
100
+ assert 'f2_mean' in result
101
+ assert 'f3_mean' in result
102
+ assert 'f4_mean' in result
103
+ assert 'f1_median' in result
104
+ assert 'f2_median' in result
105
+ assert 'f3_median' in result
106
+ assert 'f4_median' in result
107
+ assert len(result) == 8
108
+
109
+ @patch('nkululeko.feat_extract.feats_praat_core.call')
110
+ def test_extract_formant_features_with_nan(self, mock_call, extractor, mock_sound):
111
+ mock_point_process = Mock()
112
+
113
+ # Mock with some NaN values
114
+ mock_call.side_effect = [
115
+ Mock(), # formants object
116
+ 2, # num_points
117
+ 0.5, # time from index 1
118
+ float('nan'), # f1 at time 0.5 (NaN)
119
+ 1200.0, # f2 at time 0.5
120
+ float('nan'), # f3 at time 0.5 (NaN)
121
+ 3500.0, # f4 at time 0.5
122
+ 1.0, # time from index 2
123
+ 750.0, # f1 at time 1.0
124
+ 1150.0, # f2 at time 1.0
125
+ 2700.0, # f3 at time 1.0
126
+ float('nan') # f4 at time 1.0 (NaN)
127
+ ]
128
+
129
+ result = extractor._extract_formant_features(mock_sound, mock_point_process)
130
+
131
+ # Should handle NaN values gracefully
132
+ assert 'f1_mean' in result
133
+ assert not np.isnan(result['f2_mean'])
134
+ assert len(result) == 8
135
+
136
+ def test_calculate_pause_distribution_empty_list(self, extractor):
137
+ result = extractor._calculate_pause_distribution([])
138
+
139
+ assert np.isnan(result['pause_lognorm_mu'])
140
+ assert np.isnan(result['pause_lognorm_sigma'])
141
+ assert np.isnan(result['pause_lognorm_ks_pvalue'])
142
+ assert np.isnan(result['pause_mean_duration'])
143
+ assert np.isnan(result['pause_std_duration'])
144
+ assert np.isnan(result['pause_cv'])
145
+
146
+ def test_calculate_pause_distribution_valid_data(self, extractor):
147
+ pause_durations = [0.1, 0.2, 0.3, 0.4, 0.5]
148
+ result = extractor._calculate_pause_distribution(pause_durations)
149
+
150
+ assert not np.isnan(result['pause_mean_duration'])
151
+ assert not np.isnan(result['pause_std_duration'])
152
+ assert not np.isnan(result['pause_cv'])
153
+ assert result['pause_mean_duration'] == 0.3
154
+ assert len(result) == 6
155
+
156
+
157
+ class TestRunPCA:
158
+
159
+ def test_run_pca_valid_data(self):
160
+ # Create test dataframe with jitter and shimmer measures
161
+ data = {
162
+ 'localJitter': [0.01, 0.02, 0.015],
163
+ 'localabsoluteJitter': [0.05, 0.06, 0.055],
164
+ 'rapJitter': [0.02, 0.03, 0.025],
165
+ 'ppq5Jitter': [0.03, 0.04, 0.035],
166
+ 'ddpJitter': [0.04, 0.05, 0.045],
167
+ 'localShimmer': [0.1, 0.2, 0.15],
168
+ 'localdbShimmer': [0.5, 0.6, 0.55],
169
+ 'apq3Shimmer': [0.15, 0.25, 0.2],
170
+ 'apq5Shimmer': [0.2, 0.3, 0.25],
171
+ 'apq11Shimmer': [0.25, 0.35, 0.3],
172
+ 'ddaShimmer': [0.3, 0.4, 0.35]
173
+ }
174
+ df = pd.DataFrame(data)
175
+
176
+ result = run_pca(df)
177
+
178
+ assert isinstance(result, pd.DataFrame)
179
+ assert 'JitterPCA' in result.columns
180
+ assert 'ShimmerPCA' in result.columns
181
+ assert len(result) == 3
182
+
183
+ def test_run_pca_with_nan_values(self):
184
+ # Create test dataframe with NaN values
185
+ data = {
186
+ 'localJitter': [0.01, np.nan, 0.015],
187
+ 'localabsoluteJitter': [0.05, 0.06, np.nan],
188
+ 'rapJitter': [0.02, 0.03, 0.025],
189
+ 'ppq5Jitter': [0.03, 0.04, 0.035],
190
+ 'ddpJitter': [0.04, 0.05, 0.045],
191
+ 'localShimmer': [0.1, 0.2, 0.15],
192
+ 'localdbShimmer': [0.5, 0.6, 0.55],
193
+ 'apq3Shimmer': [0.15, 0.25, 0.2],
194
+ 'apq5Shimmer': [0.2, 0.3, 0.25],
195
+ 'apq11Shimmer': [0.25, 0.35, 0.3],
196
+ 'ddaShimmer': [0.3, 0.4, 0.35]
197
+ }
198
+ df = pd.DataFrame(data)
199
+
200
+ result = run_pca(df)
201
+
202
+ assert isinstance(result, pd.DataFrame)
203
+ assert 'JitterPCA' in result.columns
204
+ assert 'ShimmerPCA' in result.columns
205
+
206
+ def test_run_pca_single_file(self):
207
+ # Test with single file (should handle ValueError)
208
+ data = {
209
+ 'localJitter': [0.01],
210
+ 'localabsoluteJitter': [0.05],
211
+ 'rapJitter': [0.02],
212
+ 'ppq5Jitter': [0.03],
213
+ 'ddpJitter': [0.04],
214
+ 'localShimmer': [0.1],
215
+ 'localdbShimmer': [0.5],
216
+ 'apq3Shimmer': [0.15],
217
+ 'apq5Shimmer': [0.2],
218
+ 'apq11Shimmer': [0.25],
219
+ 'ddaShimmer': [0.3]
220
+ }
221
+ df = pd.DataFrame(data)
222
+
223
+ result = run_pca(df)
224
+
225
+ assert isinstance(result, pd.DataFrame)
226
+ assert result.iloc[0]['JitterPCA'] == 0
227
+ assert result.iloc[0]['ShimmerPCA'] == 0
228
+
229
+
230
+ class TestAddDerivedFeatures:
231
+
232
+ def test_add_derived_features(self):
233
+ # Create test dataframe with required columns
234
+ data = {
235
+ 'f1_median': [800, 750, 820],
236
+ 'f2_median': [1200, 1150, 1250],
237
+ 'f3_median': [2800, 2700, 2900],
238
+ 'f4_median': [3500, 3400, 3600],
239
+ 'localJitter': [0.01, 0.02, 0.015],
240
+ 'localabsoluteJitter': [0.05, 0.06, 0.055],
241
+ 'rapJitter': [0.02, 0.03, 0.025],
242
+ 'ppq5Jitter': [0.03, 0.04, 0.035],
243
+ 'ddpJitter': [0.04, 0.05, 0.045],
244
+ 'localShimmer': [0.1, 0.2, 0.15],
245
+ 'localdbShimmer': [0.5, 0.6, 0.55],
246
+ 'apq3Shimmer': [0.15, 0.25, 0.2],
247
+ 'apq5Shimmer': [0.2, 0.3, 0.25],
248
+ 'apq11Shimmer': [0.25, 0.35, 0.3],
249
+ 'ddaShimmer': [0.3, 0.4, 0.35]
250
+ }
251
+ df = pd.DataFrame(data)
252
+
253
+ result = add_derived_features(df)
254
+
255
+ # Check PCA columns are added
256
+ assert 'JitterPCA' in result.columns
257
+ assert 'ShimmerPCA' in result.columns
258
+
259
+ # Check vocal tract features are added
260
+ assert 'pF' in result.columns
261
+ assert 'fdisp' in result.columns
262
+ assert 'avgFormant' in result.columns
263
+ assert 'mff' in result.columns
264
+ assert 'fitch_vtl' in result.columns
265
+ assert 'delta_f' in result.columns
266
+ assert 'vtl_delta_f' in result.columns
267
+
268
+ def test_add_derived_features_with_nan(self):
269
+ # Test with NaN values
270
+ data = {
271
+ 'f1_median': [np.nan, 750, 820],
272
+ 'f2_median': [1200, np.nan, 1250],
273
+ 'f3_median': [2800, 2700, np.nan],
274
+ 'f4_median': [3500, 3400, 3600],
275
+ 'localJitter': [0.01, 0.02, 0.015],
276
+ 'localabsoluteJitter': [0.05, 0.06, 0.055],
277
+ 'rapJitter': [0.02, 0.03, 0.025],
278
+ 'ppq5Jitter': [0.03, 0.04, 0.035],
279
+ 'ddpJitter': [0.04, 0.05, 0.045],
280
+ 'localShimmer': [0.1, 0.2, 0.15],
281
+ 'localdbShimmer': [0.5, 0.6, 0.55],
282
+ 'apq3Shimmer': [0.15, 0.25, 0.2],
283
+ 'apq5Shimmer': [0.2, 0.3, 0.25],
284
+ 'apq11Shimmer': [0.25, 0.35, 0.3],
285
+ 'ddaShimmer': [0.3, 0.4, 0.35]
286
+ }
287
+ df = pd.DataFrame(data)
288
+
289
+ result = add_derived_features(df)
290
+
291
+ # Should handle NaN values without raising errors
292
+ assert 'pF' in result.columns
293
+ assert 'fdisp' in result.columns
294
+ assert len(result) == len(df)
295
+
296
+
297
+ class TestComputeFeatures:
298
+
299
+ def test_compute_features_function_exists(self):
300
+ # Simple test to verify the function exists and is importable
301
+ assert callable(compute_features)
302
+
303
+
304
+ class TestSpeechRate:
305
+
306
+ def test_speech_rate_function_exists(self):
307
+ # Simple test to verify the function exists and is importable
308
+ assert callable(speech_rate)
309
+
310
+
311
+ class TestGetSpeechRate:
312
+
313
+ def test_get_speech_rate_function_exists(self):
314
+ # Simple test to verify the function exists and is importable
315
+ assert callable(get_speech_rate)
316
+
317
+
318
+ class TestPraatIntegration:
319
+ """Integration tests for complete Praat feature extraction pipeline."""
320
+
321
+ def test_compute_features_with_real_audio_file(self):
322
+ """Test that all 45 features can be extracted from a real audio file."""
323
+ import datetime
324
+ import os
325
+
326
+ # Use a real audio file from the test data
327
+ audio_file = "./data/test/audio/debate_sample.wav"
328
+
329
+ # Verify the test audio file exists
330
+ assert os.path.exists(audio_file), f"Test audio file not found: {audio_file}"
331
+
332
+ # Create a mock file index similar to what nkululeko uses
333
+ # Format: (file_path, start_time, end_time)
334
+ file_index = pd.DataFrame([
335
+ (audio_file, datetime.timedelta(seconds=0), datetime.timedelta(seconds=5))
336
+ ], columns=['file', 'start', 'end'])
337
+
338
+ # Set the DataFrame index to match what compute_features expects
339
+ file_index = file_index.set_index(['file', 'start', 'end']).index
340
+
341
+ # Extract features using the main compute_features function
342
+ features_df = compute_features(file_index)
343
+
344
+ # Verify the result is a DataFrame
345
+ assert isinstance(features_df, pd.DataFrame), "compute_features should return a DataFrame"
346
+
347
+ # Verify we have exactly one row (one audio file)
348
+ assert len(features_df) == 1, f"Expected 1 row, got {len(features_df)}"
349
+
350
+ # Verify we have approximately 45 features (exact count may vary with optimizations)
351
+ expected_min_features = 40 # Allow some tolerance
352
+ expected_max_features = 50 # Allow some tolerance
353
+ actual_features = len(features_df.columns)
354
+
355
+ assert expected_min_features <= actual_features <= expected_max_features, \
356
+ f"Expected ~45 features (range {expected_min_features}-{expected_max_features}), got {actual_features}. " \
357
+ f"Features: {list(features_df.columns)}"
358
+
359
+ # Verify that all expected core features are present
360
+ expected_core_features = [
361
+ 'duration', 'meanF0Hz', 'stdevF0Hz', 'HNR',
362
+ 'f1_mean', 'f1_median', 'f2_mean', 'f2_median',
363
+ 'f3_mean', 'f3_median', 'f4_mean', 'f4_median',
364
+ 'localJitter', 'localabsoluteJitter', 'rapJitter', 'ppq5Jitter', 'ddpJitter',
365
+ 'localShimmer', 'localdbShimmer', 'apq3Shimmer', 'apq5Shimmer', 'apq11Shimmer', 'ddaShimmer',
366
+ 'JitterPCA', 'ShimmerPCA', # From PCA
367
+ 'pF', 'fdisp', 'avgFormant', 'mff', 'fitch_vtl', 'delta_f', 'vtl_delta_f', # Vocal tract
368
+ 'nsyll', 'npause', 'phonationtime_s', 'speechrate_nsyll_dur',
369
+ 'articulation_rate_nsyll_phonationtime', 'ASD_speakingtime_nsyll', # Speech rate
370
+ ]
371
+
372
+ missing_features = [feat for feat in expected_core_features if feat not in features_df.columns]
373
+ assert len(missing_features) == 0, f"Missing expected features: {missing_features}"
374
+
375
+ # Verify that most features are not NaN (allowing some tolerance for edge cases)
376
+ non_nan_features = features_df.notna().sum(axis=1).iloc[0]
377
+ total_features = len(features_df.columns)
378
+
379
+ # At least 80% of features should be non-NaN for a valid audio file
380
+ min_valid_features = int(0.8 * total_features)
381
+ assert non_nan_features >= min_valid_features, \
382
+ f"Too many NaN features: {non_nan_features}/{total_features} are valid, " \
383
+ f"expected at least {min_valid_features}"
384
+
385
+ # Verify that specific features have reasonable values
386
+ row = features_df.iloc[0]
387
+
388
+ # Duration should be positive and approximately 5 seconds (with some tolerance)
389
+ assert 3.0 <= row['duration'] <= 7.0, f"Duration seems unreasonable: {row['duration']}"
390
+
391
+ # F0 values should be in human speech range if detected
392
+ if not pd.isna(row['meanF0Hz']):
393
+ assert 50 <= row['meanF0Hz'] <= 500, f"Mean F0 seems unreasonable: {row['meanF0Hz']}"
394
+
395
+ # Formant values should be in typical ranges if detected
396
+ for i in range(1, 5):
397
+ formant_mean = row[f'f{i}_mean']
398
+ if not pd.isna(formant_mean):
399
+ assert 200 <= formant_mean <= 4000, f"Formant F{i} mean seems unreasonable: {formant_mean}"
400
+
401
+ print(f"SUCCESS: Extracted {actual_features} features from real audio file")
402
+ print(f"Feature names: {list(features_df.columns)}")
403
+ print(f"Non-NaN features: {non_nan_features}/{total_features}")
404
+
405
+ def test_feature_extraction_robustness_multiple_files(self):
406
+ """Test feature extraction with multiple real audio files."""
407
+ import datetime
408
+ import os
409
+
410
+ # Test with multiple audio files
411
+ audio_dir = "./data/test/audio"
412
+ available_files = [
413
+ "debate_sample.wav",
414
+ "03a01Fa.wav",
415
+ "03a01Nc.wav"
416
+ ]
417
+
418
+ # Filter to only files that actually exist
419
+ test_files = []
420
+ for fname in available_files:
421
+ fpath = os.path.join(audio_dir, fname)
422
+ if os.path.exists(fpath):
423
+ test_files.append(fpath)
424
+
425
+ assert len(test_files) >= 1, "Need at least one test audio file"
426
+
427
+ # Create file index for multiple files
428
+ file_index_data = []
429
+ for audio_file in test_files:
430
+ file_index_data.append((audio_file, datetime.timedelta(seconds=0), datetime.timedelta(seconds=3)))
431
+
432
+ file_index = pd.DataFrame(file_index_data, columns=['file', 'start', 'end'])
433
+ file_index = file_index.set_index(['file', 'start', 'end']).index
434
+
435
+ # Extract features
436
+ features_df = compute_features(file_index)
437
+
438
+ # Verify we have the correct number of rows
439
+ assert len(features_df) == len(test_files), f"Expected {len(test_files)} rows, got {len(features_df)}"
440
+
441
+ # Verify all files produced some valid features
442
+ for i, test_file in enumerate(test_files):
443
+ row = features_df.iloc[i]
444
+ non_nan_count = row.notna().sum()
445
+ total_features = len(features_df.columns)
446
+
447
+ # Each file should have at least some valid features
448
+ min_valid = int(0.5 * total_features) # More lenient for multiple files
449
+ assert non_nan_count >= min_valid, \
450
+ f"File {test_file} has too few valid features: {non_nan_count}/{total_features}"
451
+
452
+ print(f"SUCCESS: Extracted features from {len(test_files)} files")
453
+ print(f"Total features per file: {len(features_df.columns)}")
454
+
455
+ def test_expected_feature_count_matches_documentation(self):
456
+ """Test that the actual feature count matches the documented count in the code."""
457
+ import datetime
458
+ import os
459
+
460
+ audio_file = "./data/test/audio/debate_sample.wav"
461
+ assert os.path.exists(audio_file), f"Test audio file not found: {audio_file}"
462
+
463
+ file_index = pd.DataFrame([
464
+ (audio_file, datetime.timedelta(seconds=0), datetime.timedelta(seconds=2))
465
+ ], columns=['file', 'start', 'end'])
466
+ file_index = file_index.set_index(['file', 'start', 'end']).index
467
+
468
+ features_df = compute_features(file_index)
469
+ actual_count = len(features_df.columns)
470
+
471
+ # According to the docstring, we expect ~43-45 features
472
+ # The exact count may vary based on optimization and implementation details
473
+ expected_range = (42, 47) # Allow some tolerance
474
+
475
+ assert expected_range[0] <= actual_count <= expected_range[1], \
476
+ f"Feature count {actual_count} is outside expected range {expected_range}. " \
477
+ f"This may indicate changes to the feature extraction implementation."
478
+
479
+ # Print the actual features for documentation/debugging
480
+ feature_categories = {
481
+ 'basic': ['duration', 'meanF0Hz', 'stdevF0Hz', 'HNR'],
482
+ 'formants': [col for col in features_df.columns if col.startswith('f') and ('_mean' in col or '_median' in col)],
483
+ 'jitter': [col for col in features_df.columns if 'Jitter' in col],
484
+ 'shimmer': [col for col in features_df.columns if 'Shimmer' in col],
485
+ 'pca': [col for col in features_df.columns if 'PCA' in col],
486
+ 'vocal_tract': [col for col in features_df.columns if col in ['pF', 'fdisp', 'avgFormant', 'mff', 'fitch_vtl', 'delta_f', 'vtl_delta_f']],
487
+ 'speech_rate': [col for col in features_df.columns if col in ['nsyll', 'npause', 'phonationtime_s', 'speechrate_nsyll_dur', 'articulation_rate_nsyll_phonationtime', 'ASD_speakingtime_nsyll']],
488
+ 'pause_distribution': [col for col in features_df.columns if 'pause' in col.lower()],
489
+ 'other': []
490
+ }
491
+
492
+ # Classify all features
493
+ classified_features = set()
494
+ for category, features in feature_categories.items():
495
+ classified_features.update(features)
496
+
497
+ feature_categories['other'] = [col for col in features_df.columns if col not in classified_features]
498
+
499
+ print(f"\nFeature breakdown (total: {actual_count}):")
500
+ for category, features in feature_categories.items():
501
+ if features:
502
+ print(f" {category}: {len(features)} features - {features}")
503
+
504
+ # Verify we have features in all major categories
505
+ required_categories = ['basic', 'formants', 'jitter', 'shimmer']
506
+ for category in required_categories:
507
+ assert len(feature_categories[category]) > 0, f"No features found in {category} category"
nkululeko/modelrunner.py CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
4
4
 
5
5
  from nkululeko import glob_conf
6
6
  from nkululeko.utils.util import Util
7
+ from nkululeko.balance import DataBalancer
7
8
 
8
9
 
9
10
  class Modelrunner:
@@ -143,6 +144,7 @@ class Modelrunner:
143
144
 
144
145
  def _select_model(self, model_type):
145
146
  self._check_balancing()
147
+ self._check_feature_balancing()
146
148
 
147
149
  if model_type == "svm":
148
150
  from nkululeko.models.model_svm import SVM_model
@@ -243,54 +245,19 @@ class Modelrunner:
243
245
  )
244
246
  return self.model
245
247
 
246
- def _check_balancing(self):
248
+ def _check_feature_balancing(self):
249
+ """Check and apply feature balancing using the dedicated DataBalancer class."""
247
250
  balancing = self.util.config_val("FEATS", "balancing", False)
248
251
  if balancing:
249
- orig_size = self.feats_train.shape[0]
250
- self.util.debug(f"balancing the training features with: {balancing}")
251
- if balancing == "ros":
252
- from imblearn.over_sampling import RandomOverSampler
253
-
254
- sampler = RandomOverSampler(random_state=42)
255
- X_res, y_res = sampler.fit_resample(
256
- self.feats_train, self.df_train[self.target]
257
- )
258
- elif balancing == "smote":
259
- from imblearn.over_sampling import SMOTE
260
-
261
- sampler = SMOTE(random_state=42)
262
- X_res, y_res = sampler.fit_resample(
263
- self.feats_train, self.df_train[self.target]
264
- )
265
- elif balancing == "adasyn":
266
- from imblearn.over_sampling import ADASYN
267
-
268
- sampler = ADASYN(random_state=42)
269
- X_res, y_res = sampler.fit_resample(
270
- self.feats_train, self.df_train[self.target]
271
- )
272
- else:
273
- self.util.error(
274
- f"unknown balancing algorithm: {balancing} (should be [ros|smote|adasyn])"
275
- )
276
-
277
- self.feats_train = X_res
278
- self.df_train = pd.DataFrame({self.target: y_res}, index=X_res.index)
279
- self.util.debug(
280
- f"balanced with: {balancing}, new size: {X_res.shape[0]} (was {orig_size})"
252
+ self.util.debug("Applying feature balancing using DataBalancer")
253
+
254
+ # Initialize the data balancer
255
+ balancer = DataBalancer(random_state=42)
256
+
257
+ # Apply balancing
258
+ self.df_train, self.feats_train = balancer.balance_features(
259
+ df_train=self.df_train,
260
+ feats_train=self.feats_train,
261
+ target_column=self.target,
262
+ method=balancing
281
263
  )
282
- # Check if label encoder is available before using it
283
- if (
284
- hasattr(glob_conf, "label_encoder")
285
- and glob_conf.label_encoder is not None
286
- ):
287
- le = glob_conf.label_encoder
288
- res = y_res.value_counts()
289
- resd = {}
290
- for i, e in enumerate(le.inverse_transform(res.index.values)):
291
- resd[e] = res.values[i]
292
- self.util.debug(f"class distribution after balancing: {resd}")
293
- else:
294
- self.util.debug(
295
- "Label encoder not available, skipping class distribution report"
296
- )
@@ -0,0 +1,49 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ import pytest
4
+
5
+ from nkululeko.models.model_knn import KNN_model
6
+
7
+
8
+ @pytest.fixture
9
+ def mock_util():
10
+ mock = MagicMock()
11
+ mock.config_val.side_effect = lambda section, key, default: {
12
+ ("MODEL", "KNN_weights", "uniform"): "distance",
13
+ ("MODEL", "K_val", "5"): "3"
14
+ }[(section, key, default)]
15
+ return mock
16
+
17
+ @pytest.fixture
18
+ def dummy_data():
19
+ df_train = MagicMock()
20
+ df_test = MagicMock()
21
+ feats_train = MagicMock()
22
+ feats_test = MagicMock()
23
+ return df_train, df_test, feats_train, feats_test
24
+
25
+ def test_knn_model_initialization(monkeypatch, mock_util, dummy_data):
26
+ with patch.object(KNN_model, "__init__", return_value=None):
27
+ model = KNN_model(*dummy_data)
28
+ model.util = mock_util
29
+ model.name = "knn"
30
+ from sklearn.neighbors import KNeighborsClassifier
31
+ model.clf = KNeighborsClassifier(n_neighbors=3, weights="distance")
32
+ model.is_classifier = True
33
+ assert model.name == "knn"
34
+ assert model.clf.get_params()["n_neighbors"] == 3
35
+ assert model.clf.get_params()["weights"] == "distance"
36
+ assert model.is_classifier is True
37
+
38
+ def test_knn_model_default_params(monkeypatch, dummy_data):
39
+ mock_util = MagicMock()
40
+ mock_util.config_val.side_effect = lambda section, key, default: default
41
+ with patch.object(KNN_model, "__init__", return_value=None):
42
+ model = KNN_model(*dummy_data)
43
+ model.util = mock_util
44
+ model.name = "knn"
45
+ from sklearn.neighbors import KNeighborsClassifier
46
+ model.clf = KNeighborsClassifier(n_neighbors=5, weights="uniform")
47
+ model.is_classifier = True
48
+ assert model.clf.get_params()["n_neighbors"] == 5
49
+ assert model.clf.get_params()["weights"] == "uniform"