nkululeko 0.84.1__tar.gz → 0.85.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. {nkululeko-0.84.1 → nkululeko-0.85.0}/CHANGELOG.md +4 -0
  2. {nkululeko-0.84.1/nkululeko.egg-info → nkululeko-0.85.0}/PKG-INFO +5 -1
  3. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/constants.py +1 -1
  4. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/experiment.py +6 -1
  5. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_whisper.py +3 -6
  6. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/modelrunner.py +56 -33
  7. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/finetune_model.py +9 -0
  8. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model.py +1 -1
  9. nkululeko-0.85.0/nkululeko/models/model_tuned.py +506 -0
  10. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/test_pretrain.py +16 -4
  11. {nkululeko-0.84.1 → nkululeko-0.85.0/nkululeko.egg-info}/PKG-INFO +5 -1
  12. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko.egg-info/SOURCES.txt +1 -0
  13. {nkululeko-0.84.1 → nkululeko-0.85.0}/LICENSE +0 -0
  14. {nkululeko-0.84.1 → nkululeko-0.85.0}/README.md +0 -0
  15. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/aesdd/process_database.py +0 -0
  16. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/androids/process_database.py +0 -0
  17. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/androids_orig/process_database.py +0 -0
  18. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/androids_test/process_database.py +0 -0
  19. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/ased/process_database.py +0 -0
  20. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/asvp-esd/process_database.py +0 -0
  21. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/baved/process_database.py +0 -0
  22. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/cafe/process_database.py +0 -0
  23. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/clac/process_database.py +0 -0
  24. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/cmu-mosei/process_database.py +0 -0
  25. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/demos/process_database.py +0 -0
  26. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/ekorpus/process_database.py +0 -0
  27. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/emns/process_database.py +0 -0
  28. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/emofilm/convert_to_16k.py +0 -0
  29. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/emofilm/process_database.py +0 -0
  30. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/emorynlp/process_database.py +0 -0
  31. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/emov-db/process_database.py +0 -0
  32. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/emovo/process_database.py +0 -0
  33. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/emozionalmente/create.py +0 -0
  34. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/enterface/process_database.py +0 -0
  35. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/esd/process_database.py +0 -0
  36. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/gerparas/process_database.py +0 -0
  37. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/iemocap/process_database.py +0 -0
  38. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/jl/process_database.py +0 -0
  39. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/jtes/process_database.py +0 -0
  40. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/meld/process_database.py +0 -0
  41. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/mesd/process_database.py +0 -0
  42. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/mess/process_database.py +0 -0
  43. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/mlendsnd/process_database.py +0 -0
  44. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/msp-improv/process_database2.py +0 -0
  45. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/msp-podcast/process_database.py +0 -0
  46. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/oreau2/process_database.py +0 -0
  47. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/portuguese/process_database.py +0 -0
  48. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/ravdess/process_database.py +0 -0
  49. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/ravdess/process_database_speaker.py +0 -0
  50. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/savee/process_database.py +0 -0
  51. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/shemo/process_database.py +0 -0
  52. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/subesco/process_database.py +0 -0
  53. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/tess/process_database.py +0 -0
  54. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/thorsten-emotional/process_database.py +0 -0
  55. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/urdu/process_database.py +0 -0
  56. {nkululeko-0.84.1 → nkululeko-0.85.0}/data/vivae/process_database.py +0 -0
  57. {nkululeko-0.84.1 → nkululeko-0.85.0}/docs/source/conf.py +0 -0
  58. {nkululeko-0.84.1 → nkululeko-0.85.0}/meta/demos/demo_best_model.py +0 -0
  59. {nkululeko-0.84.1 → nkululeko-0.85.0}/meta/demos/my_experiment.py +0 -0
  60. {nkululeko-0.84.1 → nkululeko-0.85.0}/meta/demos/my_experiment_local.py +0 -0
  61. {nkululeko-0.84.1 → nkululeko-0.85.0}/meta/demos/plot_faster_anim.py +0 -0
  62. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/__init__.py +0 -0
  63. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/aug_train.py +0 -0
  64. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/augment.py +0 -0
  65. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/augmenting/__init__.py +0 -0
  66. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/augmenting/augmenter.py +0 -0
  67. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/augmenting/randomsplicer.py +0 -0
  68. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/augmenting/randomsplicing.py +0 -0
  69. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/augmenting/resampler.py +0 -0
  70. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/__init__.py +0 -0
  71. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_age.py +0 -0
  72. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_arousal.py +0 -0
  73. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_dominance.py +0 -0
  74. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_gender.py +0 -0
  75. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_mos.py +0 -0
  76. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_pesq.py +0 -0
  77. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_sdr.py +0 -0
  78. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_snr.py +0 -0
  79. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_stoi.py +0 -0
  80. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/ap_valence.py +0 -0
  81. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/autopredict/estimate_snr.py +0 -0
  82. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/cacheddataset.py +0 -0
  83. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/data/__init__.py +0 -0
  84. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/data/dataset.py +0 -0
  85. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/data/dataset_csv.py +0 -0
  86. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/demo.py +0 -0
  87. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/demo_feats.py +0 -0
  88. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/demo_predictor.py +0 -0
  89. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/explore.py +0 -0
  90. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/export.py +0 -0
  91. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/__init__.py +0 -0
  92. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_agender.py +0 -0
  93. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_agender_agender.py +0 -0
  94. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_analyser.py +0 -0
  95. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_auddim.py +0 -0
  96. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_audmodel.py +0 -0
  97. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_clap.py +0 -0
  98. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_hubert.py +0 -0
  99. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_import.py +0 -0
  100. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_mld.py +0 -0
  101. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_mos.py +0 -0
  102. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_opensmile.py +0 -0
  103. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_oxbow.py +0 -0
  104. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_praat.py +0 -0
  105. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_snr.py +0 -0
  106. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_spectra.py +0 -0
  107. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_spkrec.py +0 -0
  108. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_squim.py +0 -0
  109. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_trill.py +0 -0
  110. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_wav2vec2.py +0 -0
  111. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_wavlm.py +0 -0
  112. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/featureset.py +0 -0
  113. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feat_extract/feinberg_praat.py +0 -0
  114. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/feature_extractor.py +0 -0
  115. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/file_checker.py +0 -0
  116. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/filter_data.py +0 -0
  117. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/glob_conf.py +0 -0
  118. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/losses/__init__.py +0 -0
  119. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/losses/loss_ccc.py +0 -0
  120. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/losses/loss_softf1loss.py +0 -0
  121. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/__init__.py +0 -0
  122. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_bayes.py +0 -0
  123. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_cnn.py +0 -0
  124. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_gmm.py +0 -0
  125. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_knn.py +0 -0
  126. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_knn_reg.py +0 -0
  127. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_lin_reg.py +0 -0
  128. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_mlp.py +0 -0
  129. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_mlp_regression.py +0 -0
  130. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_svm.py +0 -0
  131. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_svr.py +0 -0
  132. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_tree.py +0 -0
  133. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_tree_reg.py +0 -0
  134. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_xgb.py +0 -0
  135. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/models/model_xgr.py +0 -0
  136. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/multidb.py +0 -0
  137. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/nkuluflag.py +0 -0
  138. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/nkululeko.py +0 -0
  139. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/plots.py +0 -0
  140. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/predict.py +0 -0
  141. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/reporting/__init__.py +0 -0
  142. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/reporting/defines.py +0 -0
  143. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/reporting/latex_writer.py +0 -0
  144. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/reporting/report.py +0 -0
  145. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/reporting/report_item.py +0 -0
  146. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/reporting/reporter.py +0 -0
  147. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/reporting/result.py +0 -0
  148. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/resample.py +0 -0
  149. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/runmanager.py +0 -0
  150. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/scaler.py +0 -0
  151. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/segment.py +0 -0
  152. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/segmenting/__init__.py +0 -0
  153. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/segmenting/seg_inaspeechsegmenter.py +0 -0
  154. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/segmenting/seg_silero.py +0 -0
  155. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/syllable_nuclei.py +0 -0
  156. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/test.py +0 -0
  157. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/test_predictor.py +0 -0
  158. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/utils/__init__.py +0 -0
  159. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/utils/files.py +0 -0
  160. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/utils/stats.py +0 -0
  161. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko/utils/util.py +0 -0
  162. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko.egg-info/dependency_links.txt +0 -0
  163. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko.egg-info/requires.txt +0 -0
  164. {nkululeko-0.84.1 → nkululeko-0.85.0}/nkululeko.egg-info/top_level.txt +0 -0
  165. {nkululeko-0.84.1 → nkululeko-0.85.0}/pyproject.toml +0 -0
  166. {nkululeko-0.84.1 → nkululeko-0.85.0}/setup.cfg +0 -0
  167. {nkululeko-0.84.1 → nkululeko-0.85.0}/setup.py +0 -0
  168. {nkululeko-0.84.1 → nkululeko-0.85.0}/venv/bin/activate_this.py +0 -0
