nkululeko 0.94.2__py3-none-any.whl → 0.94.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,7 +17,7 @@ class Resampler:
17
17
  def __init__(self, df, replace, not_testing=True):
18
18
  self.SAMPLING_RATE = 16000
19
19
  self.df = df
20
- self.util = Util("resampler", has_config=not_testing)
20
+ self.util = Util("resampler", has_config=not not_testing)
21
21
  self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
22
22
  self.not_testing = not_testing
23
23
  self.replace = (
@@ -30,7 +30,7 @@ class Resampler:
30
30
  files = self.df.index.get_level_values(0).values
31
31
  # replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
32
32
  replace = self.replace
33
- if self.not_testing:
33
+ if not self.not_testing:
34
34
  store = self.util.get_path("store")
35
35
  else:
36
36
  store = "./"
@@ -67,17 +67,25 @@ class Resampler:
67
67
  self.df = self.df.set_index(
68
68
  self.df.index.set_levels(new_files, level="file")
69
69
  )
70
- target_file = self.util.config_val("RESAMPLE", "target", "resampled.csv")
71
- # remove encoded labels
72
- target = self.util.config_val("DATA", "target", "emotion")
73
- if "class_label" in self.df.columns:
74
- self.df = self.df.drop(columns=[target])
75
- self.df = self.df.rename(columns={"class_label": target})
76
- # save file
77
- self.df.to_csv(target_file)
78
- self.util.debug(
79
- "saved resampled list of files to" f" {os.path.abspath(target_file)}"
80
- )
70
+ if not self.not_testing:
71
+ target_file = self.util.config_val("RESAMPLE", "target", "resampled.csv")
72
+ # remove encoded labels
73
+ target = self.util.config_val("DATA", "target", "emotion")
74
+ if "class_label" in self.df.columns:
75
+ self.df = self.df.drop(columns=[target])
76
+ self.df = self.df.rename(columns={"class_label": target})
77
+ # save file
78
+ self.df.to_csv(target_file)
79
+ self.util.debug(
80
+ "saved resampled list of files to" f" {os.path.abspath(target_file)}"
81
+ )
82
+ else:
83
+ # When running from command line, save to simple resampled.csv
84
+ target_file = "resampled.csv"
85
+ self.df.to_csv(target_file)
86
+ self.util.debug(
87
+ f"saved resampled list of files to {os.path.abspath(target_file)}"
88
+ )
81
89
  self.util.debug(f"resampled {succes} files, {error} errors")
82
90
 
83
91
 
@@ -91,7 +99,7 @@ def main():
91
99
  df_sample.index, allow_nat=False
92
100
  )
93
101
  df_sample.head(10)
94
- resampler = Resampler(df_sample, not_testing=False)
102
+ resampler = Resampler(df_sample, False, not_testing=False)
95
103
  resampler.resample()
96
104
  shutil.copyfile(testfile, "tmp.resample_result.wav")
97
105
  shutil.copyfile("tmp.wav", testfile)
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.94.2"
1
+ VERSION="0.94.3"
2
2
  SAMPLING_RATE = 16000
nkululeko/models/model.py CHANGED
@@ -3,11 +3,15 @@ import ast
3
3
  import pickle
4
4
  import random
5
5
 
6
+ from joblib import parallel_backend
6
7
  import numpy as np
7
8
  import pandas as pd
9
+ from sklearn.model_selection import GridSearchCV
10
+ from sklearn.model_selection import LeaveOneGroupOut
11
+ from sklearn.model_selection import StratifiedKFold
8
12
  import sklearn.utils
9
- from joblib import parallel_backend
10
- from sklearn.model_selection import GridSearchCV, LeaveOneGroupOut, StratifiedKFold
13
+
14
+ import audeer
11
15
 
12
16
  import nkululeko.glob_conf as glob_conf
13
17
  from nkululeko.reporting.reporter import Reporter
@@ -301,8 +305,15 @@ class Model:
301
305
  def get_type(self):
302
306
  return "generic"
303
307
 
