nkululeko 0.94.2__py3-none-any.whl → 0.94.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/augmenting/resampler.py +22 -14
- nkululeko/constants.py +1 -1
- nkululeko/models/model.py +42 -4
- nkululeko/models/model_xgb.py +1 -1
- nkululeko/nkululeko.py +9 -0
- nkululeko/runmanager.py +1 -1
- {nkululeko-0.94.2.dist-info → nkululeko-0.94.3.dist-info}/METADATA +1 -1
- {nkululeko-0.94.2.dist-info → nkululeko-0.94.3.dist-info}/RECORD +12 -12
- {nkululeko-0.94.2.dist-info → nkululeko-0.94.3.dist-info}/WHEEL +0 -0
- {nkululeko-0.94.2.dist-info → nkululeko-0.94.3.dist-info}/entry_points.txt +0 -0
- {nkululeko-0.94.2.dist-info → nkululeko-0.94.3.dist-info}/licenses/LICENSE +0 -0
- {nkululeko-0.94.2.dist-info → nkululeko-0.94.3.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ class Resampler:
|
|
17
17
|
def __init__(self, df, replace, not_testing=True):
|
18
18
|
self.SAMPLING_RATE = 16000
|
19
19
|
self.df = df
|
20
|
-
self.util = Util("resampler", has_config=not_testing)
|
20
|
+
self.util = Util("resampler", has_config=not not_testing)
|
21
21
|
self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
|
22
22
|
self.not_testing = not_testing
|
23
23
|
self.replace = (
|
@@ -30,7 +30,7 @@ class Resampler:
|
|
30
30
|
files = self.df.index.get_level_values(0).values
|
31
31
|
# replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
|
32
32
|
replace = self.replace
|
33
|
-
if self.not_testing:
|
33
|
+
if not self.not_testing:
|
34
34
|
store = self.util.get_path("store")
|
35
35
|
else:
|
36
36
|
store = "./"
|
@@ -67,17 +67,25 @@ class Resampler:
|
|
67
67
|
self.df = self.df.set_index(
|
68
68
|
self.df.index.set_levels(new_files, level="file")
|
69
69
|
)
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
70
|
+
if not self.not_testing:
|
71
|
+
target_file = self.util.config_val("RESAMPLE", "target", "resampled.csv")
|
72
|
+
# remove encoded labels
|
73
|
+
target = self.util.config_val("DATA", "target", "emotion")
|
74
|
+
if "class_label" in self.df.columns:
|
75
|
+
self.df = self.df.drop(columns=[target])
|
76
|
+
self.df = self.df.rename(columns={"class_label": target})
|
77
|
+
# save file
|
78
|
+
self.df.to_csv(target_file)
|
79
|
+
self.util.debug(
|
80
|
+
"saved resampled list of files to" f" {os.path.abspath(target_file)}"
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
# When running from command line, save to simple resampled.csv
|
84
|
+
target_file = "resampled.csv"
|
85
|
+
self.df.to_csv(target_file)
|
86
|
+
self.util.debug(
|
87
|
+
f"saved resampled list of files to {os.path.abspath(target_file)}"
|
88
|
+
)
|
81
89
|
self.util.debug(f"resampled {succes} files, {error} errors")
|
82
90
|
|
83
91
|
|
@@ -91,7 +99,7 @@ def main():
|
|
91
99
|
df_sample.index, allow_nat=False
|
92
100
|
)
|
93
101
|
df_sample.head(10)
|
94
|
-
resampler = Resampler(df_sample, not_testing=False)
|
102
|
+
resampler = Resampler(df_sample, False, not_testing=False)
|
95
103
|
resampler.resample()
|
96
104
|
shutil.copyfile(testfile, "tmp.resample_result.wav")
|
97
105
|
shutil.copyfile("tmp.wav", testfile)
|
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.94.
|
1
|
+
VERSION="0.94.3"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/models/model.py
CHANGED
@@ -3,11 +3,15 @@ import ast
|
|
3
3
|
import pickle
|
4
4
|
import random
|
5
5
|
|
6
|
+
from joblib import parallel_backend
|
6
7
|
import numpy as np
|
7
8
|
import pandas as pd
|
9
|
+
from sklearn.model_selection import GridSearchCV
|
10
|
+
from sklearn.model_selection import LeaveOneGroupOut
|
11
|
+
from sklearn.model_selection import StratifiedKFold
|
8
12
|
import sklearn.utils
|
9
|
-
|
10
|
-
|
13
|
+
|
14
|
+
import audeer
|
11
15
|
|
12
16
|
import nkululeko.glob_conf as glob_conf
|
13
17
|
from nkululeko.reporting.reporter import Reporter
|
@@ -301,8 +305,15 @@ class Model:
|
|
301
305
|
def get_type(self):
|
302
306
|
return "generic"
|
303
307
|
|
304
|
-
def predict_sample(self, features):
|
305
|
-
"""Predict
|
308
|
+
def predict_sample(self, features: np.ndarray) -> dict | float:
|
309
|
+
"""Predict a single sample using the trained model.
|
310
|
+
|
311
|
+
Args:
|
312
|
+
features (np.ndarray): The feature vector of the sample to predict.
|
313
|
+
|
314
|
+
Returns:
|
315
|
+
dict: A dictionary containing the predicted class probabilities or value.
|
316
|
+
"""
|
306
317
|
prediction = {}
|
307
318
|
if self.util.exp_is_classification():
|
308
319
|
# get the class probabilities
|
@@ -336,3 +347,30 @@ class Model:
|
|
336
347
|
self.set_id(run, epoch)
|
337
348
|
with open(path, "rb") as handle:
|
338
349
|
self.clf = pickle.load(handle)
|
350
|
+
|
351
|
+
# next function exports the model to onnx
|
352
|
+
def export_onnx(self, onnx_path, input_shape=None):
|
353
|
+
"""Export the trained sklearn model to ONNX format.
|
354
|
+
|
355
|
+
Args:
|
356
|
+
onnx_path (str): Path to save the ONNX model.
|
357
|
+
input_shape (tuple, optional): Shape of the input features. If None, inferred from feats_train.
|
358
|
+
"""
|
359
|
+
import skl2onnx
|
360
|
+
from skl2onnx import convert_sklearn
|
361
|
+
from skl2onnx.common.data_types import FloatTensorType
|
362
|
+
|
363
|
+
if not hasattr(self, "clf"):
|
364
|
+
self.util.error("No trained model found to export.")
|
365
|
+
return
|
366
|
+
|
367
|
+
if input_shape is None:
|
368
|
+
n_features = self.feats_train.shape[1]
|
369
|
+
initial_type = [("input", FloatTensorType([None, n_features]))]
|
370
|
+
else:
|
371
|
+
initial_type = [("input", FloatTensorType(input_shape))]
|
372
|
+
|
373
|
+
onnx_model = convert_sklearn(self.clf, initial_types=initial_type)
|
374
|
+
with open(audeer.path(onnx_path), "wb") as f:
|
375
|
+
f.write(onnx_model.SerializeToString())
|
376
|
+
self.util.debug(f"Model exported to ONNX at {onnx_path}")
|
nkululeko/models/model_xgb.py
CHANGED
nkululeko/nkululeko.py
CHANGED
@@ -54,6 +54,15 @@ def doit(config_file):
|
|
54
54
|
reports, last_epochs = expr.run()
|
55
55
|
result = expr.get_best_report(reports).result.test
|
56
56
|
expr.store_report()
|
57
|
+
|
58
|
+
# check if we want to export the model
|
59
|
+
o_path = util.config_val("EXP", "export_onnx", "False")
|
60
|
+
if eval(o_path):
|
61
|
+
print(f"Exporting ONNX model to {o_path}")
|
62
|
+
o_path = o_path.replace('"', '')
|
63
|
+
expr.runmgr.get_best_model().export_onnx(str(o_path))
|
64
|
+
|
65
|
+
|
57
66
|
print("DONE")
|
58
67
|
return result, int(np.asarray(last_epochs).min())
|
59
68
|
|
nkululeko/runmanager.py
CHANGED
@@ -181,7 +181,7 @@ class Runmanager:
|
|
181
181
|
"""
|
182
182
|
# self.load_model(report)
|
183
183
|
# report = self.model.predict()
|
184
|
-
self.util.debug(f"plotting conf matrix
|
184
|
+
self.util.debug(f"plotting conf matrix as {plot_name}")
|
185
185
|
report.plot_confmatrix(plot_name, epoch=report.epoch)
|
186
186
|
report.print_results(report.epoch, file_name=plot_name)
|
187
187
|
report.print_probabilities(file_name=plot_name)
|
@@ -3,7 +3,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
3
3
|
nkululeko/aug_train.py,sha256=wpiHCJ7zsW38kumg3ypwXZe2HQrhUblAnv7P2QeJnAc,3525
|
4
4
|
nkululeko/augment.py,sha256=3RzaxB3gRxovgJVjHXi0glprW01J7RaHhUkqotW2T3U,2955
|
5
5
|
nkululeko/cacheddataset.py,sha256=XFpWZmbJRg0pvhnIgYf0TkclxllD-Fctu-Ol0PF_00c,969
|
6
|
-
nkululeko/constants.py,sha256=
|
6
|
+
nkululeko/constants.py,sha256=KCqkmtwj--gcAdaRwj_Zb44_ewVNp06Hfp8-YGDG8iI,39
|
7
7
|
nkululeko/demo-ft.py,sha256=iD9Pzp9QjyAv31q1cDZ75vPez7Ve8A4Cfukv5yfZdrQ,770
|
8
8
|
nkululeko/demo.py,sha256=tu7Al2l5MCLVegkDC-NE2wcuc_YE7NRbgOlPW3yhGEs,4940
|
9
9
|
nkululeko/demo_feats.py,sha256=BvZjeNFTlERIRlq34OHM4Z96jdDQAhB01BGQAUcX9dM,2026
|
@@ -20,11 +20,11 @@ nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
|
|
20
20
|
nkululeko/modelrunner.py,sha256=NpDgXfKkn8dOrQzhUiEfGI56Qrb1sOtWTD31II4Zgbk,11550
|
21
21
|
nkululeko/multidb.py,sha256=sO6OwJn8sn1-C-ig3thsIL8QMWHdV9SnJhDodKjeKrI,6876
|
22
22
|
nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
|
23
|
-
nkululeko/nkululeko.py,sha256=
|
23
|
+
nkululeko/nkululeko.py,sha256=FaLimlbx47rJgWgDEd0ZROAiXy2cOypliVdqJn-Bvws,2257
|
24
24
|
nkululeko/plots.py,sha256=i9VIkviBWLgncfnyK44TUMzg2Xa0_UhfL0LnMF1vHTw,27022
|
25
25
|
nkululeko/predict.py,sha256=MLnHEyFmSiHLLs-HDczag8Vu3zKF5T1rXLKdZZJ6py8,2083
|
26
26
|
nkululeko/resample.py,sha256=rn3-M1A-iwVGibfQNGyeYNa7briD24lIN9Szq_1uTJo,5194
|
27
|
-
nkululeko/runmanager.py,sha256
|
27
|
+
nkululeko/runmanager.py,sha256=YtGQP0UyyQTKkilncB1XYM-T8oatzGcZEOcj5SorjJw,8902
|
28
28
|
nkululeko/scaler.py,sha256=a4lKwWT436TV4VEvqtP1uQ58Yz67XVHr1HjO5gp3xLI,5109
|
29
29
|
nkululeko/segment.py,sha256=7UrJEwdLmh9wDL5iBwpdJyJm9dwSxidHrHt-_D2qtxw,4949
|
30
30
|
nkululeko/syllable_nuclei.py,sha256=5w_naKxNxz66a_qLkraemi2fggM-gWesiiBPS47iFcE,9931
|
@@ -35,7 +35,7 @@ nkululeko/augmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
|
|
35
35
|
nkululeko/augmenting/augmenter.py,sha256=TUUznEz0pe9DSMC9r7LoBckuvsJTprvypeV5-8zLn20,2846
|
36
36
|
nkululeko/augmenting/randomsplicer.py,sha256=TQTy4RBt6XbWiuUu5Ic913DMvmwTUwEufldBJjo7i1s,2801
|
37
37
|
nkululeko/augmenting/randomsplicing.py,sha256=GXCpCDdOsOyWACDJ3ujmFZBVe6ISvkoQLefBNPgxxow,1750
|
38
|
-
nkululeko/augmenting/resampler.py,sha256=
|
38
|
+
nkululeko/augmenting/resampler.py,sha256=j2yuB9h9UwGQHqwF8CZPSGqAfOiyQV3979WQjU2toVM,3962
|
39
39
|
nkululeko/autopredict/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
40
|
nkululeko/autopredict/ap_age.py,sha256=yzd8sF6gi0hnqNawyLBCIkt-pKgl9gYPlZHsrLGfz0U,1098
|
41
41
|
nkululeko/autopredict/ap_arousal.py,sha256=lpv3jTSVEVCcR226JevNM6S7e0_uMZXHb_8Wpup1yj8,1027
|
@@ -84,7 +84,7 @@ nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
84
84
|
nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
|
85
85
|
nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
|
86
86
|
nkululeko/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
87
|
-
nkululeko/models/model.py,sha256=
|
87
|
+
nkululeko/models/model.py,sha256=0O6H-kME1yVHU-EKu5iOZVBB7fFNg3lfagvGgMrldxM,14426
|
88
88
|
nkululeko/models/model_bayes.py,sha256=tQUXEsXoS6WnfapQjP78S_gxNBssTOqE78A2iG8SfLU,407
|
89
89
|
nkululeko/models/model_cnn.py,sha256=TKj43865epsiK7a0COyfBDaFHKOYgWgnPpMVCPWUhCM,10497
|
90
90
|
nkululeko/models/model_gmm.py,sha256=mhHFNtTzHuJvqYSA0h5YhvjA--KhnN6MTU_S0G3-d1c,1332
|
@@ -98,7 +98,7 @@ nkululeko/models/model_svr.py,sha256=FEwYRdgqwgGhZdkpRnT7Ef12lklWi6GZL28PyV99xWs
|
|
98
98
|
nkululeko/models/model_tree.py,sha256=6L3PD3aIiiQz1RPWS6z3Edx4f0gnR7AOfBKOJzf0BNU,433
|
99
99
|
nkululeko/models/model_tree_reg.py,sha256=IMaQpNImoRqP8Biw1CsJevxpV_PVpKblsKtYlMW5d_U,429
|
100
100
|
nkululeko/models/model_tuned.py,sha256=VuRyNqw3XTpQ2eHsWOJN8X-V98AN8Wqiq7UgwT5BQRU,23763
|
101
|
-
nkululeko/models/model_xgb.py,sha256=
|
101
|
+
nkululeko/models/model_xgb.py,sha256=zfZM3lqH5uttVB18b1MRIhP9CCeCuIh1ycgOuFMcqUM,449
|
102
102
|
nkululeko/models/model_xgr.py,sha256=H01FJCRgmX2unvambMs5TTCS9sI6VDB9ip9G6rVGt2c,419
|
103
103
|
nkululeko/models/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
104
104
|
nkululeko/models/tests/test_model_svm.py,sha256=spDlZmeBKBdK4EFBpOgEkaAfGeGH9kau6CqSWOY6Uag,1856
|
@@ -118,9 +118,9 @@ nkululeko/utils/files.py,sha256=SrrYaU7AB80MZHiV1jcB0h_zigvYLYgSVNTXV4ao38g,4593
|
|
118
118
|
nkululeko/utils/stats.py,sha256=3Fyx8q8BSKYmiufT6OkRug9RATWmGrr9BaX_y8jziWo,3074
|
119
119
|
nkululeko/utils/unzip.py,sha256=G68f5120TjwACZC3bQcneMniddnwubPbBdMc2L5KBOo,1206
|
120
120
|
nkululeko/utils/util.py,sha256=6NDKhOx0fV5fKyhSoY4hem96p7OuPcmhCDQR9EzkQhw,17829
|
121
|
-
nkululeko-0.94.
|
122
|
-
nkululeko-0.94.
|
123
|
-
nkululeko-0.94.
|
124
|
-
nkululeko-0.94.
|
125
|
-
nkululeko-0.94.
|
126
|
-
nkululeko-0.94.
|
121
|
+
nkululeko-0.94.3.dist-info/licenses/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
122
|
+
nkululeko-0.94.3.dist-info/METADATA,sha256=QeZ9ZMTqwgdDvwRTCvgFO7X55_J84AWZh7jVf9uV-6M,2874
|
123
|
+
nkululeko-0.94.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
124
|
+
nkululeko-0.94.3.dist-info/entry_points.txt,sha256=lNTkFEdh6Kjo5o95ZAWf_0Lq-4ztGoAoMVSDuPtuyS0,442
|
125
|
+
nkululeko-0.94.3.dist-info/top_level.txt,sha256=bf1k1YKkqcXemNX_cUgoyKqQ3_GVErPqAY-53J36jkM,19
|
126
|
+
nkululeko-0.94.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|