nkululeko 0.95.0__py3-none-any.whl → 0.95.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/autopredict/tests/__init__.py +0 -0
- nkululeko/autopredict/tests/test_whisper_transcriber.py +122 -0
- nkululeko/balance.py +222 -0
- nkululeko/constants.py +1 -1
- nkululeko/feat_extract/feats_praat.py +3 -3
- nkululeko/feat_extract/{feinberg_praat.py → feats_praat_core.py} +0 -2
- nkululeko/feat_extract/tests/__init__.py +1 -0
- nkululeko/feat_extract/tests/test_feats_opensmile.py +162 -0
- nkululeko/feat_extract/tests/test_feats_praat_core.py +507 -0
- nkululeko/modelrunner.py +15 -48
- nkululeko/models/tests/test_model_knn.py +49 -0
- nkululeko/models/tests/test_model_mlp.py +153 -0
- nkululeko/models/tests/test_model_xgb.py +33 -0
- nkululeko/predict.py +3 -2
- nkululeko/reporting/reporter.py +12 -0
- nkululeko/test_predictor.py +7 -1
- nkululeko/tests/__init__.py +1 -0
- nkululeko/tests/test_balancing.py +270 -0
- nkululeko/utils/util.py +5 -5
- {nkululeko-0.95.0.dist-info → nkululeko-0.95.1.dist-info}/METADATA +1 -1
- {nkululeko-0.95.0.dist-info → nkululeko-0.95.1.dist-info}/RECORD +25 -15
- nkululeko/feat_extract/feats_opensmile copy.py +0 -93
- {nkululeko-0.95.0.dist-info → nkululeko-0.95.1.dist-info}/WHEEL +0 -0
- {nkululeko-0.95.0.dist-info → nkululeko-0.95.1.dist-info}/entry_points.txt +0 -0
- {nkululeko-0.95.0.dist-info → nkululeko-0.95.1.dist-info}/licenses/LICENSE +0 -0
- {nkululeko-0.95.0.dist-info → nkululeko-0.95.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import pandas as pd
|
3
|
+
import pytest
|
4
|
+
import torch
|
5
|
+
from unittest.mock import patch
|
6
|
+
|
7
|
+
from nkululeko.models.model_mlp import MLPModel
|
8
|
+
|
9
|
+
|
10
|
+
class DummyUtil:
|
11
|
+
def config_val(self, section, key, default=None):
|
12
|
+
# Provide defaults for required config values
|
13
|
+
if key == "manual_seed":
|
14
|
+
return True
|
15
|
+
if key == "loss":
|
16
|
+
return "cross"
|
17
|
+
if key == "device":
|
18
|
+
return "cpu"
|
19
|
+
if key == "learning_rate":
|
20
|
+
return 0.001
|
21
|
+
if key == "batch_size":
|
22
|
+
return 2
|
23
|
+
if key == "drop":
|
24
|
+
return False
|
25
|
+
return default
|
26
|
+
def debug(self, msg): pass
|
27
|
+
def error(self, msg): raise Exception(msg)
|
28
|
+
def get_path(self, key): return "./"
|
29
|
+
def get_exp_name(self, only_train=False): return "exp"
|
30
|
+
|
31
|
+
@pytest.fixture(autouse=True)
|
32
|
+
def patch_globals(monkeypatch):
|
33
|
+
# Patch global config and labels
|
34
|
+
import nkululeko.glob_conf as glob_conf
|
35
|
+
glob_conf.config = {
|
36
|
+
"DATA": {"target": "label"},
|
37
|
+
"MODEL": {"layers": "{'a': 8, 'b': 4}"}
|
38
|
+
}
|
39
|
+
glob_conf.labels = [0, 1]
|
40
|
+
yield
|
41
|
+
|
42
|
+
@pytest.fixture
|
43
|
+
def dummy_data():
|
44
|
+
# 4 samples, 3 features
|
45
|
+
feats_train = pd.DataFrame(np.random.rand(4, 3), columns=['f1', 'f2', 'f3'])
|
46
|
+
feats_test = pd.DataFrame(np.random.rand(2, 3), columns=['f1', 'f2', 'f3'])
|
47
|
+
df_train = pd.DataFrame({'label': [0, 1, 0, 1]})
|
48
|
+
df_test = pd.DataFrame({'label': [1, 0]})
|
49
|
+
return df_train, df_test, feats_train, feats_test
|
50
|
+
|
51
|
+
@pytest.fixture
|
52
|
+
def mlp_model(dummy_data, monkeypatch):
|
53
|
+
df_train, df_test, feats_train, feats_test = dummy_data
|
54
|
+
with patch.object(MLPModel, "__init__", return_value=None):
|
55
|
+
model = MLPModel(df_train, df_test, feats_train, feats_test)
|
56
|
+
model.util = DummyUtil()
|
57
|
+
model.n_jobs = 1
|
58
|
+
model.target = "label"
|
59
|
+
model.class_num = 2
|
60
|
+
model.criterion = torch.nn.CrossEntropyLoss()
|
61
|
+
model.device = "cpu"
|
62
|
+
model.learning_rate = 0.001
|
63
|
+
model.batch_size = 2
|
64
|
+
model.num_workers = 1
|
65
|
+
model.loss = 0.0
|
66
|
+
model.loss_eval = 0.0
|
67
|
+
model.run = 0
|
68
|
+
model.epoch = 0
|
69
|
+
model.df_test = df_test
|
70
|
+
model.feats_test = feats_test
|
71
|
+
model.feats_train = feats_train
|
72
|
+
|
73
|
+
# Create a simple MLP model for testing
|
74
|
+
model.model = MLPModel.MLP(3, {'a': 8, 'b': 4}, 2, False).to("cpu")
|
75
|
+
model.optimizer = torch.optim.Adam(model.model.parameters(), lr=0.001)
|
76
|
+
|
77
|
+
# Create data loaders
|
78
|
+
model.trainloader = model.get_loader(feats_train, df_train, True)
|
79
|
+
model.testloader = model.get_loader(feats_test, df_test, False)
|
80
|
+
model.store_path = "/tmp/test_model.pt"
|
81
|
+
|
82
|
+
return model
|
83
|
+
|
84
|
+
def test_mlpmodel_init(mlp_model):
|
85
|
+
assert hasattr(mlp_model, "model")
|
86
|
+
assert hasattr(mlp_model, "trainloader")
|
87
|
+
assert hasattr(mlp_model, "testloader")
|
88
|
+
assert mlp_model.model is not None
|
89
|
+
|
90
|
+
def test_train_and_predict(mlp_model):
|
91
|
+
mlp_model.train()
|
92
|
+
report = mlp_model.predict()
|
93
|
+
assert hasattr(report, "result")
|
94
|
+
assert hasattr(report.result, "train")
|
95
|
+
|
96
|
+
def test_get_predictions(mlp_model):
|
97
|
+
mlp_model.train()
|
98
|
+
preds = mlp_model.get_predictions()
|
99
|
+
assert isinstance(preds, np.ndarray)
|
100
|
+
assert preds.shape[0] == 2
|
101
|
+
|
102
|
+
def test_get_probas(mlp_model):
|
103
|
+
mlp_model.train()
|
104
|
+
_, _, _, logits = mlp_model.evaluate(mlp_model.model, mlp_model.testloader, mlp_model.device)
|
105
|
+
probas = mlp_model.get_probas(logits)
|
106
|
+
assert isinstance(probas, pd.DataFrame)
|
107
|
+
assert set(probas.columns) == set([0, 1])
|
108
|
+
|
109
|
+
def test_predict_sample(mlp_model):
|
110
|
+
mlp_model.train()
|
111
|
+
feats = np.random.rand(3)
|
112
|
+
res = mlp_model.predict_sample(feats)
|
113
|
+
assert isinstance(res, dict)
|
114
|
+
assert set(res.keys()) == set([0, 1])
|
115
|
+
|
116
|
+
def test_predict_shap(mlp_model):
|
117
|
+
mlp_model.train()
|
118
|
+
feats = pd.DataFrame(np.random.rand(2, 3))
|
119
|
+
results = mlp_model.predict_shap(feats)
|
120
|
+
assert len(results) == 2
|
121
|
+
|
122
|
+
def test_store_and_load(tmp_path, mlp_model, monkeypatch):
|
123
|
+
mlp_model.train()
|
124
|
+
|
125
|
+
# Mock the util methods that load() uses to construct the path
|
126
|
+
def mock_get_path(key):
|
127
|
+
if key == "model_dir":
|
128
|
+
return str(tmp_path) + "/"
|
129
|
+
return "./"
|
130
|
+
|
131
|
+
def mock_get_exp_name(only_train=False):
|
132
|
+
return "model"
|
133
|
+
|
134
|
+
mlp_model.util.get_path = mock_get_path
|
135
|
+
mlp_model.util.get_exp_name = mock_get_exp_name
|
136
|
+
|
137
|
+
# Set store path to match what load() will construct
|
138
|
+
mlp_model.store_path = str(tmp_path) + "/model_0_000.model"
|
139
|
+
mlp_model.store()
|
140
|
+
|
141
|
+
# Simulate loading
|
142
|
+
mlp_model.load(0, 0)
|
143
|
+
assert mlp_model.model is not None
|
144
|
+
|
145
|
+
def test_set_testdata(mlp_model, dummy_data):
|
146
|
+
_, df_test, _, feats_test = dummy_data
|
147
|
+
mlp_model.set_testdata(df_test, feats_test)
|
148
|
+
assert mlp_model.testloader is not None
|
149
|
+
|
150
|
+
def test_reset_test(mlp_model, dummy_data):
|
151
|
+
_, df_test, _, feats_test = dummy_data
|
152
|
+
mlp_model.reset_test(df_test, feats_test)
|
153
|
+
assert mlp_model.testloader is not None
|
@@ -0,0 +1,33 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
import pytest
|
3
|
+
|
4
|
+
from ..model_xgb import XGB_model
|
5
|
+
|
6
|
+
|
7
|
+
class DummyUtil:
|
8
|
+
def config_val(self, section, key, default):
|
9
|
+
return default
|
10
|
+
def debug(self, msg):
|
11
|
+
pass
|
12
|
+
|
13
|
+
class DummyModel(XGB_model):
|
14
|
+
def __init__(self, df_train, df_test, feats_train, feats_test):
|
15
|
+
# Patch util before calling super().__init__
|
16
|
+
self.util = DummyUtil()
|
17
|
+
self.target = "label"
|
18
|
+
super().__init__(df_train, df_test, feats_train, feats_test)
|
19
|
+
self.util = DummyUtil()
|
20
|
+
self.target = "label"
|
21
|
+
|
22
|
+
@pytest.fixture
|
23
|
+
def dummy_data():
|
24
|
+
df_train = pd.DataFrame({"label": [0, 1], "f1": [1.0, 2.0]})
|
25
|
+
df_test = pd.DataFrame({"label": [0, 1], "f1": [1.5, 2.5]})
|
26
|
+
feats_train = df_train[["f1"]]
|
27
|
+
feats_test = df_test[["f1"]]
|
28
|
+
return df_train, df_test, feats_train, feats_test
|
29
|
+
|
30
|
+
def test_get_type_returns_xgb(dummy_data):
|
31
|
+
df_train, df_test, feats_train, feats_test = dummy_data
|
32
|
+
model = DummyModel(df_train, df_test, feats_train, feats_test)
|
33
|
+
assert model.get_type() == "xgb"
|
nkululeko/predict.py
CHANGED
@@ -62,8 +62,9 @@ def main():
|
|
62
62
|
df = df.rename(columns={"class_label": target})
|
63
63
|
sample_selection = util.config_val("PREDICT", "sample_selection", "all")
|
64
64
|
name = f"{sample_selection}_predicted"
|
65
|
-
|
66
|
-
|
65
|
+
res_dir = util.get_res_dir()
|
66
|
+
df.to_csv(os.path.join(res_dir, f"{name}.csv"))
|
67
|
+
util.debug(f"saved {os.path.join(res_dir, name)}.csv")
|
67
68
|
print("DONE")
|
68
69
|
|
69
70
|
|
nkululeko/reporting/reporter.py
CHANGED
@@ -2,6 +2,7 @@ import ast
|
|
2
2
|
import glob
|
3
3
|
import json
|
4
4
|
import math
|
5
|
+
import os
|
5
6
|
|
6
7
|
# import os
|
7
8
|
from confidence_intervals import evaluate_with_conf_int
|
@@ -173,6 +174,17 @@ class Reporter:
|
|
173
174
|
probas["correct"] = probas.predicted == probas.truth
|
174
175
|
if file_name is None:
|
175
176
|
file_name = self.util.get_pred_name() + ".csv"
|
177
|
+
else:
|
178
|
+
# Ensure the file_name goes to the results directory
|
179
|
+
if not os.path.isabs(file_name):
|
180
|
+
res_dir = self.util.get_res_dir()
|
181
|
+
if not file_name.endswith(".csv"):
|
182
|
+
file_name = os.path.join(res_dir, file_name + ".csv")
|
183
|
+
else:
|
184
|
+
file_name = os.path.join(res_dir, file_name)
|
185
|
+
else:
|
186
|
+
if not file_name.endswith(".csv"):
|
187
|
+
file_name = file_name + ".csv"
|
176
188
|
self.probas = probas
|
177
189
|
probas.to_csv(file_name)
|
178
190
|
self.util.debug(f"Saved probabilities to {file_name}")
|
nkululeko/test_predictor.py
CHANGED
@@ -5,6 +5,7 @@ Predict targets from a model and save as csv file.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import ast
|
8
|
+
import os
|
8
9
|
|
9
10
|
import pandas as pd
|
10
11
|
from sklearn.preprocessing import LabelEncoder
|
@@ -24,7 +25,12 @@ class TestPredictor:
|
|
24
25
|
self.label_encoder = labenc
|
25
26
|
self.target = glob_conf.config["DATA"]["target"]
|
26
27
|
self.util = Util("test_predictor")
|
27
|
-
|
28
|
+
# Construct full path to results directory
|
29
|
+
res_dir = self.util.get_res_dir()
|
30
|
+
if os.path.isabs(name):
|
31
|
+
self.name = name
|
32
|
+
else:
|
33
|
+
self.name = os.path.join(res_dir, name)
|
28
34
|
|
29
35
|
def predict_and_store(self):
|
30
36
|
label_data = self.util.config_val("DATA", "label_data", False)
|
@@ -0,0 +1 @@
|
|
1
|
+
# Tests package for nkululeko
|
@@ -0,0 +1,270 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Simple and comprehensive test suite for all balancing methods in DataBalancer.
|
4
|
+
|
5
|
+
Tests all 11 balancing methods from balance.py:
|
6
|
+
|
7
|
+
Oversampling (5): ros, smote, adasyn, borderlinesmote, svmsmote
|
8
|
+
Undersampling (4): clustercentroids, randomundersampler, editednearestneighbours, tomeklinks
|
9
|
+
Combination (2): smoteenn, smotetomek
|
10
|
+
|
11
|
+
Run with: pytest nkululeko/tests/test_balancing.py -v
|
12
|
+
"""
|
13
|
+
|
14
|
+
import numpy as np
|
15
|
+
import pandas as pd
|
16
|
+
import pytest
|
17
|
+
from nkululeko.balance import DataBalancer
|
18
|
+
import nkululeko.glob_conf as glob_conf
|
19
|
+
|
20
|
+
|
21
|
+
@pytest.fixture
|
22
|
+
def sample_data():
|
23
|
+
"""Create sample imbalanced data that works with all methods"""
|
24
|
+
np.random.seed(42)
|
25
|
+
|
26
|
+
# Majority class: 100 samples, Minority class: 25 samples
|
27
|
+
# Well-separated for better algorithm performance
|
28
|
+
majority_features = np.random.randn(100, 10)
|
29
|
+
minority_features = np.random.randn(25, 10) + 3 # Good separation
|
30
|
+
|
31
|
+
features = np.vstack([majority_features, minority_features])
|
32
|
+
labels = np.array([0] * 100 + [1] * 25)
|
33
|
+
|
34
|
+
df_train = pd.DataFrame({'target': labels})
|
35
|
+
feats_train = features
|
36
|
+
|
37
|
+
return df_train, feats_train
|
38
|
+
|
39
|
+
|
40
|
+
@pytest.fixture
|
41
|
+
def mock_config():
|
42
|
+
"""Mock configuration for testing"""
|
43
|
+
original_config = getattr(glob_conf, 'config', None)
|
44
|
+
|
45
|
+
glob_conf.config = {
|
46
|
+
'FEATS': {'balancing': 'smote'},
|
47
|
+
'DATA': {'target': 'target'},
|
48
|
+
'MODEL': {'type': 'mlp'}
|
49
|
+
}
|
50
|
+
|
51
|
+
yield glob_conf.config
|
52
|
+
|
53
|
+
if original_config is not None:
|
54
|
+
glob_conf.config = original_config
|
55
|
+
|
56
|
+
|
57
|
+
class TestDataBalancer:
|
58
|
+
"""Simple test suite for DataBalancer - tests all 11 methods"""
|
59
|
+
|
60
|
+
def test_initialization(self):
|
61
|
+
"""Test 1: DataBalancer can be initialized"""
|
62
|
+
balancer = DataBalancer(random_state=42)
|
63
|
+
assert balancer is not None
|
64
|
+
assert balancer.random_state == 42
|
65
|
+
|
66
|
+
def test_get_all_supported_methods(self):
|
67
|
+
"""Test 2: All 11 methods are reported as supported"""
|
68
|
+
balancer = DataBalancer()
|
69
|
+
methods = balancer.get_supported_methods()
|
70
|
+
|
71
|
+
# Check we have all 3 categories
|
72
|
+
assert 'oversampling' in methods
|
73
|
+
assert 'undersampling' in methods
|
74
|
+
assert 'combination' in methods
|
75
|
+
|
76
|
+
# Check exact counts
|
77
|
+
assert len(methods['oversampling']) == 5
|
78
|
+
assert len(methods['undersampling']) == 4
|
79
|
+
assert len(methods['combination']) == 2
|
80
|
+
|
81
|
+
# Total should be 11
|
82
|
+
total = (len(methods['oversampling']) +
|
83
|
+
len(methods['undersampling']) +
|
84
|
+
len(methods['combination']))
|
85
|
+
assert total == 11
|
86
|
+
|
87
|
+
def test_method_validation(self):
|
88
|
+
"""Test 3: Method validation works correctly"""
|
89
|
+
balancer = DataBalancer()
|
90
|
+
|
91
|
+
# Valid methods
|
92
|
+
assert balancer.is_valid_method('ros') == True
|
93
|
+
assert balancer.is_valid_method('smote') == True
|
94
|
+
assert balancer.is_valid_method('clustercentroids') == True
|
95
|
+
assert balancer.is_valid_method('smoteenn') == True
|
96
|
+
|
97
|
+
# Invalid methods
|
98
|
+
assert balancer.is_valid_method('invalid') == False
|
99
|
+
assert balancer.is_valid_method('') == False
|
100
|
+
|
101
|
+
def test_all_oversampling_methods(self, sample_data, mock_config):
|
102
|
+
"""Test 4: All 5 oversampling methods work"""
|
103
|
+
df_train, feats_train = sample_data
|
104
|
+
balancer = DataBalancer(random_state=42)
|
105
|
+
|
106
|
+
oversampling_methods = ['ros', 'smote', 'adasyn', 'borderlinesmote', 'svmsmote']
|
107
|
+
|
108
|
+
for method in oversampling_methods:
|
109
|
+
print(f"Testing oversampling: {method}")
|
110
|
+
|
111
|
+
balanced_df, balanced_features = balancer.balance_features(
|
112
|
+
df_train=df_train,
|
113
|
+
feats_train=feats_train,
|
114
|
+
target_column='target',
|
115
|
+
method=method
|
116
|
+
)
|
117
|
+
|
118
|
+
# Basic checks
|
119
|
+
assert len(balanced_df) >= len(df_train), f"{method} should increase/maintain size"
|
120
|
+
assert len(balanced_df) == len(balanced_features), f"{method} length mismatch"
|
121
|
+
assert balanced_features.shape[1] == feats_train.shape[1], f"{method} feature dim changed"
|
122
|
+
|
123
|
+
print(f"✓ {method} passed")
|
124
|
+
|
125
|
+
def test_all_undersampling_methods(self, sample_data, mock_config):
|
126
|
+
"""Test 5: All 4 undersampling methods work"""
|
127
|
+
df_train, feats_train = sample_data
|
128
|
+
balancer = DataBalancer(random_state=42)
|
129
|
+
|
130
|
+
undersampling_methods = ['clustercentroids', 'randomundersampler',
|
131
|
+
'editednearestneighbours', 'tomeklinks']
|
132
|
+
|
133
|
+
for method in undersampling_methods:
|
134
|
+
print(f"Testing undersampling: {method}")
|
135
|
+
|
136
|
+
balanced_df, balanced_features = balancer.balance_features(
|
137
|
+
df_train=df_train,
|
138
|
+
feats_train=feats_train,
|
139
|
+
target_column='target',
|
140
|
+
method=method
|
141
|
+
)
|
142
|
+
|
143
|
+
# Basic checks
|
144
|
+
assert len(balanced_df) <= len(df_train), f"{method} should decrease/maintain size"
|
145
|
+
assert len(balanced_df) == len(balanced_features), f"{method} length mismatch"
|
146
|
+
assert balanced_features.shape[1] == feats_train.shape[1], f"{method} feature dim changed"
|
147
|
+
|
148
|
+
print(f"✓ {method} passed")
|
149
|
+
|
150
|
+
def test_all_combination_methods(self, sample_data, mock_config):
|
151
|
+
"""Test 6: All 2 combination methods work"""
|
152
|
+
df_train, feats_train = sample_data
|
153
|
+
balancer = DataBalancer(random_state=42)
|
154
|
+
|
155
|
+
combination_methods = ['smoteenn', 'smotetomek']
|
156
|
+
|
157
|
+
for method in combination_methods:
|
158
|
+
print(f"Testing combination: {method}")
|
159
|
+
|
160
|
+
balanced_df, balanced_features = balancer.balance_features(
|
161
|
+
df_train=df_train,
|
162
|
+
feats_train=feats_train,
|
163
|
+
target_column='target',
|
164
|
+
method=method
|
165
|
+
)
|
166
|
+
|
167
|
+
# Basic checks
|
168
|
+
assert len(balanced_df) == len(balanced_features), f"{method} length mismatch"
|
169
|
+
assert balanced_features.shape[1] == feats_train.shape[1], f"{method} feature dim changed"
|
170
|
+
assert len(balanced_df) > 0, f"{method} resulted in empty dataset"
|
171
|
+
|
172
|
+
print(f"✓ {method} passed")
|
173
|
+
|
174
|
+
def test_all_11_methods_comprehensive(self, sample_data, mock_config):
|
175
|
+
"""Test 7: All 11 methods work in one comprehensive test"""
|
176
|
+
df_train, feats_train = sample_data
|
177
|
+
balancer = DataBalancer(random_state=42)
|
178
|
+
|
179
|
+
# Get all methods from the balancer itself
|
180
|
+
all_methods = balancer.get_supported_methods()
|
181
|
+
|
182
|
+
successful_methods = []
|
183
|
+
failed_methods = []
|
184
|
+
|
185
|
+
print("Testing all 11 balancing methods...")
|
186
|
+
|
187
|
+
for category, methods in all_methods.items():
|
188
|
+
for method in methods:
|
189
|
+
try:
|
190
|
+
balanced_df, balanced_features = balancer.balance_features(
|
191
|
+
df_train=df_train,
|
192
|
+
feats_train=feats_train,
|
193
|
+
target_column='target',
|
194
|
+
method=method
|
195
|
+
)
|
196
|
+
|
197
|
+
# Verify results
|
198
|
+
assert len(balanced_df) == len(balanced_features)
|
199
|
+
assert balanced_features.shape[1] == feats_train.shape[1]
|
200
|
+
assert len(balanced_df) > 0
|
201
|
+
|
202
|
+
successful_methods.append(method)
|
203
|
+
print(f"✓ {method} succeeded")
|
204
|
+
|
205
|
+
except Exception as e:
|
206
|
+
failed_methods.append((method, str(e)))
|
207
|
+
print(f"✗ {method} failed: {str(e)}")
|
208
|
+
|
209
|
+
print(f"\nResults: {len(successful_methods)}/11 methods successful")
|
210
|
+
print(f"Successful: {successful_methods}")
|
211
|
+
if failed_methods:
|
212
|
+
print(f"Failed: {[m[0] for m in failed_methods]}")
|
213
|
+
|
214
|
+
# All 11 methods should work
|
215
|
+
assert len(successful_methods) == 11, f"Expected 11 successful methods, got {len(successful_methods)}"
|
216
|
+
assert len(failed_methods) == 0, f"Some methods failed: {failed_methods}"
|
217
|
+
|
218
|
+
def test_invalid_method_handling(self, sample_data, mock_config):
|
219
|
+
"""Test 8: Invalid methods are handled correctly"""
|
220
|
+
df_train, feats_train = sample_data
|
221
|
+
balancer = DataBalancer(random_state=42)
|
222
|
+
|
223
|
+
# Test that invalid methods are detected by validation
|
224
|
+
assert balancer.is_valid_method('invalid_method') == False
|
225
|
+
assert balancer.is_valid_method('nonexistent') == False
|
226
|
+
assert balancer.is_valid_method('') == False
|
227
|
+
|
228
|
+
# Note: The actual balance_features() with invalid method calls sys.exit()
|
229
|
+
# This is expected behavior in the current implementation
|
230
|
+
print("✓ Invalid method validation works correctly")
|
231
|
+
|
232
|
+
|
233
|
+
def test_simple_integration():
|
234
|
+
"""Test 9: Simple integration test without fixtures"""
|
235
|
+
print("Simple integration test...")
|
236
|
+
|
237
|
+
# Create simple data
|
238
|
+
np.random.seed(42)
|
239
|
+
features = np.random.randn(60, 5)
|
240
|
+
labels = np.array([0] * 40 + [1] * 20) # 40 vs 20 imbalance
|
241
|
+
|
242
|
+
df_train = pd.DataFrame({'target': labels})
|
243
|
+
|
244
|
+
# Test a few key methods
|
245
|
+
balancer = DataBalancer(random_state=42)
|
246
|
+
key_methods = ['ros', 'smote', 'clustercentroids', 'randomundersampler']
|
247
|
+
|
248
|
+
for method in key_methods:
|
249
|
+
balanced_df, balanced_features = balancer.balance_features(
|
250
|
+
df_train=df_train,
|
251
|
+
feats_train=features,
|
252
|
+
target_column='target',
|
253
|
+
method=method
|
254
|
+
)
|
255
|
+
|
256
|
+
assert len(balanced_df) == len(balanced_features)
|
257
|
+
print(f"✓ {method} integration test passed")
|
258
|
+
|
259
|
+
print("✓ Integration test completed")
|
260
|
+
|
261
|
+
|
262
|
+
if __name__ == "__main__":
|
263
|
+
print("Running simple balancing tests...")
|
264
|
+
print("=" * 50)
|
265
|
+
|
266
|
+
# Run integration test
|
267
|
+
test_simple_integration()
|
268
|
+
|
269
|
+
print("=" * 50)
|
270
|
+
print("Direct test completed! Run 'pytest test_balancing.py -v' for full tests")
|
nkululeko/utils/util.py
CHANGED
@@ -106,15 +106,15 @@ class Util:
|
|
106
106
|
except KeyError:
|
107
107
|
# some default values
|
108
108
|
if entry == "fig_dir":
|
109
|
-
entryn = "
|
109
|
+
entryn = "images/"
|
110
110
|
elif entry == "res_dir":
|
111
|
-
entryn = "
|
111
|
+
entryn = "results/"
|
112
112
|
elif entry == "model_dir":
|
113
|
-
entryn = "
|
113
|
+
entryn = "models/"
|
114
114
|
elif entry == "cache":
|
115
|
-
entryn = "
|
115
|
+
entryn = "cache/"
|
116
116
|
else:
|
117
|
-
entryn = "
|
117
|
+
entryn = "store/"
|
118
118
|
|
119
119
|
# Expand image, model and result directories with run index
|
120
120
|
if entry == "fig_dir" or entry == "res_dir" or entry == "model_dir":
|
@@ -2,8 +2,9 @@ examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
3
3
|
nkululeko/aug_train.py,sha256=wpiHCJ7zsW38kumg3ypwXZe2HQrhUblAnv7P2QeJnAc,3525
|
4
4
|
nkululeko/augment.py,sha256=3RzaxB3gRxovgJVjHXi0glprW01J7RaHhUkqotW2T3U,2955
|
5
|
+
nkululeko/balance.py,sha256=r7opXbrqAipm2euPPaOmLlA5J10p2bHQgO5kWk2x9ro,8702
|
5
6
|
nkululeko/cacheddataset.py,sha256=XFpWZmbJRg0pvhnIgYf0TkclxllD-Fctu-Ol0PF_00c,969
|
6
|
-
nkululeko/constants.py,sha256=
|
7
|
+
nkululeko/constants.py,sha256=9E1ltDzIxGnwuxdRBW6OUWwJB8Im9_c4dnOUwjcDDr8,39
|
7
8
|
nkululeko/demo-ft.py,sha256=iD9Pzp9QjyAv31q1cDZ75vPez7Ve8A4Cfukv5yfZdrQ,770
|
8
9
|
nkululeko/demo.py,sha256=tu7Al2l5MCLVegkDC-NE2wcuc_YE7NRbgOlPW3yhGEs,4940
|
9
10
|
nkululeko/demo_feats.py,sha256=BvZjeNFTlERIRlq34OHM4Z96jdDQAhB01BGQAUcX9dM,2026
|
@@ -17,19 +18,19 @@ nkululeko/file_checker.py,sha256=xJY0Q6w47pnmgJVK5rcAKPYBrCpV7eBT4_3YBzTx-H8,345
|
|
17
18
|
nkululeko/filter_data.py,sha256=4sGrKvMZ_hLnJPrHm_CqjDPKIRV8REWoT7nfSYGXbwo,7305
|
18
19
|
nkululeko/fixedsegment.py,sha256=Tb92QiuiyMsOO3WRWwuGjZGibS8hbHHCrcWAXGk7g04,2868
|
19
20
|
nkululeko/glob_conf.py,sha256=NLFh-1_I0Wdfo2EnSq1Oppx23AX6jAUpgFbk2zqZJ24,659
|
20
|
-
nkululeko/modelrunner.py,sha256=
|
21
|
+
nkululeko/modelrunner.py,sha256=OFN18uG84iJyjNVWjcvDpqbcBrmylziXCakUTNE2-ZQ,10530
|
21
22
|
nkululeko/multidb.py,sha256=sO6OwJn8sn1-C-ig3thsIL8QMWHdV9SnJhDodKjeKrI,6876
|
22
23
|
nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
|
23
24
|
nkululeko/nkululeko.py,sha256=6ALPMMIz6l0O3IRaP0q4b59ZUxpfzNqLQUqZMf5t3Zo,1976
|
24
25
|
nkululeko/plots.py,sha256=lUxgyoriYTwdpHZvBBQ4e41v77deQrt0PcRDLJWijys,27503
|
25
|
-
nkululeko/predict.py,sha256=
|
26
|
+
nkululeko/predict.py,sha256=PWv1Pc39lrxqqIWrYszVk5SL37dDL93CHgcruItNID8,2211
|
26
27
|
nkululeko/resample.py,sha256=rn3-M1A-iwVGibfQNGyeYNa7briD24lIN9Szq_1uTJo,5194
|
27
28
|
nkululeko/runmanager.py,sha256=YtGQP0UyyQTKkilncB1XYM-T8oatzGcZEOcj5SorjJw,8902
|
28
29
|
nkululeko/scaler.py,sha256=a4lKwWT436TV4VEvqtP1uQ58Yz67XVHr1HjO5gp3xLI,5109
|
29
30
|
nkululeko/segment.py,sha256=7UrJEwdLmh9wDL5iBwpdJyJm9dwSxidHrHt-_D2qtxw,4949
|
30
31
|
nkululeko/syllable_nuclei.py,sha256=5w_naKxNxz66a_qLkraemi2fggM-gWesiiBPS47iFcE,9931
|
31
32
|
nkululeko/test.py,sha256=1w624vo5KTzmFC8BUStGlLDmIEAFuJUz7J0W-gp7AxI,1677
|
32
|
-
nkululeko/test_predictor.py,sha256=
|
33
|
+
nkululeko/test_predictor.py,sha256=i8vSaB8OOrdELoDttQVMs2Bc-fUOi2C5ANqnt32K3Zk,3064
|
33
34
|
nkululeko/test_pretrain.py,sha256=6FZeETlWzg9Cq_sn3BFKhfH91jW26nAIDm1bJkInNNA,8463
|
34
35
|
nkululeko/augmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
35
36
|
nkululeko/augmenting/augmenter.py,sha256=TUUznEz0pe9DSMC9r7LoBckuvsJTprvypeV5-8zLn20,2846
|
@@ -52,6 +53,8 @@ nkululeko/autopredict/ap_text.py,sha256=zaz9qIg90-ghZhBe1ka0HoUnap6s6RyopUKoCptt
|
|
52
53
|
nkululeko/autopredict/ap_valence.py,sha256=9S06SpO_zXKSpkf0InHYYXZcD9HDGoCJ6UPkn__eBAg,1027
|
53
54
|
nkululeko/autopredict/estimate_snr.py,sha256=1k9-XadABudnsNOeFZD_Fg0E64-GUQVS7JEp82MLQS4,4995
|
54
55
|
nkululeko/autopredict/whisper_transcriber.py,sha256=DWDvpRaV5KmUF18ojPEvxnVXm_h_nWyY-TfW2Ngd5N8,2941
|
56
|
+
nkululeko/autopredict/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
|
+
nkululeko/autopredict/tests/test_whisper_transcriber.py,sha256=ilas6j3OUvq_xnQCRZgytQCtyrpNU6tvG5a8kPvVKBQ,5085
|
55
58
|
nkululeko/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
56
59
|
nkululeko/data/dataset.py,sha256=JLbBYGniUrjwxs-HtbIyhqO3Cv-ELfpmlq7jzij4dBc,41759
|
57
60
|
nkululeko/data/dataset_csv.py,sha256=AIbtB6pGk5BSQGIgfokZ7tEGFjmuOq5w2XumRSimVWs,4833
|
@@ -68,10 +71,10 @@ nkululeko/feat_extract/feats_hubert.py,sha256=F3vrPCkx8EimJjFWYCZ7Yg9uo1G3NjYt4U
|
|
68
71
|
nkululeko/feat_extract/feats_import.py,sha256=cPi4XRuRs71npB8YGXr7rYOvkeTU_oZEl3GrGncdiqY,2222
|
69
72
|
nkululeko/feat_extract/feats_mld.py,sha256=5aRoYiGDm5ApoFntxAMQYPjEelXHHRBHZcAJR9dxaeI,1945
|
70
73
|
nkululeko/feat_extract/feats_mos.py,sha256=vkH1FdXtduoU0-yjBtVccC2b_p_eyH8laRnwlL7QTVM,4136
|
71
|
-
nkululeko/feat_extract/feats_opensmile copy.py,sha256=BLj5sUaBPz7vLPfNlt9LdQurSypmViqgSpPK-6aXGhQ,4029
|
72
74
|
nkululeko/feat_extract/feats_opensmile.py,sha256=HwbGs0EaPxZ7DznQZFem8RYgyQWz02oya77uVY7KhZE,9203
|
73
75
|
nkululeko/feat_extract/feats_oxbow.py,sha256=TRoEJx5EKZiqoPoPRibHc0vkBMoZcKlGoGNq4NbyHZw,4895
|
74
|
-
nkululeko/feat_extract/feats_praat.py,sha256=
|
76
|
+
nkululeko/feat_extract/feats_praat.py,sha256=3j1xySKqW74USjk8DweWAajHeTcuszKCFY1htQhe1cY,3070
|
77
|
+
nkululeko/feat_extract/feats_praat_core.py,sha256=Q0OVuo5h38a860yflzRtUpy0J0w7WCg0aBLrDhIskFc,28524
|
75
78
|
nkululeko/feat_extract/feats_snr.py,sha256=Zxwo78HLleNsziYLOj34RQUnp9I7r1yMXqjYipDOjZw,2761
|
76
79
|
nkululeko/feat_extract/feats_spectra.py,sha256=6WhFUpB0WTutg7OFMlAw9lSwVU5OBYCDcPRxaiH-Qn8,3621
|
77
80
|
nkululeko/feat_extract/feats_spkrec.py,sha256=o_6bdU4lIkj64S5Kdjf1iyuo1VASeYxE4XdxV94a8gE,4732
|
@@ -81,8 +84,10 @@ nkululeko/feat_extract/feats_wav2vec2.py,sha256=q1QzMD3KbhF2SOmxdwI7CiViRmhlFRyg
|
|
81
84
|
nkululeko/feat_extract/feats_wavlm.py,sha256=O9cfc39VF5aPJRRATKb37pHT4W11i2cu5O1mY9LOjIA,4755
|
82
85
|
nkululeko/feat_extract/feats_whisper.py,sha256=n3ESZtva7wshs8E8diBlQYa9xCH_P0UY1DncSrxz-FY,4508
|
83
86
|
nkululeko/feat_extract/featureset.py,sha256=clcBv9rzBRW-bfw7JC_FYTjU5uUS-c0UE1XtQLYYRiE,1615
|
84
|
-
nkululeko/feat_extract/feinberg_praat.py,sha256=mMin5V-Kmx24oYJT_miNFN4t-tEVEF3Cd0969xiVV0E,28573
|
85
87
|
nkululeko/feat_extract/transformer_feature_extractor.py,sha256=LaXuW-AJZ931ttLis0J5h9N3RtiiE51BnkxJR-bubfY,5837
|
88
|
+
nkululeko/feat_extract/tests/__init__.py,sha256=pzjkYs1PNo7107jIXKa_xwdBR2SKxzkg53a9W3bvbpw,32
|
89
|
+
nkululeko/feat_extract/tests/test_feats_opensmile.py,sha256=eYjGBsH6UkuRleKzGZHNv2cXRZz2xPCw0dkTfXw5S9s,5761
|
90
|
+
nkululeko/feat_extract/tests/test_feats_praat_core.py,sha256=ntbpIrehr4D-lOvaE0hNCe-og5sN4syBGBUTuNGZpDo,20916
|
86
91
|
nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
87
92
|
nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
|
88
93
|
nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
|
@@ -104,26 +109,31 @@ nkululeko/models/model_tuned.py,sha256=74c_pQUtpx_x8bM3r5ufuqhaaQxfy6KRUqirdzSac
|
|
104
109
|
nkululeko/models/model_xgb.py,sha256=_VxFFP1QcoyxrwvJSrzdIwwDt85IulUWvg1BxXBgN1Y,6616
|
105
110
|
nkululeko/models/model_xgr.py,sha256=H01FJCRgmX2unvambMs5TTCS9sI6VDB9ip9G6rVGt2c,419
|
106
111
|
nkululeko/models/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
112
|
+
nkululeko/models/tests/test_model_knn.py,sha256=hFCJ0C0taQO-fwA7j8HcFrwCSluSb6Vg4NCQQ_zL4bc,1793
|
113
|
+
nkululeko/models/tests/test_model_mlp.py,sha256=XVvniKAtroxLRKyYGW-ew1mHuRo3_cWk4nGnXQ5aDEk,4977
|
107
114
|
nkululeko/models/tests/test_model_svm.py,sha256=spDlZmeBKBdK4EFBpOgEkaAfGeGH9kau6CqSWOY6Uag,1856
|
115
|
+
nkululeko/models/tests/test_model_xgb.py,sha256=-Rz5YTeqUJ4Kwdh5ny31c3zxsUJXTypR4L3ItoOU7yU,1036
|
108
116
|
nkululeko/reporting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
109
117
|
nkululeko/reporting/defines.py,sha256=0vh-Tlx4fAPpk1o6mP_4x3EkIoqzYMr38IZnj-JM5z4,641
|
110
118
|
nkululeko/reporting/latex_writer.py,sha256=NGwSIfd4nfslDkNUOSZSdqY_VDLA8634thyhe-vj1bY,1824
|
111
119
|
nkululeko/reporting/report.py,sha256=B5eoIKMz46VKDBsi7M9u_iegzAD-E3eGCmolzSFjZ3c,1118
|
112
120
|
nkululeko/reporting/report_item.py,sha256=drkknsyFhGviaPJNmPQtCXJmRhTSSfjNcJt0Bls6JAA,533
|
113
|
-
nkululeko/reporting/reporter.py,sha256=
|
121
|
+
nkululeko/reporting/reporter.py,sha256=e-piNtnv0QUWKs9Ha_d4CzgqJxPBG9XBm3Ru8y0ot-U,20896
|
114
122
|
nkululeko/reporting/result.py,sha256=G63a2tHCwHhM6NBJgYzsWKWJm4Yu3r4hsCHA2Km7eHU,1073
|
115
123
|
nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
116
124
|
nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=b3t0zdpJYofKWMyKRMtMMX91xeR-k8d5pbnNaQHcsOE,1902
|
117
125
|
nkululeko/segmenting/seg_pyannote.py,sha256=6IPbgjnGOz9juzEKDTZN3PSipX4t6Mz-DILAx3rp5do,4216
|
118
126
|
nkululeko/segmenting/seg_silero.py,sha256=ulodnvtRq5MLHDxy_RmAK4tJg6h1d-mPq-uCPFkGVKg,4258
|
127
|
+
nkululeko/tests/__init__.py,sha256=XzD6C-ZuewsccUwx7KzEUtUxJrRx2d7sPFViscjf1O0,30
|
128
|
+
nkululeko/tests/test_balancing.py,sha256=21110R77iTcSWKiSTxYDkJ26lxPFTlZf_ZwVjeiSh4w,10164
|
119
129
|
nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
120
130
|
nkululeko/utils/files.py,sha256=SrrYaU7AB80MZHiV1jcB0h_zigvYLYgSVNTXV4ao38g,4593
|
121
131
|
nkululeko/utils/stats.py,sha256=3Fyx8q8BSKYmiufT6OkRug9RATWmGrr9BaX_y8jziWo,3074
|
122
132
|
nkululeko/utils/unzip.py,sha256=G68f5120TjwACZC3bQcneMniddnwubPbBdMc2L5KBOo,1206
|
123
|
-
nkululeko/utils/util.py,sha256=
|
124
|
-
nkululeko-0.95.
|
125
|
-
nkululeko-0.95.
|
126
|
-
nkululeko-0.95.
|
127
|
-
nkululeko-0.95.
|
128
|
-
nkululeko-0.95.
|
129
|
-
nkululeko-0.95.
|
133
|
+
nkululeko/utils/util.py,sha256=o62TZRcxO1VflINai6ojEzSmcbXIFInNLGogSbqJgiA,18561
|
134
|
+
nkululeko-0.95.1.dist-info/licenses/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
135
|
+
nkululeko-0.95.1.dist-info/METADATA,sha256=KhJ1JPenNsZGUIhdeYGvNKrM1H-ioqONAh06LpxdnMQ,2874
|
136
|
+
nkululeko-0.95.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
137
|
+
nkululeko-0.95.1.dist-info/entry_points.txt,sha256=lNTkFEdh6Kjo5o95ZAWf_0Lq-4ztGoAoMVSDuPtuyS0,442
|
138
|
+
nkululeko-0.95.1.dist-info/top_level.txt,sha256=bf1k1YKkqcXemNX_cUgoyKqQ3_GVErPqAY-53J36jkM,19
|
139
|
+
nkululeko-0.95.1.dist-info/RECORD,,
|