nkululeko 0.95.5__py3-none-any.whl → 0.95.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.95.5"
1
+ VERSION="0.95.7"
2
2
  SAMPLING_RATE = 16000
nkululeko/data/dataset.py CHANGED
@@ -294,12 +294,21 @@ class Dataset:
294
294
  # try to get the gender values
295
295
  if "gender" in source_df:
296
296
  df_local["gender"] = source_df["gender"]
297
- got_gender = True
297
+ else:
298
+ # try to get the gender via the speaker description
299
+ gender_map = db['speaker'].get().to_dict()['gender']
300
+ df_local['gender'] = df_local['speaker'].map(gender_map).astype(str)
301
+ got_gender = True
298
302
  except (KeyError, ValueError, audformat.errors.BadKeyError):
299
303
  pass
300
304
  try:
301
305
  # try to get the age values
302
- df_local["age"] = source_df["age"].astype(int)
306
+ if "age" in source_df:
307
+ df_local["age"] = source_df["age"].astype(int)
308
+ else:
309
+ # try to get the age via the speaker description
310
+ age_map = db['speaker'].get().to_dict()['age']
311
+ df_local['age'] = df_local['speaker'].map(age_map).astype(int)
303
312
  got_age = True
304
313
  except (KeyError, ValueError, audformat.errors.BadKeyError):
305
314
  pass
nkululeko/explore.py CHANGED
@@ -65,6 +65,7 @@ def main():
65
65
  try:
66
66
  # load the experiment
67
67
  expr.load(f"{util.get_save_name()}")
68
+ expr.util.set_config(config)
68
69
  needs_feats = True
69
70
  experiment_loaded = True
70
71
  except FileNotFoundError:
nkululeko/plots.py CHANGED
@@ -9,6 +9,7 @@ from scipy import stats
9
9
  import seaborn as sns
10
10
  from sklearn.manifold import TSNE
11
11
 
12
+ import audeer
12
13
  from audmetric import concordance_cc as ccc
13
14
 
14
15
  import nkululeko.glob_conf as glob_conf
@@ -218,7 +219,7 @@ class Plots:
218
219
 
219
220
  def save_plot(self, ax, caption, header, filename, type_s):
220
221
  # one up because of the runs
221
- fig_dir = os.path.dirname(self.util.get_path("fig_dir"))
222
+ fig_dir = audeer.path(self.util.get_path("fig_dir"), "..")
222
223
  fig_plots = ax.figure
223
224
  # avoid warning
224
225
  # plt.tight_layout()
@@ -12,13 +12,9 @@ from scipy.special import softmax
12
12
  from scipy.stats import entropy
13
13
  from scipy.stats import pearsonr
14
14
  from sklearn.metrics import ConfusionMatrixDisplay
15
- from sklearn.metrics import RocCurveDisplay
16
- from sklearn.metrics import auc
17
15
  from sklearn.metrics import classification_report
18
16
  from sklearn.metrics import confusion_matrix
19
17
  from sklearn.metrics import r2_score
20
- from sklearn.metrics import roc_auc_score
21
- from sklearn.metrics import roc_curve
22
18
 
23
19
  # from torch import is_tensor
24
20
  from audmetric import accuracy
@@ -186,6 +182,7 @@ class Reporter:
186
182
  if not file_name.endswith(".csv"):
187
183
  file_name = file_name + ".csv"
188
184
  self.probas = probas
185
+ self.plot_proba_conf()
189
186
  probas.to_csv(file_name)
190
187
  self.util.debug(f"Saved probabilities to {file_name}")
191
188
  plots = Plots()
@@ -196,10 +193,27 @@ class Reporter:
196
193
  ax,
197
194
  caption,
198
195
  "Uncertainty",
199
- "uncertainty_samples",
196
+ "uncertainty",
200
197
  "samples",
201
198
  )
202
199
 
200
+ def plot_proba_conf(self):
201
+ uncertainty_threshold = self.util.config_val("PLOT", "uncertainty_threshold", False)
202
+ if uncertainty_threshold:
203
+ uncertainty_threshold = float(uncertainty_threshold)
204
+ old_size = self.probas.shape[0]
205
+ df = self.probas[self.probas["uncertainty"] < uncertainty_threshold]
206
+ new_size = df.shape[0]
207
+ difference = old_size - new_size
208
+ self.util.debug(
209
+ f"Filtered probabilities: {old_size} -> {new_size} ({difference}) samples with uncertainty < {uncertainty_threshold}"
210
+ )
211
+ truths = df["truth"].values
212
+ preds = df["predicted"].values
213
+ self._plot_confmat(truths, preds, f"uncertainty_less_than_{uncertainty_threshold}_cnf",
214
+ epoch=None, test_result=None)
215
+
216
+
203
217
  def set_id(self, run, epoch):
