nkululeko 0.90.1__tar.gz → 0.90.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. {nkululeko-0.90.1 → nkululeko-0.90.2}/CHANGELOG.md +5 -0
  2. {nkululeko-0.90.1/nkululeko.egg-info → nkululeko-0.90.2}/PKG-INFO +6 -1
  3. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/constants.py +1 -1
  4. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_tuned.py +34 -8
  5. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/utils/util.py +2 -2
  6. {nkululeko-0.90.1 → nkululeko-0.90.2/nkululeko.egg-info}/PKG-INFO +6 -1
  7. {nkululeko-0.90.1 → nkululeko-0.90.2}/LICENSE +0 -0
  8. {nkululeko-0.90.1 → nkululeko-0.90.2}/README.md +0 -0
  9. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/aesdd/process_database.py +0 -0
  10. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/androids/process_database.py +0 -0
  11. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/ased/process_database.py +0 -0
  12. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/asvp-esd/process_database.py +0 -0
  13. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/baved/process_database.py +0 -0
  14. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/cafe/process_database.py +0 -0
  15. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/clac/process_database.py +0 -0
  16. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/cmu-mosei/process_database.py +0 -0
  17. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/demos/process_database.py +0 -0
  18. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/ekorpus/process_database.py +0 -0
  19. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/emns/process_database.py +0 -0
  20. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/emofilm/convert_to_16k.py +0 -0
  21. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/emofilm/process_database.py +0 -0
  22. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/emorynlp/process_database.py +0 -0
  23. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/emov-db/process_database.py +0 -0
  24. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/emovo/process_database.py +0 -0
  25. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/emozionalmente/create.py +0 -0
  26. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/enterface/process_database.py +0 -0
  27. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/esd/process_database.py +0 -0
  28. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/gerparas/process_database.py +0 -0
  29. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/iemocap/process_database.py +0 -0
  30. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/jl/process_database.py +0 -0
  31. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/jtes/process_database.py +0 -0
  32. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/meld/process_database.py +0 -0
  33. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/mesd/process_database.py +0 -0
  34. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/mess/process_database.py +0 -0
  35. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/mlendsnd/process_database.py +0 -0
  36. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/msp-improv/process_database2.py +0 -0
  37. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/msp-podcast/process_database.py +0 -0
  38. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/oreau2/process_database.py +0 -0
  39. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/portuguese/process_database.py +0 -0
  40. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/ravdess/process_database.py +0 -0
  41. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/ravdess/process_database_speaker.py +0 -0
  42. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/savee/process_database.py +0 -0
  43. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/shemo/process_database.py +0 -0
  44. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/subesco/process_database.py +0 -0
  45. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/tess/process_database.py +0 -0
  46. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/thorsten-emotional/process_database.py +0 -0
  47. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/urdu/process_database.py +0 -0
  48. {nkululeko-0.90.1 → nkululeko-0.90.2}/data/vivae/process_database.py +0 -0
  49. {nkululeko-0.90.1 → nkululeko-0.90.2}/docs/source/conf.py +0 -0
  50. {nkululeko-0.90.1 → nkululeko-0.90.2}/meta/demos/demo_best_model.py +0 -0
  51. {nkululeko-0.90.1 → nkululeko-0.90.2}/meta/demos/my_experiment.py +0 -0
  52. {nkululeko-0.90.1 → nkululeko-0.90.2}/meta/demos/my_experiment_local.py +0 -0
  53. {nkululeko-0.90.1 → nkululeko-0.90.2}/meta/demos/plot_faster_anim.py +0 -0
  54. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/__init__.py +0 -0
  55. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/aug_train.py +0 -0
  56. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/augment.py +0 -0
  57. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/augmenting/__init__.py +0 -0
  58. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/augmenting/augmenter.py +0 -0
  59. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/augmenting/randomsplicer.py +0 -0
  60. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/augmenting/randomsplicing.py +0 -0
  61. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/augmenting/resampler.py +0 -0
  62. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/__init__.py +0 -0
  63. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_age.py +0 -0
  64. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_arousal.py +0 -0
  65. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_dominance.py +0 -0
  66. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_gender.py +0 -0
  67. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_mos.py +0 -0
  68. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_pesq.py +0 -0
  69. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_sdr.py +0 -0
  70. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_snr.py +0 -0
  71. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_stoi.py +0 -0
  72. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/ap_valence.py +0 -0
  73. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/autopredict/estimate_snr.py +0 -0
  74. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/cacheddataset.py +0 -0
  75. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/data/__init__.py +0 -0
  76. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/data/dataset.py +0 -0
  77. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/data/dataset_csv.py +0 -0
  78. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/demo-ft.py +0 -0
  79. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/demo.py +0 -0
  80. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/demo_feats.py +0 -0
  81. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/demo_predictor.py +0 -0
  82. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/ensemble.py +0 -0
  83. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/experiment.py +0 -0
  84. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/experiment_felix.py +0 -0
  85. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/explore.py +0 -0
  86. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/export.py +0 -0
  87. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/__init__.py +0 -0
  88. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_agender.py +0 -0
  89. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_agender_agender.py +0 -0
  90. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_analyser.py +0 -0
  91. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_ast.py +0 -0
  92. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_auddim.py +0 -0
  93. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_audmodel.py +0 -0
  94. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_clap.py +0 -0
  95. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_hubert.py +0 -0
  96. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_import.py +0 -0
  97. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_mld.py +0 -0
  98. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_mos.py +0 -0
  99. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_opensmile.py +0 -0
  100. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_oxbow.py +0 -0
  101. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_praat.py +0 -0
  102. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_snr.py +0 -0
  103. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_spectra.py +0 -0
  104. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_spkrec.py +0 -0
  105. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_squim.py +0 -0
  106. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_trill.py +0 -0
  107. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_wav2vec2.py +0 -0
  108. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_wavlm.py +0 -0
  109. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feats_whisper.py +0 -0
  110. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/featureset.py +0 -0
  111. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/feinberg_praat.py +0 -0
  112. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feat_extract/transformer_feature_extractor.py +0 -0
  113. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/feature_extractor.py +0 -0
  114. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/file_checker.py +0 -0
  115. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/filter_data.py +0 -0
  116. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/fixedsegment.py +0 -0
  117. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/glob_conf.py +0 -0
  118. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/losses/__init__.py +0 -0
  119. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/losses/loss_ccc.py +0 -0
  120. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/losses/loss_softf1loss.py +0 -0
  121. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/modelrunner.py +0 -0
  122. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/__init__.py +0 -0
  123. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model.py +0 -0
  124. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_bayes.py +0 -0
  125. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_cnn.py +0 -0
  126. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_gmm.py +0 -0
  127. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_knn.py +0 -0
  128. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_knn_reg.py +0 -0
  129. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_lin_reg.py +0 -0
  130. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_mlp.py +0 -0
  131. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_mlp_regression.py +0 -0
  132. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_svm.py +0 -0
  133. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_svr.py +0 -0
  134. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_tree.py +0 -0
  135. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_tree_reg.py +0 -0
  136. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_xgb.py +0 -0
  137. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/models/model_xgr.py +0 -0
  138. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/multidb.py +0 -0
  139. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/nkuluflag.py +0 -0
  140. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/nkululeko.py +0 -0
  141. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/plots.py +0 -0
  142. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/predict.py +0 -0
  143. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/reporting/__init__.py +0 -0
  144. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/reporting/defines.py +0 -0
  145. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/reporting/latex_writer.py +0 -0
  146. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/reporting/report.py +0 -0
  147. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/reporting/report_item.py +0 -0
  148. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/reporting/reporter.py +0 -0
  149. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/reporting/result.py +0 -0
  150. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/resample.py +0 -0
  151. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/resample_cli.py +0 -0
  152. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/runmanager.py +0 -0
  153. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/scaler.py +0 -0
  154. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/segment.py +0 -0
  155. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/segmenting/__init__.py +0 -0
  156. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/segmenting/seg_inaspeechsegmenter.py +0 -0
  157. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/segmenting/seg_silero.py +0 -0
  158. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/syllable_nuclei.py +0 -0
  159. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/test.py +0 -0
  160. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/test_predictor.py +0 -0
  161. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/test_pretrain.py +0 -0
  162. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/utils/__init__.py +0 -0
  163. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/utils/files.py +0 -0
  164. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko/utils/stats.py +0 -0
  165. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko.egg-info/SOURCES.txt +0 -0
  166. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko.egg-info/dependency_links.txt +0 -0
  167. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko.egg-info/requires.txt +0 -0
  168. {nkululeko-0.90.1 → nkululeko-0.90.2}/nkululeko.egg-info/top_level.txt +0 -0
  169. {nkululeko-0.90.1 → nkululeko-0.90.2}/pyproject.toml +0 -0
  170. {nkululeko-0.90.1 → nkululeko-0.90.2}/setup.cfg +0 -0
  171. {nkululeko-0.90.1 → nkululeko-0.90.2}/setup.py +0 -0
  172. {nkululeko-0.90.1 → nkululeko-0.90.2}/venv/bin/activate_this.py +0 -0
