nkululeko 0.81.2__py3-none-any.whl → 0.81.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.81.2"
1
+ VERSION="0.81.3"
2
2
  SAMPLING_RATE = 16000
@@ -32,10 +32,11 @@ class AgenderAgenderSet(Featureset):
32
32
  audeer.extract_archive(archive_path, model_root)
33
33
  device = self.util.config_val("MODEL", "device", "cpu")
34
34
  self.model = audonnx.load(model_root, device=device)
35
- pytorch_total_params = sum(p.numel() for p in self.model.parameters())
36
- self.util.debug(
37
- f"initialized agender model with {pytorch_total_params} parameters in total"
38
- )
35
+ # pytorch_total_params = sum(p.numel() for p in self.model.parameters())
36
+ # self.util.debug(
37
+ # f"initialized agender model with {pytorch_total_params} parameters in total"
38
+ # )
39
+ self.util.debug("initialized agender model")
39
40
  self.model_loaded = True
40
41
 
41
42
  def extract(self):
@@ -1,47 +1,46 @@
1
- """
2
- This is a copy of David R. Feinberg's Praat scripts
1
+ """This is a copy of David R. Feinberg's Praat scripts.
3
2
  https://github.com/drfeinberg/PraatScripts
4
- taken June 23rd 2022
3
+ taken June 23rd 2022.
5
4
  """
6
5
 
7
6
  #!/usr/bin/env python3
7
+ import math
8
+ import statistics
9
+
8
10
  import numpy as np
9
11
  import pandas as pd
10
- import math
11
- from tqdm import tqdm
12
12
  import parselmouth
13
- import statistics
14
- from nkululeko.utils.util import Util
15
- import audiofile
16
13
  from parselmouth.praat import call
17
14
  from scipy.stats.mstats import zscore
18
15
  from sklearn.decomposition import PCA
19
- from sklearn.preprocessing import StandardScaler
16
+ from tqdm import tqdm
17
+
18
+ import audiofile
20
19
 
21
20
 
22
21
  # This is the function to measure source acoustics using default male parameters.
23
22
 
24
23
 
25
- def measurePitch(voiceID, f0min, f0max, unit):
26
- sound = parselmouth.Sound(voiceID) # read the sound
24
+ def measure_pitch(voice_id, f0min, f0max, unit):
25
+ sound = parselmouth.Sound(voice_id) # read the sound
27
26
  duration = call(sound, "Get total duration") # duration
28
27
  pitch = call(sound, "To Pitch", 0.0, f0min, f0max) # create a praat pitch object
29
- meanF0 = call(pitch, "Get mean", 0, 0, unit) # get mean pitch
30
- stdevF0 = call(
28
+ mean_f0 = call(pitch, "Get mean", 0, 0, unit) # get mean pitch
29
+ stdev_f0 = call(
31
30
  pitch, "Get standard deviation", 0, 0, unit
32
31
  ) # get standard deviation
33
32
  harmonicity = call(sound, "To Harmonicity (cc)", 0.01, f0min, 0.1, 1.0)
34
33
  hnr = call(harmonicity, "Get mean", 0, 0)
35
- pointProcess = call(sound, "To PointProcess (periodic, cc)", f0min, f0max)
36
- localJitter = call(pointProcess, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
37
- localabsoluteJitter = call(
38
- pointProcess, "Get jitter (local, absolute)", 0, 0, 0.0001, 0.02, 1.3
34
+ point_process = call(sound, "To PointProcess (periodic, cc)", f0min, f0max)
35
+ local_jitter = call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
36
+ localabsolute_jitter = call(
37
+ point_process, "Get jitter (local, absolute)", 0, 0, 0.0001, 0.02, 1.3
39
38
  )
40
- rapJitter = call(pointProcess, "Get jitter (rap)", 0, 0, 0.0001, 0.02, 1.3)
41
- ppq5Jitter = call(pointProcess, "Get jitter (ppq5)", 0, 0, 0.0001, 0.02, 1.3)
42
- ddpJitter = call(pointProcess, "Get jitter (ddp)", 0, 0, 0.0001, 0.02, 1.3)
43
- localShimmer = call(
44
- [sound, pointProcess],
39
+ rap_jitter = call(point_process, "Get jitter (rap)", 0, 0, 0.0001, 0.02, 1.3)
40
+ ppq5_jitter = call(point_process, "Get jitter (ppq5)", 0, 0, 0.0001, 0.02, 1.3)
41
+ ddp_jitter = call(point_process, "Get jitter (ddp)", 0, 0, 0.0001, 0.02, 1.3)
42
+ local_shimmer = call(
43
+ [sound, point_process],
45
44
  "Get shimmer (local)",
46
45
  0,
47
46
  0,
@@ -50,8 +49,8 @@ def measurePitch(voiceID, f0min, f0max, unit):
50
49
  1.3,
51
50
  1.6,
52
51
  )
53
- localdbShimmer = call(
54
- [sound, pointProcess],
52
+ localdb_shimmer = call(
53
+ [sound, point_process],
55
54
  "Get shimmer (local_dB)",
56
55
  0,
57
56
  0,
@@ -60,8 +59,8 @@ def measurePitch(voiceID, f0min, f0max, unit):
60
59
  1.3,
61
60
  1.6,
62
61
  )
63
- apq3Shimmer = call(
64
- [sound, pointProcess],
62
+ apq3_shimmer = call(
63
+ [sound, point_process],
65
64
  "Get shimmer (apq3)",
66
65
  0,
67
66
  0,
@@ -70,8 +69,8 @@ def measurePitch(voiceID, f0min, f0max, unit):
70
69
  1.3,
71
70
  1.6,
72
71
  )
73
- aqpq5Shimmer = call(
74
- [sound, pointProcess],
72
+ aqpq5_shimmer = call(
73
+ [sound, point_process],
75
74
  "Get shimmer (apq5)",
76
75
  0,
77
76
  0,
@@ -80,8 +79,8 @@ def measurePitch(voiceID, f0min, f0max, unit):
80
79
  1.3,
81
80
  1.6,
82
81
  )
83
- apq11Shimmer = call(
84
- [sound, pointProcess],
82
+ apq11_shimmer = call(
83
+ [sound, point_process],
85
84
  "Get shimmer (apq11)",
86
85
  0,
87
86
  0,
@@ -90,26 +89,26 @@ def measurePitch(voiceID, f0min, f0max, unit):
90
89
  1.3,
91
90
  1.6,
92
91
  )
93
- ddaShimmer = call(
94
- [sound, pointProcess], "Get shimmer (dda)", 0, 0, 0.0001, 0.02, 1.3, 1.6
92
+ dda_shimmer = call(
93
+ [sound, point_process], "Get shimmer (dda)", 0, 0, 0.0001, 0.02, 1.3, 1.6
95
94
  )
96
95
 
97
96
  return (
98
97
  duration,
99
- meanF0,
100
- stdevF0,
98
+ mean_f0,
99
+ stdev_f0,
101
100
  hnr,
102
- localJitter,
103
- localabsoluteJitter,
104
- rapJitter,
105
- ppq5Jitter,
106
- ddpJitter,
107
- localShimmer,
108
- localdbShimmer,
109
- apq3Shimmer,
110
- aqpq5Shimmer,
111
- apq11Shimmer,
112
- ddaShimmer,
101
+ local_jitter,
102
+ localabsolute_jitter,
103
+ rap_jitter,
104
+ ppq5_jitter,
105
+ ddp_jitter,
106
+ local_shimmer,
107
+ localdb_shimmer,
108
+ apq3_shimmer,
109
+ aqpq5_shimmer,
110
+ apq11_shimmer,
111
+ dda_shimmer,
113
112
  )
114
113
 
115
114
 
@@ -120,13 +119,13 @@ def measurePitch(voiceID, f0min, f0max, unit):
120
119
  # Adapted from: DOI 10.17605/OSF.IO/K2BHS
121
120
  # This function measures formants using Formant Position formula
122
121
  # def measureFormants(sound, wave_file, f0min,f0max):
123
- def measureFormants(sound, f0min, f0max):
122
+ def measure_formants(sound, f0min, f0max):
124
123
  sound = parselmouth.Sound(sound) # read the sound
125
124
  # pitch = call(sound, "To Pitch (cc)", 0, f0min, 15, 'no', 0.03, 0.45, 0.01, 0.35, 0.14, f0max)
126
- pointProcess = call(sound, "To PointProcess (periodic, cc)", f0min, f0max)
125
+ point_process = call(sound, "To PointProcess (periodic, cc)", f0min, f0max)
127
126
 
128
127
  formants = call(sound, "To Formant (burg)", 0.0025, 5, 5000, 0.025, 50)
129
- numPoints = call(pointProcess, "Get number of points")
128
+ num_points = call(point_process, "Get number of points")
130
129
 
131
130
  f1_list = []
132
131
  f2_list = []
@@ -134,9 +133,9 @@ def measureFormants(sound, f0min, f0max):
134
133
  f4_list = []
135
134
 
136
135
  # Measure formants only at glottal pulses
137
- for point in range(0, numPoints):
136
+ for point in range(0, num_points):
138
137
  point += 1
139
- t = call(pointProcess, "Get time from index", point)
138
+ t = call(point_process, "Get time from index", point)
140
139
  f1 = call(formants, "Get value at time", 1, t, "Hertz", "Linear")
141
140
  f2 = call(formants, "Get value at time", 2, t, "Hertz", "Linear")
142
141
  f3 = call(formants, "Get value at time", 3, t, "Hertz", "Linear")
@@ -179,7 +178,7 @@ def measureFormants(sound, f0min, f0max):
179
178
  # ## This function runs a 2-factor Principle Components Analysis (PCA) on Jitter and Shimmer
180
179
 
181
180
 
182
- def runPCA(df):
181
+ def run_pca(df):
183
182
  # z-score the Jitter and Shimmer measurements
184
183
  measures = [
185
184
  "localJitter",
@@ -211,19 +210,19 @@ def runPCA(df):
211
210
  # PCA
212
211
  pca = PCA(n_components=2)
213
212
  try:
214
- principalComponents = pca.fit_transform(x)
215
- if np.any(np.isnan(principalComponents)):
213
+ principal_components = pca.fit_transform(x)
214
+ if np.any(np.isnan(principal_components)):
216
215
  print("pc is nan")
217
- print(f"count: {np.count_nonzero(np.isnan(principalComponents))}")
218
- print(principalComponents)
219
- principalComponents = np.nan_to_num(principalComponents)
216
+ print(f"count: {np.count_nonzero(np.isnan(principal_components))}")
217
+ print(principal_components)
218
+ principal_components = np.nan_to_num(principal_components)
220
219
  except ValueError:
221
220
  print("need more than one file for pca")
222
- principalComponents = [[0, 0]]
223
- principalDf = pd.DataFrame(
224
- data=principalComponents, columns=["JitterPCA", "ShimmerPCA"]
221
+ principal_components = [[0, 0]]
222
+ principal_df = pd.DataFrame(
223
+ data=principal_components, columns=["JitterPCA", "ShimmerPCA"]
225
224
  )
226
- return principalDf
225
+ return principal_df
227
226
 
228
227
 
229
228
  # ## This block of code runs the above functions on all of the '.wav' files in the /audio folder
@@ -231,22 +230,21 @@ def runPCA(df):
231
230
 
232
231
  def compute_features(file_index):
233
232
  # create lists to put the results
234
- file_list = []
235
233
  duration_list = []
236
- mean_F0_list = []
237
- sd_F0_list = []
234
+ mean_f0_list = []
235
+ sd_f0_list = []
238
236
  hnr_list = []
239
- localJitter_list = []
240
- localabsoluteJitter_list = []
241
- rapJitter_list = []
242
- ppq5Jitter_list = []
243
- ddpJitter_list = []
244
- localShimmer_list = []
245
- localdbShimmer_list = []
246
- apq3Shimmer_list = []
247
- aqpq5Shimmer_list = []
248
- apq11Shimmer_list = []
249
- ddaShimmer_list = []
237
+ local_jitter_list = []
238
+ localabsolute_jitter_list = []
239
+ rap_jitter_list = []
240
+ ppq5_jitter_list = []
241
+ ddp_jitter_list = []
242
+ local_shimmer_list = []
243
+ localdb_shimmer_list = []
244
+ apq3_shimmer_list = []
245
+ aqpq5_shimmer_list = []
246
+ apq11_shimmer_list = []
247
+ dda_shimmer_list = []
250
248
  f1_mean_list = []
251
249
  f2_mean_list = []
252
250
  f3_mean_list = []
@@ -268,21 +266,21 @@ def compute_features(file_index):
268
266
  sound = parselmouth.Sound(values=signal, sampling_frequency=sampling_rate)
269
267
  (
270
268
  duration,
271
- meanF0,
272
- stdevF0,
269
+ mean_f0,
270
+ stdev_f0,
273
271
  hnr,
274
- localJitter,
275
- localabsoluteJitter,
276
- rapJitter,
277
- ppq5Jitter,
278
- ddpJitter,
279
- localShimmer,
280
- localdbShimmer,
281
- apq3Shimmer,
282
- aqpq5Shimmer,
283
- apq11Shimmer,
284
- ddaShimmer,
285
- ) = measurePitch(sound, 75, 300, "Hertz")
272
+ local_jitter,
273
+ localabsolute_jitter,
274
+ rap_jitter,
275
+ ppq5_jitter,
276
+ ddp_jitter,
277
+ local_shimmer,
278
+ localdb_shimmer,
279
+ apq3_shimmer,
280
+ aqpq5_shimmer,
281
+ apq11_shimmer,
282
+ dda_shimmer,
283
+ ) = measure_pitch(sound, 75, 300, "Hertz")
286
284
  (
287
285
  f1_mean,
288
286
  f2_mean,
@@ -292,28 +290,28 @@ def compute_features(file_index):
292
290
  f2_median,
293
291
  f3_median,
294
292
  f4_median,
295
- ) = measureFormants(sound, 75, 300)
293
+ ) = measure_formants(sound, 75, 300)
296
294
  # file_list.append(wave_file) # make an ID list
297
295
  except (statistics.StatisticsError, parselmouth.PraatError) as errors:
298
296
  print(f"error on file {wave_file}: {errors}")
299
297
 
300
298
  duration_list.append(duration) # make duration list
301
- mean_F0_list.append(meanF0) # make a mean F0 list
302
- sd_F0_list.append(stdevF0) # make a sd F0 list
299
+ mean_f0_list.append(mean_f0) # make a mean F0 list
300
+ sd_f0_list.append(stdev_f0) # make a sd F0 list
303
301
  hnr_list.append(hnr) # add HNR data
304
302
 
305
303
  # add raw jitter and shimmer measures
306
- localJitter_list.append(localJitter)
307
- localabsoluteJitter_list.append(localabsoluteJitter)
308
- rapJitter_list.append(rapJitter)
309
- ppq5Jitter_list.append(ppq5Jitter)
310
- ddpJitter_list.append(ddpJitter)
311
- localShimmer_list.append(localShimmer)
312
- localdbShimmer_list.append(localdbShimmer)
313
- apq3Shimmer_list.append(apq3Shimmer)
314
- aqpq5Shimmer_list.append(aqpq5Shimmer)
315
- apq11Shimmer_list.append(apq11Shimmer)
316
- ddaShimmer_list.append(ddaShimmer)
304
+ local_jitter_list.append(local_jitter)
305
+ localabsolute_jitter_list.append(localabsolute_jitter)
306
+ rap_jitter_list.append(rap_jitter)
307
+ ppq5_jitter_list.append(ppq5_jitter)
308
+ ddp_jitter_list.append(ddp_jitter)
309
+ local_shimmer_list.append(local_shimmer)
310
+ localdb_shimmer_list.append(localdb_shimmer)
311
+ apq3_shimmer_list.append(apq3_shimmer)
312
+ aqpq5_shimmer_list.append(aqpq5_shimmer)
313
+ apq11_shimmer_list.append(apq11_shimmer)
314
+ dda_shimmer_list.append(dda_shimmer)
317
315
 
318
316
  # add the formant data
319
317
  f1_mean_list.append(f1_mean)
@@ -330,20 +328,20 @@ def compute_features(file_index):
330
328
  np.column_stack(
331
329
  [
332
330
  duration_list,
333
- mean_F0_list,
334
- sd_F0_list,
331
+ mean_f0_list,
332
+ sd_f0_list,
335
333
  hnr_list,
336
- localJitter_list,
337
- localabsoluteJitter_list,
338
- rapJitter_list,
339
- ppq5Jitter_list,
340
- ddpJitter_list,
341
- localShimmer_list,
342
- localdbShimmer_list,
343
- apq3Shimmer_list,
344
- aqpq5Shimmer_list,
345
- apq11Shimmer_list,
346
- ddaShimmer_list,
334
+ local_jitter_list,
335
+ localabsolute_jitter_list,
336
+ rap_jitter_list,
337
+ ppq5_jitter_list,
338
+ ddp_jitter_list,
339
+ local_shimmer_list,
340
+ localdb_shimmer_list,
341
+ apq3_shimmer_list,
342
+ aqpq5_shimmer_list,
343
+ apq11_shimmer_list,
344
+ dda_shimmer_list,
347
345
  f1_mean_list,
348
346
  f2_mean_list,
349
347
  f3_mean_list,
@@ -382,7 +380,7 @@ def compute_features(file_index):
382
380
  )
383
381
 
384
382
  # add pca data
385
- pcaData = runPCA(df) # Run jitter and shimmer PCA
383
+ pcaData = run_pca(df) # Run jitter and shimmer PCA
386
384
  df = pd.concat([df, pcaData], axis=1) # Add PCA data
387
385
  # reload the data so it's all numbers
388
386
  df.to_csv("processed_results.csv", index=False)
nkululeko/models/model.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
4
  import numpy as np
5
5
  import nkululeko.glob_conf as glob_conf
6
6
  import sklearn.utils
7
- from nkululeko.reporter import Reporter
7
+ from nkululeko.reporting.reporter import Reporter
8
8
  import ast
9
9
  from sklearn.model_selection import GridSearchCV
10
10
  import pickle
@@ -20,7 +20,7 @@ from PIL import Image
20
20
  from nkululeko.utils.util import Util
21
21
  import nkululeko.glob_conf as glob_conf
22
22
  from nkululeko.models.model import Model
23
- from nkululeko.reporter import Reporter
23
+ from nkululeko.reporting.reporter import Reporter
24
24
  from nkululeko.losses.loss_softf1loss import SoftF1Loss
25
25
 
26
26
 
@@ -2,7 +2,7 @@
2
2
  from nkululeko.utils.util import Util
3
3
  import nkululeko.glob_conf as glob_conf
4
4
  from nkululeko.models.model import Model
5
- from nkululeko.reporter import Reporter
5
+ from nkululeko.reporting.reporter import Reporter
6
6
  import torch
7
7
  import ast
8
8
  import numpy as np
@@ -1,16 +1,20 @@
1
1
  # model_mlp.py
2
- from nkululeko.utils.util import Util
3
- import nkululeko.glob_conf as glob_conf
4
- from nkululeko.models.model import Model
5
- from nkululeko.reporter import Reporter
6
- import torch
7
2
  import ast
8
- import numpy as np
9
- from sklearn.metrics import mean_squared_error, mean_absolute_error
10
3
  from collections import OrderedDict
11
- from nkululeko.losses.loss_ccc import ConcordanceCorCoeff
12
4
  import os
13
5
 
6
+ import numpy as np
7
+ import torch
8
+
9
+ from audmetric import concordance_cc
10
+ from audmetric import mean_absolute_error
11
+ from audmetric import mean_squared_error
12
+
13
+ import nkululeko.glob_conf as glob_conf
14
+ from nkululeko.losses.loss_ccc import ConcordanceCorCoeff
15
+ from nkululeko.models.model import Model
16
+ from nkululeko.reporting.reporter import Reporter
17
+
14
18
 
15
19
  class MLP_Reg_model(Model):
16
20
  """MLP = multi layer perceptron"""
@@ -201,7 +205,7 @@ class MLP_Reg_model(Model):
201
205
  elif measure == "mae":
202
206
  result = mean_absolute_error(targets.numpy(), predictions.numpy())
203
207
  elif measure == "ccc":
204
- result = Reporter.ccc(targets.numpy(), predictions.numpy())
208
+ result = concordance_cc(targets.numpy(), predictions.numpy())
205
209
  else:
206
210
  self.util.error(f"unknown measure: {measure}")
207
211
  return result, targets, predictions
@@ -2,25 +2,27 @@ import ast
2
2
  import glob
3
3
  import json
4
4
  import math
5
+
6
+ from confidence_intervals import evaluate_with_conf_int
5
7
  import matplotlib.pyplot as plt
6
8
  import numpy as np
7
9
  from scipy.stats import pearsonr
8
- from sklearn.metrics import (
9
- ConfusionMatrixDisplay,
10
- accuracy_score,
11
- classification_report,
12
- confusion_matrix,
13
- mean_squared_error,
14
- mean_absolute_error,
15
- r2_score,
16
- recall_score,
17
- )
18
- from sklearn.utils import resample
10
+ from sklearn.metrics import ConfusionMatrixDisplay
11
+ from sklearn.metrics import classification_report
12
+ from sklearn.metrics import confusion_matrix
13
+ from sklearn.metrics import r2_score
14
+ from torch import is_tensor
15
+
16
+ from audmetric import accuracy
17
+ from audmetric import concordance_cc
18
+ from audmetric import mean_absolute_error
19
+ from audmetric import mean_squared_error
20
+ from audmetric import unweighted_average_recall
19
21
 
20
22
  import nkululeko.glob_conf as glob_conf
21
- from nkululeko.reporting.report_item import ReportItem
22
- from nkululeko.result import Result
23
23
  from nkululeko.reporting.defines import Header
24
+ from nkululeko.reporting.report_item import ReportItem
25
+ from nkululeko.reporting.result import Result
24
26
  from nkululeko.utils.util import Util
25
27
 
26
28
 
@@ -44,11 +46,11 @@ class Reporter:
44
46
  self.result.measure = self.MEASURE
45
47
 
46
48
  def __init__(self, truths, preds, run, epoch):
47
- """Initialization with ground truth und predictions vector"""
49
+ """Initialization with ground truth und predictions vector."""
48
50
  self.util = Util("reporter")
49
51
  self.format = self.util.config_val("PLOT", "format", "png")
50
- self.truths = truths
51
- self.preds = preds
52
+ self.truths = np.asarray(truths)
53
+ self.preds = np.asarray(preds)
52
54
  self.result = Result(0, 0, 0, 0, "unknown")
53
55
  self.run = run
54
56
  self.epoch = epoch
@@ -56,30 +58,57 @@ class Reporter:
56
58
  self.cont_to_cat = False
57
59
  if len(self.truths) > 0 and len(self.preds) > 0:
58
60
  if self.util.exp_is_classification():
59
- self.result.test = recall_score(
60
- self.truths, self.preds, average="macro"
61
+ uar, (upper, lower) = evaluate_with_conf_int(
62
+ self.preds,
63
+ unweighted_average_recall,
64
+ self.truths,
65
+ num_bootstraps=1000,
66
+ alpha=5,
61
67
  )
62
- self.result.loss = 1 - accuracy_score(self.truths, self.preds)
68
+ self.result.test = uar
69
+ self.result.set_upper_lower(upper, lower)
70
+ self.result.loss = 1 - accuracy(self.truths, self.preds)
63
71
  else:
64
72
  # regression experiment
65
73
  if self.measure == "mse":
66
- self.result.test = mean_squared_error(self.truths, self.preds)
74
+ test_result, (upper, lower) = evaluate_with_conf_int(
75
+ self.preds,
76
+ mean_squared_error,
77
+ self.truths,
78
+ num_bootstraps=1000,
79
+ alpha=5,
80
+ )
67
81
  elif self.measure == "mae":
68
- self.result.test = mean_absolute_error(self.truths, self.preds)
82
+ test_result, (upper, lower) = evaluate_with_conf_int(
83
+ self.preds,
84
+ mean_absolute_error,
85
+ self.truths,
86
+ num_bootstraps=1000,
87
+ alpha=5,
88
+ )
69
89
  elif self.measure == "ccc":
70
- self.result.test = self.ccc(self.truths, self.preds)
90
+ test_result, (upper, lower) = evaluate_with_conf_int(
91
+ self.preds,
92
+ concordance_cc,
93
+ self.truths,
94
+ num_bootstraps=1000,
95
+ alpha=5,
96
+ )
97
+
71
98
  if math.isnan(self.result.test):
72
99
  self.util.debug(f"Truth: {self.truths}")
73
100
  self.util.debug(f"Predict.: {self.preds}")
74
- self.util.debug(f"Result is NAN: setting to -1")
101
+ self.util.debug("Result is NAN: setting to -1")
75
102
  self.result.test = -1
76
103
  else:
77
104
  self.util.error(f"unknown measure: {self.measure}")
78
105
 
106
+ self.result.test = test_result
107
+ self.result.set_upper_lower(upper, lower)
79
108
  # train and loss are being set by the model
80
109
 
81
110
  def set_id(self, run, epoch):
82
- """Make the report identifiable with run and epoch index"""
111
+ """Make the report identifiable with run and epoch index."""
83
112
  self.run = run
84
113
  self.epoch = epoch
85
114
 
@@ -97,9 +126,12 @@ class Reporter:
97
126
  self._plot_confmat(self.truths, self.preds, plot_name, epoch)
98
127
 
99
128
  def plot_per_speaker(self, result_df, plot_name, function):
100
- """Plot a confusion matrix with the mode category per speakers
129
+ """Plot a confusion matrix with the mode category per speakers.
130
+
101
131
  Args:
102
- * result_df: a pandas dataframe with columns: preds, truths and speaker
132
+ result_df: a pandas dataframe with columns: preds, truths and speaker.
133
+ plot_name: name for the figure.
134
+ function: either mode or mean.
103
135
  """
