nkululeko 0.84.0__tar.gz → 0.84.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nkululeko-0.84.0 → nkululeko-0.84.1}/CHANGELOG.md +4 -0
- {nkululeko-0.84.0/nkululeko.egg-info → nkululeko-0.84.1}/PKG-INFO +5 -1
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/resampler.py +9 -4
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/constants.py +1 -1
- nkululeko-0.84.1/nkululeko/models/finetune_model.py +181 -0
- nkululeko-0.84.1/nkululeko/resample.py +100 -0
- nkululeko-0.84.1/nkululeko/test_pretrain.py +294 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/utils/util.py +53 -32
- {nkululeko-0.84.0 → nkululeko-0.84.1/nkululeko.egg-info}/PKG-INFO +5 -1
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko.egg-info/SOURCES.txt +1 -0
- nkululeko-0.84.0/nkululeko/resample.py +0 -78
- nkululeko-0.84.0/nkululeko/test_pretrain.py +0 -117
- {nkululeko-0.84.0 → nkululeko-0.84.1}/LICENSE +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/README.md +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/aesdd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/androids/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/androids_orig/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/androids_test/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/ased/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/asvp-esd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/baved/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/cafe/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/clac/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/cmu-mosei/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/demos/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/ekorpus/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emns/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emofilm/convert_to_16k.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emofilm/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emorynlp/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emov-db/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emovo/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/emozionalmente/create.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/enterface/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/esd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/gerparas/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/iemocap/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/jl/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/jtes/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/meld/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/mesd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/mess/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/mlendsnd/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/msp-improv/process_database2.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/msp-podcast/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/oreau2/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/portuguese/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/ravdess/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/ravdess/process_database_speaker.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/savee/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/shemo/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/subesco/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/tess/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/thorsten-emotional/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/urdu/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/data/vivae/process_database.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/docs/source/conf.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/meta/demos/demo_best_model.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/meta/demos/my_experiment.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/meta/demos/my_experiment_local.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/meta/demos/plot_faster_anim.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/aug_train.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augment.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/augmenter.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/randomsplicer.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/augmenting/randomsplicing.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_age.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_arousal.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_dominance.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_gender.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_mos.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_pesq.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_sdr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_snr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_stoi.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/ap_valence.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/autopredict/estimate_snr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/cacheddataset.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/data/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/data/dataset.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/data/dataset_csv.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/demo.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/demo_feats.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/demo_predictor.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/experiment.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/explore.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/export.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_agender.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_agender_agender.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_analyser.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_auddim.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_audmodel.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_clap.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_hubert.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_import.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_mld.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_mos.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_opensmile.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_oxbow.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_praat.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_snr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_spectra.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_spkrec.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_squim.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_trill.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_wav2vec2.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_wavlm.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feats_whisper.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/featureset.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feat_extract/feinberg_praat.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/feature_extractor.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/file_checker.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/filter_data.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/glob_conf.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/losses/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/losses/loss_ccc.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/losses/loss_softf1loss.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/modelrunner.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_bayes.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_cnn.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_gmm.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_knn.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_knn_reg.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_lin_reg.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_mlp.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_mlp_regression.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_svm.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_svr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_tree.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_tree_reg.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_xgb.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/models/model_xgr.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/multidb.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/nkuluflag.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/nkululeko.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/plots.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/predict.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/defines.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/latex_writer.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/report.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/report_item.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/reporter.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/reporting/result.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/runmanager.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/scaler.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/segment.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/segmenting/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/segmenting/seg_inaspeechsegmenter.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/segmenting/seg_silero.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/syllable_nuclei.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/test.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/test_predictor.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/utils/__init__.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/utils/files.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko/utils/stats.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko.egg-info/dependency_links.txt +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko.egg-info/requires.txt +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/nkululeko.egg-info/top_level.txt +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/pyproject.toml +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/setup.cfg +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/setup.py +0 -0
- {nkululeko-0.84.0 → nkululeko-0.84.1}/venv/bin/activate_this.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.84.
|
3
|
+
Version: 0.84.1
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -333,6 +333,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
333
333
|
Changelog
|
334
334
|
=========
|
335
335
|
|
336
|
+
Version 0.84.1
|
337
|
+
--------------
|
338
|
+
* made resample independent of config file
|
339
|
+
|
336
340
|
Version 0.84.0
|
337
341
|
--------------
|
338
342
|
* added SHAP analysis
|
@@ -12,16 +12,19 @@ from nkululeko.utils.util import Util
|
|
12
12
|
|
13
13
|
|
14
14
|
class Resampler:
|
15
|
-
def __init__(self, df, not_testing=True):
|
15
|
+
def __init__(self, df, replace, not_testing=True):
|
16
16
|
self.SAMPLING_RATE = 16000
|
17
17
|
self.df = df
|
18
18
|
self.util = Util("resampler", has_config=not_testing)
|
19
19
|
self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
|
20
20
|
self.not_testing = not_testing
|
21
|
+
self.replace = eval(self.util.config_val(
|
22
|
+
"RESAMPLE", "replace", "False")) if not not_testing else replace
|
21
23
|
|
22
24
|
def resample(self):
|
23
25
|
files = self.df.index.get_level_values(0).values
|
24
|
-
replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
|
26
|
+
# replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
|
27
|
+
replace = self.replace
|
25
28
|
if self.not_testing:
|
26
29
|
store = self.util.get_path("store")
|
27
30
|
else:
|
@@ -42,7 +45,8 @@ class Resampler:
|
|
42
45
|
continue
|
43
46
|
if org_sr != self.SAMPLING_RATE:
|
44
47
|
self.util.debug(f"resampling {f} (sr = {org_sr})")
|
45
|
-
resampler = torchaudio.transforms.Resample(
|
48
|
+
resampler = torchaudio.transforms.Resample(
|
49
|
+
org_sr, self.SAMPLING_RATE)
|
46
50
|
signal = resampler(signal)
|
47
51
|
if replace:
|
48
52
|
torchaudio.save(
|
@@ -59,7 +63,8 @@ class Resampler:
|
|
59
63
|
self.df = self.df.set_index(
|
60
64
|
self.df.index.set_levels(new_files, level="file")
|
61
65
|
)
|
62
|
-
target_file = self.util.config_val(
|
66
|
+
target_file = self.util.config_val(
|
67
|
+
"RESAMPLE", "target", "resampled.csv")
|
63
68
|
# remove encoded labels
|
64
69
|
target = self.util.config_val("DATA", "target", "emotion")
|
65
70
|
if "class_label" in self.df.columns:
|
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.84.
|
1
|
+
VERSION="0.84.1"
|
2
2
|
SAMPLING_RATE = 16000
|
@@ -0,0 +1,181 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import typing
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import transformers
|
6
|
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
7
|
+
Wav2Vec2PreTrainedModel,
|
8
|
+
Wav2Vec2Model,
|
9
|
+
)
|
10
|
+
|
11
|
+
|
12
|
+
class ConcordanceCorCoeff(torch.nn.Module):
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
|
16
|
+
super().__init__()
|
17
|
+
|
18
|
+
self.mean = torch.mean
|
19
|
+
self.var = torch.var
|
20
|
+
self.sum = torch.sum
|
21
|
+
self.sqrt = torch.sqrt
|
22
|
+
self.std = torch.std
|
23
|
+
|
24
|
+
def forward(self, prediction, ground_truth):
|
25
|
+
|
26
|
+
mean_gt = self.mean(ground_truth, 0)
|
27
|
+
mean_pred = self.mean(prediction, 0)
|
28
|
+
var_gt = self.var(ground_truth, 0)
|
29
|
+
var_pred = self.var(prediction, 0)
|
30
|
+
v_pred = prediction - mean_pred
|
31
|
+
v_gt = ground_truth - mean_gt
|
32
|
+
cor = self.sum(v_pred * v_gt) / (
|
33
|
+
self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
|
34
|
+
)
|
35
|
+
sd_gt = self.std(ground_truth)
|
36
|
+
sd_pred = self.std(prediction)
|
37
|
+
numerator = 2 * cor * sd_gt * sd_pred
|
38
|
+
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
39
|
+
ccc = numerator / denominator
|
40
|
+
|
41
|
+
return 1 - ccc
|
42
|
+
|
43
|
+
|
44
|
+
@dataclasses.dataclass
|
45
|
+
class ModelOutput(transformers.file_utils.ModelOutput):
|
46
|
+
|
47
|
+
logits_cat: torch.FloatTensor = None
|
48
|
+
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
49
|
+
cnn_features: torch.FloatTensor = None
|
50
|
+
|
51
|
+
|
52
|
+
class ModelHead(torch.nn.Module):
|
53
|
+
|
54
|
+
def __init__(self, config, num_labels):
|
55
|
+
|
56
|
+
super().__init__()
|
57
|
+
|
58
|
+
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
59
|
+
self.dropout = torch.nn.Dropout(config.final_dropout)
|
60
|
+
self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
|
61
|
+
|
62
|
+
def forward(self, features, **kwargs):
|
63
|
+
|
64
|
+
x = features
|
65
|
+
x = self.dropout(x)
|
66
|
+
x = self.dense(x)
|
67
|
+
x = torch.tanh(x)
|
68
|
+
x = self.dropout(x)
|
69
|
+
x = self.out_proj(x)
|
70
|
+
|
71
|
+
return x
|
72
|
+
|
73
|
+
|
74
|
+
class Model(Wav2Vec2PreTrainedModel):
|
75
|
+
|
76
|
+
def __init__(self, config):
|
77
|
+
|
78
|
+
super().__init__(config)
|
79
|
+
|
80
|
+
self.wav2vec2 = Wav2Vec2Model(config)
|
81
|
+
self.cat = ModelHead(config, 2)
|
82
|
+
self.init_weights()
|
83
|
+
|
84
|
+
def freeze_feature_extractor(self):
|
85
|
+
self.wav2vec2.feature_extractor._freeze_parameters()
|
86
|
+
|
87
|
+
def pooling(
|
88
|
+
self,
|
89
|
+
hidden_states,
|
90
|
+
attention_mask,
|
91
|
+
):
|
92
|
+
|
93
|
+
if attention_mask is None: # For evaluation with batch_size==1
|
94
|
+
outputs = torch.mean(hidden_states, dim=1)
|
95
|
+
else:
|
96
|
+
attention_mask = self._get_feature_vector_attention_mask(
|
97
|
+
hidden_states.shape[1],
|
98
|
+
attention_mask,
|
99
|
+
)
|
100
|
+
hidden_states = hidden_states * torch.reshape(
|
101
|
+
attention_mask,
|
102
|
+
(-1, attention_mask.shape[-1], 1),
|
103
|
+
)
|
104
|
+
outputs = torch.sum(hidden_states, dim=1)
|
105
|
+
attention_sum = torch.sum(attention_mask, dim=1)
|
106
|
+
outputs = outputs / torch.reshape(attention_sum, (-1, 1))
|
107
|
+
|
108
|
+
return outputs
|
109
|
+
|
110
|
+
def forward(
|
111
|
+
self,
|
112
|
+
input_values,
|
113
|
+
attention_mask=None,
|
114
|
+
labels=None,
|
115
|
+
return_hidden=False,
|
116
|
+
):
|
117
|
+
|
118
|
+
outputs = self.wav2vec2(
|
119
|
+
input_values,
|
120
|
+
attention_mask=attention_mask,
|
121
|
+
)
|
122
|
+
|
123
|
+
cnn_features = outputs.extract_features
|
124
|
+
hidden_states_framewise = outputs.last_hidden_state
|
125
|
+
hidden_states = self.pooling(
|
126
|
+
hidden_states_framewise,
|
127
|
+
attention_mask,
|
128
|
+
)
|
129
|
+
logits_cat = self.cat(hidden_states)
|
130
|
+
|
131
|
+
if not self.training:
|
132
|
+
logits_cat = torch.softmax(logits_cat, dim=1)
|
133
|
+
|
134
|
+
if return_hidden:
|
135
|
+
|
136
|
+
# make time last axis
|
137
|
+
cnn_features = torch.transpose(cnn_features, 1, 2)
|
138
|
+
|
139
|
+
return ModelOutput(
|
140
|
+
logits_cat=logits_cat,
|
141
|
+
hidden_states=hidden_states,
|
142
|
+
cnn_features=cnn_features,
|
143
|
+
)
|
144
|
+
|
145
|
+
else:
|
146
|
+
|
147
|
+
return ModelOutput(
|
148
|
+
logits_cat=logits_cat,
|
149
|
+
)
|
150
|
+
|
151
|
+
|
152
|
+
class ModelWithPreProcessing(Model):
|
153
|
+
|
154
|
+
def __init__(self, config):
|
155
|
+
super().__init__(config)
|
156
|
+
|
157
|
+
def forward(
|
158
|
+
self,
|
159
|
+
input_values,
|
160
|
+
):
|
161
|
+
# Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
|
162
|
+
# normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
163
|
+
|
164
|
+
mean = input_values.mean()
|
165
|
+
|
166
|
+
# var = input_values.var()
|
167
|
+
# raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
|
168
|
+
|
169
|
+
var = torch.square(input_values - mean).mean()
|
170
|
+
input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
|
171
|
+
|
172
|
+
output = super().forward(
|
173
|
+
input_values,
|
174
|
+
return_hidden=True,
|
175
|
+
)
|
176
|
+
|
177
|
+
return (
|
178
|
+
output.hidden_states,
|
179
|
+
output.logits_cat,
|
180
|
+
output.cnn_features,
|
181
|
+
)
|
@@ -0,0 +1,100 @@
|
|
1
|
+
# resample.py
|
2
|
+
# change the sampling rate for audio file or INI file (train, test, all)
|
3
|
+
|
4
|
+
import argparse
|
5
|
+
import configparser
|
6
|
+
import os
|
7
|
+
import pandas as pd
|
8
|
+
import audformat
|
9
|
+
from nkululeko.augmenting.resampler import Resampler
|
10
|
+
from nkululeko.utils.util import Util
|
11
|
+
|
12
|
+
from nkululeko.constants import VERSION
|
13
|
+
from nkululeko.experiment import Experiment
|
14
|
+
|
15
|
+
|
16
|
+
def main(src_dir):
|
17
|
+
parser = argparse.ArgumentParser(
|
18
|
+
description="Call the nkululeko RESAMPLE framework.")
|
19
|
+
parser.add_argument("--config", default=None,
|
20
|
+
help="The base configuration")
|
21
|
+
parser.add_argument("--file", default=None,
|
22
|
+
help="The input audio file to resample")
|
23
|
+
parser.add_argument("--replace", action="store_true",
|
24
|
+
help="Replace the original audio file")
|
25
|
+
|
26
|
+
args = parser.parse_args()
|
27
|
+
|
28
|
+
if args.file is None and args.config is None:
|
29
|
+
print("ERROR: Either --file or --config argument must be provided.")
|
30
|
+
exit()
|
31
|
+
|
32
|
+
if args.file is not None:
|
33
|
+
# Load the audio file into a DataFrame
|
34
|
+
files = pd.Series([args.file])
|
35
|
+
df_sample = pd.DataFrame(index=files)
|
36
|
+
df_sample.index = audformat.utils.to_segmented_index(
|
37
|
+
df_sample.index, allow_nat=False
|
38
|
+
)
|
39
|
+
|
40
|
+
# Resample the audio file
|
41
|
+
util = Util("resampler", has_config=False)
|
42
|
+
util.debug(f"Resampling audio file: {args.file}")
|
43
|
+
rs = Resampler(df_sample, not_testing=True, replace=args.replace)
|
44
|
+
rs.resample()
|
45
|
+
else:
|
46
|
+
# Existing code for handling INI file
|
47
|
+
config_file = args.config
|
48
|
+
|
49
|
+
# Test if the configuration file exists
|
50
|
+
if not os.path.isfile(config_file):
|
51
|
+
print(f"ERROR: no such file: {config_file}")
|
52
|
+
exit()
|
53
|
+
|
54
|
+
# Load one configuration per experiment
|
55
|
+
config = configparser.ConfigParser()
|
56
|
+
config.read(config_file)
|
57
|
+
# Create a new experiment
|
58
|
+
expr = Experiment(config)
|
59
|
+
module = "resample"
|
60
|
+
expr.set_module(module)
|
61
|
+
util = Util(module)
|
62
|
+
util.debug(
|
63
|
+
f"running {expr.name} from config {config_file}, nkululeko version"
|
64
|
+
f" {VERSION}"
|
65
|
+
)
|
66
|
+
|
67
|
+
if util.config_val("EXP", "no_warnings", False):
|
68
|
+
import warnings
|
69
|
+
warnings.filterwarnings("ignore")
|
70
|
+
|
71
|
+
# Load the data
|
72
|
+
expr.load_datasets()
|
73
|
+
|
74
|
+
# Split into train and test
|
75
|
+
expr.fill_train_and_tests()
|
76
|
+
util.debug(
|
77
|
+
f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
78
|
+
|
79
|
+
sample_selection = util.config_val(
|
80
|
+
"RESAMPLE", "sample_selection", "all")
|
81
|
+
if sample_selection == "all":
|
82
|
+
df = pd.concat([expr.df_train, expr.df_test])
|
83
|
+
elif sample_selection == "train":
|
84
|
+
df = expr.df_train
|
85
|
+
elif sample_selection == "test":
|
86
|
+
df = expr.df_test
|
87
|
+
else:
|
88
|
+
util.error(
|
89
|
+
f"unknown selection specifier {sample_selection}, should be [all |"
|
90
|
+
" train | test]"
|
91
|
+
)
|
92
|
+
util.debug(f"resampling {sample_selection}: {df.shape[0]} samples")
|
93
|
+
replace = util.config_val("RESAMPLE", "replace", "False")
|
94
|
+
rs = Resampler(df, replace=replace)
|
95
|
+
rs.resample()
|
96
|
+
|
97
|
+
|
98
|
+
if __name__ == "__main__":
|
99
|
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
100
|
+
main(cwd)
|
@@ -0,0 +1,294 @@
|
|
1
|
+
# test_pretrain.py
|
2
|
+
import argparse
|
3
|
+
import configparser
|
4
|
+
import os.path
|
5
|
+
|
6
|
+
import datasets
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import torch
|
10
|
+
import transformers
|
11
|
+
|
12
|
+
import audeer
|
13
|
+
import audiofile
|
14
|
+
import audmetric
|
15
|
+
|
16
|
+
from nkululeko.constants import VERSION
|
17
|
+
import nkululeko.experiment as exp
|
18
|
+
import nkululeko.models.finetune_model as fm
|
19
|
+
import nkululeko.glob_conf as glob_conf
|
20
|
+
from nkululeko.utils.util import Util
|
21
|
+
import json
|
22
|
+
|
23
|
+
|
24
|
+
def doit(config_file):
|
25
|
+
# test if the configuration file exists
|
26
|
+
if not os.path.isfile(config_file):
|
27
|
+
print(f"ERROR: no such file: {config_file}")
|
28
|
+
exit()
|
29
|
+
|
30
|
+
# load one configuration per experiment
|
31
|
+
config = configparser.ConfigParser()
|
32
|
+
config.read(config_file)
|
33
|
+
|
34
|
+
# create a new experiment
|
35
|
+
expr = exp.Experiment(config)
|
36
|
+
module = "test_pretrain"
|
37
|
+
expr.set_module(module)
|
38
|
+
util = Util(module)
|
39
|
+
util.debug(
|
40
|
+
f"running {expr.name} from config {config_file}, nkululeko version"
|
41
|
+
f" {VERSION}"
|
42
|
+
)
|
43
|
+
|
44
|
+
if util.config_val("EXP", "no_warnings", False):
|
45
|
+
import warnings
|
46
|
+
|
47
|
+
warnings.filterwarnings("ignore")
|
48
|
+
|
49
|
+
# load the data
|
50
|
+
expr.load_datasets()
|
51
|
+
|
52
|
+
# split into train and test
|
53
|
+
expr.fill_train_and_tests()
|
54
|
+
util.debug(f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
55
|
+
|
56
|
+
log_root = audeer.mkdir("log")
|
57
|
+
model_root = audeer.mkdir("model")
|
58
|
+
torch_root = audeer.path(model_root, "torch")
|
59
|
+
|
60
|
+
metrics_gender = {
|
61
|
+
"UAR": audmetric.unweighted_average_recall,
|
62
|
+
"ACC": audmetric.accuracy,
|
63
|
+
}
|
64
|
+
|
65
|
+
sampling_rate = 16000
|
66
|
+
max_duration_sec = 8.0
|
67
|
+
|
68
|
+
model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
|
69
|
+
num_layers = None
|
70
|
+
|
71
|
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
72
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
73
|
+
|
74
|
+
batch_size = 16
|
75
|
+
accumulation_steps = 4
|
76
|
+
# create dataset
|
77
|
+
|
78
|
+
dataset = {}
|
79
|
+
target_name = glob_conf.target
|
80
|
+
data_sources = {
|
81
|
+
"train": pd.DataFrame(expr.df_train[target_name]),
|
82
|
+
"dev": pd.DataFrame(expr.df_test[target_name]),
|
83
|
+
}
|
84
|
+
|
85
|
+
for split in ["train", "dev"]:
|
86
|
+
df = data_sources[split]
|
87
|
+
df[target_name] = df[target_name].astype("float")
|
88
|
+
|
89
|
+
y = pd.Series(
|
90
|
+
data=df.itertuples(index=False, name=None),
|
91
|
+
index=df.index,
|
92
|
+
dtype=object,
|
93
|
+
name="labels",
|
94
|
+
)
|
95
|
+
|
96
|
+
y.name = "targets"
|
97
|
+
df = y.reset_index()
|
98
|
+
df.start = df.start.dt.total_seconds()
|
99
|
+
df.end = df.end.dt.total_seconds()
|
100
|
+
|
101
|
+
print(f"{split}: {len(df)}")
|
102
|
+
|
103
|
+
ds = datasets.Dataset.from_pandas(df)
|
104
|
+
dataset[split] = ds
|
105
|
+
|
106
|
+
dataset = datasets.DatasetDict(dataset)
|
107
|
+
|
108
|
+
# load pre-trained model
|
109
|
+
le = glob_conf.label_encoder
|
110
|
+
mapping = dict(zip(le.classes_, range(len(le.classes_))))
|
111
|
+
target_mapping = {k: int(v) for k, v in mapping.items()}
|
112
|
+
target_mapping_reverse = {value: key for key, value in target_mapping.items()}
|
113
|
+
|
114
|
+
config = transformers.AutoConfig.from_pretrained(
|
115
|
+
model_path,
|
116
|
+
num_labels=len(target_mapping),
|
117
|
+
label2id=target_mapping,
|
118
|
+
id2label=target_mapping_reverse,
|
119
|
+
finetuning_task=target_name,
|
120
|
+
)
|
121
|
+
if num_layers is not None:
|
122
|
+
config.num_hidden_layers = num_layers
|
123
|
+
setattr(config, "sampling_rate", sampling_rate)
|
124
|
+
setattr(config, "data", util.get_data_name())
|
125
|
+
|
126
|
+
vocab_dict = {}
|
127
|
+
with open("vocab.json", "w") as vocab_file:
|
128
|
+
json.dump(vocab_dict, vocab_file)
|
129
|
+
tokenizer = transformers.Wav2Vec2CTCTokenizer("./vocab.json")
|
130
|
+
tokenizer.save_pretrained(".")
|
131
|
+
|
132
|
+
feature_extractor = transformers.Wav2Vec2FeatureExtractor(
|
133
|
+
feature_size=1,
|
134
|
+
sampling_rate=16000,
|
135
|
+
padding_value=0.0,
|
136
|
+
do_normalize=True,
|
137
|
+
return_attention_mask=True,
|
138
|
+
)
|
139
|
+
processor = transformers.Wav2Vec2Processor(
|
140
|
+
feature_extractor=feature_extractor,
|
141
|
+
tokenizer=tokenizer,
|
142
|
+
)
|
143
|
+
assert processor.feature_extractor.sampling_rate == sampling_rate
|
144
|
+
|
145
|
+
model = fm.Model.from_pretrained(
|
146
|
+
model_path,
|
147
|
+
config=config,
|
148
|
+
)
|
149
|
+
model.freeze_feature_extractor()
|
150
|
+
model.train()
|
151
|
+
|
152
|
+
# training
|
153
|
+
|
154
|
+
def data_collator(data):
|
155
|
+
|
156
|
+
files = [d["file"] for d in data]
|
157
|
+
starts = [d["start"] for d in data]
|
158
|
+
ends = [d["end"] for d in data]
|
159
|
+
targets = [d["targets"] for d in data]
|
160
|
+
|
161
|
+
signals = []
|
162
|
+
for file, start, end in zip(
|
163
|
+
files,
|
164
|
+
starts,
|
165
|
+
ends,
|
166
|
+
):
|
167
|
+
offset = start
|
168
|
+
duration = end - offset
|
169
|
+
if max_duration_sec is not None:
|
170
|
+
duration = min(duration, max_duration_sec)
|
171
|
+
signal, _ = audiofile.read(
|
172
|
+
file,
|
173
|
+
offset=offset,
|
174
|
+
duration=duration,
|
175
|
+
)
|
176
|
+
signals.append(signal.squeeze())
|
177
|
+
|
178
|
+
input_values = processor(
|
179
|
+
signals,
|
180
|
+
sampling_rate=sampling_rate,
|
181
|
+
padding=True,
|
182
|
+
)
|
183
|
+
batch = processor.pad(
|
184
|
+
input_values,
|
185
|
+
padding=True,
|
186
|
+
return_tensors="pt",
|
187
|
+
)
|
188
|
+
|
189
|
+
batch["labels"] = torch.tensor(targets)
|
190
|
+
|
191
|
+
return batch
|
192
|
+
|
193
|
+
def compute_metrics(p: transformers.EvalPrediction):
|
194
|
+
|
195
|
+
truth_gender = p.label_ids[:, 0].astype(int)
|
196
|
+
preds = p.predictions
|
197
|
+
preds_gender = np.argmax(preds, axis=1)
|
198
|
+
|
199
|
+
scores = {}
|
200
|
+
|
201
|
+
for name, metric in metrics_gender.items():
|
202
|
+
scores[f"gender-{name}"] = metric(truth_gender, preds_gender)
|
203
|
+
|
204
|
+
scores["combined"] = scores["gender-UAR"]
|
205
|
+
|
206
|
+
return scores
|
207
|
+
|
208
|
+
targets = pd.DataFrame(dataset["train"]["targets"])
|
209
|
+
counts = targets[0].value_counts().sort_index()
|
210
|
+
train_weights = 1 / counts
|
211
|
+
train_weights /= train_weights.sum()
|
212
|
+
|
213
|
+
print(train_weights)
|
214
|
+
|
215
|
+
criterion_gender = torch.nn.CrossEntropyLoss(
|
216
|
+
weight=torch.Tensor(train_weights).to("cuda"),
|
217
|
+
)
|
218
|
+
|
219
|
+
class Trainer(transformers.Trainer):
|
220
|
+
|
221
|
+
def compute_loss(
|
222
|
+
self,
|
223
|
+
model,
|
224
|
+
inputs,
|
225
|
+
return_outputs=False,
|
226
|
+
):
|
227
|
+
|
228
|
+
targets = inputs.pop("labels").squeeze()
|
229
|
+
targets_gender = targets.type(torch.long)
|
230
|
+
|
231
|
+
outputs = model(**inputs)
|
232
|
+
logits_gender = outputs[0].squeeze()
|
233
|
+
|
234
|
+
loss_gender = criterion_gender(logits_gender, targets_gender)
|
235
|
+
|
236
|
+
loss = loss_gender
|
237
|
+
|
238
|
+
return (loss, outputs) if return_outputs else loss
|
239
|
+
|
240
|
+
num_steps = len(dataset["train"]) // (batch_size * accumulation_steps) // 5
|
241
|
+
num_steps = max(1, num_steps)
|
242
|
+
print(num_steps)
|
243
|
+
|
244
|
+
training_args = transformers.TrainingArguments(
|
245
|
+
output_dir=model_root,
|
246
|
+
logging_dir=log_root,
|
247
|
+
per_device_train_batch_size=batch_size,
|
248
|
+
per_device_eval_batch_size=batch_size,
|
249
|
+
gradient_accumulation_steps=accumulation_steps,
|
250
|
+
evaluation_strategy="steps",
|
251
|
+
num_train_epochs=5.0,
|
252
|
+
fp16=True,
|
253
|
+
save_steps=num_steps,
|
254
|
+
eval_steps=num_steps,
|
255
|
+
logging_steps=num_steps,
|
256
|
+
learning_rate=1e-4,
|
257
|
+
save_total_limit=2,
|
258
|
+
metric_for_best_model="combined",
|
259
|
+
greater_is_better=True,
|
260
|
+
load_best_model_at_end=True,
|
261
|
+
remove_unused_columns=False,
|
262
|
+
)
|
263
|
+
|
264
|
+
trainer = Trainer(
|
265
|
+
model=model,
|
266
|
+
data_collator=data_collator,
|
267
|
+
args=training_args,
|
268
|
+
compute_metrics=compute_metrics,
|
269
|
+
train_dataset=dataset["train"],
|
270
|
+
eval_dataset=dataset["dev"],
|
271
|
+
tokenizer=processor.feature_extractor,
|
272
|
+
callbacks=[transformers.integrations.TensorBoardCallback()],
|
273
|
+
)
|
274
|
+
|
275
|
+
trainer.train()
|
276
|
+
trainer.save_model(torch_root)
|
277
|
+
|
278
|
+
print("DONE")
|
279
|
+
|
280
|
+
|
281
|
+
def main(src_dir):
|
282
|
+
parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
|
283
|
+
parser.add_argument("--config", default="exp.ini", help="The base configuration")
|
284
|
+
args = parser.parse_args()
|
285
|
+
if args.config is not None:
|
286
|
+
config_file = args.config
|
287
|
+
else:
|
288
|
+
config_file = f"{src_dir}/exp.ini"
|
289
|
+
doit(config_file)
|
290
|
+
|
291
|
+
|
292
|
+
if __name__ == "__main__":
|
293
|
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
294
|
+
main(cwd) # use this if you want to state the config file path on command line
|