torch-rechub 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. torch_rechub/basic/activation.py +54 -52
  2. torch_rechub/basic/callback.py +32 -32
  3. torch_rechub/basic/features.py +94 -57
  4. torch_rechub/basic/initializers.py +92 -0
  5. torch_rechub/basic/layers.py +720 -240
  6. torch_rechub/basic/loss_func.py +34 -0
  7. torch_rechub/basic/metaoptimizer.py +72 -0
  8. torch_rechub/basic/metric.py +250 -0
  9. torch_rechub/models/matching/__init__.py +11 -0
  10. torch_rechub/models/matching/comirec.py +188 -0
  11. torch_rechub/models/matching/dssm.py +66 -0
  12. torch_rechub/models/matching/dssm_facebook.py +79 -0
  13. torch_rechub/models/matching/dssm_senet.py +75 -0
  14. torch_rechub/models/matching/gru4rec.py +87 -0
  15. torch_rechub/models/matching/mind.py +101 -0
  16. torch_rechub/models/matching/narm.py +76 -0
  17. torch_rechub/models/matching/sasrec.py +140 -0
  18. torch_rechub/models/matching/sine.py +151 -0
  19. torch_rechub/models/matching/stamp.py +83 -0
  20. torch_rechub/models/matching/youtube_dnn.py +71 -0
  21. torch_rechub/models/matching/youtube_sbc.py +98 -0
  22. torch_rechub/models/multi_task/__init__.py +5 -4
  23. torch_rechub/models/multi_task/aitm.py +84 -0
  24. torch_rechub/models/multi_task/esmm.py +55 -45
  25. torch_rechub/models/multi_task/mmoe.py +58 -52
  26. torch_rechub/models/multi_task/ple.py +130 -104
  27. torch_rechub/models/multi_task/shared_bottom.py +45 -44
  28. torch_rechub/models/ranking/__init__.py +11 -3
  29. torch_rechub/models/ranking/afm.py +63 -0
  30. torch_rechub/models/ranking/bst.py +63 -0
  31. torch_rechub/models/ranking/dcn.py +38 -0
  32. torch_rechub/models/ranking/dcn_v2.py +69 -0
  33. torch_rechub/models/ranking/deepffm.py +123 -0
  34. torch_rechub/models/ranking/deepfm.py +41 -41
  35. torch_rechub/models/ranking/dien.py +191 -0
  36. torch_rechub/models/ranking/din.py +91 -81
  37. torch_rechub/models/ranking/edcn.py +117 -0
  38. torch_rechub/models/ranking/fibinet.py +50 -0
  39. torch_rechub/models/ranking/widedeep.py +41 -41
  40. torch_rechub/trainers/__init__.py +2 -1
  41. torch_rechub/trainers/{trainer.py → ctr_trainer.py} +128 -111
  42. torch_rechub/trainers/match_trainer.py +170 -0
  43. torch_rechub/trainers/mtl_trainer.py +206 -144
  44. torch_rechub/utils/__init__.py +0 -0
  45. torch_rechub/utils/data.py +360 -0
  46. torch_rechub/utils/match.py +274 -0
  47. torch_rechub/utils/mtl.py +126 -0
  48. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/LICENSE +21 -21
  49. torch_rechub-0.0.3.dist-info/METADATA +177 -0
  50. torch_rechub-0.0.3.dist-info/RECORD +55 -0
  51. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/WHEEL +1 -1
  52. torch_rechub/basic/utils.py +0 -168
  53. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  54. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  55. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