104
136
  speakers = result_df.speaker.unique()
105
137
  pred = np.zeros(0)
@@ -128,8 +160,14 @@ class Reporter:
128
160
  fig_dir = self.util.get_path("fig_dir")
129
161
  labels = glob_conf.labels
130
162
  fig = plt.figure() # figsize=[5, 5]
131
- uar = recall_score(truths, preds, average="macro")
132
- acc = accuracy_score(truths, preds)
163
+ uar, (upper, lower) = evaluate_with_conf_int(
164
+ self.preds,
165
+ unweighted_average_recall,
166
+ self.truths,
167
+ num_bootstraps=1000,
168
+ alpha=5,
169
+ )
170
+ acc = accuracy(truths, preds)
133
171
  cm = confusion_matrix(
134
172
  truths, preds, normalize=None
135
173
  ) # normalize must be one of {'true', 'pred', 'all', None}
@@ -138,6 +176,7 @@ class Reporter:
138
176
  f"mismatch between confmatrix dim ({cm.shape[0]}) and labels"
139
177
  f" length ({len(labels)}: {labels})"
140
178
  )
179
+
141
180
  try:
142
181
  disp = ConfusionMatrixDisplay(
143
182
  confusion_matrix=cm, display_labels=labels
@@ -150,12 +189,23 @@ class Reporter:
150
189
 
151
190
  reg_res = ""
152
191
  if not self.is_classification:
153
- reg_res = f", {self.MEASURE}: {self.result.test:.3f}"
192
+ reg_res = f"{self.result.test:.3f} {self.MEASURE}"
193
+
194
+ uar_str = str(int(uar * 1000) / 1000.0)[1:]
195
+ acc_str = str(int(acc * 1000) / 1000.0)[1:]
196
+ up_str = str(int(upper * 1000) / 1000.0)[1:]
197
+ low_str = str(int(lower * 1000) / 1000.0)[1:]
154
198
 
155
199
  if epoch != 0:
156
- plt.title(f"Confusion Matrix, UAR: {uar:.3f}{reg_res}, Epoch: {epoch}")
200
+ plt.title(
201
+ f"Confusion Matrix, UAR: {uar_str} "
202
+ + f"(+-{up_str}/{low_str}), {reg_res}, Epoch: {epoch}"
203
+ )
157
204
  else:
158
- plt.title(f"Confusion Matrix, UAR: {uar:.3f}{reg_res}")
205
+ plt.title(
206
+ f"Confusion Matrix, UAR: {uar_str} "
207
+ + f"(+-{up_str}/{low_str}) {reg_res}"
208
+ )
159
209
  img_path = f"{fig_dir}{plot_name}.{self.format}"
160
210
  plt.savefig(img_path)
161
211
  fig.clear()
@@ -172,9 +222,10 @@ class Reporter:
172
222
  )