304
- def predict_sample(self, features):
305
- """Predict one sample"""
308
+ def predict_sample(self, features: np.ndarray) -> dict | float:
309
+ """Predict a single sample using the trained model.
310
+
311
+ Args:
312
+ features (np.ndarray): The feature vector of the sample to predict.
313
+
314
+ Returns:
315
+ dict: A dictionary containing the predicted class probabilities or value.
316
+ """
306
317
  prediction = {}
307
318
  if self.util.exp_is_classification():
308
319
  # get the class probabilities
@@ -336,3 +347,30 @@ class Model:
336
347
  self.set_id(run, epoch)
337
348
  with open(path, "rb") as handle:
338
349
  self.clf = pickle.load(handle)
350
+
351
+ # next function exports the model to onnx
352
+ def export_onnx(self, onnx_path, input_shape=None):
353
+ """Export the trained sklearn model to ONNX format.
354
+
355
+ Args:
356
+ onnx_path (str): Path to save the ONNX model.
357
+ input_shape (tuple, optional): Shape of the input features. If None, inferred from feats_train.
358
+ """
359
+ import skl2onnx
360
+ from skl2onnx import convert_sklearn
361
+ from skl2onnx.common.data_types import FloatTensorType
362
+
363
+ if not hasattr(self, "clf"):
364
+ self.util.error("No trained model found to export.")
365
+ return
366
+
367
+ if input_shape is None:
368
+ n_features = self.feats_train.shape[1]
369
+ initial_type = [("input", FloatTensorType([None, n_features]))]
370
+ else:
371
+ initial_type = [("input", FloatTensorType(input_shape))]
372
+
373
+ onnx_model = convert_sklearn(self.clf, initial_types=initial_type)
374
+ with open(audeer.path(onnx_path), "wb") as f:
375
+ f.write(onnx_model.SerializeToString())
376
+ self.util.debug(f"Model exported to ONNX at {onnx_path}")
@@ -1,4 +1,4 @@
1
- # xgbmodel.py
1
+ # model_xgb.py
2
2
 
3
3
  from xgboost import XGBClassifier
4
4
 
nkululeko/nkululeko.py CHANGED
@@ -54,6 +54,15 @@ def doit(config_file):
54
54
  reports, last_epochs = expr.run()
55
55
  result = expr.get_best_report(reports).result.test
56
56
  expr.store_report()
57
+
58
+ # check if we want to export the model
59
+ o_path = util.config_val("EXP", "export_onnx", "False")
60
+ if eval(o_path):
61
+ print(f"Exporting ONNX model to {o_path}")
62
+ o_path = o_path.replace('"', '')
63
+ expr.runmgr.get_best_model().export_onnx(str(o_path))
64
+
65
+
57
66
  print("DONE")
58
67
  return result, int(np.asarray(last_epochs).min())
59
68
 