204
218
  """Make the report identifiable with run and epoch index."""
205
219
  self.run = run
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nkululeko
3
- Version: 0.95.5
3
+ Version: 0.95.7
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -4,14 +4,14 @@ nkululeko/aug_train.py,sha256=wpiHCJ7zsW38kumg3ypwXZe2HQrhUblAnv7P2QeJnAc,3525
4
4
  nkululeko/augment.py,sha256=3RzaxB3gRxovgJVjHXi0glprW01J7RaHhUkqotW2T3U,2955
5
5
  nkululeko/balance.py,sha256=r7opXbrqAipm2euPPaOmLlA5J10p2bHQgO5kWk2x9ro,8702
6
6
  nkululeko/cacheddataset.py,sha256=XFpWZmbJRg0pvhnIgYf0TkclxllD-Fctu-Ol0PF_00c,969
7
- nkululeko/constants.py,sha256=uY1Jr5zRXhQbcZ07E355HAsT4h-soeECnBVXSukC-wY,39
7
+ nkululeko/constants.py,sha256=6jfPRCrnqqRsGqz83bT34_5gPBbTiIAsnhzVWUrKXl4,39
8
8
  nkululeko/demo-ft.py,sha256=iD9Pzp9QjyAv31q1cDZ75vPez7Ve8A4Cfukv5yfZdrQ,770
9
9
  nkululeko/demo.py,sha256=tu7Al2l5MCLVegkDC-NE2wcuc_YE7NRbgOlPW3yhGEs,4940
10
10
  nkululeko/demo_feats.py,sha256=BvZjeNFTlERIRlq34OHM4Z96jdDQAhB01BGQAUcX9dM,2026
11
11
  nkululeko/demo_predictor.py,sha256=lDF-xOxRdEAclOmbepAYg-BQXQdGkHfq2n74PTIoop8,4872
12
12
  nkululeko/ensemble.py,sha256=71V-rre61H3J4sh7lu-OTo4I2_g7mm_rQxwW1ARDHgY,12782
13
13
  nkululeko/experiment.py,sha256=BAc220lktt_tvifl-m-ZIPO7Nwi-HzDBNyTfjPDbQkE,38397
14
- nkululeko/explore.py,sha256=aDVHwuo-lkih7VZrbb_zFKg5fowSrAIcx0V9wf0SRGo,4175
14
+ nkululeko/explore.py,sha256=PjNcLuPdvWqCqYXUvGhd0hBijIhzdyi3ED1RF6o5Gjk,4212
15
15
  nkululeko/export.py,sha256=U-V4acxtuL6qKt6oAsVcM5TTeWogYUJ3GU-lA6rq6d4,4336
16
16
  nkululeko/feature_extractor.py,sha256=CsKmBoxwNClRGu20ox_eCxMG4u_1OH8Y83FYw7GfUwA,4230
17
17
  nkululeko/file_checker.py,sha256=xJY0Q6w47pnmgJVK5rcAKPYBrCpV7eBT4_3YBzTx-H8,3454
@@ -24,7 +24,7 @@ nkululeko/nkuluflag.py,sha256=_83LqLr2bSHjnVJuPeSAHCIyuiIbRxgpFKW6CwanWFM,3728
24
24
  nkululeko/nkululeko.py,sha256=6ALPMMIz6l0O3IRaP0q4b59ZUxpfzNqLQUqZMf5t3Zo,1976
25
25
  nkululeko/optim.py,sha256=Pn_02irXYJJmNG1yWA9GImHirpbXXywV61MalZb2wVA,1658
26
26
  nkululeko/optimizationrunner.py,sha256=UfWU_gOPaHUVjvYaw3AoF9HoDGYxIjbCyTGmi1PVu3s,44283
27
- nkululeko/plots.py,sha256=lUxgyoriYTwdpHZvBBQ4e41v77deQrt0PcRDLJWijys,27503
27
+ nkululeko/plots.py,sha256=rVkOGWB7yLkZ1dGg_MXeKhPOtiquiYIyCam4KYOdJQY,27519
28
28
  nkululeko/predict.py,sha256=PWv1Pc39lrxqqIWrYszVk5SL37dDL93CHgcruItNID8,2211
29
29
  nkululeko/resample.py,sha256=rn3-M1A-iwVGibfQNGyeYNa7briD24lIN9Szq_1uTJo,5194
30
30
  nkululeko/runmanager.py,sha256=YtGQP0UyyQTKkilncB1XYM-T8oatzGcZEOcj5SorjJw,8902
