nkululeko 0.84.0__py3-none-any.whl → 0.84.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,16 +12,19 @@ from nkululeko.utils.util import Util
12
12
 
13
13
 
14
14
  class Resampler:
15
- def __init__(self, df, not_testing=True):
15
+ def __init__(self, df, replace, not_testing=True):
16
16
  self.SAMPLING_RATE = 16000
17
17
  self.df = df
18
18
  self.util = Util("resampler", has_config=not_testing)
19
19
  self.util.warn(f"all files might be resampled to {self.SAMPLING_RATE}")
20
20
  self.not_testing = not_testing
21
+ self.replace = eval(self.util.config_val(
22
+ "RESAMPLE", "replace", "False")) if not not_testing else replace
21
23
 
22
24
  def resample(self):
23
25
  files = self.df.index.get_level_values(0).values
24
- replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
26
+ # replace = eval(self.util.config_val("RESAMPLE", "replace", "False"))
27
+ replace = self.replace
25
28
  if self.not_testing:
26
29
  store = self.util.get_path("store")
27
30
  else:
@@ -42,7 +45,8 @@ class Resampler:
42
45
  continue
43
46
  if org_sr != self.SAMPLING_RATE:
44
47
  self.util.debug(f"resampling {f} (sr = {org_sr})")
45
- resampler = torchaudio.transforms.Resample(org_sr, self.SAMPLING_RATE)
48
+ resampler = torchaudio.transforms.Resample(
49
+ org_sr, self.SAMPLING_RATE)
46
50
  signal = resampler(signal)
47
51
  if replace:
