nkululeko 0.84.0__tar.gz → 0.85.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nkululeko-0.84.0 → nkululeko-0.85.0}/CHANGELOG.md +8 -0
- {nkululeko-0.84.0/nkululeko.egg-info → nkululeko-0.85.0}/PKG-INFO +9 -1
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/augmenting/resampler.py +9 -4
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/constants.py +1 -1
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/experiment.py +6 -1
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_whisper.py +3 -6
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/modelrunner.py +56 -33
- nkululeko-0.85.0/nkululeko/models/finetune_model.py +190 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model.py +1 -1
- nkululeko-0.85.0/nkululeko/models/model_tuned.py +506 -0
- nkululeko-0.85.0/nkululeko/resample.py +100 -0
- nkululeko-0.85.0/nkululeko/test_pretrain.py +306 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/utils/util.py +53 -32
- {nkululeko-0.84.0 → nkululeko-0.85.0/nkululeko.egg-info}/PKG-INFO +9 -1
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko.egg-info/SOURCES.txt +2 -0
- nkululeko-0.84.0/nkululeko/resample.py +0 -78
- nkululeko-0.84.0/nkululeko/test_pretrain.py +0 -117
- {nkululeko-0.84.0 → nkululeko-0.85.0}/LICENSE +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/README.md +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/aesdd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/androids/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/androids_orig/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/androids_test/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/ased/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/asvp-esd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/baved/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/cafe/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/clac/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/cmu-mosei/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/demos/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/ekorpus/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/emns/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/emofilm/convert_to_16k.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/emofilm/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/emorynlp/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/emov-db/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/emovo/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/emozionalmente/create.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/enterface/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/esd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/gerparas/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/iemocap/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/jl/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/jtes/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/meld/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/mesd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/mess/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/mlendsnd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/msp-improv/process_database2.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/msp-podcast/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/oreau2/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/portuguese/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/ravdess/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/ravdess/process_database_speaker.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/savee/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/shemo/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/subesco/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/tess/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/thorsten-emotional/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/urdu/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/data/vivae/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/docs/source/conf.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/meta/demos/demo_best_model.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/meta/demos/my_experiment.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/meta/demos/my_experiment_local.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/meta/demos/plot_faster_anim.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/aug_train.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/augment.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/augmenting/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/augmenting/augmenter.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/augmenting/randomsplicer.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/augmenting/randomsplicing.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_age.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_arousal.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_dominance.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_gender.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_mos.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_pesq.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_sdr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_snr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_stoi.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/ap_valence.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/autopredict/estimate_snr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/cacheddataset.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/data/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/data/dataset.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/data/dataset_csv.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/demo.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/demo_feats.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/demo_predictor.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/explore.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/export.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_agender.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_agender_agender.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_analyser.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_auddim.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_audmodel.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_clap.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_hubert.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_import.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_mld.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_mos.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_opensmile.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_oxbow.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_praat.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_snr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_spectra.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_spkrec.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_squim.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_trill.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_wav2vec2.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feats_wavlm.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/featureset.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feat_extract/feinberg_praat.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/feature_extractor.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/file_checker.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/filter_data.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/glob_conf.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/losses/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/losses/loss_ccc.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/losses/loss_softf1loss.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_bayes.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_cnn.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_gmm.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_knn.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_knn_reg.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_lin_reg.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_mlp.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_mlp_regression.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_svm.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_svr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_tree.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_tree_reg.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_xgb.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/models/model_xgr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/multidb.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/nkuluflag.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/nkululeko.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/plots.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/predict.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/reporting/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/reporting/defines.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/reporting/latex_writer.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/reporting/report.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/reporting/report_item.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/reporting/reporter.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/reporting/result.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/runmanager.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/scaler.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/segment.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/segmenting/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/segmenting/seg_inaspeechsegmenter.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/segmenting/seg_silero.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/syllable_nuclei.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/test.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/test_predictor.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/utils/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/utils/files.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko/utils/stats.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko.egg-info/dependency_links.txt +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko.egg-info/requires.txt +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/nkululeko.egg-info/top_level.txt +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/pyproject.toml +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/setup.cfg +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/setup.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.85.0}/venv/bin/activate_this.py +0 -0
@@ -1,6 +1,14 @@
|
|
1
1
|
Changelog
|
2
2
|
=========
|
3
3
|
|
4
|
+
Version 0.85.0
|
5
|
+
--------------
|
6
|
+
* first version with finetuning wav2vec2 layers
|
7
|
+
|
8
|
+
Version 0.84.1
|
9
|
+
--------------
|
10
|
+
* made resample independent of config file
|
11
|
+
|
4
12
|
Version 0.84.0
|
5
13
|
--------------
|
6
14
|
* added SHAP analysis
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.85.0
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -333,6 +333,14 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
333
333
|
Changelog
|
334
334
|
=========
|
335
335
|
|
336
|
+
Version 0.85.0
|
337
|
+
--------------
|
338
|
+
* first version with finetuning wav2vec2 layers
|
339
|
+
|
340
|
+
Version 0.84.1
|
341
|
+
--------------
|
342
|
+
* made resample independent of config file
|
343
|
+
|
336
344
|
Version 0.84.0
|
337
345
|
--------------
|
338
346
|
* added SHAP analysis
|
@@ -12,16 +12,19 @@ from nkululeko.utils.util import Util
|
|
12
12
|
|
13
13
|
|
14
14
|
class Resampler:
|
15
|
-
def __init__(self, df, not_testing=True):
|
15
|
+
def __init__(self, df, replace, not_testing=True):
|
16
16
|
self.SAMPLING_RATE = 16000
|
17
17
|
self.df = df
|
18
18
|
self.util = Util("resampler", has_config=not_testing)
|
19
19
|
self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
|
20
20
|
self.not_testing = not_testing
|
21
|
+
self.replace = eval(self.util.config_val(
|
22
|
+
"RESAMPLE", "replace", "False")) if not not_testing else replace
|
21
23
|
|
22
24
|
def resample(self):
|
23
25
|
files = self.df.index.get_level_values(0).values
|
24
|
-
replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
|
26
|
+
# replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
|
27
|
+
replace = self.replace
|
25
28
|
if self.not_testing:
|
26
29
|
store = self.util.get_path("store")
|
27
30
|
else:
|
@@ -42,7 +45,8 @@ class Resampler:
|
|
42
45
|
continue
|
43
46
|
if org_sr != self.SAMPLING_RATE:
|
44
47
|
self.util.debug(f"resampling {f} (sr = {org_sr})")
|
45
|
-
resampler = torchaudio.transforms.Resample(
|
48
|
+
resampler = torchaudio.transforms.Resample(
|
49
|
+
org_sr, self.SAMPLING_RATE)
|
46
50
|
signal = resampler(signal)
|
47
51
|
if replace:
|
48
52
|
torchaudio.save(
|
@@ -59,7 +63,8 @@ class Resampler:
|
|
59
63
|
self.df = self.df.set_index(
|
60
64
|
self.df.index.set_levels(new_files, level="file")
|
61
65
|
)
|
62
|
-
target_file = self.util.config_val(
|
66
|
+
target_file = self.util.config_val(
|
67
|
+
"RESAMPLE", "target", "resampled.csv")
|
63
68
|
# remove encoded labels
|
64
69
|
target = self.util.config_val("DATA", "target", "emotion")
|
65
70
|
if "class_label" in self.df.columns:
|
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.
|
1
|
+
VERSION="0.85.0"
|
2
2
|
SAMPLING_RATE = 16000
|
@@ -340,7 +340,12 @@ class Experiment:
|
|
340
340
|
df_train, df_test = self.df_train, self.df_test
|
341
341
|
feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
|
342
342
|
self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
|
343
|
-
feats_types = self.util.config_val_list("FEATS", "type", [
|
343
|
+
feats_types = self.util.config_val_list("FEATS", "type", [])
|
344
|
+
# for some models no features are needed
|
345
|
+
if len(feats_types) == 0:
|
346
|
+
self.util.debug("no feature extractor specified.")
|
347
|
+
self.feats_train, self.feats_test = pd.DataFrame(), pd.DataFrame()
|
348
|
+
return
|
344
349
|
self.feature_extractor = FeatureExtractor(
|
345
350
|
df_train, feats_types, feats_name, "train"
|
346
351
|
)
|
@@ -32,22 +32,19 @@ class Whisper(Featureset):
|
|
32
32
|
model_name = f"openai/{self.feat_type}"
|
33
33
|
self.model = WhisperModel.from_pretrained(model_name).to(self.device)
|
34
34
|
print(f"intialized Whisper model on {self.device}")
|
35
|
-
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
36
|
-
model_name)
|
35
|
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
37
36
|
self.model_initialized = True
|
38
37
|
|
39
38
|
def extract(self):
|
40
39
|
"""Extract the features or load them from disk if present."""
|
41
40
|
store = self.util.get_path("store")
|
42
41
|
storage = f"{store}{self.name}.pkl"
|
43
|
-
extract = self.util.config_val(
|
44
|
-
"FEATS", "needs_feature_extraction", False)
|
42
|
+
extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
|
45
43
|
no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
|
46
44
|
if extract or no_reuse or not os.path.isfile(storage):
|
47
45
|
if not self.model_initialized:
|
48
46
|
self.init_model()
|
49
|
-
self.util.debug(
|
50
|
-
"extracting whisper embeddings, this might take a while...")
|
47
|
+
self.util.debug("extracting whisper embeddings, this might take a while...")
|
51
48
|
emb_series = []
|
52
49
|
for (file, start, end), _ in audeer.progress_bar(
|
53
50
|
self.data_df.iterrows(),
|
@@ -47,16 +47,12 @@ class Modelrunner:
|
|
47
47
|
highest = 0
|
48
48
|
else:
|
49
49
|
highest = 100000
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
self.model.load(self.run, epoch)
|
54
|
-
self.util.debug(f"reusing model: {self.model.store_path}")
|
55
|
-
self.model.reset_test(self.df_test, self.feats_test)
|
56
|
-
else:
|
57
|
-
self.model.set_id(self.run, epoch)
|
58
|
-
self.model.train()
|
50
|
+
if self.model.model_type == "finetuned":
|
51
|
+
# epochs are handled by Huggingface API
|
52
|
+
self.model.train()
|
59
53
|
report = self.model.predict()
|
54
|
+
# todo: findout the best epoch
|
55
|
+
epoch = epoch_num
|
60
56
|
report.set_id(self.run, epoch)
|
61
57
|
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
62
58
|
reports.append(report)
|
@@ -67,32 +63,53 @@ class Modelrunner:
|
|
67
63
|
if plot_epochs:
|
68
64
|
self.util.debug(f"plotting conf matrix to {plot_name}")
|
69
65
|
report.plot_confmatrix(plot_name, epoch)
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
patience = int(patience)
|
78
|
-
result = report.result.get_result()
|
79
|
-
if self.util.high_is_good():
|
80
|
-
if result > highest:
|
81
|
-
highest = result
|
82
|
-
patience_counter = 0
|
83
|
-
else:
|
84
|
-
patience_counter += 1
|
66
|
+
else:
|
67
|
+
# for all epochs
|
68
|
+
for epoch in range(epoch_num):
|
69
|
+
if only_test:
|
70
|
+
self.model.load(self.run, epoch)
|
71
|
+
self.util.debug(f"reusing model: {self.model.store_path}")
|
72
|
+
self.model.reset_test(self.df_test, self.feats_test)
|
85
73
|
else:
|
86
|
-
|
87
|
-
|
88
|
-
|
74
|
+
self.model.set_id(self.run, epoch)
|
75
|
+
self.model.train()
|
76
|
+
report = self.model.predict()
|
77
|
+
report.set_id(self.run, epoch)
|
78
|
+
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
79
|
+
reports.append(report)
|
80
|
+
self.util.debug(
|
81
|
+
f"run: {self.run} epoch: {epoch}: result: "
|
82
|
+
f"{reports[-1].get_result().get_test_result()}"
|
83
|
+
)
|
84
|
+
if plot_epochs:
|
85
|
+
self.util.debug(f"plotting conf matrix to {plot_name}")
|
86
|
+
report.plot_confmatrix(plot_name, epoch)
|
87
|
+
store_models = self.util.config_val("EXP", "save", False)
|
88
|
+
plot_best_model = self.util.config_val("PLOT", "best_model", False)
|
89
|
+
if (store_models or plot_best_model) and (
|
90
|
+
not only_test
|
91
|
+
): # in any case the model needs to be stored to disk.
|
92
|
+
self.model.store()
|
93
|
+
if patience:
|
94
|
+
patience = int(patience)
|
95
|
+
result = report.result.get_result()
|
96
|
+
if self.util.high_is_good():
|
97
|
+
if result > highest:
|
98
|
+
highest = result
|
99
|
+
patience_counter = 0
|
100
|
+
else:
|
101
|
+
patience_counter += 1
|
89
102
|
else:
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
103
|
+
if result < highest:
|
104
|
+
highest = result
|
105
|
+
patience_counter = 0
|
106
|
+
else:
|
107
|
+
patience_counter += 1
|
108
|
+
if patience_counter >= patience:
|
109
|
+
self.util.debug(
|
110
|
+
f"reached patience ({str(patience)}): early stopping"
|
111
|
+
)
|
112
|
+
break
|
96
113
|
|
97
114
|
if not plot_epochs:
|
98
115
|
# Do at least one confusion matrix plot
|
@@ -133,6 +150,12 @@ class Modelrunner:
|
|
133
150
|
self.model = Bayes_model(
|
134
151
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
135
152
|
)
|
153
|
+
elif model_type == "finetune":
|
154
|
+
from nkululeko.models.model_tuned import Pretrained_model
|
155
|
+
|
156
|
+
self.model = Pretrained_model(
|
157
|
+
self.df_train, self.df_test, self.feats_train, self.feats_test
|
158
|
+
)
|
136
159
|
elif model_type == "gmm":
|
137
160
|
from nkululeko.models.model_gmm import GMM_model
|
138
161
|
|
@@ -0,0 +1,190 @@
|
|
1
|
+
"""
|
2
|
+
Code based on @jwagner
|
3
|
+
"""
|
4
|
+
|
5
|
+
import dataclasses
|
6
|
+
import typing
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import transformers
|
10
|
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
11
|
+
Wav2Vec2PreTrainedModel,
|
12
|
+
Wav2Vec2Model,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
class ConcordanceCorCoeff(torch.nn.Module):
|
17
|
+
|
18
|
+
def __init__(self):
|
19
|
+
|
20
|
+
super().__init__()
|
21
|
+
|
22
|
+
self.mean = torch.mean
|
23
|
+
self.var = torch.var
|
24
|
+
self.sum = torch.sum
|
25
|
+
self.sqrt = torch.sqrt
|
26
|
+
self.std = torch.std
|
27
|
+
|
28
|
+
def forward(self, prediction, ground_truth):
|
29
|
+
|
30
|
+
mean_gt = self.mean(ground_truth, 0)
|
31
|
+
mean_pred = self.mean(prediction, 0)
|
32
|
+
var_gt = self.var(ground_truth, 0)
|
33
|
+
var_pred = self.var(prediction, 0)
|
34
|
+
v_pred = prediction - mean_pred
|
35
|
+
v_gt = ground_truth - mean_gt
|
36
|
+
cor = self.sum(v_pred * v_gt) / (
|
37
|
+
self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
|
38
|
+
)
|
39
|
+
sd_gt = self.std(ground_truth)
|
40
|
+
sd_pred = self.std(prediction)
|
41
|
+
numerator = 2 * cor * sd_gt * sd_pred
|
42
|
+
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
43
|
+
ccc = numerator / denominator
|
44
|
+
|
45
|
+
return 1 - ccc
|
46
|
+
|
47
|
+
|
48
|
+
@dataclasses.dataclass
|
49
|
+
class ModelOutput(transformers.file_utils.ModelOutput):
|
50
|
+
|
51
|
+
logits_cat: torch.FloatTensor = None
|
52
|
+
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
53
|
+
cnn_features: torch.FloatTensor = None
|
54
|
+
|
55
|
+
|
56
|
+
class ModelHead(torch.nn.Module):
|
57
|
+
|
58
|
+
def __init__(self, config, num_labels):
|
59
|
+
|
60
|
+
super().__init__()
|
61
|
+
|
62
|
+
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
63
|
+
self.dropout = torch.nn.Dropout(config.final_dropout)
|
64
|
+
self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
|
65
|
+
|
66
|
+
def forward(self, features, **kwargs):
|
67
|
+
|
68
|
+
x = features
|
69
|
+
x = self.dropout(x)
|
70
|
+
x = self.dense(x)
|
71
|
+
x = torch.tanh(x)
|
72
|
+
x = self.dropout(x)
|
73
|
+
x = self.out_proj(x)
|
74
|
+
|
75
|
+
return x
|
76
|
+
|
77
|
+
|
78
|
+
class Model(Wav2Vec2PreTrainedModel):
|
79
|
+
|
80
|
+
def __init__(self, config):
|
81
|
+
|
82
|
+
super().__init__(config)
|
83
|
+
|
84
|
+
self.wav2vec2 = Wav2Vec2Model(config)
|
85
|
+
self.cat = ModelHead(config, 2)
|
86
|
+
self.init_weights()
|
87
|
+
|
88
|
+
def freeze_feature_extractor(self):
|
89
|
+
self.wav2vec2.feature_extractor._freeze_parameters()
|
90
|
+
|
91
|
+
def pooling(
|
92
|
+
self,
|
93
|
+
hidden_states,
|
94
|
+
attention_mask,
|
95
|
+
):
|
96
|
+
|
97
|
+
if attention_mask is None: # For evaluation with batch_size==1
|
98
|
+
outputs = torch.mean(hidden_states, dim=1)
|
99
|
+
else:
|
100
|
+
attention_mask = self._get_feature_vector_attention_mask(
|
101
|
+
hidden_states.shape[1],
|
102
|
+
attention_mask,
|
103
|
+
)
|
104
|
+
hidden_states = hidden_states * torch.reshape(
|
105
|
+
attention_mask,
|
106
|
+
(-1, attention_mask.shape[-1], 1),
|
107
|
+
)
|
108
|
+
outputs = torch.sum(hidden_states, dim=1)
|
109
|
+
attention_sum = torch.sum(attention_mask, dim=1)
|
110
|
+
outputs = outputs / torch.reshape(attention_sum, (-1, 1))
|
111
|
+
|
112
|
+
return outputs
|
113
|
+
|
114
|
+
def forward(
|
115
|
+
self,
|
116
|
+
input_values,
|
117
|
+
attention_mask=None,
|
118
|
+
labels=None,
|
119
|
+
return_hidden=False,
|
120
|
+
):
|
121
|
+
|
122
|
+
outputs = self.wav2vec2(
|
123
|
+
input_values,
|
124
|
+
attention_mask=attention_mask,
|
125
|
+
)
|
126
|
+
|
127
|
+
cnn_features = outputs.extract_features
|
128
|
+
hidden_states_framewise = outputs.last_hidden_state
|
129
|
+
hidden_states = self.pooling(
|
130
|
+
hidden_states_framewise,
|
131
|
+
attention_mask,
|
132
|
+
)
|
133
|
+
logits_cat = self.cat(hidden_states)
|
134
|
+
|
135
|
+
if not self.training:
|
136
|
+
logits_cat = torch.softmax(logits_cat, dim=1)
|
137
|
+
|
138
|
+
if return_hidden:
|
139
|
+
|
140
|
+
# make time last axis
|
141
|
+
cnn_features = torch.transpose(cnn_features, 1, 2)
|
142
|
+
|
143
|
+
return ModelOutput(
|
144
|
+
logits_cat=logits_cat,
|
145
|
+
hidden_states=hidden_states,
|
146
|
+
cnn_features=cnn_features,
|
147
|
+
)
|
148
|
+
|
149
|
+
else:
|
150
|
+
|
151
|
+
return ModelOutput(
|
152
|
+
logits_cat=logits_cat,
|
153
|
+
)
|
154
|
+
|
155
|
+
def predict(self, signal):
|
156
|
+
result = self(torch.from_numpy(signal))
|
157
|
+
result = result[0].detach().numpy()[0]
|
158
|
+
return result
|
159
|
+
|
160
|
+
|
161
|
+
class ModelWithPreProcessing(Model):
|
162
|
+
|
163
|
+
def __init__(self, config):
|
164
|
+
super().__init__(config)
|
165
|
+
|
166
|
+
def forward(
|
167
|
+
self,
|
168
|
+
input_values,
|
169
|
+
):
|
170
|
+
# Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
|
171
|
+
# normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
172
|
+
|
173
|
+
mean = input_values.mean()
|
174
|
+
|
175
|
+
# var = input_values.var()
|
176
|
+
# raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
|
177
|
+
|
178
|
+
var = torch.square(input_values - mean).mean()
|
179
|
+
input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
|
180
|
+
|
181
|
+
output = super().forward(
|
182
|
+
input_values,
|
183
|
+
return_hidden=True,
|
184
|
+
)
|
185
|
+
|
186
|
+
return (
|
187
|
+
output.hidden_states,
|
188
|
+
output.logits_cat,
|
189
|
+
output.cnn_features,
|
190
|
+
)
|