nkululeko 0.86.5__py3-none-any.whl → 0.86.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.86.5"
1
+ VERSION="0.86.7"
2
2
  SAMPLING_RATE = 16000
nkululeko/data/dataset.py CHANGED
@@ -150,6 +150,13 @@ class Dataset:
150
150
  self.got_speaker = got_speaker2 or self.got_speaker
151
151
  self.got_gender = got_gender2 or self.got_gender
152
152
  self.got_age = got_age2 or self.got_age
153
+ if audformat.is_filewise_index(df_target.index):
154
+ try:
155
+ df_target = df_target.loc[df.index.get_level_values("file")]
156
+ df_target = df_target.set_index(df.index)
157
+ except KeyError:
158
+ # just a try...
159
+ pass
153
160
  if got_target2:
154
161
  df[self.target] = df_target[self.target]
155
162
  if got_speaker2:
@@ -255,6 +262,7 @@ class Dataset:
255
262
  df = pd.DataFrame()
256
263
  for table in df_files:
257
264
  source_df = db.tables[table].df
265
+ # check if columns should be renamed
258
266
  source_df = self._check_cols(source_df)
259
267
  # create a dataframe with the index (the filenames)
260
268
  df_local = pd.DataFrame(index=source_df.index)
nkululeko/modelrunner.py CHANGED
@@ -30,6 +30,8 @@ class Modelrunner:
30
30
  # intialize a new model
31
31
  model_type = glob_conf.config["MODEL"]["type"]
32
32
  self._select_model(model_type)
33
+ self.best_performance = 0
34
+ self.best_epoch = 0
33
35
 
34
36
  def do_epochs(self):
35
37
  # initialze results
@@ -51,7 +53,8 @@ class Modelrunner:
51
53
  # epochs are handled by Huggingface API
52
54
  self.model.train()
53
55
  report = self.model.predict()
54
- # todo: findout the best epoch
56
+ # todo: findout the best epoch, no need
57
+ # since oad_best_model_at_end is given in training args
55
58
  epoch = epoch_num
56
59
  report.set_id(self.run, epoch)
57
60
  plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
@@ -77,10 +80,15 @@ class Modelrunner:
77
80
  report.set_id(self.run, epoch)
78
81
  plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
79
82
  reports.append(report)
83
+ test_score_metric = report.get_result().get_test_result()
80
84
  self.util.debug(
81
- f"run: {self.run} epoch: {epoch}: result: "
82
- f"{reports[-1].get_result().get_test_result()}"
85
+ f"run: {self.run} epoch: {epoch}: result: {test_score_metric}"
83
86
  )
87
+ # print(f"performance: {performance.split(' ')[1]}")
88
+ performance = float(test_score_metric.split(' ')[1])
89
+ if performance > self.best_performance:
90
+ self.best_performance = performance
91
+ self.best_epoch = epoch
84
92
  if plot_epochs:
85
93
  self.util.debug(f"plotting conf matrix to {plot_name}")
86
94
  report.plot_confmatrix(plot_name, epoch)
@@ -110,11 +118,14 @@ class Modelrunner:
110
118
  f"reached patience ({str(patience)}): early stopping"
111
119
  )
112
120
  break
121
+ # After training, report the best performance and epoch
122
+ best_report = reports[self.best_epoch]
123
+ # self.util.debug(f"Best score at epoch: {self.best_epoch}, UAR: {self.best_performance}") # move to reporter below
113
124
 
114
125
  if not plot_epochs:
115
126
  # Do at least one confusion matrix plot
116
127
  self.util.debug(f"plotting confusion matrix to {plot_name}")
117
- reports[-1].plot_confmatrix(plot_name, epoch)
128
+ best_report.plot_confmatrix(plot_name, self.best_epoch)
118
129
  return reports, epoch
119
130
 
120
131
  def _select_model(self, model_type):
@@ -122,7 +122,7 @@ class Reporter:
122
122
  self.truths = np.digitize(self.truths, bins) - 1
123
123
  self.preds = np.digitize(self.preds, bins) - 1
124
124
 
