torch-rechub 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/activation.py +54 -52
- torch_rechub/basic/callback.py +32 -32
- torch_rechub/basic/features.py +94 -57
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +720 -240
- torch_rechub/basic/loss_func.py +34 -0
- torch_rechub/basic/metaoptimizer.py +72 -0
- torch_rechub/basic/metric.py +250 -0
- torch_rechub/models/matching/__init__.py +11 -0
- torch_rechub/models/matching/comirec.py +188 -0
- torch_rechub/models/matching/dssm.py +66 -0
- torch_rechub/models/matching/dssm_facebook.py +79 -0
- torch_rechub/models/matching/dssm_senet.py +75 -0
- torch_rechub/models/matching/gru4rec.py +87 -0
- torch_rechub/models/matching/mind.py +101 -0
- torch_rechub/models/matching/narm.py +76 -0
- torch_rechub/models/matching/sasrec.py +140 -0
- torch_rechub/models/matching/sine.py +151 -0
- torch_rechub/models/matching/stamp.py +83 -0
- torch_rechub/models/matching/youtube_dnn.py +71 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -4
- torch_rechub/models/multi_task/aitm.py +84 -0
- torch_rechub/models/multi_task/esmm.py +55 -45
- torch_rechub/models/multi_task/mmoe.py +58 -52
- torch_rechub/models/multi_task/ple.py +130 -104
- torch_rechub/models/multi_task/shared_bottom.py +45 -44
- torch_rechub/models/ranking/__init__.py +11 -3
- torch_rechub/models/ranking/afm.py +63 -0
- torch_rechub/models/ranking/bst.py +63 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +69 -0
- torch_rechub/models/ranking/deepffm.py +123 -0
- torch_rechub/models/ranking/deepfm.py +41 -41
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +91 -81
- torch_rechub/models/ranking/edcn.py +117 -0
- torch_rechub/models/ranking/fibinet.py +50 -0
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +2 -1
- torch_rechub/trainers/{trainer.py → ctr_trainer.py} +128 -111
- torch_rechub/trainers/match_trainer.py +170 -0
- torch_rechub/trainers/mtl_trainer.py +206 -144
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +360 -0
- torch_rechub/utils/match.py +274 -0
- torch_rechub/utils/mtl.py +126 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +177 -0
- torch_rechub-0.0.3.dist-info/RECORD +55 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/WHEEL +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import tqdm
|
|
4
|
+
from sklearn.metrics import roc_auc_score
|
|
5
|
+
from ..basic.callback import EarlyStopper
|
|
6
|
+
from ..basic.loss_func import BPRLoss
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MatchTrainer(object):
|
|
10
|
+
"""A general trainer for Matching/Retrieval
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
model (nn.Module): any matching model.
|
|
14
|
+
mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
|
|
15
|
+
optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
|
|
16
|
+
optimizer_params (dict): parameters of optimizer_fn.
|
|
17
|
+
scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
|
|
18
|
+
scheduler_params (dict): parameters of optimizer scheduler_fn.
|
|
19
|
+
n_epoch (int): epoch number of training.
|
|
20
|
+
earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
|
|
21
|
+
device (str): `"cpu"` or `"cuda:0"`
|
|
22
|
+
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
|
|
23
|
+
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
model,
|
|
29
|
+
mode=0,
|
|
30
|
+
optimizer_fn=torch.optim.Adam,
|
|
31
|
+
optimizer_params=None,
|
|
32
|
+
scheduler_fn=None,
|
|
33
|
+
scheduler_params=None,
|
|
34
|
+
n_epoch=10,
|
|
35
|
+
earlystop_patience=10,
|
|
36
|
+
device="cpu",
|
|
37
|
+
gpus=None,
|
|
38
|
+
model_path="./",
|
|
39
|
+
):
|
|
40
|
+
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
41
|
+
if gpus is None:
|
|
42
|
+
gpus = []
|
|
43
|
+
self.gpus = gpus
|
|
44
|
+
if len(gpus) > 1:
|
|
45
|
+
print('parallel running on these gpus:', gpus)
|
|
46
|
+
self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
|
|
47
|
+
self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
48
|
+
self.model.to(self.device)
|
|
49
|
+
if optimizer_params is None:
|
|
50
|
+
optimizer_params = {
|
|
51
|
+
"lr": 1e-3,
|
|
52
|
+
"weight_decay": 1e-5
|
|
53
|
+
}
|
|
54
|
+
self.mode = mode
|
|
55
|
+
if mode == 0: #point-wise loss, binary cross_entropy
|
|
56
|
+
self.criterion = torch.nn.BCELoss() #default loss binary cross_entropy
|
|
57
|
+
elif mode == 1: #pair-wise loss
|
|
58
|
+
self.criterion = BPRLoss()
|
|
59
|
+
elif mode == 2: #list-wise loss, softmax
|
|
60
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError("mode only contain value in %s, but got %s" % ([0, 1, 2], mode))
|
|
63
|
+
self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default optimizer
|
|
64
|
+
self.scheduler = None
|
|
65
|
+
if scheduler_fn is not None:
|
|
66
|
+
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
|
|
67
|
+
self.evaluate_fn = roc_auc_score #default evaluate function
|
|
68
|
+
self.n_epoch = n_epoch
|
|
69
|
+
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
70
|
+
self.model_path = model_path
|
|
71
|
+
|
|
72
|
+
def train_one_epoch(self, data_loader, log_interval=10):
|
|
73
|
+
self.model.train()
|
|
74
|
+
total_loss = 0
|
|
75
|
+
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
76
|
+
for i, (x_dict, y) in enumerate(tk0):
|
|
77
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
|
|
78
|
+
y = y.to(self.device)
|
|
79
|
+
if self.mode == 0:
|
|
80
|
+
y = y.float() #torch._C._nn.binary_cross_entropy expected Float
|
|
81
|
+
else:
|
|
82
|
+
y = y.long() #
|
|
83
|
+
if self.mode == 1: #pair_wise
|
|
84
|
+
pos_score, neg_score = self.model(x_dict)
|
|
85
|
+
loss = self.criterion(pos_score, neg_score)
|
|
86
|
+
else:
|
|
87
|
+
y_pred = self.model(x_dict)
|
|
88
|
+
loss = self.criterion(y_pred, y)
|
|
89
|
+
# used for debug
|
|
90
|
+
# if i == 0:
|
|
91
|
+
# print()
|
|
92
|
+
# if self.mode == 0:
|
|
93
|
+
# print('pred: ', [f'{float(each):5.2g}' for each in y_pred.detach().cpu().tolist()])
|
|
94
|
+
# print('truth:', [f'{float(each):5.2g}' for each in y.detach().cpu().tolist()])
|
|
95
|
+
# elif self.mode == 2:
|
|
96
|
+
# pred = y_pred.detach().cpu().mean(0)
|
|
97
|
+
# pred = torch.softmax(pred, dim=0).tolist()
|
|
98
|
+
# print('pred: ', [f'{float(each):4.2g}' for each in pred])
|
|
99
|
+
# elif self.mode == 1:
|
|
100
|
+
# print('pos:', [f'{float(each):5.2g}' for each in pos_score.detach().cpu().tolist()])
|
|
101
|
+
# print('neg: ', [f'{float(each):5.2g}' for each in neg_score.detach().cpu().tolist()])
|
|
102
|
+
|
|
103
|
+
self.model.zero_grad()
|
|
104
|
+
loss.backward()
|
|
105
|
+
self.optimizer.step()
|
|
106
|
+
total_loss += loss.item()
|
|
107
|
+
if (i + 1) % log_interval == 0:
|
|
108
|
+
tk0.set_postfix(loss=total_loss / log_interval)
|
|
109
|
+
total_loss = 0
|
|
110
|
+
|
|
111
|
+
def fit(self, train_dataloader, val_dataloader=None):
|
|
112
|
+
for epoch_i in range(self.n_epoch):
|
|
113
|
+
print('epoch:', epoch_i)
|
|
114
|
+
self.train_one_epoch(train_dataloader)
|
|
115
|
+
if self.scheduler is not None:
|
|
116
|
+
if epoch_i % self.scheduler.step_size == 0:
|
|
117
|
+
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
118
|
+
self.scheduler.step() #update lr in epoch level by scheduler
|
|
119
|
+
|
|
120
|
+
if val_dataloader:
|
|
121
|
+
auc = self.evaluate(self.model, val_dataloader)
|
|
122
|
+
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
123
|
+
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
124
|
+
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
125
|
+
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
126
|
+
break
|
|
127
|
+
torch.save(self.model.state_dict(), os.path.join(self.model_path,
|
|
128
|
+
"model.pth")) #save best auc model
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def evaluate(self, model, data_loader):
|
|
132
|
+
model.eval()
|
|
133
|
+
targets, predicts = list(), list()
|
|
134
|
+
with torch.no_grad():
|
|
135
|
+
tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
|
|
136
|
+
for i, (x_dict, y) in enumerate(tk0):
|
|
137
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
138
|
+
y = y.to(self.device)
|
|
139
|
+
y_pred = model(x_dict)
|
|
140
|
+
targets.extend(y.tolist())
|
|
141
|
+
predicts.extend(y_pred.tolist())
|
|
142
|
+
return self.evaluate_fn(targets, predicts)
|
|
143
|
+
|
|
144
|
+
def predict(self, model, data_loader):
|
|
145
|
+
model.eval()
|
|
146
|
+
predicts = list()
|
|
147
|
+
with torch.no_grad():
|
|
148
|
+
tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
|
|
149
|
+
for i, (x_dict, y) in enumerate(tk0):
|
|
150
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
151
|
+
y = y.to(self.device)
|
|
152
|
+
y_pred = model(x_dict)
|
|
153
|
+
predicts.extend(y_pred.tolist())
|
|
154
|
+
return predicts
|
|
155
|
+
|
|
156
|
+
def inference_embedding(self, model, mode, data_loader, model_path):
|
|
157
|
+
#inference
|
|
158
|
+
assert mode in ["user", "item"], "Invalid mode={}.".format(mode)
|
|
159
|
+
model.mode = mode
|
|
160
|
+
model.load_state_dict(torch.load(os.path.join(model_path, "model.pth")))
|
|
161
|
+
model = model.to(self.device)
|
|
162
|
+
model.eval()
|
|
163
|
+
predicts = []
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
tk0 = tqdm.tqdm(data_loader, desc="%s inference" % (mode), smoothing=0, mininterval=1.0)
|
|
166
|
+
for i, x_dict in enumerate(tk0):
|
|
167
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
168
|
+
y_pred = model(x_dict)
|
|
169
|
+
predicts.append(y_pred.data)
|
|
170
|
+
return torch.cat(predicts, dim=0)
|
|
@@ -1,145 +1,207 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import tqdm
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
import torch.nn as nn
|
|
6
|
-
from ..basic.callback import EarlyStopper
|
|
7
|
-
from ..
|
|
8
|
-
from ..models.multi_task import ESMM
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
self.model.
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
self.
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
1
|
+
import os
|
|
2
|
+
import tqdm
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from ..basic.callback import EarlyStopper
|
|
7
|
+
from ..utils.data import get_loss_func, get_metric_func
|
|
8
|
+
from ..models.multi_task import ESMM
|
|
9
|
+
from ..utils.mtl import shared_task_layers, gradnorm, MetaBalance
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MTLTrainer(object):
|
|
13
|
+
"""A trainer for multi task learning.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
model (nn.Module): any multi task learning model.
|
|
17
|
+
task_types (list): types of tasks, only support ["classfication", "regression"].
|
|
18
|
+
optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
|
|
19
|
+
optimizer_params (dict): parameters of optimizer_fn.
|
|
20
|
+
scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
|
|
21
|
+
scheduler_params (dict): parameters of optimizer scheduler_fn.
|
|
22
|
+
adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
|
|
23
|
+
n_epoch (int): epoch number of training.
|
|
24
|
+
earlystop_taskid (int): task id of earlystop metrics relies between multi task (default = 0).
|
|
25
|
+
earlystop_patience (int): how long to wait after last time validation auc improved (default = 10).
|
|
26
|
+
device (str): `"cpu"` or `"cuda:0"`
|
|
27
|
+
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
|
|
28
|
+
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
model,
|
|
34
|
+
task_types,
|
|
35
|
+
optimizer_fn=torch.optim.Adam,
|
|
36
|
+
optimizer_params=None,
|
|
37
|
+
scheduler_fn=None,
|
|
38
|
+
scheduler_params=None,
|
|
39
|
+
adaptive_params=None,
|
|
40
|
+
n_epoch=10,
|
|
41
|
+
earlystop_taskid=0,
|
|
42
|
+
earlystop_patience=10,
|
|
43
|
+
device="cpu",
|
|
44
|
+
gpus=None,
|
|
45
|
+
model_path="./",
|
|
46
|
+
):
|
|
47
|
+
self.model = model
|
|
48
|
+
if gpus is None:
|
|
49
|
+
gpus = []
|
|
50
|
+
if optimizer_params is None:
|
|
51
|
+
optimizer_params = {
|
|
52
|
+
"lr": 1e-3,
|
|
53
|
+
"weight_decay": 1e-5
|
|
54
|
+
}
|
|
55
|
+
self.task_types = task_types
|
|
56
|
+
self.n_task = len(task_types)
|
|
57
|
+
self.loss_weight = None
|
|
58
|
+
self.adaptive_method = None
|
|
59
|
+
if adaptive_params is not None:
|
|
60
|
+
if adaptive_params["method"] == "uwl":
|
|
61
|
+
self.adaptive_method = "uwl"
|
|
62
|
+
self.loss_weight = nn.ParameterList(nn.Parameter(torch.zeros(1)) for _ in range(self.n_task))
|
|
63
|
+
self.model.add_module("loss weight", self.loss_weight)
|
|
64
|
+
elif adaptive_params["method"] == "metabalance":
|
|
65
|
+
self.adaptive_method = "metabalance"
|
|
66
|
+
share_layers, task_layers = shared_task_layers(self.model)
|
|
67
|
+
self.meta_optimizer = MetaBalance(share_layers)
|
|
68
|
+
self.share_optimizer = optimizer_fn(share_layers, **optimizer_params)
|
|
69
|
+
self.task_optimizer = optimizer_fn(task_layers, **optimizer_params)
|
|
70
|
+
elif adaptive_params["method"] == "gradnorm":
|
|
71
|
+
self.adaptive_method = "gradnorm"
|
|
72
|
+
self.alpha = adaptive_params.get("alpha", 0.16)
|
|
73
|
+
share_layers = shared_task_layers(self.model)[0]
|
|
74
|
+
#gradnorm calculate the gradients of each loss on the last fully connected shared layer weight(dimension is 2)
|
|
75
|
+
for i in range(len(share_layers)):
|
|
76
|
+
if share_layers[-i].ndim == 2:
|
|
77
|
+
self.last_share_layer = share_layers[-i]
|
|
78
|
+
break
|
|
79
|
+
self.initial_task_loss = None
|
|
80
|
+
self.loss_weight = nn.ParameterList(nn.Parameter(torch.ones(1)) for _ in range(self.n_task))
|
|
81
|
+
self.model.add_module("loss weight", self.loss_weight)
|
|
82
|
+
if self.adaptive_method != "metabalance":
|
|
83
|
+
self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default Adam optimizer
|
|
84
|
+
self.scheduler = None
|
|
85
|
+
if scheduler_fn is not None:
|
|
86
|
+
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
|
|
87
|
+
self.loss_fns = [get_loss_func(task_type) for task_type in task_types]
|
|
88
|
+
self.evaluate_fns = [get_metric_func(task_type) for task_type in task_types]
|
|
89
|
+
self.n_epoch = n_epoch
|
|
90
|
+
self.earlystop_taskid = earlystop_taskid
|
|
91
|
+
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
92
|
+
|
|
93
|
+
self.gpus = gpus
|
|
94
|
+
if len(gpus) > 1:
|
|
95
|
+
print('parallel running on these gpus:', gpus)
|
|
96
|
+
self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
|
|
97
|
+
self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
98
|
+
self.model.to(self.device)
|
|
99
|
+
self.model_path = model_path
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def train_one_epoch(self, data_loader):
|
|
103
|
+
self.model.train()
|
|
104
|
+
total_loss = np.zeros(self.n_task)
|
|
105
|
+
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
106
|
+
for iter_i, (x_dict, ys) in enumerate(tk0):
|
|
107
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
|
|
108
|
+
ys = ys.to(self.device)
|
|
109
|
+
y_preds = self.model(x_dict)
|
|
110
|
+
loss_list = [self.loss_fns[i](y_preds[:, i], ys[:, i].float()) for i in range(self.n_task)]
|
|
111
|
+
if isinstance(self.model, ESMM):
|
|
112
|
+
loss = sum(loss_list[1:]) #ESSM only compute loss for ctr and ctcvr task
|
|
113
|
+
else:
|
|
114
|
+
if self.adaptive_method != None:
|
|
115
|
+
if self.adaptive_method == "uwl":
|
|
116
|
+
loss = 0
|
|
117
|
+
for loss_i, w_i in zip(loss_list, self.loss_weight):
|
|
118
|
+
w_i = torch.clamp(w_i, min=0)
|
|
119
|
+
loss += 2 * loss_i * torch.exp(-w_i) + w_i
|
|
120
|
+
else:
|
|
121
|
+
loss = sum(loss_list) / self.n_task
|
|
122
|
+
if self.adaptive_method == 'metabalance':
|
|
123
|
+
self.share_optimizer.zero_grad()
|
|
124
|
+
self.task_optimizer.zero_grad()
|
|
125
|
+
self.meta_optimizer.step(loss_list)
|
|
126
|
+
self.share_optimizer.step()
|
|
127
|
+
self.task_optimizer.step()
|
|
128
|
+
elif self.adaptive_method == "gradnorm":
|
|
129
|
+
self.optimizer.zero_grad()
|
|
130
|
+
if self.initial_task_loss is None:
|
|
131
|
+
self.initial_task_loss = [l.item() for l in loss_list]
|
|
132
|
+
gradnorm(loss_list, self.loss_weight, self.last_share_layer, self.initial_task_loss, self.alpha)
|
|
133
|
+
self.optimizer.step()
|
|
134
|
+
# renormalize
|
|
135
|
+
loss_weight_sum = sum([w.item() for w in self.loss_weight])
|
|
136
|
+
normalize_coeff = len(self.loss_weight) / loss_weight_sum
|
|
137
|
+
for w in self.loss_weight:
|
|
138
|
+
w.data = w.data * normalize_coeff
|
|
139
|
+
else:
|
|
140
|
+
self.model.zero_grad()
|
|
141
|
+
loss.backward()
|
|
142
|
+
self.optimizer.step()
|
|
143
|
+
total_loss += np.array([l.item() for l in loss_list])
|
|
144
|
+
log_dict = {"task_%d:" % (i): total_loss[i] / (iter_i + 1) for i in range(self.n_task)}
|
|
145
|
+
loss_list = [total_loss[i] / (iter_i + 1) for i in range(self.n_task)]
|
|
146
|
+
print("train loss: ", log_dict)
|
|
147
|
+
if self.loss_weight:
|
|
148
|
+
print("loss weight: ", [w.item() for w in self.loss_weight])
|
|
149
|
+
|
|
150
|
+
return loss_list
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def fit(self, train_dataloader, val_dataloader, mode = 'base', seed = 0):
|
|
154
|
+
total_log = []
|
|
155
|
+
|
|
156
|
+
for epoch_i in range(self.n_epoch):
|
|
157
|
+
_log_per_epoch = self.train_one_epoch(train_dataloader)
|
|
158
|
+
|
|
159
|
+
if self.scheduler is not None:
|
|
160
|
+
if epoch_i % self.scheduler.step_size == 0:
|
|
161
|
+
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
162
|
+
self.scheduler.step() #update lr in epoch level by scheduler
|
|
163
|
+
scores = self.evaluate(self.model, val_dataloader)
|
|
164
|
+
print('epoch:', epoch_i, 'validation scores: ', scores)
|
|
165
|
+
|
|
166
|
+
for score in scores:
|
|
167
|
+
_log_per_epoch.append(score)
|
|
168
|
+
|
|
169
|
+
total_log.append(_log_per_epoch)
|
|
170
|
+
|
|
171
|
+
if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
|
|
172
|
+
print('validation best auc of main task %d: %.6f' %
|
|
173
|
+
(self.earlystop_taskid, self.early_stopper.best_auc))
|
|
174
|
+
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
175
|
+
break
|
|
176
|
+
|
|
177
|
+
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) #save best auc model
|
|
178
|
+
|
|
179
|
+
return total_log
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def evaluate(self, model, data_loader):
|
|
183
|
+
model.eval()
|
|
184
|
+
targets, predicts = list(), list()
|
|
185
|
+
with torch.no_grad():
|
|
186
|
+
tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
|
|
187
|
+
for i, (x_dict, ys) in enumerate(tk0):
|
|
188
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
|
|
189
|
+
ys = ys.to(self.device)
|
|
190
|
+
y_preds = self.model(x_dict)
|
|
191
|
+
targets.extend(ys.tolist())
|
|
192
|
+
predicts.extend(y_preds.tolist())
|
|
193
|
+
targets, predicts = np.array(targets), np.array(predicts)
|
|
194
|
+
scores = [self.evaluate_fns[i](targets[:, i], predicts[:, i]) for i in range(self.n_task)]
|
|
195
|
+
return scores
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def predict(self, model, data_loader):
|
|
199
|
+
model.eval()
|
|
200
|
+
predicts = list()
|
|
201
|
+
with torch.no_grad():
|
|
202
|
+
tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
|
|
203
|
+
for i, x_dict in enumerate(tk0):
|
|
204
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
205
|
+
y_preds = model(x_dict)
|
|
206
|
+
predicts.extend(y_preds.tolist())
|
|
145
207
|
return predicts
|
|
File without changes
|