173
223
 
174
224
  res_dir = self.util.get_path("res_dir")
175
- uar = int(uar * 1000) / 1000.0
176
- acc = int(acc * 1000) / 1000.0
177
- rpt = f"epoch: {epoch}, UAR: {uar}, ACC: {acc}"
225
+ rpt = (
226
+ f"epoch: {epoch}, UAR: {uar_str}"
227
+ + f", (+-{up_str}/{low_str}), ACC: {acc_str}"
228
+ )
178
229
  # print(rpt)
179
230
  self.util.debug(rpt)
180
231
  file_name = f"{res_dir}{self.util.get_exp_name()}_conf.txt"
@@ -182,7 +233,7 @@ class Reporter:
182
233
  text_file.write(rpt)
183
234
 
184
235
  def print_results(self, epoch):
185
- """Print all evaluation values to text file"""
236
+ """Print all evaluation values to text file."""
186
237
  res_dir = self.util.get_path("res_dir")
187
238
  file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}.txt"
188
239
  if self.util.exp_is_classification():
@@ -279,19 +330,3 @@ class Reporter:
279
330
  plt.legend()
280
331
  plt.savefig(f"{fig_dir}{out_name}.{self.format}")
281
332
  plt.close()
282
-
283
- @staticmethod
284
- def ccc(ground_truth, prediction):
285
- mean_gt = np.mean(ground_truth, 0)
286
- mean_pred = np.mean(prediction, 0)
287
- var_gt = np.var(ground_truth, 0)
288
- var_pred = np.var(prediction, 0)
289
- v_pred = prediction - mean_pred
290
- v_gt = ground_truth - mean_gt
291
- cor = sum(v_pred * v_gt) / (np.sqrt(sum(v_pred**2)) * np.sqrt(sum(v_gt**2)))
292
- sd_gt = np.std(ground_truth)
293
- sd_pred = np.std(prediction)
294
- numerator = 2 * cor * sd_gt * sd_pred
295
- denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
296
- ccc = numerator / denominator
297
- return ccc
@@ -12,6 +12,11 @@ class Result:
12
12
  def get_result(self):
