nkululeko 0.84.1__tar.gz → 0.85.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nkululeko-0.84.1 → nkululeko-0.85.1}/CHANGELOG.md +8 -0
- {nkululeko-0.84.1/nkululeko.egg-info → nkululeko-0.85.1}/PKG-INFO +9 -1
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/constants.py +1 -1
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/experiment.py +6 -1
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_whisper.py +3 -6
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/modelrunner.py +56 -33
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/finetune_model.py +9 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model.py +1 -1
- nkululeko-0.85.1/nkululeko/models/model_tuned.py +479 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/test_pretrain.py +16 -4
- {nkululeko-0.84.1 → nkululeko-0.85.1/nkululeko.egg-info}/PKG-INFO +9 -1
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko.egg-info/SOURCES.txt +1 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/LICENSE +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/README.md +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/aesdd/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/androids/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/androids_orig/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/androids_test/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/ased/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/asvp-esd/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/baved/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/cafe/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/clac/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/cmu-mosei/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/demos/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/ekorpus/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/emns/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/emofilm/convert_to_16k.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/emofilm/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/emorynlp/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/emov-db/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/emovo/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/emozionalmente/create.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/enterface/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/esd/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/gerparas/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/iemocap/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/jl/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/jtes/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/meld/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/mesd/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/mess/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/mlendsnd/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/msp-improv/process_database2.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/msp-podcast/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/oreau2/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/portuguese/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/ravdess/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/ravdess/process_database_speaker.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/savee/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/shemo/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/subesco/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/tess/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/thorsten-emotional/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/urdu/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/data/vivae/process_database.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/docs/source/conf.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/meta/demos/demo_best_model.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/meta/demos/my_experiment.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/meta/demos/my_experiment_local.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/meta/demos/plot_faster_anim.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/aug_train.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/augment.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/augmenting/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/augmenting/augmenter.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/augmenting/randomsplicer.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/augmenting/randomsplicing.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/augmenting/resampler.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_age.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_arousal.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_dominance.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_gender.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_mos.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_pesq.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_sdr.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_snr.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_stoi.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/ap_valence.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/autopredict/estimate_snr.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/cacheddataset.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/data/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/data/dataset.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/data/dataset_csv.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/demo.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/demo_feats.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/demo_predictor.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/explore.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/export.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_agender.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_agender_agender.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_analyser.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_auddim.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_audmodel.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_clap.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_hubert.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_import.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_mld.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_mos.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_opensmile.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_oxbow.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_praat.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_snr.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_spectra.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_spkrec.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_squim.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_trill.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_wav2vec2.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feats_wavlm.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/featureset.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feat_extract/feinberg_praat.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/feature_extractor.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/file_checker.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/filter_data.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/glob_conf.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/losses/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/losses/loss_ccc.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/losses/loss_softf1loss.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_bayes.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_cnn.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_gmm.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_knn.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_knn_reg.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_lin_reg.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_mlp.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_mlp_regression.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_svm.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_svr.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_tree.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_tree_reg.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_xgb.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/models/model_xgr.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/multidb.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/nkuluflag.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/nkululeko.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/plots.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/predict.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/reporting/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/reporting/defines.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/reporting/latex_writer.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/reporting/report.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/reporting/report_item.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/reporting/reporter.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/reporting/result.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/resample.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/runmanager.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/scaler.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/segment.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/segmenting/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/segmenting/seg_inaspeechsegmenter.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/segmenting/seg_silero.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/syllable_nuclei.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/test.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/test_predictor.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/utils/__init__.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/utils/files.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/utils/stats.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko/utils/util.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko.egg-info/dependency_links.txt +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko.egg-info/requires.txt +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/nkululeko.egg-info/top_level.txt +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/pyproject.toml +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/setup.cfg +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/setup.py +0 -0
- {nkululeko-0.84.1 → nkululeko-0.85.1}/venv/bin/activate_this.py +0 -0
@@ -1,6 +1,14 @@
|
|
1
1
|
Changelog
|
2
2
|
=========
|
3
3
|
|
4
|
+
Version 0.85.1
|
5
|
+
--------------
|
6
|
+
* fixed bug in model_finetuned that label_num was constant 2
|
7
|
+
|
8
|
+
Version 0.85.0
|
9
|
+
--------------
|
10
|
+
* first version with finetuning wav2vec2 layers
|
11
|
+
|
4
12
|
Version 0.84.1
|
5
13
|
--------------
|
6
14
|
* made resample independent of config file
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.85.1
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -333,6 +333,14 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
333
333
|
Changelog
|
334
334
|
=========
|
335
335
|
|
336
|
+
Version 0.85.1
|
337
|
+
--------------
|
338
|
+
* fixed bug in model_finetuned that label_num was constant 2
|
339
|
+
|
340
|
+
Version 0.85.0
|
341
|
+
--------------
|
342
|
+
* first version with finetuning wav2vec2 layers
|
343
|
+
|
336
344
|
Version 0.84.1
|
337
345
|
--------------
|
338
346
|
* made resample independent of config file
|
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.
|
1
|
+
VERSION="0.85.1"
|
2
2
|
SAMPLING_RATE = 16000
|
@@ -340,7 +340,12 @@ class Experiment:
|
|
340
340
|
df_train, df_test = self.df_train, self.df_test
|
341
341
|
feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
|
342
342
|
self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
|
343
|
-
feats_types = self.util.config_val_list("FEATS", "type", [
|
343
|
+
feats_types = self.util.config_val_list("FEATS", "type", [])
|
344
|
+
# for some models no features are needed
|
345
|
+
if len(feats_types) == 0:
|
346
|
+
self.util.debug("no feature extractor specified.")
|
347
|
+
self.feats_train, self.feats_test = pd.DataFrame(), pd.DataFrame()
|
348
|
+
return
|
344
349
|
self.feature_extractor = FeatureExtractor(
|
345
350
|
df_train, feats_types, feats_name, "train"
|
346
351
|
)
|
@@ -32,22 +32,19 @@ class Whisper(Featureset):
|
|
32
32
|
model_name = f"openai/{self.feat_type}"
|
33
33
|
self.model = WhisperModel.from_pretrained(model_name).to(self.device)
|
34
34
|
print(f"intialized Whisper model on {self.device}")
|
35
|
-
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
36
|
-
model_name)
|
35
|
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
37
36
|
self.model_initialized = True
|
38
37
|
|
39
38
|
def extract(self):
|
40
39
|
"""Extract the features or load them from disk if present."""
|
41
40
|
store = self.util.get_path("store")
|
42
41
|
storage = f"{store}{self.name}.pkl"
|
43
|
-
extract = self.util.config_val(
|
44
|
-
"FEATS", "needs_feature_extraction", False)
|
42
|
+
extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
|
45
43
|
no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
|
46
44
|
if extract or no_reuse or not os.path.isfile(storage):
|
47
45
|
if not self.model_initialized:
|
48
46
|
self.init_model()
|
49
|
-
self.util.debug(
|
50
|
-
"extracting whisper embeddings, this might take a while...")
|
47
|
+
self.util.debug("extracting whisper embeddings, this might take a while...")
|
51
48
|
emb_series = []
|
52
49
|
for (file, start, end), _ in audeer.progress_bar(
|
53
50
|
self.data_df.iterrows(),
|
@@ -47,16 +47,12 @@ class Modelrunner:
|
|
47
47
|
highest = 0
|
48
48
|
else:
|
49
49
|
highest = 100000
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
self.model.load(self.run, epoch)
|
54
|
-
self.util.debug(f"reusing model: {self.model.store_path}")
|
55
|
-
self.model.reset_test(self.df_test, self.feats_test)
|
56
|
-
else:
|
57
|
-
self.model.set_id(self.run, epoch)
|
58
|
-
self.model.train()
|
50
|
+
if self.model.model_type == "finetuned":
|
51
|
+
# epochs are handled by Huggingface API
|
52
|
+
self.model.train()
|
59
53
|
report = self.model.predict()
|
54
|
+
# todo: findout the best epoch
|
55
|
+
epoch = epoch_num
|
60
56
|
report.set_id(self.run, epoch)
|
61
57
|
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
62
58
|
reports.append(report)
|
@@ -67,32 +63,53 @@ class Modelrunner:
|
|
67
63
|
if plot_epochs:
|
68
64
|
self.util.debug(f"plotting conf matrix to {plot_name}")
|
69
65
|
report.plot_confmatrix(plot_name, epoch)
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
patience = int(patience)
|
78
|
-
result = report.result.get_result()
|
79
|
-
if self.util.high_is_good():
|
80
|
-
if result > highest:
|
81
|
-
highest = result
|
82
|
-
patience_counter = 0
|
83
|
-
else:
|
84
|
-
patience_counter += 1
|
66
|
+
else:
|
67
|
+
# for all epochs
|
68
|
+
for epoch in range(epoch_num):
|
69
|
+
if only_test:
|
70
|
+
self.model.load(self.run, epoch)
|
71
|
+
self.util.debug(f"reusing model: {self.model.store_path}")
|
72
|
+
self.model.reset_test(self.df_test, self.feats_test)
|
85
73
|
else:
|
86
|
-
|
87
|
-
|
88
|
-
|
74
|
+
self.model.set_id(self.run, epoch)
|
75
|
+
self.model.train()
|
76
|
+
report = self.model.predict()
|
77
|
+
report.set_id(self.run, epoch)
|
78
|
+
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
79
|
+
reports.append(report)
|
80
|
+
self.util.debug(
|
81
|
+
f"run: {self.run} epoch: {epoch}: result: "
|
82
|
+
f"{reports[-1].get_result().get_test_result()}"
|
83
|
+
)
|
84
|
+
if plot_epochs:
|
85
|
+
self.util.debug(f"plotting conf matrix to {plot_name}")
|
86
|
+
report.plot_confmatrix(plot_name, epoch)
|
87
|
+
store_models = self.util.config_val("EXP", "save", False)
|
88
|
+
plot_best_model = self.util.config_val("PLOT", "best_model", False)
|
89
|
+
if (store_models or plot_best_model) and (
|
90
|
+
not only_test
|
91
|
+
): # in any case the model needs to be stored to disk.
|
92
|
+
self.model.store()
|
93
|
+
if patience:
|
94
|
+
patience = int(patience)
|
95
|
+
result = report.result.get_result()
|
96
|
+
if self.util.high_is_good():
|
97
|
+
if result > highest:
|
98
|
+
highest = result
|
99
|
+
patience_counter = 0
|
100
|
+
else:
|
101
|
+
patience_counter += 1
|
89
102
|
else:
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
103
|
+
if result < highest:
|
104
|
+
highest = result
|
105
|
+
patience_counter = 0
|
106
|
+
else:
|
107
|
+
patience_counter += 1
|
108
|
+
if patience_counter >= patience:
|
109
|
+
self.util.debug(
|
110
|
+
f"reached patience ({str(patience)}): early stopping"
|
111
|
+
)
|
112
|
+
break
|
96
113
|
|
97
114
|
if not plot_epochs:
|
98
115
|
# Do at least one confusion matrix plot
|
@@ -133,6 +150,12 @@ class Modelrunner:
|
|
133
150
|
self.model = Bayes_model(
|
134
151
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
135
152
|
)
|
153
|
+
elif model_type == "finetune":
|
154
|
+
from nkululeko.models.model_tuned import TunedModel
|
155
|
+
|
156
|
+
self.model = TunedModel(
|
157
|
+
self.df_train, self.df_test, self.feats_train, self.feats_test
|
158
|
+
)
|
136
159
|
elif model_type == "gmm":
|
137
160
|
from nkululeko.models.model_gmm import GMM_model
|
138
161
|
|
@@ -1,3 +1,7 @@
|
|
1
|
+
"""
|
2
|
+
Code based on @jwagner
|
3
|
+
"""
|
4
|
+
|
1
5
|
import dataclasses
|
2
6
|
import typing
|
3
7
|
|
@@ -148,6 +152,11 @@ class Model(Wav2Vec2PreTrainedModel):
|
|
148
152
|
logits_cat=logits_cat,
|
149
153
|
)
|
150
154
|
|
155
|
+
def predict(self, signal):
|
156
|
+
result = self(torch.from_numpy(signal))
|
157
|
+
result = result[0].detach().numpy()[0]
|
158
|
+
return result
|
159
|
+
|
151
160
|
|
152
161
|
class ModelWithPreProcessing(Model):
|
153
162
|
|