@@ -1,6 +1,11 @@
1
1
  Changelog
2
2
  =========
3
3
 
4
+ Version 0.90.2
5
+ --------------
6
+ * added probability output to finetuning classification models
7
+ * switched path to prob. output from "store" to "results"
8
+
4
9
  Version 0.90.1
5
10
  --------------
6
11
  * Add balancing for finetune and update data README
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.90.1
3
+ Version: 0.90.2
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -356,6 +356,11 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
356
356
  Changelog
357
357
  =========
358
358
 
359
+ Version 0.90.2
360
+ --------------
361
+ * added probability output to finetuning classification models
362
+ * switched path to prob. output from "store" to "results"
363
+
359
364
  Version 0.90.1
360
365
  --------------
361
366
  * Add balancing for finetune and update data README
@@ -1,2 +1,2 @@
1
- VERSION = "0.90.1"
1
+ VERSION = "0.90.2"
2
2
  SAMPLING_RATE = 16000
@@ -30,10 +30,16 @@ class TunedModel(BaseModel):
30
30
  """Constructor taking the configuration and all dataframes."""
31
31
  super().__init__(df_train, df_test, feats_train, feats_test)
32
32
  super().set_model_type("finetuned")
33
+ self.df_test, self.df_train, self.feats_test, self.feats_train = (
34
+ df_test,
35
+ df_train,
36
+ feats_test,
37
+ feats_train,
38
+ )
33
39
  self.name = "finetuned_wav2vec2"