13
13
  return self.test
14
14
 
15
+ def set_upper_lower(self, upper, lower):
16
+ """Set the upper and lower bound of confidence interval."""
17
+ self.upper = upper
18
+ self.lower = lower
19
+
15
20
  def get_test_result(self):
16
21
  return f"test: {self.test:.3f} {self.measure}"
17
22
 
nkululeko/runmanager.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # runmanager.py
2
2
 
3
- from nkululeko.reporter import Reporter
3
+ from nkululeko.reporting.reporter import Reporter
4
4
  from nkululeko.utils.util import Util
5
5
  import nkululeko.glob_conf as glob_conf
6
6
  from nkululeko.modelrunner import Modelrunner
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.81.2
3
+ Version: 0.81.3
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -18,7 +18,9 @@ Requires-Dist: audformat
18
18
  Requires-Dist: audinterface
19
19
  Requires-Dist: audiofile
20
20
  Requires-Dist: audiomentations
21
+ Requires-Dist: audmetric
21
22
  Requires-Dist: audonnx
23
+ Requires-Dist: confidence-intervals
22
24
  Requires-Dist: datasets
23
25
  Requires-Dist: imageio
24
26
  Requires-Dist: laion-clap
@@ -321,6 +323,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
321
323
  Changelog
322
324
  =========
