nkululeko 0.90.2__py3-none-any.whl → 0.90.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/augment.py CHANGED
@@ -83,17 +83,13 @@ def doit(config_file):
83
83
  print("DONE")
84
84
 
85
85
 
86
- def main(src_dir):
86
+ def main():
87
87
  parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
88
88
  parser.add_argument("--config", default="exp.ini", help="The base configuration")
89
89
  args = parser.parse_args()
90
- if args.config is not None:
91
- config_file = args.config
92
- else:
93
- config_file = f"{src_dir}/exp.ini"
90
+ config_file = args.config if args.config is not None else "exp.ini"
94
91
  doit(config_file)
95
92
 
96
93
 
97
94
  if __name__ == "__main__":
98
- cwd = os.path.dirname(os.path.abspath(__file__))
99
- main(cwd) # use this if you want to state the config file path on command line
95
+ main()
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION = "0.90.2"
1
+ VERSION="0.90.4"
2
2
  SAMPLING_RATE = 16000
nkululeko/demo.py CHANGED
@@ -30,7 +30,7 @@ from nkululeko.experiment import Experiment
30
30
  from nkululeko.utils.util import Util
31
31
 
32
32
 
33
- def main(src_dir):
33
+ def main():
34
34
  parser = argparse.ArgumentParser(description="Call the nkululeko DEMO framework.")
35
35
  parser.add_argument("--config", default="exp.ini", help="The base configuration")