125
- def plot_confmatrix(self, plot_name, epoch):
125
+ def plot_confmatrix(self, plot_name, epoch=None):
126
126
  if not self.util.exp_is_classification():
127
127
  self.continuous_to_categorical()
128
128
  self._plot_confmat(self.truths, self.preds, plot_name, epoch)
@@ -156,9 +156,11 @@ class Reporter:
156
156
  pred = np.digitize(pred, bins) - 1
157
157
  self._plot_confmat(truth, pred.astype("int"), plot_name, 0)
158
158
 
159
- def _plot_confmat(self, truths, preds, plot_name, epoch):
159
+ def _plot_confmat(self, truths, preds, plot_name, epoch=None):
160
160
  # print(truths)
161
161
  # print(preds)
162
+ if epoch is None:
163
+ epoch = self.epoch
162
164
  fig_dir = self.util.get_path("fig_dir")
163
165
  labels = glob_conf.labels
164
166
  fig = plt.figure() # figsize=[5, 5]
@@ -225,7 +227,7 @@ class Reporter:
225
227
 
226
228
  res_dir = self.util.get_path("res_dir")
227
229
  rpt = (
228
- f"epoch: {epoch}, UAR: {uar_str}"
230
+ f"Best score at epoch: {epoch}, UAR: {uar_str}"
229
231
  + f", (+-{up_str}/{low_str}), ACC: {acc_str}"
230
232
  )
231
233
  # print(rpt)
@@ -237,7 +239,9 @@ class Reporter:
237
239
  def set_filename_add(self, my_string):
238
240
  self.filenameadd = f"_{my_string}"
239
241
 
240
- def print_results(self, epoch):
242
+ def print_results(self, epoch=None):
243
+ if epoch is None:
244
+ epoch = self.epoch
241
245
  """Print all evaluation values to text file."""
242
246
  res_dir = self.util.get_path("res_dir")
243
247
  file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}.txt"
@@ -262,12 +266,14 @@ class Reporter:
262
266
  c_res = rpt[l]["f1-score"]
263
267
  c_ress[i] = float(f"{c_res:.3f}")
264
268
  self.util.debug(f"labels: {labels}")
265
- f1_per_class = f"result per class (F1 score): {c_ress}"
269
+ f1_per_class = (
270
+ f"result per class (F1 score): {c_ress} from epoch: {epoch}"
271
+ )
266
272
  if len(np.unique(self.truths)) == 2:
267
273
  fpr, tpr, _ = roc_curve(self.truths, self.preds)
268
274
  auc_score = auc(fpr, tpr)
269
275
  pauc_score = roc_auc_score(self.truths, self.preds, max_fpr=0.1)
270
- auc_pauc = f"auc: {auc_score:.3f}, pauc: {pauc_score:.3f}"
276
+ auc_pauc = f"auc: {auc_score:.3f}, pauc: {pauc_score:.3f} from epoch: {epoch}"
271
277
  self.util.debug(auc_pauc)
272
278
  self.util.debug(f1_per_class)
273
279
  rpt_str = f"{json.dumps(rpt)}\n{f1_per_class}"
nkululeko/runmanager.py CHANGED
@@ -63,8 +63,7 @@ class Runmanager:
63
63
  )
64
64
  self.reports, last_epoch = self.modelrunner.do_epochs()
65
65
  # wrap up the run
66
- plot_anim_progression = self.util.config_val(
67
- "PLOT", "anim_progression", 0)
66
+ plot_anim_progression = self.util.config_val("PLOT", "anim_progression", 0)
68
67
  if plot_anim_progression:
69
68
  plot_name_suggest = self.util.get_exp_name()
70
69
  plot_name = (
@@ -88,8 +87,7 @@ class Runmanager:
88
87
  + "_epoch_progression"
89
88
  )
90
89
  self.util.debug(f"plotting progression to {plot_name}")
91
- self.reports[-1].plot_epoch_progression(
92
- self.reports, plot_name)
90
+ self.reports[-1].plot_epoch_progression(self.reports, plot_name)
93
91
  # remember the best run
94
92
  best_report = self.get_best_result(self.reports)