323
325
 
326
+ Version 0.81.3
327
+ --------------
328
+ * added confidence intervals to result reporting
329
+
324
330
  Version 0.81.2
325
331
  --------------
326
332
  * added a parselmouth.Praat error if pitch out of range
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
3
3
  nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
4
4
  nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
5
- nkululeko/constants.py,sha256=zujT9J62h5BIBCxzigDt23S5plsfoyutXsGMdK_xkAM,39
5
+ nkululeko/constants.py,sha256=hx9HFHOlApn60yieWI1qr4PbrKeT3EFK1aaDMxlt5xU,39
6
6
  nkululeko/demo.py,sha256=me8EdjN-zrzClVy9FEmqbTQyDDON88W8vPpWEE8T0cI,2500
7
7
  nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
8
8
  nkululeko/demo_predictor.py,sha256=CQL6DO7QxwmwoB_6DlgDS-pdG1KuvemYJ1NEpMjmMk8,4733
@@ -18,10 +18,8 @@ nkululeko/multidb.py,sha256=4ceCu9LFrMGlrcgtz4pWuOQb2KA3jR5uo3FjZgAEBD4,5732
18
18
  nkululeko/nkululeko.py,sha256=Ty8cdusXUec9BHml8Gsp1r7DXuvIBMFXUckMpzILBnQ,1966
