nkululeko 0.85.1__py3-none-any.whl → 0.86.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nkululeko/constants.py +1 -1
- nkululeko/data/dataset_csv.py +7 -4
- nkululeko/experiment.py +18 -14
- nkululeko/feat_extract/feats_opensmile.py +25 -25
- nkululeko/feat_extract/featureset.py +4 -4
- nkululeko/models/model_tuned.py +149 -88
- {nkululeko-0.85.1.dist-info → nkululeko-0.86.0.dist-info}/METADATA +11 -1
- {nkululeko-0.85.1.dist-info → nkululeko-0.86.0.dist-info}/RECORD +11 -12
- nkululeko/models/finetune_model.py +0 -190
- {nkululeko-0.85.1.dist-info → nkululeko-0.86.0.dist-info}/LICENSE +0 -0
- {nkululeko-0.85.1.dist-info → nkululeko-0.86.0.dist-info}/WHEEL +0 -0
- {nkululeko-0.85.1.dist-info → nkululeko-0.86.0.dist-info}/top_level.txt +0 -0
nkululeko/constants.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
VERSION="0.
|
1
|
+
VERSION="0.86.0"
|
2
2
|
SAMPLING_RATE = 16000
|
nkululeko/data/dataset_csv.py
CHANGED
@@ -21,7 +21,7 @@ class Dataset_CSV(Dataset):
|
|
21
21
|
# exp_root = self.util.config_val("EXP", "root", "")
|
22
22
|
# data_file = os.path.join(exp_root, data_file)
|
23
23
|
root = os.path.dirname(data_file)
|
24
|
-
audio_path = self.util.config_val_data(self.name, "audio_path", "")
|
24
|
+
audio_path = self.util.config_val_data(self.name, "audio_path", "./")
|
25
25
|
df = pd.read_csv(data_file)
|
26
26
|
# special treatment for segmented dataframes with only one column:
|
27
27
|
if "start" in df.columns and len(df.columns) == 4:
|
@@ -49,7 +49,8 @@ class Dataset_CSV(Dataset):
|
|
49
49
|
.map(lambda x: root + "/" + audio_path + "/" + x)
|
50
50
|
.values
|
51
51
|
)
|
52
|
-
df = df.set_index(df.index.set_levels(
|
52
|
+
df = df.set_index(df.index.set_levels(
|
53
|
+
file_index, level="file"))
|
53
54
|
else:
|
54
55
|
if not isinstance(df, pd.DataFrame):
|
55
56
|
df = pd.DataFrame(df)
|
@@ -63,7 +64,8 @@ class Dataset_CSV(Dataset):
|
|
63
64
|
self.db = None
|
64
65
|
self.got_target = True
|
65
66
|
self.is_labeled = self.got_target
|
66
|
-
self.start_fresh = eval(
|
67
|
+
self.start_fresh = eval(
|
68
|
+
self.util.config_val("DATA", "no_reuse", "False"))
|
67
69
|
is_index = False
|
68
70
|
try:
|
69
71
|
if self.is_labeled and not "class_label" in self.df.columns:
|
@@ -90,7 +92,8 @@ class Dataset_CSV(Dataset):
|
|
90
92
|
f" {self.got_gender}, got age: {self.got_age}"
|
91
93
|
)
|
92
94
|
self.util.debug(r_string)
|
93
|
-
glob_conf.report.add_item(ReportItem(
|
95
|
+
glob_conf.report.add_item(ReportItem(
|
96
|
+
"Data", "Loaded report", r_string))
|
94
97
|
|
95
98
|
def prepare(self):
|
96
99
|
super().prepare()
|
nkululeko/experiment.py
CHANGED
@@ -30,15 +30,14 @@ from nkululeko.utils.util import Util
|
|
30
30
|
|
31
31
|
|
32
32
|
class Experiment:
|
33
|
-
"""Main class specifying an experiment"""
|
33
|
+
"""Main class specifying an experiment."""
|
34
34
|
|
35
35
|
def __init__(self, config_obj):
|
36
|
-
"""
|
37
|
-
Parameters
|
38
|
-
----------
|
39
|
-
config_obj : a config parser object that sets the experiment parameters and being set as a global object.
|
40
|
-
"""
|
36
|
+
"""Constructor.
|
41
37
|
|
38
|
+
Args:
|
39
|
+
- config_obj : a config parser object that sets the experiment parameters and being set as a global object.
|
40
|
+
"""
|
42
41
|
self.set_globals(config_obj)
|
43
42
|
self.name = glob_conf.config["EXP"]["name"]
|
44
43
|
self.root = os.path.join(glob_conf.config["EXP"]["root"], "")
|
@@ -109,14 +108,13 @@ class Experiment:
|
|
109
108
|
# print keys/column
|
110
109
|
dbs = ",".join(list(self.datasets.keys()))
|
111
110
|
labels = self.util.config_val("DATA", "labels", False)
|
111
|
+
auto_labels = list(next(iter(self.datasets.values())).df[self.target].unique())
|
112
112
|
if labels:
|
113
113
|
self.labels = ast.literal_eval(labels)
|
114
114
|
self.util.debug(f"Target labels (from config): {labels}")
|
115
115
|
else:
|
116
|
-
self.labels =
|
117
|
-
|
118
|
-
)
|
119
|
-
self.util.debug(f"Target labels (from database): {labels}")
|
116
|
+
self.labels = auto_labels
|
117
|
+
self.util.debug(f"Target labels (from database): {auto_labels}")
|
120
118
|
glob_conf.set_labels(self.labels)
|
121
119
|
self.util.debug(f"loaded databases {dbs}")
|
122
120
|
|
@@ -373,14 +371,18 @@ class Experiment:
|
|
373
371
|
f" ({self.df_test.shape[0]})"
|
374
372
|
)
|
375
373
|
self.df_test = self.df_test[self.df_test.index.isin(self.feats_test.index)]
|
376
|
-
self.util.warn(f"
|
374
|
+
self.util.warn(f"new test labels shape: {self.df_test.shape[0]}")
|
377
375
|
|
378
376
|
self._check_scale()
|
377
|
+
# store = self.util.get_path("store")
|
378
|
+
# store_format = self.util.config_val("FEATS", "store_format", "pkl")
|
379
|
+
# storage = f"{store}test_feats.{store_format}"
|
380
|
+
# self.util.write_store(self.feats_test, storage, store_format)
|
381
|
+
# storage = f"{store}train_feats.{store_format}"
|
382
|
+
# self.util.write_store(self.feats_train, storage, store_format)
|
379
383
|
|
380
384
|
def augment(self):
|
381
|
-
"""
|
382
|
-
Augment the selected samples
|
383
|
-
"""
|
385
|
+
"""Augment the selected samples."""
|
384
386
|
from nkululeko.augmenting.augmenter import Augmenter
|
385
387
|
|
386
388
|
sample_selection = self.util.config_val("AUGMENT", "sample_selection", "all")
|
@@ -577,6 +579,8 @@ class Experiment:
|
|
577
579
|
)
|
578
580
|
|
579
581
|
def _check_scale(self):
|
582
|
+
self.util.save_to_store(self.feats_train, "feats_train")
|
583
|
+
self.util.save_to_store(self.feats_test, "feats_test")
|
580
584
|
scale_feats = self.util.config_val("FEATS", "scale", False)
|
581
585
|
# print the scale
|
582
586
|
self.util.debug(f"scaler: {scale_feats}")
|
@@ -65,28 +65,28 @@ class Opensmileset(Featureset):
|
|
65
65
|
feats = smile.process_signal(signal, sr)
|
66
66
|
return feats.to_numpy()
|
67
67
|
|
68
|
-
def filter(self):
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
68
|
+
# def filter(self):
|
69
|
+
# # use only the features that are indexed in the target dataframes
|
70
|
+
# self.df = self.df[self.df.index.isin(self.data_df.index)]
|
71
|
+
# try:
|
72
|
+
# # use only some features
|
73
|
+
# selected_features = ast.literal_eval(
|
74
|
+
# glob_conf.config["FEATS"]["os.features"]
|
75
|
+
# )
|
76
|
+
# self.util.debug(f"selecting features from opensmile: {selected_features}")
|
77
|
+
# sel_feats_df = pd.DataFrame()
|
78
|
+
# hit = False
|
79
|
+
# for feat in selected_features:
|
80
|
+
# try:
|
81
|
+
# sel_feats_df[feat] = self.df[feat]
|
82
|
+
# hit = True
|
83
|
+
# except KeyError:
|
84
|
+
# pass
|
85
|
+
# if hit:
|
86
|
+
# self.df = sel_feats_df
|
87
|
+
# self.util.debug(
|
88
|
+
# "new feats shape after selecting opensmile features:"
|
89
|
+
# f" {self.df.shape}"
|
90
|
+
# )
|
91
|
+
# except KeyError:
|
92
|
+
# pass
|
@@ -15,7 +15,7 @@ class Featureset:
|
|
15
15
|
self.name = name
|
16
16
|
self.data_df = data_df
|
17
17
|
self.util = Util("featureset")
|
18
|
-
self.
|
18
|
+
self.feats_type = feats_type
|
19
19
|
|
20
20
|
def extract(self):
|
21
21
|
pass
|
@@ -25,8 +25,7 @@ class Featureset:
|
|
25
25
|
self.df = self.df[self.df.index.isin(self.data_df.index)]
|
26
26
|
try:
|
27
27
|
# use only some features
|
28
|
-
selected_features = ast.literal_eval(
|
29
|
-
glob_conf.config["FEATS"]["features"])
|
28
|
+
selected_features = ast.literal_eval(glob_conf.config["FEATS"]["features"])
|
30
29
|
self.util.debug(f"selecting features: {selected_features}")
|
31
30
|
sel_feats_df = pd.DataFrame()
|
32
31
|
hit = False
|
@@ -35,11 +34,12 @@ class Featureset:
|
|
35
34
|
sel_feats_df[feat] = self.df[feat]
|
36
35
|
hit = True
|
37
36
|
except KeyError:
|
37
|
+
self.util.warn(f"non existent feature in {self.feats_type}: {feat}")
|
38
38
|
pass
|
39
39
|
if hit:
|
40
40
|
self.df = sel_feats_df
|
41
41
|
self.util.debug(
|
42
|
-
f"new feats shape after selecting features: {self.df.shape}"
|
42
|
+
f"new feats shape after selecting features for {self.feats_type}: {self.df.shape}"
|
43
43
|
)
|
44
44
|
except KeyError:
|
45
45
|
pass
|
nkululeko/models/model_tuned.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
Code based on @jwagner.
|
3
|
-
"""
|
1
|
+
"""Code based on @jwagner."""
|
4
2
|
|
5
3
|
import dataclasses
|
6
4
|
import json
|
@@ -27,8 +25,6 @@ from nkululeko.reporting.reporter import Reporter
|
|
27
25
|
|
28
26
|
class TunedModel(BaseModel):
|
29
27
|
|
30
|
-
is_classifier = True
|
31
|
-
|
32
28
|
def __init__(self, df_train, df_test, feats_train, feats_test):
|
33
29
|
"""Constructor taking the configuration and all dataframes."""
|
34
30
|
super().__init__(df_train, df_test, feats_train, feats_test)
|
@@ -37,25 +33,47 @@ class TunedModel(BaseModel):
|
|
37
33
|
self.target = glob_conf.config["DATA"]["target"]
|
38
34
|
labels = glob_conf.labels
|
39
35
|
self.class_num = len(labels)
|
40
|
-
device = self.util.config_val("MODEL", "device",
|
41
|
-
|
42
|
-
|
43
|
-
|
36
|
+
device = self.util.config_val("MODEL", "device", False)
|
37
|
+
if not device:
|
38
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
39
|
+
else:
|
40
|
+
self.device = device
|
41
|
+
if self.device != "cpu":
|
44
42
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
45
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = device
|
43
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = self.device
|
44
|
+
self.util.debug(f"running on device {self.device}")
|
45
|
+
self.is_classifier = self.util.exp_is_classification()
|
46
|
+
if self.is_classifier:
|
47
|
+
self.measure = "uar"
|
48
|
+
else:
|
49
|
+
self.measure = self.util.config_val("MODEL", "measure", "ccc")
|
50
|
+
self.util.debug(f"evaluation metrics: {self.measure}")
|
51
|
+
self.batch_size = int(self.util.config_val("MODEL", "batch_size", "8"))
|
52
|
+
self.util.debug(f"batch size: {self.batch_size}")
|
53
|
+
self.learning_rate = float(
|
54
|
+
self.util.config_val("MODEL", "learning_rate", 0.0001)
|
55
|
+
)
|
46
56
|
self.df_train, self.df_test = df_train, df_test
|
47
57
|
self.epoch_num = int(self.util.config_val("EXP", "epochs", 1))
|
48
|
-
|
58
|
+
drop = self.util.config_val("MODEL", "drop", False)
|
59
|
+
self.drop = 0.1
|
60
|
+
if drop:
|
61
|
+
self.drop = float(drop)
|
62
|
+
self.util.debug(f"init: training with dropout: {self.drop}")
|
49
63
|
self._init_model()
|
50
64
|
|
51
65
|
def _init_model(self):
|
52
66
|
model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
|
67
|
+
pretrained_model = self.util.config_val("MODEL", "pretrained_model", model_path)
|
53
68
|
self.num_layers = None
|
54
69
|
self.sampling_rate = 16000
|
55
70
|
self.max_duration_sec = 8.0
|
56
71
|
self.accumulation_steps = 4
|
57
|
-
# create dataset
|
58
72
|
|
73
|
+
# print finetuning information via debug
|
74
|
+
self.util.debug(f"Finetuning from model: {pretrained_model}")
|
75
|
+
|
76
|
+
# create dataset
|
59
77
|
dataset = {}
|
60
78
|
target_name = glob_conf.target
|
61
79
|
data_sources = {
|
@@ -76,22 +94,32 @@ class TunedModel(BaseModel):
|
|
76
94
|
self.dataset = datasets.DatasetDict(dataset)
|
77
95
|
|
78
96
|
# load pre-trained model
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
97
|
+
if self.is_classifier:
|
98
|
+
le = glob_conf.label_encoder
|
99
|
+
mapping = dict(zip(le.classes_, range(len(le.classes_))))
|
100
|
+
target_mapping = {k: int(v) for k, v in mapping.items()}
|
101
|
+
target_mapping_reverse = {
|
102
|
+
value: key for key, value in target_mapping.items()
|
103
|
+
}
|
104
|
+
self.config = transformers.AutoConfig.from_pretrained(
|
105
|
+
model_path,
|
106
|
+
num_labels=len(target_mapping),
|
107
|
+
label2id=target_mapping,
|
108
|
+
id2label=target_mapping_reverse,
|
109
|
+
finetuning_task=target_name,
|
110
|
+
)
|
111
|
+
else:
|
112
|
+
self.config = transformers.AutoConfig.from_pretrained(
|
113
|
+
model_path,
|
114
|
+
num_labels=1,
|
115
|
+
finetuning_task=target_name,
|
116
|
+
)
|
91
117
|
if self.num_layers is not None:
|
92
118
|
self.config.num_hidden_layers = self.num_layers
|
119
|
+
self.config.final_dropout = self.drop
|
93
120
|
setattr(self.config, "sampling_rate", self.sampling_rate)
|
94
121
|
setattr(self.config, "data", self.util.get_data_name())
|
122
|
+
setattr(self.config, "is_classifier", self.is_classifier)
|
95
123
|
|
96
124
|
vocab_dict = {}
|
97
125
|
with open("vocab.json", "w") as vocab_file:
|
@@ -113,7 +141,7 @@ class TunedModel(BaseModel):
|
|
113
141
|
assert self.processor.feature_extractor.sampling_rate == self.sampling_rate
|
114
142
|
|
115
143
|
self.model = Model.from_pretrained(
|
116
|
-
|
144
|
+
pretrained_model,
|
117
145
|
config=self.config,
|
118
146
|
)
|
119
147
|
self.model.freeze_feature_extractor()
|
@@ -170,7 +198,7 @@ class TunedModel(BaseModel):
|
|
170
198
|
return_tensors="pt",
|
171
199
|
)
|
172
200
|
|
173
|
-
batch["labels"] = torch.
|
201
|
+
batch["labels"] = torch.Tensor(targets)
|
174
202
|
|
175
203
|
return batch
|
176
204
|
|
@@ -180,14 +208,25 @@ class TunedModel(BaseModel):
|
|
180
208
|
"UAR": audmetric.unweighted_average_recall,
|
181
209
|
"ACC": audmetric.accuracy,
|
182
210
|
}
|
211
|
+
metrics_reg = {
|
212
|
+
"PCC": audmetric.pearson_cc,
|
213
|
+
"CCC": audmetric.concordance_cc,
|
214
|
+
"MSE": audmetric.mean_squared_error,
|
215
|
+
"MAE": audmetric.mean_absolute_error,
|
216
|
+
}
|
183
217
|
|
184
218
|
# truth = p.label_ids[:, 0].astype(int)
|
185
219
|
truth = p.label_ids
|
186
220
|
preds = p.predictions
|
187
221
|
preds = np.argmax(preds, axis=1)
|
188
222
|
scores = {}
|
189
|
-
|
190
|
-
|
223
|
+
if self.is_classifier:
|
224
|
+
for name, metric in metrics.items():
|
225
|
+
scores[f"{name}"] = metric(truth, preds)
|
226
|
+
else:
|
227
|
+
for name, metric in metrics_reg.items():
|
228
|
+
scores[f"{name}"] = metric(truth, preds)
|
229
|
+
|
191
230
|
return scores
|
192
231
|
|
193
232
|
def train(self):
|
@@ -203,23 +242,24 @@ class TunedModel(BaseModel):
|
|
203
242
|
return
|
204
243
|
targets = pd.DataFrame(self.dataset["train"]["targets"])
|
205
244
|
counts = targets[0].value_counts().sort_index()
|
206
|
-
train_weights = 1 / counts
|
207
|
-
train_weights /= train_weights.sum()
|
208
|
-
self.util.debug("train weights: {train_weights}")
|
209
|
-
criterion = torch.nn.CrossEntropyLoss(
|
210
|
-
weight=torch.Tensor(train_weights).to("cuda"),
|
211
|
-
)
|
212
|
-
# criterion = torch.nn.CrossEntropyLoss()
|
213
245
|
|
214
|
-
|
246
|
+
if self.is_classifier:
|
247
|
+
train_weights = 1 / counts
|
248
|
+
train_weights /= train_weights.sum()
|
249
|
+
self.util.debug(f"train weights: {train_weights}")
|
250
|
+
criterion = torch.nn.CrossEntropyLoss(
|
251
|
+
weight=torch.Tensor(train_weights).to("cuda"),
|
252
|
+
)
|
253
|
+
else:
|
254
|
+
criterion = ConcordanceCorCoeff()
|
215
255
|
|
256
|
+
class Trainer(transformers.Trainer):
|
216
257
|
def compute_loss(
|
217
258
|
self,
|
218
259
|
model,
|
219
260
|
inputs,
|
220
261
|
return_outputs=False,
|
221
262
|
):
|
222
|
-
|
223
263
|
targets = inputs.pop("labels").squeeze()
|
224
264
|
targets = targets.type(torch.long)
|
225
265
|
|
@@ -236,7 +276,8 @@ class TunedModel(BaseModel):
|
|
236
276
|
// 5
|
237
277
|
)
|
238
278
|
num_steps = max(1, num_steps)
|
239
|
-
|
279
|
+
|
280
|
+
metrics_for_best_model = self.measure.upper()
|
240
281
|
|
241
282
|
training_args = transformers.TrainingArguments(
|
242
283
|
output_dir=model_root,
|
@@ -246,13 +287,14 @@ class TunedModel(BaseModel):
|
|
246
287
|
gradient_accumulation_steps=self.accumulation_steps,
|
247
288
|
evaluation_strategy="steps",
|
248
289
|
num_train_epochs=self.epoch_num,
|
249
|
-
fp16=
|
290
|
+
fp16=self.device == "cuda",
|
250
291
|
save_steps=num_steps,
|
251
292
|
eval_steps=num_steps,
|
252
293
|
logging_steps=num_steps,
|
253
|
-
|
294
|
+
logging_strategy="epoch",
|
295
|
+
learning_rate=self.learning_rate,
|
254
296
|
save_total_limit=2,
|
255
|
-
metric_for_best_model=
|
297
|
+
metric_for_best_model=metrics_for_best_model,
|
256
298
|
greater_is_better=True,
|
257
299
|
load_best_model_at_end=True,
|
258
300
|
remove_unused_columns=False,
|
@@ -271,6 +313,7 @@ class TunedModel(BaseModel):
|
|
271
313
|
)
|
272
314
|
trainer.train()
|
273
315
|
trainer.save_model(self.torch_root)
|
316
|
+
self.util.debug(f"saved best model to {self.torch_root}")
|
274
317
|
self.load(self.run, self.epoch)
|
275
318
|
|
276
319
|
def get_predictions(self):
|
@@ -305,7 +348,7 @@ class TunedModel(BaseModel):
|
|
305
348
|
def predict_sample(self, signal):
|
306
349
|
"""Predict one sample"""
|
307
350
|
prediction = {}
|
308
|
-
if self.
|
351
|
+
if self.is_classifier:
|
309
352
|
# get the class probabilities
|
310
353
|
predictions = self.model.predict(signal)
|
311
354
|
# pred = self.clf.predict(features)
|
@@ -337,8 +380,19 @@ class TunedModel(BaseModel):
|
|
337
380
|
@dataclasses.dataclass
|
338
381
|
class ModelOutput(transformers.file_utils.ModelOutput):
|
339
382
|
|
340
|
-
|
383
|
+
logits: torch.FloatTensor = None
|
384
|
+
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
385
|
+
cnn_features: torch.FloatTensor = None
|
386
|
+
|
387
|
+
|
388
|
+
@dataclasses.dataclass
|
389
|
+
class ModelOutputReg(transformers.file_utils.ModelOutput):
|
390
|
+
|
391
|
+
logits: torch.FloatTensor
|
341
392
|
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
393
|
+
attentions: typing.Tuple[torch.FloatTensor] = None
|
394
|
+
logits_framewise: torch.FloatTensor = None
|
395
|
+
hidden_states_framewise: torch.FloatTensor = None
|
342
396
|
cnn_features: torch.FloatTensor = None
|
343
397
|
|
344
398
|
|
@@ -368,10 +422,14 @@ class Model(Wav2Vec2PreTrainedModel):
|
|
368
422
|
|
369
423
|
def __init__(self, config):
|
370
424
|
|
425
|
+
if not hasattr(config, "add_adapter"):
|
426
|
+
setattr(config, "add_adapter", False)
|
427
|
+
|
371
428
|
super().__init__(config)
|
372
429
|
|
373
430
|
self.wav2vec2 = Wav2Vec2Model(config)
|
374
|
-
self.
|
431
|
+
self.head = ModelHead(config)
|
432
|
+
self.is_classifier = config.is_classifier
|
375
433
|
self.init_weights()
|
376
434
|
|
377
435
|
def freeze_feature_extractor(self):
|
@@ -407,39 +465,44 @@ class Model(Wav2Vec2PreTrainedModel):
|
|
407
465
|
labels=None,
|
408
466
|
return_hidden=False,
|
409
467
|
):
|
410
|
-
|
411
468
|
outputs = self.wav2vec2(
|
412
469
|
input_values,
|
413
470
|
attention_mask=attention_mask,
|
414
471
|
)
|
415
|
-
|
416
472
|
cnn_features = outputs.extract_features
|
417
473
|
hidden_states_framewise = outputs.last_hidden_state
|
418
474
|
hidden_states = self.pooling(
|
419
475
|
hidden_states_framewise,
|
420
476
|
attention_mask,
|
421
477
|
)
|
422
|
-
|
423
|
-
|
478
|
+
logits = self.head(hidden_states)
|
424
479
|
if not self.training:
|
425
|
-
|
480
|
+
logits = torch.softmax(logits, dim=1)
|
426
481
|
|
427
482
|
if return_hidden:
|
428
|
-
|
429
483
|
# make time last axis
|
430
484
|
cnn_features = torch.transpose(cnn_features, 1, 2)
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
485
|
+
if self.is_classifier:
|
486
|
+
return ModelOutput(
|
487
|
+
logits=logits,
|
488
|
+
hidden_states=hidden_states,
|
489
|
+
cnn_features=cnn_features,
|
490
|
+
)
|
491
|
+
else:
|
492
|
+
return ModelOutputReg(
|
493
|
+
logits=logits,
|
494
|
+
hidden_states=hidden_states,
|
495
|
+
cnn_features=cnn_features,
|
496
|
+
)
|
438
497
|
else:
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
498
|
+
if self.is_classifier:
|
499
|
+
return ModelOutput(
|
500
|
+
logits=logits,
|
501
|
+
)
|
502
|
+
else:
|
503
|
+
return ModelOutputReg(
|
504
|
+
logits=logits,
|
505
|
+
)
|
443
506
|
|
444
507
|
def predict(self, signal):
|
445
508
|
result = self(torch.from_numpy(signal))
|
@@ -447,33 +510,31 @@ class Model(Wav2Vec2PreTrainedModel):
|
|
447
510
|
return result
|
448
511
|
|
449
512
|
|
450
|
-
class
|
451
|
-
|
452
|
-
def __init__(self, config):
|
453
|
-
super().__init__(config)
|
513
|
+
class ConcordanceCorCoeff(torch.nn.Module):
|
454
514
|
|
455
|
-
def
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
515
|
+
def __init__(self):
|
516
|
+
super().__init__()
|
517
|
+
self.mean = torch.mean
|
518
|
+
self.var = torch.var
|
519
|
+
self.sum = torch.sum
|
520
|
+
self.sqrt = torch.sqrt
|
521
|
+
self.std = torch.std
|
522
|
+
|
523
|
+
def forward(self, prediction, ground_truth):
|
524
|
+
ground_truth = ground_truth.float()
|
525
|
+
mean_gt = self.mean(ground_truth, 0)
|
526
|
+
mean_pred = self.mean(prediction, 0)
|
527
|
+
var_gt = self.var(ground_truth, 0)
|
528
|
+
var_pred = self.var(prediction, 0)
|
529
|
+
v_pred = prediction - mean_pred
|
530
|
+
v_gt = ground_truth - mean_gt
|
531
|
+
cor = self.sum(v_pred * v_gt) / (
|
532
|
+
self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
|
473
533
|
)
|
534
|
+
sd_gt = self.std(ground_truth)
|
535
|
+
sd_pred = self.std(prediction)
|
536
|
+
numerator = 2 * cor * sd_gt * sd_pred
|
537
|
+
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
538
|
+
ccc = numerator / denominator
|
474
539
|
|
475
|
-
return
|
476
|
-
output.hidden_states,
|
477
|
-
output.logits_cat,
|
478
|
-
output.cnn_features,
|
479
|
-
)
|
540
|
+
return 1 - ccc
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nkululeko
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.86.0
|
4
4
|
Summary: Machine learning audio prediction experiments based on templates
|
5
5
|
Home-page: https://github.com/felixbur/nkululeko
|
6
6
|
Author: Felix Burkhardt
|
@@ -333,6 +333,16 @@ F. Burkhardt, Johannes Wagner, Hagen Wierstorf, Florian Eyben and Björn Schulle
|
|
333
333
|
Changelog
|
334
334
|
=========
|
335
335
|
|
336
|
+
Version 0.86.0
|
337
|
+
--------------
|
338
|
+
* added regression to finetuning
|
339
|
+
* added other transformer models to finetuning
|
340
|
+
* added output the train/dev features sets actually used by the model
|
341
|
+
|
342
|
+
Version 0.85.2
|
343
|
+
--------------
|
344
|
+
* added data, and automatic task label detection
|
345
|
+
|
336
346
|
Version 0.85.1
|
337
347
|
--------------
|
338
348
|
* fixed bug in model_finetuned that label_num was constant 2
|
@@ -2,11 +2,11 @@ nkululeko/__init__.py,sha256=62f8HiEzJ8rG2QlTFJXUCMpvuH3fKI33DoJSj33mscc,63
|
|
2
2
|
nkululeko/aug_train.py,sha256=YhuZnS_WVWnun9G-M6g5n6rbRxoVREz6Zh7k6qprFNQ,3194
|
3
3
|
nkululeko/augment.py,sha256=4MG0apTAG5RgkuJrYEjGgDdbodZWi_HweSPNI1JJ5QA,3051
|
4
4
|
nkululeko/cacheddataset.py,sha256=lIJ6hUo5LoxSrzXtWV8mzwO7wRtUETWnOQ4ws2XfL1E,969
|
5
|
-
nkululeko/constants.py,sha256=
|
5
|
+
nkululeko/constants.py,sha256=hvi1X27m7vcqkB_Rgl7alourAusZB1mjPxdW4ChdVyU,39
|
6
6
|
nkululeko/demo.py,sha256=8bl15Kitoesnz8oa8yrs52T6YCSOhWbbq9PnZ8Hj6D0,3232
|
7
7
|
nkululeko/demo_feats.py,sha256=sAeGFojhEj9WEDFtG3SzPBmyYJWLF2rkbpp65m8Ujo4,2025
|
8
8
|
nkululeko/demo_predictor.py,sha256=es56xbT8ifkS_vnrlb5NTZT54gNmeUtNlA4zVA_gnN8,4757
|
9
|
-
nkululeko/experiment.py,sha256=
|
9
|
+
nkululeko/experiment.py,sha256=gUJsBMWuadqxEVzuPVToQzFHC9FRUadptP49kTcBiGs,30962
|
10
10
|
nkululeko/explore.py,sha256=lDzRoW_Taa5u4BBABZLD89BcQWnYlrftJR4jgt1yyj0,2609
|
11
11
|
nkululeko/export.py,sha256=mHeEAAmtZuxdyebLlbSzPrHSi9OMgJHbk35d3DTxRBc,4632
|
12
12
|
nkululeko/feature_extractor.py,sha256=8mssYKmo4LclVI-hiLmJEDZ0ZPyDavFG2YwtXcrGzwM,3976
|
@@ -46,7 +46,7 @@ nkululeko/autopredict/ap_valence.py,sha256=n-hctRKySzhmJtowuMOTUu0T_ld3uK5pnfOzW
|
|
46
46
|
nkululeko/autopredict/estimate_snr.py,sha256=S-bpS0xFkwWc4Ch75UrjbS8y538lQ0U3g_iLRFXureY,5048
|
47
47
|
nkululeko/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
48
|
nkululeko/data/dataset.py,sha256=JGzMD6HIvkFkYBekmbmslIKc5ADaCj06T-8gpqH_kFo,27650
|
49
|
-
nkululeko/data/dataset_csv.py,sha256=
|
49
|
+
nkululeko/data/dataset_csv.py,sha256=vTnjIc2UdSJT7foL-ltE9MWrZTCg0nplwKdEtMPxt2o,3933
|
50
50
|
nkululeko/feat_extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
51
51
|
nkululeko/feat_extract/feats_agender.py,sha256=Qm69G4kqAyTVVk7wwRgrXlNwGaDMGRYyKGpuf0vOEgM,3113
|
52
52
|
nkululeko/feat_extract/feats_agender_agender.py,sha256=tgH2BnwcxpvuLmOkrMbVdBSX0Onfz2MG12FsddalRKI,3424
|
@@ -58,7 +58,7 @@ nkululeko/feat_extract/feats_hubert.py,sha256=cLoUzSLjSYBkQnftjacSL7ES3O7Ysh_KrP
|
|
58
58
|
nkululeko/feat_extract/feats_import.py,sha256=rj1p8lz19tCAC8hLzzZAwZ0M6gzwH3BzfabFUgal0yw,1622
|
59
59
|
nkululeko/feat_extract/feats_mld.py,sha256=Vvu7GZOkn7Vda8eIOXqHjg78zegkFe3vTUaCXyVM0eA,2021
|
60
60
|
nkululeko/feat_extract/feats_mos.py,sha256=KXNt7QYEfxkvr6UyVhig2aWQBaIvovlrR4gPuP03gmo,4174
|
61
|
-
nkululeko/feat_extract/feats_opensmile.py,sha256=
|
61
|
+
nkululeko/feat_extract/feats_opensmile.py,sha256=g6ZsAxjjGGvGfrr5fngWC-NJ8E7CP1kYZwrlodZJzzU,4028
|
62
62
|
nkululeko/feat_extract/feats_oxbow.py,sha256=CmIG9cbHTJTJVnzgCPdQpYpnlewWExpsr5ZcK8Malyo,4980
|
63
63
|
nkululeko/feat_extract/feats_praat.py,sha256=kZrS6srzH7WoWEd2prp1Dxw6g9JklFQGTNq5zzPpHzg,3105
|
64
64
|
nkululeko/feat_extract/feats_snr.py,sha256=9dqZ-4RpK98iJEssM3ttozNd18LWlZYM_QVXvp5xDcs,2829
|
@@ -69,13 +69,12 @@ nkululeko/feat_extract/feats_trill.py,sha256=K2ahhdpwpjgg3WZS1POg3UMP2U44i8cLZZv
|
|
69
69
|
nkululeko/feat_extract/feats_wav2vec2.py,sha256=9WUMfyddB_3nx79g7mZoQrRynhM1uEBWuOotRq8bxoU,5268
|
70
70
|
nkululeko/feat_extract/feats_wavlm.py,sha256=ulxpGjifUFx2ZgGmY32SmBJGIuvkYHoLb2n1LZ8KMwA,4703
|
71
71
|
nkululeko/feat_extract/feats_whisper.py,sha256=0N7Vj65OVi2PNoB_NrDjWT5lP6xZNKxFOZZIoxkJvcA,4533
|
72
|
-
nkululeko/feat_extract/featureset.py,sha256=
|
72
|
+
nkululeko/feat_extract/featureset.py,sha256=ll7tyKAdr--TDShyOYJg0FB4I9NqBq0Ni1k_kUJ-2Vw,1541
|
73
73
|
nkululeko/feat_extract/feinberg_praat.py,sha256=EP9pMALjlKdiYInLQdrZ7MmE499Mq-ISRCgqbqL3Rxc,21304
|
74
74
|
nkululeko/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
75
75
|
nkululeko/losses/loss_ccc.py,sha256=NOK0y0fxKUnU161B5geap6Fmn8QzoPl2MqtPiV8IuJE,976
|
76
76
|
nkululeko/losses/loss_softf1loss.py,sha256=5gW-PuiqeAZcRgfwjueIOQtMokOjZWgQnVIv59HKTCo,1309
|
77
77
|
nkululeko/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
78
|
-
nkululeko/models/finetune_model.py,sha256=OMlzDyUFNXZ2xSiqqH8tbzey_KzPJ4jsoYT-4KrWFKM,5091
|
79
78
|
nkululeko/models/model.py,sha256=PUCqF2r_dEfmFsZn6Cgr1UIzYvxziLH6nSqZ5-vuN1o,11639
|
80
79
|
nkululeko/models/model_bayes.py,sha256=WJFZ8wFKwWATz6MhmjeZIi1Pal1viU549WL_PjXDSy8,406
|
81
80
|
nkululeko/models/model_cnn.py,sha256=bJxqwe6FnVR2hFeqN6EXexYGgvKYFED1VOhBXVlLWaE,9954
|
@@ -89,7 +88,7 @@ nkululeko/models/model_svm.py,sha256=rsME3KvKvNG7bdE5lbvYUu85WZhaASZxxmdNDIVJRZ4
|
|
89
88
|
nkululeko/models/model_svr.py,sha256=_YZeksqB3eBENGlg3g9RwYFlk9rQQ-XCeNBKLlGGVoE,725
|
90
89
|
nkululeko/models/model_tree.py,sha256=rf16faUm4o2LJgkoYpeY998b8DQIvXZ73_m1IS3TnnE,417
|
91
90
|
nkululeko/models/model_tree_reg.py,sha256=IgQcPTE-304HQLYSKPF8Z4ot_Ur9dH01fZjS0nXke_M,428
|
92
|
-
nkululeko/models/model_tuned.py,sha256=
|
91
|
+
nkululeko/models/model_tuned.py,sha256=J5CemIAW_WhZIQgppFgPChrsMJvGYzJlCvJC8O62l9M,18049
|
93
92
|
nkululeko/models/model_xgb.py,sha256=Thgx5ESdIok4v72mKh4plxpo4smGcKALWNCJTDScY0M,447
|
94
93
|
nkululeko/models/model_xgr.py,sha256=aGBtNGLWjOE_2rICGYGFxmT8DtnHYsIl1lIpMtghHsY,418
|
95
94
|
nkululeko/reporting/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -106,8 +105,8 @@ nkululeko/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
106
105
|
nkululeko/utils/files.py,sha256=UiGAtZRWYjHSvlmPaTMtzyNNGE6qaLaxQkybctS7iRM,4021
|
107
106
|
nkululeko/utils/stats.py,sha256=1yUq0FTOyqkU8TwUocJRYdJaqMU5SlOBBRUun9STo2M,2829
|
108
107
|
nkululeko/utils/util.py,sha256=b1IHFucRNuF9Iyv5IJeK4AEg0Rga0xKG80UM5GWWdHA,13816
|
109
|
-
nkululeko-0.
|
110
|
-
nkululeko-0.
|
111
|
-
nkululeko-0.
|
112
|
-
nkululeko-0.
|
113
|
-
nkululeko-0.
|
108
|
+
nkululeko-0.86.0.dist-info/LICENSE,sha256=0zGP5B_W35yAcGfHPS18Q2B8UhvLRY3dQq1MhpsJU_U,1076
|
109
|
+
nkululeko-0.86.0.dist-info/METADATA,sha256=KrHrjQ6rc4oGxN4EJ_TuZ0dVGGI-qIxw8dY1RBTCnLo,36852
|
110
|
+
nkululeko-0.86.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
111
|
+
nkululeko-0.86.0.dist-info/top_level.txt,sha256=DPFNNSHPjUeVKj44dVANAjuVGRCC3MusJ08lc2a8xFA,10
|
112
|
+
nkululeko-0.86.0.dist-info/RECORD,,
|
@@ -1,190 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Code based on @jwagner
|
3
|
-
"""
|
4
|
-
|
5
|
-
import dataclasses
|
6
|
-
import typing
|
7
|
-
|
8
|
-
import torch
|
9
|
-
import transformers
|
10
|
-
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
11
|
-
Wav2Vec2PreTrainedModel,
|
12
|
-
Wav2Vec2Model,
|
13
|
-
)
|
14
|
-
|
15
|
-
|
16
|
-
class ConcordanceCorCoeff(torch.nn.Module):
|
17
|
-
|
18
|
-
def __init__(self):
|
19
|
-
|
20
|
-
super().__init__()
|
21
|
-
|
22
|
-
self.mean = torch.mean
|
23
|
-
self.var = torch.var
|
24
|
-
self.sum = torch.sum
|
25
|
-
self.sqrt = torch.sqrt
|
26
|
-
self.std = torch.std
|
27
|
-
|
28
|
-
def forward(self, prediction, ground_truth):
|
29
|
-
|
30
|
-
mean_gt = self.mean(ground_truth, 0)
|
31
|
-
mean_pred = self.mean(prediction, 0)
|
32
|
-
var_gt = self.var(ground_truth, 0)
|
33
|
-
var_pred = self.var(prediction, 0)
|
34
|
-
v_pred = prediction - mean_pred
|
35
|
-
v_gt = ground_truth - mean_gt
|
36
|
-
cor = self.sum(v_pred * v_gt) / (
|
37
|
-
self.sqrt(self.sum(v_pred**2)) * self.sqrt(self.sum(v_gt**2))
|
38
|
-
)
|
39
|
-
sd_gt = self.std(ground_truth)
|
40
|
-
sd_pred = self.std(prediction)
|
41
|
-
numerator = 2 * cor * sd_gt * sd_pred
|
42
|
-
denominator = var_gt + var_pred + (mean_gt - mean_pred) ** 2
|
43
|
-
ccc = numerator / denominator
|
44
|
-
|
45
|
-
return 1 - ccc
|
46
|
-
|
47
|
-
|
48
|
-
@dataclasses.dataclass
|
49
|
-
class ModelOutput(transformers.file_utils.ModelOutput):
|
50
|
-
|
51
|
-
logits_cat: torch.FloatTensor = None
|
52
|
-
hidden_states: typing.Tuple[torch.FloatTensor] = None
|
53
|
-
cnn_features: torch.FloatTensor = None
|
54
|
-
|
55
|
-
|
56
|
-
class ModelHead(torch.nn.Module):
|
57
|
-
|
58
|
-
def __init__(self, config, num_labels):
|
59
|
-
|
60
|
-
super().__init__()
|
61
|
-
|
62
|
-
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
63
|
-
self.dropout = torch.nn.Dropout(config.final_dropout)
|
64
|
-
self.out_proj = torch.nn.Linear(config.hidden_size, num_labels)
|
65
|
-
|
66
|
-
def forward(self, features, **kwargs):
|
67
|
-
|
68
|
-
x = features
|
69
|
-
x = self.dropout(x)
|
70
|
-
x = self.dense(x)
|
71
|
-
x = torch.tanh(x)
|
72
|
-
x = self.dropout(x)
|
73
|
-
x = self.out_proj(x)
|
74
|
-
|
75
|
-
return x
|
76
|
-
|
77
|
-
|
78
|
-
class Model(Wav2Vec2PreTrainedModel):
|
79
|
-
|
80
|
-
def __init__(self, config):
|
81
|
-
|
82
|
-
super().__init__(config)
|
83
|
-
|
84
|
-
self.wav2vec2 = Wav2Vec2Model(config)
|
85
|
-
self.cat = ModelHead(config, 2)
|
86
|
-
self.init_weights()
|
87
|
-
|
88
|
-
def freeze_feature_extractor(self):
|
89
|
-
self.wav2vec2.feature_extractor._freeze_parameters()
|
90
|
-
|
91
|
-
def pooling(
|
92
|
-
self,
|
93
|
-
hidden_states,
|
94
|
-
attention_mask,
|
95
|
-
):
|
96
|
-
|
97
|
-
if attention_mask is None: # For evaluation with batch_size==1
|
98
|
-
outputs = torch.mean(hidden_states, dim=1)
|
99
|
-
else:
|
100
|
-
attention_mask = self._get_feature_vector_attention_mask(
|
101
|
-
hidden_states.shape[1],
|
102
|
-
attention_mask,
|
103
|
-
)
|
104
|
-
hidden_states = hidden_states * torch.reshape(
|
105
|
-
attention_mask,
|
106
|
-
(-1, attention_mask.shape[-1], 1),
|
107
|
-
)
|
108
|
-
outputs = torch.sum(hidden_states, dim=1)
|
109
|
-
attention_sum = torch.sum(attention_mask, dim=1)
|
110
|
-
outputs = outputs / torch.reshape(attention_sum, (-1, 1))
|
111
|
-
|
112
|
-
return outputs
|
113
|
-
|
114
|
-
def forward(
|
115
|
-
self,
|
116
|
-
input_values,
|
117
|
-
attention_mask=None,
|
118
|
-
labels=None,
|
119
|
-
return_hidden=False,
|
120
|
-
):
|
121
|
-
|
122
|
-
outputs = self.wav2vec2(
|
123
|
-
input_values,
|
124
|
-
attention_mask=attention_mask,
|
125
|
-
)
|
126
|
-
|
127
|
-
cnn_features = outputs.extract_features
|
128
|
-
hidden_states_framewise = outputs.last_hidden_state
|
129
|
-
hidden_states = self.pooling(
|
130
|
-
hidden_states_framewise,
|
131
|
-
attention_mask,
|
132
|
-
)
|
133
|
-
logits_cat = self.cat(hidden_states)
|
134
|
-
|
135
|
-
if not self.training:
|
136
|
-
logits_cat = torch.softmax(logits_cat, dim=1)
|
137
|
-
|
138
|
-
if return_hidden:
|
139
|
-
|
140
|
-
# make time last axis
|
141
|
-
cnn_features = torch.transpose(cnn_features, 1, 2)
|
142
|
-
|
143
|
-
return ModelOutput(
|
144
|
-
logits_cat=logits_cat,
|
145
|
-
hidden_states=hidden_states,
|
146
|
-
cnn_features=cnn_features,
|
147
|
-
)
|
148
|
-
|
149
|
-
else:
|
150
|
-
|
151
|
-
return ModelOutput(
|
152
|
-
logits_cat=logits_cat,
|
153
|
-
)
|
154
|
-
|
155
|
-
def predict(self, signal):
|
156
|
-
result = self(torch.from_numpy(signal))
|
157
|
-
result = result[0].detach().numpy()[0]
|
158
|
-
return result
|
159
|
-
|
160
|
-
|
161
|
-
class ModelWithPreProcessing(Model):
|
162
|
-
|
163
|
-
def __init__(self, config):
|
164
|
-
super().__init__(config)
|
165
|
-
|
166
|
-
def forward(
|
167
|
-
self,
|
168
|
-
input_values,
|
169
|
-
):
|
170
|
-
# Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm():
|
171
|
-
# normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
172
|
-
|
173
|
-
mean = input_values.mean()
|
174
|
-
|
175
|
-
# var = input_values.var()
|
176
|
-
# raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
|
177
|
-
|
178
|
-
var = torch.square(input_values - mean).mean()
|
179
|
-
input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
|
180
|
-
|
181
|
-
output = super().forward(
|
182
|
-
input_values,
|
183
|
-
return_hidden=True,
|
184
|
-
)
|
185
|
-
|
186
|
-
return (
|
187
|
-
output.hidden_states,
|
188
|
-
output.logits_cat,
|
189
|
-
output.cnn_features,
|
190
|
-
)
|
File without changes
|
File without changes
|
File without changes
|