1
+ """
2
+ Date: create on 09/13/2022
3
+ References:
4
+ paper: (KDD'21) EDCN: Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models
5
+ url: https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf
6
+ Authors: lailai, lailai_zxy@tju.edu.cn
7
+ """
8
+
9
+ import torch
10
+ from torch import nn
11
+ from ...basic.layers import LR, MLP, CrossLayer, EmbeddingLayer
12
+
13
+
14
+ class EDCN(torch.nn.Module):
15
+ """Deep & Cross Network with a mixture of low-rank architecture
16
+
17
+ Args:
18
+ features (list[Feature Class]): training by the whole module.
19
+ n_cross_layers (int) : the number of layers of feature intersection layers
20
+ mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}
21
+ bridge_type (str): the type interaction function, in ["hadamard_product", "pointwise_addition", "concatenation", "attention_pooling"]
22
+ use_regulation_module (bool): True, whether to use regulation module
23
+ temperature (int): the temperature coefficient to control distribution
24
+ """
25
+
26
+ def __init__(self, features, n_cross_layers, mlp_params, bridge_type="hadamard_product", use_regulation_module=True,
27
+ temperature=1):
28
+ super().__init__()
29
+ self.features = features
30
+ self.n_cross_layers = n_cross_layers
31
+ self.num_fields = len(features)
32
+ self.dims = sum([fea.embed_dim for fea in features])
33
+ self.fea_dims = [fea.embed_dim for fea in features]
34
+ self.embedding = EmbeddingLayer(features)
35
+ self.cross_layers = nn.ModuleList([CrossLayer(self.dims) for _ in range(n_cross_layers)])
36
+ self.bridge_modules = nn.ModuleList([BridgeModule(self.dims, bridge_type) for _ in range(n_cross_layers)])
37
+ self.regulation_modules = nn.ModuleList([RegulationModule(self.num_fields,
38
+ self.fea_dims,
39
+ tau=temperature,
40
+ use_regulation=use_regulation_module) for _ in range(n_cross_layers)])
41
+ mlp_params["dims"] = [self.dims, self.dims]
42
+ self.mlps = nn.ModuleList([MLP(self.dims, output_layer=False, **mlp_params) for _ in range(n_cross_layers)])
43
+ self.linear = LR(self.dims * 3)
44
+
45
+ def forward(self, x):
46
+ embed_x = self.embedding(x, self.features, squeeze_dim=True)
47
+ cross_i, deep_i = self.regulation_modules[0](embed_x)
48
+ cross_0 = cross_i
49
+ for i in range(self.n_cross_layers):
50
+ if i>0:
51
+ cross_i, deep_i = self.regulation_modules[i](bridge_i)
52
+ cross_i = cross_i + self.cross_layers[i](cross_0, cross_i)
53
+ deep_i = self.mlps[i](deep_i)
54
+ bridge_i = self.bridge_modules[i](cross_i, deep_i)
55
+ x_stack = torch.cat([cross_i, deep_i, bridge_i], dim=1)
56
+ y = self.linear(x_stack)
57
+ return torch.sigmoid(y.squeeze(1))
58
+
59
+
60
+ class BridgeModule(torch.nn.Module):
61
+ def __init__(self, input_dim, bridge_type):
62
+ super(BridgeModule, self).__init__()
63
+ assert bridge_type in ["hadamard_product", "pointwise_addition", "concatenation",
64
+ "attention_pooling"], 'bridge_type= is not supported'.format(bridge_type)
65
+ self.bridge_type = bridge_type
66
+ if bridge_type=="concatenation":
67
+ self.concat_pooling = nn.Sequential(nn.Linear(input_dim * 2, input_dim),
68
+ nn.ReLU())
69
+ elif bridge_type=="attention_pooling":
70
+ self.attention_x = nn.Sequential(nn.Linear(input_dim, input_dim),
71
+ nn.ReLU(),
72
+ nn.Linear(input_dim, input_dim,bias=False),
73
+ nn.Softmax(dim=-1))
74
+ self.attention_h = nn.Sequential(nn.Linear(input_dim, input_dim),
75
+ nn.ReLU(),
76
+ nn.Linear(input_dim, input_dim,bias=False),
77
+ nn.Softmax(dim=-1))
78
+ def forward(self, x, h):
79
+ if self.bridge_type == "hadamard_product":
80
+ out = x * h
81
+ elif self.bridge_type == "pointwise_addition":
82
+ out = x + h
83
+ elif self.bridge_type == "concatenation":
84
+ out = self.concat_pooling(torch.cat([x, h], dim=-1))
85
+ elif self.bridge_type == "attention_pooling":
86
+ out = self.attention_x(x) * x + self.attention_h(h) * h
87
+ return out
88
+
89
+
90
+ class RegulationModule(torch.nn.Module):
91
+ def __init__(self, num_fields,
92
+ dims,
93
+ tau,
94
+ use_regulation=True):
95
+ super(RegulationModule, self).__init__()
96
+ self.use_regulation = use_regulation
97
+ if self.use_regulation:
98
+ self.num_fields = num_fields
99
+ self.dims = dims
100
+ self.tau = tau
101
+ self.g1 = nn.Parameter(torch.ones(num_fields))
102
+ self.g2 = nn.Parameter(torch.ones(num_fields))
103
+
104
+ def forward(self, x):
105
+ if self.use_regulation:
106
+ g1 = torch.cat([(self.g1[i]/ self.tau).softmax(dim=-1).unsqueeze(-1).repeat(1, self.dims[i]) for i in range(self.num_fields)], dim=-1)
107
+ g2 = torch.cat([(self.g2[i] / self.tau).softmax(dim=-1).unsqueeze(-1).repeat(1, self.dims[i]) for i in range(self.num_fields)], dim=-1)
108
+
109
+ out1, out2 = g1*x, g2*x
110
+ else:
111
+ out1, out2 = x, x
112
+ return out1, out2
113
+
114
+
115
+
116
+
117
+
@@ -0,0 +1,50 @@
1
+ """
2
+ Date: create on 10/19/2022
3
+ References:
4
+ paper: (RecSys '19) FiBiNET: combining feature importance and bilinear feature interaction for click-through rate prediction
5
+ url: https://dl.acm.org/doi/abs/10.1145/3298689.3347043
6
+ Authors: lailai, lailai_zxy@tju.edu.cn
7
+ """
8
+ import torch
9
+ from torch import nn
10
+ from ...basic.layers import MLP, EmbeddingLayer, SENETLayer, BiLinearInteractionLayer
11
+ from ...basic.features import SparseFeature
12
+ class FiBiNet(torch.nn.Module):
13
+ """
14
+ Args:
15
+ features (list[Feature Class]): training by the whole module.
16
+ reduction_ratio (int) : Hidden layer reduction factor of SENET layer
17
+ mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}
18
+ bilinear_type (str): the type bilinear interaction function, in ["field_all", "field_each", "field_interaction"], field_all means that all features share a W, field_each means that a feature field corresponds to a W_i, field_interaction means that a feature field intersection corresponds to a W_ij
19
+ """
20
+ def __init__(self, features, mlp_params, reduction_ratio=3, bilinear_type="field_interaction", **kwargs):
21
+ super(FiBiNet, self).__init__()
22
+ self.features = features
23
+ self.embedding = EmbeddingLayer(features)
24
+ embedding_dim = max([fea.embed_dim for fea in features])
25
+ num_fields = len([fea.embed_dim for fea in features if isinstance(fea, SparseFeature) and fea.shared_with == None])
26
+ self.senet_layer = SENETLayer(num_fields, reduction_ratio)
27
+ self.bilinear_interaction = BiLinearInteractionLayer(embedding_dim, num_fields, bilinear_type)
28
+ self.dims = num_fields * (num_fields - 1) * embedding_dim
29
+ self.mlp = MLP(self.dims, **mlp_params)
30
+
31
+ def forward(self, x):
32
+ embed_x = self.embedding(x, self.features)
33
+ embed_senet = self.senet_layer(embed_x)
34
+ embed_bi1 = self.bilinear_interaction(embed_x)
35
+ embed_bi2 = self.bilinear_interaction(embed_senet)
36
+ shallow_part = torch.flatten(torch.cat([embed_bi1, embed_bi2], dim=1), start_dim=1)
37
+ mlp_out = self.mlp(shallow_part)
38
+ return torch.sigmoid(mlp_out.squeeze(1))
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
@@ -1,41 +1,41 @@
1
- """
2
- Date: create on 22/04/2022
3
- References:
4
- paper: (DLRS'2016) Wide & Deep Learning for Recommender Systems
5
- url: https://arxiv.org/abs/1606.07792
6
- Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
7
- """
8
-
9
- import torch
10
-
11
- from ...basic.layers import LR, MLP, EmbeddingLayer
12
-
13
-
14
- class WideDeep(torch.nn.Module):
15
- """Wide & Deep Learning model.
16
-
17
- Args:
18
- wide_features (list): the list of `Feature Class`, training by the wide part module.
19
- deep_features (list): the list of `Feature Class`, training by the deep part module.
20
- mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}
21
- """
22
-
23
- def __init__(self, wide_features, deep_features, mlp_params):
24
- super(WideDeep, self).__init__()
25
- self.wide_features = wide_features
26
- self.deep_features = deep_features
27
- self.wide_dims = sum([fea.embed_dim for fea in wide_features])
28
- self.deep_dims = sum([fea.embed_dim for fea in deep_features])
29
- self.linear = LR(self.wide_dims)
30
- self.embedding = EmbeddingLayer(wide_features + deep_features)
31
- self.mlp = MLP(self.deep_dims, **mlp_params)
32
-
33
- def forward(self, x):
34
- input_wide = self.embedding(x, self.wide_features, squeeze_dim=True) #[batch_size, wide_dims]
35
- input_deep = self.embedding(x, self.deep_features, squeeze_dim=True) #[batch_size, deep_dims]
36
-
37
- y_wide = self.linear(input_wide) #[batch_size, 1]
38
- y_deep = self.mlp(input_deep) #[batch_size, 1]
39
- y = y_wide + y_deep
40
- y = torch.sigmoid(y.squeeze(1))
41
- return y
1
+ """
2
+ Date: create on 22/04/2022
3
+ References:
4
+ paper: (DLRS'2016) Wide & Deep Learning for Recommender Systems
5
+ url: https://arxiv.org/abs/1606.07792
6
+ Authors: Mincai Lai, laimincai@shanghaitech.edu.cn
7
+ """
8
+
9
+ import torch
10
+
11
+ from ...basic.layers import LR, MLP, EmbeddingLayer
12
+
13
+
14
+ class WideDeep(torch.nn.Module):
15
+ """Wide & Deep Learning model.
16
+
17
+ Args:
18
+ wide_features (list): the list of `Feature Class`, training by the wide part module.
19
+ deep_features (list): the list of `Feature Class`, training by the deep part module.
20
+ mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}
21
+ """
22
+
23
+ def __init__(self, wide_features, deep_features, mlp_params):
24
+ super(WideDeep, self).__init__()
25
+ self.wide_features = wide_features
26
+ self.deep_features = deep_features
27
+ self.wide_dims = sum([fea.embed_dim for fea in wide_features])
28
+ self.deep_dims = sum([fea.embed_dim for fea in deep_features])
29
+ self.linear = LR(self.wide_dims)
30
+ self.embedding = EmbeddingLayer(wide_features + deep_features)
31
+ self.mlp = MLP(self.deep_dims, **mlp_params)
32
+
33
+ def forward(self, x):
34
+ input_wide = self.embedding(x, self.wide_features, squeeze_dim=True) #[batch_size, wide_dims]
35
+ input_deep = self.embedding(x, self.deep_features, squeeze_dim=True) #[batch_size, deep_dims]
36
+
37
+ y_wide = self.linear(input_wide) #[batch_size, 1]
38
+ y_deep = self.mlp(input_deep) #[batch_size, 1]
39
+ y = y_wide + y_deep
40
+ y = torch.sigmoid(y.squeeze(1))
41
+ return y
@@ -1,2 +1,3 @@
1
- from .trainer import CTRTrainer
1
+ from .ctr_trainer import CTRTrainer
2
+ from .match_trainer import MatchTrainer
2
3
  from .mtl_trainer import MTLTrainer