19
19
  nkululeko/plots.py,sha256=K88ZRPFGX_r03BT742H06Dde20xZYdltv7dxjgUiAFA,23025
20
20
  nkululeko/predict.py,sha256=dRXX-sQVESa7cNi_56S6UkUOa_pV1g_K4xYtYVM1SJs,1876
21
- nkululeko/reporter.py,sha256=Gg0dsZclMmdTRUju7yWM3tBVhEZno9VSKD4Tcu_1pJI,11497
22
21
  nkululeko/resample.py,sha256=Yzfr_rInG9afPZFnEjiQ3EKRdMSwyYKVQwt9-yNGJn8,2233
23
- nkululeko/result.py,sha256=kLeEyHQxPzqgCcTadgwvGd2b8gJGpdaf5feHqshjPH0,574
24
- nkululeko/runmanager.py,sha256=YNjYLzf4KrtcOyiDLF06YLs3nU3U7n_hY_VH4fYFuh0,7451
22
+ nkululeko/runmanager.py,sha256=JNBm7JJN8QU8qEqfWr4eS6rkPnBWoVdIUTynHctCPpw,7461
25
23
  nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
26
24
  nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
27
25
  nkululeko/syllable_nuclei.py,sha256=Sky-C__MeUDaxqHnDl2TGLLYOYvsahD35TUjWGeG31k,10047