34
40
  self.target = glob_conf.config["DATA"]["target"]
35
- labels = glob_conf.labels
36
- self.class_num = len(labels)
41
+ self.labels = glob_conf.labels
42
+ self.class_num = len(self.labels)
37
43
  device = self.util.config_val("MODEL", "device", False)
38
44
  if not device:
39
45
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -304,7 +310,7 @@ class TunedModel(BaseModel):
304
310
  else:
305
311
  self.util.error(f"criterion {criterion} not supported for classifier")
306
312
  else:
307
- self.criterion = self.util.config_val("MODEL", "loss", "ccc")
313
+ criterion = self.util.config_val("MODEL", "loss", "1-ccc")
308
314
  if criterion == "1-ccc":
309
315
  criterion = ConcordanceCorCoeff()
310
316
  elif criterion == "mse":
@@ -402,7 +408,7 @@ class TunedModel(BaseModel):
402
408
  self.load(self.run, self.epoch)
403
409
 
404
410
  def get_predictions(self):
405
- results = []
411
+ results = [[]].pop(0)
406
412
  for (file, start, end), _ in audeer.progress_bar(
407
413
  self.df_test.iterrows(),
408
414
  total=len(self.df_test),
@@ -415,18 +421,37 @@ class TunedModel(BaseModel):
415
421
  file, duration=end - start, offset=start, always_2d=True
416
422
  )
417
423
  assert sr == self.sampling_rate
418
- predictions = self.model.predict(signal)
419
- results.append(predictions.argmax())
420
- return results
424
+ prediction = self.model.predict(signal)
425
+ results.append(prediction)
426
+ # results.append(predictions.argmax())
427
+ predictions = np.asarray(results)
428
+ if self.util.exp_is_classification():
429
+ # make a dataframe for the class probabilities
430
+ proba_d = {}
431
+ for c in range(self.class_num):
432
+ proba_d[c] = []
433
+ # get the class probabilities
434
+ # predictions = self.clf.predict_proba(self.feats_test.to_numpy())
435
+ # pred = self.clf.predict(features)
436
+ for i in range(self.class_num):
437
+ proba_d[i] = list(predictions.T[i])
438
+ probas = pd.DataFrame(proba_d)
439
+ probas = probas.set_index(self.df_test.index)
440
+ predictions = probas.idxmax(axis=1).values
441
+ else:
442
+ predictions = predictions.flatten()
443
+ probas = None
444
+ return predictions, probas
421
445
 
422
446
  def predict(self):
423
447
  """Predict the whole eval feature set"""
424
- predictions = self.get_predictions()
448
+ predictions, probas = self.get_predictions()
425
449
  report = Reporter(
426
450
  self.df_test[self.target].to_numpy().astype(float),
427
451
  predictions,
428
452
  self.run,
429
453
  self.epoch_num,
454
+ probas=probas,
430
455
  )
431
456
  self._plot_epoch_progression(report)
432
457
  return report
@@ -438,6 +463,7 @@ class TunedModel(BaseModel):
438
463
  )
439
464
  with open(log_file, "r") as file:
440
465
  data = file.read()
466
+ data = data.strip().replace("nan", "0")
441
467
  list = ast.literal_eval(data)
442
468
  epochs, vals, loss = [], [], []
443
469
  for index, tp in enumerate(list):
@@ -155,10 +155,10 @@ class Util:
155
155
  return f"{store}/{self.get_exp_name()}.pkl"
156
156
 
157
157
  def get_pred_name(self):
158
- store = self.get_path("store")
158
+ results_dir = self.get_path("res_dir")
159
159
  target = self.get_target_name()
160
160
  pred_name = self.get_model_description()
161
- return f"{store}/pred_{target}_{pred_name}.csv"
161
+ return f"{results_dir}/pred_{target}_{pred_name}.csv"
162
162
 
163
163
  def is_categorical(self, pd_series):
164
164
  """Check if a dataframe column is categorical."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.90.1
3
+ Version: 0.90.2
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -356,6 +356,11 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
356
356
  Changelog
357
357
  =========
358
358
 
359
+ Version 0.90.2
360
+ --------------
361
+ * added probability output to finetuning classification models
362
+ * switched path to prob. output from "store" to "results"
363
+
359
364
  Version 0.90.1
360
365
  --------------
361
366
  * Add balancing for finetune and update data README
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes