torch-rechub 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. torch_rechub/basic/activation.py +54 -52
  2. torch_rechub/basic/callback.py +32 -32
  3. torch_rechub/basic/features.py +94 -57
  4. torch_rechub/basic/initializers.py +92 -0
  5. torch_rechub/basic/layers.py +720 -240
  6. torch_rechub/basic/loss_func.py +34 -0
  7. torch_rechub/basic/metaoptimizer.py +72 -0
  8. torch_rechub/basic/metric.py +250 -0
  9. torch_rechub/models/matching/__init__.py +11 -0
  10. torch_rechub/models/matching/comirec.py +188 -0
  11. torch_rechub/models/matching/dssm.py +66 -0
  12. torch_rechub/models/matching/dssm_facebook.py +79 -0
  13. torch_rechub/models/matching/dssm_senet.py +75 -0
  14. torch_rechub/models/matching/gru4rec.py +87 -0
  15. torch_rechub/models/matching/mind.py +101 -0
  16. torch_rechub/models/matching/narm.py +76 -0
  17. torch_rechub/models/matching/sasrec.py +140 -0
  18. torch_rechub/models/matching/sine.py +151 -0
  19. torch_rechub/models/matching/stamp.py +83 -0
  20. torch_rechub/models/matching/youtube_dnn.py +71 -0
  21. torch_rechub/models/matching/youtube_sbc.py +98 -0
  22. torch_rechub/models/multi_task/__init__.py +5 -4
  23. torch_rechub/models/multi_task/aitm.py +84 -0
  24. torch_rechub/models/multi_task/esmm.py +55 -45
  25. torch_rechub/models/multi_task/mmoe.py +58 -52
  26. torch_rechub/models/multi_task/ple.py +130 -104
  27. torch_rechub/models/multi_task/shared_bottom.py +45 -44
  28. torch_rechub/models/ranking/__init__.py +11 -3
  29. torch_rechub/models/ranking/afm.py +63 -0
  30. torch_rechub/models/ranking/bst.py +63 -0
  31. torch_rechub/models/ranking/dcn.py +38 -0
  32. torch_rechub/models/ranking/dcn_v2.py +69 -0
  33. torch_rechub/models/ranking/deepffm.py +123 -0
  34. torch_rechub/models/ranking/deepfm.py +41 -41
  35. torch_rechub/models/ranking/dien.py +191 -0
  36. torch_rechub/models/ranking/din.py +91 -81
  37. torch_rechub/models/ranking/edcn.py +117 -0
  38. torch_rechub/models/ranking/fibinet.py +50 -0
  39. torch_rechub/models/ranking/widedeep.py +41 -41
  40. torch_rechub/trainers/__init__.py +2 -1
  41. torch_rechub/trainers/{trainer.py → ctr_trainer.py} +128 -111
  42. torch_rechub/trainers/match_trainer.py +170 -0
  43. torch_rechub/trainers/mtl_trainer.py +206 -144
  44. torch_rechub/utils/__init__.py +0 -0
  45. torch_rechub/utils/data.py +360 -0
  46. torch_rechub/utils/match.py +274 -0
  47. torch_rechub/utils/mtl.py +126 -0
  48. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/LICENSE +21 -21
  49. torch_rechub-0.0.3.dist-info/METADATA +177 -0
  50. torch_rechub-0.0.3.dist-info/RECORD +55 -0
  51. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/WHEEL +1 -1
  52. torch_rechub/basic/utils.py +0 -168
  53. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  54. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  55. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
1
+ import os
2
+ import torch
3
+ import tqdm
4
+ from sklearn.metrics import roc_auc_score
5
+ from ..basic.callback import EarlyStopper
6
+ from ..basic.loss_func import BPRLoss
7
+
8
+
9
+ class MatchTrainer(object):
10
+ """A general trainer for Matching/Retrieval
11
+
12
+ Args:
13
+ model (nn.Module): any matching model.
14
+ mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
15
+ optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
16
+ optimizer_params (dict): parameters of optimizer_fn.
17
+ scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
18
+ scheduler_params (dict): parameters of optimizer scheduler_fn.
19
+ n_epoch (int): epoch number of training.
20
+ earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
21
+ device (str): `"cpu"` or `"cuda:0"`
22
+ gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
23
+ model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ model,
29
+ mode=0,
30
+ optimizer_fn=torch.optim.Adam,
31
+ optimizer_params=None,
32
+ scheduler_fn=None,
33
+ scheduler_params=None,
34
+ n_epoch=10,
35
+ earlystop_patience=10,
36
+ device="cpu",
37
+ gpus=None,
38
+ model_path="./",
39
+ ):
40
+ self.model = model # for uniform weights save method in one gpu or multi gpu
41
+ if gpus is None:
42
+ gpus = []
43
+ self.gpus = gpus
44
+ if len(gpus) > 1:
45
+ print('parallel running on these gpus:', gpus)
46
+ self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
47
+ self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
48
+ self.model.to(self.device)
49
+ if optimizer_params is None:
50
+ optimizer_params = {
51
+ "lr": 1e-3,
52
+ "weight_decay": 1e-5
53
+ }
54
+ self.mode = mode
55
+ if mode == 0: #point-wise loss, binary cross_entropy
56
+ self.criterion = torch.nn.BCELoss() #default loss binary cross_entropy
57
+ elif mode == 1: #pair-wise loss
58
+ self.criterion = BPRLoss()
59
+ elif mode == 2: #list-wise loss, softmax
60
+ self.criterion = torch.nn.CrossEntropyLoss()
61
+ else:
62
+ raise ValueError("mode only contain value in %s, but got %s" % ([0, 1, 2], mode))
63
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default optimizer
64
+ self.scheduler = None
65
+ if scheduler_fn is not None:
66
+ self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
67
+ self.evaluate_fn = roc_auc_score #default evaluate function
68
+ self.n_epoch = n_epoch
69
+ self.early_stopper = EarlyStopper(patience=earlystop_patience)
70
+ self.model_path = model_path
71
+
72
+ def train_one_epoch(self, data_loader, log_interval=10):
73
+ self.model.train()
74
+ total_loss = 0
75
+ tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
76
+ for i, (x_dict, y) in enumerate(tk0):
77
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
78
+ y = y.to(self.device)
79
+ if self.mode == 0:
80
+ y = y.float() #torch._C._nn.binary_cross_entropy expected Float
81
+ else:
82
+ y = y.long() #
83
+ if self.mode == 1: #pair_wise
84
+ pos_score, neg_score = self.model(x_dict)
85
+ loss = self.criterion(pos_score, neg_score)
86
+ else:
87
+ y_pred = self.model(x_dict)
88
+ loss = self.criterion(y_pred, y)
89
+ # used for debug
90
+ # if i == 0:
91
+ # print()
92
+ # if self.mode == 0:
93
+ # print('pred: ', [f'{float(each):5.2g}' for each in y_pred.detach().cpu().tolist()])
94
+ # print('truth:', [f'{float(each):5.2g}' for each in y.detach().cpu().tolist()])
95
+ # elif self.mode == 2:
96
+ # pred = y_pred.detach().cpu().mean(0)
97
+ # pred = torch.softmax(pred, dim=0).tolist()
98
+ # print('pred: ', [f'{float(each):4.2g}' for each in pred])
99
+ # elif self.mode == 1:
100
+ # print('pos:', [f'{float(each):5.2g}' for each in pos_score.detach().cpu().tolist()])
101
+ # print('neg: ', [f'{float(each):5.2g}' for each in neg_score.detach().cpu().tolist()])
102
+
103
+ self.model.zero_grad()
104
+ loss.backward()
105
+ self.optimizer.step()
106
+ total_loss += loss.item()
107
+ if (i + 1) % log_interval == 0:
108
+ tk0.set_postfix(loss=total_loss / log_interval)
109
+ total_loss = 0
110
+
111
+ def fit(self, train_dataloader, val_dataloader=None):
112
+ for epoch_i in range(self.n_epoch):
113
+ print('epoch:', epoch_i)
114
+ self.train_one_epoch(train_dataloader)
115
+ if self.scheduler is not None:
116
+ if epoch_i % self.scheduler.step_size == 0:
117
+ print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
118
+ self.scheduler.step() #update lr in epoch level by scheduler
119
+
120
+ if val_dataloader:
121
+ auc = self.evaluate(self.model, val_dataloader)
122
+ print('epoch:', epoch_i, 'validation: auc:', auc)
123
+ if self.early_stopper.stop_training(auc, self.model.state_dict()):
124
+ print(f'validation: best auc: {self.early_stopper.best_auc}')
125
+ self.model.load_state_dict(self.early_stopper.best_weights)
126
+ break
127
+ torch.save(self.model.state_dict(), os.path.join(self.model_path,
128
+ "model.pth")) #save best auc model
129
+
130
+
131
+ def evaluate(self, model, data_loader):
132
+ model.eval()
133
+ targets, predicts = list(), list()
134
+ with torch.no_grad():
135
+ tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
136
+ for i, (x_dict, y) in enumerate(tk0):
137
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
138
+ y = y.to(self.device)
139
+ y_pred = model(x_dict)
140
+ targets.extend(y.tolist())
141
+ predicts.extend(y_pred.tolist())
142
+ return self.evaluate_fn(targets, predicts)
143
+
144
+ def predict(self, model, data_loader):
145
+ model.eval()
146
+ predicts = list()
147
+ with torch.no_grad():
148
+ tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
149
+ for i, (x_dict, y) in enumerate(tk0):
150
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
151
+ y = y.to(self.device)
152
+ y_pred = model(x_dict)
153
+ predicts.extend(y_pred.tolist())
154
+ return predicts
155
+
156
+ def inference_embedding(self, model, mode, data_loader, model_path):
157
+ #inference
158
+ assert mode in ["user", "item"], "Invalid mode={}.".format(mode)
159
+ model.mode = mode
160
+ model.load_state_dict(torch.load(os.path.join(model_path, "model.pth")))
161
+ model = model.to(self.device)
162
+ model.eval()
163
+ predicts = []
164
+ with torch.no_grad():
165
+ tk0 = tqdm.tqdm(data_loader, desc="%s inference" % (mode), smoothing=0, mininterval=1.0)
166
+ for i, x_dict in enumerate(tk0):
167
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
168
+ y_pred = model(x_dict)
169
+ predicts.append(y_pred.data)
170
+ return torch.cat(predicts, dim=0)
@@ -1,145 +1,207 @@
1
- import os
2
- import tqdm
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- from ..basic.callback import EarlyStopper
7
- from ..basic.utils import get_loss_func, get_metric_func
8
- from ..models.multi_task import ESMM
9
-
10
-
11
- class MTLTrainer(object):
12
- """A trainer for multi task learning.
13
-
14
- Args:
15
- model (nn.Module): any multi task learning model.
16
- task_types (list): types of tasks, only support ["classfication", "regression"].
17
- optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
18
- optimizer_params (dict): parameters of optimizer_fn.
19
- scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
20
- scheduler_params (dict): parameters of optimizer scheduler_fn.
21
- adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
22
- n_epoch (int): epoch number of training.
23
- earlystop_taskid (int): task id of earlystop metrics relies between multi task (default = 0).
24
- earlystop_patience (int): how long to wait after last time validation auc improved (default = 10).
25
- device (str): `"cpu"` or `"cuda:0"`
26
- gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
27
- model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
28
- """
29
-
30
- def __init__(
31
- self,
32
- model,
33
- task_types,
34
- optimizer_fn=torch.optim.Adam,
35
- optimizer_params={
36
- "lr": 1e-3,
37
- "weight_decay": 1e-5
38
- },
39
- scheduler_fn=None,
40
- scheduler_params=None,
41
- adaptive_params=None,
42
- n_epoch=10,
43
- earlystop_taskid=0,
44
- earlystop_patience=10,
45
- device="cpu",
46
- gpus=[],
47
- model_path="./",
48
- ):
49
- self.model = model
50
- self.task_types = task_types
51
- self.n_task = len(task_types)
52
- self.loss_weight = None
53
- self.adaptive_method = None
54
- if adaptive_params is not None:
55
- if adaptive_params["method"] == "uwl":
56
- self.adaptive_method = "uwl"
57
- self.loss_weight = nn.ParameterList(nn.Parameter(torch.zeros(1)) for _ in range(self.n_task))
58
- self.model.add_module("loss weight", self.loss_weight)
59
- self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default Adam optimizer
60
- self.scheduler = None
61
- if scheduler_fn is not None:
62
- self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
63
- self.loss_fns = [get_loss_func(task_type) for task_type in task_types]
64
- self.evaluate_fns = [get_metric_func(task_type) for task_type in task_types]
65
- self.n_epoch = n_epoch
66
- self.earlystop_taskid = earlystop_taskid
67
- self.early_stopper = EarlyStopper(patience=earlystop_patience)
68
- self.device = torch.device(device)
69
-
70
- self.gpus = gpus
71
- if len(gpus) > 1:
72
- print('parallel running on these gpus:', gpus)
73
- self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
74
- self.model_path = model_path
75
-
76
- def train_one_epoch(self, data_loader):
77
- self.model.train()
78
- total_loss = np.zeros(self.n_task)
79
- tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
80
- for iter_i, (x_dict, ys) in enumerate(tk0):
81
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
82
- ys = ys.to(self.device)
83
- y_preds = self.model(x_dict)
84
- loss_list = [self.loss_fns[i](y_preds[:, i], ys[:, i].float()) for i in range(self.n_task)]
85
- if isinstance(self.model, ESMM):
86
- loss = sum(loss_list[1:]) #ESSM only compute loss for ctr and ctcvr task
87
- else:
88
- if self.adaptive_method != None:
89
- if self.adaptive_method == "uwl":
90
- loss = 0
91
- for loss_i, w_i in zip(loss_list, self.loss_weight):
92
- w_i = torch.clamp(w_i, min=0)
93
- loss += 2 * loss_i * torch.exp(-w_i) + w_i
94
- else:
95
- loss = sum(loss_list) / self.n_task
96
- self.model.zero_grad()
97
- loss.backward()
98
- self.optimizer.step()
99
- total_loss += np.array([l.item() for l in loss_list])
100
- log_dict = {"task_%d:" % (i): total_loss[i] / (iter_i + 1) for i in range(self.n_task)}
101
- print("train loss: ", log_dict)
102
- if self.loss_weight:
103
- print("loss weight: ", [w.item() for w in self.loss_weight])
104
-
105
- def fit(self, train_dataloader, val_dataloader):
106
- self.model.to(self.device)
107
- for epoch_i in range(self.n_epoch):
108
- self.train_one_epoch(train_dataloader)
109
- if self.scheduler is not None:
110
- if epoch_i % self.scheduler.step_size == 0:
111
- print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
112
- self.scheduler.step() #update lr in epoch level by scheduler
113
- scores = self.evaluate(self.model, val_dataloader)
114
- print('epoch:', epoch_i, 'validation scores: ', scores)
115
- if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
116
- print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
117
- self.model.load_state_dict(self.early_stopper.best_weights)
118
- torch.save(self.early_stopper.best_weights, os.path.join(self.model_path, "model.pth")) #save best auc model
119
- break
120
-
121
- def evaluate(self, model, data_loader):
122
- model.eval()
123
- targets, predicts = list(), list()
124
- with torch.no_grad():
125
- tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
126
- for i, (x_dict, ys) in enumerate(tk0):
127
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
128
- ys = ys.to(self.device)
129
- y_preds = self.model(x_dict)
130
- targets.extend(ys.tolist())
131
- predicts.extend(y_preds.tolist())
132
- targets, predicts = np.array(targets), np.array(predicts)
133
- scores = [self.evaluate_fns[i](targets[:, i], predicts[:, i]) for i in range(self.n_task)]
134
- return scores
135
-
136
- def predict(self, model, data_loader):
137
- model.eval()
138
- predicts = list()
139
- with torch.no_grad():
140
- tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
141
- for i, x_dict in enumerate(tk0):
142
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
143
- y_preds = model(x_dict)
144
- predicts.extend(y_preds.tolist())
1
+ import os
2
+ import tqdm
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from ..basic.callback import EarlyStopper
7
+ from ..utils.data import get_loss_func, get_metric_func
8
+ from ..models.multi_task import ESMM
9
+ from ..utils.mtl import shared_task_layers, gradnorm, MetaBalance
10
+
11
+
12
+ class MTLTrainer(object):
13
+ """A trainer for multi task learning.
14
+
15
+ Args:
16
+ model (nn.Module): any multi task learning model.
17
+ task_types (list): types of tasks, only support ["classfication", "regression"].
18
+ optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
19
+ optimizer_params (dict): parameters of optimizer_fn.
20
+ scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
21
+ scheduler_params (dict): parameters of optimizer scheduler_fn.
22
+ adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
23
+ n_epoch (int): epoch number of training.
24
+ earlystop_taskid (int): task id of earlystop metrics relies between multi task (default = 0).
25
+ earlystop_patience (int): how long to wait after last time validation auc improved (default = 10).
26
+ device (str): `"cpu"` or `"cuda:0"`
27
+ gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
28
+ model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ model,
34
+ task_types,
35
+ optimizer_fn=torch.optim.Adam,
36
+ optimizer_params=None,
37
+ scheduler_fn=None,
38
+ scheduler_params=None,
39
+ adaptive_params=None,
40
+ n_epoch=10,
41
+ earlystop_taskid=0,
42
+ earlystop_patience=10,
43
+ device="cpu",
44
+ gpus=None,
45
+ model_path="./",
46
+ ):
47
+ self.model = model
48
+ if gpus is None:
49
+ gpus = []
50
+ if optimizer_params is None:
51
+ optimizer_params = {
52
+ "lr": 1e-3,
53
+ "weight_decay": 1e-5
54
+ }
55
+ self.task_types = task_types
56
+ self.n_task = len(task_types)
57
+ self.loss_weight = None
58
+ self.adaptive_method = None
59
+ if adaptive_params is not None:
60
+ if adaptive_params["method"] == "uwl":
61
+ self.adaptive_method = "uwl"
62
+ self.loss_weight = nn.ParameterList(nn.Parameter(torch.zeros(1)) for _ in range(self.n_task))
63
+ self.model.add_module("loss weight", self.loss_weight)
64
+ elif adaptive_params["method"] == "metabalance":
65
+ self.adaptive_method = "metabalance"
66
+ share_layers, task_layers = shared_task_layers(self.model)
67
+ self.meta_optimizer = MetaBalance(share_layers)
68
+ self.share_optimizer = optimizer_fn(share_layers, **optimizer_params)
69
+ self.task_optimizer = optimizer_fn(task_layers, **optimizer_params)
70
+ elif adaptive_params["method"] == "gradnorm":
71
+ self.adaptive_method = "gradnorm"
72
+ self.alpha = adaptive_params.get("alpha", 0.16)
73
+ share_layers = shared_task_layers(self.model)[0]
74
+ #gradnorm calculate the gradients of each loss on the last fully connected shared layer weight(dimension is 2)
75
+ for i in range(len(share_layers)):
76
+ if share_layers[-i].ndim == 2:
77
+ self.last_share_layer = share_layers[-i]
78
+ break
79
+ self.initial_task_loss = None
80
+ self.loss_weight = nn.ParameterList(nn.Parameter(torch.ones(1)) for _ in range(self.n_task))
81
+ self.model.add_module("loss weight", self.loss_weight)
82
+ if self.adaptive_method != "metabalance":
83
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default Adam optimizer
84
+ self.scheduler = None
85
+ if scheduler_fn is not None:
86
+ self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
87
+ self.loss_fns = [get_loss_func(task_type) for task_type in task_types]
88
+ self.evaluate_fns = [get_metric_func(task_type) for task_type in task_types]
89
+ self.n_epoch = n_epoch
90
+ self.earlystop_taskid = earlystop_taskid
91
+ self.early_stopper = EarlyStopper(patience=earlystop_patience)
92
+
93
+ self.gpus = gpus
94
+ if len(gpus) > 1:
95
+ print('parallel running on these gpus:', gpus)
96
+ self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
97
+ self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
98
+ self.model.to(self.device)
99
+ self.model_path = model_path
100
+
101
+
102
+ def train_one_epoch(self, data_loader):
103
+ self.model.train()
104
+ total_loss = np.zeros(self.n_task)
105
+ tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
106
+ for iter_i, (x_dict, ys) in enumerate(tk0):
107
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
108
+ ys = ys.to(self.device)
109
+ y_preds = self.model(x_dict)
110
+ loss_list = [self.loss_fns[i](y_preds[:, i], ys[:, i].float()) for i in range(self.n_task)]
111
+ if isinstance(self.model, ESMM):
112
+ loss = sum(loss_list[1:]) #ESSM only compute loss for ctr and ctcvr task
113
+ else:
114
+ if self.adaptive_method != None:
115
+ if self.adaptive_method == "uwl":
116
+ loss = 0
117
+ for loss_i, w_i in zip(loss_list, self.loss_weight):
118
+ w_i = torch.clamp(w_i, min=0)
119
+ loss += 2 * loss_i * torch.exp(-w_i) + w_i
120
+ else:
121
+ loss = sum(loss_list) / self.n_task
122
+ if self.adaptive_method == 'metabalance':
123
+ self.share_optimizer.zero_grad()
124
+ self.task_optimizer.zero_grad()
125
+ self.meta_optimizer.step(loss_list)
126
+ self.share_optimizer.step()
127
+ self.task_optimizer.step()
128
+ elif self.adaptive_method == "gradnorm":
129
+ self.optimizer.zero_grad()
130
+ if self.initial_task_loss is None:
131
+ self.initial_task_loss = [l.item() for l in loss_list]
132
+ gradnorm(loss_list, self.loss_weight, self.last_share_layer, self.initial_task_loss, self.alpha)
133
+ self.optimizer.step()
134
+ # renormalize
135
+ loss_weight_sum = sum([w.item() for w in self.loss_weight])
136
+ normalize_coeff = len(self.loss_weight) / loss_weight_sum
137
+ for w in self.loss_weight:
138
+ w.data = w.data * normalize_coeff
139
+ else:
140
+ self.model.zero_grad()
141
+ loss.backward()
142
+ self.optimizer.step()
143
+ total_loss += np.array([l.item() for l in loss_list])
144
+ log_dict = {"task_%d:" % (i): total_loss[i] / (iter_i + 1) for i in range(self.n_task)}
145
+ loss_list = [total_loss[i] / (iter_i + 1) for i in range(self.n_task)]
146
+ print("train loss: ", log_dict)
147
+ if self.loss_weight:
148
+ print("loss weight: ", [w.item() for w in self.loss_weight])
149
+
150
+ return loss_list
151
+
152
+
153
+ def fit(self, train_dataloader, val_dataloader, mode = 'base', seed = 0):
154
+ total_log = []
155
+
156
+ for epoch_i in range(self.n_epoch):
157
+ _log_per_epoch = self.train_one_epoch(train_dataloader)
158
+
159
+ if self.scheduler is not None:
160
+ if epoch_i % self.scheduler.step_size == 0:
161
+ print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
162
+ self.scheduler.step() #update lr in epoch level by scheduler
163
+ scores = self.evaluate(self.model, val_dataloader)
164
+ print('epoch:', epoch_i, 'validation scores: ', scores)
165
+
166
+ for score in scores:
167
+ _log_per_epoch.append(score)
168
+
169
+ total_log.append(_log_per_epoch)
170
+
171
+ if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
172
+ print('validation best auc of main task %d: %.6f' %
173
+ (self.earlystop_taskid, self.early_stopper.best_auc))
174
+ self.model.load_state_dict(self.early_stopper.best_weights)
175
+ break
176
+
177
+ torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) #save best auc model
178
+
179
+ return total_log
180
+
181
+
182
+ def evaluate(self, model, data_loader):
183
+ model.eval()
184
+ targets, predicts = list(), list()
185
+ with torch.no_grad():
186
+ tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
187
+ for i, (x_dict, ys) in enumerate(tk0):
188
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
189
+ ys = ys.to(self.device)
190
+ y_preds = self.model(x_dict)
191
+ targets.extend(ys.tolist())
192
+ predicts.extend(y_preds.tolist())
193
+ targets, predicts = np.array(targets), np.array(predicts)
194
+ scores = [self.evaluate_fns[i](targets[:, i], predicts[:, i]) for i in range(self.n_task)]
195
+ return scores
196
+
197
+
198
+ def predict(self, model, data_loader):
199
+ model.eval()
200
+ predicts = list()
201
+ with torch.no_grad():
202
+ tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
203
+ for i, x_dict in enumerate(tk0):
204
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
205
+ y_preds = model(x_dict)
206
+ predicts.extend(y_preds.tolist())
145
207
  return predicts
File without changes