nkululeko 0.86.7__tar.gz → 0.87.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nkululeko-0.86.7 → nkululeko-0.87.0}/CHANGELOG.md +8 -0
- {nkululeko-0.86.7/nkululeko.egg-info → nkululeko-0.87.0}/PKG-INFO +17 -1
- {nkululeko-0.86.7 → nkululeko-0.87.0}/README.md +8 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/constants.py +1 -1
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/data/dataset_csv.py +12 -14
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/demo.py +4 -8
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/experiment.py +39 -21
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feature_extractor.py +10 -4
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/modelrunner.py +5 -5
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model.py +23 -3
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_cnn.py +41 -22
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_mlp.py +37 -17
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_mlp_regression.py +3 -1
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/plots.py +25 -37
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/reporting/reporter.py +69 -6
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/runmanager.py +8 -11
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/test_predictor.py +1 -6
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/utils/stats.py +11 -7
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/utils/util.py +22 -16
- {nkululeko-0.86.7 → nkululeko-0.87.0/nkululeko.egg-info}/PKG-INFO +17 -1
- {nkululeko-0.86.7 → nkululeko-0.87.0}/LICENSE +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/aesdd/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/androids/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/androids_orig/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/androids_test/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/ased/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/asvp-esd/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/baved/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/cafe/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/clac/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/cmu-mosei/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/demos/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/ekorpus/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/emns/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/emofilm/convert_to_16k.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/emofilm/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/emorynlp/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/emov-db/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/emovo/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/emozionalmente/create.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/enterface/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/esd/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/gerparas/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/iemocap/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/jl/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/jtes/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/meld/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/mesd/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/mess/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/mlendsnd/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/msp-improv/process_database2.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/msp-podcast/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/oreau2/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/portuguese/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/ravdess/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/ravdess/process_database_speaker.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/savee/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/shemo/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/subesco/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/tess/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/thorsten-emotional/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/urdu/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/data/vivae/process_database.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/docs/source/conf.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/meta/demos/demo_best_model.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/meta/demos/my_experiment.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/meta/demos/my_experiment_local.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/meta/demos/plot_faster_anim.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/aug_train.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/augment.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/augmenting/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/augmenting/augmenter.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/augmenting/randomsplicer.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/augmenting/randomsplicing.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/augmenting/resampler.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_age.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_arousal.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_dominance.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_gender.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_mos.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_pesq.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_sdr.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_snr.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_stoi.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/ap_valence.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/autopredict/estimate_snr.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/cacheddataset.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/data/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/data/dataset.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/demo_feats.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/demo_predictor.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/explore.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/export.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_agender.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_agender_agender.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_analyser.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_auddim.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_audmodel.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_clap.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_hubert.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_import.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_mld.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_mos.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_opensmile.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_oxbow.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_praat.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_snr.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_spectra.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_spkrec.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_squim.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_trill.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_wav2vec2.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_wavlm.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feats_whisper.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/featureset.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/feat_extract/feinberg_praat.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/file_checker.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/filter_data.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/glob_conf.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/losses/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/losses/loss_ccc.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/losses/loss_softf1loss.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_bayes.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_gmm.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_knn.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_knn_reg.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_lin_reg.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_svm.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_svr.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_tree.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_tree_reg.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_tuned.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_xgb.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/models/model_xgr.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/multidb.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/nkuluflag.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/nkululeko.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/predict.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/reporting/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/reporting/defines.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/reporting/latex_writer.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/reporting/report.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/reporting/report_item.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/reporting/result.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/resample.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/scaler.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/segment.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/segmenting/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/segmenting/seg_inaspeechsegmenter.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/segmenting/seg_silero.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/syllable_nuclei.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/test.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/test_pretrain.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/utils/__init__.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko/utils/files.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko.egg-info/SOURCES.txt +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko.egg-info/dependency_links.txt +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko.egg-info/requires.txt +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/nkululeko.egg-info/top_level.txt +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/pyproject.toml +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/setup.cfg +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/setup.py +0 -0
- {nkululeko-0.86.7 → nkululeko-0.87.0}/venv/bin/activate_this.py +0 -0
@@ -1,6 +1,14 @@
|
|
1
1
|
Changelog
|
2
2
|
=========
|
3
3
|
|
4
|
+
Version 0.87.0
|
5
|
+
--------------
|
6
|
+
* added class probability output and uncertainty analysis
|
7
|
+
|
8
|
+
Version 0.86.8
|
9
|
+
--------------
|
10
|
+
* handle single feature sets as strings in the config
|
11
|
+
|
4
12
|
Version 0.86.7
|
5
13
|
--------------
|
6
14
|
* handles now audformat tables where the target is in a file index
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.87.0
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -51,6 +51,7 @@ Requires-Dist: pylatex
|
|
51
51
|
- [t-SNE plots](#t-sne-plots)
|
52
52
|
- [Data distribution](#data-distribution)
|
53
53
|
- [Bias checking](#bias-checking)
|
54
|
+
- [Uncertainty](#uncertainty)
|
54
55
|
- [Documentation](#documentation)
|
55
56
|
- [Installation](#installation)
|
56
57
|
- [Usage](#usage)
|
@@ -113,6 +114,13 @@ In cases you might wonder if there's bias in your data. You can try to detect th
|
|
113
114
|
|
114
115
|
<img src="meta/images/emotion-pesq.png" width="500px"/>
|
115
116
|
|
117
|
+
### Uncertainty
|
118
|
+
Nkululeko estimates uncertainty of model decision (only for classifiers) with entropy over the class-probabilities or logits per sample.
|
119
|
+
|
120
|
+
<img src="meta/images/uncertainty.png" width="500px"/>
|
121
|
+
|
122
|
+
|
123
|
+
|
116
124
|
## Documentation
|
117
125
|
The documentation, along with extensions of installation, usage, INI file format, and examples, can be found [nkululeko.readthedocs.io](https://nkululeko.readthedocs.io).
|
118
126
|
|
@@ -343,6 +351,14 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
343
351
|
Changelog
|
344
352
|
=========
|
345
353
|
|
354
|
+
Version 0.87.0
|
355
|
+
--------------
|
356
|
+
* added class probability output and uncertainty analysis
|
357
|
+
|
358
|
+
Version 0.86.8
|
359
|
+
--------------
|
360
|
+
* handle single feature sets as strings in the config
|
361
|
+
|
346
362
|
Version 0.86.7
|
347
363
|
--------------
|
348
364
|
* handles now audformat tables where the target is in a file index
|
@@ -7,6 +7,7 @@
|
|
7
7
|
- [t-SNE plots](#t-sne-plots)
|
8
8
|
- [Data distribution](#data-distribution)
|
9
9
|
- [Bias checking](#bias-checking)
|
10
|
+
- [Uncertainty](#uncertainty)
|
10
11
|
- [Documentation](#documentation)
|
11
12
|
- [Installation](#installation)
|
12
13
|
- [Usage](#usage)
|
@@ -69,6 +70,13 @@ In cases you might wonder if there's bias in your data. You can try to detect th
|
|
69
70
|
|
70
71
|
<img src="meta/images/emotion-pesq.png" width="500px"/>
|
71
72
|
|
73
|
+
### Uncertainty
|
74
|
+
Nkululeko estimates uncertainty of model decision (only for classifiers) with entropy over the class-probabilities or logits per sample.
|
75
|
+
|
76
|
+
<img src="meta/images/uncertainty.png" width="500px"/>
|
77
|
+
|
78
|
+
|
79
|
+
|
72
80
|
## Documentation
|
73
81
|
The documentation, along with extensions of installation, usage, INI file format, and examples, can be found [nkululeko.readthedocs.io](https://nkululeko.readthedocs.io).
|
74
82
|
|
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.
|
1
|
+
VERSION="0.87.0"
|
2
2
|
SAMPLING_RATE = 16000
|
@@ -23,6 +23,9 @@ class Dataset_CSV(Dataset):
|
|
23
23
|
root = os.path.dirname(data_file)
|
24
24
|
audio_path = self.util.config_val_data(self.name, "audio_path", "./")
|
25
25
|
df = pd.read_csv(data_file)
|
26
|
+
# trim all string values
|
27
|
+
df_obj = df.select_dtypes("object")
|
28
|
+
df[df_obj.columns] = df_obj.apply(lambda x: x.str.strip())
|
26
29
|
# special treatment for segmented dataframes with only one column:
|
27
30
|
if "start" in df.columns and len(df.columns) == 4:
|
28
31
|
index = audformat.segmented_index(
|
@@ -49,8 +52,7 @@ class Dataset_CSV(Dataset):
|
|
49
52
|
.map(lambda x: root + "/" + audio_path + "/" + x)
|
50
53
|
.values
|
51
54
|
)
|
52
|
-
df = df.set_index(df.index.set_levels(
|
53
|
-
file_index, level="file"))
|
55
|
+
df = df.set_index(df.index.set_levels(file_index, level="file"))
|
54
56
|
else:
|
55
57
|
if not isinstance(df, pd.DataFrame):
|
56
58
|
df = pd.DataFrame(df)
|
@@ -59,27 +61,24 @@ class Dataset_CSV(Dataset):
|
|
59
61
|
lambda x: root + "/" + audio_path + "/" + x
|
60
62
|
)
|
61
63
|
)
|
62
|
-
else:
|
64
|
+
else: # absolute path is True
|
63
65
|
if audformat.index_type(df.index) == "segmented":
|
64
66
|
file_index = (
|
65
|
-
df.index.levels[0]
|
66
|
-
.map(lambda x: audio_path + "/" + x)
|
67
|
-
.values
|
67
|
+
df.index.levels[0].map(lambda x: audio_path + "/" + x).values
|
68
68
|
)
|
69
|
-
df = df.set_index(df.index.set_levels(
|
70
|
-
file_index, level="file"))
|
69
|
+
df = df.set_index(df.index.set_levels(file_index, level="file"))
|
71
70
|
else:
|
72
71
|
if not isinstance(df, pd.DataFrame):
|
73
72
|
df = pd.DataFrame(df)
|
74
|
-
df = df.set_index(
|
75
|
-
lambda x: audio_path + "/" + x
|
73
|
+
df = df.set_index(
|
74
|
+
df.index.to_series().apply(lambda x: audio_path + "/" + x)
|
75
|
+
)
|
76
76
|
|
77
77
|
self.df = df
|
78
78
|
self.db = None
|
79
79
|
self.got_target = True
|
80
80
|
self.is_labeled = self.got_target
|
81
|
-
self.start_fresh = eval(
|
82
|
-
self.util.config_val("DATA", "no_reuse", "False"))
|
81
|
+
self.start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
|
83
82
|
is_index = False
|
84
83
|
try:
|
85
84
|
if self.is_labeled and not "class_label" in self.df.columns:
|
@@ -106,8 +105,7 @@ class Dataset_CSV(Dataset):
|
|
106
105
|
f" {self.got_gender}, got age: {self.got_age}"
|
107
106
|
)
|
108
107
|
self.util.debug(r_string)
|
109
|
-
glob_conf.report.add_item(ReportItem(
|
110
|
-
"Data", "Loaded report", r_string))
|
108
|
+
glob_conf.report.add_item(ReportItem("Data", "Loaded report", r_string))
|
111
109
|
|
112
110
|
def prepare(self):
|
113
111
|
super().prepare()
|
@@ -30,10 +30,8 @@ from transformers import pipeline
|
|
30
30
|
|
31
31
|
|
32
32
|
def main(src_dir):
|
33
|
-
parser = argparse.ArgumentParser(
|
34
|
-
|
35
|
-
parser.add_argument("--config", default="exp.ini",
|
36
|
-
help="The base configuration")
|
33
|
+
parser = argparse.ArgumentParser(description="Call the nkululeko DEMO framework.")
|
34
|
+
parser.add_argument("--config", default="exp.ini", help="The base configuration")
|
37
35
|
parser.add_argument(
|
38
36
|
"--file", help="A file that should be processed (16kHz mono wav)"
|
39
37
|
)
|
@@ -84,8 +82,7 @@ def main(src_dir):
|
|
84
82
|
)
|
85
83
|
|
86
84
|
def print_pipe(files, outfile):
|
87
|
-
"""
|
88
|
-
Prints the pipeline output for a list of files, and optionally writes the results to an output file.
|
85
|
+
"""Prints the pipeline output for a list of files, and optionally writes the results to an output file.
|
89
86
|
|
90
87
|
Args:
|
91
88
|
files (list): A list of file paths to process through the pipeline.
|
@@ -108,8 +105,7 @@ def main(src_dir):
|
|
108
105
|
f.write("\n".join(results))
|
109
106
|
|
110
107
|
if util.get_model_type() == "finetune":
|
111
|
-
model_path = os.path.join(
|
112
|
-
util.get_exp_dir(), "models", "run_0", "torch")
|
108
|
+
model_path = os.path.join(util.get_exp_dir(), "models", "run_0", "torch")
|
113
109
|
pipe = pipeline("audio-classification", model=model_path)
|
114
110
|
if args.file is not None:
|
115
111
|
print_pipe([args.file], args.outfile)
|
@@ -5,13 +5,13 @@ import pickle
|
|
5
5
|
import random
|
6
6
|
import time
|
7
7
|
|
8
|
+
import audeer
|
9
|
+
import audformat
|
8
10
|
import numpy as np
|
9
11
|
import pandas as pd
|
10
12
|
from sklearn.preprocessing import LabelEncoder
|
11
13
|
|
12
|
-
import
|
13
|
-
import audformat
|
14
|
-
|
14
|
+
import nkululeko.glob_conf as glob_conf
|
15
15
|
from nkululeko.data.dataset import Dataset
|
16
16
|
from nkululeko.data.dataset_csv import Dataset_CSV
|
17
17
|
from nkululeko.demo_predictor import Demo_predictor
|
@@ -19,8 +19,6 @@ from nkululeko.feat_extract.feats_analyser import FeatureAnalyser
|
|
19
19
|
from nkululeko.feature_extractor import FeatureExtractor
|
20
20
|
from nkululeko.file_checker import FileChecker
|
21
21
|
from nkululeko.filter_data import DataFilter
|
22
|
-
from nkululeko.filter_data import filter_min_dur
|
23
|
-
import nkululeko.glob_conf as glob_conf
|
24
22
|
from nkululeko.plots import Plots
|
25
23
|
from nkululeko.reporting.report import Report
|
26
24
|
from nkululeko.runmanager import Runmanager
|
@@ -109,7 +107,8 @@ class Experiment:
|
|
109
107
|
# print keys/column
|
110
108
|
dbs = ",".join(list(self.datasets.keys()))
|
111
109
|
labels = self.util.config_val("DATA", "labels", False)
|
112
|
-
auto_labels = list(
|
110
|
+
auto_labels = list(
|
111
|
+
next(iter(self.datasets.values())).df[self.target].unique())
|
113
112
|
if labels:
|
114
113
|
self.labels = ast.literal_eval(labels)
|
115
114
|
self.util.debug(f"Using target labels (from config): {labels}")
|
@@ -159,7 +158,8 @@ class Experiment:
|
|
159
158
|
data.split()
|
160
159
|
data.prepare_labels()
|
161
160
|
self.df_test = pd.concat(
|
162
|
-
[self.df_test, self.util.make_segmented_index(
|
161
|
+
[self.df_test, self.util.make_segmented_index(
|
162
|
+
data.df_test)]
|
163
163
|
)
|
164
164
|
self.df_test.is_labeled = data.is_labeled
|
165
165
|
self.df_test.got_gender = self.got_gender
|
@@ -260,7 +260,8 @@ class Experiment:
|
|
260
260
|
test_cats = self.df_test[self.target].unique()
|
261
261
|
else:
|
262
262
|
# if there is no target, copy a dummy label
|
263
|
-
self.df_test = self._add_random_target(
|
263
|
+
self.df_test = self._add_random_target(
|
264
|
+
self.df_test).astype("str")
|
264
265
|
train_cats = self.df_train[self.target].unique()
|
265
266
|
# print(f"df_train: {pd.DataFrame(self.df_train[self.target])}")
|
266
267
|
# print(f"train_cats with target {self.target}: {train_cats}")
|
@@ -268,7 +269,8 @@ class Experiment:
|
|
268
269
|
if type(test_cats) == np.ndarray:
|
269
270
|
self.util.debug(f"Categories test (nd.array): {test_cats}")
|
270
271
|
else:
|
271
|
-
self.util.debug(
|
272
|
+
self.util.debug(
|
273
|
+
f"Categories test (list): {list(test_cats)}")
|
272
274
|
if type(train_cats) == np.ndarray:
|
273
275
|
self.util.debug(f"Categories train (nd.array): {train_cats}")
|
274
276
|
else:
|
@@ -291,7 +293,8 @@ class Experiment:
|
|
291
293
|
|
292
294
|
target_factor = self.util.config_val("DATA", "target_divide_by", False)
|
293
295
|
if target_factor:
|
294
|
-
self.df_test[self.target] = self.df_test[self.target] /
|
296
|
+
self.df_test[self.target] = self.df_test[self.target] / \
|
297
|
+
float(target_factor)
|
295
298
|
self.df_train[self.target] = self.df_train[self.target] / float(
|
296
299
|
target_factor
|
297
300
|
)
|
@@ -314,14 +317,16 @@ class Experiment:
|
|
314
317
|
def plot_distribution(self, df_labels):
|
315
318
|
"""Plot the distribution of samples and speaker per target class and biological sex"""
|
316
319
|
plot = Plots()
|
317
|
-
sample_selection = self.util.config_val(
|
320
|
+
sample_selection = self.util.config_val(
|
321
|
+
"EXPL", "sample_selection", "all")
|
318
322
|
plot.plot_distributions(df_labels)
|
319
323
|
if self.got_speaker:
|
320
324
|
plot.plot_distributions_speaker(df_labels)
|
321
325
|
|
322
326
|
def extract_test_feats(self):
|
323
327
|
self.feats_test = pd.DataFrame()
|
324
|
-
feats_name = "_".join(ast.literal_eval(
|
328
|
+
feats_name = "_".join(ast.literal_eval(
|
329
|
+
glob_conf.config["DATA"]["tests"]))
|
325
330
|
feats_types = self.util.config_val_list("FEATS", "type", ["os"])
|
326
331
|
self.feature_extractor = FeatureExtractor(
|
327
332
|
self.df_test, feats_types, feats_name, "test"
|
@@ -338,9 +343,17 @@ class Experiment:
|
|
338
343
|
|
339
344
|
"""
|
340
345
|
df_train, df_test = self.df_train, self.df_test
|
341
|
-
feats_name = "_".join(ast.literal_eval(
|
346
|
+
feats_name = "_".join(ast.literal_eval(
|
347
|
+
glob_conf.config["DATA"]["databases"]))
|
342
348
|
self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
|
343
|
-
feats_types = self.util.
|
349
|
+
feats_types = self.util.config_val("FEATS", "type", "os")
|
350
|
+
# Ensure feats_types is always a list of strings
|
351
|
+
if isinstance(feats_types, str):
|
352
|
+
if feats_types.startswith("[") and feats_types.endswith("]"):
|
353
|
+
feats_types = ast.literal_eval(feats_types)
|
354
|
+
else:
|
355
|
+
feats_types = [feats_types]
|
356
|
+
# print(f"feats_types: {feats_types}")
|
344
357
|
# for some models no features are needed
|
345
358
|
if len(feats_types) == 0:
|
346
359
|
self.util.debug("no feature extractor specified.")
|
@@ -372,7 +385,8 @@ class Experiment:
|
|
372
385
|
f"test feats ({self.feats_test.shape[0]}) != test labels"
|
373
386
|
f" ({self.df_test.shape[0]})"
|
374
387
|
)
|
375
|
-
self.df_test = self.df_test[self.df_test.index.isin(
|
388
|
+
self.df_test = self.df_test[self.df_test.index.isin(
|
389
|
+
self.feats_test.index)]
|
376
390
|
self.util.warn(f"new test labels shape: {self.df_test.shape[0]}")
|
377
391
|
|
378
392
|
self._check_scale()
|
@@ -387,7 +401,8 @@ class Experiment:
|
|
387
401
|
"""Augment the selected samples."""
|
388
402
|
from nkululeko.augmenting.augmenter import Augmenter
|
389
403
|
|
390
|
-
sample_selection = self.util.config_val(
|
404
|
+
sample_selection = self.util.config_val(
|
405
|
+
"AUGMENT", "sample_selection", "all")
|
391
406
|
if sample_selection == "all":
|
392
407
|
df = pd.concat([self.df_train, self.df_test])
|
393
408
|
elif sample_selection == "train":
|
@@ -482,7 +497,8 @@ class Experiment:
|
|
482
497
|
"""
|
483
498
|
from nkululeko.augmenting.randomsplicer import Randomsplicer
|
484
499
|
|
485
|
-
sample_selection = self.util.config_val(
|
500
|
+
sample_selection = self.util.config_val(
|
501
|
+
"AUGMENT", "sample_selection", "all")
|
486
502
|
if sample_selection == "all":
|
487
503
|
df = pd.concat([self.df_train, self.df_test])
|
488
504
|
elif sample_selection == "train":
|
@@ -503,7 +519,8 @@ class Experiment:
|
|
503
519
|
plot_feats = eval(
|
504
520
|
self.util.config_val("EXPL", "feature_distributions", "False")
|
505
521
|
)
|
506
|
-
sample_selection = self.util.config_val(
|
522
|
+
sample_selection = self.util.config_val(
|
523
|
+
"EXPL", "sample_selection", "all")
|
507
524
|
# get the data labels
|
508
525
|
if sample_selection == "all":
|
509
526
|
df_labels = pd.concat([self.df_train, self.df_test])
|
@@ -566,7 +583,8 @@ class Experiment:
|
|
566
583
|
for scat_target in scat_targets:
|
567
584
|
if self.util.is_categorical(df_labels[scat_target]):
|
568
585
|
for scatter in scatters:
|
569
|
-
plots.scatter_plot(
|
586
|
+
plots.scatter_plot(
|
587
|
+
df_feats, df_labels, scat_target, scatter)
|
570
588
|
else:
|
571
589
|
self.util.debug(
|
572
590
|
f"{self.name}: binning continuous variable to categories"
|
@@ -657,7 +675,8 @@ class Experiment:
|
|
657
675
|
preds = best.preds
|
658
676
|
speakers = self.df_test.speaker.values
|
659
677
|
print(f"{len(truths)} {len(preds)} {len(speakers) }")
|
660
|
-
df = pd.DataFrame(
|
678
|
+
df = pd.DataFrame(
|
679
|
+
data={"truth": truths, "pred": preds, "speaker": speakers})
|
661
680
|
plot_name = "result_combined_per_speaker"
|
662
681
|
self.util.debug(
|
663
682
|
f"plotting speaker combination ({function}) confusion matrix to"
|
@@ -733,7 +752,6 @@ class Experiment:
|
|
733
752
|
if model.is_ann():
|
734
753
|
print("converting to onnx from torch")
|
735
754
|
else:
|
736
|
-
from skl2onnx import to_onnx
|
737
755
|
|
738
756
|
print("converting to onnx from sklearn")
|
739
757
|
# save the rest
|
@@ -39,16 +39,20 @@ class FeatureExtractor:
|
|
39
39
|
self.feats = pd.DataFrame()
|
40
40
|
for feats_type in self.feats_types:
|
41
41
|
store_name = f"{self.data_name}_{feats_type}"
|
42
|
-
self.feat_extractor = self._get_feat_extractor(
|
42
|
+
self.feat_extractor = self._get_feat_extractor(
|
43
|
+
store_name, feats_type)
|
43
44
|
self.feat_extractor.extract()
|
44
45
|
self.feat_extractor.filter()
|
45
|
-
self.feats = pd.concat(
|
46
|
+
self.feats = pd.concat(
|
47
|
+
[self.feats, self.feat_extractor.df], axis=1)
|
46
48
|
return self.feats
|
47
49
|
|
48
50
|
def extract_sample(self, signal, sr):
|
49
51
|
return self.feat_extractor.extract_sample(signal, sr)
|
50
52
|
|
51
53
|
def _get_feat_extractor(self, store_name, feats_type):
|
54
|
+
if isinstance(feats_type, list) and len(feats_type) == 1:
|
55
|
+
feats_type = feats_type[0]
|
52
56
|
feat_extractor_class = self._get_feat_extractor_class(feats_type)
|
53
57
|
if feat_extractor_class is None:
|
54
58
|
self.util.error(f"unknown feats_type: {feats_type}")
|
@@ -103,13 +107,15 @@ class FeatureExtractor:
|
|
103
107
|
prefix, _, ext = feats_type.partition("-")
|
104
108
|
from importlib import import_module
|
105
109
|
|
106
|
-
module = import_module(
|
110
|
+
module = import_module(
|
111
|
+
f"nkululeko.feat_extract.feats_{prefix.lower()}")
|
107
112
|
class_name = f"{prefix.capitalize()}"
|
108
113
|
return getattr(module, class_name)
|
109
114
|
|
110
115
|
def _get_feat_extractor_by_name(self, feats_type):
|
111
116
|
from importlib import import_module
|
112
117
|
|
113
|
-
module = import_module(
|
118
|
+
module = import_module(
|
119
|
+
f"nkululeko.feat_extract.feats_{feats_type.lower()}")
|
114
120
|
class_name = f"{feats_type.capitalize()}Set"
|
115
121
|
return getattr(module, class_name)
|
@@ -85,7 +85,7 @@ class Modelrunner:
|
|
85
85
|
f"run: {self.run} epoch: {epoch}: result: {test_score_metric}"
|
86
86
|
)
|
87
87
|
# print(f"performance: {performance.split(' ')[1]}")
|
88
|
-
performance = float(test_score_metric.split(
|
88
|
+
performance = float(test_score_metric.split(" ")[1])
|
89
89
|
if performance > self.best_performance:
|
90
90
|
self.best_performance = performance
|
91
91
|
self.best_epoch = epoch
|
@@ -204,15 +204,15 @@ class Modelrunner:
|
|
204
204
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
205
205
|
)
|
206
206
|
elif model_type == "cnn":
|
207
|
-
from nkululeko.models.model_cnn import
|
207
|
+
from nkululeko.models.model_cnn import CNNModel
|
208
208
|
|
209
|
-
self.model =
|
209
|
+
self.model = CNNModel(
|
210
210
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
211
211
|
)
|
212
212
|
elif model_type == "mlp":
|
213
|
-
from nkululeko.models.model_mlp import
|
213
|
+
from nkululeko.models.model_mlp import MLPModel
|
214
214
|
|
215
|
-
self.model =
|
215
|
+
self.model = MLPModel(
|
216
216
|
self.df_train, self.df_test, self.feats_train, self.feats_test
|
217
217
|
)
|
218
218
|
elif model_type == "mlp_reg":
|
@@ -247,8 +247,25 @@ class Model:
|
|
247
247
|
self.clf.fit(feats, labels)
|
248
248
|
|
249
249
|
def get_predictions(self):
|
250
|
-
predictions = self.clf.predict(self.feats_test.to_numpy())
|
251
|
-
|
250
|
+
# predictions = self.clf.predict(self.feats_test.to_numpy())
|
251
|
+
if self.util.exp_is_classification():
|
252
|
+
# make a dataframe for the class probabilities
|
253
|
+
proba_d = {}
|
254
|
+
for c in self.clf.classes_:
|
255
|
+
proba_d[c] = []
|
256
|
+
# get the class probabilities
|
257
|
+
predictions = self.clf.predict_proba(self.feats_test.to_numpy())
|
258
|
+
# pred = self.clf.predict(features)
|
259
|
+
for i, c in enumerate(self.clf.classes_):
|
260
|
+
proba_d[c] = list(predictions.T[i])
|
261
|
+
probas = pd.DataFrame(proba_d)
|
262
|
+
probas = probas.set_index(self.feats_test.index)
|
263
|
+
predictions = probas.idxmax(axis=1).values
|
264
|
+
else:
|
265
|
+
predictions = self.clf.predict(self.feats_test.to_numpy())
|
266
|
+
probas = None
|
267
|
+
|
268
|
+
return predictions, probas
|
252
269
|
|
253
270
|
def predict(self):
|
254
271
|
if self.feats_test.isna().to_numpy().any():
|
@@ -263,13 +280,16 @@ class Model:
|
|
263
280
|
)
|
264
281
|
return report
|
265
282
|
"""Predict the whole eval feature set"""
|
266
|
-
predictions = self.get_predictions()
|
283
|
+
predictions, probas = self.get_predictions()
|
284
|
+
|
267
285
|
report = Reporter(
|
268
286
|
self.df_test[self.target].to_numpy().astype(float),
|
269
287
|
predictions,
|
270
288
|
self.run,
|
271
289
|
self.epoch,
|
290
|
+
probas=probas,
|
272
291
|
)
|
292
|
+
report.print_probabilities()
|
273
293
|
return report
|
274
294
|
|
275
295
|
def get_type(self):
|
@@ -5,33 +5,40 @@ Inspired by code from Su Lei
|
|
5
5
|
|
6
6
|
"""
|
7
7
|
|
8
|
+
import ast
|
9
|
+
from collections import OrderedDict
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
from PIL import Image
|
14
|
+
from sklearn.metrics import recall_score
|
8
15
|
import torch
|
9
16
|
import torch.nn as nn
|
10
17
|
import torch.nn.functional as F
|
11
|
-
import torchvision
|
12
|
-
import torchvision.transforms as transforms
|
13
18
|
from torch.utils.data import Dataset
|
14
|
-
import
|
15
|
-
import numpy as np
|
16
|
-
from sklearn.metrics import recall_score
|
17
|
-
from collections import OrderedDict
|
18
|
-
from PIL import Image
|
19
|
-
from traitlets import default
|
19
|
+
import torchvision.transforms as transforms
|
20
20
|
|
21
|
-
from nkululeko.utils.util import Util
|
22
21
|
import nkululeko.glob_conf as glob_conf
|
22
|
+
from nkululeko.losses.loss_softf1loss import SoftF1Loss
|
23
23
|
from nkululeko.models.model import Model
|
24
24
|
from nkululeko.reporting.reporter import Reporter
|
25
|
-
from nkululeko.
|
25
|
+
from nkululeko.utils.util import Util
|
26
26
|
|
27
27
|
|
28
|
-
class
|
29
|
-
"""CNN = convolutional neural net"""
|
28
|
+
class CNNModel(Model):
|
29
|
+
"""CNN = convolutional neural net."""
|
30
30
|
|
31
31
|
is_classifier = True
|
32
32
|
|
33
33
|
def __init__(self, df_train, df_test, feats_train, feats_test):
|
34
|
-
"""Constructor taking
|
34
|
+
"""Constructor, taking all dataframes.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
df_train (pd.DataFrame): The train labels.
|
38
|
+
df_test (pd.DataFrame): The test labels.
|
39
|
+
feats_train (pd.DataFrame): The train features.
|
40
|
+
feats_test (pd.DataFrame): The test features.
|
41
|
+
"""
|
35
42
|
super().__init__(df_train, df_test, feats_train, feats_test)
|
36
43
|
super().set_model_type("ann")
|
37
44
|
self.name = "cnn"
|
@@ -147,7 +154,20 @@ class CNN_model(Model):
|
|
147
154
|
self.optimizer.step()
|
148
155
|
self.loss = (np.asarray(losses)).mean()
|
149
156
|
|
150
|
-
def
|
157
|
+
def get_probas(self, logits):
|
158
|
+
# make a dataframe for probabilites (logits)
|
159
|
+
proba_d = {}
|
160
|
+
classes = self.df_test[self.target].unique()
|
161
|
+
classes.sort()
|
162
|
+
for c in classes:
|
163
|
+
proba_d[c] = []
|
164
|
+
for i, c in enumerate(classes):
|
165
|
+
proba_d[c] = list(logits.numpy().T[i])
|
166
|
+
probas = pd.DataFrame(proba_d)
|
167
|
+
probas = probas.set_index(self.df_test.index)
|
168
|
+
return probas
|
169
|
+
|
170
|
+
def evaluate(self, model, loader, device):
|
151
171
|
logits = torch.zeros(len(loader.dataset), self.class_num)
|
152
172
|
targets = torch.zeros(len(loader.dataset))
|
153
173
|
model.eval()
|
@@ -169,14 +189,15 @@ class CNN_model(Model):
|
|
169
189
|
self.loss_eval = (np.asarray(losses)).mean()
|
170
190
|
predictions = logits.argmax(dim=1)
|
171
191
|
uar = recall_score(targets.numpy(), predictions.numpy(), average="macro")
|
172
|
-
return uar, targets, predictions
|
192
|
+
return uar, targets, predictions, logits
|
173
193
|
|
174
194
|
def predict(self):
|
175
|
-
_, truths, predictions = self.
|
195
|
+
_, truths, predictions, logits = self.evaluate(
|
176
196
|
self.model, self.testloader, self.device
|
177
197
|
)
|
178
|
-
uar, _, _ = self.
|
179
|
-
|
198
|
+
uar, _, _, _ = self.evaluate(self.model, self.trainloader, self.device)
|
199
|
+
probas = self.get_probas(logits)
|
200
|
+
report = Reporter(truths, predictions, self.run, self.epoch, probas=probas)
|
180
201
|
try:
|
181
202
|
report.result.loss = self.loss
|
182
203
|
except AttributeError: # if the model was loaded from disk the loss is unknown
|
@@ -189,13 +210,11 @@ class CNN_model(Model):
|
|
189
210
|
return report
|
190
211
|
|
191
212
|
def get_predictions(self):
|
192
|
-
_,
|
193
|
-
self.model, self.testloader, self.device
|
194
|
-
)
|
213
|
+
_, _, predictions, _ = self.evaluate(self.model, self.testloader, self.device)
|
195
214
|
return predictions.numpy()
|
196
215
|
|
197
216
|
def predict_sample(self, features):
|
198
|
-
"""Predict one sample"""
|
217
|
+
"""Predict one sample."""
|
199
218
|
with torch.no_grad():
|
200
219
|
logits = self.model(torch.from_numpy(features).to(self.device))
|
201
220
|
a = logits.numpy()
|