95
93
  plot_best_model = self.util.config_val("PLOT", "best_model", False)
@@ -107,9 +105,10 @@ class Runmanager:
107
105
  )
108
106
  self.print_model(best_report, plot_name)
109
107
  # finally, print out the numbers for this run
110
- self.reports[-1].print_results(
111
- int(self.util.config_val("EXP", "epochs", 1))
112
- )
108
+ # self.reports[-1].print_results(
109
+ # int(self.util.config_val("EXP", "epochs", 1))
110
+ # )
111
+ best_report.print_results(best_report.epoch)
113
112
  self.best_results.append(best_report)
114
113
  self.last_epochs.append(last_epoch)
115
114
 
@@ -145,19 +144,18 @@ class Runmanager:
145
144
  )
146
145
  self.print_model(report, plot_name)
147
146
 
148
- def print_model(self, report, plot_name):
147
+ def print_model(self, reporter, plot_name):
149
148
  """Print a confusion matrix for a special report.
150
149
 
151
150
  Args:
152
151
  report: for which report (will be computed newly from model)
153
152
  plot_name: name of plot file
154
153
  """
155
- epoch = report.epoch
156
154
  # self.load_model(report)
157
155
  # report = self.model.predict()
158
156
  self.util.debug(f"plotting conf matrix to {plot_name}")
159
- report.plot_confmatrix(plot_name, epoch)
160
- report.print_results(epoch)
157
+ reporter.plot_confmatrix(plot_name)
158
+ reporter.print_results()
161
159
 
162
160
  def load_model(self, report):
