nkululeko 0.83.3__py3-none-any.whl → 0.84.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/augmenting/resampler.py +9 -4
- nkululeko/constants.py +1 -1
- nkululeko/demo.py +6 -7
- nkululeko/demo_predictor.py +4 -3
- nkululeko/experiment.py +15 -12
- nkululeko/explore.py +29 -23
- nkululeko/feat_extract/feats_analyser.py +33 -0
- nkululeko/glob_conf.py +5 -0
- nkululeko/models/finetune_model.py +181 -0
- nkululeko/models/model.py +1 -0
- nkululeko/models/model_bayes.py +1 -0
- nkululeko/models/model_cnn.py +6 -9
- nkululeko/models/model_gmm.py +2 -3
- nkululeko/models/model_knn.py +1 -0
- nkululeko/models/model_knn_reg.py +1 -0
- nkululeko/models/model_lin_reg.py +1 -0
- nkululeko/models/model_mlp.py +17 -7
- nkululeko/models/model_mlp_regression.py +7 -12
- nkululeko/models/model_svm.py +1 -0
- nkululeko/models/model_svr.py +1 -0
- nkululeko/models/model_tree.py +1 -0
- nkululeko/models/model_tree_reg.py +1 -0
- nkululeko/models/model_xgb.py +5 -3
- nkululeko/models/model_xgr.py +6 -4
- nkululeko/resample.py +76 -54
- nkululeko/test_pretrain.py +294 -0
- nkululeko/utils/util.py +81 -35
- {nkululeko-0.83.3.dist-info → nkululeko-0.84.1.dist-info}/METADATA +10 -1
- {nkululeko-0.83.3.dist-info → nkululeko-0.84.1.dist-info}/RECORD +32 -30
- {nkululeko-0.83.3.dist-info → nkululeko-0.84.1.dist-info}/LICENSE +0 -0
- {nkululeko-0.83.3.dist-info → nkululeko-0.84.1.dist-info}/WHEEL +0 -0
- {nkululeko-0.83.3.dist-info → nkululeko-0.84.1.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ class MLP_Reg_model(Model):
|
|
25
25
|
def __init__(self, df_train, df_test, feats_train, feats_test):
|
26
26
|
"""Constructor taking the configuration and all dataframes"""
|
27
27
|
super().__init__(df_train, df_test, feats_train, feats_test)
|
28
|
+
self.name = "mlp_reg"
|
28
29
|
super().set_model_type("ann")
|
29
30
|
self.target = glob_conf.config["DATA"]["target"]
|
30
31
|
labels = glob_conf.labels
|
@@ -52,8 +53,7 @@ class MLP_Reg_model(Model):
|
|
52
53
|
drop = self.util.config_val("MODEL", "drop", False)
|
53
54
|
if drop:
|
54
55
|
self.util.debug(f"training with dropout: {drop}")
|
55
|
-
self.model = self.MLP(
|
56
|
-
feats_train.shape[1], layers, 1, drop).to(self.device)
|
56
|
+
self.model = self.MLP(feats_train.shape[1], layers, 1, drop).to(self.device)
|
57
57
|
self.learning_rate = float(
|
58
58
|
self.util.config_val("MODEL", "learning_rate", 0.0001)
|
59
59
|
)
|
@@ -96,10 +96,8 @@ class MLP_Reg_model(Model):
|
|
96
96
|
_, truths, predictions = self.evaluate_model(
|
97
97
|
self.model, self.testloader, self.device
|
98
98
|
)
|
99
|
-
result, _, _ = self.evaluate_model(
|
100
|
-
|
101
|
-
report = Reporter(truths.numpy(), predictions.numpy(),
|
102
|
-
self.run, self.epoch)
|
99
|
+
result, _, _ = self.evaluate_model(self.model, self.trainloader, self.device)
|
100
|
+
report = Reporter(truths.numpy(), predictions.numpy(), self.run, self.epoch)
|
103
101
|
try:
|
104
102
|
report.result.loss = self.loss
|
105
103
|
except AttributeError: # if the model was loaded from disk the loss is unknown
|
@@ -133,11 +131,9 @@ class MLP_Reg_model(Model):
|
|
133
131
|
|
134
132
|
def __getitem__(self, item):
|
135
133
|
index = self.df.index[item]
|
136
|
-
features = self.df_features.loc[index, :].values.astype(
|
137
|
-
"float32").squeeze()
|
134
|
+
features = self.df_features.loc[index, :].values.astype("float32").squeeze()
|
138
135
|
labels = (
|
139
|
-
np.array([self.df.loc[index, self.label]]
|
140
|
-
).astype("float32").squeeze()
|
136
|
+
np.array([self.df.loc[index, self.label]]).astype("float32").squeeze()
|
141
137
|
)
|
142
138
|
return features, labels
|
143
139
|
|
@@ -194,8 +190,7 @@ class MLP_Reg_model(Model):
|
|
194
190
|
end_index = (index + 1) * loader.batch_size
|
195
191
|
if end_index > len(loader.dataset):
|
196
192
|
end_index = len(loader.dataset)
|
197
|
-
logits[start_index:end_index] = model(
|
198
|
-
features.to(device)).reshape(-1)
|
193
|
+
logits[start_index:end_index] = model(features.to(device)).reshape(-1)
|
199
194
|
targets[start_index:end_index] = labels
|
200
195
|
loss = self.criterion(
|
201
196
|
logits[start_index:end_index].to(
|
nkululeko/models/model_svm.py
CHANGED
@@ -11,6 +11,7 @@ class SVM_model(Model):
|
|
11
11
|
|
12
12
|
def __init__(self, df_train, df_test, feats_train, feats_test):
|
13
13
|
super().__init__(df_train, df_test, feats_train, feats_test)
|
14
|
+
self.name = "svm"
|
14
15
|
c = float(self.util.config_val("MODEL", "C_val", "0.001"))
|
15
16
|
if eval(self.util.config_val("MODEL", "class_weight", "False")):
|
16
17
|
class_weight = "balanced"
|
nkululeko/models/model_svr.py
CHANGED
@@ -11,6 +11,7 @@ class SVR_model(Model):
|
|
11
11
|
|
12
12
|
def __init__(self, df_train, df_test, feats_train, feats_test):
|
13
13
|
super().__init__(df_train, df_test, feats_train, feats_test)
|
14
|
+
self.name = "svr"
|
14
15
|
c = float(self.util.config_val("MODEL", "C_val", "0.001"))
|
15
16
|
# kernel{‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’} or callable, default=’rbf’
|
16
17
|
kernel = self.util.config_val("MODEL", "kernel", "rbf")
|
nkululeko/models/model_tree.py
CHANGED
nkululeko/models/model_xgb.py
CHANGED
@@ -7,9 +7,11 @@ from nkululeko.models.model import Model
|
|
7
7
|
class XGB_model(Model):
|
8
8
|
"""An XGBoost model"""
|
9
9
|
|
10
|
-
|
11
|
-
|
12
|
-
|
10
|
+
def __init__(self, df_train, df_test, feats_train, feats_test):
|
11
|
+
super().__init__(df_train, df_test, feats_train, feats_test)
|
12
|
+
self.name = "xgb"
|
13
|
+
self.is_classifier = True
|
14
|
+
self.clf = XGBClassifier() # set up the classifier
|
13
15
|
|
14
16
|
def get_type(self):
|
15
17
|
return "xgb"
|
nkululeko/models/model_xgr.py
CHANGED
@@ -5,8 +5,10 @@ from nkululeko.models.model import Model
|
|
5
5
|
|
6
6
|
|
7
7
|
class XGR_model(Model):
|
8
|
-
"""An XGBoost model"""
|
8
|
+
"""An XGBoost regression model"""
|
9
9
|
|
10
|
-
|
11
|
-
|
12
|
-
|
10
|
+
def __init__(self, df_train, df_test, feats_train, feats_test):
|
11
|
+
super().__init__(df_train, df_test, feats_train, feats_test)
|
12
|
+
self.name = "xgr"
|
13
|
+
self.is_classifier = False
|
14
|
+
self.clf = XGBRegressor() # set up the regressor
|
nkululeko/resample.py
CHANGED
@@ -1,78 +1,100 @@
|
|
1
1
|
# resample.py
|
2
|
-
# change the sampling rate for train
|
2
|
+
# change the sampling rate for audio file or INI file (train, test, all)
|
3
3
|
|
4
4
|
import argparse
|
5
5
|
import configparser
|
6
6
|
import os
|
7
|
-
|
8
7
|
import pandas as pd
|
9
|
-
|
8
|
+
import audformat
|
10
9
|
from nkululeko.augmenting.resampler import Resampler
|
10
|
+
from nkululeko.utils.util import Util
|
11
|
+
|
11
12
|
from nkululeko.constants import VERSION
|
12
13
|
from nkululeko.experiment import Experiment
|
13
|
-
from nkululeko.utils.util import Util
|
14
14
|
|
15
15
|
|
16
16
|
def main(src_dir):
|
17
17
|
parser = argparse.ArgumentParser(
|
18
|
-
description="Call the nkululeko RESAMPLE
|
19
|
-
parser.add_argument("--config", default=
|
18
|
+
description="Call the nkululeko RESAMPLE framework.")
|
19
|
+
parser.add_argument("--config", default=None,
|
20
20
|
help="The base configuration")
|
21
|
+
parser.add_argument("--file", default=None,
|
22
|
+
help="The input audio file to resample")
|
23
|
+
parser.add_argument("--replace", action="store_true",
|
24
|
+
help="Replace the original audio file")
|
25
|
+
|
21
26
|
args = parser.parse_args()
|
22
|
-
if args.config is not None:
|
23
|
-
config_file = args.config
|
24
|
-
else:
|
25
|
-
config_file = f"{src_dir}/exp.ini"
|
26
27
|
|
27
|
-
|
28
|
-
|
29
|
-
print(f"ERROR: no such file: {config_file}")
|
28
|
+
if args.file is None and args.config is None:
|
29
|
+
print("ERROR: Either --file or --config argument must be provided.")
|
30
30
|
exit()
|
31
31
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
f" {
|
43
|
-
|
44
|
-
|
45
|
-
if util.config_val("EXP", "no_warnings", False):
|
46
|
-
import warnings
|
47
|
-
|
48
|
-
warnings.filterwarnings("ignore")
|
49
|
-
|
50
|
-
# load the data
|
51
|
-
expr.load_datasets()
|
52
|
-
|
53
|
-
# split into train and test
|
54
|
-
expr.fill_train_and_tests()
|
55
|
-
util.debug(
|
56
|
-
f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
57
|
-
|
58
|
-
sample_selection = util.config_val("RESAMPLE", "sample_selection", "all")
|
59
|
-
if sample_selection == "all":
|
60
|
-
df = pd.concat([expr.df_train, expr.df_test])
|
61
|
-
elif sample_selection == "train":
|
62
|
-
df = expr.df_train
|
63
|
-
elif sample_selection == "test":
|
64
|
-
df = expr.df_test
|
32
|
+
if args.file is not None:
|
33
|
+
# Load the audio file into a DataFrame
|
34
|
+
files = pd.Series([args.file])
|
35
|
+
df_sample = pd.DataFrame(index=files)
|
36
|
+
df_sample.index = audformat.utils.to_segmented_index(
|
37
|
+
df_sample.index, allow_nat=False
|
38
|
+
)
|
39
|
+
|
40
|
+
# Resample the audio file
|
41
|
+
util = Util("resampler", has_config=False)
|
42
|
+
util.debug(f"Resampling audio file: {args.file}")
|
43
|
+
rs = Resampler(df_sample, not_testing=True, replace=args.replace)
|
44
|
+
rs.resample()
|
65
45
|
else:
|
66
|
-
|
67
|
-
|
68
|
-
|
46
|
+
# Existing code for handling INI file
|
47
|
+
config_file = args.config
|
48
|
+
|
49
|
+
# Test if the configuration file exists
|
50
|
+
if not os.path.isfile(config_file):
|
51
|
+
print(f"ERROR: no such file: {config_file}")
|
52
|
+
exit()
|
53
|
+
|
54
|
+
# Load one configuration per experiment
|
55
|
+
config = configparser.ConfigParser()
|
56
|
+
config.read(config_file)
|
57
|
+
# Create a new experiment
|
58
|
+
expr = Experiment(config)
|
59
|
+
module = "resample"
|
60
|
+
expr.set_module(module)
|
61
|
+
util = Util(module)
|
62
|
+
util.debug(
|
63
|
+
f"running {expr.name} from config {config_file}, nkululeko version"
|
64
|
+
f" {VERSION}"
|
69
65
|
)
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
66
|
+
|
67
|
+
if util.config_val("EXP", "no_warnings", False):
|
68
|
+
import warnings
|
69
|
+
warnings.filterwarnings("ignore")
|
70
|
+
|
71
|
+
# Load the data
|
72
|
+
expr.load_datasets()
|
73
|
+
|
74
|
+
# Split into train and test
|
75
|
+
expr.fill_train_and_tests()
|
76
|
+
util.debug(
|
77
|
+
f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
78
|
+
|
79
|
+
sample_selection = util.config_val(
|
80
|
+
"RESAMPLE", "sample_selection", "all")
|
81
|
+
if sample_selection == "all":
|
82
|
+
df = pd.concat([expr.df_train, expr.df_test])
|
83
|
+
elif sample_selection == "train":
|
84
|
+
df = expr.df_train
|
85
|
+
elif sample_selection == "test":
|
86
|
+
df = expr.df_test
|
87
|
+
else:
|
88
|
+
util.error(
|
89
|
+
f"unknown selection specifier {sample_selection}, should be [all |"
|
90
|
+
" train | test]"
|
91
|
+
)
|
92
|
+
util.debug(f"resampling {sample_selection}: {df.shape[0]} samples")
|
93
|
+
replace = util.config_val("RESAMPLE", "replace", "False")
|
94
|
+
rs = Resampler(df, replace=replace)
|
95
|
+
rs.resample()
|
74
96
|
|
75
97
|
|
76
98
|
if __name__ == "__main__":
|
77
99
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
78
|
-
main(cwd)
|
100
|
+
main(cwd)
|
@@ -0,0 +1,294 @@
|
|
1
|
+
# test_pretrain.py
|
2
|
+
import argparse
|
3
|
+
import configparser
|
4
|
+
import os.path
|
5
|
+
|
6
|
+
import datasets
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import torch
|
10
|
+
import transformers
|
11
|
+
|
12
|
+
import audeer
|
13
|
+
import audiofile
|
14
|
+
import audmetric
|
15
|
+
|
16
|
+
from nkululeko.constants import VERSION
|
17
|
+
import nkululeko.experiment as exp
|
18
|
+
import nkululeko.models.finetune_model as fm
|
19
|
+
import nkululeko.glob_conf as glob_conf
|
20
|
+
from nkululeko.utils.util import Util
|
21
|
+
import json
|
22
|
+
|
23
|
+
|
24
|
+
def doit(config_file):
|
25
|
+
# test if the configuration file exists
|
26
|
+
if not os.path.isfile(config_file):
|
27
|
+
print(f"ERROR: no such file: {config_file}")
|
28
|
+
exit()
|
29
|
+
|
30
|
+
# load one configuration per experiment
|
31
|
+
config = configparser.ConfigParser()
|
32
|
+
config.read(config_file)
|
33
|
+
|
34
|
+
# create a new experiment
|
35
|
+
expr = exp.Experiment(config)
|
36
|
+
module = "test_pretrain"
|
37
|
+
expr.set_module(module)
|
38
|
+
util = Util(module)
|
39
|
+
util.debug(
|
40
|
+
f"running {expr.name} from config {config_file}, nkululeko version"
|
41
|
+
f" {VERSION}"
|
42
|
+
)
|
43
|
+
|
44
|
+
if util.config_val("EXP", "no_warnings", False):
|
45
|
+
import warnings
|
46
|
+
|
47
|
+
warnings.filterwarnings("ignore")
|
48
|
+
|
49
|
+
# load the data
|
50
|
+
expr.load_datasets()
|
51
|
+
|
52
|
+
# split into train and test
|
53
|
+
expr.fill_train_and_tests()
|
54
|
+
util.debug(f"train shape : {expr.df_train.shape}, test shape:{expr.df_test.shape}")
|
55
|
+
|
56
|
+
log_root = audeer.mkdir("log")
|
57
|
+
model_root = audeer.mkdir("model")
|
58
|
+
torch_root = audeer.path(model_root, "torch")
|
59
|
+
|
60
|
+
metrics_gender = {
|
61
|
+
"UAR": audmetric.unweighted_average_recall,
|
62
|
+
"ACC": audmetric.accuracy,
|
63
|
+
}
|
64
|
+
|
65
|
+
sampling_rate = 16000
|
66
|
+
max_duration_sec = 8.0
|
67
|
+
|
68
|
+
model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
|
69
|
+
num_layers = None
|
70
|
+
|
71
|
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
72
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
73
|
+
|
74
|
+
batch_size = 16
|
75
|
+
accumulation_steps = 4
|
76
|
+
# create dataset
|
77
|
+
|
78
|
+
dataset = {}
|
79
|
+
target_name = glob_conf.target
|
80
|
+
data_sources = {
|
81
|
+
"train": pd.DataFrame(expr.df_train[target_name]),
|
82
|
+
"dev": pd.DataFrame(expr.df_test[target_name]),
|
83
|
+
}
|
84
|
+
|
85
|
+
for split in ["train", "dev"]:
|
86
|
+
df = data_sources[split]
|
87
|
+
df[target_name] = df[target_name].astype("float")
|
88
|
+
|
89
|
+
y = pd.Series(
|
90
|
+
data=df.itertuples(index=False, name=None),
|
91
|
+
index=df.index,
|
92
|
+
dtype=object,
|
93
|
+
name="labels",
|
94
|
+
)
|
95
|
+
|
96
|
+
y.name = "targets"
|
97
|
+
df = y.reset_index()
|
98
|
+
df.start = df.start.dt.total_seconds()
|
99
|
+
df.end = df.end.dt.total_seconds()
|
100
|
+
|
101
|
+
print(f"{split}: {len(df)}")
|
102
|
+
|
103
|
+
ds = datasets.Dataset.from_pandas(df)
|
104
|
+
dataset[split] = ds
|
105
|
+
|
106
|
+
dataset = datasets.DatasetDict(dataset)
|
107
|
+
|
108
|
+
# load pre-trained model
|
109
|
+
le = glob_conf.label_encoder
|
110
|
+
mapping = dict(zip(le.classes_, range(len(le.classes_))))
|
111
|
+
target_mapping = {k: int(v) for k, v in mapping.items()}
|
112
|
+
target_mapping_reverse = {value: key for key, value in target_mapping.items()}
|
113
|
+
|
114
|
+
config = transformers.AutoConfig.from_pretrained(
|
115
|
+
model_path,
|
116
|
+
num_labels=len(target_mapping),
|
117
|
+
label2id=target_mapping,
|
118
|
+
id2label=target_mapping_reverse,
|
119
|
+
finetuning_task=target_name,
|
120
|
+
)
|
121
|
+
if num_layers is not None:
|
122
|
+
config.num_hidden_layers = num_layers
|
123
|
+
setattr(config, "sampling_rate", sampling_rate)
|
124
|
+
setattr(config, "data", util.get_data_name())
|
125
|
+
|
126
|
+
vocab_dict = {}
|
127
|
+
with open("vocab.json", "w") as vocab_file:
|
128
|
+
json.dump(vocab_dict, vocab_file)
|
129
|
+
tokenizer = transformers.Wav2Vec2CTCTokenizer("./vocab.json")
|
130
|
+
tokenizer.save_pretrained(".")
|
131
|
+
|
132
|
+
feature_extractor = transformers.Wav2Vec2FeatureExtractor(
|
133
|
+
feature_size=1,
|
134
|
+
sampling_rate=16000,
|
135
|
+
padding_value=0.0,
|
136
|
+
do_normalize=True,
|
137
|
+
return_attention_mask=True,
|
138
|
+
)
|
139
|
+
processor = transformers.Wav2Vec2Processor(
|
140
|
+
feature_extractor=feature_extractor,
|
141
|
+
tokenizer=tokenizer,
|
142
|
+
)
|
143
|
+
assert processor.feature_extractor.sampling_rate == sampling_rate
|
144
|
+
|
145
|
+
model = fm.Model.from_pretrained(
|
146
|
+
model_path,
|
147
|
+
config=config,
|
148
|
+
)
|
149
|
+
model.freeze_feature_extractor()
|
150
|
+
model.train()
|
151
|
+
|
152
|
+
# training
|
153
|
+
|
154
|
+
def data_collator(data):
|
155
|
+
|
156
|
+
files = [d["file"] for d in data]
|
157
|
+
starts = [d["start"] for d in data]
|
158
|
+
ends = [d["end"] for d in data]
|
159
|
+
targets = [d["targets"] for d in data]
|
160
|
+
|
161
|
+
signals = []
|
162
|
+
for file, start, end in zip(
|
163
|
+
files,
|
164
|
+
starts,
|
165
|
+
ends,
|
166
|
+
):
|
167
|
+
offset = start
|
168
|
+
duration = end - offset
|
169
|
+
if max_duration_sec is not None:
|
170
|
+
duration = min(duration, max_duration_sec)
|
171
|
+
signal, _ = audiofile.read(
|
172
|
+
file,
|
173
|
+
offset=offset,
|
174
|
+
duration=duration,
|
175
|
+
)
|
176
|
+
signals.append(signal.squeeze())
|
177
|
+
|
178
|
+
input_values = processor(
|
179
|
+
signals,
|
180
|
+
sampling_rate=sampling_rate,
|
181
|
+
padding=True,
|
182
|
+
)
|
183
|
+
batch = processor.pad(
|
184
|
+
input_values,
|
185
|
+
padding=True,
|
186
|
+
return_tensors="pt",
|
187
|
+
)
|
188
|
+
|
189
|
+
batch["labels"] = torch.tensor(targets)
|
190
|
+
|
191
|
+
return batch
|
192
|
+
|
193
|
+
def compute_metrics(p: transformers.EvalPrediction):
|
194
|
+
|
195
|
+
truth_gender = p.label_ids[:, 0].astype(int)
|
196
|
+
preds = p.predictions
|
197
|
+
preds_gender = np.argmax(preds, axis=1)
|
198
|
+
|
199
|
+
scores = {}
|
200
|
+
|
201
|
+
for name, metric in metrics_gender.items():
|
202
|
+
scores[f"gender-{name}"] = metric(truth_gender, preds_gender)
|
203
|
+
|
204
|
+
scores["combined"] = scores["gender-UAR"]
|
205
|
+
|
206
|
+
return scores
|
207
|
+
|
208
|
+
targets = pd.DataFrame(dataset["train"]["targets"])
|
209
|
+
counts = targets[0].value_counts().sort_index()
|
210
|
+
train_weights = 1 / counts
|
211
|
+
train_weights /= train_weights.sum()
|
212
|
+
|
213
|
+
print(train_weights)
|
214
|
+
|
215
|
+
criterion_gender = torch.nn.CrossEntropyLoss(
|
216
|
+
weight=torch.Tensor(train_weights).to("cuda"),
|
217
|
+
)
|
218
|
+
|
219
|
+
class Trainer(transformers.Trainer):
|
220
|
+
|
221
|
+
def compute_loss(
|
222
|
+
self,
|
223
|
+
model,
|
224
|
+
inputs,
|
225
|
+
return_outputs=False,
|
226
|
+
):
|
227
|
+
|
228
|
+
targets = inputs.pop("labels").squeeze()
|
229
|
+
targets_gender = targets.type(torch.long)
|
230
|
+
|
231
|
+
outputs = model(**inputs)
|
232
|
+
logits_gender = outputs[0].squeeze()
|
233
|
+
|
234
|
+
loss_gender = criterion_gender(logits_gender, targets_gender)
|
235
|
+
|
236
|
+
loss = loss_gender
|
237
|
+
|
238
|
+
return (loss, outputs) if return_outputs else loss
|
239
|
+
|
240
|
+
num_steps = len(dataset["train"]) // (batch_size * accumulation_steps) // 5
|
241
|
+
num_steps = max(1, num_steps)
|
242
|
+
print(num_steps)
|
243
|
+
|
244
|
+
training_args = transformers.TrainingArguments(
|
245
|
+
output_dir=model_root,
|
246
|
+
logging_dir=log_root,
|
247
|
+
per_device_train_batch_size=batch_size,
|
248
|
+
per_device_eval_batch_size=batch_size,
|
249
|
+
gradient_accumulation_steps=accumulation_steps,
|
250
|
+
evaluation_strategy="steps",
|
251
|
+
num_train_epochs=5.0,
|
252
|
+
fp16=True,
|
253
|
+
save_steps=num_steps,
|
254
|
+
eval_steps=num_steps,
|
255
|
+
logging_steps=num_steps,
|
256
|
+
learning_rate=1e-4,
|
257
|
+
save_total_limit=2,
|
258
|
+
metric_for_best_model="combined",
|
259
|
+
greater_is_better=True,
|
260
|
+
load_best_model_at_end=True,
|
261
|
+
remove_unused_columns=False,
|
262
|
+
)
|
263
|
+
|
264
|
+
trainer = Trainer(
|
265
|
+
model=model,
|
266
|
+
data_collator=data_collator,
|
267
|
+
args=training_args,
|
268
|
+
compute_metrics=compute_metrics,
|
269
|
+
train_dataset=dataset["train"],
|
270
|
+
eval_dataset=dataset["dev"],
|
271
|
+
tokenizer=processor.feature_extractor,
|
272
|
+
callbacks=[transformers.integrations.TensorBoardCallback()],
|
273
|
+
)
|
274
|
+
|
275
|
+
trainer.train()
|
276
|
+
trainer.save_model(torch_root)
|
277
|
+
|
278
|
+
print("DONE")
|
279
|
+
|
280
|
+
|
281
|
+
def main(src_dir):
|
282
|
+
parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
|
283
|
+
parser.add_argument("--config", default="exp.ini", help="The base configuration")
|
284
|
+
args = parser.parse_args()
|
285
|
+
if args.config is not None:
|
286
|
+
config_file = args.config
|
287
|
+
else:
|
288
|
+
config_file = f"{src_dir}/exp.ini"
|
289
|
+
doit(config_file)
|
290
|
+
|
291
|
+
|
292
|
+
if __name__ == "__main__":
|
293
|
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
294
|
+
main(cwd) # use this if you want to state the config file path on command line
|