@@ -49,7 +47,7 @@ nkululeko/data/dataset.py,sha256=n6v_vVdA0EsZ-NaTgnYfPlCT4QCcD02mJJb-oD7SaSU,272
49
47
  nkululeko/data/dataset_csv.py,sha256=v3lSjF23EVjoP460QOfhdcqbWAlBQWlBOuaYujZoS4s,3407
50
48
  nkululeko/feat_extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
49
  nkululeko/feat_extract/feats_agender.py,sha256=_lAL6IxJDJH2bhIvd7yarTqQryx7FjbQXAgY0mJP-KI,3192
52
- nkululeko/feat_extract/feats_agender_agender.py,sha256=ckQN8K02vdPLUWIylxg5Z4X145gBlPbSDbyBRjabLD0,3278
50
+ nkululeko/feat_extract/feats_agender_agender.py,sha256=5dA7YA-YGxODovMC7ynMk3bnpPjfs0ApvSfjqvoSZY0,3346
53
51
  nkululeko/feat_extract/feats_analyser.py,sha256=_5oz4y-NZCEBgfNP2GZ9WNqQR50Hbykm0TvDVomWP0U,11399
54
52
  nkululeko/feat_extract/feats_audmodel.py,sha256=TRCkLqPgnyWN-OAcO69pPZF2FIbBy5ERb5ZY22qh6iA,3108
55
53
  nkululeko/feat_extract/feats_audmodel_dim.py,sha256=yg39CSR0b54AJyOAlXO3M1ohyY9Rbrjf18pllsoQ03g,3078
@@ -69,20 +67,20 @@ nkululeko/feat_extract/feats_trill.py,sha256=PpygJK_W6QoBNeSah9npQPiQlJxLWFn6TSO
69
67
  nkululeko/feat_extract/feats_wav2vec2.py,sha256=sFf-WkLUgKUQsFxGO9m2hS3uYoGkv95mZavCEZyWFGA,5072
70
68
  nkululeko/feat_extract/feats_wavlm.py,sha256=RhI0oWIsknnxTVmdnNS_xJO1NnUUR0CUNDWH1yTpNLk,4683
71
69
  nkululeko/feat_extract/featureset.py,sha256=-ynkdor8iX7BFx10aIbB3LfwxrrzPoBGz9kXwyAJO9M,1375
72
- nkululeko/feat_extract/feinberg_praat.py,sha256=7V1VhVMu4QrXkdcXpmqCbpStXfpmOHtfx5GzxXWukz8,21287
70
+ nkululeko/feat_extract/feinberg_praat.py,sha256=EP9pMALjlKdiYInLQdrZ7MmE499Mq-ISRCgqbqL3Rxc,21304
73
71
  nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