@@ -58,7 +58,7 @@ nkululeko/autopredict/whisper_transcriber.py,sha256=DWDvpRaV5KmUF18ojPEvxnVXm_h_
58
58
  nkululeko/autopredict/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  nkululeko/autopredict/tests/test_whisper_transcriber.py,sha256=ilas6j3OUvq_xnQCRZgytQCtyrpNU6tvG5a8kPvVKBQ,5085
60
60
  nkululeko/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
61
- nkululeko/data/dataset.py,sha256=2-I7d91M9f296SANYiL4eTZmcXKs-nj1vqsUEXpp-cA,42461
61
+ nkululeko/data/dataset.py,sha256=uj4rtcAoiEUpoZv8dlgrdzBuUdFrXtU7Pai6wSHY2xU,42997
62
62
  nkululeko/data/dataset_csv.py,sha256=AIbtB6pGk5BSQGIgfokZ7tEGFjmuOq5w2XumRSimVWs,4833
63
63
  nkululeko/feat_extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
64
  nkululeko/feat_extract/feats_agender.py,sha256=onfAQ6-xx_mFMJXEF1IX8cHBmGtGeX6weJmxbkfh1_o,3184
@@ -120,7 +120,7 @@ nkululeko/reporting/defines.py,sha256=0vh-Tlx4fAPpk1o6mP_4x3EkIoqzYMr38IZnj-JM5z
120
120
  nkululeko/reporting/latex_writer.py,sha256=NGwSIfd4nfslDkNUOSZSdqY_VDLA8634thyhe-vj1bY,1824
121
121
  nkululeko/reporting/report.py,sha256=B5eoIKMz46VKDBsi7M9u_iegzAD-E3eGCmolzSFjZ3c,1118
122
122
  nkululeko/reporting/report_item.py,sha256=drkknsyFhGviaPJNmPQtCXJmRhTSSfjNcJt0Bls6JAA,533
123
- nkululeko/reporting/reporter.py,sha256=awBaewERa8xSQtZ0c1KVAQhV77L-BvXSDyU959hQ6qU,21150
123
+ nkululeko/reporting/reporter.py,sha256=ITxM5O9Hoe_1z_59g-GF4b9vciR4shokZxeFzCrDaag,21869
124
124
  nkululeko/reporting/result.py,sha256=G63a2tHCwHhM6NBJgYzsWKWJm4Yu3r4hsCHA2Km7eHU,1073
125
125
  nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
126
126
  nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=b3t0zdpJYofKWMyKRMtMMX91xeR-k8d5pbnNaQHcsOE,1902
@@ -134,9 +134,9 @@ nkululeko/utils/files.py,sha256=SrrYaU7AB80MZHiV1jcB0h_zigvYLYgSVNTXV4ao38g,4593
134
134
  nkululeko/utils/stats.py,sha256=3Fyx8q8BSKYmiufT6OkRug9RATWmGrr9BaX_y8jziWo,3074
135
135
  nkululeko/utils/unzip.py,sha256=G68f5120TjwACZC3bQcneMniddnwubPbBdMc2L5KBOo,1206
136
136
  nkululeko/utils/util.py,sha256=yHgzfj-8ncgCvyrrrH_NDWCh6VmhAqVYY6Vlgyg-c6E,18585
137
- nkululeko-0.95.5.dist-info/licenses/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
138
- nkululeko-0.95.5.dist-info/METADATA,sha256=jt5I1QkdFV2jSe6vVu4SgJ5Ptlc9GMvKSpxRCpc9Awk,21998
139
- nkululeko-0.95.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
140
- nkululeko-0.95.5.dist-info/entry_points.txt,sha256=lNTkFEdh6Kjo5o95ZAWf_0Lq-4ztGoAoMVSDuPtuyS0,442
141
- nkululeko-0.95.5.dist-info/top_level.txt,sha256=bf1k1YKkqcXemNX_cUgoyKqQ3_GVErPqAY-53J36jkM,19
142
- nkululeko-0.95.5.dist-info/RECORD,,
137
+ nkululeko-0.95.7.dist-info/licenses/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
138
+ nkululeko-0.95.7.dist-info/METADATA,sha256=s_XLh9XUEm_NRApCwnUc8QKkRHWqlva7yxY8Jce0vSI,21998
139
+ nkululeko-0.95.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
140
+ nkululeko-0.95.7.dist-info/entry_points.txt,sha256=lNTkFEdh6Kjo5o95ZAWf_0Lq-4ztGoAoMVSDuPtuyS0,442
141
+ nkululeko-0.95.7.dist-info/top_level.txt,sha256=bf1k1YKkqcXemNX_cUgoyKqQ3_GVErPqAY-53J36jkM,19
142
+ nkululeko-0.95.7.dist-info/RECORD,,