48
52
  torchaudio.save(
@@ -59,7 +63,8 @@ class Resampler:
59
63
  self.df = self.df.set_index(
60
64
  self.df.index.set_levels(new_files, level="file")
61
65
  )
62
- target_file = self.util.config_val("RESAMPLE", "target", "resampled.csv")
66
+ target_file = self.util.config_val(
67
+ "RESAMPLE", "target", "resampled.csv")
63
68
  # remove encoded labels
64
69
  target = self.util.config_val("DATA", "target", "emotion")
65
70
  if "class_label" in self.df.columns:
nkululeko/constants.py CHANGED
@@ -1,2 +1,2 @@
1
- VERSION="0.84.0"
1
+ VERSION="0.84.1"
2
2
  SAMPLING_RATE = 16000
@@ -0,0 +1,181 @@
1
+ import dataclasses
2
+ import typing
3
+
4
+ import torch
5
+ import transformers
6
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
7
+ Wav2Vec2PreTrainedModel,
8
+ Wav2Vec2Model,
9
+ )
10
+
11
+
12
+ class ConcordanceCorCoeff(torch.nn.Module):
13
+
14
+ def __init__(self):
15
+
16
+ super().__init__()
17
+
18
+ self.mean = torch.mean
19
+ self.var = torch.var
20
+ self.sum = torch.sum
21
+ self.sqrt = torch.sqrt
22
+ self.std = torch.std
23
+
24
+ def forward(self, prediction, ground_truth):
25
+
26
+ mean_gt = self.mean(ground_truth, 0)
27
+ mean_pred = self.mean(prediction, 0)
28
+ var_gt = self.var(ground_truth, 0)
29
+ var_pred = self.var(prediction, 0)
30
+ v_pred = prediction - mean_pred
31
+ v_gt = ground_truth - mean_gt
32
+ cor = self.sum(v_pred * v_gt) / (
33
+ self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
34
+ )
35
+ sd_gt = self.std(ground_truth)
36
+ sd_pred = self.std(prediction)
37
+ numerator = 2 * cor * sd_gt * sd_pred
38
+ denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
39
+ ccc = numerator / denominator
40
+
41
+ return 1 - ccc
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class ModelOutput(transformers.file_utils.ModelOutput):
46
+
47
+ logits_cat: torch.FloatTensor = None
48
+ hidden_states: typing.Tuple[torch.FloatTensor] = None
49
+ cnn_features: torch.FloatTensor = None
50
+
51
+
52
+ class ModelHead(torch.nn.Module):
53
+
54
+ def __init__(self, config, num_labels):
55
+
56
+ super().__init__()
57
+
58
+ self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
59
+ self.dropout = torch.nn.Dropout(config.final_dropout)
60
+ self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
61
+
62
+ def forward(self, features, **kwargs):
63
+
64
+ x = features
65
+ x = self.dropout(x)
66
+ x = self.dense(x)
67
+ x = torch.tanh(x)
68
+ x = self.dropout(x)
69
+ x = self.out_proj(x)
70
+
71
+ return x
72
+
73
+
74
+ class Model(Wav2Vec2PreTrainedModel):
75
+
76
+ def __init__(self, config):
77
+
78
+ super().__init__(config)
79
+
80
+ self.wav2vec2 = Wav2Vec2Model(config)
81
+ self.cat = ModelHead(config, 2)
82
+ self.init_weights()
83
+
84
+ def freeze_feature_extractor(self):
85
+ self.wav2vec2.feature_extractor._freeze_parameters()
86
+
87
+ def pooling(
88
+ self,
89
+ hidden_states,
90
+ attention_mask,
91
+ ):
92
+
93
+ if attention_mask is None: # For evaluation with batch_size==1
94
+ outputs = torch.mean(hidden_states, dim=1)
95
+ else:
96
+ attention_mask = self._get_feature_vector_attention_mask(
97
+ hidden_states.shape[1],
98
+ attention_mask,
99
+ )
100
+ hidden_states = hidden_states * torch.reshape(
101
+ attention_mask,
102
+ (-1, attention_mask.shape[-1], 1),
103
+ )
104
+ outputs = torch.sum(hidden_states, dim=1)
105
+ attention_sum = torch.sum(attention_mask, dim=1)
106
+ outputs = outputs / torch.reshape(attention_sum, (-1, 1))
107
+
108
+ return outputs
109
+
110
+ def forward(
111
+ self,
112
+ input_values,
113
+ attention_mask=None,
114
+ labels=None,
115
+ return_hidden=False,
116
+ ):
117
+
118
+ outputs = self.wav2vec2(
119
+ input_values,
120
+ attention_mask=attention_mask,
121
+ )
122
+
123
+ cnn_features = outputs.extract_features
124
+ hidden_states_framewise = outputs.last_hidden_state
125
+ hidden_states = self.pooling(
126
+ hidden_states_framewise,
127
+ attention_mask,
128
+ )
129
+ logits_cat = self.cat(hidden_states)
130
+
131
+ if not self.training:
132
+ logits_cat = torch.softmax(logits_cat, dim=1)
133
+
134
+ if return_hidden:
135
+
136
+ # make time last axis
137
+ cnn_features = torch.transpose(cnn_features, 1, 2)
138
+
139
+ return ModelOutput(
140
+ logits_cat=logits_cat,
141
+ hidden_states=hidden_states,
142
+ cnn_features=cnn_features,
143
+ )
144
+
145
+ else:
146
+
147
+ return ModelOutput(
148
+ logits_cat=logits_cat,
149
+ )
150
+
151
+
152
+ class ModelWithPreProcessing(Model):
153
+
154
+ def __init__(self, config):
155
+ super().__init__(config)
156
+
157
+ def forward(
158
+ self,
159
+ input_values,
160
+ ):
161
+ # Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
162
+ # normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
163
+
164
+ mean = input_values.mean()
165
+
166
+ # var = input_values.var()
167
+ # raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
168
+
169
+ var = torch.square(input_values - mean).mean()
170
+ input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
171
+
172
+ output = super().forward(
173
+ input_values,
174
+ return_hidden=True,
175
+ )
176
+
177
+ return (
178
+ output.hidden_states,
179
+ output.logits_cat,
180
+ output.cnn_features,
181
+ )
nkululeko/resample.py CHANGED
@@ -1,78 +1,100 @@
1
1
  # resample.py
2
- # change the sampling rate for train and test splits
2
+ # change the sampling rate for audio file or INI file (train, test, all)
3
3
 
4
4
  import argparse
5
5
  import configparser
6
6
  import os
7
-
8
7
  import pandas as pd
9
-
8
+ import audformat
10
9
  from nkululeko.augmenting.resampler import Resampler
10
+ from nkululeko.utils.util import Util
11
+
11
12
  from nkululeko.constants import VERSION
12
13
  from nkululeko.experiment import Experiment
13
- from nkululeko.utils.util import Util
14
14
 
15
15
 
16
16
  def main(src_dir):
17
17
  parser = argparse.ArgumentParser(
18
- description="Call the nkululeko RESAMPLE framework.")
19
- parser.add_argument("--config", default="exp.ini",
18
+ description="Call the nkululeko RESAMPLE framework.")
19
+ parser.add_argument("--config", default=None,
20
20
  help="The base configuration")