74
72
  nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
75
73
  nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
76
74
  nkululeko/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
- nkululeko/models/model.py,sha256=SZ2HQ3KiF5fcmrTcvko1E95EQQeFIaPCG90DvZVHbBA,11638
75
+ nkululeko/models/model.py,sha256=8gjRsjSYLWZvfcyTCWhbZ741rkHhx8lxCS2NlSOLP1Y,11648
78
76
  nkululeko/models/model_bayes.py,sha256=wI7-sCwibqXMCHviu349TYjgJXXNXym-Z6ZM83uxlFQ,378
79
- nkululeko/models/model_cnn.py,sha256=iyXeRsAMVeRST1j_D2AUngE02CtVkg6vWwQc1BOaBl0,9716
77
+ nkululeko/models/model_cnn.py,sha256=j4NTp7quWqInzOPfpiMrTcfMbXkOsdlFF9ns0tW_ld4,9726
80
78
  nkululeko/models/model_gmm.py,sha256=onovzGBeguwZ-upXtuDLaBw9sd6fDDQslVBOrz1Z8TE,645
81
79
  nkululeko/models/model_knn.py,sha256=5tGqiPo2JTw9VLmD-MXNZKFJ5RTLA6uv_blJDJ9lScA,573
82
80
  nkululeko/models/model_knn_reg.py,sha256=Fbuk6Ku6eyrbbMEk7rB5dwfhvQOMsdZk6HI_0T0gYPw,580
83
81
  nkululeko/models/model_lin_reg.py,sha256=NBTnY2ULuhUBt5ArYQwskZ2Vq4BBDGkqd9SYBFl7Ql4,392
84
- nkululeko/models/model_mlp.py,sha256=IjiiupLxm5ddb73-eU5Ad79Gb6enurR1fgGY-7NkbFc,9097
85
- nkululeko/models/model_mlp_regression.py,sha256=F0SaU1qAjnGmTTg-ti1s-XmFYVUYxSV0TJw0_jMxlKU,10054
82
+ nkululeko/models/model_mlp.py,sha256=lYhGrkqEj6fa6a_tcPrqEoorOpM7t7bjSfFLKEV6pu4,9107
83
+ nkululeko/models/model_mlp_regression.py,sha256=NP1yEsqvpDcDBWWzDq7W4SHnXC1kE4fAo4A9aBCq3cY,10083
86
84
  nkululeko/models/model_svm.py,sha256=dqDQbfRCtlW3RNqpHDGVsj3ikc131gKURHj5VzAcCr0,867
87
85
  nkululeko/models/model_svr.py,sha256=p-Mb4Bn54yOe1upuHQKNpfj4ttOmQnm9pCB7ECkJkJQ,699
88
86
  nkululeko/models/model_tree.py,sha256=soXjV523eRvRZ-jbX7X_3S73Wto1B9bm7ZzzDmgYzTc,390
@@ -94,6 +92,8 @@ nkululeko/reporting/defines.py,sha256=IsY1YgKRMaABpylVKjBJgJ5bNCEbGCVA_E6pivraqS
94
92
  nkululeko/reporting/latex_writer.py,sha256=qiCRSmB4KOD_za4oHu5x-PhwjZohzfo8wecMOwlXZwc,1886
95
93
  nkululeko/reporting/report.py,sha256=W0rcigDdjBvxZQ3pZja_gvToILYvaZ1BFtnN2qFRfYI,1060
96
94
  nkululeko/reporting/report_item.py,sha256=siWeGNgo4bAE46YBMNcsdf3jTMTy76BO9Fi6DTvDig4,533
95
+ nkululeko/reporting/reporter.py,sha256=wwpY0gA-8E8d26XH3DSmXm3X0BkBw2Y0YyEiUiNU_Y0,12670
96
+ nkululeko/reporting/result.py,sha256=nSN5or-Py2GPRWHkWpGRh7UCi1W0er7WLEHz8fYLk-A,742
97
97
  nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
98
98
  nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=pmLHuXsaqvcdYxB4PSW9l1mbQWZZBJFhi_CGabqydas,1947
99
99
  nkululeko/segmenting/seg_silero.py,sha256=lLytS38KzARS17omwv8VBw-zz60RVSXGSvZ5EvWlcWQ,3301
@@ -101,8 +101,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
101
101
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
102
102
  nkululeko/utils/stats.py,sha256=29otJpUp1VqbtDKmlLkPPzBmVfTFiHZ70rUdR4860rM,2788
103
103
  nkululeko/utils/util.py,sha256=_Z6OMJ3f-8TdETW9eqJYY5hwNRS5XCt9azzRnqoTTZE,12330
104
- nkululeko-0.81.2.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
105
- nkululeko-0.81.2.dist-info/METADATA,sha256=-Oo7DH0SM9gF8F0c65DLjGIt6rnUUPF_Ah_OgJrxDRA,34523
106
- nkululeko-0.81.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
107
- nkululeko-0.81.2.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
108
- nkululeko-0.81.2.dist-info/RECORD,,
104
+ nkululeko-0.81.3.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
105
+ nkululeko-0.81.3.dist-info/METADATA,sha256=72Q5q8KeaEP3I0TrVzswdI4g0Fc0hnCG-kPFZke8YM8,34664
106
+ nkululeko-0.81.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
107
+ nkululeko-0.81.3.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
108
+ nkululeko-0.81.3.dist-info/RECORD,,