nkululeko/runmanager.py CHANGED
@@ -181,7 +181,7 @@ class Runmanager:
181
181
  """
182
182
  # self.load_model(report)
183
183
  # report = self.model.predict()
184
- self.util.debug(f"plotting conf matrix to {plot_name}")
184
+ self.util.debug(f"plotting conf matrix as {plot_name}")
185
185
  report.plot_confmatrix(plot_name, epoch=report.epoch)
186
186
  report.print_results(report.epoch, file_name=plot_name)
187
187
  report.print_probabilities(file_name=plot_name)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nkululeko
3
- Version: 0.94.2
3
+ Version: 0.94.3
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -3,7 +3,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
3
3
  nkululeko/aug_train.py,sha256=wpiHCJ7zsW38kumg3ypwXZe2HQrhUblAnv7P2QeJnAc,3525
4
4
  nkululeko/augment.py,sha256=3RzaxB3gRxovgJVjHXi0glprW01J7RaHhUkqotW2T3U,2955
5
5
  nkululeko/cacheddataset.py,sha256=XFpWZmbJRg0pvhnIgYf0TkclxllD-Fctu-Ol0PF_00c,969
6
- nkululeko/constants.py,sha256=1uXTovFvHp9PGrpkAyOYxPPZcMM5ojIHW2L7u92ADUo,39
6
+ nkululeko/constants.py,sha256=KCqkmtwj--gcAdaRwj_Zb44_ewVNp06Hfp8-YGDG8iI,39
7
7
  nkululeko/demo-ft.py,sha256=iD9Pzp9QjyAv31q1cDZ75vPez7Ve8A4Cfukv5yfZdrQ,770
8
8
  nkululeko/demo.py,sha256=tu7Al2l5MCLVegkDC-NE2wcuc_YE7NRbgOlPW3yhGEs,4940
9
9
  nkululeko/demo_feats.py,sha256=BvZjeNFTlERIRlq34OHM4Z96jdDQAhB01BGQAUcX9dM,2026
@@ -20,11 +20,11 @@ nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
20
20
  nkululeko/modelrunner.py,sha256=NpDgXfKkn8dOrQzhUiEfGI56Qrb1sOtWTD31II4Zgbk,11550
21
21
  nkululeko/multidb.py,sha256=sO6OwJn8sn1-C-ig3thsIL8QMWHdV9SnJhDodKjeKrI,6876
22
22
  nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
23
- nkululeko/nkululeko.py,sha256=6ALPMMIz6l0O3IRaP0q4b59ZUxpfzNqLQUqZMf5t3Zo,1976
23
+ nkululeko/nkululeko.py,sha256=FaLimlbx47rJgWgDEd0ZROAiXy2cOypliVdqJn-Bvws,2257
24
24
  nkululeko/plots.py,sha256=i9VIkviBWLgncfnyK44TUMzg2Xa0_UhfL0LnMF1vHTw,27022
25
25
  nkululeko/predict.py,sha256=MLnHEyFmSiHLLs-HDczag8Vu3zKF5T1rXLKdZZJ6py8,2083
26
26
  nkululeko/resample.py,sha256=rn3-M1A-iwVGibfQNGyeYNa7briD24lIN9Szq_1uTJo,5194
27
- nkululeko/runmanager.py,sha256=-QI7pGLVnNAPMAIobcEip9zQNQdO2u0sp0Yd4XH4mBE,8902
27
+ nkululeko/runmanager.py,sha256=YtGQP0UyyQTKkilncB1XYM-T8oatzGcZEOcj5SorjJw,8902
28
28
  nkululeko/scaler.py,sha256=a4lKwWT436TV4VEvqtP1uQ58Yz67XVHr1HjO5gp3xLI,5109
29
29
  nkululeko/segment.py,sha256=7UrJEwdLmh9wDL5iBwpdJyJm9dwSxidHrHt-_D2qtxw,4949
30
30
  nkululeko/syllable_nuclei.py,sha256=5w_naKxNxz66a_qLkraemi2fggM-gWesiiBPS47iFcE,9931
@@ -35,7 +35,7 @@ nkululeko/augmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
35
35
  nkululeko/augmenting/augmenter.py,sha256=TUUznEz0pe9DSMC9r7LoBckuvsJTprvypeV5-8zLn20,2846
36
36
  nkululeko/augmenting/randomsplicer.py,sha256=TQTy4RBt6XbWiuUu5Ic913DMvmwTUwEufldBJjo7i1s,2801
37
37
  nkululeko/augmenting/randomsplicing.py,sha256=GXCpCDdOsOyWACDJ3ujmFZBVe6ISvkoQLefBNPgxxow,1750
38
- nkululeko/augmenting/resampler.py,sha256=gcjyyTD6QtJK6s_xoOQpsu5adpn0uSJwHxJTHMskfOM,3541
38
+ nkululeko/augmenting/resampler.py,sha256=j2yuB9h9UwGQHqwF8CZPSGqAfOiyQV3979WQjU2toVM,3962
39
39
  nkululeko/autopredict/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
40
  nkululeko/autopredict/ap_age.py,sha256=yzd8sF6gi0hnqNawyLBCIkt-pKgl9gYPlZHsrLGfz0U,1098
41
41
  nkululeko/autopredict/ap_arousal.py,sha256=lpv3jTSVEVCcR226JevNM6S7e0_uMZXHb_8Wpup1yj8,1027
@@ -84,7 +84,7 @@ nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
84
84
  nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
85
85
  nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
86
86
  nkululeko/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
87
- nkululeko/models/model.py,sha256=2STBD3jtLKeNSk7arCFJdaV6FL-nuLR1qpsjvZ4W-9A,12975
87
+ nkululeko/models/model.py,sha256=0O6H-kME1yVHU-EKu5iOZVBB7fFNg3lfagvGgMrldxM,14426
88
88
  nkululeko/models/model_bayes.py,sha256=tQUXEsXoS6WnfapQjP78S_gxNBssTOqE78A2iG8SfLU,407
89
89
  nkululeko/models/model_cnn.py,sha256=TKj43865epsiK7a0COyfBDaFHKOYgWgnPpMVCPWUhCM,10497
90
90
  nkululeko/models/model_gmm.py,sha256=mhHFNtTzHuJvqYSA0h5YhvjA--KhnN6MTU_S0G3-d1c,1332
@@ -98,7 +98,7 @@ nkululeko/models/model_svr.py,sha256=FEwYRdgqwgGhZdkpRnT7Ef12lklWi6GZL28PyV99xWs
98
98
  nkululeko/models/model_tree.py,sha256=6L3PD3aIiiQz1RPWS6z3Edx4f0gnR7AOfBKOJzf0BNU,433
99
99
  nkululeko/models/model_tree_reg.py,sha256=IMaQpNImoRqP8Biw1CsJevxpV_PVpKblsKtYlMW5d_U,429
100
100
  nkululeko/models/model_tuned.py,sha256=VuRyNqw3XTpQ2eHsWOJN8X-V98AN8Wqiq7UgwT5BQRU,23763
101
- nkululeko/models/model_xgb.py,sha256=ytBaSHZH8r7VvRYdmrBrQnzRM6V4HyCJ8O-v20J8G_g,448
101
+ nkululeko/models/model_xgb.py,sha256=zfZM3lqH5uttVB18b1MRIhP9CCeCuIh1ycgOuFMcqUM,449
102
102
  nkululeko/models/model_xgr.py,sha256=H01FJCRgmX2unvambMs5TTCS9sI6VDB9ip9G6rVGt2c,419
103
103
  nkululeko/models/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
104
104
  nkululeko/models/tests/test_model_svm.py,sha256=spDlZmeBKBdK4EFBpOgEkaAfGeGH9kau6CqSWOY6Uag,1856
@@ -118,9 +118,9 @@ nkululeko/utils/files.py,sha256=SrrYaU7AB80MZHiV1jcB0h_zigvYLYgSVNTXV4ao38g,4593
118
118
  nkululeko/utils/stats.py,sha256=3Fyx8q8BSKYmiufT6OkRug9RATWmGrr9BaX_y8jziWo,3074
119
119
  nkululeko/utils/unzip.py,sha256=G68f5120TjwACZC3bQcneMniddnwubPbBdMc2L5KBOo,1206
120
120
  nkululeko/utils/util.py,sha256=6NDKhOx0fV5fKyhSoY4hem96p7OuPcmhCDQR9EzkQhw,17829
121
- nkululeko-0.94.2.dist-info/licenses/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
122
- nkululeko-0.94.2.dist-info/METADATA,sha256=YmvStXvJdODIXXgDRQSmzhmJzl_SuX7donWfg66KZBI,2874
123
- nkululeko-0.94.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
124
- nkululeko-0.94.2.dist-info/entry_points.txt,sha256=lNTkFEdh6Kjo5o95ZAWf_0Lq-4ztGoAoMVSDuPtuyS0,442
125
- nkululeko-0.94.2.dist-info/top_level.txt,sha256=bf1k1YKkqcXemNX_cUgoyKqQ3_GVErPqAY-53J36jkM,19
126
- nkululeko-0.94.2.dist-info/RECORD,,
121
+ nkululeko-0.94.3.dist-info/licenses/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
122
+ nkululeko-0.94.3.dist-info/METADATA,sha256=QeZ9ZMTqwgdDvwRTCvgFO7X55_J84AWZh7jVf9uV-6M,2874
123
+ nkululeko-0.94.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
124
+ nkululeko-0.94.3.dist-info/entry_points.txt,sha256=lNTkFEdh6Kjo5o95ZAWf_0Lq-4ztGoAoMVSDuPtuyS0,442
125
+ nkululeko-0.94.3.dist-info/top_level.txt,sha256=bf1k1YKkqcXemNX_cUgoyKqQ3_GVErPqAY-53J36jkM,19
126
+ nkululeko-0.94.3.dist-info/RECORD,,