163
161
  """Load a model from disk for a specific run and epoch and evaluate it.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.86.5
3
+ Version: 0.86.7
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -343,6 +343,14 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
343
343
  Changelog
344
344
  =========
345
345
 
346
+ Version 0.86.7
347
+ --------------
348
+ * handles now audformat tables where the target is in a file index
349
+
350
+ Version 0.86.6
351
+ --------------
352
+ * now best (not last) result is shown at end
353
+
346
354
  Version 0.86.5
347
355
  --------------
348
356
  * fix audio path detection in data csv import
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
3
3
  nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
4
4
  nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
5
- nkululeko/constants.py,sha256=ctptCGup_HGCOxioUojLqMivtVfYq8CZDLHJprDr9aE,39
5
+ nkululeko/constants.py,sha256=CscqJhC7nceHk2wmZd2bBFSeFExtr0HkXt99qpAZU4E,39
6
6
  nkululeko/demo.py,sha256=WSKr-W5uJ9DQfemK923g7Hd5V3kgAn03Er0JX1Pa45I,5142
7
7
  nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
8
8
  nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
@@ -13,14 +13,14 @@ nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzw
13
13
  nkululeko/file_checker.py,sha256=LoLnL8aHpW-axMQ46qbqrManTs5otG9ShpEZuz9iRSk,3474
14
14
  nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
15
15
  nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
16
- nkululeko/modelrunner.py,sha256=iCmfJxsS2UafcikjRdUqPQuqQMOYA-Ctr3et3HeNR3c,10452
16
+ nkululeko/modelrunner.py,sha256=OU35qwP94GxW_EtL4I2-RhqB-wxbjNvp8CIHNbtnt7Q,11155
17
17
  nkululeko/multidb.py,sha256=fG3VukEWP1vreVN4gB1IRXxwwg4jLftsSEYtu0o1f78,5634
18
18
  nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
19
19
  nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
20
20
  nkululeko/plots.py,sha256=C2mwQFK0Vxfl5ZM7CO87tULDoEf7G16ek0nU77bhOc4,23070
21
21
  nkululeko/predict.py,sha256=sF091sSSLnEWcISx9ZcULLie3tY5XeFsQJd6b3vrxFg,2409
22
22
  nkululeko/resample.py,sha256=2d9eao_0sLrGZ_KSl8OVKsPor3BkFrlmMhrpB9WelIs,4267
23
- nkululeko/runmanager.py,sha256=eTM1DNQKt1lxYhzt4vZyZluPXW9sWlIJHNQzex4lkJU,7624
23
+ nkululeko/runmanager.py,sha256=Na8oPn59lRFiNMsYChRHBRgw40mBcw0Rwl2Kz1RUsA0,7614
24
24
  nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
25
25
  nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
26
26
  nkululeko/syllable_nuclei.py,sha256=Sky-C__MeUDaxqHnDl2TGLLYOYvsahD35TUjWGeG31k,10047
@@ -45,7 +45,7 @@ nkululeko/autopredict/ap_stoi.py,sha256=It0Lk-ki-gohA2AzD8nkLAN2WahYvD9rPDGTQuvd
45
45
  nkululeko/autopredict/ap_valence.py,sha256=n-hctRKySzhmJtowuMOTUu0T_ld3uK5pnfOzWeWW4VM,1024
46
46
  nkululeko/autopredict/estimate_snr.py,sha256=S-bpS0xFkwWc4Ch75UrjbS8y538lQ0U3g_iLRFXureY,5048
47
47
  nkululeko/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
- nkululeko/data/dataset.py,sha256=JGzMD6HIvkFkYBekmbmslIKc5ADaCj06T-8gpqH_kFo,27650
48
+ nkululeko/data/dataset.py,sha256=hUD0NqWCfRaSHG8JNs1MsPb0zjUZAf8FJkg_c0ebq0Q,28046
49
49
  nkululeko/data/dataset_csv.py,sha256=dzOrbKB8t0UATAIYaKAOqHTogmYPBqskt6Hak7VjbSM,4537
50
50
  nkululeko/feat_extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  nkululeko/feat_extract/feats_agender.py,sha256=Qm69G4kqAyTVVk7wwRgrXlNwGaDMGRYyKGpuf0vOEgM,3113
@@ -96,7 +96,7 @@ nkululeko/reporting/defines.py,sha256=IsY1YgKRMaABpylVKjBJgJ5bNCEbGCVA_E6pivraqS
96
96
  nkululeko/reporting/latex_writer.py,sha256=qiCRSmB4KOD_za4oHu5x-PhwjZohzfo8wecMOwlXZwc,1886
97
97
  nkululeko/reporting/report.py,sha256=W0rcigDdjBvxZQ3pZja_gvToILYvaZ1BFtnN2qFRfYI,1060
98
98
  nkululeko/reporting/report_item.py,sha256=siWeGNgo4bAE46YBMNcsdf3jTMTy76BO9Fi6DTvDig4,533
99
- nkululeko/reporting/reporter.py,sha256=II3QyeneAv8xQDBZ-qE_GJL8_WV_yXqLwBUYqrjqwPo,13938
99
+ nkululeko/reporting/reporter.py,sha256=S9A62AxdMTEV-9XDUQNxdoevGLXBP52WiDmZ694QMV4,14161
100
100
  nkululeko/reporting/result.py,sha256=nSN5or-Py2GPRWHkWpGRh7UCi1W0er7WLEHz8fYLk-A,742
101
101
  nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
102
102
  nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=pmLHuXsaqvcdYxB4PSW9l1mbQWZZBJFhi_CGabqydas,1947
@@ -105,8 +105,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
105
105
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
106
106
  nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
107
107
  nkululeko/utils/util.py,sha256=ILpfNuaeq-hy1bUkRhVrzO2wG9z9Upaozs9EBoIaMG0,14123
108
- nkululeko-0.86.5.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
109
- nkululeko-0.86.5.dist-info/METADATA,sha256=HrTVTfGh3KDsmyBFijAp5tMINdiBvHhsC8E0_YwBjwE,37848
110
- nkululeko-0.86.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
111
- nkululeko-0.86.5.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
112
- nkululeko-0.86.5.dist-info/RECORD,,
108
+ nkululeko-0.86.7.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
109
+ nkululeko-0.86.7.dist-info/METADATA,sha256=t5cI43YRp3qmyJj03ACfgCbKoAuLYImDCLS1QkYbMQM,38024
110
+ nkululeko-0.86.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
111
+ nkululeko-0.86.7.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
112
+ nkululeko-0.86.7.dist-info/RECORD,,