@@ -1,6 +1,10 @@
1
1
  Changelog
2
2
  =========
3
3
 
4
+ Version 0.85.0
5
+ --------------
6
+ * first version with finetuning wav2vec2 layers
7
+
4
8
  Version 0.84.1
5
9
  --------------
6
10
  * made resample independent of config file
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.84.1
3
+ Version: 0.85.0
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -333,6 +333,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
333
333
  Changelog
334
334
  =========
335
335
 
336
+ Version 0.85.0
337
+ --------------
338
+ * first version with finetuning wav2vec2 layers
339
+
336
340
  Version 0.84.1
337
341
  --------------
338
342
  * made resample independent of config file
@@ -1,2 +1,2 @@
1
- VERSION="0.84.1"
1
+ VERSION="0.85.0"
2
2
  SAMPLING_RATE = 16000
@@ -340,7 +340,12 @@ class Experiment:
340
340
  df_train, df_test = self.df_train, self.df_test
341
341
  feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
342
342
  self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
343
- feats_types = self.util.config_val_list("FEATS", "type", ["os"])
343
+ feats_types = self.util.config_val_list("FEATS", "type", [])
344
+ # for some models no features are needed
345
+ if len(feats_types) == 0:
346
+ self.util.debug("no feature extractor specified.")
347
+ self.feats_train, self.feats_test = pd.DataFrame(), pd.DataFrame()
348
+ return
344
349
  self.feature_extractor = FeatureExtractor(
345
350
  df_train, feats_types, feats_name, "train"
346
351
  )
