nkululeko 0.84.0__py3-none-any.whl → 0.85.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/augmenting/resampler.py +9 -4
- nkululeko/constants.py +1 -1
- nkululeko/experiment.py +6 -1
- nkululeko/feat_extract/feats_whisper.py +3 -6
- nkululeko/modelrunner.py +56 -33
- nkululeko/models/finetune_model.py +190 -0
- nkululeko/models/model.py +1 -1
- nkululeko/models/model_tuned.py +506 -0
- nkululeko/resample.py +76 -54
- nkululeko/test_pretrain.py +200 -11
- nkululeko/utils/util.py +53 -32
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/METADATA +9 -1
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/RECORD +16 -14
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/LICENSE +0 -0
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/WHEEL +0 -0
- {nkululeko-0.84.0.dist-info → nkululeko-0.85.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,506 @@
|
|
1
|
+
"""
|
2
|
+
Code based on @jwagner
|
3
|
+
"""
|
4
|
+
|
5
|
+
import audiofile
|
6
|
+
import audeer
|
7
|
+
import audmetric
|
8
|
+
import datasets
|
9
|
+
import pandas as pd
|
10
|
+
import transformers
|
11
|
+
from nkululeko.utils.util import Util
|
12
|
+
import nkululeko.glob_conf as glob_conf
|
13
|
+
from nkululeko.models.model import Model as BaseModel
|
14
|
+
|
15
|
+
# import nkululeko.models.finetune_model as fm
|
16
|
+
from nkululeko.reporting.reporter import Reporter
|
17
|
+
import torch
|
18
|
+
import ast
|
19
|
+
import numpy as np
|
20
|
+
from sklearn.metrics import recall_score
|
21
|
+
from collections import OrderedDict
|
22
|
+
import os
|
23
|
+
import json
|
24
|
+
import pickle
|
25
|
+
import dataclasses
|
26
|
+
import typing
|
27
|
+
|
28
|
+
import torch
|
29
|
+
import transformers
|
30
|
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
31
|
+
Wav2Vec2PreTrainedModel,
|
32
|
+
Wav2Vec2Model,
|
33
|
+
)
|
34
|
+
|
35
|
+
|
36
|
+
class Pretrained_model(BaseModel):
|
37
|
+
|
38
|
+
is_classifier = True
|
39
|
+
|
40
|
+
def __init__(self, df_train, df_test, feats_train, feats_test):
|
41
|
+
"""Constructor taking the configuration and all dataframes"""
|
42
|
+
super().__init__(df_train, df_test, feats_train, feats_test)
|
43
|
+
super().set_model_type("ann")
|
44
|
+
self.name = "finetuned_wav2vec2"
|
45
|
+
self.model_type = "finetuned"
|
46
|
+
self.target = glob_conf.config["DATA"]["target"]
|
47
|
+
labels = glob_conf.labels
|
48
|
+
self.class_num = len(labels)
|
49
|
+
device = self.util.config_val("MODEL", "device", "cpu")
|
50
|
+
self.batch_size = int(self.util.config_val("MODEL", "batch_size", "8"))
|
51
|
+
if device != "cpu":
|
52
|
+
self.util.debug(f"running on device {device}")
|
53
|
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
54
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = device
|
55
|
+
self.df_train, self.df_test = df_train, df_test
|
56
|
+
self.epoch_num = int(self.util.config_val("EXP", "epochs", 1))
|
57
|
+
|
58
|
+
self._init_model()
|
59
|
+
|
60
|
+
def _init_model(self):
|
61
|
+
model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
|
62
|
+
self.num_layers = None
|
63
|
+
self.sampling_rate = 16000
|
64
|
+
self.max_duration_sec = 8.0
|
65
|
+
self.accumulation_steps = 4
|
66
|
+
# create dataset
|
67
|
+
|
68
|
+
dataset = {}
|
69
|
+
target_name = glob_conf.target
|
70
|
+
data_sources = {
|
71
|
+
"train": pd.DataFrame(self.df_train[target_name]),
|
72
|
+
"dev": pd.DataFrame(self.df_test[target_name]),
|
73
|
+
}
|
74
|
+
|
75
|
+
for split in ["train", "dev"]:
|
76
|
+
df = data_sources[split]
|
77
|
+
df[target_name] = df[target_name].astype("float")
|
78
|
+
|
79
|
+
y = pd.Series(
|
80
|
+
data=df.itertuples(index=False, name=None),
|
81
|
+
index=df.index,
|
82
|
+
dtype=object,
|
83
|
+
name="labels",
|
84
|
+
)
|
85
|
+
|
86
|
+
y.name = "targets"
|
87
|
+
df = y.reset_index()
|
88
|
+
df.start = df.start.dt.total_seconds()
|
89
|
+
df.end = df.end.dt.total_seconds()
|
90
|
+
|
91
|
+
# print(f"{split}: {len(df)}")
|
92
|
+
|
93
|
+
ds = datasets.Dataset.from_pandas(df)
|
94
|
+
dataset[split] = ds
|
95
|
+
|
96
|
+
self.dataset = datasets.DatasetDict(dataset)
|
97
|
+
|
98
|
+
# load pre-trained model
|
99
|
+
le = glob_conf.label_encoder
|
100
|
+
mapping = dict(zip(le.classes_, range(len(le.classes_))))
|
101
|
+
target_mapping = {k: int(v) for k, v in mapping.items()}
|
102
|
+
target_mapping_reverse = {value: key for key, value in target_mapping.items()}
|
103
|
+
|
104
|
+
self.config = transformers.AutoConfig.from_pretrained(
|
105
|
+
model_path,
|
106
|
+
num_labels=len(target_mapping),
|
107
|
+
label2id=target_mapping,
|
108
|
+
id2label=target_mapping_reverse,
|
109
|
+
finetuning_task=target_name,
|
110
|
+
)
|
111
|
+
if self.num_layers is not None:
|
112
|
+
self.config.num_hidden_layers = self.num_layers
|
113
|
+
setattr(self.config, "sampling_rate", self.sampling_rate)
|
114
|
+
setattr(self.config, "data", self.util.get_data_name())
|
115
|
+
|
116
|
+
vocab_dict = {}
|
117
|
+
with open("vocab.json", "w") as vocab_file:
|
118
|
+
json.dump(vocab_dict, vocab_file)
|
119
|
+
tokenizer = transformers.Wav2Vec2CTCTokenizer("./vocab.json")
|
120
|
+
tokenizer.save_pretrained(".")
|
121
|
+
|
122
|
+
feature_extractor = transformers.Wav2Vec2FeatureExtractor(
|
123
|
+
feature_size=1,
|
124
|
+
sampling_rate=16000,
|
125
|
+
padding_value=0.0,
|
126
|
+
do_normalize=True,
|
127
|
+
return_attention_mask=True,
|
128
|
+
)
|
129
|
+
self.processor = transformers.Wav2Vec2Processor(
|
130
|
+
feature_extractor=feature_extractor,
|
131
|
+
tokenizer=tokenizer,
|
132
|
+
)
|
133
|
+
assert self.processor.feature_extractor.sampling_rate == self.sampling_rate
|
134
|
+
|
135
|
+
self.model = Model.from_pretrained(
|
136
|
+
model_path,
|
137
|
+
config=self.config,
|
138
|
+
)
|
139
|
+
self.model.freeze_feature_extractor()
|
140
|
+
self.model.train()
|
141
|
+
self.model_initialized = True
|
142
|
+
|
143
|
+
def set_model_type(self, type):
|
144
|
+
self.model_type = type
|
145
|
+
|
146
|
+
def is_ann(self):
|
147
|
+
if self.model_type == "ann":
|
148
|
+
return True
|
149
|
+
else:
|
150
|
+
return False
|
151
|
+
|
152
|
+
def set_testdata(self, data_df, feats_df):
|
153
|
+
self.df_test, self.feats_test = data_df, feats_df
|
154
|
+
|
155
|
+
def reset_test(self, df_test, feats_test):
|
156
|
+
self.df_test, self.feats_test = df_test, feats_test
|
157
|
+
|
158
|
+
def set_id(self, run, epoch):
|
159
|
+
self.run = run
|
160
|
+
self.epoch = epoch
|
161
|
+
dir = self.util.get_path("model_dir")
|
162
|
+
name = f"{self.util.get_exp_name(only_train=True)}_{self.run}_{self.epoch:03d}.model"
|
163
|
+
self.store_path = dir + name
|
164
|
+
|
165
|
+
def data_collator(self, data):
|
166
|
+
files = [d["file"] for d in data]
|
167
|
+
starts = [d["start"] for d in data]
|
168
|
+
ends = [d["end"] for d in data]
|
169
|
+
targets = [d["targets"] for d in data]
|
170
|
+
|
171
|
+
signals = []
|
172
|
+
for file, start, end in zip(
|
173
|
+
files,
|
174
|
+
starts,
|
175
|
+
ends,
|
176
|
+
):
|
177
|
+
offset = start
|
178
|
+
duration = end - offset
|
179
|
+
if self.max_duration_sec is not None:
|
180
|
+
duration = min(duration, self.max_duration_sec)
|
181
|
+
signal, _ = audiofile.read(
|
182
|
+
file,
|
183
|
+
offset=offset,
|
184
|
+
duration=duration,
|
185
|
+
)
|
186
|
+
signals.append(signal.squeeze())
|
187
|
+
|
188
|
+
input_values = self.processor(
|
189
|
+
signals,
|
190
|
+
sampling_rate=self.sampling_rate,
|
191
|
+
padding=True,
|
192
|
+
)
|
193
|
+
batch = self.processor.pad(
|
194
|
+
input_values,
|
195
|
+
padding=True,
|
196
|
+
return_tensors="pt",
|
197
|
+
)
|
198
|
+
|
199
|
+
batch["labels"] = torch.tensor(targets)
|
200
|
+
|
201
|
+
return batch
|
202
|
+
|
203
|
+
def compute_metrics(self, p: transformers.EvalPrediction):
|
204
|
+
|
205
|
+
metrics = {
|
206
|
+
"UAR": audmetric.unweighted_average_recall,
|
207
|
+
"ACC": audmetric.accuracy,
|
208
|
+
}
|
209
|
+
|
210
|
+
truth = p.label_ids[:, 0].astype(int)
|
211
|
+
preds = p.predictions
|
212
|
+
preds = np.argmax(preds, axis=1)
|
213
|
+
scores = {}
|
214
|
+
for name, metric in metrics.items():
|
215
|
+
scores[f"{name}"] = metric(truth, preds)
|
216
|
+
return scores
|
217
|
+
|
218
|
+
def train(self):
|
219
|
+
"""Train the model"""
|
220
|
+
|
221
|
+
model_root = self.util.get_path("model_dir")
|
222
|
+
log_root = os.path.join(self.util.get_exp_dir(), "log")
|
223
|
+
audeer.mkdir(log_root)
|
224
|
+
self.torch_root = audeer.path(model_root, "torch")
|
225
|
+
conf_file = os.path.join(self.torch_root, "config.json")
|
226
|
+
if os.path.isfile(conf_file):
|
227
|
+
self.util.debug(f"reusing finetuned model: {conf_file}")
|
228
|
+
self.load(self.run, self.epoch)
|
229
|
+
return
|
230
|
+
targets = pd.DataFrame(self.dataset["train"]["targets"])
|
231
|
+
counts = targets[0].value_counts().sort_index()
|
232
|
+
train_weights = 1 / counts
|
233
|
+
train_weights /= train_weights.sum()
|
234
|
+
# print(train_weights)
|
235
|
+
criterion_gender = torch.nn.CrossEntropyLoss(
|
236
|
+
weight=torch.Tensor(train_weights).to("cuda"),
|
237
|
+
)
|
238
|
+
|
239
|
+
class Trainer(transformers.Trainer):
|
240
|
+
|
241
|
+
def compute_loss(
|
242
|
+
self,
|
243
|
+
model,
|
244
|
+
inputs,
|
245
|
+
return_outputs=False,
|
246
|
+
):
|
247
|
+
|
248
|
+
targets = inputs.pop("labels").squeeze()
|
249
|
+
targets_gender = targets.type(torch.long)
|
250
|
+
|
251
|
+
outputs = model(**inputs)
|
252
|
+
logits_gender = outputs[0].squeeze()
|
253
|
+
|
254
|
+
loss_gender = criterion_gender(logits_gender, targets_gender)
|
255
|
+
|
256
|
+
loss = loss_gender
|
257
|
+
|
258
|
+
return (loss, outputs) if return_outputs else loss
|
259
|
+
|
260
|
+
num_steps = (
|
261
|
+
len(self.dataset["train"])
|
262
|
+
// (self.batch_size * self.accumulation_steps)
|
263
|
+
// 5
|
264
|
+
)
|
265
|
+
num_steps = max(1, num_steps)
|
266
|
+
# print(num_steps)
|
267
|
+
|
268
|
+
training_args = transformers.TrainingArguments(
|
269
|
+
output_dir=model_root,
|
270
|
+
logging_dir=log_root,
|
271
|
+
per_device_train_batch_size=self.batch_size,
|
272
|
+
per_device_eval_batch_size=self.batch_size,
|
273
|
+
gradient_accumulation_steps=self.accumulation_steps,
|
274
|
+
evaluation_strategy="steps",
|
275
|
+
num_train_epochs=self.epoch_num,
|
276
|
+
fp16=True,
|
277
|
+
save_steps=num_steps,
|
278
|
+
eval_steps=num_steps,
|
279
|
+
logging_steps=num_steps,
|
280
|
+
learning_rate=1e-4,
|
281
|
+
save_total_limit=2,
|
282
|
+
metric_for_best_model="UAR",
|
283
|
+
greater_is_better=True,
|
284
|
+
load_best_model_at_end=True,
|
285
|
+
remove_unused_columns=False,
|
286
|
+
report_to="none",
|
287
|
+
)
|
288
|
+
|
289
|
+
trainer = Trainer(
|
290
|
+
model=self.model,
|
291
|
+
data_collator=self.data_collator,
|
292
|
+
args=training_args,
|
293
|
+
compute_metrics=self.compute_metrics,
|
294
|
+
train_dataset=self.dataset["train"],
|
295
|
+
eval_dataset=self.dataset["dev"],
|
296
|
+
tokenizer=self.processor.feature_extractor,
|
297
|
+
callbacks=[transformers.integrations.TensorBoardCallback()],
|
298
|
+
)
|
299
|
+
trainer.train()
|
300
|
+
trainer.save_model(self.torch_root)
|
301
|
+
self.load(self.run, self.epoch)
|
302
|
+
|
303
|
+
def get_predictions(self):
|
304
|
+
results = []
|
305
|
+
for (file, start, end), _ in audeer.progress_bar(
|
306
|
+
self.df_test.iterrows(),
|
307
|
+
total=len(self.df_test),
|
308
|
+
desc=f"Predicting {len(self.df_test)} audiofiles",
|
309
|
+
):
|
310
|
+
if end == pd.NaT:
|
311
|
+
signal, sr = audiofile.read(file, offset=start)
|
312
|
+
else:
|
313
|
+
signal, sr = audiofile.read(
|
314
|
+
file, duration=end - start, offset=start, always_2d=True
|
315
|
+
)
|
316
|
+
assert sr == self.sampling_rate
|
317
|
+
predictions = self.model.predict(signal)
|
318
|
+
results.append(predictions.argmax())
|
319
|
+
return results
|
320
|
+
|
321
|
+
def predict(self):
|
322
|
+
"""Predict the whole eval feature set"""
|
323
|
+
predictions = self.get_predictions()
|
324
|
+
report = Reporter(
|
325
|
+
self.df_test[self.target].to_numpy().astype(float),
|
326
|
+
predictions,
|
327
|
+
self.run,
|
328
|
+
self.epoch,
|
329
|
+
)
|
330
|
+
return report
|
331
|
+
|
332
|
+
def predict_sample(self, signal):
|
333
|
+
"""Predict one sample"""
|
334
|
+
prediction = {}
|
335
|
+
if self.util.exp_is_classification():
|
336
|
+
# get the class probabilities
|
337
|
+
predictions = self.model.predict(signal)
|
338
|
+
# pred = self.clf.predict(features)
|
339
|
+
for i in range(len(self.labels)):
|
340
|
+
cat = self.labels[i]
|
341
|
+
prediction[cat] = predictions[i]
|
342
|
+
else:
|
343
|
+
predictions = self.model.predict(signal)
|
344
|
+
prediction = predictions
|
345
|
+
return prediction
|
346
|
+
|
347
|
+
def store(self):
|
348
|
+
self.util.debug("stored: ")
|
349
|
+
|
350
|
+
def load(self, run, epoch):
|
351
|
+
self.set_id(run, epoch)
|
352
|
+
self.model = Model.from_pretrained(
|
353
|
+
self.torch_root,
|
354
|
+
config=self.config,
|
355
|
+
)
|
356
|
+
# print(f"loaded model type {type(self.model)}")
|
357
|
+
|
358
|
+
def load_path(self, path, run, epoch):
|
359
|
+
self.set_id(run, epoch)
|
360
|
+
with open(path, "rb") as handle:
|
361
|
+
self.clf = pickle.load(handle)
|
362
|
+
|
363
|
+
|
364
|
+
@dataclasses.dataclass
|
365
|
+
class ModelOutput(transformers.file_utils.ModelOutput):
|
366
|
+
|
367
|
+
logits_cat: torch.FloatTensor = None
|
368
|
+
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
369
|
+
cnn_features: torch.FloatTensor = None
|
370
|
+
|
371
|
+
|
372
|
+
class ModelHead(torch.nn.Module):
|
373
|
+
|
374
|
+
def __init__(self, config, num_labels):
|
375
|
+
|
376
|
+
super().__init__()
|
377
|
+
|
378
|
+
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
379
|
+
self.dropout = torch.nn.Dropout(config.final_dropout)
|
380
|
+
self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
|
381
|
+
|
382
|
+
def forward(self, features, **kwargs):
|
383
|
+
|
384
|
+
x = features
|
385
|
+
x = self.dropout(x)
|
386
|
+
x = self.dense(x)
|
387
|
+
x = torch.tanh(x)
|
388
|
+
x = self.dropout(x)
|
389
|
+
x = self.out_proj(x)
|
390
|
+
|
391
|
+
return x
|
392
|
+
|
393
|
+
|
394
|
+
class Model(Wav2Vec2PreTrainedModel):
|
395
|
+
|
396
|
+
def __init__(self, config):
|
397
|
+
|
398
|
+
super().__init__(config)
|
399
|
+
|
400
|
+
self.wav2vec2 = Wav2Vec2Model(config)
|
401
|
+
self.cat = ModelHead(config, 2)
|
402
|
+
self.init_weights()
|
403
|
+
|
404
|
+
def freeze_feature_extractor(self):
|
405
|
+
self.wav2vec2.feature_extractor._freeze_parameters()
|
406
|
+
|
407
|
+
def pooling(
|
408
|
+
self,
|
409
|
+
hidden_states,
|
410
|
+
attention_mask,
|
411
|
+
):
|
412
|
+
|
413
|
+
if attention_mask is None: # For evaluation with batch_size==1
|
414
|
+
outputs = torch.mean(hidden_states, dim=1)
|
415
|
+
else:
|
416
|
+
attention_mask = self._get_feature_vector_attention_mask(
|
417
|
+
hidden_states.shape[1],
|
418
|
+
attention_mask,
|
419
|
+
)
|
420
|
+
hidden_states = hidden_states * torch.reshape(
|
421
|
+
attention_mask,
|
422
|
+
(-1, attention_mask.shape[-1], 1),
|
423
|
+
)
|
424
|
+
outputs = torch.sum(hidden_states, dim=1)
|
425
|
+
attention_sum = torch.sum(attention_mask, dim=1)
|
426
|
+
outputs = outputs / torch.reshape(attention_sum, (-1, 1))
|
427
|
+
|
428
|
+
return outputs
|
429
|
+
|
430
|
+
def forward(
|
431
|
+
self,
|
432
|
+
input_values,
|
433
|
+
attention_mask=None,
|
434
|
+
labels=None,
|
435
|
+
return_hidden=False,
|
436
|
+
):
|
437
|
+
|
438
|
+
outputs = self.wav2vec2(
|
439
|
+
input_values,
|
440
|
+
attention_mask=attention_mask,
|
441
|
+
)
|
442
|
+
|
443
|
+
cnn_features = outputs.extract_features
|
444
|
+
hidden_states_framewise = outputs.last_hidden_state
|
445
|
+
hidden_states = self.pooling(
|
446
|
+
hidden_states_framewise,
|
447
|
+
attention_mask,
|
448
|
+
)
|
449
|
+
logits_cat = self.cat(hidden_states)
|
450
|
+
|
451
|
+
if not self.training:
|
452
|
+
logits_cat = torch.softmax(logits_cat, dim=1)
|
453
|
+
|
454
|
+
if return_hidden:
|
455
|
+
|
456
|
+
# make time last axis
|
457
|
+
cnn_features = torch.transpose(cnn_features, 1, 2)
|
458
|
+
|
459
|
+
return ModelOutput(
|
460
|
+
logits_cat=logits_cat,
|
461
|
+
hidden_states=hidden_states,
|
462
|
+
cnn_features=cnn_features,
|
463
|
+
)
|
464
|
+
|
465
|
+
else:
|
466
|
+
|
467
|
+
return ModelOutput(
|
468
|
+
logits_cat=logits_cat,
|
469
|
+
)
|
470
|
+
|
471
|
+
def predict(self, signal):
|
472
|
+
result = self(torch.from_numpy(signal))
|
473
|
+
result = result[0].detach().numpy()[0]
|
474
|
+
return result
|
475
|
+
|
476
|
+
|
477
|
+
class ModelWithPreProcessing(Model):
|
478
|
+
|
479
|
+
def __init__(self, config):
|
480
|
+
super().__init__(config)
|
481
|
+
|
482
|
+
def forward(
|
483
|
+
self,
|
484
|
+
input_values,
|
485
|
+
):
|
486
|
+
# Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
|
487
|
+
# normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
488
|
+
|
489
|
+
mean = input_values.mean()
|
490
|
+
|
491
|
+
# var = input_values.var()
|
492
|
+
# raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
|
493
|
+
|
494
|
+
var = torch.square(input_values - mean).mean()
|
495
|
+
input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
|
496
|
+
|
497
|
+
output = super().forward(
|
498
|
+
input_values,
|
499
|
+
return_hidden=True,
|
500
|
+
)
|
501
|
+
|
502
|
+
return (
|
503
|
+
output.hidden_states,
|
504
|
+
output.logits_cat,
|
505
|
+
output.cnn_features,
|
506
|
+
)
|
nkululeko/resample.py
CHANGED
@@ -1,78 +1,100 @@
|
|
1
1
|
# resample.py
|
2
|
-
# change the sampling rate for train
|
2
|
+
# change the sampling rate for audio file or INI file (train, test, all)
|
3
3
|
|
4
4
|
import argparse
|
5
5
|
import configparser
|
6
6
|
import os
|
7
|
-
|
8
7
|
import pandas as pd
|
9
|
-
|
8
|
+
import audformat
|
10
9
|
from nkululeko.augmenting.resampler import Resampler
|
10
|
+
from nkululeko.utils.util import Util
|
11
|
+
|
11
12
|
from nkululeko.constants import VERSION
|
12
13
|
from nkululeko.experiment import Experiment
|
13
|
-
from nkululeko.utils.util import Util
|
14
14
|
|
15
15
|
|
16
16
|
def main(src_dir):
|
17
17
|
parser = argparse.ArgumentParser(
|
18
|
-
description="Call the nkululeko RESAMPLE
|
19
|
-
parser.add_argument("--config", default=
|
18
|
+
description="Call the nkululeko RESAMPLE framework.")
|
19
|
+
parser.add_argument("--config", default=None,
|
20
20
|
help="The base configuration")
|
21
|
+
parser.add_argument("--file", default=None,
|
22
|
+
help="The input audio file to resample")
|
23
|
+
parser.add_argument("--replace", action="store_true",
|
24
|
+
help="Replace the original audio file")
|
25
|
+
|
21
26
|
args = parser.parse_args()
|
22
|
-
if args.config is not None:
|
23
|
-
config_file = args.config
|
24
|
-
else:
|
25
|
-
config_file = f"{src_dir}/exp.ini"
|
26
27
|
|
27
|
-
|
28
|
-
|
29
|
-
print(f"ERROR: no such file: {config_file}")
|
28
|
+
if args.file is None and args.config is None:
|
29
|
+
print("ERROR: Either --file or --config argument must be provided.")
|
30
30
|
exit()
|
31
31
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
f" {
|
43
|
-
|
44
|
-
|
45
|
-
if util.config_val("EXP", "no_warnings", False):
|
46
|
-
import warnings
|
47
|
-
|
48
|
-
warnings.filterwarnings("ignore")
|
49
|
-
|
50
|
-
# load the data
|
51
|
-
expr.load_datasets()
|
52
|
-
|
53
|
-
# split into train and test
|
54
|
-
expr.fill_train_and_tests()
|
55
|
-
util.debug(
|
56
|
-
f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
57
|
-
|
58
|
-
sample_selection = util.config_val("RESAMPLE", "sample_selection", "all")
|
59
|
-
if sample_selection == "all":
|
60
|
-
df = pd.concat([expr.df_train, expr.df_test])
|
61
|
-
elif sample_selection == "train":
|
62
|
-
df = expr.df_train
|
63
|
-
elif sample_selection == "test":
|
64
|
-
df = expr.df_test
|
32
|
+
if args.file is not None:
|
33
|
+
# Load the audio file into a DataFrame
|
34
|
+
files = pd.Series([args.file])
|
35
|
+
df_sample = pd.DataFrame(index=files)
|
36
|
+
df_sample.index = audformat.utils.to_segmented_index(
|
37
|
+
df_sample.index, allow_nat=False
|
38
|
+
)
|
39
|
+
|
40
|
+
# Resample the audio file
|
41
|
+
util = Util("resampler", has_config=False)
|
42
|
+
util.debug(f"Resampling audio file: {args.file}")
|
43
|
+
rs = Resampler(df_sample, not_testing=True, replace=args.replace)
|
44
|
+
rs.resample()
|
65
45
|
else:
|
66
|
-
|
67
|
-
|
68
|
-
|
46
|
+
# Existing code for handling INI file
|
47
|
+
config_file = args.config
|
48
|
+
|
49
|
+
# Test if the configuration file exists
|
50
|
+
if not os.path.isfile(config_file):
|
51
|
+
print(f"ERROR: no such file: {config_file}")
|
52
|
+
exit()
|
53
|
+
|
54
|
+
# Load one configuration per experiment
|
55
|
+
config = configparser.ConfigParser()
|
56
|
+
config.read(config_file)
|
57
|
+
# Create a new experiment
|
58
|
+
expr = Experiment(config)
|
59
|
+
module = "resample"
|
60
|
+
expr.set_module(module)
|
61
|
+
util = Util(module)
|
62
|
+
util.debug(
|
63
|
+
f"running {expr.name} from config {config_file}, nkululeko version"
|
64
|
+
f" {VERSION}"
|
69
65
|
)
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
66
|
+
|
67
|
+
if util.config_val("EXP", "no_warnings", False):
|
68
|
+
import warnings
|
69
|
+
warnings.filterwarnings("ignore")
|
70
|
+
|
71
|
+
# Load the data
|
72
|
+
expr.load_datasets()
|
73
|
+
|
74
|
+
# Split into train and test
|
75
|
+
expr.fill_train_and_tests()
|
76
|
+
util.debug(
|
77
|
+
f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
78
|
+
|
79
|
+
sample_selection = util.config_val(
|
80
|
+
"RESAMPLE", "sample_selection", "all")
|
81
|
+
if sample_selection == "all":
|
82
|
+
df = pd.concat([expr.df_train, expr.df_test])
|
83
|
+
elif sample_selection == "train":
|
84
|
+
df = expr.df_train
|
85
|
+
elif sample_selection == "test":
|
86
|
+
df = expr.df_test
|
87
|
+
else:
|
88
|
+
util.error(
|
89
|
+
f"unknown selection specifier {sample_selection}, should be [all |"
|
90
|
+
" train | test]"
|
91
|
+
)
|
92
|
+
util.debug(f"resampling {sample_selection}: {df.shape[0]} samples")
|
93
|
+
replace = util.config_val("RESAMPLE", "replace", "False")
|
94
|
+
rs = Resampler(df, replace=replace)
|
95
|
+
rs.resample()
|
74
96
|
|
75
97
|
|
76
98
|
if __name__ == "__main__":
|
77
99
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
78
|
-
main(cwd)
|
100
|
+
main(cwd)
|