@@ -1,111 +1,128 @@
1
- import os
2
- import torch
3
- import tqdm
4
- from sklearn.metrics import roc_auc_score
5
- from ..basic.callback import EarlyStopper
6
-
7
-
8
- class CTRTrainer(object):
9
- """A general trainer for single task learning.
10
-
11
- Args:
12
- model (nn.Module): any multi task learning model.
13
- optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
14
- optimizer_params (dict): parameters of optimizer_fn.
15
- scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
16
- scheduler_params (dict): parameters of optimizer scheduler_fn.
17
- n_epoch (int): epoch number of training.
18
- earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
19
- device (str): `"cpu"` or `"cuda:0"`
20
- gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
21
- model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
22
- """
23
-
24
- def __init__(
25
- self,
26
- model,
27
- optimizer_fn=torch.optim.Adam,
28
- optimizer_params={
29
- "lr": 1e-3,
30
- "weight_decay": 1e-5
31
- },
32
- scheduler_fn=None,
33
- scheduler_params=None,
34
- n_epoch=10,
35
- earlystop_patience=10,
36
- device="cpu",
37
- gpus=[],
38
- model_path="./",
39
- ):
40
- self.model = model #for uniform weights save method in one gpu or multi gpu
41
- self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default optimizer
42
- self.scheduler = None
43
- if scheduler_fn is not None:
44
- self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
45
- self.criterion = torch.nn.BCELoss() #default loss cross_entropy
46
- self.evaluate_fn = roc_auc_score #default evaluate function
47
- self.n_epoch = n_epoch
48
- self.early_stopper = EarlyStopper(patience=earlystop_patience)
49
- self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
- self.gpus = gpus
51
- if len(gpus) > 1:
52
- print('parallel running on these gpus:', gpus)
53
- self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
54
- self.model_path = model_path
55
-
56
- def train_one_epoch(self, data_loader, log_interval=10):
57
- self.model.train()
58
- total_loss = 0
59
- tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
60
- for i, (x_dict, y) in enumerate(tk0):
61
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
62
- y = y.to(self.device)
63
- y_pred = self.model(x_dict)
64
- loss = self.criterion(y_pred, y.float())
65
- self.model.zero_grad()
66
- loss.backward()
67
- self.optimizer.step()
68
- total_loss += loss.item()
69
- if (i + 1) % log_interval == 0:
70
- tk0.set_postfix(loss=total_loss / log_interval)
71
- total_loss = 0
72
-
73
- def fit(self, train_dataloader, val_dataloader):
74
- self.model.to(self.device)
75
- for epoch_i in range(self.n_epoch):
76
- self.train_one_epoch(train_dataloader)
77
- if self.scheduler is not None:
78
- if epoch_i % self.scheduler.step_size == 0:
79
- print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
80
- self.scheduler.step() #update lr in epoch level by scheduler
81
- auc = self.evaluate(self.model, val_dataloader)
82
- print('epoch:', epoch_i, 'validation: auc:', auc)
83
- if self.early_stopper.stop_training(auc, self.model.state_dict()):
84
- print(f'validation: best auc: {self.early_stopper.best_auc}')
85
- self.model.load_state_dict(self.early_stopper.best_weights)
86
- torch.save(self.early_stopper.best_weights, os.path.join(self.model_path, "model.pth")) #save best auc model
87
- break
88
-
89
- def evaluate(self, model, data_loader):
90
- model.eval()
91
- targets, predicts = list(), list()
92
- with torch.no_grad():
93
- tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
94
- for i, (x_dict, y) in enumerate(tk0):
95
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
96
- y = y.to(self.device)
97
- y_pred = model(x_dict)
98
- targets.extend(y.tolist())
99
- predicts.extend(y_pred.tolist())
100
- return self.evaluate_fn(targets, predicts)
101
-
102
- def predict(self, model, data_loader):
103
- model.eval()
104
- predicts = list()
105
- with torch.no_grad():
106
- tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
107
- for i, x_dict in enumerate(tk0):
108
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
109
- y_pred = model(x_dict)
110
- predicts.extend(y_pred.tolist())
111
- return predicts
1
+ import os
2
+ import torch
3
+ import tqdm
4
+ from sklearn.metrics import roc_auc_score
5
+ from ..basic.callback import EarlyStopper
6
+
7
+
8
+ class CTRTrainer(object):
9
+ """A general trainer for single task learning.
10
+
11
+ Args:
12
+ model (nn.Module): any multi task learning model.
13
+ optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
14
+ optimizer_params (dict): parameters of optimizer_fn.
15
+ scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
16
+ scheduler_params (dict): parameters of optimizer scheduler_fn.
17
+ n_epoch (int): epoch number of training.
18
+ earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
19
+ device (str): `"cpu"` or `"cuda:0"`
20
+ gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
21
+ loss_mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
22
+ model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model,
28
+ optimizer_fn=torch.optim.Adam,
29
+ optimizer_params=None,
30
+ scheduler_fn=None,
31
+ scheduler_params=None,
32
+ n_epoch=10,
33
+ earlystop_patience=10,
34
+ device="cpu",
35
+ gpus=None,
36
+ loss_mode=True,
37
+ model_path="./",
38
+ ):
39
+ self.model = model # for uniform weights save method in one gpu or multi gpu
40
+ if gpus is None:
41
+ gpus = []
42
+ self.gpus = gpus
43
+ if len(gpus) > 1:
44
+ print('parallel running on these gpus:', gpus)
45
+ self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
46
+ self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47
+ self.model.to(self.device)
48
+ if optimizer_params is None:
49
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
50
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default optimizer
51
+ self.scheduler = None
52
+ if scheduler_fn is not None:
53
+ self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
54
+ self.loss_mode = loss_mode
55
+ self.criterion = torch.nn.BCELoss() #default loss cross_entropy
56
+ self.evaluate_fn = roc_auc_score #default evaluate function
57
+ self.n_epoch = n_epoch
58
+ self.early_stopper = EarlyStopper(patience=earlystop_patience)
59
+ self.model_path = model_path
60
+
61
+ def train_one_epoch(self, data_loader, log_interval=10):
62
+ self.model.train()
63
+ total_loss = 0
64
+ tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
65
+ for i, (x_dict, y) in enumerate(tk0):
66
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
67
+ y = y.to(self.device).float()
68
+ if self.loss_mode:
69
+ y_pred = self.model(x_dict)
70
+ loss = self.criterion(y_pred, y)
71
+ else:
72
+ y_pred, other_loss = self.model(x_dict)
73
+ loss = self.criterion(y_pred, y) + other_loss
74
+ self.model.zero_grad()
75
+ loss.backward()
76
+ self.optimizer.step()
77
+ total_loss += loss.item()
78
+ if (i + 1) % log_interval == 0:
79
+ tk0.set_postfix(loss=total_loss / log_interval)
80
+ total_loss = 0
81
+
82
+ def fit(self, train_dataloader, val_dataloader=None):
83
+ for epoch_i in range(self.n_epoch):
84
+ print('epoch:', epoch_i)
85
+ self.train_one_epoch(train_dataloader)
86
+ if self.scheduler is not None:
87
+ if epoch_i % self.scheduler.step_size == 0:
88
+ print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
89
+ self.scheduler.step() #update lr in epoch level by scheduler
90
+ if val_dataloader:
91
+ auc = self.evaluate(self.model, val_dataloader)
92
+ print('epoch:', epoch_i, 'validation: auc:', auc)
93
+ if self.early_stopper.stop_training(auc, self.model.state_dict()):
94
+ print(f'validation: best auc: {self.early_stopper.best_auc}')
95
+ self.model.load_state_dict(self.early_stopper.best_weights)
96
+ break
97
+ torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) #save best auc model
98
+
99
+ def evaluate(self, model, data_loader):
100
+ model.eval()
101
+ targets, predicts = list(), list()
102
+ with torch.no_grad():
103
+ tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
104
+ for i, (x_dict, y) in enumerate(tk0):
105
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
106
+ y = y.to(self.device).float().view(-1, 1) # 确保y是float类型且维度为[batch_size, 1]
107
+ if self.loss_mode:
108
+ y_pred = model(x_dict)
109
+ else:
110
+ y_pred, _ = model(x_dict)
111
+ targets.extend(y.tolist())
112
+ predicts.extend(y_pred.tolist())
113
+ return self.evaluate_fn(targets, predicts)
114
+
115
+ def predict(self, model, data_loader):
116
+ model.eval()
117
+ predicts = list()
118
+ with torch.no_grad():
119
+ tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
120
+ for i, (x_dict, y) in enumerate(tk0):
121
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
122
+ y = y.to(self.device)
123
+ if self.loss_mode:
124
+ y_pred = model(x_dict)
125
+ else:
126
+ y_pred, _ = model(x_dict)
127
+ predicts.extend(y_pred.tolist())
128
+ return predicts