nkululeko 0.81.2__py3-none-any.whl → 0.81.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/feat_extract/feats_agender_agender.py +5 -4
- nkululeko/feat_extract/feinberg_praat.py +114 -116
- nkululeko/models/model.py +1 -1
- nkululeko/models/model_cnn.py +1 -1
- nkululeko/models/model_mlp.py +1 -1
- nkululeko/models/model_mlp_regression.py +13 -9
- nkululeko/{reporter.py → reporting/reporter.py} +86 -51
- nkululeko/{result.py → reporting/result.py} +5 -0
- nkululeko/runmanager.py +1 -1
- {nkululeko-0.81.2.dist-info → nkululeko-0.81.3.dist-info}/METADATA +7 -1
- {nkululeko-0.81.2.dist-info → nkululeko-0.81.3.dist-info}/RECORD +15 -15
- {nkululeko-0.81.2.dist-info → nkululeko-0.81.3.dist-info}/LICENSE +0 -0
- {nkululeko-0.81.2.dist-info → nkululeko-0.81.3.dist-info}/WHEEL +0 -0
- {nkululeko-0.81.2.dist-info → nkululeko-0.81.3.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.81.
|
1
|
+
VERSION="0.81.3"
|
2
2
|
SAMPLING_RATE = 16000
|
@@ -32,10 +32,11 @@ class AgenderAgenderSet(Featureset):
|
|
32
32
|
audeer.extract_archive(archive_path, model_root)
|
33
33
|
device = self.util.config_val("MODEL", "device", "cpu")
|
34
34
|
self.model = audonnx.load(model_root, device=device)
|
35
|
-
pytorch_total_params = sum(p.numel() for p in self.model.parameters())
|
36
|
-
self.util.debug(
|
37
|
-
|
38
|
-
)
|
35
|
+
# pytorch_total_params = sum(p.numel() for p in self.model.parameters())
|
36
|
+
# self.util.debug(
|
37
|
+
# f"initialized agender model with {pytorch_total_params} parameters in total"
|
38
|
+
# )
|
39
|
+
self.util.debug("initialized agender model")
|
39
40
|
self.model_loaded = True
|
40
41
|
|
41
42
|
def extract(self):
|
@@ -1,47 +1,46 @@
|
|
1
|
-
"""
|
2
|
-
This is a copy of David R. Feinberg's Praat scripts
|
1
|
+
"""This is a copy of David R. Feinberg's Praat scripts.
|
3
2
|
https://github.com/drfeinberg/PraatScripts
|
4
|
-
taken June 23rd 2022
|
3
|
+
taken June 23rd 2022.
|
5
4
|
"""
|
6
5
|
|
7
6
|
#!/usr/bin/env python3
|
7
|
+
import math
|
8
|
+
import statistics
|
9
|
+
|
8
10
|
import numpy as np
|
9
11
|
import pandas as pd
|
10
|
-
import math
|
11
|
-
from tqdm import tqdm
|
12
12
|
import parselmouth
|
13
|
-
import statistics
|
14
|
-
from nkululeko.utils.util import Util
|
15
|
-
import audiofile
|
16
13
|
from parselmouth.praat import call
|
17
14
|
from scipy.stats.mstats import zscore
|
18
15
|
from sklearn.decomposition import PCA
|
19
|
-
from
|
16
|
+
from tqdm import tqdm
|
17
|
+
|
18
|
+
import audiofile
|
20
19
|
|
21
20
|
|
22
21
|
# This is the function to measure source acoustics using default male parameters.
|
23
22
|
|
24
23
|
|
25
|
-
def
|
26
|
-
sound = parselmouth.Sound(
|
24
|
+
def measure_pitch(voice_id, f0min, f0max, unit):
|
25
|
+
sound = parselmouth.Sound(voice_id) # read the sound
|
27
26
|
duration = call(sound, "Get total duration") # duration
|
28
27
|
pitch = call(sound, "To Pitch", 0.0, f0min, f0max) # create a praat pitch object
|
29
|
-
|
30
|
-
|
28
|
+
mean_f0 = call(pitch, "Get mean", 0, 0, unit) # get mean pitch
|
29
|
+
stdev_f0 = call(
|
31
30
|
pitch, "Get standard deviation", 0, 0, unit
|
32
31
|
) # get standard deviation
|
33
32
|
harmonicity = call(sound, "To Harmonicity (cc)", 0.01, f0min, 0.1, 1.0)
|
34
33
|
hnr = call(harmonicity, "Get mean", 0, 0)
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
34
|
+
point_process = call(sound, "To PointProcess (periodic, cc)", f0min, f0max)
|
35
|
+
local_jitter = call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
|
36
|
+
localabsolute_jitter = call(
|
37
|
+
point_process, "Get jitter (local, absolute)", 0, 0, 0.0001, 0.02, 1.3
|
39
38
|
)
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
[sound,
|
39
|
+
rap_jitter = call(point_process, "Get jitter (rap)", 0, 0, 0.0001, 0.02, 1.3)
|
40
|
+
ppq5_jitter = call(point_process, "Get jitter (ppq5)", 0, 0, 0.0001, 0.02, 1.3)
|
41
|
+
ddp_jitter = call(point_process, "Get jitter (ddp)", 0, 0, 0.0001, 0.02, 1.3)
|
42
|
+
local_shimmer = call(
|
43
|
+
[sound, point_process],
|
45
44
|
"Get shimmer (local)",
|
46
45
|
0,
|
47
46
|
0,
|
@@ -50,8 +49,8 @@ def measurePitch(voiceID, f0min, f0max, unit):
|
|
50
49
|
1.3,
|
51
50
|
1.6,
|
52
51
|
)
|
53
|
-
|
54
|
-
[sound,
|
52
|
+
localdb_shimmer = call(
|
53
|
+
[sound, point_process],
|
55
54
|
"Get shimmer (local_dB)",
|
56
55
|
0,
|
57
56
|
0,
|
@@ -60,8 +59,8 @@ def measurePitch(voiceID, f0min, f0max, unit):
|
|
60
59
|
1.3,
|
61
60
|
1.6,
|
62
61
|
)
|
63
|
-
|
64
|
-
[sound,
|
62
|
+
apq3_shimmer = call(
|
63
|
+
[sound, point_process],
|
65
64
|
"Get shimmer (apq3)",
|
66
65
|
0,
|
67
66
|
0,
|
@@ -70,8 +69,8 @@ def measurePitch(voiceID, f0min, f0max, unit):
|
|
70
69
|
1.3,
|
71
70
|
1.6,
|
72
71
|
)
|
73
|
-
|
74
|
-
[sound,
|
72
|
+
aqpq5_shimmer = call(
|
73
|
+
[sound, point_process],
|
75
74
|
"Get shimmer (apq5)",
|
76
75
|
0,
|
77
76
|
0,
|
@@ -80,8 +79,8 @@ def measurePitch(voiceID, f0min, f0max, unit):
|
|
80
79
|
1.3,
|
81
80
|
1.6,
|
82
81
|
)
|
83
|
-
|
84
|
-
[sound,
|
82
|
+
apq11_shimmer = call(
|
83
|
+
[sound, point_process],
|
85
84
|
"Get shimmer (apq11)",
|
86
85
|
0,
|
87
86
|
0,
|
@@ -90,26 +89,26 @@ def measurePitch(voiceID, f0min, f0max, unit):
|
|
90
89
|
1.3,
|
91
90
|
1.6,
|
92
91
|
)
|
93
|
-
|
94
|
-
[sound,
|
92
|
+
dda_shimmer = call(
|
93
|
+
[sound, point_process], "Get shimmer (dda)", 0, 0, 0.0001, 0.02, 1.3, 1.6
|
95
94
|
)
|
96
95
|
|
97
96
|
return (
|
98
97
|
duration,
|
99
|
-
|
100
|
-
|
98
|
+
mean_f0,
|
99
|
+
stdev_f0,
|
101
100
|
hnr,
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
101
|
+
local_jitter,
|
102
|
+
localabsolute_jitter,
|
103
|
+
rap_jitter,
|
104
|
+
ppq5_jitter,
|
105
|
+
ddp_jitter,
|
106
|
+
local_shimmer,
|
107
|
+
localdb_shimmer,
|
108
|
+
apq3_shimmer,
|
109
|
+
aqpq5_shimmer,
|
110
|
+
apq11_shimmer,
|
111
|
+
dda_shimmer,
|
113
112
|
)
|
114
113
|
|
115
114
|
|
@@ -120,13 +119,13 @@ def measurePitch(voiceID, f0min, f0max, unit):
|
|
120
119
|
# Adapted from: DOI 10.17605/OSF.IO/K2BHS
|
121
120
|
# This function measures formants using Formant Position formula
|
122
121
|
# def measureFormants(sound, wave_file, f0min,f0max):
|
123
|
-
def
|
122
|
+
def measure_formants(sound, f0min, f0max):
|
124
123
|
sound = parselmouth.Sound(sound) # read the sound
|
125
124
|
# pitch = call(sound, "To Pitch (cc)", 0, f0min, 15, 'no', 0.03, 0.45, 0.01, 0.35, 0.14, f0max)
|
126
|
-
|
125
|
+
point_process = call(sound, "To PointProcess (periodic, cc)", f0min, f0max)
|
127
126
|
|
128
127
|
formants = call(sound, "To Formant (burg)", 0.0025, 5, 5000, 0.025, 50)
|
129
|
-
|
128
|
+
num_points = call(point_process, "Get number of points")
|
130
129
|
|
131
130
|
f1_list = []
|
132
131
|
f2_list = []
|
@@ -134,9 +133,9 @@ def measureFormants(sound, f0min, f0max):
|
|
134
133
|
f4_list = []
|
135
134
|
|
136
135
|
# Measure formants only at glottal pulses
|
137
|
-
for point in range(0,
|
136
|
+
for point in range(0, num_points):
|
138
137
|
point += 1
|
139
|
-
t = call(
|
138
|
+
t = call(point_process, "Get time from index", point)
|
140
139
|
f1 = call(formants, "Get value at time", 1, t, "Hertz", "Linear")
|
141
140
|
f2 = call(formants, "Get value at time", 2, t, "Hertz", "Linear")
|
142
141
|
f3 = call(formants, "Get value at time", 3, t, "Hertz", "Linear")
|
@@ -179,7 +178,7 @@ def measureFormants(sound, f0min, f0max):
|
|
179
178
|
# ## This function runs a 2-factor Principle Components Analysis (PCA) on Jitter and Shimmer
|
180
179
|
|
181
180
|
|
182
|
-
def
|
181
|
+
def run_pca(df):
|
183
182
|
# z-score the Jitter and Shimmer measurements
|
184
183
|
measures = [
|
185
184
|
"localJitter",
|
@@ -211,19 +210,19 @@ def runPCA(df):
|
|
211
210
|
# PCA
|
212
211
|
pca = PCA(n_components=2)
|
213
212
|
try:
|
214
|
-
|
215
|
-
if np.any(np.isnan(
|
213
|
+
principal_components = pca.fit_transform(x)
|
214
|
+
if np.any(np.isnan(principal_components)):
|
216
215
|
print("pc is nan")
|
217
|
-
print(f"count: {np.count_nonzero(np.isnan(
|
218
|
-
print(
|
219
|
-
|
216
|
+
print(f"count: {np.count_nonzero(np.isnan(principal_components))}")
|
217
|
+
print(principal_components)
|
218
|
+
principal_components = np.nan_to_num(principal_components)
|
220
219
|
except ValueError:
|
221
220
|
print("need more than one file for pca")
|
222
|
-
|
223
|
-
|
224
|
-
data=
|
221
|
+
principal_components = [[0, 0]]
|
222
|
+
principal_df = pd.DataFrame(
|
223
|
+
data=principal_components, columns=["JitterPCA", "ShimmerPCA"]
|
225
224
|
)
|
226
|
-
return
|
225
|
+
return principal_df
|
227
226
|
|
228
227
|
|
229
228
|
# ## This block of code runs the above functions on all of the '.wav' files in the /audio folder
|
@@ -231,22 +230,21 @@ def runPCA(df):
|
|
231
230
|
|
232
231
|
def compute_features(file_index):
|
233
232
|
# create lists to put the results
|
234
|
-
file_list = []
|
235
233
|
duration_list = []
|
236
|
-
|
237
|
-
|
234
|
+
mean_f0_list = []
|
235
|
+
sd_f0_list = []
|
238
236
|
hnr_list = []
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
237
|
+
local_jitter_list = []
|
238
|
+
localabsolute_jitter_list = []
|
239
|
+
rap_jitter_list = []
|
240
|
+
ppq5_jitter_list = []
|
241
|
+
ddp_jitter_list = []
|
242
|
+
local_shimmer_list = []
|
243
|
+
localdb_shimmer_list = []
|
244
|
+
apq3_shimmer_list = []
|
245
|
+
aqpq5_shimmer_list = []
|
246
|
+
apq11_shimmer_list = []
|
247
|
+
dda_shimmer_list = []
|
250
248
|
f1_mean_list = []
|
251
249
|
f2_mean_list = []
|
252
250
|
f3_mean_list = []
|
@@ -268,21 +266,21 @@ def compute_features(file_index):
|
|
268
266
|
sound = parselmouth.Sound(values=signal, sampling_frequency=sampling_rate)
|
269
267
|
(
|
270
268
|
duration,
|
271
|
-
|
272
|
-
|
269
|
+
mean_f0,
|
270
|
+
stdev_f0,
|
273
271
|
hnr,
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
) =
|
272
|
+
local_jitter,
|
273
|
+
localabsolute_jitter,
|
274
|
+
rap_jitter,
|
275
|
+
ppq5_jitter,
|
276
|
+
ddp_jitter,
|
277
|
+
local_shimmer,
|
278
|
+
localdb_shimmer,
|
279
|
+
apq3_shimmer,
|
280
|
+
aqpq5_shimmer,
|
281
|
+
apq11_shimmer,
|
282
|
+
dda_shimmer,
|
283
|
+
) = measure_pitch(sound, 75, 300, "Hertz")
|
286
284
|
(
|
287
285
|
f1_mean,
|
288
286
|
f2_mean,
|
@@ -292,28 +290,28 @@ def compute_features(file_index):
|
|
292
290
|
f2_median,
|
293
291
|
f3_median,
|
294
292
|
f4_median,
|
295
|
-
) =
|
293
|
+
) = measure_formants(sound, 75, 300)
|
296
294
|
# file_list.append(wave_file) # make an ID list
|
297
295
|
except (statistics.StatisticsError, parselmouth.PraatError) as errors:
|
298
296
|
print(f"error on file {wave_file}: {errors}")
|
299
297
|
|
300
298
|
duration_list.append(duration) # make duration list
|
301
|
-
|
302
|
-
|
299
|
+
mean_f0_list.append(mean_f0) # make a mean F0 list
|
300
|
+
sd_f0_list.append(stdev_f0) # make a sd F0 list
|
303
301
|
hnr_list.append(hnr) # add HNR data
|
304
302
|
|
305
303
|
# add raw jitter and shimmer measures
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
304
|
+
local_jitter_list.append(local_jitter)
|
305
|
+
localabsolute_jitter_list.append(localabsolute_jitter)
|
306
|
+
rap_jitter_list.append(rap_jitter)
|
307
|
+
ppq5_jitter_list.append(ppq5_jitter)
|
308
|
+
ddp_jitter_list.append(ddp_jitter)
|
309
|
+
local_shimmer_list.append(local_shimmer)
|
310
|
+
localdb_shimmer_list.append(localdb_shimmer)
|
311
|
+
apq3_shimmer_list.append(apq3_shimmer)
|
312
|
+
aqpq5_shimmer_list.append(aqpq5_shimmer)
|
313
|
+
apq11_shimmer_list.append(apq11_shimmer)
|
314
|
+
dda_shimmer_list.append(dda_shimmer)
|
317
315
|
|
318
316
|
# add the formant data
|
319
317
|
f1_mean_list.append(f1_mean)
|
@@ -330,20 +328,20 @@ def compute_features(file_index):
|
|
330
328
|
np.column_stack(
|
331
329
|
[
|
332
330
|
duration_list,
|
333
|
-
|
334
|
-
|
331
|
+
mean_f0_list,
|
332
|
+
sd_f0_list,
|
335
333
|
hnr_list,
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
334
|
+
local_jitter_list,
|
335
|
+
localabsolute_jitter_list,
|
336
|
+
rap_jitter_list,
|
337
|
+
ppq5_jitter_list,
|
338
|
+
ddp_jitter_list,
|
339
|
+
local_shimmer_list,
|
340
|
+
localdb_shimmer_list,
|
341
|
+
apq3_shimmer_list,
|
342
|
+
aqpq5_shimmer_list,
|
343
|
+
apq11_shimmer_list,
|
344
|
+
dda_shimmer_list,
|
347
345
|
f1_mean_list,
|
348
346
|
f2_mean_list,
|
349
347
|
f3_mean_list,
|
@@ -382,7 +380,7 @@ def compute_features(file_index):
|
|
382
380
|
)
|
383
381
|
|
384
382
|
# add pca data
|
385
|
-
pcaData =
|
383
|
+
pcaData = run_pca(df) # Run jitter and shimmer PCA
|
386
384
|
df = pd.concat([df, pcaData], axis=1) # Add PCA data
|
387
385
|
# reload the data so it's all numbers
|
388
386
|
df.to_csv("processed_results.csv", index=False)
|
nkululeko/models/model.py
CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
|
|
4
4
|
import numpy as np
|
5
5
|
import nkululeko.glob_conf as glob_conf
|
6
6
|
import sklearn.utils
|
7
|
-
from nkululeko.reporter import Reporter
|
7
|
+
from nkululeko.reporting.reporter import Reporter
|
8
8
|
import ast
|
9
9
|
from sklearn.model_selection import GridSearchCV
|
10
10
|
import pickle
|
nkululeko/models/model_cnn.py
CHANGED
@@ -20,7 +20,7 @@ from PIL import Image
|
|
20
20
|
from nkululeko.utils.util import Util
|
21
21
|
import nkululeko.glob_conf as glob_conf
|
22
22
|
from nkululeko.models.model import Model
|
23
|
-
from nkululeko.reporter import Reporter
|
23
|
+
from nkululeko.reporting.reporter import Reporter
|
24
24
|
from nkululeko.losses.loss_softf1loss import SoftF1Loss
|
25
25
|
|
26
26
|
|
nkululeko/models/model_mlp.py
CHANGED
@@ -1,16 +1,20 @@
|
|
1
1
|
# model_mlp.py
|
2
|
-
from nkululeko.utils.util import Util
|
3
|
-
import nkululeko.glob_conf as glob_conf
|
4
|
-
from nkululeko.models.model import Model
|
5
|
-
from nkululeko.reporter import Reporter
|
6
|
-
import torch
|
7
2
|
import ast
|
8
|
-
import numpy as np
|
9
|
-
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
10
3
|
from collections import OrderedDict
|
11
|
-
from nkululeko.losses.loss_ccc import ConcordanceCorCoeff
|
12
4
|
import os
|
13
5
|
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from audmetric import concordance_cc
|
10
|
+
from audmetric import mean_absolute_error
|
11
|
+
from audmetric import mean_squared_error
|
12
|
+
|
13
|
+
import nkululeko.glob_conf as glob_conf
|
14
|
+
from nkululeko.losses.loss_ccc import ConcordanceCorCoeff
|
15
|
+
from nkululeko.models.model import Model
|
16
|
+
from nkululeko.reporting.reporter import Reporter
|
17
|
+
|
14
18
|
|
15
19
|
class MLP_Reg_model(Model):
|
16
20
|
"""MLP = multi layer perceptron"""
|
@@ -201,7 +205,7 @@ class MLP_Reg_model(Model):
|
|
201
205
|
elif measure == "mae":
|
202
206
|
result = mean_absolute_error(targets.numpy(), predictions.numpy())
|
203
207
|
elif measure == "ccc":
|
204
|
-
result =
|
208
|
+
result = concordance_cc(targets.numpy(), predictions.numpy())
|
205
209
|
else:
|
206
210
|
self.util.error(f"unknown measure: {measure}")
|
207
211
|
return result, targets, predictions
|
@@ -2,25 +2,27 @@ import ast
|
|
2
2
|
import glob
|
3
3
|
import json
|
4
4
|
import math
|
5
|
+
|
6
|
+
from confidence_intervals import evaluate_with_conf_int
|
5
7
|
import matplotlib.pyplot as plt
|
6
8
|
import numpy as np
|
7
9
|
from scipy.stats import pearsonr
|
8
|
-
from sklearn.metrics import
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
from
|
10
|
+
from sklearn.metrics import ConfusionMatrixDisplay
|
11
|
+
from sklearn.metrics import classification_report
|
12
|
+
from sklearn.metrics import confusion_matrix
|
13
|
+
from sklearn.metrics import r2_score
|
14
|
+
from torch import is_tensor
|
15
|
+
|
16
|
+
from audmetric import accuracy
|
17
|
+
from audmetric import concordance_cc
|
18
|
+
from audmetric import mean_absolute_error
|
19
|
+
from audmetric import mean_squared_error
|
20
|
+
from audmetric import unweighted_average_recall
|
19
21
|
|
20
22
|
import nkululeko.glob_conf as glob_conf
|
21
|
-
from nkululeko.reporting.report_item import ReportItem
|
22
|
-
from nkululeko.result import Result
|
23
23
|
from nkululeko.reporting.defines import Header
|
24
|
+
from nkululeko.reporting.report_item import ReportItem
|
25
|
+
from nkululeko.reporting.result import Result
|
24
26
|
from nkululeko.utils.util import Util
|
25
27
|
|
26
28
|
|
@@ -44,11 +46,11 @@ class Reporter:
|
|
44
46
|
self.result.measure = self.MEASURE
|
45
47
|
|
46
48
|
def __init__(self, truths, preds, run, epoch):
|
47
|
-
"""Initialization with ground truth und predictions vector"""
|
49
|
+
"""Initialization with ground truth und predictions vector."""
|
48
50
|
self.util = Util("reporter")
|
49
51
|
self.format = self.util.config_val("PLOT", "format", "png")
|
50
|
-
self.truths = truths
|
51
|
-
self.preds = preds
|
52
|
+
self.truths = np.asarray(truths)
|
53
|
+
self.preds = np.asarray(preds)
|
52
54
|
self.result = Result(0, 0, 0, 0, "unknown")
|
53
55
|
self.run = run
|
54
56
|
self.epoch = epoch
|
@@ -56,30 +58,57 @@ class Reporter:
|
|
56
58
|
self.cont_to_cat = False
|
57
59
|
if len(self.truths) > 0 and len(self.preds) > 0:
|
58
60
|
if self.util.exp_is_classification():
|
59
|
-
|
60
|
-
self.
|
61
|
+
uar, (upper, lower) = evaluate_with_conf_int(
|
62
|
+
self.preds,
|
63
|
+
unweighted_average_recall,
|
64
|
+
self.truths,
|
65
|
+
num_bootstraps=1000,
|
66
|
+
alpha=5,
|
61
67
|
)
|
62
|
-
self.result.
|
68
|
+
self.result.test = uar
|
69
|
+
self.result.set_upper_lower(upper, lower)
|
70
|
+
self.result.loss = 1 - accuracy(self.truths, self.preds)
|
63
71
|
else:
|
64
72
|
# regression experiment
|
65
73
|
if self.measure == "mse":
|
66
|
-
|
74
|
+
test_result, (upper, lower) = evaluate_with_conf_int(
|
75
|
+
self.preds,
|
76
|
+
mean_squared_error,
|
77
|
+
self.truths,
|
78
|
+
num_bootstraps=1000,
|
79
|
+
alpha=5,
|
80
|
+
)
|
67
81
|
elif self.measure == "mae":
|
68
|
-
|
82
|
+
test_result, (upper, lower) = evaluate_with_conf_int(
|
83
|
+
self.preds,
|
84
|
+
mean_absolute_error,
|
85
|
+
self.truths,
|
86
|
+
num_bootstraps=1000,
|
87
|
+
alpha=5,
|
88
|
+
)
|
69
89
|
elif self.measure == "ccc":
|
70
|
-
|
90
|
+
test_result, (upper, lower) = evaluate_with_conf_int(
|
91
|
+
self.preds,
|
92
|
+
concordance_cc,
|
93
|
+
self.truths,
|
94
|
+
num_bootstraps=1000,
|
95
|
+
alpha=5,
|
96
|
+
)
|
97
|
+
|
71
98
|
if math.isnan(self.result.test):
|
72
99
|
self.util.debug(f"Truth: {self.truths}")
|
73
100
|
self.util.debug(f"Predict.: {self.preds}")
|
74
|
-
self.util.debug(
|
101
|
+
self.util.debug("Result is NAN: setting to -1")
|
75
102
|
self.result.test = -1
|
76
103
|
else:
|
77
104
|
self.util.error(f"unknown measure: {self.measure}")
|
78
105
|
|
106
|
+
self.result.test = test_result
|
107
|
+
self.result.set_upper_lower(upper, lower)
|
79
108
|
# train and loss are being set by the model
|
80
109
|
|
81
110
|
def set_id(self, run, epoch):
|
82
|
-
"""Make the report identifiable with run and epoch index"""
|
111
|
+
"""Make the report identifiable with run and epoch index."""
|
83
112
|
self.run = run
|
84
113
|
self.epoch = epoch
|
85
114
|
|
@@ -97,9 +126,12 @@ class Reporter:
|
|
97
126
|
self._plot_confmat(self.truths, self.preds, plot_name, epoch)
|
98
127
|
|
99
128
|
def plot_per_speaker(self, result_df, plot_name, function):
|
100
|
-
"""Plot a confusion matrix with the mode category per speakers
|
129
|
+
"""Plot a confusion matrix with the mode category per speakers.
|
130
|
+
|
101
131
|
Args:
|
102
|
-
|
132
|
+
result_df: a pandas dataframe with columns: preds, truths and speaker.
|
133
|
+
plot_name: name for the figure.
|
134
|
+
function: either mode or mean.
|
103
135
|
"""
|
104
136
|
speakers = result_df.speaker.unique()
|
105
137
|
pred = np.zeros(0)
|
@@ -128,8 +160,14 @@ class Reporter:
|
|
128
160
|
fig_dir = self.util.get_path("fig_dir")
|
129
161
|
labels = glob_conf.labels
|
130
162
|
fig = plt.figure() # figsize=[5, 5]
|
131
|
-
uar
|
132
|
-
|
163
|
+
uar, (upper, lower) = evaluate_with_conf_int(
|
164
|
+
self.preds,
|
165
|
+
unweighted_average_recall,
|
166
|
+
self.truths,
|
167
|
+
num_bootstraps=1000,
|
168
|
+
alpha=5,
|
169
|
+
)
|
170
|
+
acc = accuracy(truths, preds)
|
133
171
|
cm = confusion_matrix(
|
134
172
|
truths, preds, normalize=None
|
135
173
|
) # normalize must be one of {'true', 'pred', 'all', None}
|
@@ -138,6 +176,7 @@ class Reporter:
|
|
138
176
|
f"mismatch between confmatrix dim ({cm.shape[0]}) and labels"
|
139
177
|
f" length ({len(labels)}: {labels})"
|
140
178
|
)
|
179
|
+
|
141
180
|
try:
|
142
181
|
disp = ConfusionMatrixDisplay(
|
143
182
|
confusion_matrix=cm, display_labels=labels
|
@@ -150,12 +189,23 @@ class Reporter:
|
|
150
189
|
|
151
190
|
reg_res = ""
|
152
191
|
if not self.is_classification:
|
153
|
-
reg_res = f"
|
192
|
+
reg_res = f"{self.result.test:.3f} {self.MEASURE}"
|
193
|
+
|
194
|
+
uar_str = str(int(uar * 1000) / 1000.0)[1:]
|
195
|
+
acc_str = str(int(acc * 1000) / 1000.0)[1:]
|
196
|
+
up_str = str(int(upper * 1000) / 1000.0)[1:]
|
197
|
+
low_str = str(int(lower * 1000) / 1000.0)[1:]
|
154
198
|
|
155
199
|
if epoch != 0:
|
156
|
-
plt.title(
|
200
|
+
plt.title(
|
201
|
+
f"Confusion Matrix, UAR: {uar_str} "
|
202
|
+
+ f"(+-{up_str}/{low_str}), {reg_res}, Epoch: {epoch}"
|
203
|
+
)
|
157
204
|
else:
|
158
|
-
plt.title(
|
205
|
+
plt.title(
|
206
|
+
f"Confusion Matrix, UAR: {uar_str} "
|
207
|
+
+ f"(+-{up_str}/{low_str}) {reg_res}"
|
208
|
+
)
|
159
209
|
img_path = f"{fig_dir}{plot_name}.{self.format}"
|
160
210
|
plt.savefig(img_path)
|
161
211
|
fig.clear()
|
@@ -172,9 +222,10 @@ class Reporter:
|
|
172
222
|
)
|
173
223
|
|
174
224
|
res_dir = self.util.get_path("res_dir")
|
175
|
-
|
176
|
-
|
177
|
-
|
225
|
+
rpt = (
|
226
|
+
f"epoch: {epoch}, UAR: {uar_str}"
|
227
|
+
+ f", (+-{up_str}/{low_str}), ACC: {acc_str}"
|
228
|
+
)
|
178
229
|
# print(rpt)
|
179
230
|
self.util.debug(rpt)
|
180
231
|
file_name = f"{res_dir}{self.util.get_exp_name()}_conf.txt"
|
@@ -182,7 +233,7 @@ class Reporter:
|
|
182
233
|
text_file.write(rpt)
|
183
234
|
|
184
235
|
def print_results(self, epoch):
|
185
|
-
"""Print all evaluation values to text file"""
|
236
|
+
"""Print all evaluation values to text file."""
|
186
237
|
res_dir = self.util.get_path("res_dir")
|
187
238
|
file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}.txt"
|
188
239
|
if self.util.exp_is_classification():
|
@@ -279,19 +330,3 @@ class Reporter:
|
|
279
330
|
plt.legend()
|
280
331
|
plt.savefig(f"{fig_dir}{out_name}.{self.format}")
|
281
332
|
plt.close()
|
282
|
-
|
283
|
-
@staticmethod
|
284
|
-
def ccc(ground_truth, prediction):
|
285
|
-
mean_gt = np.mean(ground_truth, 0)
|
286
|
-
mean_pred = np.mean(prediction, 0)
|
287
|
-
var_gt = np.var(ground_truth, 0)
|
288
|
-
var_pred = np.var(prediction, 0)
|
289
|
-
v_pred = prediction - mean_pred
|
290
|
-
v_gt = ground_truth - mean_gt
|
291
|
-
cor = sum(v_pred * v_gt) / (np.sqrt(sum(v_pred**2)) * np.sqrt(sum(v_gt**2)))
|
292
|
-
sd_gt = np.std(ground_truth)
|
293
|
-
sd_pred = np.std(prediction)
|
294
|
-
numerator = 2 * cor * sd_gt * sd_pred
|
295
|
-
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
296
|
-
ccc = numerator / denominator
|
297
|
-
return ccc
|
@@ -12,6 +12,11 @@ class Result:
|
|
12
12
|
def get_result(self):
|
13
13
|
return self.test
|
14
14
|
|
15
|
+
def set_upper_lower(self, upper, lower):
|
16
|
+
"""Set the upper and lower bound of confidence interval."""
|
17
|
+
self.upper = upper
|
18
|
+
self.lower = lower
|
19
|
+
|
15
20
|
def get_test_result(self):
|
16
21
|
return f"test: {self.test:.3f} {self.measure}"
|
17
22
|
|
nkululeko/runmanager.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.81.
|
3
|
+
Version: 0.81.3
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -18,7 +18,9 @@ Requires-Dist: audformat
|
|
18
18
|
Requires-Dist: audinterface
|
19
19
|
Requires-Dist: audiofile
|
20
20
|
Requires-Dist: audiomentations
|
21
|
+
Requires-Dist: audmetric
|
21
22
|
Requires-Dist: audonnx
|
23
|
+
Requires-Dist: confidence-intervals
|
22
24
|
Requires-Dist: datasets
|
23
25
|
Requires-Dist: imageio
|
24
26
|
Requires-Dist: laion-clap
|
@@ -321,6 +323,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
321
323
|
Changelog
|
322
324
|
=========
|
323
325
|
|
326
|
+
Version 0.81.3
|
327
|
+
--------------
|
328
|
+
* added confidence intervals to result reporting
|
329
|
+
|
324
330
|
Version 0.81.2
|
325
331
|
--------------
|
326
332
|
* added a parselmouth.Praat error if pitch out of range
|
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=hx9HFHOlApn60yieWI1qr4PbrKeT3EFK1aaDMxlt5xU,39
|
6
6
|
nkululeko/demo.py,sha256=me8EdjN-zrzClVy9FEmqbTQyDDON88W8vPpWEE8T0cI,2500
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=CQL6DO7QxwmwoB_6DlgDS-pdG1KuvemYJ1NEpMjmMk8,4733
|
@@ -18,10 +18,8 @@ nkululeko/multidb.py,sha256=4ceCu9LFrMGlrcgtz4pWuOQb2KA3jR5uo3FjZgAEBD4,5732
|
|
18
18
|
nkululeko/nkululeko.py,sha256=Ty8cdusXUec9BHml8Gsp1r7DXuvIBMFXUckMpzILBnQ,1966
|
19
19
|
nkululeko/plots.py,sha256=K88ZRPFGX_r03BT742H06Dde20xZYdltv7dxjgUiAFA,23025
|
20
20
|
nkululeko/predict.py,sha256=dRXX-sQVESa7cNi_56S6UkUOa_pV1g_K4xYtYVM1SJs,1876
|
21
|
-
nkululeko/reporter.py,sha256=Gg0dsZclMmdTRUju7yWM3tBVhEZno9VSKD4Tcu_1pJI,11497
|
22
21
|
nkululeko/resample.py,sha256=Yzfr_rInG9afPZFnEjiQ3EKRdMSwyYKVQwt9-yNGJn8,2233
|
23
|
-
nkululeko/
|
24
|
-
nkululeko/runmanager.py,sha256=YNjYLzf4KrtcOyiDLF06YLs3nU3U7n_hY_VH4fYFuh0,7451
|
22
|
+
nkululeko/runmanager.py,sha256=JNBm7JJN8QU8qEqfWr4eS6rkPnBWoVdIUTynHctCPpw,7461
|
25
23
|
nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
|
26
24
|
nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
|
27
25
|
nkululeko/syllable_nuclei.py,sha256=Sky-C__MeUDaxqHnDl2TGLLYOYvsahD35TUjWGeG31k,10047
|
@@ -49,7 +47,7 @@ nkululeko/data/dataset.py,sha256=n6v_vVdA0EsZ-NaTgnYfPlCT4QCcD02mJJb-oD7SaSU,272
|
|
49
47
|
nkululeko/data/dataset_csv.py,sha256=v3lSjF23EVjoP460QOfhdcqbWAlBQWlBOuaYujZoS4s,3407
|
50
48
|
nkululeko/feat_extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
51
49
|
nkululeko/feat_extract/feats_agender.py,sha256=_lAL6IxJDJH2bhIvd7yarTqQryx7FjbQXAgY0mJP-KI,3192
|
52
|
-
nkululeko/feat_extract/feats_agender_agender.py,sha256=
|
50
|
+
nkululeko/feat_extract/feats_agender_agender.py,sha256=5dA7YA-YGxODovMC7ynMk3bnpPjfs0ApvSfjqvoSZY0,3346
|
53
51
|
nkululeko/feat_extract/feats_analyser.py,sha256=_5oz4y-NZCEBgfNP2GZ9WNqQR50Hbykm0TvDVomWP0U,11399
|
54
52
|
nkululeko/feat_extract/feats_audmodel.py,sha256=TRCkLqPgnyWN-OAcO69pPZF2FIbBy5ERb5ZY22qh6iA,3108
|
55
53
|
nkululeko/feat_extract/feats_audmodel_dim.py,sha256=yg39CSR0b54AJyOAlXO3M1ohyY9Rbrjf18pllsoQ03g,3078
|
@@ -69,20 +67,20 @@ nkululeko/feat_extract/feats_trill.py,sha256=PpygJK_W6QoBNeSah9npQPiQlJxLWFn6TSO
|
|
69
67
|
nkululeko/feat_extract/feats_wav2vec2.py,sha256=sFf-WkLUgKUQsFxGO9m2hS3uYoGkv95mZavCEZyWFGA,5072
|
70
68
|
nkululeko/feat_extract/feats_wavlm.py,sha256=RhI0oWIsknnxTVmdnNS_xJO1NnUUR0CUNDWH1yTpNLk,4683
|
71
69
|
nkululeko/feat_extract/featureset.py,sha256=-ynkdor8iX7BFx10aIbB3LfwxrrzPoBGz9kXwyAJO9M,1375
|
72
|
-
nkululeko/feat_extract/feinberg_praat.py,sha256=
|
70
|
+
nkululeko/feat_extract/feinberg_praat.py,sha256=EP9pMALjlKdiYInLQdrZ7MmE499Mq-ISRCgqbqL3Rxc,21304
|
73
71
|
nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
74
72
|
nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
|
75
73
|
nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
|
76
74
|
nkululeko/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
77
|
-
nkululeko/models/model.py,sha256=
|
75
|
+
nkululeko/models/model.py,sha256=8gjRsjSYLWZvfcyTCWhbZ741rkHhx8lxCS2NlSOLP1Y,11648
|
78
76
|
nkululeko/models/model_bayes.py,sha256=wI7-sCwibqXMCHviu349TYjgJXXNXym-Z6ZM83uxlFQ,378
|
79
|
-
nkululeko/models/model_cnn.py,sha256=
|
77
|
+
nkululeko/models/model_cnn.py,sha256=j4NTp7quWqInzOPfpiMrTcfMbXkOsdlFF9ns0tW_ld4,9726
|
80
78
|
nkululeko/models/model_gmm.py,sha256=onovzGBeguwZ-upXtuDLaBw9sd6fDDQslVBOrz1Z8TE,645
|
81
79
|
nkululeko/models/model_knn.py,sha256=5tGqiPo2JTw9VLmD-MXNZKFJ5RTLA6uv_blJDJ9lScA,573
|
82
80
|
nkululeko/models/model_knn_reg.py,sha256=Fbuk6Ku6eyrbbMEk7rB5dwfhvQOMsdZk6HI_0T0gYPw,580
|
83
81
|
nkululeko/models/model_lin_reg.py,sha256=NBTnY2ULuhUBt5ArYQwskZ2Vq4BBDGkqd9SYBFl7Ql4,392
|
84
|
-
nkululeko/models/model_mlp.py,sha256=
|
85
|
-
nkululeko/models/model_mlp_regression.py,sha256=
|
82
|
+
nkululeko/models/model_mlp.py,sha256=lYhGrkqEj6fa6a_tcPrqEoorOpM7t7bjSfFLKEV6pu4,9107
|
83
|
+
nkululeko/models/model_mlp_regression.py,sha256=NP1yEsqvpDcDBWWzDq7W4SHnXC1kE4fAo4A9aBCq3cY,10083
|
86
84
|
nkululeko/models/model_svm.py,sha256=dqDQbfRCtlW3RNqpHDGVsj3ikc131gKURHj5VzAcCr0,867
|
87
85
|
nkululeko/models/model_svr.py,sha256=p-Mb4Bn54yOe1upuHQKNpfj4ttOmQnm9pCB7ECkJkJQ,699
|
88
86
|
nkululeko/models/model_tree.py,sha256=soXjV523eRvRZ-jbX7X_3S73Wto1B9bm7ZzzDmgYzTc,390
|
@@ -94,6 +92,8 @@ nkululeko/reporting/defines.py,sha256=IsY1YgKRMaABpylVKjBJgJ5bNCEbGCVA_E6pivraqS
|
|
94
92
|
nkululeko/reporting/latex_writer.py,sha256=qiCRSmB4KOD_za4oHu5x-PhwjZohzfo8wecMOwlXZwc,1886
|
95
93
|
nkululeko/reporting/report.py,sha256=W0rcigDdjBvxZQ3pZja_gvToILYvaZ1BFtnN2qFRfYI,1060
|
96
94
|
nkululeko/reporting/report_item.py,sha256=siWeGNgo4bAE46YBMNcsdf3jTMTy76BO9Fi6DTvDig4,533
|
95
|
+
nkululeko/reporting/reporter.py,sha256=wwpY0gA-8E8d26XH3DSmXm3X0BkBw2Y0YyEiUiNU_Y0,12670
|
96
|
+
nkululeko/reporting/result.py,sha256=nSN5or-Py2GPRWHkWpGRh7UCi1W0er7WLEHz8fYLk-A,742
|
97
97
|
nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
98
98
|
nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=pmLHuXsaqvcdYxB4PSW9l1mbQWZZBJFhi_CGabqydas,1947
|
99
99
|
nkululeko/segmenting/seg_silero.py,sha256=lLytS38KzARS17omwv8VBw-zz60RVSXGSvZ5EvWlcWQ,3301
|
@@ -101,8 +101,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
101
101
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
102
102
|
nkululeko/utils/stats.py,sha256=29otJpUp1VqbtDKmlLkPPzBmVfTFiHZ70rUdR4860rM,2788
|
103
103
|
nkululeko/utils/util.py,sha256=_Z6OMJ3f-8TdETW9eqJYY5hwNRS5XCt9azzRnqoTTZE,12330
|
104
|
-
nkululeko-0.81.
|
105
|
-
nkululeko-0.81.
|
106
|
-
nkululeko-0.81.
|
107
|
-
nkululeko-0.81.
|
108
|
-
nkululeko-0.81.
|
104
|
+
nkululeko-0.81.3.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
105
|
+
nkululeko-0.81.3.dist-info/METADATA,sha256=72Q5q8KeaEP3I0TrVzswdI4g0Fc0hnCG-kPFZke8YM8,34664
|
106
|
+
nkululeko-0.81.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
107
|
+
nkululeko-0.81.3.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
108
|
+
nkululeko-0.81.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|