36
36
  parser.add_argument(
@@ -142,5 +142,4 @@ def main(src_dir):
142
142
 
143
143
 
144
144
  if __name__ == "__main__":
145
- cwd = os.path.dirname(os.path.abspath(__file__))
146
- main(cwd) # use this if you want to state the config file path on command line
145
+ main()
nkululeko/explore.py CHANGED
@@ -25,33 +25,27 @@ for an `exp.ini` file in the same directory as the script.
25
25
 
26
26
  import argparse
27
27
  import configparser
28
- import os
28
+ from pathlib import Path
29
29
 
30
30
  from nkululeko.constants import VERSION
31
31
  from nkululeko.experiment import Experiment
32
32
  from nkululeko.utils.util import Util
33
33
 
34
34
 
35
- def main(src_dir):
35
+ def main():
36
36
  parser = argparse.ArgumentParser(
37
37
  description="Call the nkululeko EXPLORE framework."
38
38
  )
39
39
  parser.add_argument("--config", default="exp.ini", help="The base configuration")
40
40
  args = parser.parse_args()
41
- if args.config is not None:
42
- config_file = args.config
43
- else:
44
- config_file = f"{src_dir}/exp.ini"
41
+ config_file = args.config if args.config is not None else "exp.ini"
45
42
 
46
- # test if the configuration file exists
47
- if not os.path.isfile(config_file):
43
+ if not Path(config_file).is_file():
48
44
  print(f"ERROR: no such file: {config_file}")
49
45
  exit()
50
46
 
51
- # load one configuration per experiment
52
47
  config = configparser.ConfigParser()
53
48
  config.read(config_file)
54
- # create a new experiment
55
49
  expr = Experiment(config)
56
50
  module = "explore"
57
51
  expr.set_module(module)
@@ -101,5 +95,4 @@ def main(src_dir):
101
95
 
102
96
 
103
97
  if __name__ == "__main__":
104
- cwd = os.path.dirname(os.path.abspath(__file__))
105
- main(cwd) # use this if you want to state the config file path on command line
98
+ main()
nkululeko/export.py CHANGED
@@ -15,24 +15,18 @@ from nkululeko.experiment import Experiment
15
15
  from nkululeko.utils.util import Util
16
16
 
17
17
 
18
- def main(src_dir):
18
+ def main():
19
19
  parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
20
20
  parser.add_argument("--config", default="exp.ini", help="The base configuration")
21
21
  args = parser.parse_args()
22
- if args.config is not None:
23
- config_file = args.config
24
- else:
25
- config_file = f"{src_dir}/exp.ini"
22
+ config_file = args.config if args.config is not None else "exp.ini"
26
23
 
27
- # test if the configuration file exists
28
24
  if not os.path.isfile(config_file):
29
25
  print(f"ERROR: no such file: {config_file}")
30
26
  exit()
31
27
 
32
- # load one configuration per experiment
33
28
  config = configparser.ConfigParser()
34
29
  config.read(config_file)
35
- # create a new experiment
36
30
  expr = Experiment(config)
37
31
  util = Util("export")
38
32
  util.debug(
@@ -122,5 +116,4 @@ def main(src_dir):
122
116
 
123
117
 
124
118
  if __name__ == "__main__":
125
- cwd = os.path.dirname(os.path.abspath(__file__))
126
- main(cwd) # use this if you want to state the config file path on command line
119
+ main()
nkululeko/multidb.py CHANGED
@@ -115,7 +115,11 @@ def main(src_dir):
115
115
  print(repr(results))
116
116
  print(repr(last_epochs))
117
117
  root = os.path.join(config["EXP"]["root"], "")
118
- plot_name = f"{root}/heatmap.png"
118
+ try:
119
+ format = config["PLOT"]["format"]
120
+ plot_name = f"{root}/heatmap.{format}"
121
+ except KeyError:
122
+ plot_name = f"{root}/heatmap.png"
119
123
  plot_heatmap(results, last_epochs, datasets, plot_name, config, datasets)
120
124
 
121
125
 
nkululeko/nkululeko.py CHANGED
@@ -2,7 +2,7 @@
2
2
  # Entry script to do a Nkululeko experiment
3
3
  import argparse
4
4
  import configparser
5
- import os.path
5
+ from pathlib import Path
6
6
 
7
7
  import numpy as np
8
8
 
@@ -13,7 +13,7 @@ from nkululeko.utils.util import Util
13
13
 
14
14
  def doit(config_file):
15
15
  # test if the configuration file exists
16
- if not os.path.isfile(config_file):
16
+ if not Path(config_file).is_file():
17
17
  print(f"ERROR: no such file: {config_file}")
18
18
  exit()
19
19
 
@@ -57,17 +57,18 @@ def doit(config_file):
57
57
  return result, int(np.asarray(last_epochs).min())
58
58
 
59
59
 
60
- def main(src_dir):
60
+ def main():
61
+ cwd = Path(__file__).parent.absolute()
61
62
  parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
63
+ parser.add_argument("--version", action="version", version=f"Nkululeko {VERSION}")
62
64
  parser.add_argument("--config", default="exp.ini", help="The base configuration")
63
65
  args = parser.parse_args()
64
66
  if args.config is not None:
65
67
  config_file = args.config
66
68
  else:
67
- config_file = f"{src_dir}/exp.ini"
69
+ config_file = cwd / "exp.ini"
68
70
  doit(config_file)
69
71
 
70
72
 
71
73
  if __name__ == "__main__":
72
- cwd = os.path.dirname(os.path.abspath(__file__))
73
- main(cwd) # use this if you want to state the config file path on command line
74
+ main() # use this if you want to state the config file path on command line
nkululeko/predict.py CHANGED
@@ -24,26 +24,20 @@ from nkululeko.experiment import Experiment
24
24
  from nkululeko.utils.util import Util
25
25
 
26
26
 
27
- def main(src_dir):
27
+ def main():
28
28
  parser = argparse.ArgumentParser(
29
29
  description="Call the nkululeko PREDICT framework."
30
30
  )
31
31
  parser.add_argument("--config", default="exp.ini", help="The base configuration")
32
32
  args = parser.parse_args()
33
- if args.config is not None:
34
- config_file = args.config
35
- else:
36
- config_file = f"{src_dir}/exp.ini"
33
+ config_file = args.config if args.config is not None else "exp.ini"
37
34
 
38
- # test if the configuration file exists
39
35
  if not os.path.isfile(config_file):
40
36
  print(f"ERROR: no such file: {config_file}")
41
37
  exit()
42
38
 
43
- # load one configuration per experiment
44
39
  config = configparser.ConfigParser()
45
40
  config.read(config_file)
46
- # create a new experiment
47
41
  expr = Experiment(config)
48
42
  module = "predict"
49
43
  expr.set_module(module)
@@ -73,5 +67,4 @@ def main(src_dir):
73
67
 
74
68
 
75
69
  if __name__ == "__main__":
76
- cwd = os.path.dirname(os.path.abspath(__file__))
77
- main(cwd) # use this if you want to state the config file path on command line
70
+ main()
nkululeko/resample.py CHANGED
@@ -1,3 +1,25 @@
1
+ """
2
+ Resample audio files or INI files (train, test, all) to change the sampling rate.
3
+
4
+ This script provides a command-line interface to resample audio files or INI files
5
+ containing train, test, and all data. It supports resampling a single file, a
6
+ directory of files, or all files specified in an INI configuration file.
7
+
8
+ The script uses the `Resampler` class from the `nkululeko.augmenting.resampler`
9
+ module to perform the resampling operation. It can optionally replace the original
10
+ audio files with the resampled versions.
11
+
12
+ The script supports the following command-line arguments:
13
+ - `--config`: The base configuration file (INI format) to use for resampling.
14
+ - `--file`: The input audio file to resample.
15
+ - `--folder`: The input directory containing audio files and subdirectories to resample.
16
+ - `--replace`: Whether to replace the original audio files with the resampled versions.
17
+
18
+ The script also supports loading configuration from an INI file, which can be used
19
+ to specify the sample selection (all, train, or test) and whether to replace the
20
+ original files.
21
+ """
22
+
1
23
  # resample.py
2
24
  # change the sampling rate for audio file or INI file (train, test, all)
3
25
 
@@ -15,7 +37,7 @@ from nkululeko.utils.files import find_files
15
37
  from nkululeko.utils.util import Util
16
38
 
17
39
 
18
- def main(src_dir):
40
+ def main():
19
41
  parser = argparse.ArgumentParser(
20
42
  description="Call the nkululeko RESAMPLE framework."
21
43
  )
@@ -118,5 +140,4 @@ def main(src_dir):
118
140
 
119
141
 
120
142
  if __name__ == "__main__":
121
- cwd = os.path.dirname(os.path.abspath(__file__))
122
- main(cwd)
143
+ main()
nkululeko/segment.py CHANGED
@@ -14,24 +14,18 @@ from nkululeko.reporting.report_item import ReportItem
14
14
  from nkululeko.utils.util import Util
15
15
 
16
16
 
17
- def main(src_dir):
17
+ def main():
18
18
  parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
19
19
  parser.add_argument("--config", default="exp.ini", help="The base configuration")
20
20
  args = parser.parse_args()
21
- if args.config is not None:
22
- config_file = args.config
23
- else:
24
- config_file = f"{src_dir}/exp.ini"
21
+ config_file = args.config if args.config is not None else "exp.ini"
25
22
 
26
- # test if the configuration file exists
27
23
  if not os.path.isfile(config_file):
28
24
  print(f"ERROR: no such file: {config_file}")
29
25
  exit()
30
26
 
31
- # load one configuration per experiment
32
27
  config = configparser.ConfigParser()
33
28
  config.read(config_file)
34
- # create a new experiment
35
29
  expr = Experiment(config)
36
30
  module = "segment"
37
31
  expr.set_module(module)
@@ -153,5 +147,4 @@ def segment_dataframe(df):
153
147
 
154
148
 
155
149
  if __name__ == "__main__":
156
- cwd = os.path.dirname(os.path.abspath(__file__))
157
- main(cwd) # use this if you want to state the config file path on command line
150
+ main() # use this if you want to state the config file path on command line
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.90.2
3
+ Version: 0.90.4
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -68,7 +68,7 @@ A project to detect speaker characteristics by machine learning experiments with
68
68
 
69
69
  The idea is to have a framework (based on e.g. sklearn and torch) that can be used to rapidly and automatically analyse audio data and explore machine learning models based on that data.
70
70
 
71
- * NEW with nkululek: [Ensemble learning](http://blog.syntheticspeech.de/2024/06/25/nkululeko-ensemble-classifiers-with-late-fusion/)
71
+ * NEW with nkululeko: [Ensemble learning](http://blog.syntheticspeech.de/2024/06/25/nkululeko-ensemble-classifiers-with-late-fusion/)
72
72
  * NEW: [Finetune transformer-models](http://blog.syntheticspeech.de/2024/05/29/nkululeko-how-to-finetune-a-transformer-model/)
73
73
  * The latest features can be seen in [the ini-file](./ini_file.md) options that are used to control Nkululeko
74
74
  * Below is a [Hello World example](#helloworld) that should set you up fastly, also on [Google Colab](https://colab.research.google.com/drive/1GYNBd5cdZQ1QC3Jm58qoeMaJg3UuPhjw?usp=sharing#scrollTo=4G_SjuF9xeQf), and [with Kaggle](https://www.kaggle.com/felixburk/nkululeko-hello-world-example)
@@ -356,6 +356,14 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
356
356
  Changelog
357
357
  =========
358
358
 
359
+ Version 0.90.4
360
+ --------------
361
+ * added plot format for multidb
362
+
363
+ Version 0.90.3
364
+ --------------
365
+ * refactorings and documentations
366
+
359
367
  Version 0.90.2
360
368
  --------------
361
369
  * added probability output to finetuning classification models
@@ -1,33 +1,31 @@
1
1
  nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=FoMbBrfyOZd4QAw7oIHl3X6-UpsqAKWVDIolCA7qOWs,3196
3
- nkululeko/augment.py,sha256=sIXRg19Uz8dWKgQv2LBGH7jbd2pgcUTh0PIQ_62B0kA,3135
3
+ nkululeko/augment.py,sha256=3RzaxB3gRxovgJVjHXi0glprW01J7RaHhUkqotW2T3U,2955
4
4
  nkululeko/cacheddataset.py,sha256=XFpWZmbJRg0pvhnIgYf0TkclxllD-Fctu-Ol0PF_00c,969
5
- nkululeko/constants.py,sha256=RbyLuq3HuWP1QWBrcWXo-YcwlYf2qDk6H1ihR4_KqbY,41
5
+ nkululeko/constants.py,sha256=jZ8xPXzwC4olxRWBxh7QNAfDpWxH99Bim1eoRIcVwtY,39
6
6
  nkululeko/demo-ft.py,sha256=iD9Pzp9QjyAv31q1cDZ75vPez7Ve8A4Cfukv5yfZdrQ,770
7
- nkululeko/demo.py,sha256=bLuHkeEl5rOfm7ecGHCcWATiPK7-njNbtrGljxzNzFs,5088
7
+ nkululeko/demo.py,sha256=4Yzhg6pCPBYPGJrP7JX2TysVosl_R1llpVDKc2P_gUA,4955
8
8
  nkululeko/demo_feats.py,sha256=BvZjeNFTlERIRlq34OHM4Z96jdDQAhB01BGQAUcX9dM,2026
9
9
  nkululeko/demo_predictor.py,sha256=lDF-xOxRdEAclOmbepAYg-BQXQdGkHfq2n74PTIoop8,4872
10
10
  nkululeko/ensemble.py,sha256=QONr-1VwMr2D0I7wjWxwGjtYzWf4v9DoI3C-fFnar7E,12862
11
11
  nkululeko/experiment.py,sha256=octx5S4Y8-gAD0dXCRb6DFZwsXTYgzk06RBA3LX2SN0,31388
12
- nkululeko/experiment_felix.py,sha256=IBXtyXkQJP7IuFjZ4tCP3SAQ0g_Oqe3Pyzxz8DOeT-A,30134
13
- nkululeko/explore.py,sha256=lrMrbM2WFJDcfaD_uJFbxpK-cGX2ZVy2QRfWMLRiXjw,3941
14
- nkululeko/export.py,sha256=aqHnZPRv3dk69keY8HB5WJrhFl649X1PVbv_GlYmfH8,4634
12
+ nkululeko/explore.py,sha256=Y5lPPychnI-7fyP8zvwVb9P09fvprbUPOofOppuABYQ,3658
13
+ nkululeko/export.py,sha256=U-V4acxtuL6qKt6oAsVcM5TTeWogYUJ3GU-lA6rq6d4,4336
15
14
  nkululeko/feature_extractor.py,sha256=UnspIWz3XrNhKnBBhWZkH2bHvD-sROtrQVqB1JvkUyw,4088
16
15
  nkululeko/file_checker.py,sha256=xJY0Q6w47pnmgJVK5rcAKPYBrCpV7eBT4_3YBzTx-H8,3454
17
16
  nkululeko/filter_data.py,sha256=5AYDtqs_GWGr4V5CbbYQkVVgCD3kq2dpKu8rF3V87NI,7224
18
17
  nkululeko/fixedsegment.py,sha256=Tb92QiuiyMsOO3WRWwuGjZGibS8hbHHCrcWAXGk7g04,2868
19
18
  nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
20
19
  nkululeko/modelrunner.py,sha256=lJy-xM4QfDDWeL0dLTE_VIb4sYrnd_Z_yJRK3wwohQA,11199
21
- nkululeko/multidb.py,sha256=mDh2Zj4zDbM-wZxib-r8LaiGqfAbh7oihgWBODj76kU,6753
20
+ nkululeko/multidb.py,sha256=sO6OwJn8sn1-C-ig3thsIL8QMWHdV9SnJhDodKjeKrI,6876
22
21
  nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
23
- nkululeko/nkululeko.py,sha256=n4KidI4sN3LwNyZoz-q2bLBjNn8lxYDya35qws55_ys,1968
22
+ nkululeko/nkululeko.py,sha256=M7baIq2nAoi6dEoBL4ATEuqAs5U1fvl_hyqAl5DybAQ,2040
24
23
  nkululeko/plots.py,sha256=p9YyN-xAtdGBKjcA305V0KOagAzG8VG6D_Ceoa9rae4,22964
25
- nkululeko/predict.py,sha256=ObFOxIgQ8JVYZLk2h0VFt8h7lYLMy8fXLUxU6eiePZc,2381
26
- nkululeko/resample.py,sha256=y2l7k1jKheO-ntBZio9bRFWLKGTihVFUV0fb8U69T0o,4185
27
- nkululeko/resample_cli.py,sha256=EJnN5t13qC4e0JVO3Rah3uJd4JRE3HM8GkoKyXsE49s,3211
24
+ nkululeko/predict.py,sha256=b35YOqovGb5PLDz0nDuhJGykEAPq2Y45R9lzxJZMuMU,2083
25
+ nkululeko/resample.py,sha256=akSAjJ3qn-O5NAyLJHVHdsK7MUZPGaZUvM2TwMSmj2M,5194
28
26
  nkululeko/runmanager.py,sha256=AswmORVUkCIH0gTx6zEyufvFATQBS8C5TXo2erSNdVg,7611
29
27
  nkululeko/scaler.py,sha256=7VOZ4sREMoQtahfETt9RyuR29Fb7PCwxlYVjBbdCVFc,4101
30
- nkululeko/segment.py,sha256=PPB8oSs_MLdEYoWh6_q3gm4mIUqPnCeGrB7FbX2AsBs,4799
28
+ nkululeko/segment.py,sha256=lSeI1i96HZTloSqdH75FhD7VyDQ16Do99-5mhI30To8,4571
31
29
  nkululeko/syllable_nuclei.py,sha256=5w_naKxNxz66a_qLkraemi2fggM-gWesiiBPS47iFcE,9931
32
30
  nkululeko/test.py,sha256=1w624vo5KTzmFC8BUStGlLDmIEAFuJUz7J0W-gp7AxI,1677
33
31
  nkululeko/test_predictor.py,sha256=DEHE_D3A6m6KJTrpDKceA1n655t_UZV3WQd57K4a3Ho,2863
@@ -112,8 +110,9 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
112
110
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
113
111
  nkululeko/utils/stats.py,sha256=vCRzhCR0Gx5SiJyAGbj1TIto8ocGz58CM5Pr3LltagA,2948
114
112
  nkululeko/utils/util.py,sha256=XFZdhCc_LM4EmoZ5tKKaBCQLXclcNmvHwhfT_CXB98c,16723
115
- nkululeko-0.90.2.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
116
- nkululeko-0.90.2.dist-info/METADATA,sha256=rJnGf71UEIyv0OBiNxrfu0l1e6o83v8q_UlIlmhtE_0,41113
117
- nkululeko-0.90.2.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
118
- nkululeko-0.90.2.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
119
- nkululeko-0.90.2.dist-info/RECORD,,
113
+ nkululeko-0.90.4.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
114
+ nkululeko-0.90.4.dist-info/METADATA,sha256=t64nFqxKkX3gaQ8J0PjpiRxc03LBS0yGO3i5wTR1bxc,41242
115
+ nkululeko-0.90.4.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
116
+ nkululeko-0.90.4.dist-info/entry_points.txt,sha256=KpQhz4HKBvYLrNooqLIc83hub76axRbYUgWzYkH3GnU,397
117
+ nkululeko-0.90.4.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
118
+ nkululeko-0.90.4.dist-info/RECORD,,
@@ -0,0 +1,10 @@
1
+ [console_scripts]
2
+ nkululeko.augment = nkululeko.augment:main
3
+ nkululeko.demo = nkululeko.demo:main
4
+ nkululeko.explore = nkululeko.explore:main
5
+ nkululeko.export = nkululeko.export:main
6
+ nkululeko.nkululeko = nkululeko.nkululeko:main
7
+ nkululeko.predict = nkululeko.predict:main
8
+ nkululeko.resample = nkululeko.resample:main
9
+ nkululeko.segment = nkululeko.segment:main
10
+ nkululeko.test = nkululeko.test:main
@@ -1,728 +0,0 @@
1
- # experiment.py: Main class for an experiment (nkululeko.nkululeko)
2
- import ast
3
- import os
4
- import pickle
5
- import random
6
- import time
7
-
8
- import audeer
9
- import audformat
10
- import numpy as np
11
- import pandas as pd
12
- from sklearn.preprocessing import LabelEncoder
13
-
14
- import nkululeko.glob_conf as glob_conf
15
- from nkululeko.data.dataset import Dataset
16
- from nkululeko.data.dataset_csv import Dataset_CSV
17
- from nkululeko.demo_predictor import Demo_predictor
18
- from nkululeko.feat_extract.feats_analyser import FeatureAnalyser
19
- from nkululeko.feature_extractor import FeatureExtractor
20
- from nkululeko.file_checker import FileChecker
21
- from nkululeko.filter_data import DataFilter
22
- from nkululeko.plots import Plots
23
- from nkululeko.reporting.report import Report
24
- from nkululeko.runmanager import Runmanager
25
- from nkululeko.scaler import Scaler
26
- from nkululeko.test_predictor import TestPredictor
27
- from nkululeko.utils.util import Util
28
-
29
-
30
- class Experiment:
31
- """Main class specifying an experiment"""
32
-
33
- def __init__(self, config_obj):
34
- """
35
- Parameters
36
- ----------
37
- config_obj : a config parser object that sets the experiment parameters and being set as a global object.
38
- """
39
-
40
- self.set_globals(config_obj)
41
- self.name = glob_conf.config["EXP"]["name"]
42
- self.root = os.path.join(glob_conf.config["EXP"]["root"], "")
43
- self.data_dir = os.path.join(self.root, self.name)
44
- audeer.mkdir(self.data_dir) # create the experiment directory
45
- self.util = Util("experiment")
46
- glob_conf.set_util(self.util)
47
- fresh_report = eval(self.util.config_val("REPORT", "fresh", "False"))
48
- if not fresh_report:
49
- try:
50
- with open(os.path.join(self.data_dir, "report.pkl"), "rb") as handle:
51
- self.report = pickle.load(handle)
52
- except FileNotFoundError:
53
- self.report = Report()
54
- else:
55
- self.util.debug("starting a fresh report")
56
- self.report = Report()
57
- glob_conf.set_report(self.report)
58
- self.loso = self.util.config_val("MODEL", "loso", False)
59
- self.logo = self.util.config_val("MODEL", "logo", False)
60
- self.xfoldx = self.util.config_val("MODEL", "k_fold_cross", False)
61
- self.start = time.process_time()
62
-
63
- def set_module(self, module):
64
- glob_conf.set_module(module)
65
-
66
- def store_report(self):
67
- with open(os.path.join(self.data_dir, "report.pkl"), "wb") as handle:
68
- pickle.dump(self.report, handle)
69
- if eval(self.util.config_val("REPORT", "show", "False")):
70
- self.report.print()
71
- if self.util.config_val("REPORT", "latex", False):
72
- self.report.export_latex()
73
-
74
- def get_name(self):
75
- return self.util.get_exp_name()
76
-
77
- def set_globals(self, config_obj):
78
- """install a config object in the global space"""
79
- glob_conf.init_config(config_obj)
80
-
81
- def load_datasets(self):
82
- """Load all databases specified in the configuration and map the labels"""
83
- ds = ast.literal_eval(glob_conf.config["DATA"]["databases"])
84
- self.datasets = {}
85
- self.got_speaker, self.got_gender, self.got_age = False, False, False
86
- for d in ds:
87
- ds_type = self.util.config_val_data(d, "type", "audformat")
88
- if ds_type == "audformat":
89
- data = Dataset(d)
90
- elif ds_type == "csv":
91
- data = Dataset_CSV(d)
92
- else:
93
- self.util.error(f"unknown data type: {ds_type}")
94
- data.load()
95
- data.prepare()
96
- if data.got_gender:
97
- self.got_gender = True
98
- if data.got_age:
99
- self.got_age = True
100
- if data.got_speaker:
101
- self.got_speaker = True
102
- self.datasets.update({d: data})
103
- self.target = self.util.config_val("DATA", "target", "emotion")
104
- glob_conf.set_target(self.target)
105
- # print target via debug
106
- self.util.debug(f"target: {self.target}")
107
- # print keys/column
108
- dbs = ",".join(list(self.datasets.keys()))
109
- labels = self.util.config_val("DATA", "labels", False)
110
- if labels:
111
- self.labels = ast.literal_eval(labels)
112
- self.util.debug(f"Target labels (from config): {labels}")
113
- else:
114
- self.labels = list(
115
- next(iter(self.datasets.values())).df[self.target].unique()
116
- )
117
- self.util.debug(f"Target labels (from database): {labels}")
118
- glob_conf.set_labels(self.labels)
119
- self.util.debug(f"loaded databases {dbs}")
120
-
121
- def _import_csv(self, storage):
122
- # df = pd.read_csv(storage, header=0, index_col=[0,1,2])
123
- # df.index.set_levels(pd.to_timedelta(df.index.levels[1]), level=1)
124
- # df.index.set_levels(pd.to_timedelta(df.index.levels[2]), level=2)
125
- df = audformat.utils.read_csv(storage)
126
- df.is_labeled = True if self.target in df else False
127
- # print(df.head())
128
- return df
129
-
130
- def fill_tests(self):
131
- """Only fill a new test set"""
132
-
133
- test_dbs = ast.literal_eval(glob_conf.config["DATA"]["tests"])
134
- self.df_test = pd.DataFrame()
135
- start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
136
- store = self.util.get_path("store")
137
- storage_test = f"{store}extra_testdf.csv"
138
- if os.path.isfile(storage_test) and not start_fresh:
139
- self.util.debug(f"reusing previously stored {storage_test}")
140
- self.df_test = self._import_csv(storage_test)
141
- else:
142
- for d in test_dbs:
143
- ds_type = self.util.config_val_data(d, "type", "audformat")
144
- if ds_type == "audformat":
145
- data = Dataset(d)
146
- elif ds_type == "csv":
147
- data = Dataset_CSV(d)
148
- else:
149
- self.util.error(f"unknown data type: {ds_type}")
150
- data.load()
151
- if data.got_gender:
152
- self.got_gender = True
153
- if data.got_age:
154
- self.got_age = True
155
- if data.got_speaker:
156
- self.got_speaker = True
157
- data.split()
158
- data.prepare_labels()
159
- self.df_test = pd.concat(
160
- [self.df_test, self.util.make_segmented_index(data.df_test)]
161
- )
162
- self.df_test.is_labeled = data.is_labeled
163
- self.df_test.got_gender = self.got_gender
164
- self.df_test.got_speaker = self.got_speaker
165
- # self.util.set_config_val('FEATS', 'needs_features_extraction', 'True')
166
- # self.util.set_config_val('FEATS', 'no_reuse', 'True')
167
- self.df_test["class_label"] = self.df_test[self.target]
168
- self.df_test[self.target] = self.label_encoder.transform(
169
- self.df_test[self.target]
170
- )
171
- self.df_test.to_csv(storage_test)
172
-
173
- def fill_train_and_tests(self):
174
- """Set up train and development sets. The method should be specified in the config."""
175
- store = self.util.get_path("store")
176
- storage_test = f"{store}testdf.csv"
177
- storage_train = f"{store}traindf.csv"
178
- start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
179
- if (
180
- os.path.isfile(storage_train)
181
- and os.path.isfile(storage_test)
182
- and not start_fresh
183
- ):
184
- self.util.debug(
185
- f"reusing previously stored {storage_test} and {storage_train}"
186
- )
187
- self.df_test = self._import_csv(storage_test)
188
- # print(f"df_test: {self.df_test}")
189
- self.df_train = self._import_csv(storage_train)
190
- # print(f"df_train: {self.df_train}")
191
- else:
192
- self.df_train, self.df_test = pd.DataFrame(), pd.DataFrame()
193
- for d in self.datasets.values():
194
- d.split()
195
- d.prepare_labels()
196
- if d.df_train.shape[0] == 0:
197
- self.util.debug(f"warn: {d.name} train empty")
198
- self.df_train = pd.concat([self.df_train, d.df_train])
199
- # print(f"df_train: {self.df_train}")
200
- self.util.copy_flags(d, self.df_train)
201
- if d.df_test.shape[0] == 0:
202
- self.util.debug(f"warn: {d.name} test empty")
203
- self.df_test = pd.concat([self.df_test, d.df_test])
204
- self.util.copy_flags(d, self.df_test)
205
- store = self.util.get_path("store")
206
- storage_test = f"{store}testdf.csv"
207
- storage_train = f"{store}traindf.csv"
208
- self.df_test.to_csv(storage_test)
209
- self.df_train.to_csv(storage_train)
210
-
211
- self.util.copy_flags(self, self.df_test)
212
- self.util.copy_flags(self, self.df_train)
213
- # Try data checks
214
- datachecker = FileChecker(self.df_train)
215
- self.df_train = datachecker.all_checks()
216
- datachecker.set_data(self.df_test)
217
- self.df_test = datachecker.all_checks()
218
-
219
- # Check for filters
220
- filter_sample_selection = self.util.config_val(
221
- "DATA", "filter.sample_selection", "all"
222
- )
223
- if filter_sample_selection == "all":
224
- datafilter = DataFilter(self.df_train)
225
- self.df_train = datafilter.all_filters()
226
- datafilter = DataFilter(self.df_test)
227
- self.df_test = datafilter.all_filters()
228
- elif filter_sample_selection == "train":
229
- datafilter = DataFilter(self.df_train)
230
- self.df_train = datafilter.all_filters()
231
- elif filter_sample_selection == "test":
232
- datafilter = DataFilter(self.df_test)
233
- self.df_test = datafilter.all_filters()
234
- else:
235
- self.util.error(
236
- "unkown filter sample selection specifier"
237
- f" {filter_sample_selection}, should be [all | train | test]"
238
- )
239
-
240
- # encode the labels
241
- if self.util.exp_is_classification():
242
- datatype = self.util.config_val("DATA", "type", "dummy")
243
- if datatype == "continuous":
244
- # if self.df_test.is_labeled:
245
- # # remember the target in case they get labelencoded later
246
- # self.df_test["class_label"] = self.df_test[self.target]
247
- test_cats = self.df_test["class_label"].unique()
248
- # else:
249
- # # if there is no target, copy a dummy label
250
- # self.df_test = self._add_random_target(self.df_test)
251
- # if self.df_train.is_labeled:
252
- # # remember the target in case they get labelencoded later
253
- # self.df_train["class_label"] = self.df_train[self.target]
254
- train_cats = self.df_train["class_label"].unique()
255
-
256
- else:
257
- if self.df_test.is_labeled:
258
- test_cats = self.df_test[self.target].unique()
259
- else:
260
- # if there is no target, copy a dummy label
261
- self.df_test = self._add_random_target(self.df_test).astype("str")
262
- train_cats = self.df_train[self.target].unique()
263
- # print(f"df_train: {pd.DataFrame(self.df_train[self.target])}")
264
- # print(f"train_cats with target {self.target}: {train_cats}")
265
- if self.df_test.is_labeled:
266
- if type(test_cats) == np.ndarray:
267
- self.util.debug(f"Categories test (nd.array): {test_cats}")
268
- else:
269
- self.util.debug(f"Categories test (list): {list(test_cats)}")
270
- if type(train_cats) == np.ndarray:
271
- self.util.debug(f"Categories train (nd.array): {train_cats}")
272
- else:
273
- self.util.debug(f"Categories train (list): {list(train_cats)}")
274
-
275
- # encode the labels as numbers
276
- self.label_encoder = LabelEncoder()
277
- self.df_train[self.target] = self.label_encoder.fit_transform(
278
- self.df_train[self.target]
279
- )
280
- self.df_test[self.target] = self.label_encoder.transform(
281
- self.df_test[self.target]
282
- )
283
- glob_conf.set_label_encoder(self.label_encoder)
284
- if self.got_speaker:
285
- self.util.debug(
286
- f"{self.df_test.speaker.nunique()} speakers in test and"
287
- f" {self.df_train.speaker.nunique()} speakers in train"
288
- )
289
-
290
- target_factor = self.util.config_val("DATA", "target_divide_by", False)
291
- if target_factor:
292
- self.df_test[self.target] = self.df_test[self.target] / float(target_factor)
293
- self.df_train[self.target] = self.df_train[self.target] / float(
294
- target_factor
295
- )
296
- if not self.util.exp_is_classification():
297
- self.df_test["class_label"] = self.df_test["class_label"] / float(
298
- target_factor
299
- )
300
- self.df_train["class_label"] = self.df_train["class_label"] / float(
301
- target_factor
302
- )
303
-
304
- def _add_random_target(self, df):
305
- labels = glob_conf.labels
306
- a = [None] * len(df)
307
- for i in range(0, len(df)):
308
- a[i] = random.choice(labels)
309
- df[self.target] = a
310
- return df
311
-
312
- def plot_distribution(self, df_labels):
313
- """Plot the distribution of samples and speaker per target class and biological sex"""
314
- plot = Plots()
315
- sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
316
- plot.plot_distributions(df_labels)
317
- if self.got_speaker:
318
- plot.plot_distributions_speaker(df_labels)
319
-
320
- def extract_test_feats(self):
321
- self.feats_test = pd.DataFrame()
322
- feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["tests"]))
323
- feats_types = self.util.config_val_list("FEATS", "type", ["os"])
324
- self.feature_extractor = FeatureExtractor(
325
- self.df_test, feats_types, feats_name, "test"
326
- )
327
- self.feats_test = self.feature_extractor.extract()
328
- self.util.debug(f"Test features shape:{self.feats_test.shape}")
329
-
330
- def extract_feats(self):
331
- """Extract the features for train and dev sets.
332
-
333
- They will be stored on disk and need to be removed manually.
334
-
335
- The string FEATS.feats_type is read from the config, defaults to os.
336
-
337
- """
338
- df_train, df_test = self.df_train, self.df_test
339
- feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
340
- self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
341
- feats_types = self.util.config_val_list("FEATS", "type", ["os"])
342
- self.feature_extractor = FeatureExtractor(
343
- df_train, feats_types, feats_name, "train"
344
- )
345
- self.feats_train = self.feature_extractor.extract()
346
- self.feature_extractor = FeatureExtractor(
347
- df_test, feats_types, feats_name, "test"
348
- )
349
- self.feats_test = self.feature_extractor.extract()
350
- self.util.debug(
351
- f"All features: train shape : {self.feats_train.shape}, test"
352
- f" shape:{self.feats_test.shape}"
353
- )
354
- if self.feats_train.shape[0] < self.df_train.shape[0]:
355
- self.util.warn(
356
- f"train feats ({self.feats_train.shape[0]}) != train labels"
357
- f" ({self.df_train.shape[0]})"
358
- )
359
- self.df_train = self.df_train[
360
- self.df_train.index.isin(self.feats_train.index)
361
- ]
362
- self.util.warn(f"new train labels shape: {self.df_train.shape[0]}")
363
- if self.feats_test.shape[0] < self.df_test.shape[0]:
364
- self.util.warn(
365
- f"test feats ({self.feats_test.shape[0]}) != test labels"
366
- f" ({self.df_test.shape[0]})"
367
- )
368
- self.df_test = self.df_test[self.df_test.index.isin(self.feats_test.index)]
369
- self.util.warn(f"mew test labels shape: {self.df_test.shape[0]}")
370
-
371
- self._check_scale()
372
-
373
- def augment(self):
374
- """
375
- Augment the selected samples
376
- """
377
- from nkululeko.augmenting.augmenter import Augmenter
378
-
379
- sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
380
- if sample_selection == "all":
381
- df = pd.concat([self.df_train, self.df_test])
382
- elif sample_selection == "train":
383
- df = self.df_train
384
- elif sample_selection == "test":
385
- df = self.df_test
386
- else:
387
- self.util.error(
388
- f"unknown augmentation selection specifier {sample_selection},"
389
- " should be [all | train | test]"
390
- )
391
-
392
- augmenter = Augmenter(df)
393
- df_ret = augmenter.augment(sample_selection)
394
- return df_ret
395
-
396
- def autopredict(self):
397
- """
398
- Predict labels for samples with existing models and add to the dataframe.
399
- """
400
- sample_selection = self.util.config_val("PREDICT", "split", "all")
401
- if sample_selection == "all":
402
- df = pd.concat([self.df_train, self.df_test])
403
- elif sample_selection == "train":
404
- df = self.df_train
405
- elif sample_selection == "test":
406
- df = self.df_test
407
- else:
408
- self.util.error(
409
- f"unknown augmentation selection specifier {sample_selection},"
410
- " should be [all | train | test]"
411
- )
412
- targets = self.util.config_val_list("PREDICT", "targets", ["gender"])
413
- for target in targets:
414
- if target == "gender":
415
- from nkululeko.autopredict.ap_gender import GenderPredictor
416
-
417
- predictor = GenderPredictor(df)
418
- df = predictor.predict(sample_selection)
419
- elif target == "age":
420
- from nkululeko.autopredict.ap_age import AgePredictor
421
-
422
- predictor = AgePredictor(df)
423
- df = predictor.predict(sample_selection)
424
- elif target == "snr":
425
- from nkululeko.autopredict.ap_snr import SNRPredictor
426
-
427
- predictor = SNRPredictor(df)
428
- df = predictor.predict(sample_selection)
429
- elif target == "mos":
430
- from nkululeko.autopredict.ap_mos import MOSPredictor
431
-
432
- predictor = MOSPredictor(df)
433
- df = predictor.predict(sample_selection)
434
- elif target == "pesq":
435
- from nkululeko.autopredict.ap_pesq import PESQPredictor
436
-
437
- predictor = PESQPredictor(df)
438
- df = predictor.predict(sample_selection)
439
- elif target == "sdr":
440
- from nkululeko.autopredict.ap_sdr import SDRPredictor
441
-
442
- predictor = SDRPredictor(df)
443
- df = predictor.predict(sample_selection)
444
- elif target == "stoi":
445
- from nkululeko.autopredict.ap_stoi import STOIPredictor
446
-
447
- predictor = STOIPredictor(df)
448
- df = predictor.predict(sample_selection)
449
- elif target == "arousal":
450
- from nkululeko.autopredict.ap_arousal import ArousalPredictor
451
-
452
- predictor = ArousalPredictor(df)
453
- df = predictor.predict(sample_selection)
454
- elif target == "valence":
455
- from nkululeko.autopredict.ap_valence import ValencePredictor
456
-
457
- predictor = ValencePredictor(df)
458
- df = predictor.predict(sample_selection)
459
- elif target == "dominance":
460
- from nkululeko.autopredict.ap_dominance import DominancePredictor
461
-
462
- predictor = DominancePredictor(df)
463
- df = predictor.predict(sample_selection)
464
- else:
465
- self.util.error(f"unknown auto predict target: {target}")
466
- return df
467
-
468
- def random_splice(self):
469
- """
470
- Random-splice the selected samples
471
- """
472
- from nkululeko.augmenting.randomsplicer import Randomsplicer
473
-
474
- sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
475
- if sample_selection == "all":
476
- df = pd.concat([self.df_train, self.df_test])
477
- elif sample_selection == "train":
478
- df = self.df_train
479
- elif sample_selection == "test":
480
- df = self.df_test
481
- else:
482
- self.util.error(
483
- f"unknown augmentation selection specifier {sample_selection},"
484
- " should be [all | train | test]"
485
- )
486
- randomsplicer = Randomsplicer(df)
487
- df_ret = randomsplicer.run(sample_selection)
488
- return df_ret
489
-
490
- def analyse_features(self, needs_feats):
491
- """Do a feature exploration."""
492
- plot_feats = eval(
493
- self.util.config_val("EXPL", "feature_distributions", "False")
494
- )
495
- sample_selection = self.util.config_val("EXPL", "sample_selection", "all")
496
- # get the data labels
497
- if sample_selection == "all":
498
- df_labels = pd.concat([self.df_train, self.df_test])
499
- self.util.copy_flags(self.df_train, df_labels)
500
- elif sample_selection == "train":
501
- df_labels = self.df_train
502
- self.util.copy_flags(self.df_train, df_labels)
503
- elif sample_selection == "test":
504
- df_labels = self.df_test
505
- self.util.copy_flags(self.df_test, df_labels)
506
- else:
507
- self.util.error(
508
- f"unknown sample selection specifier {sample_selection}, should"
509
- " be [all | train | test]"
510
- )
511
- self.util.debug(f"sampling selection: {sample_selection}")
512
- if self.util.config_val("EXPL", "value_counts", False):
513
- self.plot_distribution(df_labels)
514
-
515
- # check if data should be shown with the spotlight data visualizer
516
- spotlight = eval(self.util.config_val("EXPL", "spotlight", "False"))
517
- if spotlight:
518
- self.util.debug("opening spotlight tab in web browser")
519
- from renumics import spotlight
520
-
521
- spotlight.show(df_labels.reset_index())
522
-
523
- if not needs_feats:
524
- return
525
- # get the feature values
526
- if sample_selection == "all":
527
- df_feats = pd.concat([self.feats_train, self.feats_test])
528
- elif sample_selection == "train":
529
- df_feats = self.feats_train
530
- elif sample_selection == "test":
531
- df_feats = self.feats_test
532
- else:
533
- self.util.error(
534
- f"unknown sample selection specifier {sample_selection}, should"
535
- " be [all | train | test]"
536
- )
537
- feat_analyser = FeatureAnalyser(sample_selection, df_labels, df_feats)
538
- # check if SHAP features should be analysed
539
- shap = eval(self.util.config_val("EXPL", "shap", "False"))
540
- if shap:
541
- feat_analyser.analyse_shap(self.runmgr.get_best_model())
542
-
543
- if plot_feats:
544
- feat_analyser.analyse()
545
-
546
- # check if a scatterplot should be done
547
- scatter_var = eval(self.util.config_val("EXPL", "scatter", "False"))
548
- scatter_target = self.util.config_val(
549
- "EXPL", "scatter.target", "['class_label']"
550
- )
551
- if scatter_var:
552
- scatters = ast.literal_eval(glob_conf.config["EXPL"]["scatter"])
553
- scat_targets = ast.literal_eval(scatter_target)
554
- plots = Plots()
555
- for scat_target in scat_targets:
556
- if self.util.is_categorical(df_labels[scat_target]):
557
- for scatter in scatters:
558
- plots.scatter_plot(df_feats, df_labels, scat_target, scatter)
559
- else:
560
- self.util.debug(
561
- f"{self.name}: binning continuous variable to categories"
562
- )
563
- cat_vals = self.util.continuous_to_categorical(
564
- df_labels[scat_target]
565
- )
566
- df_labels[f"{scat_target}_bins"] = cat_vals.values
567
- for scatter in scatters:
568
- plots.scatter_plot(
569
- df_feats, df_labels, f"{scat_target}_bins", scatter
570
- )
571
-
572
- def _check_scale(self):
573
- scale_feats = self.util.config_val("FEATS", "scale", False)
574
- # print the scale
575
- self.util.debug(f"scaler: {scale_feats}")
576
- if scale_feats:
577
- self.scaler_feats = Scaler(
578
- self.df_train,
579
- self.df_test,
580
- self.feats_train,
581
- self.feats_test,
582
- scale_feats,
583
- )
584
- self.feats_train, self.feats_test = self.scaler_feats.scale()
585
- # store versions
586
- self.util.save_to_store(self.feats_train, "feats_train_scaled")
587
- self.util.save_to_store(self.feats_test, "feats_test_scaled")
588
-
589
- def init_runmanager(self):
590
- """Initialize the manager object for the runs."""
591
- self.runmgr = Runmanager(
592
- self.df_train, self.df_test, self.feats_train, self.feats_test
593
- )
594
-
595
- def run(self):
596
- """Do the runs."""
597
- self.runmgr.do_runs()
598
-
599
- # access the best results all runs
600
- self.reports = self.runmgr.best_results
601
- last_epochs = self.runmgr.last_epochs
602
- # try to save yourself
603
- save = self.util.config_val("EXP", "save", False)
604
- if save:
605
- # save the experiment for future use
606
- self.save(self.util.get_save_name())
607
- # self.save_onnx(self.util.get_save_name())
608
-
609
- # self.__collect_reports()
610
- self.util.print_best_results(self.reports)
611
-
612
- # check if the test predictions should be saved to disk
613
- test_pred_file = self.util.config_val("EXP", "save_test", False)
614
- if test_pred_file:
615
- self.predict_test_and_save(test_pred_file)
616
-
617
- # check if the majority voting for all speakers should be plotted
618
- conf_mat_per_speaker_function = self.util.config_val(
619
- "PLOT", "combine_per_speaker", False
620
- )
621
- if conf_mat_per_speaker_function:
622
- self.plot_confmat_per_speaker(conf_mat_per_speaker_function)
623
- used_time = time.process_time() - self.start
624
- self.util.debug(f"Done, used {used_time:.3f} seconds")
625
-
626
- # check if a test set should be labeled by the model:
627
- label_data = self.util.config_val("DATA", "label_data", False)
628
- label_result = self.util.config_val("DATA", "label_result", False)
629
- if label_data and label_result:
630
- self.predict_test_and_save(label_result)
631
-
632
- return self.reports, last_epochs
633
-
634
- def plot_confmat_per_speaker(self, function):
635
- if self.loso or self.logo or self.xfoldx:
636
- self.util.debug(
637
- "plot combined speaker predictions not possible for cross" " validation"
638
- )
639
- return
640
- best = self.get_best_report(self.reports)
641
- # if not best.is_classification:
642
- # best.continuous_to_categorical()
643
- truths = best.truths
644
- preds = best.preds
645
- speakers = self.df_test.speaker.values
646
- print(f"{len(truths)} {len(preds)} {len(speakers) }")
647
- df = pd.DataFrame(data={"truth": truths, "pred": preds, "speaker": speakers})
648
- plot_name = "result_combined_per_speaker"
649
- self.util.debug(
650
- f"plotting speaker combination ({function}) confusion matrix to"
651
- f" {plot_name}"
652
- )
653
- best.plot_per_speaker(df, plot_name, function)
654
-
655
- def get_best_report(self, reports):
656
- return self.runmgr.get_best_result(reports)
657
-
658
- def print_best_model(self):
659
- self.runmgr.print_best_result_runs()
660
-
661
- def demo(self, file, is_list, outfile):
662
- model = self.runmgr.get_best_model()
663
- labelEncoder = None
664
- try:
665
- labelEncoder = self.label_encoder
666
- except AttributeError:
667
- pass
668
- demo = Demo_predictor(
669
- model, file, is_list, self.feature_extractor, labelEncoder, outfile
670
- )
671
- demo.run_demo()
672
-
673
- def predict_test_and_save(self, result_name):
674
- model = self.runmgr.get_best_model()
675
- model.set_testdata(self.df_test, self.feats_test)
676
- test_predictor = TestPredictor(
677
- model, self.df_test, self.label_encoder, result_name
678
- )
679
- result = test_predictor.predict_and_store()
680
- return result
681
-
682
- def load(self, filename):
683
- try:
684
- f = open(filename, "rb")
685
- tmp_dict = pickle.load(f)
686
- f.close()
687
- except EOFError as eof:
688
- self.util.error(f"can't open file {filename}: {eof}")
689
- self.__dict__.update(tmp_dict)
690
- glob_conf.set_labels(self.labels)
691
-
692
- def save(self, filename):
693
- if self.runmgr.modelrunner.model.is_ann():
694
- self.runmgr.modelrunner.model = None
695
- self.util.warn(
696
- "Save experiment: Can't pickle the trained model so saving without it. (it should be stored anyway)"
697
- )
698
- try:
699
- f = open(filename, "wb")
700
- pickle.dump(self.__dict__, f)
701
- f.close()
702
- except (TypeError, AttributeError) as error:
703
- self.feature_extractor.feat_extractor.model = None
704
- f = open(filename, "wb")
705
- pickle.dump(self.__dict__, f)
706
- f.close()
707
- self.util.warn(
708
- "Save experiment: Can't pickle the feature extraction model so saving without it."
709
- + f"{type(error).__name__} {error}"
710
- )
711
- except RuntimeError as error:
712
- self.util.warn(
713
- "Save experiment: Can't pickle local object, NOT saving: "
714
- + f"{type(error).__name__} {error}"
715
- )
716
-
717
- def save_onnx(self, filename):
718
- # export the model to onnx
719
- model = self.runmgr.get_best_model()
720
- if model.is_ann():
721
- print("converting to onnx from torch")
722
- else:
723
-
724
- print("converting to onnx from sklearn")
725
- # save the rest
726
- f = open(filename, "wb")
727
- pickle.dump(self.__dict__, f)
728
- f.close()
nkululeko/resample_cli.py DELETED
@@ -1,99 +0,0 @@
1
- import argparse
2
- import configparser
3
- import os
4
-
5
- import audformat
6
- import pandas as pd
7
-
8
- from nkululeko.augmenting.resampler import Resampler
9
- from nkululeko.constants import VERSION
10
- from nkululeko.experiment import Experiment
11
- from nkululeko.utils.util import Util
12
-
13
-
14
- def main(src_dir):
15
- parser = argparse.ArgumentParser(
16
- description="Call the nkululeko RESAMPLE framework."
17
- )
18
- parser.add_argument("--config", default=None, help="The base configuration")
19
- parser.add_argument("--file", default=None, help="The input audio file to resample")
20
- parser.add_argument(
21
- "--replace", action="store_true", help="Replace the original audio file"
22
- )
23
-
24
- args = parser.parse_args()
25
-
26
- if args.file is None and args.config is None:
27
- print("ERROR: Either --file or --config argument must be provided.")
28
- exit()
29
-
30
- if args.file is not None:
31
- # Load the audio file into a DataFrame
32
- files = pd.Series([args.file])
33
- df_sample = pd.DataFrame(index=files)
34
- df_sample.index = audformat.utils.to_segmented_index(
35
- df_sample.index, allow_nat=False
36
- )
37
-
38
- # Resample the audio file
39
- util = Util("resampler", has_config=False)
40
- util.debug(f"Resampling audio file: {args.file}")
41
- rs = Resampler(df_sample, not_testing=True, replace=args.replace)
42
- rs.resample()
43
- else:
44
- # Existing code for handling INI file
45
- config_file = args.config
46
-
47
- # Test if the configuration file exists
48
- if not os.path.isfile(config_file):
49
- print(f"ERROR: no such file: {config_file}")
50
- exit()
51
-
52
- # Load one configuration per experiment
53
- config = configparser.ConfigParser()
54
- config.read(config_file)
55
- # Create a new experiment
56
- expr = Experiment(config)
57
- module = "resample"
58
- expr.set_module(module)
59
- util = Util(module)
60
- util.debug(
61
- f"running {expr.name} from config {config_file}, nkululeko version"
62
- f" {VERSION}"
63
- )
64
-
65
- if util.config_val("EXP", "no_warnings", False):
66
- import warnings
67
-
68
- warnings.filterwarnings("ignore")
69
-
70
- # Load the data
71
- expr.load_datasets()
72
-
73
- # Split into train and test
74
- expr.fill_train_and_tests()
75
- util.debug(
76
- f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}"
77
- )
78
-
79
- sample_selection = util.config_val("RESAMPLE", "sample_selection", "all")
80
- if sample_selection == "all":
81
- df = pd.concat([expr.df_train, expr.df_test])
82
- elif sample_selection == "train":
83
- df = expr.df_train
84
- elif sample_selection == "test":
85
- df = expr.df_test
86
- else:
87
- util.error(
88
- f"unknown selection specifier {sample_selection}, should be [all |"
89
- " train | test]"
90
- )
91
- util.debug(f"resampling {sample_selection}: {df.shape[0]} samples")
92
- replace = util.config_val("RESAMPLE", "replace", "False")
93
- rs = Resampler(df, replace=replace)
94
- rs.resample()
95
-
96
-
97
- if __name__ == "__main__":
98
- cwd = os.path.dirname(os.path.abspath(__file__))
99
- main(cwd)