nkululeko 0.84.0__tar.gz → 0.84.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {nkululeko-0.84.0 → nkululeko-0.84.1}/CHANGELOG.md +4 -0
  2. {nkululeko-0.84.0/nkululeko.egg-info → nkululeko-0.84.1}/PKG-INFO +5 -1
  3. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/resampler.py +9 -4
  4. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/constants.py +1 -1
  5. nkululeko-0.84.1/nkululeko/models/finetune_model.py +181 -0
  6. nkululeko-0.84.1/nkululeko/resample.py +100 -0
  7. nkululeko-0.84.1/nkululeko/test_pretrain.py +294 -0
  8. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/utils/util.py +53 -32
  9. {nkululeko-0.84.0 → nkululeko-0.84.1/nkululeko.egg-info}/PKG-INFO +5 -1
  10. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko.egg-info/SOURCES.txt +1 -0
  11. nkululeko-0.84.0/nkululeko/resample.py +0 -78
  12. nkululeko-0.84.0/nkululeko/test_pretrain.py +0 -117
  13. {nkululeko-0.84.0 → nkululeko-0.84.1}/LICENSE +0 -0
  14. {nkululeko-0.84.0 → nkululeko-0.84.1}/README.md +0 -0
  15. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/aesdd/process_database.py +0 -0
  16. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/androids/process_database.py +0 -0
  17. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/androids_orig/process_database.py +0 -0
  18. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/androids_test/process_database.py +0 -0
  19. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/ased/process_database.py +0 -0
  20. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/asvp-esd/process_database.py +0 -0
  21. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/baved/process_database.py +0 -0
  22. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/cafe/process_database.py +0 -0
  23. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/clac/process_database.py +0 -0
  24. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/cmu-mosei/process_database.py +0 -0
  25. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/demos/process_database.py +0 -0
  26. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/ekorpus/process_database.py +0 -0
  27. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emns/process_database.py +0 -0
  28. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emofilm/convert_to_16k.py +0 -0
  29. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emofilm/process_database.py +0 -0
  30. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emorynlp/process_database.py +0 -0
  31. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emov-db/process_database.py +0 -0
  32. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emovo/process_database.py +0 -0
  33. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emozionalmente/create.py +0 -0
  34. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/enterface/process_database.py +0 -0
  35. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/esd/process_database.py +0 -0
  36. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/gerparas/process_database.py +0 -0
  37. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/iemocap/process_database.py +0 -0
  38. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/jl/process_database.py +0 -0
  39. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/jtes/process_database.py +0 -0
  40. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/meld/process_database.py +0 -0
  41. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/mesd/process_database.py +0 -0
  42. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/mess/process_database.py +0 -0
  43. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/mlendsnd/process_database.py +0 -0
  44. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/msp-improv/process_database2.py +0 -0
  45. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/msp-podcast/process_database.py +0 -0
  46. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/oreau2/process_database.py +0 -0
  47. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/portuguese/process_database.py +0 -0
  48. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/ravdess/process_database.py +0 -0
  49. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/ravdess/process_database_speaker.py +0 -0
  50. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/savee/process_database.py +0 -0
  51. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/shemo/process_database.py +0 -0
  52. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/subesco/process_database.py +0 -0
  53. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/tess/process_database.py +0 -0
  54. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/thorsten-emotional/process_database.py +0 -0
  55. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/urdu/process_database.py +0 -0
  56. {nkululeko-0.84.0 → nkululeko-0.84.1}/data/vivae/process_database.py +0 -0
  57. {nkululeko-0.84.0 → nkululeko-0.84.1}/docs/source/conf.py +0 -0
  58. {nkululeko-0.84.0 → nkululeko-0.84.1}/meta/demos/demo_best_model.py +0 -0
  59. {nkululeko-0.84.0 → nkululeko-0.84.1}/meta/demos/my_experiment.py +0 -0
  60. {nkululeko-0.84.0 → nkululeko-0.84.1}/meta/demos/my_experiment_local.py +0 -0
  61. {nkululeko-0.84.0 → nkululeko-0.84.1}/meta/demos/plot_faster_anim.py +0 -0
  62. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/__init__.py +0 -0
  63. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/aug_train.py +0 -0
  64. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augment.py +0 -0
  65. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/__init__.py +0 -0
  66. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/augmenter.py +0 -0
  67. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/randomsplicer.py +0 -0
  68. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/randomsplicing.py +0 -0
  69. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/__init__.py +0 -0
  70. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_age.py +0 -0
  71. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_arousal.py +0 -0
  72. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_dominance.py +0 -0
  73. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_gender.py +0 -0
  74. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_mos.py +0 -0
  75. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_pesq.py +0 -0
  76. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_sdr.py +0 -0
  77. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_snr.py +0 -0
  78. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_stoi.py +0 -0
  79. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_valence.py +0 -0
  80. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/estimate_snr.py +0 -0
  81. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/cacheddataset.py +0 -0
  82. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/data/__init__.py +0 -0
  83. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/data/dataset.py +0 -0
  84. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/data/dataset_csv.py +0 -0
  85. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/demo.py +0 -0
  86. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/demo_feats.py +0 -0
  87. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/demo_predictor.py +0 -0
  88. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/experiment.py +0 -0
  89. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/explore.py +0 -0
  90. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/export.py +0 -0
  91. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/__init__.py +0 -0
  92. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_agender.py +0 -0
  93. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_agender_agender.py +0 -0
  94. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_analyser.py +0 -0
  95. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_auddim.py +0 -0
  96. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_audmodel.py +0 -0
  97. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_clap.py +0 -0
  98. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_hubert.py +0 -0
  99. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_import.py +0 -0
  100. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_mld.py +0 -0
  101. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_mos.py +0 -0
  102. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_opensmile.py +0 -0
  103. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_oxbow.py +0 -0
  104. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_praat.py +0 -0
  105. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_snr.py +0 -0
  106. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_spectra.py +0 -0
  107. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_spkrec.py +0 -0
  108. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_squim.py +0 -0
  109. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_trill.py +0 -0
  110. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_wav2vec2.py +0 -0
  111. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_wavlm.py +0 -0
  112. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_whisper.py +0 -0
  113. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/featureset.py +0 -0
  114. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feinberg_praat.py +0 -0
  115. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feature_extractor.py +0 -0
  116. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/file_checker.py +0 -0
  117. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/filter_data.py +0 -0
  118. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/glob_conf.py +0 -0
  119. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/losses/__init__.py +0 -0
  120. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/losses/loss_ccc.py +0 -0
  121. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/losses/loss_softf1loss.py +0 -0
  122. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/modelrunner.py +0 -0
  123. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/__init__.py +0 -0
  124. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model.py +0 -0
  125. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_bayes.py +0 -0
  126. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_cnn.py +0 -0
  127. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_gmm.py +0 -0
  128. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_knn.py +0 -0
  129. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_knn_reg.py +0 -0
  130. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_lin_reg.py +0 -0
  131. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_mlp.py +0 -0
  132. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_mlp_regression.py +0 -0
  133. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_svm.py +0 -0
  134. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_svr.py +0 -0
  135. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_tree.py +0 -0
  136. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_tree_reg.py +0 -0
  137. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_xgb.py +0 -0
  138. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_xgr.py +0 -0
  139. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/multidb.py +0 -0
  140. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/nkuluflag.py +0 -0
  141. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/nkululeko.py +0 -0
  142. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/plots.py +0 -0
  143. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/predict.py +0 -0
  144. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/__init__.py +0 -0
  145. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/defines.py +0 -0
  146. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/latex_writer.py +0 -0
  147. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/report.py +0 -0
  148. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/report_item.py +0 -0
  149. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/reporter.py +0 -0
  150. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/result.py +0 -0
  151. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/runmanager.py +0 -0
  152. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/scaler.py +0 -0
  153. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/segment.py +0 -0
  154. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/segmenting/__init__.py +0 -0
  155. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/segmenting/seg_inaspeechsegmenter.py +0 -0
  156. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/segmenting/seg_silero.py +0 -0
  157. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/syllable_nuclei.py +0 -0
  158. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/test.py +0 -0
  159. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/test_predictor.py +0 -0
  160. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/utils/__init__.py +0 -0
  161. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/utils/files.py +0 -0
  162. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/utils/stats.py +0 -0
  163. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko.egg-info/dependency_links.txt +0 -0
  164. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko.egg-info/requires.txt +0 -0
  165. {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko.egg-info/top_level.txt +0 -0
  166. {nkululeko-0.84.0 → nkululeko-0.84.1}/pyproject.toml +0 -0
  167. {nkululeko-0.84.0 → nkululeko-0.84.1}/setup.cfg +0 -0
  168. {nkululeko-0.84.0 → nkululeko-0.84.1}/setup.py +0 -0
  169. {nkululeko-0.84.0 → nkululeko-0.84.1}/venv/bin/activate_this.py +0 -0
@@ -1,6 +1,10 @@
1
1
  Changelog
2
2
  =========
3
3
 
4
+ Version 0.84.1
5
+ --------------
6
+ * made resample independent of config file
7
+
4
8
  Version 0.84.0
5
9
  --------------
6
10
  * added SHAP analysis
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.84.0
3
+ Version: 0.84.1
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -333,6 +333,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
333
333
  Changelog
334
334
  =========
335
335
 
336
+ Version 0.84.1
337
+ --------------
338
+ * made resample independent of config file
339
+
336
340
  Version 0.84.0
337
341
  --------------
338
342
  * added SHAP analysis
@@ -12,16 +12,19 @@ from nkululeko.utils.util import Util
12
12
 
13
13
 
14
14
  class Resampler:
15
- def __init__(self, df, not_testing=True):
15
+ def __init__(self, df, replace, not_testing=True):
16
16
  self.SAMPLING_RATE = 16000
17
17
  self.df = df
18
18
  self.util = Util("resampler", has_config=not_testing)
19
19
  self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
20
20
  self.not_testing = not_testing
21
+ self.replace = eval(self.util.config_val(
22
+ "RESAMPLE", "replace", "False")) if not not_testing else replace
21
23
 
22
24
  def resample(self):
23
25
  files = self.df.index.get_level_values(0).values
24
- replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
26
+ # replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
27
+ replace = self.replace
25
28
  if self.not_testing:
26
29
  store = self.util.get_path("store")
27
30
  else:
@@ -42,7 +45,8 @@ class Resampler:
42
45
  continue
43
46
  if org_sr != self.SAMPLING_RATE:
44
47
  self.util.debug(f"resampling {f} (sr = {org_sr})")
45
- resampler = torchaudio.transforms.Resample(org_sr, self.SAMPLING_RATE)
48
+ resampler = torchaudio.transforms.Resample(
49
+ org_sr, self.SAMPLING_RATE)
46
50
  signal = resampler(signal)
47
51
  if replace:
48
52
  torchaudio.save(
@@ -59,7 +63,8 @@ class Resampler:
59
63
  self.df = self.df.set_index(
60
64
  self.df.index.set_levels(new_files, level="file")
61
65
  )
62
- target_file = self.util.config_val("RESAMPLE", "target", "resampled.csv")
66
+ target_file = self.util.config_val(
67
+ "RESAMPLE", "target", "resampled.csv")
63
68
  # remove encoded labels
64
69
  target = self.util.config_val("DATA", "target", "emotion")
65
70
  if "class_label" in self.df.columns:
@@ -1,2 +1,2 @@
1
- VERSION="0.84.0"
1
+ VERSION="0.84.1"
2
2
  SAMPLING_RATE = 16000
@@ -0,0 +1,181 @@
1
+ import dataclasses
2
+ import typing
3
+
4
+ import torch
5
+ import transformers
6
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
7
+ Wav2Vec2PreTrainedModel,
8
+ Wav2Vec2Model,
9
+ )
10
+
11
+
12
+ class ConcordanceCorCoeff(torch.nn.Module):
13
+
14
+ def __init__(self):
15
+
16
+ super().__init__()
17
+
18
+ self.mean = torch.mean
19
+ self.var = torch.var
20
+ self.sum = torch.sum
21
+ self.sqrt = torch.sqrt
22
+ self.std = torch.std
23
+
24
+ def forward(self, prediction, ground_truth):
25
+
26
+ mean_gt = self.mean(ground_truth, 0)
27
+ mean_pred = self.mean(prediction, 0)
28
+ var_gt = self.var(ground_truth, 0)
29
+ var_pred = self.var(prediction, 0)
30
+ v_pred = prediction - mean_pred
31
+ v_gt = ground_truth - mean_gt
32
+ cor = self.sum(v_pred * v_gt) / (
33
+ self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
34
+ )
35
+ sd_gt = self.std(ground_truth)
36
+ sd_pred = self.std(prediction)
37
+ numerator = 2 * cor * sd_gt * sd_pred
38
+ denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
39
+ ccc = numerator / denominator
40
+
41
+ return 1 - ccc
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class ModelOutput(transformers.file_utils.ModelOutput):
46
+
47
+ logits_cat: torch.FloatTensor = None
48
+ hidden_states: typing.Tuple[torch.FloatTensor] = None
49
+ cnn_features: torch.FloatTensor = None
50
+
51
+
52
+ class ModelHead(torch.nn.Module):
53
+
54
+ def __init__(self, config, num_labels):
55
+
56
+ super().__init__()
57
+
58
+ self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
59
+ self.dropout = torch.nn.Dropout(config.final_dropout)
60
+ self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
61
+
62
+ def forward(self, features, **kwargs):
63
+
64
+ x = features
65
+ x = self.dropout(x)
66
+ x = self.dense(x)
67
+ x = torch.tanh(x)
68
+ x = self.dropout(x)
69
+ x = self.out_proj(x)
70
+
71
+ return x
72
+
73
+
74
+ class Model(Wav2Vec2PreTrainedModel):
75
+
76
+ def __init__(self, config):
77
+
78
+ super().__init__(config)
79
+
80
+ self.wav2vec2 = Wav2Vec2Model(config)
81
+ self.cat = ModelHead(config, 2)
82
+ self.init_weights()
83
+
84
+ def freeze_feature_extractor(self):
85
+ self.wav2vec2.feature_extractor._freeze_parameters()
86
+
87
+ def pooling(
88
+ self,
89
+ hidden_states,
90
+ attention_mask,
91
+ ):
92
+
93
+ if attention_mask is None: # For evaluation with batch_size==1
94
+ outputs = torch.mean(hidden_states, dim=1)
95
+ else:
96
+ attention_mask = self._get_feature_vector_attention_mask(
97
+ hidden_states.shape[1],
98
+ attention_mask,
99
+ )
100
+ hidden_states = hidden_states * torch.reshape(
101
+ attention_mask,
102
+ (-1, attention_mask.shape[-1], 1),
103
+ )
104
+ outputs = torch.sum(hidden_states, dim=1)
105
+ attention_sum = torch.sum(attention_mask, dim=1)
106
+ outputs = outputs / torch.reshape(attention_sum, (-1, 1))
107
+
108
+ return outputs
109
+
110
+ def forward(
111
+ self,
112
+ input_values,
113
+ attention_mask=None,
114
+ labels=None,
115
+ return_hidden=False,
116
+ ):
117
+
118
+ outputs = self.wav2vec2(
119
+ input_values,
120
+ attention_mask=attention_mask,
121
+ )
122
+
123
+ cnn_features = outputs.extract_features
124
+ hidden_states_framewise = outputs.last_hidden_state
125
+ hidden_states = self.pooling(
126
+ hidden_states_framewise,
127
+ attention_mask,
128
+ )
129
+ logits_cat = self.cat(hidden_states)
130
+
131
+ if not self.training:
132
+ logits_cat = torch.softmax(logits_cat, dim=1)
133
+
134
+ if return_hidden:
135
+
136
+ # make time last axis
137
+ cnn_features = torch.transpose(cnn_features, 1, 2)
138
+
139
+ return ModelOutput(
140
+ logits_cat=logits_cat,
141
+ hidden_states=hidden_states,
142
+ cnn_features=cnn_features,
143
+ )
144
+
145
+ else:
146
+
147
+ return ModelOutput(
148
+ logits_cat=logits_cat,
149
+ )
150
+
151
+
152
+ class ModelWithPreProcessing(Model):
153
+
154
+ def __init__(self, config):
155
+ super().__init__(config)
156
+
157
+ def forward(
158
+ self,
159
+ input_values,
160
+ ):
161
+ # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
162
+ # normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
163
+
164
+ mean = input_values.mean()
165
+
166
+ # var = input_values.var()
167
+ # raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
168
+
169
+ var = torch.square(input_values - mean).mean()
170
+ input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
171
+
172
+ output = super().forward(
173
+ input_values,
174
+ return_hidden=True,
175
+ )
176
+
177
+ return (
178
+ output.hidden_states,
179
+ output.logits_cat,
180
+ output.cnn_features,
181
+ )
@@ -0,0 +1,100 @@
1
+ # resample.py
2
+ # change the sampling rate for audio file or INI file (train, test, all)
3
+
4
+ import argparse
5
+ import configparser
6
+ import os
7
+ import pandas as pd
8
+ import audformat
9
+ from nkululeko.augmenting.resampler import Resampler
10
+ from nkululeko.utils.util import Util
11
+
12
+ from nkululeko.constants import VERSION
13
+ from nkululeko.experiment import Experiment
14
+
15
+
16
+ def main(src_dir):
17
+ parser = argparse.ArgumentParser(
18
+ description="Call the nkululeko RESAMPLE framework.")
19
+ parser.add_argument("--config", default=None,
20
+ help="The base configuration")
21
+ parser.add_argument("--file", default=None,
22
+ help="The input audio file to resample")
23
+ parser.add_argument("--replace", action="store_true",
24
+ help="Replace the original audio file")
25
+
26
+ args = parser.parse_args()
27
+
28
+ if args.file is None and args.config is None:
29
+ print("ERROR: Either --file or --config argument must be provided.")
30
+ exit()
31
+
32
+ if args.file is not None:
33
+ # Load the audio file into a DataFrame
34
+ files = pd.Series([args.file])
35
+ df_sample = pd.DataFrame(index=files)
36
+ df_sample.index = audformat.utils.to_segmented_index(
37
+ df_sample.index, allow_nat=False
38
+ )
39
+
40
+ # Resample the audio file
41
+ util = Util("resampler", has_config=False)
42
+ util.debug(f"Resampling audio file: {args.file}")
43
+ rs = Resampler(df_sample, not_testing=True, replace=args.replace)
44
+ rs.resample()
45
+ else:
46
+ # Existing code for handling INI file
47
+ config_file = args.config
48
+
49
+ # Test if the configuration file exists
50
+ if not os.path.isfile(config_file):
51
+ print(f"ERROR: no such file: {config_file}")
52
+ exit()
53
+
54
+ # Load one configuration per experiment
55
+ config = configparser.ConfigParser()
56
+ config.read(config_file)
57
+ # Create a new experiment
58
+ expr = Experiment(config)
59
+ module = "resample"
60
+ expr.set_module(module)
61
+ util = Util(module)
62
+ util.debug(
63
+ f"running {expr.name} from config {config_file}, nkululeko version"
64
+ f" {VERSION}"
65
+ )
66
+
67
+ if util.config_val("EXP", "no_warnings", False):
68
+ import warnings
69
+ warnings.filterwarnings("ignore")
70
+
71
+ # Load the data
72
+ expr.load_datasets()
73
+
74
+ # Split into train and test
75
+ expr.fill_train_and_tests()
76
+ util.debug(
77
+ f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
78
+
79
+ sample_selection = util.config_val(
80
+ "RESAMPLE", "sample_selection", "all")
81
+ if sample_selection == "all":
82
+ df = pd.concat([expr.df_train, expr.df_test])
83
+ elif sample_selection == "train":
84
+ df = expr.df_train
85
+ elif sample_selection == "test":
86
+ df = expr.df_test
87
+ else:
88
+ util.error(
89
+ f"unknown selection specifier {sample_selection}, should be [all |"
90
+ " train | test]"
91
+ )
92
+ util.debug(f"resampling {sample_selection}: {df.shape[0]} samples")
93
+ replace = util.config_val("RESAMPLE", "replace", "False")
94
+ rs = Resampler(df, replace=replace)
95
+ rs.resample()
96
+
97
+
98
+ if __name__ == "__main__":
99
+ cwd = os.path.dirname(os.path.abspath(__file__))
100
+ main(cwd)
@@ -0,0 +1,294 @@
1
+ # test_pretrain.py
2
+ import argparse
3
+ import configparser
4
+ import os.path
5
+
6
+ import datasets
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import transformers
11
+
12
+ import audeer
13
+ import audiofile
14
+ import audmetric
15
+
16
+ from nkululeko.constants import VERSION
17
+ import nkululeko.experiment as exp
18
+ import nkululeko.models.finetune_model as fm
19
+ import nkululeko.glob_conf as glob_conf
20
+ from nkululeko.utils.util import Util
21
+ import json
22
+
23
+
24
+ def doit(config_file):
25
+ # test if the configuration file exists
26
+ if not os.path.isfile(config_file):
27
+ print(f"ERROR: no such file: {config_file}")
28
+ exit()
29
+
30
+ # load one configuration per experiment
31
+ config = configparser.ConfigParser()
32
+ config.read(config_file)
33
+
34
+ # create a new experiment
35
+ expr = exp.Experiment(config)
36
+ module = "test_pretrain"
37
+ expr.set_module(module)
38
+ util = Util(module)
39
+ util.debug(
40
+ f"running {expr.name} from config {config_file}, nkululeko version"
41
+ f" {VERSION}"
42
+ )
43
+
44
+ if util.config_val("EXP", "no_warnings", False):
45
+ import warnings
46
+
47
+ warnings.filterwarnings("ignore")
48
+
49
+ # load the data
50
+ expr.load_datasets()
51
+
52
+ # split into train and test
53
+ expr.fill_train_and_tests()
54
+ util.debug(f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
55
+
56
+ log_root = audeer.mkdir("log")
57
+ model_root = audeer.mkdir("model")
58
+ torch_root = audeer.path(model_root, "torch")
59
+
60
+ metrics_gender = {
61
+ "UAR": audmetric.unweighted_average_recall,
62
+ "ACC": audmetric.accuracy,
63
+ }
64
+
65
+ sampling_rate = 16000
66
+ max_duration_sec = 8.0
67
+
68
+ model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
69
+ num_layers = None
70
+
71
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
72
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3"
73
+
74
+ batch_size = 16
75
+ accumulation_steps = 4
76
+ # create dataset
77
+
78
+ dataset = {}
79
+ target_name = glob_conf.target
80
+ data_sources = {
81
+ "train": pd.DataFrame(expr.df_train[target_name]),
82
+ "dev": pd.DataFrame(expr.df_test[target_name]),
83
+ }
84
+
85
+ for split in ["train", "dev"]:
86
+ df = data_sources[split]
87
+ df[target_name] = df[target_name].astype("float")
88
+
89
+ y = pd.Series(
90
+ data=df.itertuples(index=False, name=None),
91
+ index=df.index,
92
+ dtype=object,
93
+ name="labels",
94
+ )
95
+
96
+ y.name = "targets"
97
+ df = y.reset_index()
98
+ df.start = df.start.dt.total_seconds()
99
+ df.end = df.end.dt.total_seconds()
100
+
101
+ print(f"{split}: {len(df)}")
102
+
103
+ ds = datasets.Dataset.from_pandas(df)
104
+ dataset[split] = ds
105
+
106
+ dataset = datasets.DatasetDict(dataset)
107
+
108
+ # load pre-trained model
109
+ le = glob_conf.label_encoder
110
+ mapping = dict(zip(le.classes_, range(len(le.classes_))))
111
+ target_mapping = {k: int(v) for k, v in mapping.items()}
112
+ target_mapping_reverse = {value: key for key, value in target_mapping.items()}
113
+
114
+ config = transformers.AutoConfig.from_pretrained(
115
+ model_path,
116
+ num_labels=len(target_mapping),
117
+ label2id=target_mapping,
118
+ id2label=target_mapping_reverse,
119
+ finetuning_task=target_name,
120
+ )
121
+ if num_layers is not None:
122
+ config.num_hidden_layers = num_layers
123
+ setattr(config, "sampling_rate", sampling_rate)
124
+ setattr(config, "data", util.get_data_name())
125
+
126
+ vocab_dict = {}
127
+ with open("vocab.json", "w") as vocab_file:
128
+ json.dump(vocab_dict, vocab_file)
129
+ tokenizer = transformers.Wav2Vec2CTCTokenizer("./vocab.json")
130
+ tokenizer.save_pretrained(".")
131
+
132
+ feature_extractor = transformers.Wav2Vec2FeatureExtractor(
133
+ feature_size=1,
134
+ sampling_rate=16000,
135
+ padding_value=0.0,
136
+ do_normalize=True,
137
+ return_attention_mask=True,
138
+ )
139
+ processor = transformers.Wav2Vec2Processor(
140
+ feature_extractor=feature_extractor,
141
+ tokenizer=tokenizer,
142
+ )
143
+ assert processor.feature_extractor.sampling_rate == sampling_rate
144
+
145
+ model = fm.Model.from_pretrained(
146
+ model_path,
147
+ config=config,
148
+ )
149
+ model.freeze_feature_extractor()
150
+ model.train()
151
+
152
+ # training
153
+
154
+ def data_collator(data):
155
+
156
+ files = [d["file"] for d in data]
157
+ starts = [d["start"] for d in data]
158
+ ends = [d["end"] for d in data]
159
+ targets = [d["targets"] for d in data]
160
+
161
+ signals = []
162
+ for file, start, end in zip(
163
+ files,
164
+ starts,
165
+ ends,
166
+ ):
167
+ offset = start
168
+ duration = end - offset
169
+ if max_duration_sec is not None:
170
+ duration = min(duration, max_duration_sec)
171
+ signal, _ = audiofile.read(
172
+ file,
173
+ offset=offset,
174
+ duration=duration,
175
+ )
176
+ signals.append(signal.squeeze())
177
+
178
+ input_values = processor(
179
+ signals,
180
+ sampling_rate=sampling_rate,
181
+ padding=True,
182
+ )
183
+ batch = processor.pad(
184
+ input_values,
185
+ padding=True,
186
+ return_tensors="pt",
187
+ )
188
+
189
+ batch["labels"] = torch.tensor(targets)
190
+
191
+ return batch
192
+
193
+ def compute_metrics(p: transformers.EvalPrediction):
194
+
195
+ truth_gender = p.label_ids[:, 0].astype(int)
196
+ preds = p.predictions
197
+ preds_gender = np.argmax(preds, axis=1)
198
+
199
+ scores = {}
200
+
201
+ for name, metric in metrics_gender.items():
202
+ scores[f"gender-{name}"] = metric(truth_gender, preds_gender)
203
+
204
+ scores["combined"] = scores["gender-UAR"]
205
+
206
+ return scores
207
+
208
+ targets = pd.DataFrame(dataset["train"]["targets"])
209
+ counts = targets[0].value_counts().sort_index()
210
+ train_weights = 1 / counts
211
+ train_weights /= train_weights.sum()
212
+
213
+ print(train_weights)
214
+
215
+ criterion_gender = torch.nn.CrossEntropyLoss(
216
+ weight=torch.Tensor(train_weights).to("cuda"),
217
+ )
218
+
219
+ class Trainer(transformers.Trainer):
220
+
221
+ def compute_loss(
222
+ self,
223
+ model,
224
+ inputs,
225
+ return_outputs=False,
226
+ ):
227
+
228
+ targets = inputs.pop("labels").squeeze()
229
+ targets_gender = targets.type(torch.long)
230
+
231
+ outputs = model(**inputs)
232
+ logits_gender = outputs[0].squeeze()
233
+
234
+ loss_gender = criterion_gender(logits_gender, targets_gender)
235
+
236
+ loss = loss_gender
237
+
238
+ return (loss, outputs) if return_outputs else loss
239
+
240
+ num_steps = len(dataset["train"]) // (batch_size * accumulation_steps) // 5
241
+ num_steps = max(1, num_steps)
242
+ print(num_steps)
243
+
244
+ training_args = transformers.TrainingArguments(
245
+ output_dir=model_root,
246
+ logging_dir=log_root,
247
+ per_device_train_batch_size=batch_size,
248
+ per_device_eval_batch_size=batch_size,
249
+ gradient_accumulation_steps=accumulation_steps,
250
+ evaluation_strategy="steps",
251
+ num_train_epochs=5.0,
252
+ fp16=True,
253
+ save_steps=num_steps,
254
+ eval_steps=num_steps,
255
+ logging_steps=num_steps,
256
+ learning_rate=1e-4,
257
+ save_total_limit=2,
258
+ metric_for_best_model="combined",
259
+ greater_is_better=True,
260
+ load_best_model_at_end=True,
261
+ remove_unused_columns=False,
262
+ )
263
+
264
+ trainer = Trainer(
265
+ model=model,
266
+ data_collator=data_collator,
267
+ args=training_args,
268
+ compute_metrics=compute_metrics,
269
+ train_dataset=dataset["train"],
270
+ eval_dataset=dataset["dev"],
271
+ tokenizer=processor.feature_extractor,
272
+ callbacks=[transformers.integrations.TensorBoardCallback()],
273
+ )
274
+
275
+ trainer.train()
276
+ trainer.save_model(torch_root)
277
+
278
+ print("DONE")
279
+
280
+
281
+ def main(src_dir):
282
+ parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
283
+ parser.add_argument("--config", default="exp.ini", help="The base configuration")
284
+ args = parser.parse_args()
285
+ if args.config is not None:
286
+ config_file = args.config
287
+ else:
288
+ config_file = f"{src_dir}/exp.ini"
289
+ doit(config_file)
290
+
291
+
292
+ if __name__ == "__main__":
293
+ cwd = os.path.dirname(os.path.abspath(__file__))
294
+ main(cwd) # use this if you want to state the config file path on command line