torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,207 +1,356 @@
1
- import os
2
- import tqdm
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- from ..basic.callback import EarlyStopper
7
- from ..utils.data import get_loss_func, get_metric_func
8
- from ..models.multi_task import ESMM
9
- from ..utils.mtl import shared_task_layers, gradnorm, MetaBalance
10
-
11
-
12
- class MTLTrainer(object):
13
- """A trainer for multi task learning.
14
-
15
- Args:
16
- model (nn.Module): any multi task learning model.
17
- task_types (list): types of tasks, only support ["classfication", "regression"].
18
- optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
19
- optimizer_params (dict): parameters of optimizer_fn.
20
- scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
21
- scheduler_params (dict): parameters of optimizer scheduler_fn.
22
- adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
23
- n_epoch (int): epoch number of training.
24
- earlystop_taskid (int): task id of earlystop metrics relies between multi task (default = 0).
25
- earlystop_patience (int): how long to wait after last time validation auc improved (default = 10).
26
- device (str): `"cpu"` or `"cuda:0"`
27
- gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
28
- model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
29
- """
30
-
31
- def __init__(
32
- self,
33
- model,
34
- task_types,
35
- optimizer_fn=torch.optim.Adam,
36
- optimizer_params=None,
37
- scheduler_fn=None,
38
- scheduler_params=None,
39
- adaptive_params=None,
40
- n_epoch=10,
41
- earlystop_taskid=0,
42
- earlystop_patience=10,
43
- device="cpu",
44
- gpus=None,
45
- model_path="./",
46
- ):
47
- self.model = model
48
- if gpus is None:
49
- gpus = []
50
- if optimizer_params is None:
51
- optimizer_params = {
52
- "lr": 1e-3,
53
- "weight_decay": 1e-5
54
- }
55
- self.task_types = task_types
56
- self.n_task = len(task_types)
57
- self.loss_weight = None
58
- self.adaptive_method = None
59
- if adaptive_params is not None:
60
- if adaptive_params["method"] == "uwl":
61
- self.adaptive_method = "uwl"
62
- self.loss_weight = nn.ParameterList(nn.Parameter(torch.zeros(1)) for _ in range(self.n_task))
63
- self.model.add_module("loss weight", self.loss_weight)
64
- elif adaptive_params["method"] == "metabalance":
65
- self.adaptive_method = "metabalance"
66
- share_layers, task_layers = shared_task_layers(self.model)
67
- self.meta_optimizer = MetaBalance(share_layers)
68
- self.share_optimizer = optimizer_fn(share_layers, **optimizer_params)
69
- self.task_optimizer = optimizer_fn(task_layers, **optimizer_params)
70
- elif adaptive_params["method"] == "gradnorm":
71
- self.adaptive_method = "gradnorm"
72
- self.alpha = adaptive_params.get("alpha", 0.16)
73
- share_layers = shared_task_layers(self.model)[0]
74
- #gradnorm calculate the gradients of each loss on the last fully connected shared layer weight(dimension is 2)
75
- for i in range(len(share_layers)):
76
- if share_layers[-i].ndim == 2:
77
- self.last_share_layer = share_layers[-i]
78
- break
79
- self.initial_task_loss = None
80
- self.loss_weight = nn.ParameterList(nn.Parameter(torch.ones(1)) for _ in range(self.n_task))
81
- self.model.add_module("loss weight", self.loss_weight)
82
- if self.adaptive_method != "metabalance":
83
- self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default Adam optimizer
84
- self.scheduler = None
85
- if scheduler_fn is not None:
86
- self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
87
- self.loss_fns = [get_loss_func(task_type) for task_type in task_types]
88
- self.evaluate_fns = [get_metric_func(task_type) for task_type in task_types]
89
- self.n_epoch = n_epoch
90
- self.earlystop_taskid = earlystop_taskid
91
- self.early_stopper = EarlyStopper(patience=earlystop_patience)
92
-
93
- self.gpus = gpus
94
- if len(gpus) > 1:
95
- print('parallel running on these gpus:', gpus)
96
- self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
97
- self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
98
- self.model.to(self.device)
99
- self.model_path = model_path
100
-
101
-
102
- def train_one_epoch(self, data_loader):
103
- self.model.train()
104
- total_loss = np.zeros(self.n_task)
105
- tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
106
- for iter_i, (x_dict, ys) in enumerate(tk0):
107
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
108
- ys = ys.to(self.device)
109
- y_preds = self.model(x_dict)
110
- loss_list = [self.loss_fns[i](y_preds[:, i], ys[:, i].float()) for i in range(self.n_task)]
111
- if isinstance(self.model, ESMM):
112
- loss = sum(loss_list[1:]) #ESSM only compute loss for ctr and ctcvr task
113
- else:
114
- if self.adaptive_method != None:
115
- if self.adaptive_method == "uwl":
116
- loss = 0
117
- for loss_i, w_i in zip(loss_list, self.loss_weight):
118
- w_i = torch.clamp(w_i, min=0)
119
- loss += 2 * loss_i * torch.exp(-w_i) + w_i
120
- else:
121
- loss = sum(loss_list) / self.n_task
122
- if self.adaptive_method == 'metabalance':
123
- self.share_optimizer.zero_grad()
124
- self.task_optimizer.zero_grad()
125
- self.meta_optimizer.step(loss_list)
126
- self.share_optimizer.step()
127
- self.task_optimizer.step()
128
- elif self.adaptive_method == "gradnorm":
129
- self.optimizer.zero_grad()
130
- if self.initial_task_loss is None:
131
- self.initial_task_loss = [l.item() for l in loss_list]
132
- gradnorm(loss_list, self.loss_weight, self.last_share_layer, self.initial_task_loss, self.alpha)
133
- self.optimizer.step()
134
- # renormalize
135
- loss_weight_sum = sum([w.item() for w in self.loss_weight])
136
- normalize_coeff = len(self.loss_weight) / loss_weight_sum
137
- for w in self.loss_weight:
138
- w.data = w.data * normalize_coeff
139
- else:
140
- self.model.zero_grad()
141
- loss.backward()
142
- self.optimizer.step()
143
- total_loss += np.array([l.item() for l in loss_list])
144
- log_dict = {"task_%d:" % (i): total_loss[i] / (iter_i + 1) for i in range(self.n_task)}
145
- loss_list = [total_loss[i] / (iter_i + 1) for i in range(self.n_task)]
146
- print("train loss: ", log_dict)
147
- if self.loss_weight:
148
- print("loss weight: ", [w.item() for w in self.loss_weight])
149
-
150
- return loss_list
151
-
152
-
153
- def fit(self, train_dataloader, val_dataloader, mode = 'base', seed = 0):
154
- total_log = []
155
-
156
- for epoch_i in range(self.n_epoch):
157
- _log_per_epoch = self.train_one_epoch(train_dataloader)
158
-
159
- if self.scheduler is not None:
160
- if epoch_i % self.scheduler.step_size == 0:
161
- print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
162
- self.scheduler.step() #update lr in epoch level by scheduler
163
- scores = self.evaluate(self.model, val_dataloader)
164
- print('epoch:', epoch_i, 'validation scores: ', scores)
165
-
166
- for score in scores:
167
- _log_per_epoch.append(score)
168
-
169
- total_log.append(_log_per_epoch)
170
-
171
- if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
172
- print('validation best auc of main task %d: %.6f' %
173
- (self.earlystop_taskid, self.early_stopper.best_auc))
174
- self.model.load_state_dict(self.early_stopper.best_weights)
175
- break
176
-
177
- torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) #save best auc model
178
-
179
- return total_log
180
-
181
-
182
- def evaluate(self, model, data_loader):
183
- model.eval()
184
- targets, predicts = list(), list()
185
- with torch.no_grad():
186
- tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
187
- for i, (x_dict, ys) in enumerate(tk0):
188
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
189
- ys = ys.to(self.device)
190
- y_preds = self.model(x_dict)
191
- targets.extend(ys.tolist())
192
- predicts.extend(y_preds.tolist())
193
- targets, predicts = np.array(targets), np.array(predicts)
194
- scores = [self.evaluate_fns[i](targets[:, i], predicts[:, i]) for i in range(self.n_task)]
195
- return scores
196
-
197
-
198
- def predict(self, model, data_loader):
199
- model.eval()
200
- predicts = list()
201
- with torch.no_grad():
202
- tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
203
- for i, x_dict in enumerate(tk0):
204
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
205
- y_preds = model(x_dict)
206
- predicts.extend(y_preds.tolist())
207
- return predicts
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import tqdm
7
+
8
+ from ..basic.callback import EarlyStopper
9
+ from ..basic.loss_func import RegularizationLoss
10
+ from ..models.multi_task import ESMM
11
+ from ..utils.data import get_loss_func, get_metric_func
12
+ from ..utils.mtl import MetaBalance, gradnorm, shared_task_layers
13
+
14
+
15
+ class MTLTrainer(object):
16
+ """A trainer for multi task learning.
17
+
18
+ Args:
19
+ model (nn.Module): any multi task learning model.
20
+ task_types (list): types of tasks, only support ["classfication", "regression"].
21
+ optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
22
+ optimizer_params (dict): parameters of optimizer_fn.
23
+ scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
24
+ scheduler_params (dict): parameters of optimizer scheduler_fn.
25
+ adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
26
+ n_epoch (int): epoch number of training.
27
+ earlystop_taskid (int): task id of earlystop metrics relies between multi task (default = 0).
28
+ earlystop_patience (int): how long to wait after last time validation auc improved (default = 10).
29
+ device (str): `"cpu"` or `"cuda:0"`
30
+ gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
31
+ model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model,
37
+ task_types,
38
+ optimizer_fn=torch.optim.Adam,
39
+ optimizer_params=None,
40
+ regularization_params=None,
41
+ scheduler_fn=None,
42
+ scheduler_params=None,
43
+ adaptive_params=None,
44
+ n_epoch=10,
45
+ earlystop_taskid=0,
46
+ earlystop_patience=10,
47
+ device="cpu",
48
+ gpus=None,
49
+ model_path="./",
50
+ ):
51
+ self.model = model
52
+ if gpus is None:
53
+ gpus = []
54
+ if optimizer_params is None:
55
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
56
+ if regularization_params is None:
57
+ regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
58
+ self.task_types = task_types
59
+ self.n_task = len(task_types)
60
+ self.loss_weight = None
61
+ self.adaptive_method = None
62
+ if adaptive_params is not None:
63
+ if adaptive_params["method"] == "uwl":
64
+ self.adaptive_method = "uwl"
65
+ self.loss_weight = nn.ParameterList(nn.Parameter(torch.zeros(1)) for _ in range(self.n_task))
66
+ self.model.add_module("loss weight", self.loss_weight)
67
+ elif adaptive_params["method"] == "metabalance":
68
+ self.adaptive_method = "metabalance"
69
+ share_layers, task_layers = shared_task_layers(self.model)
70
+ self.meta_optimizer = MetaBalance(share_layers)
71
+ self.share_optimizer = optimizer_fn(share_layers, **optimizer_params)
72
+ self.task_optimizer = optimizer_fn(task_layers, **optimizer_params)
73
+ elif adaptive_params["method"] == "gradnorm":
74
+ self.adaptive_method = "gradnorm"
75
+ self.alpha = adaptive_params.get("alpha", 0.16)
76
+ share_layers = shared_task_layers(self.model)[0]
77
+ # gradnorm calculate the gradients of each loss on the last
78
+ # fully connected shared layer weight(dimension is 2)
79
+ for i in range(len(share_layers)):
80
+ if share_layers[-i].ndim == 2:
81
+ self.last_share_layer = share_layers[-i]
82
+ break
83
+ self.initial_task_loss = None
84
+ self.loss_weight = nn.ParameterList(nn.Parameter(torch.ones(1)) for _ in range(self.n_task))
85
+ self.model.add_module("loss weight", self.loss_weight)
86
+ if self.adaptive_method != "metabalance":
87
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default Adam optimizer
88
+ self.scheduler = None
89
+ if scheduler_fn is not None:
90
+ self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
91
+ self.loss_fns = [get_loss_func(task_type) for task_type in task_types]
92
+ self.evaluate_fns = [get_metric_func(task_type) for task_type in task_types]
93
+ self.n_epoch = n_epoch
94
+ self.earlystop_taskid = earlystop_taskid
95
+ self.early_stopper = EarlyStopper(patience=earlystop_patience)
96
+
97
+ self.gpus = gpus
98
+ if len(gpus) > 1:
99
+ print('parallel running on these gpus:', gpus)
100
+ self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
101
+ # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
102
+ self.device = torch.device(device)
103
+ self.model.to(self.device)
104
+ self.model_path = model_path
105
+ # Initialize regularization loss
106
+ self.reg_loss_fn = RegularizationLoss(**regularization_params)
107
+
108
+ def train_one_epoch(self, data_loader):
109
+ self.model.train()
110
+ total_loss = np.zeros(self.n_task)
111
+ tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
112
+ for iter_i, (x_dict, ys) in enumerate(tk0):
113
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
114
+ ys = ys.to(self.device)
115
+ y_preds = self.model(x_dict)
116
+ loss_list = [self.loss_fns[i](y_preds[:, i], ys[:, i].float()) for i in range(self.n_task)]
117
+ if isinstance(self.model, ESMM):
118
+ # ESSM only compute loss for ctr and ctcvr task
119
+ loss = sum(loss_list[1:])
120
+ else:
121
+ if self.adaptive_method is not None:
122
+ if self.adaptive_method == "uwl":
123
+ loss = 0
124
+ for loss_i, w_i in zip(loss_list, self.loss_weight):
125
+ w_i = torch.clamp(w_i, min=0)
126
+ loss += 2 * loss_i * torch.exp(-w_i) + w_i
127
+ else:
128
+ loss = sum(loss_list) / self.n_task
129
+
130
+ # Add regularization loss
131
+ reg_loss = self.reg_loss_fn(self.model)
132
+ loss = loss + reg_loss
133
+ if self.adaptive_method == 'metabalance':
134
+ self.share_optimizer.zero_grad()
135
+ self.task_optimizer.zero_grad()
136
+ self.meta_optimizer.step(loss_list)
137
+ self.share_optimizer.step()
138
+ self.task_optimizer.step()
139
+ elif self.adaptive_method == "gradnorm":
140
+ self.optimizer.zero_grad()
141
+ if self.initial_task_loss is None:
142
+ self.initial_task_loss = [l.item() for l in loss_list]
143
+ gradnorm(loss_list, self.loss_weight, self.last_share_layer, self.initial_task_loss, self.alpha)
144
+ self.optimizer.step()
145
+ # renormalize
146
+ loss_weight_sum = sum([w.item() for w in self.loss_weight])
147
+ normalize_coeff = len(self.loss_weight) / loss_weight_sum
148
+ for w in self.loss_weight:
149
+ w.data = w.data * normalize_coeff
150
+ else:
151
+ self.model.zero_grad()
152
+ loss.backward()
153
+ self.optimizer.step()
154
+ total_loss += np.array([l.item() for l in loss_list])
155
+ log_dict = {"task_%d:" % (i): total_loss[i] / (iter_i + 1) for i in range(self.n_task)}
156
+ loss_list = [total_loss[i] / (iter_i + 1) for i in range(self.n_task)]
157
+ print("train loss: ", log_dict)
158
+ if self.loss_weight:
159
+ print("loss weight: ", [w.item() for w in self.loss_weight])
160
+
161
+ return loss_list
162
+
163
+ def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
164
+ total_log = []
165
+
166
+ for epoch_i in range(self.n_epoch):
167
+ _log_per_epoch = self.train_one_epoch(train_dataloader)
168
+
169
+ if self.scheduler is not None:
170
+ if epoch_i % self.scheduler.step_size == 0:
171
+ print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
172
+ self.scheduler.step() # update lr in epoch level by scheduler
173
+ scores = self.evaluate(self.model, val_dataloader)
174
+ print('epoch:', epoch_i, 'validation scores: ', scores)
175
+
176
+ for score in scores:
177
+ _log_per_epoch.append(score)
178
+
179
+ total_log.append(_log_per_epoch)
180
+
181
+ if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
182
+ print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
183
+ self.model.load_state_dict(self.early_stopper.best_weights)
184
+ break
185
+
186
+ torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
187
+
188
+ return total_log
189
+
190
+ def evaluate(self, model, data_loader):
191
+ model.eval()
192
+ targets, predicts = list(), list()
193
+ with torch.no_grad():
194
+ tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
195
+ for i, (x_dict, ys) in enumerate(tk0):
196
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
197
+ ys = ys.to(self.device)
198
+ y_preds = self.model(x_dict)
199
+ targets.extend(ys.tolist())
200
+ predicts.extend(y_preds.tolist())
201
+ targets, predicts = np.array(targets), np.array(predicts)
202
+ scores = [self.evaluate_fns[i](targets[:, i], predicts[:, i]) for i in range(self.n_task)]
203
+ return scores
204
+
205
+ def predict(self, model, data_loader):
206
+ model.eval()
207
+ predicts = list()
208
+ with torch.no_grad():
209
+ tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
210
+ for i, x_dict in enumerate(tk0):
211
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
212
+ y_preds = model(x_dict)
213
+ predicts.extend(y_preds.tolist())
214
+ return predicts
215
+
216
+ def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
217
+ """Export the trained multi-task model to ONNX format.
218
+
219
+ This method exports multi-task learning models (e.g., MMOE, PLE, ESMM, SharedBottom)
220
+ to ONNX format for deployment. The exported model will have multiple outputs
221
+ corresponding to each task.
222
+
223
+ Note:
224
+ The ONNX model will output a tensor of shape [batch_size, n_task] where
225
+ n_task is the number of tasks in the multi-task model.
226
+
227
+ Args:
228
+ output_path (str): Path to save the ONNX model file.
229
+ dummy_input (dict, optional): Example input dict {feature_name: tensor}.
230
+ If not provided, dummy inputs will be generated automatically.
231
+ batch_size (int): Batch size for auto-generated dummy input (default: 2).
232
+ seq_length (int): Sequence length for SequenceFeature (default: 10).
233
+ opset_version (int): ONNX opset version (default: 14).
234
+ dynamic_batch (bool): Enable dynamic batch size (default: True).
235
+ device (str, optional): Device for export ('cpu', 'cuda', etc.).
236
+ If None, defaults to 'cpu' for maximum compatibility.
237
+ verbose (bool): Print export details (default: False).
238
+
239
+ Returns:
240
+ bool: True if export succeeded, False otherwise.
241
+
242
+ Example:
243
+ >>> trainer = MTLTrainer(mmoe_model, task_types=["classification", "classification"], ...)
244
+ >>> trainer.fit(train_dl, val_dl)
245
+ >>> trainer.export_onnx("mmoe.onnx")
246
+
247
+ >>> # Export on specific device
248
+ >>> trainer.export_onnx("mmoe.onnx", device="cpu")
249
+ """
250
+ from ..utils.onnx_export import ONNXExporter
251
+
252
+ # Handle DataParallel wrapped model
253
+ model = self.model.module if hasattr(self.model, 'module') else self.model
254
+
255
+ # Use provided device or default to 'cpu'
256
+ export_device = device if device is not None else 'cpu'
257
+
258
+ exporter = ONNXExporter(model, device=export_device)
259
+ return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
260
+
261
+ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
262
+ """Visualize the model's computation graph.
263
+
264
+ This method generates a visual representation of the model architecture,
265
+ showing layer connections, tensor shapes, and nested module structures.
266
+ It automatically extracts feature information from the model.
267
+
268
+ Parameters
269
+ ----------
270
+ input_data : dict, optional
271
+ Example input dict {feature_name: tensor}.
272
+ If not provided, dummy inputs will be generated automatically.
273
+ batch_size : int, default=2
274
+ Batch size for auto-generated dummy input.
275
+ seq_length : int, default=10
276
+ Sequence length for SequenceFeature.
277
+ depth : int, default=3
278
+ Visualization depth, higher values show more detail.
279
+ Set to -1 to show all layers.
280
+ show_shapes : bool, default=True
281
+ Whether to display tensor shapes.
282
+ expand_nested : bool, default=True
283
+ Whether to expand nested modules.
284
+ save_path : str, optional
285
+ Path to save the graph image (.pdf, .svg, .png).
286
+ If None, displays in Jupyter or opens system viewer.
287
+ graph_name : str, default="model"
288
+ Name for the graph.
289
+ device : str, optional
290
+ Device for model execution. If None, defaults to 'cpu'.
291
+ dpi : int, default=300
292
+ Resolution in dots per inch for output image.
293
+ Higher values produce sharper images suitable for papers.
294
+ **kwargs : dict
295
+ Additional arguments passed to torchview.draw_graph().
296
+
297
+ Returns
298
+ -------
299
+ ComputationGraph
300
+ A torchview ComputationGraph object.
301
+
302
+ Raises
303
+ ------
304
+ ImportError
305
+ If torchview or graphviz is not installed.
306
+
307
+ Notes
308
+ -----
309
+ Default Display Behavior:
310
+ When `save_path` is None (default):
311
+ - In Jupyter/IPython: automatically displays the graph inline
312
+ - In Python script: opens the graph with system default viewer
313
+
314
+ Examples
315
+ --------
316
+ >>> trainer = MTLTrainer(model, task_types=["classification", "classification"])
317
+ >>> trainer.fit(train_dl, val_dl)
318
+ >>>
319
+ >>> # Auto-display in Jupyter (no save_path needed)
320
+ >>> trainer.visualization(depth=4)
321
+ >>>
322
+ >>> # Save to high-DPI PNG for papers
323
+ >>> trainer.visualization(save_path="model.png", dpi=300)
324
+ """
325
+ from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
326
+
327
+ if not TORCHVIEW_AVAILABLE:
328
+ raise ImportError(
329
+ "Visualization requires torchview. "
330
+ "Install with: pip install torch-rechub[visualization]\n"
331
+ "Also ensure graphviz is installed on your system:\n"
332
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
333
+ " - macOS: brew install graphviz\n"
334
+ " - Windows: choco install graphviz"
335
+ )
336
+
337
+ # Handle DataParallel wrapped model
338
+ model = self.model.module if hasattr(self.model, 'module') else self.model
339
+
340
+ # Use provided device or default to 'cpu'
341
+ viz_device = device if device is not None else 'cpu'
342
+
343
+ return visualize_model(
344
+ model,
345
+ input_data=input_data,
346
+ batch_size=batch_size,
347
+ seq_length=seq_length,
348
+ depth=depth,
349
+ show_shapes=show_shapes,
350
+ expand_nested=expand_nested,
351
+ save_path=save_path,
352
+ graph_name=graph_name,
353
+ device=viz_device,
354
+ dpi=dpi,
355
+ **kwargs
356
+ )