21
+ parser.add_argument("--file", default=None,
22
+ help="The input audio file to resample")
23
+ parser.add_argument("--replace", action="store_true",
24
+ help="Replace the original audio file")
25
+
21
26
  args = parser.parse_args()
22
- if args.config is not None:
23
- config_file = args.config
24
- else:
25
- config_file = f"{src_dir}/exp.ini"
26
27
 
27
- # test if the configuration file exists
28
- if not os.path.isfile(config_file):
29
- print(f"ERROR: no such file: {config_file}")
28
+ if args.file is None and args.config is None:
29
+ print("ERROR: Either --file or --config argument must be provided.")
30
30
  exit()
31
31
 
32
- # load one configuration per experiment
33
- config = configparser.ConfigParser()
34
- config.read(config_file)
35
- # create a new experiment
36
- expr = Experiment(config)
37
- module = "resample"
38
- expr.set_module(module)
39
- util = Util(module)
40
- util.debug(
41
- f"running {expr.name} from config {config_file}, nkululeko version"
42
- f" {VERSION}"
43
- )
44
-
45
- if util.config_val("EXP", "no_warnings", False):
46
- import warnings
47
-
48
- warnings.filterwarnings("ignore")
49
-
50
- # load the data
51
- expr.load_datasets()
52
-
53
- # split into train and test
54
- expr.fill_train_and_tests()
55
- util.debug(
56
- f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
57
-
58
- sample_selection = util.config_val("RESAMPLE", "sample_selection", "all")
59
- if sample_selection == "all":
60
- df = pd.concat([expr.df_train, expr.df_test])
61
- elif sample_selection == "train":
62
- df = expr.df_train
63
- elif sample_selection == "test":
64
- df = expr.df_test
32
+ if args.file is not None:
33
+ # Load the audio file into a DataFrame
34
+ files = pd.Series([args.file])
35
+ df_sample = pd.DataFrame(index=files)
36
+ df_sample.index = audformat.utils.to_segmented_index(
37
+ df_sample.index, allow_nat=False
38
+ )
39
+
40
+ # Resample the audio file
41
+ util = Util("resampler", has_config=False)
42
+ util.debug(f"Resampling audio file: {args.file}")
43
+ rs = Resampler(df_sample, not_testing=True, replace=args.replace)
44
+ rs.resample()
65
45
  else:
66
- util.error(
67
- f"unknown selection specifier {sample_selection}, should be [all |"
68
- " train | test]"
46
+ # Existing code for handling INI file
47
+ config_file = args.config
48
+
49
+ # Test if the configuration file exists
50
+ if not os.path.isfile(config_file):
51
+ print(f"ERROR: no such file: {config_file}")
52
+ exit()
53
+
54
+ # Load one configuration per experiment
55
+ config = configparser.ConfigParser()
56
+ config.read(config_file)
57
+ # Create a new experiment
58
+ expr = Experiment(config)
59
+ module = "resample"
60
+ expr.set_module(module)
61
+ util = Util(module)
62
+ util.debug(
63
+ f"running {expr.name} from config {config_file}, nkululeko version"
64
+ f" {VERSION}"
69
65
  )
70
- util.debug(f"resampling {sample_selection}: {df.shape[0]} samples")
71
- rs = Resampler(df)
72
- rs.resample()
73
- print("DONE")
66
+
67
+ if util.config_val("EXP", "no_warnings", False):
68
+ import warnings
69
+ warnings.filterwarnings("ignore")
70
+
71
+ # Load the data
72
+ expr.load_datasets()
73
+
74
+ # Split into train and test
75
+ expr.fill_train_and_tests()
76
+ util.debug(
77
+ f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
78
+
79
+ sample_selection = util.config_val(
80
+ "RESAMPLE", "sample_selection", "all")
81
+ if sample_selection == "all":
82
+ df = pd.concat([expr.df_train, expr.df_test])
83
+ elif sample_selection == "train":
84
+ df = expr.df_train
85
+ elif sample_selection == "test":
86
+ df = expr.df_test
87
+ else:
88
+ util.error(
89
+ f"unknown selection specifier {sample_selection}, should be [all |"
90
+ " train | test]"
91
+ )
92
+ util.debug(f"resampling {sample_selection}: {df.shape[0]} samples")
93
+ replace = util.config_val("RESAMPLE", "replace", "False")
94
+ rs = Resampler(df, replace=replace)
95
+ rs.resample()
74
96
 
75
97
 
76
98
  if __name__ == "__main__":
77
99
  cwd = os.path.dirname(os.path.abspath(__file__))
78
- main(cwd) # use this if you want to state the config file path on command line
100
+ main(cwd)
@@ -11,11 +11,14 @@ import transformers
11
11
 
12
12
  import audeer
13
13
  import audiofile
14
+ import audmetric
14
15
 
15
16
  from nkululeko.constants import VERSION
16
17
  import nkululeko.experiment as exp
18
+ import nkululeko.models.finetune_model as fm
17
19
  import nkululeko.glob_conf as glob_conf
18
20
  from nkululeko.utils.util import Util
21
+ import json
19
22
 
20
23
 
21
24
  def doit(config_file):
@@ -50,28 +53,42 @@ def doit(config_file):
50
53
  expr.fill_train_and_tests()
51
54
  util.debug(f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
52
55
 
56
+ log_root = audeer.mkdir("log")
57
+ model_root = audeer.mkdir("model")
58
+ torch_root = audeer.path(model_root, "torch")
59
+
60
+ metrics_gender = {
61
+ "UAR": audmetric.unweighted_average_recall,
62
+ "ACC": audmetric.accuracy,
63
+ }
64
+
53
65
  sampling_rate = 16000
54
66
  max_duration_sec = 8.0
55
67
 
56
68
  model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
57
69
  num_layers = None
58
70
 
71
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
72
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3"
73
+
59
74
  batch_size = 16
60
75
  accumulation_steps = 4
61
-
62
76
  # create dataset
63
77
 
64
78
  dataset = {}
79
+ target_name = glob_conf.target
65
80
  data_sources = {
66
- "train": pd.DataFrame(expr.df_train[glob_conf.target]),
67
- "dev": pd.DataFrame(expr.df_test[glob_conf.target]),
81
+ "train": pd.DataFrame(expr.df_train[target_name]),
82
+ "dev": pd.DataFrame(expr.df_test[target_name]),
68
83
  }
69
84
 
70
85
  for split in ["train", "dev"]:
86
+ df = data_sources[split]
87
+ df[target_name] = df[target_name].astype("float")
71
88
 
72
89
  y = pd.Series(
73
- data=data_sources[split].itertuples(index=False, name=None),
74
- index=data_sources[split].index,
90
+ data=df.itertuples(index=False, name=None),
91
+ index=df.index,
75
92
  dtype=object,
76
93
  name="labels",
77
94
  )
@@ -80,23 +97,183 @@ def doit(config_file):
80
97
  df = y.reset_index()
81
98
  df.start = df.start.dt.total_seconds()
82
99
  df.end = df.end.dt.total_seconds()
100
+
83
101
  print(f"{split}: {len(df)}")
102
+
84
103
  ds = datasets.Dataset.from_pandas(df)
85
104
  dataset[split] = ds
86
105
 
87
- dataset = datasets.DatasetDict(dataset)
106
+ dataset = datasets.DatasetDict(dataset)
107
+
108
+ # load pre-trained model
109
+ le = glob_conf.label_encoder
110
+ mapping = dict(zip(le.classes_, range(len(le.classes_))))
111
+ target_mapping = {k: int(v) for k, v in mapping.items()}
112
+ target_mapping_reverse = {value: key for key, value in target_mapping.items()}
88
113
 
89
114
  config = transformers.AutoConfig.from_pretrained(
90
115
  model_path,
91
- num_labels=len(util.la),
92
- label2id=data.gender_mapping,
93
- id2label=data.gender_mapping_reverse,
94
- finetuning_task="age-gender",
116
+ num_labels=len(target_mapping),
117
+ label2id=target_mapping,
118
+ id2label=target_mapping_reverse,
119
+ finetuning_task=target_name,
95
120
  )
96
121
  if num_layers is not None:
97
122
  config.num_hidden_layers = num_layers
98
123
  setattr(config, "sampling_rate", sampling_rate)
99
- setattr(config, "data", ",".join(sources))
124
+ setattr(config, "data", util.get_data_name())
125
+
126
+ vocab_dict = {}
127
+ with open("vocab.json", "w") as vocab_file:
128
+ json.dump(vocab_dict, vocab_file)
129
+ tokenizer = transformers.Wav2Vec2CTCTokenizer("./vocab.json")
130
+ tokenizer.save_pretrained(".")
131
+
132
+ feature_extractor = transformers.Wav2Vec2FeatureExtractor(
133
+ feature_size=1,
134
+ sampling_rate=16000,
135
+ padding_value=0.0,
136
+ do_normalize=True,
137
+ return_attention_mask=True,
138
+ )
139
+ processor = transformers.Wav2Vec2Processor(
140
+ feature_extractor=feature_extractor,
141
+ tokenizer=tokenizer,
142
+ )
143
+ assert processor.feature_extractor.sampling_rate == sampling_rate
144
+
145
+ model = fm.Model.from_pretrained(
146
+ model_path,
147
+ config=config,
148
+ )
149
+ model.freeze_feature_extractor()
150
+ model.train()
151
+
152
+ # training
153
+
154
+ def data_collator(data):
155
+
156
+ files = [d["file"] for d in data]
157
+ starts = [d["start"] for d in data]
158
+ ends = [d["end"] for d in data]
159
+ targets = [d["targets"] for d in data]
160
+
161
+ signals = []
162
+ for file, start, end in zip(
163
+ files,
164
+ starts,
165
+ ends,
166
+ ):
167
+ offset = start
168
+ duration = end - offset
169
+ if max_duration_sec is not None:
170
+ duration = min(duration, max_duration_sec)
171
+ signal, _ = audiofile.read(
172
+ file,
173
+ offset=offset,
174
+ duration=duration,
175
+ )
176
+ signals.append(signal.squeeze())
177
+
178
+ input_values = processor(
179
+ signals,
180
+ sampling_rate=sampling_rate,
181
+ padding=True,
182
+ )
183
+ batch = processor.pad(
184
+ input_values,
185
+ padding=True,
186
+ return_tensors="pt",
187
+ )
188
+
189
+ batch["labels"] = torch.tensor(targets)
190
+
191
+ return batch
192
+
193
+ def compute_metrics(p: transformers.EvalPrediction):
194
+
195
+ truth_gender = p.label_ids[:, 0].astype(int)
196
+ preds = p.predictions
197
+ preds_gender = np.argmax(preds, axis=1)
198
+
199
+ scores = {}
200
+
201
+ for name, metric in metrics_gender.items():
202
+ scores[f"gender-{name}"] = metric(truth_gender, preds_gender)
203
+
204
+ scores["combined"] = scores["gender-UAR"]
205
+
206
+ return scores
207
+
208
+ targets = pd.DataFrame(dataset["train"]["targets"])
209
+ counts = targets[0].value_counts().sort_index()
210
+ train_weights = 1 / counts
211
+ train_weights /= train_weights.sum()
212
+
213
+ print(train_weights)
214
+
215
+ criterion_gender = torch.nn.CrossEntropyLoss(
216
+ weight=torch.Tensor(train_weights).to("cuda"),
217
+ )
218
+
219
+ class Trainer(transformers.Trainer):
220
+
221
+ def compute_loss(
222
+ self,
223
+ model,
224
+ inputs,
225
+ return_outputs=False,
226
+ ):
227
+
228
+ targets = inputs.pop("labels").squeeze()
229
+ targets_gender = targets.type(torch.long)
230
+
231
+ outputs = model(**inputs)
232
+ logits_gender = outputs[0].squeeze()
233
+
234
+ loss_gender = criterion_gender(logits_gender, targets_gender)
235
+
236
+ loss = loss_gender
237
+
238
+ return (loss, outputs) if return_outputs else loss
239
+
240
+ num_steps = len(dataset["train"]) // (batch_size * accumulation_steps) // 5
241
+ num_steps = max(1, num_steps)
242
+ print(num_steps)
243
+
244
+ training_args = transformers.TrainingArguments(
245
+ output_dir=model_root,
246
+ logging_dir=log_root,
247
+ per_device_train_batch_size=batch_size,
248
+ per_device_eval_batch_size=batch_size,
249
+ gradient_accumulation_steps=accumulation_steps,
250
+ evaluation_strategy="steps",
251
+ num_train_epochs=5.0,
252
+ fp16=True,
253
+ save_steps=num_steps,
254
+ eval_steps=num_steps,
255
+ logging_steps=num_steps,
256
+ learning_rate=1e-4,
257
+ save_total_limit=2,
258
+ metric_for_best_model="combined",
259
+ greater_is_better=True,
260
+ load_best_model_at_end=True,
261
+ remove_unused_columns=False,
262
+ )
263
+
264
+ trainer = Trainer(
265
+ model=model,
266
+ data_collator=data_collator,
267
+ args=training_args,
268
+ compute_metrics=compute_metrics,
269
+ train_dataset=dataset["train"],
270
+ eval_dataset=dataset["dev"],
271
+ tokenizer=processor.feature_extractor,
272
+ callbacks=[transformers.integrations.TensorBoardCallback()],
273
+ )
274
+
275
+ trainer.train()
276
+ trainer.save_model(torch_root)
100
277
 
101
278
  print("DONE")
102
279
 
nkululeko/utils/util.py CHANGED
@@ -33,43 +33,58 @@ class Util:
33
33
  else:
34
34
  self.caller = ""
35
35
  if has_config:
36
- import nkululeko.glob_conf as glob_conf
37
-
38
- self.config = glob_conf.config
39
- self.got_data_roots = self.config_val("DATA", "root_folders", False)
40
- if self.got_data_roots:
41
- # if there is a global data rootfolder file, read from there
42
- if not os.path.isfile(self.got_data_roots):
43
- self.error(f"no such file: {self.got_data_roots}")
44
- self.data_roots = configparser.ConfigParser()
45
- self.data_roots.read(self.got_data_roots)
46
- # self.debug(f"getting data roots from {self.got_data_roots}")
36
+ try:
37
+ import nkululeko.glob_conf as glob_conf
38
+ self.config = glob_conf.config
39
+ self.got_data_roots = self.config_val(
40
+ "DATA", "root_folders", False)
41
+ if self.got_data_roots:
42
+ # if there is a global data rootfolder file, read from there
43
+ if not os.path.isfile(self.got_data_roots):
44
+ self.error(f"no such file: {self.got_data_roots}")
45
+ self.data_roots = configparser.ConfigParser()
46
+ self.data_roots.read(self.got_data_roots)
47
+ except (ModuleNotFoundError, AttributeError):
48
+ self.config = None
49
+ self.got_data_roots = False
47
50
 
48
51
  def get_path(self, entry):
49
52
  """
50
53
  This method allows the user to get the directory path for the given argument.
51
54
  """
52
- root = os.path.join(self.config["EXP"]["root"], "")
53
- name = self.config["EXP"]["name"]
54
- try:
55
- entryn = self.config["EXP"][entry]
56
- except KeyError:
57
- # some default values
55
+ if self.config is None:
56
+ # If no configuration file is provided, use default paths
58
57
  if entry == "fig_dir":
59
- entryn = "./images/"
58
+ dir_name = "./images/"
60
59
  elif entry == "res_dir":
61
- entryn = "./results/"
60
+ dir_name = "./results/"
62
61
  elif entry == "model_dir":
63
- entryn = "./models/"
62
+ dir_name = "./models/"
64
63
  else:
65
- entryn = "./store/"
66
-
67
- # Expand image, model and result directories with run index
68
- if entry == "fig_dir" or entry == "res_dir" or entry == "model_dir":
69
- run = self.config_val("EXP", "run", 0)
70
- entryn = entryn + f"run_{run}/"
64
+ dir_name = "./store/"
65
+ else:
66
+ root = os.path.join(self.config["EXP"]["root"], "")
67
+ name = self.config["EXP"]["name"]
68
+ try:
69
+ entryn = self.config["EXP"][entry]
70
+ except KeyError:
71
+ # some default values
72
+ if entry == "fig_dir":
73
+ entryn = "./images/"
74
+ elif entry == "res_dir":
75
+ entryn = "./results/"
76
+ elif entry == "model_dir":
77
+ entryn = "./models/"
78
+ else:
79
+ entryn = "./store/"
80
+
81
+ # Expand image, model and result directories with run index
82
+ if entry == "fig_dir" or entry == "res_dir" or entry == "model_dir":
83
+ run = self.config_val("EXP", "run", 0)
84
+ entryn = entryn + f"run_{run}/"
85
+
86
+ dir_name = f"{root}{name}/{entryn}"
71
87
 
72
- dir_name = f"{root}{name}/{entryn}"
73
88
  audeer.mkdir(dir_name)
74
89
  return dir_name
75
90
 
@@ -101,7 +116,8 @@ class Util:
101
116
  )
102
117
  return default
103
118
  if not default in self.stopvals:
104
- self.debug(f"value for {key} not found, using default: {default}")
119
+ self.debug(
120
+ f"value for {key} not found, using default: {default}")
105
121
  return default
106
122
 
107
123
  def set_config(self, config):
@@ -138,7 +154,8 @@ class Util:
138
154
  if len(df) == 0:
139
155
  return df
140
156
  if not isinstance(df.index, pd.MultiIndex):
141
- df.index = audformat.utils.to_segmented_index(df.index, allow_nat=False)
157
+ df.index = audformat.utils.to_segmented_index(
158
+ df.index, allow_nat=False)
142
159
  return df
143
160
 
144
161
  def _get_value_descript(self, section, name):
@@ -243,11 +260,14 @@ class Util:
243
260
  print(df.head(1))
244
261
 
245
262
  def config_val(self, section, key, default):
263
+ if self.config is None:
264
+ return default
246
265
  try:
247
266
  return self.config[section][key]
248
267
  except KeyError:
249
- if not default in self.stopvals:
250
- self.debug(f"value for {key} not found, using default: {default}")
268
+ if default not in self.stopvals:
269
+ self.debug(
270
+ f"value for {key} not found, using default: {default}")
251
271
  return default
252
272
 
253
273
  def config_val_list(self, section, key, default):
@@ -255,7 +275,8 @@ class Util:
255
275
  return ast.literal_eval(self.config[section][key])
256
276
  except KeyError:
257
277
  if not default in self.stopvals:
258
- self.debug(f"value for {key} not found, using default: {default}")
278
+ self.debug(
279
+ f"value for {key} not found, using default: {default}")
259
280
  return default
260
281
 
261
282
  def continuous_to_categorical(self, series):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nkululeko
3
- Version: 0.84.0
3
+ Version: 0.84.1
4
4
  Summary: Machine learning audio prediction experiments based on templates
5
5
  Home-page: https://github.com/felixbur/nkululeko
6
6
  Author: Felix Burkhardt
@@ -333,6 +333,10 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
333
333
  Changelog
334
334
  =========
335
335
 
336
+ Version 0.84.1
337
+ --------------
338
+ * made resample independent of config file
339
+
336
340
  Version 0.84.0
337
341
  --------------
338
342
  * added SHAP analysis
@@ -2,7 +2,7 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
2
2
  nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
3
3
  nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
4
4
  nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
5
- nkululeko/constants.py,sha256=Ij7mvqjHA328NaCRJL2JvyYgAPfkfYpVq_WiS735RQY,39
5
+ nkululeko/constants.py,sha256=31GQXyAN-nrfQCNIt6_aSkBVeE_J3GO-PklTEy6EgBg,39
6
6
  nkululeko/demo.py,sha256=8bl15Kitoesnz8oa8yrs52T6YCSOhWbbq9PnZ8Hj6D0,3232
7
7
  nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
8
8
  nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
@@ -19,19 +19,19 @@ nkululeko/nkuluflag.py,sha256=PGWSmZz-PiiHLgcZJAoGOI_Y-sZDVI1ksB8p5r7riWM,3725
19
19
  nkululeko/nkululeko.py,sha256=Kn3s2E3yyH8cJ7z6lkMxrnqtCxTu7-qfe9Zr_ONTD5g,1968
20
20
  nkululeko/plots.py,sha256=nd9tF_61DyAx7oGZF8gTrHXazkgFjFe4eClxu1nQ_XU,23276
21
21
  nkululeko/predict.py,sha256=sF091sSSLnEWcISx9ZcULLie3tY5XeFsQJd6b3vrxFg,2409
22
- nkululeko/resample.py,sha256=3WbxkwgyTe_fW38046Rjxk3knOkFdhqn2C4nfhbUurQ,2287
22
+ nkululeko/resample.py,sha256=IPtYqU0nhZ-CqO_O1jJN0EvpfjxHZdFRwdTpEJOVuaQ,3354
23
23
  nkululeko/runmanager.py,sha256=eTM1DNQKt1lxYhzt4vZyZluPXW9sWlIJHNQzex4lkJU,7624
24
24
  nkululeko/scaler.py,sha256=4nkIqoajkIkuTPK0Z02ifMN_awl6fP_i-GBYdoGYgGM,4101
25
25
  nkululeko/segment.py,sha256=YLKckX44tbvTb3LrdgYw9X4guzuF27sutl92z9DkpZU,4835
26
26
  nkululeko/syllable_nuclei.py,sha256=Sky-C__MeUDaxqHnDl2TGLLYOYvsahD35TUjWGeG31k,10047
27
27
  nkululeko/test.py,sha256=1w624vo5KTzmFC8BUStGlLDmIEAFuJUz7J0W-gp7AxI,1677
28
28
  nkululeko/test_predictor.py,sha256=_w5J8CxH6hmW3mLTKbdfmywl5QpdNAnW1Y8TE5GtlfE,3237
29
- nkululeko/test_pretrain.py,sha256=aoN-C9M4Zn9LwseIWQdMMpEGclnkWs-gJXyItU5V0fk,3109
29
+ nkululeko/test_pretrain.py,sha256=4b_39l01dySei_e0ys2NKo9Gipf1Fukp1GvhQllFHt8,8131
30
30
  nkululeko/augmenting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  nkululeko/augmenting/augmenter.py,sha256=XAt0dpmlnKxqyysqCgV3rcz-pRIvOz7rU7dmGDCVAzs,2905
32
32
  nkululeko/augmenting/randomsplicer.py,sha256=Z5rxdKKUpuncLWuTS6xVfVKUeVbeiYU_dLRHQ5fcg4Y,2669
33
33
  nkululeko/augmenting/randomsplicing.py,sha256=ldym9vZNsZIU5BAAaJVaOmAgmVHNs4a5i5K3bW-WAQU,1791
34
- nkululeko/augmenting/resampler.py,sha256=cRrn27w_f2I6aN0CftlTuHT2edi7pTREh3Yc6BxhcGU,3335
34
+ nkululeko/augmenting/resampler.py,sha256=nOBsiQpX6p4jXsP7x6wak78F3B5YYYRmC_iHX8iuOXs,3542
35
35
  nkululeko/autopredict/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  nkululeko/autopredict/ap_age.py,sha256=2Wn5E-Jd49sTn40WqaMcYtUEl4zEq3OY75XmjOpdxsA,1095
37
37
  nkululeko/autopredict/ap_arousal.py,sha256=ymt0diu4v1osw3VxJbSglsVKDAJYRzebQ2TTfFMKKxk,1024
@@ -75,6 +75,7 @@ nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
75
75
  nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
76
76
  nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
77
77
  nkululeko/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
+ nkululeko/models/finetune_model.py,sha256=bx9NsFpEqf_mBohcrf-9lWjrC4AtOIJ7holNXwaFo2Y,4910
78
79
  nkululeko/models/model.py,sha256=fL6LB6I9Oqo_OWUIptqiu6abuxVYYv8bW2a3m4XSLqU,11601
79
80
  nkululeko/models/model_bayes.py,sha256=WJFZ8wFKwWATz6MhmjeZIi1Pal1viU549WL_PjXDSy8,406
80
81
  nkululeko/models/model_cnn.py,sha256=bJxqwe6FnVR2hFeqN6EXexYGgvKYFED1VOhBXVlLWaE,9954
@@ -103,9 +104,9 @@ nkululeko/segmenting/seg_silero.py,sha256=lLytS38KzARS17omwv8VBw-zz60RVSXGSvZ5Ev
103
104
  nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
104
105
  nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
105
106
  nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
106
- nkululeko/utils/util.py,sha256=lVKcIYHeN8wt7ap8o1UTx5z6nvOY40twJ_C4CshT42Y,13068
107
- nkululeko-0.84.0.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
108
- nkululeko-0.84.0.dist-info/METADATA,sha256=RJnEnBwqdKRLs4J16zOzsps0GXmdVvzPMi1_2hpZh-Q,36346
109
- nkululeko-0.84.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
110
- nkululeko-0.84.0.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
111
- nkululeko-0.84.0.dist-info/RECORD,,
107
+ nkululeko/utils/util.py,sha256=b1IHFucRNuF9Iyv5IJeK4AEg0Rga0xKG80UM5GWWdHA,13816
108
+ nkululeko-0.84.1.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
109
+ nkululeko-0.84.1.dist-info/METADATA,sha256=Y647w-vkRjPG7fssLTEF_Aa_pP74aN-WPCGv6r0_NcE,36420
110
+ nkululeko-0.84.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
111
+ nkululeko-0.84.1.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
112
+ nkululeko-0.84.1.dist-info/RECORD,,