nkululeko 0.86.5__py3-none-any.whl → 0.86.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/data/dataset.py +8 -0
- nkululeko/modelrunner.py +15 -4
- nkululeko/reporting/reporter.py +12 -6
- nkululeko/runmanager.py +9 -11
- {nkululeko-0.86.5.dist-info → nkululeko-0.86.7.dist-info}/METADATA +9 -1
- {nkululeko-0.86.5.dist-info → nkululeko-0.86.7.dist-info}/RECORD +10 -10
- {nkululeko-0.86.5.dist-info → nkululeko-0.86.7.dist-info}/LICENSE +0 -0
- {nkululeko-0.86.5.dist-info → nkululeko-0.86.7.dist-info}/WHEEL +0 -0
- {nkululeko-0.86.5.dist-info → nkululeko-0.86.7.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.86.
|
1
|
+
VERSION="0.86.7"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/data/dataset.py
CHANGED
@@ -150,6 +150,13 @@ class Dataset:
|
|
150
150
|
self.got_speaker = got_speaker2 or self.got_speaker
|
151
151
|
self.got_gender = got_gender2 or self.got_gender
|
152
152
|
self.got_age = got_age2 or self.got_age
|
153
|
+
if audformat.is_filewise_index(df_target.index):
|
154
|
+
try:
|
155
|
+
df_target = df_target.loc[df.index.get_level_values("file")]
|
156
|
+
df_target = df_target.set_index(df.index)
|
157
|
+
except KeyError:
|
158
|
+
# just a try...
|
159
|
+
pass
|
153
160
|
if got_target2:
|
154
161
|
df[self.target] = df_target[self.target]
|
155
162
|
if got_speaker2:
|
@@ -255,6 +262,7 @@ class Dataset:
|
|
255
262
|
df = pd.DataFrame()
|
256
263
|
for table in df_files:
|
257
264
|
source_df = db.tables[table].df
|
265
|
+
# check if columns should be renamed
|
258
266
|
source_df = self._check_cols(source_df)
|
259
267
|
# create a dataframe with the index (the filenames)
|
260
268
|
df_local = pd.DataFrame(index=source_df.index)
|
nkululeko/modelrunner.py
CHANGED
@@ -30,6 +30,8 @@ class Modelrunner:
|
|
30
30
|
# intialize a new model
|
31
31
|
model_type = glob_conf.config["MODEL"]["type"]
|
32
32
|
self._select_model(model_type)
|
33
|
+
self.best_performance = 0
|
34
|
+
self.best_epoch = 0
|
33
35
|
|
34
36
|
def do_epochs(self):
|
35
37
|
# initialze results
|
@@ -51,7 +53,8 @@ class Modelrunner:
|
|
51
53
|
# epochs are handled by Huggingface API
|
52
54
|
self.model.train()
|
53
55
|
report = self.model.predict()
|
54
|
-
# todo: findout the best epoch
|
56
|
+
# todo: findout the best epoch, no need
|
57
|
+
# since oad_best_model_at_end is given in training args
|
55
58
|
epoch = epoch_num
|
56
59
|
report.set_id(self.run, epoch)
|
57
60
|
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
@@ -77,10 +80,15 @@ class Modelrunner:
|
|
77
80
|
report.set_id(self.run, epoch)
|
78
81
|
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
|
79
82
|
reports.append(report)
|
83
|
+
test_score_metric = report.get_result().get_test_result()
|
80
84
|
self.util.debug(
|
81
|
-
f"run: {self.run} epoch: {epoch}: result: "
|
82
|
-
f"{reports[-1].get_result().get_test_result()}"
|
85
|
+
f"run: {self.run} epoch: {epoch}: result: {test_score_metric}"
|
83
86
|
)
|
87
|
+
# print(f"performance: {performance.split(' ')[1]}")
|
88
|
+
performance = float(test_score_metric.split(' ')[1])
|
89
|
+
if performance > self.best_performance:
|
90
|
+
self.best_performance = performance
|
91
|
+
self.best_epoch = epoch
|
84
92
|
if plot_epochs:
|
85
93
|
self.util.debug(f"plotting conf matrix to {plot_name}")
|
86
94
|
report.plot_confmatrix(plot_name, epoch)
|
@@ -110,11 +118,14 @@ class Modelrunner:
|
|
110
118
|
f"reached patience ({str(patience)}): early stopping"
|
111
119
|
)
|
112
120
|
break
|
121
|
+
# After training, report the best performance and epoch
|
122
|
+
best_report = reports[self.best_epoch]
|
123
|
+
# self.util.debug(f"Best score at epoch: {self.best_epoch}, UAR: {self.best_performance}") # move to reporter below
|
113
124
|
|
114
125
|
if not plot_epochs:
|
115
126
|
# Do at least one confusion matrix plot
|
116
127
|
self.util.debug(f"plotting confusion matrix to {plot_name}")
|
117
|
-
|
128
|
+
best_report.plot_confmatrix(plot_name, self.best_epoch)
|
118
129
|
return reports, epoch
|
119
130
|
|
120
131
|
def _select_model(self, model_type):
|
nkululeko/reporting/reporter.py
CHANGED
@@ -122,7 +122,7 @@ class Reporter:
|
|
122
122
|
self.truths = np.digitize(self.truths, bins) - 1
|
123
123
|
self.preds = np.digitize(self.preds, bins) - 1
|
124
124
|
|
125
|
-
def plot_confmatrix(self, plot_name, epoch):
|
125
|
+
def plot_confmatrix(self, plot_name, epoch=None):
|
126
126
|
if not self.util.exp_is_classification():
|
127
127
|
self.continuous_to_categorical()
|
128
128
|
self._plot_confmat(self.truths, self.preds, plot_name, epoch)
|
@@ -156,9 +156,11 @@ class Reporter:
|
|
156
156
|
pred = np.digitize(pred, bins) - 1
|
157
157
|
self._plot_confmat(truth, pred.astype("int"), plot_name, 0)
|
158
158
|
|
159
|
-
def _plot_confmat(self, truths, preds, plot_name, epoch):
|
159
|
+
def _plot_confmat(self, truths, preds, plot_name, epoch=None):
|
160
160
|
# print(truths)
|
161
161
|
# print(preds)
|
162
|
+
if epoch is None:
|
163
|
+
epoch = self.epoch
|
162
164
|
fig_dir = self.util.get_path("fig_dir")
|
163
165
|
labels = glob_conf.labels
|
164
166
|
fig = plt.figure() # figsize=[5, 5]
|
@@ -225,7 +227,7 @@ class Reporter:
|
|
225
227
|
|
226
228
|
res_dir = self.util.get_path("res_dir")
|
227
229
|
rpt = (
|
228
|
-
f"epoch: {epoch}, UAR: {uar_str}"
|
230
|
+
f"Best score at epoch: {epoch}, UAR: {uar_str}"
|
229
231
|
+ f", (+-{up_str}/{low_str}), ACC: {acc_str}"
|
230
232
|
)
|
231
233
|
# print(rpt)
|
@@ -237,7 +239,9 @@ class Reporter:
|
|
237
239
|
def set_filename_add(self, my_string):
|
238
240
|
self.filenameadd = f"_{my_string}"
|
239
241
|
|
240
|
-
def print_results(self, epoch):
|
242
|
+
def print_results(self, epoch=None):
|
243
|
+
if epoch is None:
|
244
|
+
epoch = self.epoch
|
241
245
|
"""Print all evaluation values to text file."""
|
242
246
|
res_dir = self.util.get_path("res_dir")
|
243
247
|
file_name = f"{res_dir}{self.util.get_exp_name()}_{epoch}{self.filenameadd}.txt"
|
@@ -262,12 +266,14 @@ class Reporter:
|
|
262
266
|
c_res = rpt[l]["f1-score"]
|
263
267
|
c_ress[i] = float(f"{c_res:.3f}")
|
264
268
|
self.util.debug(f"labels: {labels}")
|
265
|
-
f1_per_class =
|
269
|
+
f1_per_class = (
|
270
|
+
f"result per class (F1 score): {c_ress} from epoch: {epoch}"
|
271
|
+
)
|
266
272
|
if len(np.unique(self.truths)) == 2:
|
267
273
|
fpr, tpr, _ = roc_curve(self.truths, self.preds)
|
268
274
|
auc_score = auc(fpr, tpr)
|
269
275
|
pauc_score = roc_auc_score(self.truths, self.preds, max_fpr=0.1)
|
270
|
-
auc_pauc = f"auc: {auc_score:.3f}, pauc: {pauc_score:.3f}"
|
276
|
+
auc_pauc = f"auc: {auc_score:.3f}, pauc: {pauc_score:.3f} from epoch: {epoch}"
|
271
277
|
self.util.debug(auc_pauc)
|
272
278
|
self.util.debug(f1_per_class)
|
273
279
|
rpt_str = f"{json.dumps(rpt)}\n{f1_per_class}"
|
nkululeko/runmanager.py
CHANGED
@@ -63,8 +63,7 @@ class Runmanager:
|
|
63
63
|
)
|
64
64
|
self.reports, last_epoch = self.modelrunner.do_epochs()
|
65
65
|
# wrap up the run
|
66
|
-
plot_anim_progression = self.util.config_val(
|
67
|
-
"PLOT", "anim_progression", 0)
|
66
|
+
plot_anim_progression = self.util.config_val("PLOT", "anim_progression", 0)
|
68
67
|
if plot_anim_progression:
|
69
68
|
plot_name_suggest = self.util.get_exp_name()
|
70
69
|
plot_name = (
|
@@ -88,8 +87,7 @@ class Runmanager:
|
|
88
87
|
+ "_epoch_progression"
|
89
88
|
)
|
90
89
|
self.util.debug(f"plotting progression to {plot_name}")
|
91
|
-
self.reports[-1].plot_epoch_progression(
|
92
|
-
self.reports, plot_name)
|
90
|
+
self.reports[-1].plot_epoch_progression(self.reports, plot_name)
|
93
91
|
# remember the best run
|
94
92
|
best_report = self.get_best_result(self.reports)
|
95
93
|
plot_best_model = self.util.config_val("PLOT", "best_model", False)
|
@@ -107,9 +105,10 @@ class Runmanager:
|
|
107
105
|
)
|
108
106
|
self.print_model(best_report, plot_name)
|
109
107
|
# finally, print out the numbers for this run
|
110
|
-
self.reports[-1].print_results(
|
111
|
-
|
112
|
-
)
|
108
|
+
# self.reports[-1].print_results(
|
109
|
+
# int(self.util.config_val("EXP", "epochs", 1))
|
110
|
+
# )
|
111
|
+
best_report.print_results(best_report.epoch)
|
113
112
|
self.best_results.append(best_report)
|
114
113
|
self.last_epochs.append(last_epoch)
|
115
114
|
|
@@ -145,19 +144,18 @@ class Runmanager:
|
|
145
144
|
)
|
146
145
|
self.print_model(report, plot_name)
|
147
146
|
|
148
|
-
def print_model(self,
|
147
|
+
def print_model(self, reporter, plot_name):
|
149
148
|
"""Print a confusion matrix for a special report.
|
150
149
|
|
151
150
|
Args:
|
152
151
|
report: for which report (will be computed newly from model)
|
153
152
|
plot_name: name of plot file
|
154
153
|
"""
|
155
|
-
epoch = report.epoch
|
156
154
|
# self.load_model(report)
|
157
155
|
# report = self.model.predict()
|
158
156
|
self.util.debug(f"plotting conf matrix to {plot_name}")
|
159
|
-
|
160
|
-
|
157
|
+
reporter.plot_confmatrix(plot_name)
|
158
|
+
reporter.print_results()
|
161
159
|
|
162
160
|
def load_model(self, report):
|
163
161
|
"""Load a model from disk for a specific run and epoch and evaluate it.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.86.
|
3
|
+
Version: 0.86.7
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -343,6 +343,14 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
343
343
|
Changelog
|
344
344
|
=========
|
345
345
|
|
346
|
+
Version 0.86.7
|
347
|
+
--------------
|
348
|
+
* handles now audformat tables where the target is in a file index
|
349
|
+
|
350
|
+
Version 0.86.6
|
351
|
+
--------------
|
352
|
+
* now best (not last) result is shown at end
|
353
|
+
|
346
354
|
Version 0.86.5
|
347
355
|
--------------
|
348
356
|
* fix audio path detection in data csv import
|
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=CscqJhC7nceHk2wmZd2bBFSeFExtr0HkXt99qpAZU4E,39
|
6
6
|
nkululeko/demo.py,sha256=WSKr-W5uJ9DQfemK923g7Hd5V3kgAn03Er0JX1Pa45I,5142
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
|
@@ -13,14 +13,14 @@ nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzw
|
|
13
13
|
nkululeko/file_checker.py,sha256=LoLnL8aHpW-axMQ46qbqrManTs5otG9ShpEZuz9iRSk,3474
|
14
14
|
nkululeko/filter_data.py,sha256=w-X2mhKdYr5DxDIz50E5yzO6Jmzk4jjDBoXsgOOVtcA,7222
|
15
15
|
nkululeko/glob_conf.py,sha256=KL9YJQTHvTztxo1vr25qRRgaPnx4NTg0XrdbovKGMmw,525
|
16
|
-
nkululeko/modelrunner.py,sha256=
|
16
|
+
nkululeko/modelrunner.py,sha256=OU35qwP94GxW_EtL4I2-RhqB-wxbjNvp8CIHNbtnt7Q,11155
|
17
17
|
nkululeko/multidb.py,sha256=fG3VukEWP1vreVN4gB1IRXxwwg4jLftsSEYtu0o1f78,5634
|
18
18
|
nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
|
19
19
|
nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
|
20
20
|
nkululeko/plots.py,sha256=C2mwQFK0Vxfl5ZM7CO87tULDoEf7G16ek0nU77bhOc4,23070
|
21
21
|
nkululeko/predict.py,sha256=sF091sSSLnEWcISx9ZcULLie3tY5XeFsQJd6b3vrxFg,2409
|
22
22
|
nkululeko/resample.py,sha256=2d9eao_0sLrGZ_KSl8OVKsPor3BkFrlmMhrpB9WelIs,4267
|
23
|
-
nkululeko/runmanager.py,sha256=
|
23
|
+
nkululeko/runmanager.py,sha256=Na8oPn59lRFiNMsYChRHBRgw40mBcw0Rwl2Kz1RUsA0,7614
|
24
24
|
nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
|
25
25
|
nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
|
26
26
|
nkululeko/syllable_nuclei.py,sha256=Sky-C__MeUDaxqHnDl2TGLLYOYvsahD35TUjWGeG31k,10047
|
@@ -45,7 +45,7 @@ nkululeko/autopredict/ap_stoi.py,sha256=It0Lk-ki-gohA2AzD8nkLAN2WahYvD9rPDGTQuvd
|
|
45
45
|
nkululeko/autopredict/ap_valence.py,sha256=n-hctRKySzhmJtowuMOTUu0T_ld3uK5pnfOzWeWW4VM,1024
|
46
46
|
nkululeko/autopredict/estimate_snr.py,sha256=S-bpS0xFkwWc4Ch75UrjbS8y538lQ0U3g_iLRFXureY,5048
|
47
47
|
nkululeko/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
|
-
nkululeko/data/dataset.py,sha256=
|
48
|
+
nkululeko/data/dataset.py,sha256=hUD0NqWCfRaSHG8JNs1MsPb0zjUZAf8FJkg_c0ebq0Q,28046
|
49
49
|
nkululeko/data/dataset_csv.py,sha256=dzOrbKB8t0UATAIYaKAOqHTogmYPBqskt6Hak7VjbSM,4537
|
50
50
|
nkululeko/feat_extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
51
51
|
nkululeko/feat_extract/feats_agender.py,sha256=Qm69G4kqAyTVVk7wwRgrXlNwGaDMGRYyKGpuf0vOEgM,3113
|
@@ -96,7 +96,7 @@ nkululeko/reporting/defines.py,sha256=IsY1YgKRMaABpylVKjBJgJ5bNCEbGCVA_E6pivraqS
|
|
96
96
|
nkululeko/reporting/latex_writer.py,sha256=qiCRSmB4KOD_za4oHu5x-PhwjZohzfo8wecMOwlXZwc,1886
|
97
97
|
nkululeko/reporting/report.py,sha256=W0rcigDdjBvxZQ3pZja_gvToILYvaZ1BFtnN2qFRfYI,1060
|
98
98
|
nkululeko/reporting/report_item.py,sha256=siWeGNgo4bAE46YBMNcsdf3jTMTy76BO9Fi6DTvDig4,533
|
99
|
-
nkululeko/reporting/reporter.py,sha256=
|
99
|
+
nkululeko/reporting/reporter.py,sha256=S9A62AxdMTEV-9XDUQNxdoevGLXBP52WiDmZ694QMV4,14161
|
100
100
|
nkululeko/reporting/result.py,sha256=nSN5or-Py2GPRWHkWpGRh7UCi1W0er7WLEHz8fYLk-A,742
|
101
101
|
nkululeko/segmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
102
102
|
nkululeko/segmenting/seg_inaspeechsegmenter.py,sha256=pmLHuXsaqvcdYxB4PSW9l1mbQWZZBJFhi_CGabqydas,1947
|
@@ -105,8 +105,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
105
105
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
106
106
|
nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
|
107
107
|
nkululeko/utils/util.py,sha256=ILpfNuaeq-hy1bUkRhVrzO2wG9z9Upaozs9EBoIaMG0,14123
|
108
|
-
nkululeko-0.86.
|
109
|
-
nkululeko-0.86.
|
110
|
-
nkululeko-0.86.
|
111
|
-
nkululeko-0.86.
|
112
|
-
nkululeko-0.86.
|
108
|
+
nkululeko-0.86.7.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
109
|
+
nkululeko-0.86.7.dist-info/METADATA,sha256=t5cI43YRp3qmyJj03ACfgCbKoAuLYImDCLS1QkYbMQM,38024
|
110
|
+
nkululeko-0.86.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
111
|
+
nkululeko-0.86.7.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
112
|
+
nkululeko-0.86.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|