@@ -32,22 +32,19 @@ class Whisper(Featureset):
32
32
  model_name = f"openai/{self.feat_type}"
33
33
  self.model = WhisperModel.from_pretrained(model_name).to(self.device)
34
34
  print(f"intialized Whisper model on {self.device}")
35
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(
36
- model_name)
35
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
37
36
  self.model_initialized = True
38
37
 
39
38
  def extract(self):
40
39
  """Extract the features or load them from disk if present."""
41
40
  store = self.util.get_path("store")
42
41
  storage = f"{store}{self.name}.pkl"
43
- extract = self.util.config_val(
44
- "FEATS", "needs_feature_extraction", False)
42
+ extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
45
43
  no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
46
44
  if extract or no_reuse or not os.path.isfile(storage):
47
45
  if not self.model_initialized:
48
46
  self.init_model()
49
- self.util.debug(
50
- "extracting whisper embeddings, this might take a while...")
47
+ self.util.debug("extracting whisper embeddings, this might take a while...")
51
48
  emb_series = []
52
49
  for (file, start, end), _ in audeer.progress_bar(
53
50
  self.data_df.iterrows(),
@@ -47,16 +47,12 @@ class Modelrunner:
47
47
  highest = 0
48
48
  else:
49
49
  highest = 100000
50
- # for all epochs
51
- for epoch in range(epoch_num):
52
- if only_test:
53
- self.model.load(self.run, epoch)
54
- self.util.debug(f"reusing model: {self.model.store_path}")
55
- self.model.reset_test(self.df_test, self.feats_test)
56
- else:
57
- self.model.set_id(self.run, epoch)
58
- self.model.train()
50
+ if self.model.model_type == "finetuned":
51
+ # epochs are handled by Huggingface API
52
+ self.model.train()
59
53
  report = self.model.predict()
54
+ # todo: findout the best epoch
55
+ epoch = epoch_num
60
56
  report.set_id(self.run, epoch)
61
57
  plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
62
58
  reports.append(report)
@@ -67,32 +63,53 @@ class Modelrunner:
67
63
  if plot_epochs:
68
64
  self.util.debug(f"plotting conf matrix to {plot_name}")
69
65
  report.plot_confmatrix(plot_name, epoch)
70
- store_models = self.util.config_val("EXP", "save", False)
71
- plot_best_model = self.util.config_val("PLOT", "best_model", False)
72
- if (store_models or plot_best_model) and (
73
- not only_test
74
- ): # in any case the model needs to be stored to disk.
75
- self.model.store()
76
- if patience:
77
- patience = int(patience)
78
- result = report.result.get_result()
79
- if self.util.high_is_good():
80
- if result > highest:
81
- highest = result
82
- patience_counter = 0
83
- else:
84
- patience_counter += 1
66
+ else:
67
+ # for all epochs
68
+ for epoch in range(epoch_num):
69
+ if only_test:
70
+ self.model.load(self.run, epoch)
71
+ self.util.debug(f"reusing model: {self.model.store_path}")
72
+ self.model.reset_test(self.df_test, self.feats_test)
85
73
  else:
86
- if result < highest:
87
- highest = result
88
- patience_counter = 0
74
+ self.model.set_id(self.run, epoch)
75
+ self.model.train()
76
+ report = self.model.predict()
77
+ report.set_id(self.run, epoch)
78
+ plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
79
+ reports.append(report)
80
+ self.util.debug(
81
+ f"run: {self.run} epoch: {epoch}: result: "
82
+ f"{reports[-1].get_result().get_test_result()}"
83
+ )
84
+ if plot_epochs:
85
+ self.util.debug(f"plotting conf matrix to {plot_name}")
86
+ report.plot_confmatrix(plot_name, epoch)
87
+ store_models = self.util.config_val("EXP", "save", False)
88
+ plot_best_model = self.util.config_val("PLOT", "best_model", False)
89
+ if (store_models or plot_best_model) and (
90
+ not only_test
91
+ ): # in any case the model needs to be stored to disk.
92
+ self.model.store()
93
+ if patience:
94
+ patience = int(patience)
95
+ result = report.result.get_result()
96
+ if self.util.high_is_good():
97
+ if result > highest:
98
+ highest = result
99
+ patience_counter = 0
100
+ else:
101
+ patience_counter += 1
89
102
  else:
90
- patience_counter += 1
91
- if patience_counter >= patience:
92
- self.util.debug(
93
- f"reached patience ({str(patience)}): early stopping"
94
- )
95
- break
103
+ if result < highest:
104
+ highest = result
105
+ patience_counter = 0
106
+ else:
107
+ patience_counter += 1
108
+ if patience_counter >= patience:
109
+ self.util.debug(
110
+ f"reached patience ({str(patience)}): early stopping"
111
+ )
112
+ break
96
113
 
97
114
  if not plot_epochs:
98
115
  # Do at least one confusion matrix plot
@@ -133,6 +150,12 @@ class Modelrunner:
133
150
  self.model = Bayes_model(
134
151
  self.df_train, self.df_test, self.feats_train, self.feats_test
135
152
  )
153
+ elif model_type == "finetune":
154
+ from nkululeko.models.model_tuned import Pretrained_model
155
+
156
+ self.model = Pretrained_model(
157
+ self.df_train, self.df_test, self.feats_train, self.feats_test
158
+ )
136
159
  elif model_type == "gmm":
137
160
  from nkululeko.models.model_gmm import GMM_model
138
161
 
@@ -1,3 +1,7 @@
1
+ """
2
+ Code based on @jwagner
3
+ """
4
+
1
5
  import dataclasses
2
6
  import typing
3
7
 
@@ -148,6 +152,11 @@ class Model(Wav2Vec2PreTrainedModel):
148
152
  logits_cat=logits_cat,
149
153
  )
150
154
 
155
+ def predict(self, signal):
156
+ result = self(torch.from_numpy(signal))
157
+ result = result[0].detach().numpy()[0]
158
+ return result
159
+
151
160
 
152
161
  class ModelWithPreProcessing(Model):
153
162
 
@@ -39,7 +39,7 @@ class Model:
39
39
  self.model_type = type
40
40
 
41
41
  def is_ann(self):
42
- if self.model_type == "ann":
42
+ if (self.model_type == "ann") or (self.model_type == "finetuned"):
43
43
  return True
44
44
  else:
45
45
  return False